Skip to content

Commit bbb13d5

Browse files
authored
[SW-2682] do constant check & row count in one iteration (#2730)
1 parent 233c41f commit bbb13d5

File tree

10 files changed

+237
-47
lines changed

10 files changed

+237
-47
lines changed

benchmarks/src/main/scala/ai/h2o/sparkling/benchmarks/DataFrameToH2OFrameConversionViaCsvFilesBenchmark.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package ai.h2o.sparkling.benchmarks
1919

20-
import java.net.URI
2120
import ai.h2o.sparkling.H2OFrame
2221
import org.apache.spark.sql.{DataFrame, SaveMode}
2322

benchmarks/src/main/scala/ai/h2o/sparkling/benchmarks/DataFrameToH2OFrameConversionViaCsvFilesIncludingS3LoadBenchmark.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package ai.h2o.sparkling.benchmarks
1919

20-
import org.apache.spark.sql.{DataFrame, SaveMode}
20+
import org.apache.spark.sql.DataFrame
2121

2222
class DataFrameToH2OFrameConversionViaCsvFilesIncludingS3LoadBenchmark(context: BenchmarkContext)
2323
extends DataFrameToH2OFrameConversionViaCsvFilesBenchmark(context) {

core/src/main/scala/ai/h2o/sparkling/H2OContext.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,21 @@ class H2OContext private[sparkling] (private val conf: H2OConf) extends H2OConte
159159
}
160160

161161
/** Transform DataFrame to H2OFrame */
162-
def asH2OFrame(df: DataFrame): H2OFrame = asH2OFrame(df, None)
162+
def asH2OFrame(df: DataFrame): H2OFrame = asH2OFrame(df, frameName = None)
163+
164+
def asH2OFrame(df: DataFrame, featureColumns: Seq[String]): H2OFrame =
165+
asH2OFrame(df, frameName = None, Some(featureColumns))
163166

164167
def asH2OFrame(df: DataFrame, frameName: String): H2OFrame = asH2OFrame(df, Option(frameName))
165168

166-
def asH2OFrame(df: DataFrame, frameName: Option[String]): H2OFrame = {
167-
withConversionDebugPrints(sparkContext, "Dataframe", SparkDataFrameConverter.toH2OFrame(this, df, frameName))
169+
def asH2OFrame(
170+
df: DataFrame,
171+
frameName: Option[String] = None,
172+
featureColumns: Option[Seq[String]] = None): H2OFrame = {
173+
withConversionDebugPrints(
174+
sparkContext,
175+
"Dataframe",
176+
SparkDataFrameConverter.toH2OFrame(this, df, frameName, featureColumns))
168177
}
169178

170179
/** Transforms Dataset[Supported type] to H2OFrame */
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package ai.h2o.sparkling.backend
18+
19+
case class PartitionStats(partitionSizes: Map[Int, Int], areFeatureColumnsConstant: Option[Boolean])
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package ai.h2o.sparkling.backend
18+
19+
import org.apache.spark.rdd.RDD
20+
import org.apache.spark.sql.Row
21+
22+
/**
23+
* Goes over RDD partitions counting records and checking if given set of columns has constant values
24+
*/
25+
private[backend] object PartitionStatsGenerator {
26+
27+
def getPartitionStats(rdd: RDD[Row], maybeColumnsForConstantCheck: Option[Seq[String]] = None): PartitionStats = {
28+
val partitionStats = rdd
29+
.mapPartitionsWithIndex {
30+
case (partitionIdx, iterator) =>
31+
maybeColumnsForConstantCheck
32+
.map(rowCountWithColumnsConstantCheck(partitionIdx, iterator, _))
33+
.getOrElse(rowCountWithoutColumnsConstantCheck(partitionIdx, iterator))
34+
}
35+
.fold((Map.empty, Set.empty))((a, b) => (a._1 ++ b._1, a._2 ++ b._2))
36+
37+
val areProvidedColumnsConstant = if (partitionStats._2.isEmpty || maybeColumnsForConstantCheck.isEmpty) {
38+
None
39+
} else {
40+
Some(partitionStats._2.size < 2)
41+
}
42+
PartitionStats(partitionStats._1, areProvidedColumnsConstant)
43+
}
44+
45+
private def rowCountWithoutColumnsConstantCheck(partitionIdx: Int, iterator: Iterator[Row]) =
46+
Iterator.single(Map(partitionIdx -> iterator.size), Set.empty)
47+
48+
private def rowCountWithColumnsConstantCheck(
49+
partitionIdx: Int,
50+
iterator: Iterator[Row],
51+
columnsForConstantCheck: Seq[String]) = {
52+
var atMostTwoDistinctColumnSetValues = Set[Map[String, Any]]()
53+
var recordCount = 0
54+
var constantCheckColumnsFlattened: Option[Seq[String]] = None
55+
while (iterator.hasNext) {
56+
val row = iterator.next()
57+
if (constantCheckColumnsFlattened.isEmpty) {
58+
constantCheckColumnsFlattened = Some(
59+
findFlattenedColumnNamesByPrefix(columnsForConstantCheck, row.schema.fieldNames))
60+
}
61+
if (atMostTwoDistinctColumnSetValues.size < 2) {
62+
atMostTwoDistinctColumnSetValues += row.getValuesMap(constantCheckColumnsFlattened.get)
63+
}
64+
recordCount += 1
65+
}
66+
Iterator.single(Map(partitionIdx -> recordCount), atMostTwoDistinctColumnSetValues)
67+
}
68+
69+
private def findFlattenedColumnNamesByPrefix(
70+
columnPrefixes: Seq[String],
71+
flattenedFields: Array[String]): Seq[String] =
72+
columnPrefixes.flatMap(
73+
colNameBeforeFlatten =>
74+
flattenedFields
75+
.filter(col => col == colNameBeforeFlatten || col.startsWith(colNameBeforeFlatten + ".")))
76+
77+
}

core/src/main/scala/ai/h2o/sparkling/backend/Writer.scala

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,21 @@
1717

1818
package ai.h2o.sparkling.backend
1919

20-
import java.io.Closeable
21-
22-
import ai.h2o.sparkling.{H2OConf, H2OFrame}
2320
import ai.h2o.sparkling.H2OFrame.query
21+
import ai.h2o.sparkling.backend.converters.{CategoricalDomainBuilder, TimeZoneConverter}
2422
import ai.h2o.sparkling.backend.utils.RestApiUtils.getClusterEndpoint
2523
import ai.h2o.sparkling.extensions.rest.api.Paths
2624
import ai.h2o.sparkling.extensions.rest.api.schema.UploadPlanV3
27-
import ai.h2o.sparkling.backend.converters.{CategoricalDomainBuilder, TimeZoneConverter}
28-
import ai.h2o.sparkling.extensions.serde.{ChunkAutoBufferWriter, ExpectedTypes, SerdeUtils}
25+
import ai.h2o.sparkling.extensions.serde.{ChunkAutoBufferWriter, SerdeUtils}
2926
import ai.h2o.sparkling.utils.ScalaUtils.withResource
3027
import ai.h2o.sparkling.utils.SparkSessionUtils
31-
import org.apache.spark.rdd.RDD
28+
import ai.h2o.sparkling.{H2OConf, H2OFrame}
3229
import org.apache.spark.sql.Row
3330
import org.apache.spark.sql.types._
3431
import org.apache.spark.{ExposeUtils, TaskContext, ml, mllib}
3532

33+
import java.io.Closeable
34+
3635
private[backend] class Writer(nodeDesc: NodeDesc, metadata: WriterMetadata, numRows: Int, chunkId: Int)
3736
extends Closeable {
3837

@@ -91,16 +90,22 @@ private[backend] object Writer {
9190

9291
def convert(rdd: H2OAwareRDD[Row], colNames: Array[String], metadata: WriterMetadata): H2OFrame = {
9392
H2OFrame.initializeFrame(metadata.conf, metadata.frameId, colNames)
94-
val partitionSizes = getNonEmptyPartitionSizes(rdd)
95-
val nonEmptyPartitions = getNonEmptyPartitions(partitionSizes)
93+
val partitionStats = PartitionStatsGenerator.getPartitionStats(rdd, metadata.featureColsForConstCheck)
94+
if (partitionStats.areFeatureColumnsConstant.getOrElse(false)) {
95+
throw new IllegalArgumentException(s"H2O could not use any of the specified input" +
96+
s" columns: '${metadata.featureColsForConstCheck.get.mkString(", ")}' because they are all constants. H2O requires at least one non-constant column.")
97+
}
98+
99+
val partitionSizes = partitionStats.partitionSizes
100+
val nonEmptyPartitions = partitionSizes.filter(_._2 > 0).keys.toSeq.sorted
96101

97102
val uploadPlan = getUploadPlan(metadata.conf, nonEmptyPartitions.length)
98103
val operation: SparkJob = perDataFramePartition(metadata, uploadPlan, nonEmptyPartitions, partitionSizes)
99104
val rows = SparkSessionUtils.active.sparkContext.runJob(rdd, operation, nonEmptyPartitions)
100-
val res = new Array[Long](nonEmptyPartitions.size)
101-
rows.foreach { case (chunkIdx, numRows) => res(chunkIdx) = numRows }
105+
val rowsPerChunk = new Array[Long](nonEmptyPartitions.size)
106+
rows.foreach { case (chunkIdx, numRows) => rowsPerChunk(chunkIdx) = numRows }
102107
val types = SerdeUtils.expectedTypesToVecTypes(metadata.expectedTypes, metadata.maxVectorSizes)
103-
H2OFrame.finalizeFrame(metadata.conf, metadata.frameId, res, types)
108+
H2OFrame.finalizeFrame(metadata.conf, metadata.frameId, rowsPerChunk, types)
104109
H2OFrame(metadata.frameId)
105110
}
106111

@@ -164,24 +169,6 @@ private[backend] object Writer {
164169
}
165170
}
166171

167-
private def getNonEmptyPartitionSizes[T](rdd: RDD[T]): Map[Int, Int] = {
168-
rdd
169-
.mapPartitionsWithIndex {
170-
case (idx, it) =>
171-
if (it.nonEmpty) {
172-
Iterator.single((idx, it.size))
173-
} else {
174-
Iterator.empty
175-
}
176-
}
177-
.collect()
178-
.toMap
179-
}
180-
181-
private def getNonEmptyPartitions(partitionSizes: Map[Int, Int]): Seq[Int] = {
182-
partitionSizes.keys.toSeq.sorted
183-
}
184-
185172
private def getUploadPlan(conf: H2OConf, numberOfPartitions: Int): UploadPlan = {
186173
val endpoint = getClusterEndpoint(conf)
187174
val parameters = Map("number_of_chunks" -> numberOfPartitions)

core/src/main/scala/ai/h2o/sparkling/backend/WriterMetadata.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ case class WriterMetadata(
2626
frameId: String,
2727
expectedTypes: Array[ExpectedType],
2828
maxVectorSizes: Array[Int],
29-
timezone: TimeZone)
29+
timezone: TimeZone,
30+
featureColsForConstCheck: Option[Seq[String]])

core/src/main/scala/ai/h2o/sparkling/backend/converters/SparkDataFrameConverter.scala

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import ai.h2o.sparkling.utils.SparkSessionUtils
2323
import ai.h2o.sparkling.{H2OContext, H2OFrame, SparkTimeZone}
2424
import org.apache.spark.expose.Logging
2525
import org.apache.spark.sql.DataFrame
26+
import org.apache.spark.storage.StorageLevel
2627

2728
object SparkDataFrameConverter extends Logging {
2829

@@ -40,23 +41,42 @@ object SparkDataFrameConverter extends Logging {
4041
spark.baseRelationToDataFrame(relation)
4142
}
4243

43-
def toH2OFrame(hc: H2OContext, dataFrame: DataFrame, frameKeyName: Option[String]): H2OFrame = {
44+
def toH2OFrame(
45+
hc: H2OContext,
46+
dataFrame: DataFrame,
47+
frameKeyName: Option[String] = None,
48+
featureColsForConstCheck: Option[Seq[String]] = None): H2OFrame = {
4449
val df = dataFrame.toDF() // Because of PySparkling, we can receive Dataset[Primitive] in this method, ensure that
4550
// we are dealing with Dataset[Row]
4651
val flatDataFrame = flattenDataFrame(df)
4752
val schema = flatDataFrame.schema
48-
val rdd = flatDataFrame.rdd // materialized the data frame
53+
val rdd = flatDataFrame.rdd
54+
if (hc.getConf.runsInInternalClusterMode) {
55+
rdd.persist(StorageLevel.DISK_ONLY)
56+
} else {
57+
rdd.persist()
58+
}
4959

5060
val elemMaxSizes = collectMaxElementSizes(rdd, schema)
5161
val vecIndices = collectVectorLikeTypes(schema).toArray
52-
val flattenSchema = expandedSchema(schema, elemMaxSizes)
53-
val colNames = flattenSchema.map(field => "\"" + field.name + "\"").toArray
62+
val flattenedSchema = expandedSchema(schema, elemMaxSizes)
63+
val h2oColNames = flattenedSchema.map(field => "\"" + field.name + "\"").toArray
5464
val maxVecSizes = vecIndices.map(elemMaxSizes(_))
5565

5666
val expectedTypes = DataTypeConverter.determineExpectedTypes(schema)
5767

5868
val uniqueFrameId = frameKeyName.getOrElse("frame_rdd_" + rdd.id + scala.util.Random.nextInt())
59-
val metadata = WriterMetadata(hc.getConf, uniqueFrameId, expectedTypes, maxVecSizes, SparkTimeZone.current())
60-
Writer.convert(new H2OAwareRDD(hc.getH2ONodes(), rdd), colNames, metadata)
69+
val metadata =
70+
WriterMetadata(
71+
hc.getConf,
72+
uniqueFrameId,
73+
expectedTypes,
74+
maxVecSizes,
75+
SparkTimeZone.current(),
76+
featureColsForConstCheck)
77+
val result = Writer.convert(new H2OAwareRDD(hc.getH2ONodes(), rdd), h2oColNames, metadata)
78+
rdd.unpersist(blocking = false)
79+
result
6180
}
81+
6282
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package ai.h2o.sparkling.backend
18+
19+
import ai.h2o.sparkling.SparkTestContext
20+
import org.apache.spark.sql.SparkSession
21+
import org.junit.runner.RunWith
22+
import org.scalatest.junit.JUnitRunner
23+
import org.scalatest.{FunSuite, Matchers, OptionValues}
24+
25+
@RunWith(classOf[JUnitRunner])
26+
class PartitionStatsGeneratorTestSuite extends FunSuite with SparkTestContext with Matchers with OptionValues {
27+
28+
override def createSparkSession(): SparkSession = sparkSession("local[*]")
29+
30+
import spark.implicits._
31+
32+
private final val dataset =
33+
Seq((1, "John", "Doe", 1999), (2, "John", "Doe", 1999), (3, "Jane", "Doe", 1999), (4, "Jane", "Doe", 1999))
34+
35+
private val datasetCols = Seq("id", "name", "surname", "birthYear")
36+
37+
test("should correctly detect constant columns") {
38+
val input = dataset.toDF(datasetCols: _*).rdd
39+
40+
val resultOnConstantColumn = PartitionStatsGenerator.getPartitionStats(input, Some(Seq("surname")))
41+
val resultOnConstantColumns = PartitionStatsGenerator.getPartitionStats(input, Some(Seq("surname", "birthYear")))
42+
val resultOnNotConstantColumn = PartitionStatsGenerator.getPartitionStats(input, Some(Seq("name")))
43+
val resultOnNotConstantColumns = PartitionStatsGenerator.getPartitionStats(input, Some(Seq("name", "id")))
44+
val resultWhereOnlyOneColumnIsConstant =
45+
PartitionStatsGenerator.getPartitionStats(input, Some(Seq("surname", "id")))
46+
47+
resultOnConstantColumn.areFeatureColumnsConstant.value shouldBe true
48+
resultOnConstantColumns.areFeatureColumnsConstant.value shouldBe true
49+
resultOnNotConstantColumn.areFeatureColumnsConstant.value shouldBe false
50+
resultOnNotConstantColumns.areFeatureColumnsConstant.value shouldBe false
51+
resultWhereOnlyOneColumnIsConstant.areFeatureColumnsConstant.value shouldBe false
52+
}
53+
54+
test("should correctly count values") {
55+
val inputWithTwoPartitions = dataset.toDF(datasetCols: _*).rdd.coalesce(numPartitions = 2)
56+
57+
val result = PartitionStatsGenerator.getPartitionStats(inputWithTwoPartitions, Some(Seq("id")))
58+
59+
result.areFeatureColumnsConstant.value shouldBe false
60+
result.partitionSizes should have size 2
61+
result.partitionSizes should contain theSameElementsAs Map(0 -> 2, 1 -> 2)
62+
}
63+
64+
test("should not fail given an empty dataset") {
65+
val emptyInput = Seq.empty[String].toDF.rdd
66+
67+
val result = PartitionStatsGenerator.getPartitionStats(emptyInput, Some(Seq("id")))
68+
69+
result.areFeatureColumnsConstant shouldBe None
70+
}
71+
72+
test("should not fail given one element dataset") {
73+
val oneElementInput = Seq(dataset.head).toDF(datasetCols: _*).rdd
74+
75+
val result = PartitionStatsGenerator.getPartitionStats(oneElementInput, Some(Seq("id")))
76+
77+
result.areFeatureColumnsConstant.value shouldBe true
78+
result.partitionSizes should have size 1
79+
result.partitionSizes shouldBe Map(0 -> 1)
80+
}
81+
82+
}

ml/src/main/scala/ai/h2o/sparkling/ml/algos/H2OAlgoCommonUtils.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import ai.h2o.sparkling.backend.utils.H2OFrameLifecycle
2020
import ai.h2o.sparkling.ml.models.H2OBinaryModel
2121
import ai.h2o.sparkling.ml.utils.EstimatorCommonUtils
2222
import ai.h2o.sparkling.{H2OContext, H2OFrame}
23-
import org.apache.spark.sql.{DataFrame, Dataset}
23+
import org.apache.spark.sql.{Column, DataFrame, Dataset}
2424
import org.apache.spark.sql.functions.col
2525

2626
trait H2OAlgoCommonUtils extends EstimatorCommonUtils with H2OFrameLifecycle {
@@ -76,16 +76,12 @@ trait H2OAlgoCommonUtils extends EstimatorCommonUtils with H2OFrameLifecycle {
7676

7777
val featureColumns = getInputCols().map(sanitize).map(col)
7878

79-
if (dataset.select(featureColumns: _*).distinct().count() == 1) {
80-
throw new IllegalArgumentException(s"H2O could not use any of the specified input" +
81-
s" columns: '${getInputCols().mkString(", ")}' because they are all constants. H2O requires at least one non-constant column.")
82-
}
8379
val excludedColumns = excludedCols.map(sanitize).map(col)
8480
val additionalColumns = getAdditionalCols().map(sanitize).map(col)
8581
val columns = (featureColumns ++ excludedColumns ++ additionalColumns).distinct
8682
val h2oContext = H2OContext.ensure(
8783
"H2OContext needs to be created in order to train the model. Please create one as H2OContext.getOrCreate().")
88-
val trainFrame = h2oContext.asH2OFrame(dataset.select(columns: _*).toDF())
84+
val trainFrame = h2oContext.asH2OFrame(dataset.select(columns: _*).toDF(), getInputCols())
8985

9086
trainFrame.convertColumnsToStrings(getColumnsToString())
9187

0 commit comments

Comments
 (0)