4
4
from typedecorator import params , Nullable
5
5
6
6
from varspark import java
7
+ from varspark .core import VarsparkContext
7
8
from varspark .importanceanalysis import ImportanceAnalysis
8
9
from varspark .lfdrvsnohail import LocalFdrVs
9
10
11
+
10
12
class RandomForestModel (object ):
11
- @params (self = object , ss = SparkSession ,
12
- mtry_fraction = Nullable (float ), oob = Nullable (bool ),
13
- seed = Nullable (int ), var_ordinal_levels = Nullable (int ),
14
- max_depth = int , min_node_size = int )
15
- def __init__ (self , ss , mtry_fraction = None ,
16
- oob = True , seed = None , var_ordinal_levels = 3 ,
17
- max_depth = java .MAX_INT , min_node_size = 1 ):
18
- self .ss = ss
19
- self .sc = ss .sparkContext
20
- self ._jvm = self .sc ._jvm
21
- self ._vs_api = getattr (self ._jvm , 'au.csiro.variantspark.api' )
22
- self .sql = SQLContext .getOrCreate (self .sc )
23
- self ._jsql = self .sql ._jsqlContext
24
- self .mtry_fraction = mtry_fraction
13
+ @params (
14
+ self = object ,
15
+ vc = VarsparkContext ,
16
+ mtry_fraction = Nullable (float ),
17
+ oob = Nullable (bool ),
18
+ seed = Nullable (int ),
19
+ var_ordinal_levels = Nullable (int ),
20
+ max_depth = Nullable (int ),
21
+ min_node_size = Nullable (int ),
22
+ )
23
+ def __init__ (
24
+ self ,
25
+ vc ,
26
+ mtry_fraction = None ,
27
+ oob = True ,
28
+ seed = None ,
29
+ var_ordinal_levels = 3 ,
30
+ max_depth = java .MAX_INT ,
31
+ min_node_size = 1 ,
32
+ ):
33
+ self .sc = vc .sc
34
+ self .sql = vc .sql
35
+ self ._jsql = vc ._jsql
36
+ self ._jvm = vc ._jvm
37
+ self ._vs_api = vc ._vs_api
38
+ self ._jvsc = vc ._jvsc
39
+ self .mtry_fraction = mtry_fraction
25
40
self .oob = oob
26
41
self .seed = seed
27
42
self .var_ordinal_levels = var_ordinal_levels
28
43
self .max_depth = max_depth
29
44
self .min_node_size = min_node_size
30
45
self .vs_algo = self ._jvm .au .csiro .variantspark .algo
31
- self .jrf_params = self .vs_algo .RandomForestParams (bool (oob ),
32
- java .jfloat_or (
33
- mtry_fraction ),
34
- True , java .NAN , True ,
35
- java .jlong_or (seed ,
36
- randint (
37
- java .MIN_LONG ,
38
- java .MAX_LONG )),
39
- max_depth ,
40
- min_node_size , False ,
41
- 0 )
46
+ self .jrf_params = self .vs_algo .RandomForestParams (
47
+ bool (oob ),
48
+ java .jfloat_or (mtry_fraction ),
49
+ True ,
50
+ java .NAN ,
51
+ True ,
52
+ java .jlong_or (seed , randint (java .MIN_LONG , java .MAX_LONG )),
53
+ max_depth ,
54
+ min_node_size ,
55
+ False ,
56
+ 0 ,
57
+ )
42
58
self ._jrf_model = None
43
59
44
- @params (self = object , X = object , y = object , n_trees = Nullable (int ), batch_size = Nullable (int ))
60
+ @params (
61
+ self = object ,
62
+ X = object ,
63
+ y = object ,
64
+ n_trees = Nullable (int ),
65
+ batch_size = Nullable (int ),
66
+ )
45
67
def fit_trees (self , X , y , n_trees = 1000 , batch_size = 100 ):
46
- """ Fits random forest model on provided input features and labels
68
+ """Fits random forest model on provided input features and labels
47
69
:param (int) n_trees: Number of trees in the forest
48
70
:param (int) batch_size:
49
71
"""
50
72
self .n_trees = n_trees
51
73
self .batch_size = batch_size
52
74
self ._jfs = X ._jfs
53
- self ._jrf_model = self ._vs_api .RFModelTrainer .trainModel (X ._jfs , y , self .jrf_params , self .n_trees , self .batch_size )
75
+ self ._jrf_model = self ._vs_api .RFModelTrainer .trainModel (
76
+ X ._jfs , y , self .jrf_params , self .n_trees , self .batch_size
77
+ )
54
78
55
79
@params (self = object )
56
80
def importance_analysis (self ):
57
- """ Returns gini variable importances for a fitted random forest model
81
+ """Returns gini variable importances for a fitted random forest model
58
82
:return ImportanceAnalysis: Class containing importances and associated methods
59
83
"""
60
84
jia = self ._vs_api .ImportanceAnalysis (self ._jsql , self ._jfs , self ._jrf_model )
61
85
return ImportanceAnalysis (jia , self .sql )
62
86
63
87
@params (self = object )
64
88
def oob_error (self ):
65
- """ Returns the overall out-of-bag error associated with a fitted random forest model
89
+ """Returns the overall out-of-bag error associated with a fitted random forest model
66
90
:return oob_error (float): Out of bag error associated with the fitted model
67
91
"""
68
92
oob_error = self ._jrf_model .oobError ()
69
93
return oob_error
70
94
71
95
@params (self = object )
72
96
def get_lfdr (self ):
73
- """ Returns the class with the information preloaded to compute the local FDR
97
+ """Returns the class with the information preloaded to compute the local FDR
74
98
:return: class LocalFdrVs with the importances loaded
75
99
"""
76
100
return LocalFdrVs .from_imp_df (self .importance_analysis ().variable_importance ())
77
101
78
- @params (self = object , file_name = str , resolve_variable_names = Nullable (bool ), batch_size = Nullable (int ))
102
+ @params (
103
+ self = object ,
104
+ file_name = str ,
105
+ resolve_variable_names = Nullable (bool ),
106
+ batch_size = Nullable (int ),
107
+ )
79
108
def export_to_json (self , file_name , resolve_variable_names = True , batch_size = 1000 ):
80
- """ Exports the random forest model to Json format
109
+ """Exports the random forest model to Json format
81
110
82
111
:param (string) file_name: File name to export
83
112
:param (bool) resolve_variable_names: Indicates whether to associate variant ids with exported nodes
@@ -86,5 +115,6 @@ def export_to_json(self, file_name, resolve_variable_names=True, batch_size=1000
86
115
jexp = self ._vs_api .ExportModel (self ._jrf_model , self ._jfs )
87
116
jexp .toJson (file_name , resolve_variable_names , batch_size )
88
117
118
+
89
119
# Deprecated
90
120
RFModelContext = RandomForestModel
0 commit comments