Skip to content

Commit 5edfbfa

Browse files
DEV: Created standalone class for Random Forest Models (#237)
FEAT: Implemented RF class method for fitting the model FEAT: Implemented RF class method for obtaining importance analysis from a fitted RF FEAT: Implemented RF class method for returning oob error FEAT: Implemented RF class method for obtaining FDR from a fitted model FEAT: Implemented RF class method for exporting forest to JSON REFACTOR: Make RF model available at package level CHORE: Added type checking to all methods
1 parent 89108aa commit 5edfbfa

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

python/varspark/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
try:
44
from varspark.core import VarsparkContext, VariantsContext
5+
from varspark.rfmodel import RandomForestModel, RFModelContext
56
except Exception:
67
if not os.environ.get('VS_FIND_JAR'):
78
raise

python/varspark/rfmodel.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from random import randint
2+
3+
from pyspark.sql import SparkSession, SQLContext
4+
from typedecorator import params, Nullable
5+
6+
from varspark import java
7+
from varspark.importanceanalysis import ImportanceAnalysis
8+
from varspark.lfdrvsnohail import LocalFdrVs
9+
10+
class RandomForestModel(object):
11+
@params(self=object, ss=SparkSession,
12+
mtry_fraction=Nullable(float), oob=Nullable(bool),
13+
seed=Nullable(int), var_ordinal_levels=Nullable(int),
14+
max_depth=int, min_node_size=int)
15+
def __init__(self, ss, mtry_fraction=None,
16+
oob=True, seed=None, var_ordinal_levels=3,
17+
max_depth=java.MAX_INT, min_node_size=1):
18+
self.ss = ss
19+
self.sc = ss.sparkContext
20+
self._jvm = self.sc._jvm
21+
self._vs_api = getattr(self._jvm, 'au.csiro.variantspark.api')
22+
self.sql = SQLContext.getOrCreate(self.sc)
23+
self._jsql = self.sql._jsqlContext
24+
self.mtry_fraction=mtry_fraction
25+
self.oob = oob
26+
self.seed = seed
27+
self.var_ordinal_levels = var_ordinal_levels
28+
self.max_depth = max_depth
29+
self.min_node_size = min_node_size
30+
self.vs_algo = self._jvm.au.csiro.variantspark.algo
31+
self.jrf_params = self.vs_algo.RandomForestParams(bool(oob),
32+
java.jfloat_or(
33+
mtry_fraction),
34+
True, java.NAN, True,
35+
java.jlong_or(seed,
36+
randint(
37+
java.MIN_LONG,
38+
java.MAX_LONG)),
39+
max_depth,
40+
min_node_size, False,
41+
0)
42+
self._jrf_model = None
43+
44+
@params(self=object, X=object, y=object, n_trees=Nullable(int), batch_size=Nullable(int))
45+
def fit_trees(self, X, y, n_trees=1000, batch_size=100):
46+
""" Fits random forest model on provided input features and labels
47+
:param (int) n_trees: Number of trees in the forest
48+
:param (int) batch_size:
49+
"""
50+
self.n_trees = n_trees
51+
self.batch_size = batch_size
52+
self._jfs = X._jfs
53+
self._jrf_model = self._vs_api.RFModelTrainer.trainModel(X._jfs, y, self.jrf_params, self.n_trees, self.batch_size)
54+
55+
@params(self=object)
56+
def importance_analysis(self):
57+
""" Returns gini variable importances for a fitted random forest model
58+
:return ImportanceAnalysis: Class containing importances and associated methods
59+
"""
60+
jia = self._vs_api.ImportanceAnalysis(self._jsql, self._jfs, self._jrf_model)
61+
return ImportanceAnalysis(jia, self.sql)
62+
63+
@params(self=object)
64+
def oob_error(self):
65+
""" Returns the overall out-of-bag error associated with a fitted random forest model
66+
:return oob_error (float): Out of bag error associated with the fitted model
67+
"""
68+
oob_error = self._jrf_model.oobError()
69+
return oob_error
70+
71+
@params(self=object)
72+
def get_lfdr(self):
73+
""" Returns the class with the information preloaded to compute the local FDR
74+
:return: class LocalFdrVs with the importances loaded
75+
"""
76+
return LocalFdrVs.from_imp_df(self.importance_analysis().variable_importance())
77+
78+
@params(self=object, file_name=str, resolve_variable_names=Nullable(bool), batch_size=Nullable(int))
79+
def export_to_json(self, file_name, resolve_variable_names=True, batch_size=1000):
80+
""" Exports the random forest model to Json format
81+
82+
:param (string) file_name: File name to export
83+
:param (bool) resolve_variable_names: Indicates whether to associate variant ids with exported nodes
84+
:param (int) batch_size: Number of trees to process in a single batch during export
85+
"""
86+
jexp = self._vs_api.ExportModel(self._jrf_model, self._jfs)
87+
jexp.toJson(file_name, resolve_variable_names, batch_size)
88+
89+
# Deprecated
90+
RFModelContext = RandomForestModel

0 commit comments

Comments
 (0)