diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 6fd20ceb562b..e25fdc3e05ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -343,16 +343,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - if (ReadWriteUtils.localSavingModeState.get()) { - throw new UnsupportedOperationException( - "FPGrowthModel does not support saving to local filesystem path." - ) - } val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords) DefaultParamsWriter.saveMetadata(instance, path, sparkSession, extraMetadata = Some(extraMetadata)) val dataPath = new Path(path, "data").toString - instance.freqItemsets.write.parquet(dataPath) + ReadWriteUtils.saveDataFrame(dataPath, instance.freqItemsets) } } @@ -362,11 +357,6 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { private val className = classOf[FPGrowthModel].getName override def load(path: String): FPGrowthModel = { - if (ReadWriteUtils.localSavingModeState.get()) { - throw new UnsupportedOperationException( - "FPGrowthModel does not support loading from local filesystem path." - ) - } implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) @@ -378,7 +368,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { (metadata.metadata \ "numTrainingRecords").extract[Long] } val dataPath = new Path(path, "data").toString - val frequentItems = sparkSession.read.parquet(dataPath) + val frequentItems = ReadWriteUtils.loadDataFrame(dataPath, sparkSession) val itemSupport = if (numTrainingRecords == 0L) { Map.empty[Any, Double] } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala index 06de43260b30..c64e8d3007e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.util -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{CLASS_NAME, LABEL_COLUMN, NUM_CLASSES} import org.apache.spark.ml.PredictorParams @@ -28,6 +28,7 @@ import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -212,4 +213,51 @@ private[spark] object DatasetUtils extends Logging { dataset.select(columnToVector(dataset, vectorCol)).head().getAs[Vector](0).size } } + + private[ml] def toArrowBatchRDD( + dataFrame: DataFrame, + timeZoneId: String): RDD[Array[Byte]] = { + dataFrame match { + case df: org.apache.spark.sql.classic.DataFrame => + val spark = df.sparkSession + val schema = df.schema + val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch + df.queryExecution.executedPlan.execute().mapPartitionsInternal { iter => + val context = TaskContext.get() + ArrowConverters.toBatchIterator( + iter, + schema, + maxRecordsPerBatch, + timeZoneId, + true, + false, + context) + } + + case _ => throw new UnsupportedOperationException("Not implemented") + } + } + + private[ml] def fromArrowBatchRDD( + rdd: RDD[Array[Byte]], + schema: StructType, + timeZoneId: String, + sparkSession: SparkSession): DataFrame = { + sparkSession match { + case spark: org.apache.spark.sql.classic.SparkSession => + val rowRDD = rdd.mapPartitions { iter => + val context = TaskContext.get() + ArrowConverters.fromBatchIterator( + iter, + schema, + timeZoneId, + true, + false, + context) + } + spark.internalCreateDataFrame(rowRDD.setName("arrow"), schema) + + case _ => throw new UnsupportedOperationException("Not implemented") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index e3f31874a4c2..c04e798dccae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -46,7 +46,8 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector} import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Utils, VersionUtils} /** @@ -1142,4 +1143,66 @@ private[spark] object ReadWriteUtils { spark.read.parquet(path).as[T].collect() } } + + def saveDataFrame(path: String, df: DataFrame): Unit = { + if (localSavingModeState.get()) { + val filePath = Paths.get(path) + Files.createDirectories(filePath.getParent) + + Using.resource( + new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile))) + ) { dos => + dos.writeUTF("ARROW") // format + + val schema: StructType = df.schema + dos.writeUTF(schema.json) + + val iter = DatasetUtils.toArrowBatchRDD(df, "UTC").toLocalIterator + while (iter.hasNext) { + val bytes = iter.next() + require(bytes != null) + dos.writeInt(bytes.length) + dos.write(bytes) + } + dos.writeInt(-1) // End + } + } else { + df.write.parquet(path) + } + } + + def loadDataFrame(path: String, spark: SparkSession): DataFrame = { + if (localSavingModeState.get()) { + val sc = spark match { + case s: org.apache.spark.sql.classic.SparkSession => s.sparkContext + } + + Using.resource( + new DataInputStream(new BufferedInputStream(new FileInputStream(path))) + ) { dis => + val format = dis.readUTF() + require(format == "ARROW") + + val schema: StructType = StructType.fromString(dis.readUTF()) + + val buff = mutable.ListBuffer.empty[Array[Byte]] + var nextBytes = dis.readInt() + while (nextBytes >= 0) { + val bytes = dis.readNBytes(nextBytes) + buff.append(bytes) + nextBytes = dis.readInt() + } + require(nextBytes == -1) + + DatasetUtils.fromArrowBatchRDD( + sc.parallelize[Array[Byte]](buff.result()), + schema, + "UTC", + spark + ) + } + } else { + spark.read.parquet(path) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 1630a5d07d8e..3d994366b891 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -165,7 +165,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } val fPGrowth = new FPGrowth() testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, - FPGrowthSuite.allParamSettings, checkModelData, skipTestSaveLocal = true) + FPGrowthSuite.allParamSettings, checkModelData) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index 5b1d9f1f116a..bbc27b2a73f9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -531,7 +531,7 @@ object StructType extends AbstractDataType { override private[sql] def simpleString: String = "struct" - private[sql] def fromString(raw: String): StructType = { + private[spark] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parseString(raw)) match { case t: StructType => t case _ => throw DataTypeErrors.failedParsingStructTypeError(raw) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala index 5811fe759d3e..e002a1a616fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala @@ -399,7 +399,7 @@ class SparkSession private( /** * Creates a `DataFrame` from an `RDD[InternalRow]`. */ - private[sql] def internalCreateDataFrame( + private[spark] def internalCreateDataFrame( catalystRows: RDD[InternalRow], schema: StructType, isStreaming: Boolean = false): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 8b031af14e8b..ac2b873d9beb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -80,7 +80,7 @@ private[sql] class ArrowBatchStreamWriter( } } -private[sql] object ArrowConverters extends Logging { +private[spark] object ArrowConverters extends Logging { private[sql] class ArrowBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, @@ -231,7 +231,7 @@ private[sql] object ArrowConverters extends Logging { * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ - private[sql] def toBatchIterator( + private[spark] def toBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Long, @@ -484,7 +484,7 @@ private[sql] object ArrowConverters extends Logging { /** * Maps iterator from serialized ArrowRecordBatches to InternalRows. */ - private[sql] def fromBatchIterator( + private[spark] def fromBatchIterator( arrowBatchIter: Iterator[Array[Byte]], schema: StructType, timeZoneId: String,