Skip to content

Commit 49d3b9d

Browse files
Pratik Joseph DabrePratik Joseph Dabre
authored andcommitted
Add utility functions to resolve intermediate type and fix bug in json handling of aggregate functions
1 parent 9c6435a commit 49d3b9d

File tree

4 files changed

+148
-11
lines changed

4 files changed

+148
-11
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.common.type;
15+
16+
import java.util.ArrayList;
17+
import java.util.Collections;
18+
import java.util.HashMap;
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.Optional;
22+
import java.util.stream.Collectors;
23+
24+
public final class TypeSignatureUtils
25+
{
26+
private TypeSignatureUtils() {}
27+
28+
public static TypeSignature resolveIntermediateType(TypeSignature typeSignature, List<TypeSignature> parameters, List<TypeSignature> argumentTypes)
29+
{
30+
Map<TypeSignature, TypeSignature> typeSignatureMap = getTypeSignatureMap(parameters, argumentTypes);
31+
return resolveTypeSignatures(typeSignature, typeSignatureMap).getTypeSignature();
32+
}
33+
34+
// todo: Change to ImmutableList when open sourcing
35+
36+
private static NamedTypeSignature resolveTypeSignatures(TypeSignature typeSignature, Map<TypeSignature, TypeSignature> typeSignatureMap)
37+
{
38+
TypeSignature resolvedTypeSignature = typeSignatureMap.getOrDefault(typeSignature, typeSignature);
39+
List<NamedTypeSignature> namedTypeSignatures = new ArrayList<>();
40+
List<TypeSignature> typeSignatures = new ArrayList<>();
41+
List<TypeSignatureParameter> typeSignaturesList = typeSignature.getParameters();
42+
for (TypeSignatureParameter typeSignatureParameter : typeSignaturesList) {
43+
TypeSignature typeSignatureOrNamedTypeSignature = typeSignatureParameter.getTypeSignatureOrNamedTypeSignature().orElseThrow(() ->
44+
new IllegalStateException("Could not get type signature for type parameter [" + typeSignatureParameter + "]"));
45+
TypeSignature resolvedTypeParameterSignature = typeSignatureMap.getOrDefault(typeSignatureOrNamedTypeSignature, typeSignatureOrNamedTypeSignature);
46+
if (resolvedTypeSignature.getBase().equals("row")) {
47+
if (!typeSignatureOrNamedTypeSignature.getParameters().isEmpty()) {
48+
namedTypeSignatures.add(resolveTypeSignatures(resolvedTypeParameterSignature, typeSignatureMap));
49+
}
50+
else {
51+
namedTypeSignatures.add(new NamedTypeSignature(Optional.empty(), new TypeSignature(resolvedTypeParameterSignature.getBase(), Collections.emptyList())));
52+
}
53+
}
54+
else {
55+
if (!typeSignatureOrNamedTypeSignature.getParameters().isEmpty()) {
56+
typeSignatures.add(resolveTypeSignatures(resolvedTypeParameterSignature, typeSignatureMap).getTypeSignature());
57+
}
58+
else {
59+
typeSignatures.add(new TypeSignature(resolvedTypeParameterSignature.getBase(), Collections.emptyList()));
60+
}
61+
}
62+
}
63+
return new NamedTypeSignature(Optional.empty(), new TypeSignature(resolvedTypeSignature.getBase(),
64+
(typeSignatures.isEmpty() ? namedTypeSignatures : typeSignatures).stream().map(
65+
signature -> signature instanceof NamedTypeSignature ?
66+
TypeSignatureParameter.of((NamedTypeSignature) signature)
67+
: TypeSignatureParameter.of((TypeSignature) signature)).collect(Collectors.toList())));
68+
}
69+
70+
/**
71+
* Parameter and argument type mapping must be consistent
72+
*/
73+
74+
public static Map<TypeSignature, TypeSignature> getTypeSignatureMap(List<TypeSignature> parameters, List<TypeSignature> argumentTypes)
75+
{
76+
HashMap<TypeSignature, TypeSignature> typeSignatureMap = new HashMap<>();
77+
if (argumentTypes.size() != parameters.size()) {
78+
throw new IllegalStateException("Parameters size and argumentTypes size do not match!");
79+
}
80+
for (int i = 0; i < argumentTypes.size(); i++) {
81+
TypeSignature parameter = parameters.get(i);
82+
TypeSignature argumentType = argumentTypes.get(i);
83+
if (argumentTypes.get(i).getParameters().isEmpty()) {
84+
typeSignatureMap.put(parameter, argumentType);
85+
}
86+
else {
87+
typeSignatureMap.putAll(getTypeSignatureMap(parameter.getTypeParametersAsTypeSignatures(), argumentType.getTypeParametersAsTypeSignatures()));
88+
}
89+
}
90+
return typeSignatureMap;
91+
}
92+
}

presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ protected AggregationFunctionImplementation sqlInvokedFunctionToAggregationImple
357357
getClass().getSimpleName(),
358358
implementationType));
359359
case CPP:
360+
case REST:
360361
checkArgument(
361362
function.getAggregationMetadata().isPresent(),
362363
"Need aggregationMetadata to get aggregation function implementation");

presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManager.java

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.facebook.presto.spi.NodeManager;
2929
import com.facebook.presto.spi.PrestoException;
3030
import com.facebook.presto.spi.function.AggregationFunctionImplementation;
31+
import com.facebook.presto.spi.function.AggregationFunctionMetadata;
3132
import com.facebook.presto.spi.function.AlterRoutineCharacteristics;
3233
import com.facebook.presto.spi.function.FunctionHandle;
3334
import com.facebook.presto.spi.function.FunctionMetadata;
@@ -40,6 +41,7 @@
4041
import com.facebook.presto.spi.function.SqlFunctionHandle;
4142
import com.facebook.presto.spi.function.SqlFunctionId;
4243
import com.facebook.presto.spi.function.SqlFunctionSupplier;
44+
import com.facebook.presto.spi.function.SqlInvokedAggregationFunctionImplementation;
4345
import com.facebook.presto.spi.function.SqlInvokedFunction;
4446
import com.facebook.presto.spi.function.TypeVariableConstraint;
4547
import com.google.common.base.Suppliers;
@@ -59,7 +61,9 @@
5961
import java.util.concurrent.ConcurrentHashMap;
6062
import java.util.concurrent.TimeUnit;
6163
import java.util.function.Supplier;
64+
import java.util.stream.Collectors;
6265

66+
import static com.facebook.presto.common.type.TypeSignatureUtils.resolveIntermediateType;
6367
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR;
6468
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
6569
import static com.facebook.presto.spi.function.FunctionVersion.notVersioned;
@@ -141,19 +145,52 @@ public final AggregationFunctionImplementation getAggregateFunctionImplementatio
141145
checkArgument(functionHandle instanceof SqlFunctionHandle, "Unsupported FunctionHandle type '%s'", functionHandle.getClass().getSimpleName());
142146

143147
SqlFunctionHandle sqlFunctionHandle = (SqlFunctionHandle) functionHandle;
148+
if (functionHandle instanceof NativeFunctionHandle) {
149+
NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle;
150+
return processNativeFunctionHandle(nativeFunctionHandle, typeManager);
151+
}
152+
else {
153+
return processSqlFunctionHandle(sqlFunctionHandle, typeManager);
154+
}
155+
}
156+
157+
private AggregationFunctionImplementation processNativeFunctionHandle(NativeFunctionHandle nativeFunctionHandle, TypeManager typeManager)
158+
{
159+
if (!aggregationImplementationByHandle.containsKey(nativeFunctionHandle)) {
160+
Signature signature = nativeFunctionHandle.getSignature();
161+
SqlFunction function = getSqlFunctionFromSignature(signature);
162+
SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function;
163+
164+
checkArgument(
165+
sqlFunction.getAggregationMetadata().isPresent(),
166+
"Need aggregationMetadata to get aggregation function implementation");
167+
168+
AggregationFunctionMetadata aggregationMetadata = sqlFunction.getAggregationMetadata().get();
169+
TypeSignature intermediateType = aggregationMetadata.getIntermediateType();
170+
List<TypeSignature> typeSignatures = sqlFunction.getParameters().stream().map(Parameter::getType).collect(Collectors.toList());
171+
TypeSignature resolvedIntermediateType = resolveIntermediateType(intermediateType, typeSignatures, signature.getArgumentTypes());
172+
aggregationImplementationByHandle.put(
173+
nativeFunctionHandle,
174+
new SqlInvokedAggregationFunctionImplementation(
175+
typeManager.getType(resolvedIntermediateType),
176+
typeManager.getType(signature.getReturnType()),
177+
aggregationMetadata.isOrderSensitive()));
178+
}
179+
return aggregationImplementationByHandle.get(nativeFunctionHandle);
180+
}
144181

