|
| 1 | +from random import randint |
| 2 | + |
| 3 | +from pyspark.sql import SparkSession, SQLContext |
| 4 | +from typedecorator import params, Nullable |
| 5 | + |
| 6 | +from varspark import java |
| 7 | +from varspark.importanceanalysis import ImportanceAnalysis |
| 8 | +from varspark.lfdrvsnohail import LocalFdrVs |
| 9 | + |
| 10 | +class RandomForestModel(object): |
| 11 | + @params(self=object, ss=SparkSession, |
| 12 | + mtry_fraction=Nullable(float), oob=Nullable(bool), |
| 13 | + seed=Nullable(int), var_ordinal_levels=Nullable(int), |
| 14 | + max_depth=int, min_node_size=int) |
| 15 | + def __init__(self, ss, mtry_fraction=None, |
| 16 | + oob=True, seed=None, var_ordinal_levels=3, |
| 17 | + max_depth=java.MAX_INT, min_node_size=1): |
| 18 | + self.ss = ss |
| 19 | + self.sc = ss.sparkContext |
| 20 | + self._jvm = self.sc._jvm |
| 21 | + self._vs_api = getattr(self._jvm, 'au.csiro.variantspark.api') |
| 22 | + self.sql = SQLContext.getOrCreate(self.sc) |
| 23 | + self._jsql = self.sql._jsqlContext |
| 24 | + self.mtry_fraction=mtry_fraction |
| 25 | + self.oob = oob |
| 26 | + self.seed = seed |
| 27 | + self.var_ordinal_levels = var_ordinal_levels |
| 28 | + self.max_depth = max_depth |
| 29 | + self.min_node_size = min_node_size |
| 30 | + self.vs_algo = self._jvm.au.csiro.variantspark.algo |
| 31 | + self.jrf_params = self.vs_algo.RandomForestParams(bool(oob), |
| 32 | + java.jfloat_or( |
| 33 | + mtry_fraction), |
| 34 | + True, java.NAN, True, |
| 35 | + java.jlong_or(seed, |
| 36 | + randint( |
| 37 | + java.MIN_LONG, |
| 38 | + java.MAX_LONG)), |
| 39 | + max_depth, |
| 40 | + min_node_size, False, |
| 41 | + 0) |
| 42 | + self._jrf_model = None |
| 43 | + |
| 44 | + @params(self=object, X=object, y=object, n_trees=Nullable(int), batch_size=Nullable(int)) |
| 45 | + def fit_trees(self, X, y, n_trees=1000, batch_size=100): |
| 46 | + """ Fits random forest model on provided input features and labels |
| 47 | + :param (int) n_trees: Number of trees in the forest |
| 48 | + :param (int) batch_size: |
| 49 | + """ |
| 50 | + self.n_trees = n_trees |
| 51 | + self.batch_size = batch_size |
| 52 | + self._jfs = X._jfs |
| 53 | + self._jrf_model = self._vs_api.RFModelTrainer.trainModel(X._jfs, y, self.jrf_params, self.n_trees, self.batch_size) |
| 54 | + |
| 55 | + @params(self=object) |
| 56 | + def importance_analysis(self): |
| 57 | + """ Returns gini variable importances for a fitted random forest model |
| 58 | + :return ImportanceAnalysis: Class containing importances and associated methods |
| 59 | + """ |
| 60 | + jia = self._vs_api.ImportanceAnalysis(self._jsql, self._jfs, self._jrf_model) |
| 61 | + return ImportanceAnalysis(jia, self.sql) |
| 62 | + |
| 63 | + @params(self=object) |
| 64 | + def oob_error(self): |
| 65 | + """ Returns the overall out-of-bag error associated with a fitted random forest model |
| 66 | + :return oob_error (float): Out of bag error associated with the fitted model |
| 67 | + """ |
| 68 | + oob_error = self._jrf_model.oobError() |
| 69 | + return oob_error |
| 70 | + |
| 71 | + @params(self=object) |
| 72 | + def get_lfdr(self): |
| 73 | + """ Returns the class with the information preloaded to compute the local FDR |
| 74 | + :return: class LocalFdrVs with the importances loaded |
| 75 | + """ |
| 76 | + return LocalFdrVs.from_imp_df(self.importance_analysis().variable_importance()) |
| 77 | + |
| 78 | + @params(self=object, file_name=str, resolve_variable_names=Nullable(bool), batch_size=Nullable(int)) |
| 79 | + def export_to_json(self, file_name, resolve_variable_names=True, batch_size=1000): |
| 80 | + """ Exports the random forest model to Json format |
| 81 | +
|
| 82 | + :param (string) file_name: File name to export |
| 83 | + :param (bool) resolve_variable_names: Indicates whether to associate variant ids with exported nodes |
| 84 | + :param (int) batch_size: Number of trees to process in a single batch during export |
| 85 | + """ |
| 86 | + jexp = self._vs_api.ExportModel(self._jrf_model, self._jfs) |
| 87 | + jexp.toJson(file_name, resolve_variable_names, batch_size) |
| 88 | + |
| 89 | +# Deprecated |
| 90 | +RFModelContext = RandomForestModel |
0 commit comments