Skip to content

Commit b0d6e92

Browse files
committed
fix
1 parent 59a2a74 commit b0d6e92

File tree

3 files changed

+51
-16
lines changed

3 files changed

+51
-16
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -343,16 +343,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
343343
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {
344344

345345
override protected def saveImpl(path: String): Unit = {
346-
if (ReadWriteUtils.localSavingModeState.get()) {
347-
throw new UnsupportedOperationException(
348-
"FPGrowthModel does not support saving to local filesystem path."
349-
)
350-
}
351346
val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords)
352347
DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
353348
extraMetadata = Some(extraMetadata))
354349
val dataPath = new Path(path, "data").toString
355-
instance.freqItemsets.write.parquet(dataPath)
350+
ReadWriteUtils.saveDataFrame(dataPath, instance.freqItemsets)
356351
}
357352
}
358353

@@ -362,11 +357,6 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
362357
private val className = classOf[FPGrowthModel].getName
363358

364359
override def load(path: String): FPGrowthModel = {
365-
if (ReadWriteUtils.localSavingModeState.get()) {
366-
throw new UnsupportedOperationException(
367-
"FPGrowthModel does not support loading from local filesystem path."
368-
)
369-
}
370360
implicit val format = DefaultFormats
371361
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
372362
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
@@ -378,7 +368,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
378368
(metadata.metadata \ "numTrainingRecords").extract[Long]
379369
}
380370
val dataPath = new Path(path, "data").toString
381-
val frequentItems = sparkSession.read.parquet(dataPath)
371+
val frequentItems = ReadWriteUtils.loadDataFrame(dataPath, sparkSession)
382372
val itemSupport = if (numTrainingRecords == 0L) {
383373
Map.empty[Any, Double]
384374
} else {

mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ package org.apache.spark.ml.util
1919

2020
import java.io.{
2121
BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream,
22-
File, FileInputStream, FileOutputStream, IOException
22+
File, FileInputStream, FileOutputStream, IOException, ObjectInputStream,
23+
ObjectOutputStream
2324
}
2425
import java.nio.file.{Files, Paths}
25-
import java.util.{Locale, ServiceLoader}
26+
import java.util.{ArrayList, Locale, ServiceLoader}
2627

2728
import scala.collection.mutable
2829
import scala.jdk.CollectionConverters._
@@ -46,7 +47,8 @@ import org.apache.spark.ml.feature.RFormulaModel
4647
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}
4748
import org.apache.spark.ml.param.{ParamPair, Params}
4849
import org.apache.spark.ml.tuning.ValidatorParams
49-
import org.apache.spark.sql.{SparkSession, SQLContext}
50+
import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext}
51+
import org.apache.spark.sql.types.StructType
5052
import org.apache.spark.util.{Utils, VersionUtils}
5153

5254
/**
@@ -1142,4 +1144,47 @@ private[spark] object ReadWriteUtils {
11421144
spark.read.parquet(path).as[T].collect()
11431145
}
11441146
}
1147+
1148+
def saveDataFrame(path: String, df: DataFrame): Unit = {
1149+
if (localSavingModeState.get()) {
1150+
val filePath = Paths.get(path)
1151+
Files.createDirectories(filePath.getParent)
1152+
1153+
Using.resource(
1154+
new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile)))
1155+
) { oos =>
1156+
val schema: StructType = df.schema
1157+
oos.writeObject(schema)
1158+
val it = df.toLocalIterator()
1159+
while (it.hasNext) {
1160+
oos.writeBoolean(true) // hasNext = True
1161+
val row: Row = it.next()
1162+
oos.writeObject(row)
1163+
}
1164+
oos.writeBoolean(false) // hasNext = False
1165+
}
1166+
} else {
1167+
df.write.parquet(path)
1168+
}
1169+
}
1170+
1171+
def loadDataFrame(path: String, spark: SparkSession): DataFrame = {
1172+
if (localSavingModeState.get()) {
1173+
Using.resource(
1174+
new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))
1175+
) { ois =>
1176+
val schema = ois.readObject().asInstanceOf[StructType]
1177+
val rows = new ArrayList[Row]
1178+
var hasNext = ois.readBoolean()
1179+
while (hasNext) {
1180+
val row = ois.readObject().asInstanceOf[Row]
1181+
rows.add(row)
1182+
hasNext = ois.readBoolean()
1183+
}
1184+
spark.createDataFrame(rows, schema)
1185+
}
1186+
} else {
1187+
spark.read.parquet(path)
1188+
}
1189+
}
11451190
}

mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
165165
}
166166
val fPGrowth = new FPGrowth()
167167
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
168-
FPGrowthSuite.allParamSettings, checkModelData, skipTestSaveLocal = true)
168+
FPGrowthSuite.allParamSettings, checkModelData)
169169
}
170170
}
171171

0 commit comments

Comments
 (0)