Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,16 +343,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
if (ReadWriteUtils.localSavingModeState.get()) {
throw new UnsupportedOperationException(
"FPGrowthModel does not support saving to local filesystem path."
)
}
val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
extraMetadata = Some(extraMetadata))
val dataPath = new Path(path, "data").toString
instance.freqItemsets.write.parquet(dataPath)
ReadWriteUtils.saveDataFrame(dataPath, instance.freqItemsets)
}
}

Expand All @@ -362,11 +357,6 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
private val className = classOf[FPGrowthModel].getName

override def load(path: String): FPGrowthModel = {
if (ReadWriteUtils.localSavingModeState.get()) {
throw new UnsupportedOperationException(
"FPGrowthModel does not support loading from local filesystem path."
)
}
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
Expand All @@ -378,7 +368,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
(metadata.metadata \ "numTrainingRecords").extract[Long]
}
val dataPath = new Path(path, "data").toString
val frequentItems = sparkSession.read.parquet(dataPath)
val frequentItems = ReadWriteUtils.loadDataFrame(dataPath, sparkSession)
val itemSupport = if (numTrainingRecords == 0L) {
Map.empty[Any, Double]
} else {
Expand Down
50 changes: 49 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.ml.util

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{CLASS_NAME, LABEL_COLUMN, NUM_CLASSES}
import org.apache.spark.ml.PredictorParams
Expand All @@ -28,6 +28,7 @@ import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -212,4 +213,51 @@ private[spark] object DatasetUtils extends Logging {
dataset.select(columnToVector(dataset, vectorCol)).head().getAs[Vector](0).size
}
}

private[ml] def toArrowBatchRDD(
dataFrame: DataFrame,
timeZoneId: String): RDD[Array[Byte]] = {
dataFrame match {
case df: org.apache.spark.sql.classic.DataFrame =>
val spark = df.sparkSession
val schema = df.schema
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
df.queryExecution.executedPlan.execute().mapPartitionsInternal { iter =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dataset already has def toArrowBatchRdd, shall we reuse it?

val context = TaskContext.get()
ArrowConverters.toBatchIterator(
iter,
schema,
maxRecordsPerBatch,
timeZoneId,
true,
false,
context)
}

case _ => throw new UnsupportedOperationException("Not implemented")
}
}

private[ml] def fromArrowBatchRDD(
rdd: RDD[Array[Byte]],
schema: StructType,
timeZoneId: String,
sparkSession: SparkSession): DataFrame = {
sparkSession match {
case spark: org.apache.spark.sql.classic.SparkSession =>
val rowRDD = rdd.mapPartitions { iter =>
val context = TaskContext.get()
ArrowConverters.fromBatchIterator(
iter,
schema,
timeZoneId,
true,
false,
context)
}
spark.internalCreateDataFrame(rowRDD.setName("arrow"), schema)

case _ => throw new UnsupportedOperationException("Not implemented")
}
}
}
65 changes: 64 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{Utils, VersionUtils}

/**
Expand Down Expand Up @@ -1142,4 +1143,66 @@ private[spark] object ReadWriteUtils {
spark.read.parquet(path).as[T].collect()
}
}

def saveDataFrame(path: String, df: DataFrame): Unit = {
if (localSavingModeState.get()) {
val filePath = Paths.get(path)
Files.createDirectories(filePath.getParent)

Using.resource(
new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile)))
) { dos =>
dos.writeUTF("ARROW") // format

val schema: StructType = df.schema
dos.writeUTF(schema.json)

val iter = DatasetUtils.toArrowBatchRDD(df, "UTC").toLocalIterator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the arrow library provide APIs to write to local file?

while (iter.hasNext) {
val bytes = iter.next()
require(bytes != null)
dos.writeInt(bytes.length)
dos.write(bytes)
}
dos.writeInt(-1) // End
}
} else {
df.write.parquet(path)
}
}

def loadDataFrame(path: String, spark: SparkSession): DataFrame = {
if (localSavingModeState.get()) {
val sc = spark match {
case s: org.apache.spark.sql.classic.SparkSession => s.sparkContext
}

Using.resource(
new DataInputStream(new BufferedInputStream(new FileInputStream(path)))
) { dis =>
val format = dis.readUTF()
require(format == "ARROW")

val schema: StructType = StructType.fromString(dis.readUTF())

val buff = mutable.ListBuffer.empty[Array[Byte]]
var nextBytes = dis.readInt()
while (nextBytes >= 0) {
val bytes = dis.readNBytes(nextBytes)
buff.append(bytes)
nextBytes = dis.readInt()
}
require(nextBytes == -1)

DatasetUtils.fromArrowBatchRDD(
sc.parallelize[Array[Byte]](buff.result()),
schema,
"UTC",
spark
)
}
} else {
spark.read.parquet(path)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData, skipTestSaveLocal = true)
FPGrowthSuite.allParamSettings, checkModelData)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ object StructType extends AbstractDataType {

override private[sql] def simpleString: String = "struct"

private[sql] def fromString(raw: String): StructType = {
private[spark] def fromString(raw: String): StructType = {
Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parseString(raw)) match {
case t: StructType => t
case _ => throw DataTypeErrors.failedParsingStructTypeError(raw)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ class SparkSession private(
/**
* Creates a `DataFrame` from an `RDD[InternalRow]`.
*/
private[sql] def internalCreateDataFrame(
private[spark] def internalCreateDataFrame(
catalystRows: RDD[InternalRow],
schema: StructType,
isStreaming: Boolean = false): DataFrame = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private[sql] class ArrowBatchStreamWriter(
}
}

private[sql] object ArrowConverters extends Logging {
private[spark] object ArrowConverters extends Logging {
private[sql] class ArrowBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
Expand Down Expand Up @@ -231,7 +231,7 @@ private[sql] object ArrowConverters extends Logging {
* Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
* in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
*/
private[sql] def toBatchIterator(
private[spark] def toBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Long,
Expand Down Expand Up @@ -484,7 +484,7 @@ private[sql] object ArrowConverters extends Logging {
/**
* Maps iterator from serialized ArrowRecordBatches to InternalRows.
*/
private[sql] def fromBatchIterator(
private[spark] def fromBatchIterator(
arrowBatchIter: Iterator[Array[Byte]],
schema: StructType,
timeZoneId: String,
Expand Down