diff --git a/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/SlidingWindowUtils.java b/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/SlidingWindowUtils.java index 5023670b..ef16ce4c 100644 --- a/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/SlidingWindowUtils.java +++ b/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/SlidingWindowUtils.java @@ -224,13 +224,14 @@ public SlidingWindowPreprocessAggregateFunction( @Override public Row createAccumulator() { - Row acc = Row.withNames(); + int arity = keyFields.size() + 1 + aggDescriptors.getAggFieldDescriptors().size(); + Object[] values = new Object[arity]; for (AggregationFieldsDescriptor.AggregationFieldDescriptor descriptor : aggDescriptors.getAggFieldDescriptors()) { - acc.setField( - descriptor.fieldName, descriptor.aggFuncWithoutRetract.createAccumulator()); + int pos = keyFields.size() + 1 + aggDescriptors.getAggFieldIdx(descriptor); + values[pos] = descriptor.aggFuncWithoutRetract.createAccumulator(); } - return acc; + return Row.of(values); } @Override @@ -239,16 +240,17 @@ public Row add(Row row, Row acc) { for (AggregationFieldsDescriptor.AggregationFieldDescriptor descriptor : aggDescriptors.getAggFieldDescriptors()) { Object fieldValue = row.getFieldAs(descriptor.fieldName); - Object fieldAcc = acc.getField(descriptor.fieldName); + int pos = keyFields.size() + 1 + aggDescriptors.getAggFieldIdx(descriptor); + Object fieldAcc = acc.getFieldAs(pos); descriptor.aggFuncWithoutRetract.add(fieldAcc, fieldValue, timestamp); } - if (acc.getField(rowTimeFieldName) == null) { - acc.setField( - rowTimeFieldName, - Instant.ofEpochMilli(getWindowTime(timestamp, size, offset))); + if (acc.getField(keyFields.size()) == null) { + acc.setField(keyFields.size(), Instant.ofEpochMilli(getWindowTime(timestamp, size, offset))); + int idx = 0; for (String key : keyFields) { - acc.setField(key, row.getField(key)); + acc.setField(idx, row.getField(key)); + idx += 1; } } @@ -264,14 +266,17 @@ public Row getResult(Row acc) { public Row merge(Row acc1, Row acc2) { for (AggregationFieldsDescriptor.AggregationFieldDescriptor descriptor : aggDescriptors.getAggFieldDescriptors()) { - Object fieldAcc1 = acc1.getField(descriptor.fieldName); - Object fieldAcc2 = acc2.getField(descriptor.fieldName); + int pos = keyFields.size() + 1 + aggDescriptors.getAggFieldIdx(descriptor); + Object fieldAcc1 = acc1.getField(pos); + Object fieldAcc2 = acc2.getField(pos); descriptor.aggFuncWithoutRetract.merge(fieldAcc1, fieldAcc2); } - if (acc1.getField(rowTimeFieldName) == null) { - acc1.setField(rowTimeFieldName, acc2.getField(rowTimeFieldName)); + if (acc1.getField(keyFields.size()) == null) { + acc1.setField(keyFields.size(), acc2.getField(keyFields.size())); + int idx = 0; for (String key : keyFields) { - acc1.setField(key, acc2.getField(key)); + acc1.setField(idx, acc2.getField(idx)); + idx += 1; } } return acc1; @@ -354,27 +359,33 @@ public static Table applySlidingWindowAggregationProcess( rowTimeFieldName, keyFieldNames); } - rowDataStream = - rowDataStream - .keyBy( - (KeySelector) - value -> - Row.of( - Arrays.stream(keyFieldNames) - .map(value::getField) - .toArray())) - .process( - new SlidingWindowKeyedProcessFunction( - aggregationFieldsDescriptor, - rowTypeSerializer, - resultRowTypeInfo.createSerializer(null), - keyFieldNames, - rowTimeFieldName, - windowDescriptor.stepSize.toMillis(), - expiredRowHandler, - skipSameWindowOutput)) - .setParallelism(rowDataStream.getParallelism()) - .returns(resultRowTypeInfo); + rowDataStream = rowDataStream + .keyBy((KeySelector) row -> { + List values = new ArrayList<>(); + for (int i = 0; i < keyFieldNames.length; i += 1) { + Object value; + try { + value = row.getField(i); + } catch (IllegalArgumentException e) { + value = row.getField(keyFieldNames[i]); + } + values.add(value); + } + return Row.of(values.toArray(new Object[0])); + }) + .process( + new SlidingWindowKeyedProcessFunction( + aggregationFieldsDescriptor, + rowTypeSerializer, + resultRowTypeInfo.createSerializer(null), + keyFieldNames, + rowTimeFieldName, + windowDescriptor.stepSize.toMillis(), + expiredRowHandler, + skipSameWindowOutput) + ).setParallelism(rowDataStream.getParallelism()) + .returns(resultRowTypeInfo); + Table table = tEnv.fromDataStream( diff --git a/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/processfunction/SlidingWindowKeyedProcessFunction.java b/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/processfunction/SlidingWindowKeyedProcessFunction.java index 9124a534..838b82d8 100644 --- a/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/processfunction/SlidingWindowKeyedProcessFunction.java +++ b/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/processfunction/SlidingWindowKeyedProcessFunction.java @@ -237,8 +237,14 @@ public void onTimer( break; } for (Row row : state.timestampToRows.get(rowTime)) { - descriptor.aggFunc.retractAccumulator( - accumulatorState, row.getField(descriptor.fieldName)); + Object value; + try { + int idx = keyFieldNames.length + 1 + aggregationFieldsDescriptor.getAggFieldIdx(descriptor); + value = row.getField(idx); + } catch (IllegalArgumentException e) { + value = row.getField(descriptor.fieldName); + } + descriptor.aggFunc.retractAccumulator(accumulatorState, value); } } if (leftIdx < timestampList.size() && timestampList.get(leftIdx) <= timestamp) {