Skip to content

Commit

Permalink
[SPARK-14681][ML] Provide label/impurity stats for spark.ml decision …
Browse files Browse the repository at this point in the history
…tree nodes

## What changes were proposed in this pull request?

API:
```
trait ClassificationNode extends Node
  def getLabelCount(label: Int): Double

trait RegressionNode extends Node
  def getCount(): Double
  def getSum(): Double
  def getSquareSum(): Double

// turn LeafNode to be trait
trait LeafNode extends Node {
  def prediction: Double
  def impurity: Double
  ...
}

class ClassificationLeafNode extends ClassificationNode with LeafNode

class RegressionLeafNode extends RegressionNode with LeafNode

// turn InternalNode to be trait
trait InternalNode extends Node{
  def gain: Double
  def leftChild: Node
  def rightChild: Node
  def split: Split
  ...
}

class ClassificationInternalNode extends ClassificationNode with InternalNode
  override def leftChild: ClassificationNode
  override def rightChild: ClassificationNode

class RegressionInternalNode extends RegressionNode with InternalNode
  override val leftChild: RegressionNode
  override val rightChild: RegressionNode

class DecisionTreeClassificationModel
  override val rootNode: ClassificationNode

class DecisionTreeRegressionModel
  override val rootNode: RegressionNode
```
Closes apache#17466

## How was this patch tested?

UT will be added soon.

Author: WeichenXu <[email protected]>
Author: jkbradley <[email protected]>

Closes apache#20786 from WeichenXu123/tree_stat_api_2.
  • Loading branch information
WeichenXu123 authored and jkbradley committed Apr 9, 2018
1 parent 7c1654e commit 252468a
Show file tree
Hide file tree
Showing 16 changed files with 333 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi
@Since("1.4.0")
class DecisionTreeClassificationModel private[ml] (
@Since("1.4.0")override val uid: String,
@Since("1.4.0")override val rootNode: Node,
@Since("1.4.0")override val rootNode: ClassificationNode,
@Since("1.6.0")override val numFeatures: Int,
@Since("1.5.0")override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
Expand All @@ -178,7 +178,7 @@ class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) =
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)

override def predict(features: Vector): Double = {
Expand Down Expand Up @@ -276,8 +276,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true)
val model = new DecisionTreeClassificationModel(metadata.uid,
root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
Expand All @@ -292,9 +293,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
require(oldModel.algo == OldAlgo.Classification,
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
// Can't infer number of features from old model, so default to -1
new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
new DecisionTreeClassificationModel(uid,
rootNode.asInstanceOf[ClassificationNode], numFeatures, -1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
override def load(path: String): GBTClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]

val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
DefaultParamsReader.getAndSetParams(tree, treeMetadata)
tree
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
override def load(path: String): RandomForestClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

val trees: Array[DecisionTreeClassificationModel] = treesData.map {
case (treeMetadata, root) =>
val tree =
new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
val tree = new DecisionTreeClassificationModel(treeMetadata.uid,
root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(tree, treeMetadata)
tree
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
@Since("1.4.0")
class DecisionTreeRegressionModel private[ml] (
override val uid: String,
override val rootNode: Node,
override val rootNode: RegressionNode,
override val numFeatures: Int)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable {
Expand All @@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] (
* Construct a decision tree regression model.
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int) =
private[ml] def this(rootNode: RegressionNode, numFeatures: Int) =
this(Identifiable.randomUID("dtr"), rootNode, numFeatures)

override def predict(features: Vector): Double = {
Expand Down Expand Up @@ -279,8 +279,9 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures)
val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false)
val model = new DecisionTreeRegressionModel(metadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
Expand All @@ -295,8 +296,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
require(oldModel.algo == OldAlgo.Regression,
s"Cannot convert non-regression DecisionTreeModel (old API) to" +
s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
override def load(path: String): GBTRegressionModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)

val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
DefaultParamsReader.getAndSetParams(tree, treeMetadata)
tree
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode
override def load(path: String): RandomForestRegressionModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
DefaultParamsReader.getAndSetParams(tree, treeMetadata)
tree
}
Expand Down
Loading

0 comments on commit 252468a

Please sign in to comment.