Skip to content

Commit 9de693f

Browse files
API, Spark: Fix aggregation pushdown on struct fields (apache#9176)
1 parent 26d62c0 commit 9de693f

File tree

2 files changed

+127
-1
lines changed

2 files changed

+127
-1
lines changed

api/src/main/java/org/apache/iceberg/expressions/ValueAggregate.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ public int size() {
6060
@Override
6161
@SuppressWarnings("unchecked")
6262
public <T> T get(int pos, Class<T> javaClass) {
63-
return (T) value;
63+
if (javaClass.isAssignableFrom(StructLike.class)) {
64+
return (T) this;
65+
} else {
66+
return (T) value;
67+
}
6468
}
6569

6670
@Override

spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.math.BigDecimal;
2424
import java.sql.Date;
2525
import java.sql.Timestamp;
26+
import java.util.Arrays;
2627
import java.util.List;
2728
import java.util.Locale;
2829
import org.apache.iceberg.CatalogUtil;
@@ -36,6 +37,7 @@
3637
import org.apache.iceberg.spark.CatalogTestBase;
3738
import org.apache.iceberg.spark.TestBase;
3839
import org.apache.spark.sql.SparkSession;
40+
import org.assertj.core.api.Assertions;
3941
import org.junit.jupiter.api.AfterEach;
4042
import org.junit.jupiter.api.BeforeAll;
4143
import org.junit.jupiter.api.TestTemplate;
@@ -478,6 +480,126 @@ public void testAggregateWithComplexType() {
478480
.isFalse();
479481
}
480482

483+
@TestTemplate
484+
public void testAggregationPushdownStructInteger() {
485+
sql("CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:BIGINT>) USING iceberg", tableName);
486+
sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName);
487+
sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2))", tableName);
488+
sql("INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 3))", tableName);
489+
490+
String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
491+
String aggField = "struct_with_int.c1";
492+
assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 3L, 2L);
493+
assertExplainContains(
494+
sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
495+
"count(struct_with_int.c1)",
496+
"max(struct_with_int.c1)",
497+
"min(struct_with_int.c1)");
498+
}
499+
500+
@TestTemplate
501+
public void testAggregationPushdownNestedStruct() {
502+
sql(
503+
"CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:STRUCT<c2:STRUCT<c3:STRUCT<c4:BIGINT>>>>) USING iceberg",
504+
tableName);
505+
sql(
506+
"INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", NULL)))))",
507+
tableName);
508+
sql(
509+
"INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 2)))))",
510+
tableName);
511+
sql(
512+
"INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 3)))))",
513+
tableName);
514+
515+
String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
516+
String aggField = "struct_with_int.c1.c2.c3.c4";
517+
518+
assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 3L, 2L);
519+
520+
assertExplainContains(
521+
sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
522+
"count(struct_with_int.c1.c2.c3.c4)",
523+
"max(struct_with_int.c1.c2.c3.c4)",
524+
"min(struct_with_int.c1.c2.c3.c4)");
525+
}
526+
527+
@TestTemplate
528+
public void testAggregationPushdownStructTimestamp() {
529+
sql(
530+
"CREATE TABLE %s (id BIGINT, struct_with_ts STRUCT<c1:TIMESTAMP>) USING iceberg",
531+
tableName);
532+
sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName);
533+
sql(
534+
"INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", timestamp('2023-01-30T22:22:22Z')))",
535+
tableName);
536+
sql(
537+
"INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", timestamp('2023-01-30T22:23:23Z')))",
538+
tableName);
539+
540+
String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
541+
String aggField = "struct_with_ts.c1";
542+
543+
assertAggregates(
544+
sql(query, aggField, aggField, aggField, tableName),
545+
2L,
546+
new Timestamp(1675117403000L),
547+
new Timestamp(1675117342000L));
548+
549+
assertExplainContains(
550+
sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
551+
"count(struct_with_ts.c1)",
552+
"max(struct_with_ts.c1)",
553+
"min(struct_with_ts.c1)");
554+
}
555+
556+
@TestTemplate
557+
public void testAggregationPushdownOnBucketedColumn() {
558+
sql(
559+
"CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:INT>) USING iceberg PARTITIONED BY (bucket(8, id))",
560+
tableName);
561+
562+
sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName);
563+
sql("INSERT INTO TABLE %s VALUES (null, named_struct(\"c1\", 2))", tableName);
564+
sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 3))", tableName);
565+
566+
String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
567+
String aggField = "id";
568+
assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 2L, 1L);
569+
assertExplainContains(
570+
sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
571+
"count(id)",
572+
"max(id)",
573+
"min(id)");
574+
}
575+
576+
private void assertAggregates(
577+
List<Object[]> actual, Object expectedCount, Object expectedMax, Object expectedMin) {
578+
Object actualCount = actual.get(0)[0];
579+
Object actualMax = actual.get(0)[1];
580+
Object actualMin = actual.get(0)[2];
581+
582+
Assertions.assertThat(actualCount)
583+
.as("Expected and actual count should equal")
584+
.isEqualTo(expectedCount);
585+
Assertions.assertThat(actualMax)
586+
.as("Expected and actual max should equal")
587+
.isEqualTo(expectedMax);
588+
Assertions.assertThat(actualMin)
589+
.as("Expected and actual min should equal")
590+
.isEqualTo(expectedMin);
591+
}
592+
593+
private void assertExplainContains(List<Object[]> explain, String... expectedFragments) {
594+
String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT);
595+
Arrays.stream(expectedFragments)
596+
.forEach(
597+
fragment ->
598+
Assertions.assertThat(explainString.contains(fragment))
599+
.isTrue()
600+
.as("Expected to find plan fragment in explain plan"));
601+
}
602+
481603
@TestTemplate
482604
public void testAggregatePushDownInDeleteCopyOnWrite() {
483605
sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);

0 commit comments

Comments
 (0)