Skip to content

Commit 8fe0d6e

Browse files
committed
apply arrow
nit
1 parent 7fb027a commit 8fe0d6e

File tree

5 files changed

+95
-29
lines changed

5 files changed

+95
-29
lines changed

mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.util
1919

20-
import org.apache.spark.SparkException
20+
import org.apache.spark.{SparkException, TaskContext}
2121
import org.apache.spark.internal.Logging
2222
import org.apache.spark.internal.LogKeys.{CLASS_NAME, LABEL_COLUMN, NUM_CLASSES}
2323
import org.apache.spark.ml.PredictorParams
@@ -28,6 +28,7 @@ import org.apache.spark.ml.param.shared.HasWeightCol
2828
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
2929
import org.apache.spark.rdd.RDD
3030
import org.apache.spark.sql._
31+
import org.apache.spark.sql.execution.arrow.ArrowConverters
3132
import org.apache.spark.sql.functions._
3233
import org.apache.spark.sql.types._
3334

@@ -212,4 +213,51 @@ private[spark] object DatasetUtils extends Logging {
212213
dataset.select(columnToVector(dataset, vectorCol)).head().getAs[Vector](0).size
213214
}
214215
}
216+
217+
private[ml] def toArrowBatchRDD(
218+
dataFrame: DataFrame,
219+
timeZoneId: String): RDD[Array[Byte]] = {
220+
dataFrame match {
221+
case df: org.apache.spark.sql.classic.DataFrame =>
222+
val spark = df.sparkSession
223+
val schema = df.schema
224+
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
225+
df.queryExecution.executedPlan.execute().mapPartitionsInternal { iter =>
226+
val context = TaskContext.get()
227+
ArrowConverters.toBatchIterator(
228+
iter,
229+
schema,
230+
maxRecordsPerBatch,
231+
timeZoneId,
232+
true,
233+
false,
234+
context)
235+
}
236+
237+
case _ => throw new UnsupportedOperationException("Not implemented")
238+
}
239+
}
240+
241+
private[ml] def fromArrowBatchRDD(
242+
rdd: RDD[Array[Byte]],
243+
schema: StructType,
244+
timeZoneId: String,
245+
sparkSession: SparkSession): DataFrame = {
246+
sparkSession match {
247+
case spark: org.apache.spark.sql.classic.SparkSession =>
248+
val rowRDD = rdd.mapPartitions { iter =>
249+
val context = TaskContext.get()
250+
ArrowConverters.fromBatchIterator(
251+
iter,
252+
schema,
253+
timeZoneId,
254+
true,
255+
false,
256+
context)
257+
}
258+
spark.internalCreateDataFrame(rowRDD.setName("arrow"), schema)
259+
260+
case _ => throw new UnsupportedOperationException("Not implemented")
261+
}
262+
}
215263
}

mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ package org.apache.spark.ml.util
1919

2020
import java.io.{
2121
BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream,
22-
File, FileInputStream, FileOutputStream, IOException, ObjectInputStream,
23-
ObjectOutputStream
22+
File, FileInputStream, FileOutputStream, IOException
2423
}
2524
import java.nio.file.{Files, Paths}
26-
import java.util.{ArrayList, Locale, ServiceLoader}
25+
import java.util.{Locale, ServiceLoader}
2726

2827
import scala.collection.mutable
2928
import scala.jdk.CollectionConverters._
@@ -47,7 +46,7 @@ import org.apache.spark.ml.feature.RFormulaModel
4746
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}
4847
import org.apache.spark.ml.param.{ParamPair, Params}
4948
import org.apache.spark.ml.tuning.ValidatorParams
50-
import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext}
49+
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
5150
import org.apache.spark.sql.types.StructType
5251
import org.apache.spark.util.{Utils, VersionUtils}
5352

