Skip to content

Commit f5d79c1

Browse files
authored
[SW-2639] Expose Fields of Model Output on H2OMOJOModel Classes as Getters (#2692)
* [SW-2639] Expose Fields of Model Output on H2OMOJOModel Classes as Getters * address review comment * Fix tests * address review comments * throw exception in tests if field cannot be parsed. * Fix tests * spotless apply * modify after merge of word2vec * fix double array parsing
1 parent 1716d8f commit f5d79c1

32 files changed

+492
-131
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,6 @@ jenkins/docker/regular-tests/Dockerfile
8686
# Generated code and documentation
8787
src-gen/
8888
doc/src/site/sphinx/parameters
89+
doc/src/site/sphinx/model_details
8990
doc/src/site/sphinx/metrics
9091
doc/src/site/sphinx/configuration/configuration_properties.rst

api-generation/src/main/scala/ai/h2o/sparkling/api/generation/MOJOModelAPIRunner.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ object MOJOModelAPIRunner
4545
entityName = algorithmContext.entityName + "MOJOModel")
4646
}
4747

48-
for ((mojoContext, parameterContext) <- mojoConfiguration.zip(parametersConfiguration)) {
49-
val content = mojoTemplates(languageExtension)(mojoContext, parameterContext)
48+
for (((mojoContext, parameterContext), modelOutputContext) <- mojoConfiguration
49+
.zip(parametersConfiguration)
50+
.zip(modelOutputConfiguration)) {
51+
val content = mojoTemplates(languageExtension)(mojoContext, parameterContext, modelOutputContext)
5052
writeResultToFile(content, mojoContext, languageExtension, destinationDir)
5153
}
5254

api-generation/src/main/scala/ai/h2o/sparkling/api/generation/common/AlgorithmConfigurations.scala

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,22 @@ import hex.glm.GLMModel.GLMParameters
2626
import hex.kmeans.KMeansModel.KMeansParameters
2727
import hex.schemas.CoxPHV3.CoxPHParametersV3
2828
import hex.rulefit.RuleFitModel.RuleFitParameters
29+
import hex.schemas.CoxPHModelV3.CoxPHModelOutputV3
30+
import hex.schemas.DRFModelV3.DRFModelOutputV3
31+
import hex.schemas.DeepLearningModelV3.DeepLearningModelOutputV3
32+
import hex.schemas.GAMModelV3.GAMModelOutputV3
33+
import hex.schemas.GBMModelV3.GBMModelOutputV3
34+
import hex.schemas.GLMModelV3.GLMModelOutputV3
35+
import hex.schemas.IsolationForestModelV3.IsolationForestModelOutputV3
36+
import hex.schemas.KMeansModelV3.KMeansModelOutputV3
37+
import hex.schemas.RuleFitModelV3.RuleFitModelOutputV3
2938
import hex.schemas.RuleFitV3.RuleFitParametersV3
39+
import hex.schemas.XGBoostModelV3.XGBoostModelOutputV3
3040
import hex.schemas.{DRFV3, DeepLearningV3, GAMV3, GBMV3, GLMV3, IsolationForestV3, KMeansV3, XGBoostV3}
3141
import hex.tree.drf.DRFModel.DRFParameters
3242
import hex.tree.gbm.GBMModel.GBMParameters
3343
import hex.tree.isofor.IsolationForestModel.IsolationForestParameters
44+
import hex.tree.xgboost.XGBoostModel
3445
import hex.tree.xgboost.XGBoostModel.XGBoostParameters
3546

3647
trait AlgorithmConfigurations extends ConfigurationsBase {
@@ -85,7 +96,7 @@ trait AlgorithmConfigurations extends ConfigurationsBase {
8596
type KMeansParamsV3 = KMeansV3.KMeansParametersV3
8697

8798
val explicitDefaultValues =
88-
Map[String, Any]("max_w2" -> 3.402823e38f, "response_column" -> "label", "model_id" -> null)
99+
Map[String, Any]("max_w2" -> 3.402823e38f, "response_column" -> "label", "model_id" -> null, "lambda" -> null)
89100

90101
val noDeprecation = Seq.empty
91102

@@ -191,4 +202,25 @@ trait AlgorithmConfigurations extends ConfigurationsBase {
191202
"ai.h2o.sparkling.ml.algos",
192203
parametersToCheck)
193204
}
205+
206+
override def modelOutputConfiguration: Seq[ModelOutputSubstitutionContext] = super.modelOutputConfiguration ++ {
207+
val modelOutputs = Seq[(String, Class[_])](
208+
("H2OXGBoostModelOutputs", classOf[XGBoostModelOutputV3]),
209+
("H2OGBMModelOutputs", classOf[GBMModelOutputV3]),
210+
("H2ODRFModelOutputs", classOf[DRFModelOutputV3]),
211+
("H2OGLMModelOutputs", classOf[GLMModelOutputV3]),
212+
("H2OGAMModelOutputs", classOf[GAMModelOutputV3]),
213+
("H2ODeepLearningModelOutputs", classOf[DeepLearningModelOutputV3]),
214+
("H2ORuleFitModelOutputs", classOf[RuleFitModelOutputV3]),
215+
("H2OKMeansModelOutputs", classOf[KMeansModelOutputV3]),
216+
("H2OCoxPHModelOutputs", classOf[CoxPHModelOutputV3]),
217+
("H2OIsolationForestModelOutputs", classOf[IsolationForestModelOutputV3]))
218+
219+
for ((outputEntityName, h2oParametersClass: Class[_]) <- modelOutputs)
220+
yield ModelOutputSubstitutionContext(
221+
"ai.h2o.sparkling.ml.outputs",
222+
outputEntityName,
223+
h2oParametersClass,
224+
Seq.empty)
225+
}
194226
}

api-generation/src/main/scala/ai/h2o/sparkling/api/generation/common/ConfigurationsBase.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,6 @@ trait ConfigurationsBase {
3232
def algorithmConfiguration: Seq[AlgorithmSubstitutionContext] = Seq.empty
3333

3434
def parametersConfiguration: Seq[ParameterSubstitutionContext] = Seq.empty
35+
36+
def modelOutputConfiguration: Seq[ModelOutputSubstitutionContext] = Seq.empty
3537
}

api-generation/src/main/scala/ai/h2o/sparkling/api/generation/common/FeatureEstimatorConfigurations.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ package ai.h2o.sparkling.api.generation.common
2020
import hex.deeplearning.DeepLearningModel.DeepLearningParameters
2121
import hex.glrm.GLRMModel.GLRMParameters
2222
import hex.pca.PCAModel.PCAParameters
23+
import hex.schemas.DeepLearningModelV3.DeepLearningModelOutputV3
24+
import hex.schemas.GLRMModelV3.GLRMModelOutputV3
25+
import hex.schemas.PCAModelV3.PCAModelOutputV3
26+
import hex.schemas.Word2VecModelV3.Word2VecModelOutputV3
2327
import hex.schemas.{DeepLearningV3, GLRMV3, PCAV3, Word2VecV3}
2428
import hex.word2vec.Word2VecModel.Word2VecParameters
2529

@@ -123,4 +127,19 @@ trait FeatureEstimatorConfigurations extends ConfigurationsBase {
123127
algorithmType,
124128
specificMetricsClass = metricsClass)
125129
}
130+
131+
override def modelOutputConfiguration: Seq[ModelOutputSubstitutionContext] = super.modelOutputConfiguration ++ {
132+
val modelOutputs = Seq[(String, Class[_])](
133+
("H2OAutoEncoderModelOutputs", classOf[DeepLearningModelOutputV3]),
134+
("H2OPCAModelOutputs", classOf[PCAModelOutputV3]),
135+
("H2OGLRMModelOutputs", classOf[GLRMModelOutputV3]),
136+
("H2OWord2VecModelOutputs", classOf[Word2VecModelOutputV3]))
137+
138+
for ((outputEntityName, h2oParametersClass: Class[_]) <- modelOutputs)
139+
yield ModelOutputSubstitutionContext(
140+
"ai.h2o.sparkling.ml.outputs",
141+
outputEntityName,
142+
h2oParametersClass,
143+
Seq.empty)
144+
}
126145
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package ai.h2o.sparkling.api.generation.common
19+
20+
object IgnoredOutputs {
21+
val ignoredTypes: Set[String] = Set("FrameKeyV3", "FrameKeyV3[]")
22+
23+
val implementedInParent: Seq[String] = Seq(
24+
"names",
25+
"original_names",
26+
"column_types",
27+
"domains",
28+
"cross_validation_models",
29+
"model_category",
30+
"scoring_history",
31+
"training_metrics",
32+
"validation_metrics",
33+
"cross_validation_metrics",
34+
"cross_validation_metrics_summary",
35+
"cv_scoring_history",
36+
"reproducibility_information_table",
37+
"model_summary",
38+
"start_time",
39+
"end_time",
40+
"run_time",
41+
"default_threshold")
42+
43+
val ignored: Seq[String] = Seq("status", "help", "__meta")
44+
45+
def all(mojoModel: String): Seq[String] = implementedInParent ++ ignored ++ {
46+
mojoModel match {
47+
case "H2OGLRMMOJOModel" => Seq("representation_name") // Collision with a parameter
48+
case "H2OWord2VecMOJOModel" => Seq("epochs") // Collision with a parameter
49+
case _ => Seq.empty
50+
}
51+
}
52+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package ai.h2o.sparkling.api.generation.common
19+
20+
case class ModelOutputSubstitutionContext(
21+
namespace: String,
22+
entityName: String,
23+
h2oSchemaClass: Class[_],
24+
ignoredOutputs: Seq[String])
25+
extends SubstitutionContextBase
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package ai.h2o.sparkling.api.generation.common
19+
20+
import water.api.API
21+
22+
trait OutputResolver {
23+
def resolveOutputs(outputSubstitutionContext: ModelOutputSubstitutionContext): Seq[Parameter] = {
24+
val h2oSchemaClass = outputSubstitutionContext.h2oSchemaClass
25+
26+
val outputs = h2oSchemaClass.getFields
27+
.filterNot(_.getAnnotation(classOf[API]) == null)
28+
.map { field =>
29+
Parameter(
30+
ParameterNameConverter.convertFromH2OToSW(field.getName),
31+
field.getName,
32+
if (field.getType.isPrimitive) 0 else null,
33+
field.getType,
34+
field.getAnnotation(classOf[API]).help())
35+
}
36+
outputs
37+
}
38+
}

api-generation/src/main/scala/ai/h2o/sparkling/api/generation/python/MOJOModelTemplate.scala

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,27 @@
1818
package ai.h2o.sparkling.api.generation.python
1919

2020
import ai.h2o.sparkling.api.generation.common._
21+
import ai.h2o.sparkling.api.generation.scala.MOJOModelTemplate.resolveOutputs
2122

2223
object MOJOModelTemplate
23-
extends ((AlgorithmSubstitutionContext, ParameterSubstitutionContext) => String)
24+
extends ((AlgorithmSubstitutionContext, ParameterSubstitutionContext, ModelOutputSubstitutionContext) => String)
2425
with PythonEntityTemplate
25-
with ParameterResolver {
26+
with ParameterResolver
27+
with OutputResolver {
2628

2729
def apply(
2830
algorithmSubstitutionContext: AlgorithmSubstitutionContext,
29-
parameterSubstitutionContext: ParameterSubstitutionContext): String = {
31+
parameterSubstitutionContext: ParameterSubstitutionContext,
32+
outputSubstitutionContext: ModelOutputSubstitutionContext): String = {
33+
3034
val parameters = resolveParameters(parameterSubstitutionContext)
31-
.filter(parameter =>
32-
!IgnoredParameters.ignoredInMOJOs(algorithmSubstitutionContext.entityName).contains(parameter.h2oName))
35+
.filterNot(parameter =>
36+
IgnoredParameters.ignoredInMOJOs(algorithmSubstitutionContext.entityName).contains(parameter.h2oName))
37+
38+
val outputs = resolveOutputs(outputSubstitutionContext)
39+
.filterNot(output => IgnoredOutputs.all(algorithmSubstitutionContext.entityName).contains(output.h2oName))
40+
.filterNot(output => IgnoredOutputs.ignoredTypes(output.dataType.getSimpleName))
41+
3342
val entityName = algorithmSubstitutionContext.entityName
3443
val namespace = algorithmSubstitutionContext.namespace
3544
val algorithmType = algorithmSubstitutionContext.algorithmType
@@ -72,8 +81,9 @@ object MOJOModelTemplate
7281
| else:
7382
| raise TypeError("Invalid type.")
7483
|
75-
|""".stripMargin ++
76-
generateGetterMethods(parameters)
84+
|""".stripMargin +
85+
generateGetterMethods(parameters) + "\n\n" +
86+
generateGetterMethods(outputs)
7787
}
7888
}
7989

@@ -93,6 +103,7 @@ object MOJOModelTemplate
93103
private def generateValueConversion(parameter: Parameter): String = parameter.dataType match {
94104
case x if x.isArray && x.getComponentType.isArray() => "H2OTypeConverters.scala2DArrayToPython2DArray(value)"
95105
case x if x.isArray => "H2OTypeConverters.scalaArrayToPythonArray(value)"
106+
case x if x.getSimpleName == "TwoDimTableV3" => "H2OTypeConverters.scalaToPythonDataFrame(value)"
96107
case _ => "value"
97108
}
98109
}

api-generation/src/main/scala/ai/h2o/sparkling/api/generation/scala/MOJOModelTemplate.scala

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,30 @@ import ai.h2o.sparkling.api.generation.common._
2121
import ai.h2o.sparkling.api.generation.scala.ParametersTemplate.resolveParameterConstructorMethodType
2222

2323
object MOJOModelTemplate
24-
extends ((AlgorithmSubstitutionContext, ParameterSubstitutionContext) => String)
24+
extends ((AlgorithmSubstitutionContext, ParameterSubstitutionContext, ModelOutputSubstitutionContext) => String)
2525
with ParametersTemplateBase
2626
with ScalaEntityTemplate
27-
with ParameterResolver {
27+
with ParameterResolver
28+
with OutputResolver {
2829

2930
def apply(
3031
algorithmSubstitutionContext: AlgorithmSubstitutionContext,
31-
parameterSubstitutionContext: ParameterSubstitutionContext): String = {
32+
parameterSubstitutionContext: ParameterSubstitutionContext,
33+
outputSubstitutionContext: ModelOutputSubstitutionContext): String = {
3234

3335
val parameters = resolveParameters(parameterSubstitutionContext)
34-
.filter(parameter =>
35-
!IgnoredParameters.ignoredInMOJOs(algorithmSubstitutionContext.entityName).contains(parameter.h2oName))
36+
.filterNot(parameter =>
37+
IgnoredParameters.ignoredInMOJOs(algorithmSubstitutionContext.entityName).contains(parameter.h2oName))
38+
39+
val outputs = resolveOutputs(outputSubstitutionContext)
40+
.filterNot(output => IgnoredOutputs.all(algorithmSubstitutionContext.entityName).contains(output.h2oName))
41+
.filterNot(output => IgnoredOutputs.ignoredTypes(output.dataType.getSimpleName))
3642

3743
val explicitFieldImplementations = parameterSubstitutionContext.explicitFields.flatMap(_.mojoImplementation) ++
3844
parameterSubstitutionContext.deprecatedFields.flatMap(_.mojoImplementation)
3945

4046
val imports = Seq(
47+
"com.google.gson.JsonObject",
4148
"ai.h2o.sparkling.ml.params.ParameterConstructorMethods",
4249
"hex.genmodel.MojoModel",
4350
"org.apache.spark.expose.Logging") ++
@@ -70,10 +77,20 @@ object MOJOModelTemplate
7077
|${generateParameterDefinitions(parameters)}
7178
|
7279
| //
80+
| // Output definitions
81+
| //
82+
|${generateParameterDefinitions(outputs)}
83+
|
84+
| //
7385
| // Getters
7486
| //
7587
|${generateGetters(parameters)}
7688
|
89+
| //
90+
| // Output Getters
91+
| //
92+
|${generateGetters(outputs)}
93+
|
7794
| override private[sparkling] def setSpecificParams(h2oMojo: MojoModel): Unit = {
7895
| super.setSpecificParams(h2oMojo)
7996
| try {
@@ -86,6 +103,10 @@ object MOJOModelTemplate
86103
| }
87104
| }
88105
|
106+
| override private[sparkling] def setOutputParameters(outputSection: JsonObject): Unit = {
107+
|${generateOutputParameterAssignments(outputs)}
108+
| }
109+
|
89110
|${generateMetricsOverrides(algorithmSubstitutionContext.specificMetricsClass)}""".stripMargin
90111
}
91112

@@ -136,6 +157,38 @@ object MOJOModelTemplate
136157
.mkString("\n\n")
137158
}
138159

160+
def generateOutputParameterAssignments(outputs: Seq[Parameter]): String = {
161+
outputs
162+
.map { output =>
163+
val h2oName = output.h2oName
164+
val swName = output.swName
165+
val value = output.dataType.getSimpleName match {
166+
case "boolean" => s"""outputSection.get("$h2oName").getAsBoolean()"""
167+
case "byte" => s"""outputSection.get("$h2oName").getAsByte()"""
168+
case "short" => s"""outputSection.get("$h2oName").getAsShort()"""
169+
case "int" => s"""outputSection.get("$h2oName").getAsInt()"""
170+
case "long" => s"""outputSection.get("$h2oName").getAsLong()"""
171+
case "float" => s"""outputSection.get("$h2oName").getAsFloat()"""
172+
case "double" => s"""outputSection.get("$h2oName").getAsDouble()"""
173+
case "double[]" => s"""jsonFieldToDoubleArray(outputSection, "$h2oName")"""
174+
case "TwoDimTableV3" => s"""jsonFieldToDataFrame(outputSection, "$h2oName")"""
175+
}
176+
s""" if (outputSection.has("$h2oName")) {
177+
| try {
178+
| val extractedValue = $value
179+
| set("$swName", extractedValue)
180+
| } catch {
181+
| case e: Throwable if System.getProperty("spark.testing", "false") != "true" =>
182+
| logWarning("An error occurred during setting up the '$swName' parameter. The method " +
183+
| "get${swName.capitalize}() on the MOJO model object won't be able to provide the actual value.", e)
184+
| }
185+
| } else if (System.getProperty("spark.testing", "false") == "true") {
186+
| throw new AssertionError("The output field '$h2oName' in does not exist.")
187+
| }""".stripMargin
188+
}
189+
.mkString("\n\n")
190+
}
191+
139192
protected def resolveParameterConstructorMethod(dataType: Class[_], defaultValue: Any): String = {
140193
val rawPrefix = resolveParameterConstructorMethodType(dataType, defaultValue)
141194
val finalPrefix = if (defaultValue == null || dataType.isEnum) s"nullable${rawPrefix.capitalize}" else rawPrefix

0 commit comments

Comments
 (0)