Skip to content

Commit

Permalink
Removing the deprecated log_xxx_metadata calls (#28)
Browse files Browse the repository at this point in the history
* removing the deprecated calls

* correcting the ref

* fixing the review comments

* fixing the steps

* new way to fetch artifacts

* fixed the errors

* fixed imports
  • Loading branch information
bcdurak authored Dec 2, 2024
1 parent a200a13 commit a992a9e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ jobs:
with:
stack-name: ${{ matrix.stack-name }}
python-version: ${{ matrix.python-version }}
ref-zenml: ${{ inputs.ref-zenml || 'develop' }}
ref-zenml: ${{ inputs.ref-zenml || 'feature/followup-run-metadata' }}
ref-template: ${{ inputs.ref-template || github.ref }}
7 changes: 4 additions & 3 deletions template/steps/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.preprocessing import MinMaxScaler
from typing_extensions import Annotated
from utils.preprocess import ColumnsDropper, DataFrameCaster, NADropper
from zenml import log_artifact_metadata, step
from zenml import log_metadata, step


@step
Expand Down Expand Up @@ -67,8 +67,9 @@ def data_preprocessor(
dataset_tst = preprocess_pipeline.transform(dataset_tst)

# Log metadata so we can load it in the inference pipeline
log_artifact_metadata(
artifact_name="preprocess_pipeline",
log_metadata(
metadata={"random_state": random_state, "target": target},
artifact_name="preprocess_pipeline",
infer_artifact=True,
)
return dataset_trn, dataset_tst, preprocess_pipeline
37 changes: 23 additions & 14 deletions template/steps/model_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@

import pandas as pd
from sklearn.base import ClassifierMixin
from zenml import log_artifact_metadata, step

from zenml import log_metadata, step
from zenml.client import Client
from zenml.logger import get_logger

logger = get_logger(__name__)


@step
def model_evaluator(
model: ClassifierMixin,
dataset_trn: pd.DataFrame,
dataset_tst: pd.DataFrame,
min_train_accuracy: float = 0.0,
min_test_accuracy: float = 0.0,
target: Optional[str] = "target",
model: ClassifierMixin,
dataset_trn: pd.DataFrame,
dataset_tst: pd.DataFrame,
min_train_accuracy: float = 0.0,
min_test_accuracy: float = 0.0,
target: Optional[str] = "target",
) -> float:
"""Evaluate a trained model.
Expand Down Expand Up @@ -63,24 +65,31 @@ def model_evaluator(
dataset_tst.drop(columns=[target]),
dataset_tst[target],
)
logger.info(f"Train accuracy={trn_acc*100:.2f}%")
logger.info(f"Test accuracy={tst_acc*100:.2f}%")
logger.info(f"Train accuracy={trn_acc * 100:.2f}%")
logger.info(f"Test accuracy={tst_acc * 100:.2f}%")

messages = []
if trn_acc < min_train_accuracy:
messages.append(
f"Train accuracy {trn_acc*100:.2f}% is below {min_train_accuracy*100:.2f}% !"
f"Train accuracy {trn_acc * 100:.2f}% is below {min_train_accuracy * 100:.2f}% !"
)
if tst_acc < min_test_accuracy:
messages.append(
f"Test accuracy {tst_acc*100:.2f}% is below {min_test_accuracy*100:.2f}% !"
f"Test accuracy {tst_acc * 100:.2f}% is below {min_test_accuracy * 100:.2f}% !"
)
else:
for message in messages:
logger.warning(message)

log_artifact_metadata(
metadata={"train_accuracy": float(trn_acc), "test_accuracy": float(tst_acc)},
artifact_name="sklearn_classifier",
client = Client()
latest_classifier = client.get_artifact_version("sklearn_classifier")

log_metadata(
metadata={
"train_accuracy": float(trn_acc),
"test_accuracy": float(tst_acc)
},
artifact_version_id=latest_classifier.id
)

return float(tst_acc)

0 comments on commit a992a9e

Please sign in to comment.