|
28 | 28 | import com.facebook.presto.spi.NodeManager; |
29 | 29 | import com.facebook.presto.spi.PrestoException; |
30 | 30 | import com.facebook.presto.spi.function.AggregationFunctionImplementation; |
| 31 | +import com.facebook.presto.spi.function.AggregationFunctionMetadata; |
31 | 32 | import com.facebook.presto.spi.function.AlterRoutineCharacteristics; |
32 | 33 | import com.facebook.presto.spi.function.FunctionHandle; |
33 | 34 | import com.facebook.presto.spi.function.FunctionMetadata; |
|
40 | 41 | import com.facebook.presto.spi.function.SqlFunctionHandle; |
41 | 42 | import com.facebook.presto.spi.function.SqlFunctionId; |
42 | 43 | import com.facebook.presto.spi.function.SqlFunctionSupplier; |
| 44 | +import com.facebook.presto.spi.function.SqlInvokedAggregationFunctionImplementation; |
43 | 45 | import com.facebook.presto.spi.function.SqlInvokedFunction; |
44 | 46 | import com.facebook.presto.spi.function.TypeVariableConstraint; |
45 | 47 | import com.google.common.base.Suppliers; |
|
59 | 61 | import java.util.concurrent.ConcurrentHashMap; |
60 | 62 | import java.util.concurrent.TimeUnit; |
61 | 63 | import java.util.function.Supplier; |
| 64 | +import java.util.stream.Collectors; |
62 | 65 |
|
| 66 | +import static com.facebook.presto.common.type.TypeSignatureUtils.resolveIntermediateType; |
63 | 67 | import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR; |
64 | 68 | import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; |
65 | 69 | import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; |
@@ -141,19 +145,52 @@ public final AggregationFunctionImplementation getAggregateFunctionImplementatio |
141 | 145 | checkArgument(functionHandle instanceof SqlFunctionHandle, "Unsupported FunctionHandle type '%s'", functionHandle.getClass().getSimpleName()); |
142 | 146 |
|
143 | 147 | SqlFunctionHandle sqlFunctionHandle = (SqlFunctionHandle) functionHandle; |
| 148 | + if (functionHandle instanceof NativeFunctionHandle) { |
| 149 | + NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle; |
| 150 | + return processNativeFunctionHandle(nativeFunctionHandle, typeManager); |
| 151 | + } |
| 152 | + else { |
| 153 | + return processSqlFunctionHandle(sqlFunctionHandle, typeManager); |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + private AggregationFunctionImplementation processNativeFunctionHandle(NativeFunctionHandle nativeFunctionHandle, TypeManager typeManager) |
| 158 | + { |
| 159 | + if (!aggregationImplementationByHandle.containsKey(nativeFunctionHandle)) { |
| 160 | + Signature signature = nativeFunctionHandle.getSignature(); |
| 161 | + SqlFunction function = getSqlFunctionFromSignature(signature); |
| 162 | + SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function; |
| 163 | + |
| 164 | + checkArgument( |
| 165 | + sqlFunction.getAggregationMetadata().isPresent(), |
| 166 | + "Need aggregationMetadata to get aggregation function implementation"); |
| 167 | + |
| 168 | + AggregationFunctionMetadata aggregationMetadata = sqlFunction.getAggregationMetadata().get(); |
| 169 | + TypeSignature intermediateType = aggregationMetadata.getIntermediateType(); |
| 170 | + List<TypeSignature> typeSignatures = sqlFunction.getParameters().stream().map(Parameter::getType).collect(Collectors.toList()); |
| 171 | + TypeSignature resolvedIntermediateType = resolveIntermediateType(intermediateType, typeSignatures, signature.getArgumentTypes()); |
| 172 | + aggregationImplementationByHandle.put( |
| 173 | + nativeFunctionHandle, |
| 174 | + new SqlInvokedAggregationFunctionImplementation( |
| 175 | + typeManager.getType(resolvedIntermediateType), |
| 176 | + typeManager.getType(signature.getReturnType()), |
| 177 | + aggregationMetadata.isOrderSensitive())); |
| 178 | + } |
| 179 | + return aggregationImplementationByHandle.get(nativeFunctionHandle); |
| 180 | + } |
144 | 181 |
|
145 | | - // Cache results if applicable |
| 182 | + private AggregationFunctionImplementation processSqlFunctionHandle(SqlFunctionHandle sqlFunctionHandle, TypeManager typeManager) |
| 183 | + { |
146 | 184 | if (!aggregationImplementationByHandle.containsKey(sqlFunctionHandle)) { |
147 | 185 | SqlFunctionId functionId = sqlFunctionHandle.getFunctionId(); |
148 | | - if (!latestFunctions.containsKey(functionId)) { |
| 186 | + if (!memoizedFunctionsSupplier.get().containsKey(functionId)) { |
149 | 187 | throw new PrestoException(GENERIC_USER_ERROR, format("Function '%s' is missing from cache", functionId.getId())); |
150 | 188 | } |
151 | 189 |
|
152 | 190 | aggregationImplementationByHandle.put( |
153 | 191 | sqlFunctionHandle, |
154 | | - sqlInvokedFunctionToAggregationImplementation(latestFunctions.get(functionId), typeManager)); |
| 192 | + sqlInvokedFunctionToAggregationImplementation(memoizedFunctionsSupplier.get().get(functionId), typeManager)); |
155 | 193 | } |
156 | | - |
157 | 194 | return aggregationImplementationByHandle.get(sqlFunctionHandle); |
158 | 195 | } |
159 | 196 |
|
@@ -303,18 +340,23 @@ public final FunctionHandle getFunctionHandle(Optional<? extends FunctionNamespa |
303 | 340 | return functionHandle; |
304 | 341 | } |
305 | 342 |
|
306 | | - private FunctionMetadata getMetadataFromNativeFunctionHandle(SqlFunctionHandle functionHandle) |
| 343 | + private SqlFunction getSqlFunctionFromSignature(Signature signature) |
307 | 344 | { |
308 | | - NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle; |
309 | | - Signature signature = nativeFunctionHandle.getSignature(); |
310 | 345 | SqlFunctionSupplier functionKey; |
311 | 346 | try { |
312 | 347 | functionKey = specializedFunctionKeyCache.getUnchecked(signature); |
313 | 348 | } |
314 | 349 | catch (UncheckedExecutionException e) { |
315 | | - throw convertToPrestoException(e, format("Error getting FunctionMetadata for handle: %s", functionHandle)); |
| 350 | + throw convertToPrestoException(e, format("Error getting FunctionMetadata for signature: %s", signature)); |
316 | 351 | } |
317 | | - SqlFunction function = functionKey.getFunction(); |
| 352 | + return functionKey.getFunction(); |
| 353 | + } |
| 354 | + |
| 355 | + private FunctionMetadata getMetadataFromNativeFunctionHandle(SqlFunctionHandle functionHandle) |
| 356 | + { |
| 357 | + NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle; |
| 358 | + Signature signature = nativeFunctionHandle.getSignature(); |
| 359 | + SqlFunction function = getSqlFunctionFromSignature(signature); |
318 | 360 |
|
319 | 361 | // todo: verify this metadata return |
320 | 362 | SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function; |
|
0 commit comments