@@ -1151,17 +1150,21 @@ private[spark] object ReadWriteUtils {
11511150
Files.createDirectories(filePath.getParent)
11521151

11531152
Using.resource(
1154-
new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile)))
1155-
) { oos =>
1153+
new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile)))
1154+
) { dos =>
1155+
dos.writeUTF("ARROW") // format
1156+
11561157
val schema: StructType = df.schema
1157-
oos.writeObject(schema)
1158-
val it = df.toLocalIterator()
1159-
while (it.hasNext) {
1160-
oos.writeBoolean(true) // hasNext = True
1161-
val row: Row = it.next()
1162-
oos.writeObject(row)
1158+
dos.writeUTF(schema.json)
1159+
1160+
val iter = DatasetUtils.toArrowBatchRDD(df, "UTC").toLocalIterator
1161+
while (iter.hasNext) {
1162+
val bytes = iter.next()
1163+
require(bytes != null)
1164+
dos.writeInt(bytes.length)
1165+
dos.write(bytes)
11631166
}
1164-
oos.writeBoolean(false) // hasNext = False
1167+
dos.writeInt(-1) // End
11651168
}
11661169
} else {
11671170
df.write.parquet(path)
@@ -1170,18 +1173,33 @@ private[spark] object ReadWriteUtils {
11701173

11711174
def loadDataFrame(path: String, spark: SparkSession): DataFrame = {
11721175
if (localSavingModeState.get()) {
1176+
val sc = spark match {
1177+
case s: org.apache.spark.sql.classic.SparkSession => s.sparkContext
1178+
}
1179+
11731180
Using.resource(
1174-
new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))
1175-
) { ois =>
1176-
val schema = ois.readObject().asInstanceOf[StructType]
1177-
val rows = new ArrayList[Row]
1178-
var hasNext = ois.readBoolean()
1179-
while (hasNext) {
1180-
val row = ois.readObject().asInstanceOf[Row]
1181-
rows.add(row)
1182-
hasNext = ois.readBoolean()
1181+
new DataInputStream(new BufferedInputStream(new FileInputStream(path)))
1182+
) { dis =>
1183+
val format = dis.readUTF()
1184+
require(format == "ARROW")
1185+
1186+
val schema: StructType = StructType.fromString(dis.readUTF())
1187+
1188+
val buff = mutable.ListBuffer.empty[Array[Byte]]
1189+
var nextBytes = dis.readInt()
1190+
while (nextBytes >= 0) {
1191+
val bytes = dis.readNBytes(nextBytes)
1192+
buff.append(bytes)
1193+
nextBytes = dis.readInt()
11831194
}
1184-
spark.createDataFrame(rows, schema)
1195+
require(nextBytes == -1)
1196+
1197+
DatasetUtils.fromArrowBatchRDD(
1198+
sc.parallelize[Array[Byte]](buff.result()),
1199+
schema,
1200+
"UTC",
1201+
spark
1202+
)
11851203
}
11861204
} else {
11871205
spark.read.parquet(path)

sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ object StructType extends AbstractDataType {
531531

532532
override private[sql] def simpleString: String = "struct"
533533

534-
private[sql] def fromString(raw: String): StructType = {
534+
private[spark] def fromString(raw: String): StructType = {
535535
Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parseString(raw)) match {
536536
case t: StructType => t
537537
case _ => throw DataTypeErrors.failedParsingStructTypeError(raw)

sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ class SparkSession private(
399399
/**
400400
* Creates a `DataFrame` from an `RDD[InternalRow]`.
401401
*/
402-
private[sql] def internalCreateDataFrame(
402+
private[spark] def internalCreateDataFrame(
403403
catalystRows: RDD[InternalRow],
404404
schema: StructType,
405405
isStreaming: Boolean = false): DataFrame = {

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ private[sql] class ArrowBatchStreamWriter(
8080
}
8181
}
8282

83-
private[sql] object ArrowConverters extends Logging {
83+
private[spark] object ArrowConverters extends Logging {
8484
private[sql] class ArrowBatchIterator(
8585
rowIter: Iterator[InternalRow],
8686
schema: StructType,
@@ -231,7 +231,7 @@ private[sql] object ArrowConverters extends Logging {
231231
* Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
232232
* in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
233233
*/
234-
private[sql] def toBatchIterator(
234+
private[spark] def toBatchIterator(
235235
rowIter: Iterator[InternalRow],
236236
schema: StructType,
237237
maxRecordsPerBatch: Long,
@@ -484,7 +484,7 @@ private[sql] object ArrowConverters extends Logging {
484484
/**
485485
* Maps iterator from serialized ArrowRecordBatches to InternalRows.
486486
*/
487-
private[sql] def fromBatchIterator(
487+
private[spark] def fromBatchIterator(
488488
arrowBatchIter: Iterator[Array[Byte]],
489489
schema: StructType,
490490
timeZoneId: String,

0 commit comments

Comments
 (0)