diff --git a/python/varspark/core.py b/python/varspark/core.py index ae8ee9ae..8d85ebd4 100644 --- a/python/varspark/core.py +++ b/python/varspark/core.py @@ -1,5 +1,4 @@ import sys -from random import randint from pyspark import SparkConf from pyspark.sql import SQLContext @@ -7,7 +6,7 @@ from varspark import java from varspark.etc import find_jar - +from varspark.featuresource import FeatureSource class VarsparkContext(object): """The main entry point for VariantSpark functionality. @@ -78,85 +77,3 @@ def stop(self): # Deprecated VariantsContext = VarsparkContext - - -class FeatureSource(object): - - def __init__(self, _jvm, _vs_api, _jsql, sql, _jfs): - self._jfs = _jfs - self._jvm = _jvm - self._vs_api = _vs_api - self._jsql = _jsql - self.sql = sql - - @params(self=object, label_source=object, n_trees=Nullable(int), mtry_fraction=Nullable(float), - oob=Nullable(bool), seed=Nullable(int), batch_size=Nullable(int), - var_ordinal_levels=Nullable(int), max_depth=int, min_node_size=int) - def importance_analysis(self, label_source, n_trees=1000, mtry_fraction=None, - oob=True, seed=None, batch_size=100, var_ordinal_levels=3, - max_depth=java.MAX_INT, min_node_size=1): - """Builds random forest classifier. - - :param label_source: The ingested label source - :param int n_trees: The number of trees to build in the forest. - :param float mtry_fraction: The fraction of variables to try at each split. - :param bool oob: Should OOB error be calculated. - :param int seed: Random seed to use. - :param int batch_size: The number of trees to build in one batch. - :param int var_ordinal_levels: - - :return: Importance analysis model. - :rtype: :py:class:`ImportanceAnalysis` - """ - vs_algo = self._jvm.au.csiro.variantspark.algo - jrf_params = vs_algo.RandomForestParams(bool(oob), - java.jfloat_or( - mtry_fraction), - True, java.NAN, True, - java.jlong_or(seed, - randint( - java.MIN_LONG, - java.MAX_LONG)), - max_depth, - min_node_size, False, - 0) - jia = self._vs_api.ImportanceAnalysis(self._jsql, self._jfs, label_source, - jrf_params, n_trees, batch_size, var_ordinal_levels) - return ImportanceAnalysis(jia, self.sql) - - -class ImportanceAnalysis(object): - """ Model for random forest based importance analysis - """ - - def __init__(self, _jia, sql): - self._jia = _jia - self.sql = sql - - @params(self=object, limit=Nullable(int)) - def important_variables(self, limit=10): - """ Gets the top limit important variables as a list of tuples (name, importance) where: - - name: string - variable name - - importance: double - gini importance - """ - jimpvarmap = self._jia.importantVariablesJavaMap(limit) - return sorted(jimpvarmap.items(), key=lambda x: x[1], reverse=True) - - def oob_error(self): - """ OOB (Out of Bag) error estimate for the model - - :rtype: float - """ - return self._jia.oobError() - - def variable_importance(self): - """ Returns a DataFrame with the gini importance of variables. - - The DataFrame has two columns: - - variable: string - variable name - - importance: double - gini importance - """ - jdf = self._jia.variableImportance() - jdf.count() - jdf.createTempView("df") - return self.sql.table("df")