Skip to content

Commit 04daae2

Browse files
STYLE: Format with black (#237)
1 parent b94afcc commit 04daae2

File tree

1 file changed

+63
-33
lines changed

1 file changed

+63
-33
lines changed

python/varspark/rfmodel.py

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,80 +4,109 @@
44
from typedecorator import params, Nullable
55

66
from varspark import java
7+
from varspark.core import VarsparkContext
78
from varspark.importanceanalysis import ImportanceAnalysis
89
from varspark.lfdrvsnohail import LocalFdrVs
910

11+
1012
class RandomForestModel(object):
11-
@params(self=object, ss=SparkSession,
12-
mtry_fraction=Nullable(float), oob=Nullable(bool),
13-
seed=Nullable(int), var_ordinal_levels=Nullable(int),
14-
max_depth=int, min_node_size=int)
15-
def __init__(self, ss, mtry_fraction=None,
16-
oob=True, seed=None, var_ordinal_levels=3,
17-
max_depth=java.MAX_INT, min_node_size=1):
18-
self.ss = ss
19-
self.sc = ss.sparkContext
20-
self._jvm = self.sc._jvm
21-
self._vs_api = getattr(self._jvm, 'au.csiro.variantspark.api')
22-
self.sql = SQLContext.getOrCreate(self.sc)
23-
self._jsql = self.sql._jsqlContext
24-
self.mtry_fraction=mtry_fraction
13+
@params(
14+
self=object,
15+
vc=VarsparkContext,
16+
mtry_fraction=Nullable(float),
17+
oob=Nullable(bool),
18+
seed=Nullable(int),
19+
var_ordinal_levels=Nullable(int),
20+
max_depth=Nullable(int),
21+
min_node_size=Nullable(int),
22+
)
23+
def __init__(
24+
self,
25+
vc,
26+
mtry_fraction=None,
27+
oob=True,
28+
seed=None,
29+
var_ordinal_levels=3,
30+
max_depth=java.MAX_INT,
31+
min_node_size=1,
32+
):
33+
self.sc = vc.sc
34+
self.sql = vc.sql
35+
self._jsql = vc._jsql
36+
self._jvm = vc._jvm
37+
self._vs_api = vc._vs_api
38+
self._jvsc = vc._jvsc
39+
self.mtry_fraction = mtry_fraction
2540
self.oob = oob
2641
self.seed = seed
2742
self.var_ordinal_levels = var_ordinal_levels
2843
self.max_depth = max_depth
2944
self.min_node_size = min_node_size
3045
self.vs_algo = self._jvm.au.csiro.variantspark.algo
31-
self.jrf_params = self.vs_algo.RandomForestParams(bool(oob),
32-
java.jfloat_or(
33-
mtry_fraction),
34-
True, java.NAN, True,
35-
java.jlong_or(seed,
36-
randint(
37-
java.MIN_LONG,
38-
java.MAX_LONG)),
39-
max_depth,
40-
min_node_size, False,
41-
0)
46+
self.jrf_params = self.vs_algo.RandomForestParams(
47+
bool(oob),
48+
java.jfloat_or(mtry_fraction),
49+
True,
50+
java.NAN,
51+
True,
52+
java.jlong_or(seed, randint(java.MIN_LONG, java.MAX_LONG)),
53+
max_depth,
54+
min_node_size,
55+
False,
56+
0,
57+
)
4258
self._jrf_model = None
4359

44-
@params(self=object, X=object, y=object, n_trees=Nullable(int), batch_size=Nullable(int))
60+
@params(
61+
self=object,
62+
X=object,
63+
y=object,
64+
n_trees=Nullable(int),
65+
batch_size=Nullable(int),
66+
)
4567
def fit_trees(self, X, y, n_trees=1000, batch_size=100):
46-
""" Fits random forest model on provided input features and labels
68+
"""Fits random forest model on provided input features and labels
4769
:param (int) n_trees: Number of trees in the forest
4870
:param (int) batch_size:
4971
"""
5072
self.n_trees = n_trees
5173
self.batch_size = batch_size
5274
self._jfs = X._jfs
53-
self._jrf_model = self._vs_api.RFModelTrainer.trainModel(X._jfs, y, self.jrf_params, self.n_trees, self.batch_size)
75+
self._jrf_model = self._vs_api.RFModelTrainer.trainModel(
76+
X._jfs, y, self.jrf_params, self.n_trees, self.batch_size
77+
)
5478

5579
@params(self=object)
5680
def importance_analysis(self):
57-
""" Returns gini variable importances for a fitted random forest model
81+
"""Returns gini variable importances for a fitted random forest model
5882
:return ImportanceAnalysis: Class containing importances and associated methods
5983
"""
6084
jia = self._vs_api.ImportanceAnalysis(self._jsql, self._jfs, self._jrf_model)
6185
return ImportanceAnalysis(jia, self.sql)
6286

6387
@params(self=object)
6488
def oob_error(self):
65-
""" Returns the overall out-of-bag error associated with a fitted random forest model
89+
"""Returns the overall out-of-bag error associated with a fitted random forest model
6690
:return oob_error (float): Out of bag error associated with the fitted model
6791
"""
6892
oob_error = self._jrf_model.oobError()
6993
return oob_error
7094

7195
@params(self=object)
7296
def get_lfdr(self):
73-
""" Returns the class with the information preloaded to compute the local FDR
97+
"""Returns the class with the information preloaded to compute the local FDR
7498
:return: class LocalFdrVs with the importances loaded
7599
"""
76100
return LocalFdrVs.from_imp_df(self.importance_analysis().variable_importance())
77101

78-
@params(self=object, file_name=str, resolve_variable_names=Nullable(bool), batch_size=Nullable(int))
102+
@params(
103+
self=object,
104+
file_name=str,
105+
resolve_variable_names=Nullable(bool),
106+
batch_size=Nullable(int),
107+
)
79108
def export_to_json(self, file_name, resolve_variable_names=True, batch_size=1000):
80-
""" Exports the random forest model to Json format
109+
"""Exports the random forest model to Json format
81110
82111
:param (string) file_name: File name to export
83112
:param (bool) resolve_variable_names: Indicates whether to associate variant ids with exported nodes
@@ -86,5 +115,6 @@ def export_to_json(self, file_name, resolve_variable_names=True, batch_size=1000
86115
jexp = self._vs_api.ExportModel(self._jrf_model, self._jfs)
87116
jexp.toJson(file_name, resolve_variable_names, batch_size)
88117

118+
89119
# Deprecated
90120
RFModelContext = RandomForestModel

0 commit comments

Comments
 (0)