|
23 | 23 | import java.math.BigDecimal;
|
24 | 24 | import java.sql.Date;
|
25 | 25 | import java.sql.Timestamp;
|
| 26 | +import java.util.Arrays; |
26 | 27 | import java.util.List;
|
27 | 28 | import java.util.Locale;
|
28 | 29 | import org.apache.iceberg.CatalogUtil;
|
|
36 | 37 | import org.apache.iceberg.spark.CatalogTestBase;
|
37 | 38 | import org.apache.iceberg.spark.TestBase;
|
38 | 39 | import org.apache.spark.sql.SparkSession;
|
| 40 | +import org.assertj.core.api.Assertions; |
39 | 41 | import org.junit.jupiter.api.AfterEach;
|
40 | 42 | import org.junit.jupiter.api.BeforeAll;
|
41 | 43 | import org.junit.jupiter.api.TestTemplate;
|
@@ -478,6 +480,126 @@ public void testAggregateWithComplexType() {
|
478 | 480 | .isFalse();
|
479 | 481 | }
|
480 | 482 |
|
| 483 | + @TestTemplate |
| 484 | + public void testAggregationPushdownStructInteger() { |
| 485 | + sql("CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:BIGINT>) USING iceberg", tableName); |
| 486 | + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName); |
| 487 | + sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2))", tableName); |
| 488 | + sql("INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 3))", tableName); |
| 489 | + |
| 490 | + String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s"; |
| 491 | + String aggField = "struct_with_int.c1"; |
| 492 | + assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 3L, 2L); |
| 493 | + assertExplainContains( |
| 494 | + sql("EXPLAIN " + query, aggField, aggField, aggField, tableName), |
| 495 | + "count(struct_with_int.c1)", |
| 496 | + "max(struct_with_int.c1)", |
| 497 | + "min(struct_with_int.c1)"); |
| 498 | + } |
| 499 | + |
| 500 | + @TestTemplate |
| 501 | + public void testAggregationPushdownNestedStruct() { |
| 502 | + sql( |
| 503 | + "CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:STRUCT<c2:STRUCT<c3:STRUCT<c4:BIGINT>>>>) USING iceberg", |
| 504 | + tableName); |
| 505 | + sql( |
| 506 | + "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", NULL)))))", |
| 507 | + tableName); |
| 508 | + sql( |
| 509 | + "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 2)))))", |
| 510 | + tableName); |
| 511 | + sql( |
| 512 | + "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 3)))))", |
| 513 | + tableName); |
| 514 | + |
| 515 | + String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s"; |
| 516 | + String aggField = "struct_with_int.c1.c2.c3.c4"; |
| 517 | + |
| 518 | + assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 3L, 2L); |
| 519 | + |
| 520 | + assertExplainContains( |
| 521 | + sql("EXPLAIN " + query, aggField, aggField, aggField, tableName), |
| 522 | + "count(struct_with_int.c1.c2.c3.c4)", |
| 523 | + "max(struct_with_int.c1.c2.c3.c4)", |
| 524 | + "min(struct_with_int.c1.c2.c3.c4)"); |
| 525 | + } |
| 526 | + |
| 527 | + @TestTemplate |
| 528 | + public void testAggregationPushdownStructTimestamp() { |
| 529 | + sql( |
| 530 | + "CREATE TABLE %s (id BIGINT, struct_with_ts STRUCT<c1:TIMESTAMP>) USING iceberg", |
| 531 | + tableName); |
| 532 | + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName); |
| 533 | + sql( |
| 534 | + "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", timestamp('2023-01-30T22:22:22Z')))", |
| 535 | + tableName); |
| 536 | + sql( |
| 537 | + "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", timestamp('2023-01-30T22:23:23Z')))", |
| 538 | + tableName); |
| 539 | + |
| 540 | + String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s"; |
| 541 | + String aggField = "struct_with_ts.c1"; |
| 542 | + |
| 543 | + assertAggregates( |
| 544 | + sql(query, aggField, aggField, aggField, tableName), |
| 545 | + 2L, |
| 546 | + new Timestamp(1675117403000L), |
| 547 | + new Timestamp(1675117342000L)); |
| 548 | + |
| 549 | + assertExplainContains( |
| 550 | + sql("EXPLAIN " + query, aggField, aggField, aggField, tableName), |
| 551 | + "count(struct_with_ts.c1)", |
| 552 | + "max(struct_with_ts.c1)", |
| 553 | + "min(struct_with_ts.c1)"); |
| 554 | + } |
| 555 | + |
| 556 | + @TestTemplate |
| 557 | + public void testAggregationPushdownOnBucketedColumn() { |
| 558 | + sql( |
| 559 | + "CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:INT>) USING iceberg PARTITIONED BY (bucket(8, id))", |
| 560 | + tableName); |
| 561 | + |
| 562 | + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName); |
| 563 | + sql("INSERT INTO TABLE %s VALUES (null, named_struct(\"c1\", 2))", tableName); |
| 564 | + sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 3))", tableName); |
| 565 | + |
| 566 | + String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s"; |
| 567 | + String aggField = "id"; |
| 568 | + assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 2L, 1L); |
| 569 | + assertExplainContains( |
| 570 | + sql("EXPLAIN " + query, aggField, aggField, aggField, tableName), |
| 571 | + "count(id)", |
| 572 | + "max(id)", |
| 573 | + "min(id)"); |
| 574 | + } |
| 575 | + |
| 576 | + private void assertAggregates( |
| 577 | + List<Object[]> actual, Object expectedCount, Object expectedMax, Object expectedMin) { |
| 578 | + Object actualCount = actual.get(0)[0]; |
| 579 | + Object actualMax = actual.get(0)[1]; |
| 580 | + Object actualMin = actual.get(0)[2]; |
| 581 | + |
| 582 | + Assertions.assertThat(actualCount) |
| 583 | + .as("Expected and actual count should equal") |
| 584 | + .isEqualTo(expectedCount); |
| 585 | + Assertions.assertThat(actualMax) |
| 586 | + .as("Expected and actual max should equal") |
| 587 | + .isEqualTo(expectedMax); |
| 588 | + Assertions.assertThat(actualMin) |
| 589 | + .as("Expected and actual min should equal") |
| 590 | + .isEqualTo(expectedMin); |
| 591 | + } |
| 592 | + |
| 593 | + private void assertExplainContains(List<Object[]> explain, String... expectedFragments) { |
| 594 | + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); |
| 595 | + Arrays.stream(expectedFragments) |
| 596 | + .forEach( |
| 597 | + fragment -> |
| 598 | + Assertions.assertThat(explainString.contains(fragment)) |
| 599 | + .isTrue() |
| 600 | + .as("Expected to find plan fragment in explain plan")); |
| 601 | + } |
| 602 | + |
481 | 603 | @TestTemplate
|
482 | 604 | public void testAggregatePushDownInDeleteCopyOnWrite() {
|
483 | 605 | sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);
|
|
0 commit comments