145-
// Cache results if applicable
182+
private AggregationFunctionImplementation processSqlFunctionHandle(SqlFunctionHandle sqlFunctionHandle, TypeManager typeManager)
183+
{
146184
if (!aggregationImplementationByHandle.containsKey(sqlFunctionHandle)) {
147185
SqlFunctionId functionId = sqlFunctionHandle.getFunctionId();
148-
if (!latestFunctions.containsKey(functionId)) {
186+
if (!memoizedFunctionsSupplier.get().containsKey(functionId)) {
149187
throw new PrestoException(GENERIC_USER_ERROR, format("Function '%s' is missing from cache", functionId.getId()));
150188
}
151189

152190
aggregationImplementationByHandle.put(
153191
sqlFunctionHandle,
154-
sqlInvokedFunctionToAggregationImplementation(latestFunctions.get(functionId), typeManager));
192+
sqlInvokedFunctionToAggregationImplementation(memoizedFunctionsSupplier.get().get(functionId), typeManager));
155193
}
156-
157194
return aggregationImplementationByHandle.get(sqlFunctionHandle);
158195
}
159196

@@ -303,18 +340,23 @@ public final FunctionHandle getFunctionHandle(Optional<? extends FunctionNamespa
303340
return functionHandle;
304341
}
305342

306-
private FunctionMetadata getMetadataFromNativeFunctionHandle(SqlFunctionHandle functionHandle)
343+
private SqlFunction getSqlFunctionFromSignature(Signature signature)
307344
{
308-
NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle;
309-
Signature signature = nativeFunctionHandle.getSignature();
310345
SqlFunctionSupplier functionKey;
311346
try {
312347
functionKey = specializedFunctionKeyCache.getUnchecked(signature);
313348
}
314349
catch (UncheckedExecutionException e) {
315-
throw convertToPrestoException(e, format("Error getting FunctionMetadata for handle: %s", functionHandle));
350+
throw convertToPrestoException(e, format("Error getting FunctionMetadata for signature: %s", signature));
316351
}
317-
SqlFunction function = functionKey.getFunction();
352+
return functionKey.getFunction();
353+
}
354+
355+
private FunctionMetadata getMetadataFromNativeFunctionHandle(SqlFunctionHandle functionHandle)
356+
{
357+
NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle;
358+
Signature signature = nativeFunctionHandle.getSignature();
359+
SqlFunction function = getSqlFunctionFromSignature(signature);
318360

319361
// todo: verify this metadata return
320362
SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function;

presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -842,15 +842,17 @@ void VeloxQueryPlanConverterBase::toAggregations(
842842
auto pos = functionId.find(";", start + 1);
843843
if (pos == std::string::npos) {
844844
auto argumentType = functionId.substr(start + 1);
845-
aggregate.rawInputTypes.push_back(
845+
if (!argumentType.empty()) {
846+
aggregate.rawInputTypes.push_back(
846847
stringToType(argumentType, typeParser_));
848+
}
847849
break;
848850
}
849851

850852
auto argumentType = functionId.substr(start + 1, pos - start - 1);
851853
aggregate.rawInputTypes.push_back(
852854
stringToType(argumentType, typeParser_));
853-
pos = start + 1;
855+
start = pos;
854856
}
855857
}
856858
} else {

0 commit comments

Comments
 (0)