From c65da874f477a857bd28bc46bc5fbfaedacbfbd4 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:47:50 +0100 Subject: [PATCH] Improve remote MlFlow behaviour (#33) --- template/steps/training/model_trainer.py | 10 ++++------ template/utils/promote_in_model_registry.py | 1 + 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/template/steps/training/model_trainer.py b/template/steps/training/model_trainer.py index 3479818..e64ef4a 100644 --- a/template/steps/training/model_trainer.py +++ b/template/steps/training/model_trainer.py @@ -30,9 +30,7 @@ def model_trainer( model: ClassifierMixin, target: str, name: str, -) -> Annotated[ - ClassifierMixin, ArtifactConfig(name="model", is_model_artifact=True) -]: +) -> Annotated[ClassifierMixin, ArtifactConfig(name="model", is_model_artifact=True)]: """Configure and train a model on the training dataset. This is an example of a model training step that takes in a dataset artifact @@ -82,10 +80,10 @@ def model_trainer( # keep track of mlflow version for future use model_registry = Client().active_stack.model_registry if model_registry: - versions = model_registry.list_model_versions(name=name) - if versions: + version = model_registry.get_latest_model_version(name=name, stage=None) + if version: model_ = get_step_context().model - model_.log_metadata({"model_registry_version": versions[-1].version}) + model_.log_metadata({"model_registry_version": version.version}) ### YOUR CODE ENDS HERE ### return model diff --git a/template/utils/promote_in_model_registry.py b/template/utils/promote_in_model_registry.py index aedf995..4e09b06 100644 --- a/template/utils/promote_in_model_registry.py +++ b/template/utils/promote_in_model_registry.py @@ -20,6 +20,7 @@ def promote_in_model_registry( target_env: stage for promotion """ model_registry = Client().active_stack.model_registry + model_registry.configure_mlflow() if latest_version != current_version: model_registry.update_model_version( name=model_name,