Skip to content

Commit

Permalink
Merge branch 'pop' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
opcode81 committed Jul 25, 2023
2 parents 9454b03 + ac2b26f commit cce1b1f
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 22 deletions.
35 changes: 26 additions & 9 deletions src/sensai/evaluation/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..data import InputOutputData
from ..feature_importance import AggregatedFeatureImportance, FeatureImportanceProvider, plotFeatureImportance, FeatureImportance
from ..tracking import TrackedExperiment
from ..tracking.tracking_base import TrackingContext
from ..util.deprecation import deprecated
from ..util.io import ResultWriter
from ..util.string import prettyStringRepr
Expand Down Expand Up @@ -135,13 +136,17 @@ def evalModelViaEvaluator(model: TModel, inputOutputData: InputOutputData, testF


class EvaluationResultCollector:
def __init__(self, showPlots: bool = True, resultWriter: Optional[ResultWriter] = None):
def __init__(self, showPlots: bool = True, resultWriter: Optional[ResultWriter] = None,
trackingContext: TrackingContext = None):
self.showPlots = showPlots
self.resultWriter = resultWriter
self.trackingContext = trackingContext

def addFigure(self, name: str, fig: matplotlib.figure.Figure):
if self.resultWriter is not None:
self.resultWriter.writeFigure(name, fig, closeFigure=not self.showPlots)
if self.trackingContext is not None:
self.trackingContext.trackFigure(name, fig)

def addDataFrameCsvFile(self, name: str, df: pd.DataFrame):
if self.resultWriter is not None:
Expand Down Expand Up @@ -272,15 +277,16 @@ def gatherResults(evalResultData: VectorModelEvaluationData, resultWriter, subti
if resultWriter is not None:
resultWriter.writeTextFile("evaluator-results", strEvalResults)
if createPlots:
self.createPlots(evalResultData, showPlots=showPlots, resultWriter=resultWriter, subtitlePrefix=subtitlePrefix)
with TrackingContext.fromOptionalExperiment(trackedExperiment, model=model) as trackingContext:
self.createPlots(evalResultData, showPlots=showPlots, resultWriter=resultWriter,
subtitlePrefix=subtitlePrefix, trackingContext=trackingContext)

evalResultData = evaluator.evalModel(model)
gatherResults(evalResultData, resultWriter)
if additionalEvaluationOnTrainingData:
evalResultDataTrain = evaluator.evalModel(model, onTrainingData=True)
additionalResultWriter = resultWriter.childWithAddedPrefix("onTrain-") if resultWriter is not None else None
gatherResults(evalResultDataTrain, additionalResultWriter, subtitlePrefix="[onTrain] ")

return evalResultData

@staticmethod
Expand All @@ -303,28 +309,37 @@ def performCrossValidation(self, model: TModel, showPlots=False, logResults=True
:return: cross-validation result data
"""
resultWriter = self._resultWriterForModel(resultWriter, model)

if crossValidator is None:
crossValidator = self.createCrossValidator(model)
if trackedExperiment is not None:
crossValidator.setTrackedExperiment(trackedExperiment)

crossValidationData = crossValidator.evalModel(model)

aggStatsByVar = {varName: crossValidationData.getEvalStatsCollection(predictedVarName=varName).aggMetricsDict()
for varName in crossValidationData.predictedVarNames}
df = pd.DataFrame.from_dict(aggStatsByVar, orient="index")

strEvalResults = df.to_string()
if logResults:
log.info(f"Cross-validation results:\n{strEvalResults}")
if resultWriter is not None:
resultWriter.writeTextFile("crossval-results", strEvalResults)
self.createPlots(crossValidationData, showPlots=showPlots, resultWriter=resultWriter)

with TrackingContext.fromOptionalExperiment(trackedExperiment, model=model) as trackingContext:
self.createPlots(crossValidationData, showPlots=showPlots, resultWriter=resultWriter,
trackingContext=trackingContext)

return crossValidationData

def compareModels(self, models: Sequence[TModel], resultWriter: Optional[ResultWriter] = None, useCrossValidation=False,
fitModels=True, writeIndividualResults=True, sortColumn: Optional[str] = None, sortAscending: bool = True,
sortColumnMoveToLeft=True,
alsoIncludeUnsortedResults: bool = False, alsoIncludeCrossValGlobalStats: bool = False,
visitors: Optional[Iterable["ModelComparisonVisitor"]] = None,
writeVisitorResults=False, writeCSV=False) -> "ModelComparisonData":
writeVisitorResults=False, writeCSV=False,
trackedExperiment: Optional[TrackedExperiment] = None) -> "ModelComparisonData":
"""
Compares several models via simple evaluation or cross-validation
Expand Down Expand Up @@ -363,7 +378,7 @@ def compareModels(self, models: Sequence[TModel], resultWriter: Optional[ResultW
if crossValidator is None:
crossValidator = self.createCrossValidator(model)
crossValData = self.performCrossValidation(model, resultWriter=resultWriter if writeIndividualResults else None,
crossValidator=crossValidator)
crossValidator=crossValidator, trackedExperiment=trackedExperiment)
modelResult = ModelComparisonData.Result(crossValData=crossValData)
resultByModelName[modelName] = modelResult
evalStatsCollection = crossValData.getEvalStatsCollection()
Expand All @@ -372,7 +387,7 @@ def compareModels(self, models: Sequence[TModel], resultWriter: Optional[ResultW
if evaluator is None:
evaluator = self.createEvaluator(model)
evalData = self.performSimpleEvaluation(model, resultWriter=resultWriter if writeIndividualResults else None,
fitModel=fitModels, evaluator=evaluator)
fitModel=fitModels, evaluator=evaluator, trackedExperiment=trackedExperiment)
modelResult = ModelComparisonData.Result(evalData=evalData)
resultByModelName[modelName] = modelResult
evalStats = evalData.getEvalStats()
Expand Down Expand Up @@ -453,7 +468,8 @@ def compareModelsCrossValidation(self, models: Sequence[TModel], resultWriter: O
"""
return self.compareModels(models, resultWriter=resultWriter, useCrossValidation=True)

def createPlots(self, data: Union[TEvalData, TCrossValData], showPlots=True, resultWriter: Optional[ResultWriter] = None, subtitlePrefix: str = ""):
def createPlots(self, data: Union[TEvalData, TCrossValData], showPlots=True, resultWriter: Optional[ResultWriter] = None,
subtitlePrefix: str = "", trackingContext: Optional[TrackingContext] = None):
"""
Creates default plots that visualise the results in the given evaluation data
Expand All @@ -464,7 +480,8 @@ def createPlots(self, data: Union[TEvalData, TCrossValData], showPlots=True, res
"""
if not showPlots and resultWriter is None:
return
resultCollector = EvaluationResultCollector(showPlots=showPlots, resultWriter=resultWriter)
resultCollector = EvaluationResultCollector(showPlots=showPlots, resultWriter=resultWriter,
trackingContext=trackingContext)
self._createPlots(data, resultCollector, subtitle=subtitlePrefix + data.modelName)

def _createPlots(self, data: Union[TEvalData, TCrossValData], resultCollector: EvaluationResultCollector, subtitle=None):
Expand Down
10 changes: 6 additions & 4 deletions src/sensai/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..data import DataSplitter, DataSplitterFractional, InputOutputData
from ..data_transformation import DataFrameTransformer
from ..tracking import TrackingMixin, TrackedExperiment
from ..tracking.tracking_base import TrackingContext
from ..util.string import ToStringMixin
from ..util.typing import PandasNamedTuple
from ..vector_model import VectorClassificationModel, VectorModel, VectorModelBase, VectorModelFittableBase, VectorRegressionModel
Expand Down Expand Up @@ -47,7 +48,7 @@ def computeMetrics(self, model, **kwargs) -> Optional[Dict[str, float]]:
"""
valuesDict = self._computeMetrics(model, **kwargs)
if self.trackedExperiment is not None:
self.trackedExperiment.trackValues(valuesDict, addValuesDict={"str(model)": str(model)})
self.trackedExperiment.trackValues(valuesDict, addValuesDict={"str(model)": str(model)}) # TODO
return valuesDict


Expand Down Expand Up @@ -200,10 +201,11 @@ def evalModel(self, model: VectorModelBase, onTrainingData=False, track=True) ->
"""
data = self.trainingData if onTrainingData else self.testData
result: VectorModelEvaluationData = self._evalModel(model, data)
if track and self.trackedExperiment is not None:
with TrackingContext.fromOptionalExperiment(self.trackedExperiment if track else None, model=model) as trackingContext:
multipleVars = len(result.predictedVarNames) > 1
for predVarName in result.predictedVarNames:
addValuesDict = {"str(model)": str(model), "predVarName": predVarName}
self.trackedExperiment.trackValues(result.getEvalStats(predVarName).metricsDict(), addValuesDict=addValuesDict)
metrics = result.getEvalStats(predVarName).metricsDict()
trackingContext.trackMetrics(metrics, predVarName if multipleVars else None)
return result

@abstractmethod
Expand Down
5 changes: 4 additions & 1 deletion src/sensai/tracking/azure_tracking.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from azureml.core import Experiment, Workspace
from typing import Dict, Any

from .tracking_base import TrackedExperiment
from .tracking_base import TrackedExperiment, TContext
from .. import VectorModel
from ..evaluation.evaluator import MetricsDictProvider

Expand Down Expand Up @@ -47,3 +47,6 @@ def _trackValues(self, valuesDict: Dict[str, Any]):
with self.experiment.start_logging() as run:
for name, value in valuesDict.items():
run.log(name, value)

def _createTrackingContext(self, name: str, description: str) -> TContext:
raise NotImplementedError()
22 changes: 22 additions & 0 deletions src/sensai/tracking/clearml_tracking.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import logging
from typing import Dict

from matplotlib import pyplot as plt

from .tracking_base import TrackingContext, TContext
from ..tracking import TrackedExperiment

from clearml import Task

log = logging.getLogger(__name__)


class ClearMLTrackingContext(TrackingContext):
def __init__(self, name, experiment, task):
super().__init__(name, experiment)
self.task = task

def _trackMetrics(self, metrics: Dict[str, float]):
self.task.connect(metrics)

def trackFigure(self, name: str, fig: plt.Figure):
fig.show()

def _end(self):
pass


# TODO: this is an initial working implementation, it should eventually be improved
class ClearMLExperiment(TrackedExperiment):
def __init__(self, task: Task = None, projectName: str = None, taskName: str = None,
Expand Down Expand Up @@ -39,3 +58,6 @@ def __init__(self, task: Task = None, projectName: str = None, taskName: str = N

def _trackValues(self, valuesDict):
self.task.connect(valuesDict)

def _createTrackingContext(self, name: str, description: str) -> TContext:
return ClearMLTrackingContext(name, self, self.task)
40 changes: 35 additions & 5 deletions src/sensai/tracking/mlflow_tracking.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,50 @@
from typing import Dict, Any

import mlflow
from matplotlib import pyplot as plt

from .tracking_base import TrackedExperiment
from .tracking_base import TrackedExperiment, TrackingContext


class MlFlowExperiment(TrackedExperiment):
def __init__(self, experimentName: str, trackingUri: str, additionalLoggingValuesDict=None):
"""
class MlFlowTrackingContext(TrackingContext):
def __init__(self, name: str, experiment: "MlFlowExperiment", run_id=None, description=""):
super().__init__(name, experiment)
if run_id is not None:
run = mlflow.start_run(run_id)
else:
run = mlflow.start_run(run_name=name, description=description)
self.run = run

def _trackMetrics(self, metrics: Dict[str, float]):
mlflow.log_metrics(metrics)

def trackFigure(self, name: str, fig: plt.Figure):
mlflow.log_figure(fig, name + ".png")

def _end(self):
mlflow.end_run()


class MlFlowExperiment(TrackedExperiment[MlFlowTrackingContext]):
def __init__(self, experimentName: str, trackingUri: str, additionalLoggingValuesDict=None,
instancePrefix: str = ""):
"""
:param experimentName:
:param trackingUri:
:param additionalLoggingValuesDict:
:param instancePrefix:
"""
mlflow.set_tracking_uri(trackingUri)
mlflow.set_experiment(experiment_name=experimentName)
super().__init__(additionalLoggingValuesDict=additionalLoggingValuesDict)
super().__init__(instancePrefix=instancePrefix, additionalLoggingValuesDict=additionalLoggingValuesDict)
self._run_name_to_id = {}

def _trackValues(self, valuesDict: Dict[str, Any]):
with mlflow.start_run():
mlflow.log_metrics(valuesDict)

def _createTrackingContext(self, name: str, description: str) -> MlFlowTrackingContext:
run_id = self._run_name_to_id.get(name)
context = MlFlowTrackingContext(name, self, run_id=run_id, description=description)
self._run_name_to_id[name] = context.run.info.run_id
return context
101 changes: 98 additions & 3 deletions src/sensai/tracking/tracking_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,92 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, Generic, TypeVar

from matplotlib import pyplot as plt

class TrackedExperiment(ABC):
def __init__(self, additionalLoggingValuesDict=None):
from ..util import countNone
from ..util.deprecation import deprecated
from ..vector_model import VectorModelBase


class TrackingContext(ABC):
def __init__(self, name: str, experiment: Optional["TrackedExperiment"]):
self.name = name
self._experiment = experiment
self._isRunning = False

@staticmethod
def fromOptionalExperiment(experiment: Optional["TrackedExperiment"], model: Optional[VectorModelBase] = None,
name: Optional[str] = None, description: str = ""):
if experiment is None:
return DummyTrackingContext(name)
else:
if countNone(name, model) != 1:
raise ValueError("Must provide exactly one of {model, name}")
if model is not None:
return experiment.beginContextForModel(model)
else:
return experiment.beginContext(name, description)

@abstractmethod
def _trackMetrics(self, metrics: Dict[str, float]):
pass

def trackMetrics(self, metrics: Dict[str, float], predictedVarName: Optional[str] = None):
if predictedVarName is not None:
metrics = {f"{predictedVarName}_{k}": v for k, v in metrics.items()}
self._trackMetrics(metrics)

@abstractmethod
def trackFigure(self, name: str, fig: plt.Figure):
pass

def __enter__(self):
self._isRunning = True
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.end()

@abstractmethod
def _end(self):
pass

def end(self):
self._end()
if self._isRunning:
if self._experiment is not None:
self._experiment.endContext(self)
self._isRunning = False


class DummyTrackingContext(TrackingContext):
def __init__(self, name):
super().__init__(name, None)

def _trackMetrics(self, metrics: Dict[str, float]):
pass

def trackFigure(self, name: str, fig: plt.Figure):
pass

def _end(self):
pass


TContext = TypeVar("TContext", bound=TrackingContext)


class TrackedExperiment(Generic[TContext], ABC):
def __init__(self, instancePrefix: str = "", additionalLoggingValuesDict=None):
"""
Base class for tracking
:param additionalLoggingValuesDict: additional values to be logged for each run
"""
self.instancePrefix = instancePrefix
self.additionalLoggingValuesDict = additionalLoggingValuesDict
self._contexts = []

@deprecated("Use a tracking context instead")
def trackValues(self, valuesDict: Dict[str, Any], addValuesDict: Dict[str, Any] = None):
valuesDict = dict(valuesDict)
if addValuesDict is not None:
Expand All @@ -22,6 +99,24 @@ def trackValues(self, valuesDict: Dict[str, Any], addValuesDict: Dict[str, Any]
def _trackValues(self, valuesDict):
pass

@abstractmethod
def _createTrackingContext(self, name: str, description: str) -> TContext:
pass

def beginContext(self, name: str, description: str = "") -> TContext:
instance = self._createTrackingContext(self.instancePrefix + name, description)
self._contexts.append(instance)
return instance

def beginContextForModel(self, model: VectorModelBase):
return self.beginContext(model.getName(), str(model))

def endContext(self, instance: TContext):
runningInstance = self._contexts[-1]
if instance != runningInstance:
raise ValueError(f"Passed instance ({instance}) is not the currently running instance ({runningInstance})")
self._contexts.pop()


class TrackingMixin(ABC):
_objectId2trackedExperiment = {}
Expand Down

0 comments on commit cce1b1f

Please sign in to comment.