Skip to content

Commit 325350a

Browse files
Michał Sośnickimichalsosn
authored andcommitted
feat(mlflow): use log_batch in mlflow loader
1 parent ba1830c commit 325350a

File tree

2 files changed

+45
-20
lines changed

2 files changed

+45
-20
lines changed

src/neptune_exporter/loaders/mlflow.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import pandas as pd
2323
import pyarrow as pa
2424
import mlflow
25-
import mlflow.tracking
25+
from mlflow.tracking import MlflowClient
26+
from mlflow.entities import Metric
2627
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID
2728

2829

@@ -252,13 +253,16 @@ def upload_metrics(self, run_data: pd.DataFrame, run_id: str) -> None:
252253
# Determine step multiplier from actual data
253254
step_multiplier = self._determine_step_multiplier(metrics_data["step"])
254255

256+
mlflow_client = MlflowClient()
257+
255258
# Group by attribute path and log metrics
256259
for attr_path, group in metrics_data.groupby("attribute_path"):
257260
attr_name = self._sanitize_attribute_name(attr_path)
258261

259262
# Sort by step
260263
group = group.sort_values("step")
261264

265+
metrics = []
262266
for _, row in group.iterrows():
263267
if pd.notna(row["float_value"]) and pd.notna(row["step"]):
264268
step = self._convert_step_to_int(
@@ -268,10 +272,20 @@ def upload_metrics(self, run_data: pd.DataFrame, run_id: str) -> None:
268272
timestamp = None
269273
if pd.notna(row["timestamp"]):
270274
timestamp = int(row["timestamp"].timestamp() * 1000)
271-
mlflow.log_metric(
272-
attr_name, row["float_value"], step=step, timestamp=timestamp
275+
metrics.append(
276+
Metric(
277+
key=attr_name,
278+
value=row["float_value"],
279+
step=step,
280+
timestamp=timestamp,
281+
)
273282
)
274283

284+
mlflow_client.log_batch(
285+
run_id=run_id,
286+
metrics=metrics,
287+
)
288+
275289
self._logger.info(f"Uploaded metrics for run {run_id}")
276290

277291
def upload_artifacts(

tests/unit/test_mlflow_loader.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -267,25 +267,33 @@ def test_upload_metrics():
267267
}
268268
)
269269

270-
with patch("mlflow.log_metric") as mock_log_metric:
270+
with patch("neptune_exporter.loaders.mlflow.MlflowClient") as mock_client_class:
271+
mock_client = Mock()
272+
mock_client_class.return_value = mock_client
273+
271274
loader.upload_metrics(test_data, "RUN-123")
272275

273-
# Verify metrics were logged
274-
assert mock_log_metric.call_count == 3
276+
# Verify log_batch was called twice (once for each metric group)
277+
assert mock_client.log_batch.call_count == 2
275278

276-
# Check specific calls
277-
calls = mock_log_metric.call_args_list
278-
metric_names = [call[0][0] for call in calls]
279-
values = [call[0][1] for call in calls]
280-
steps = [call[1]["step"] for call in calls]
279+
# Check the calls
280+
calls = mock_client.log_batch.call_args_list
281281

282-
assert "test/metric1" in metric_names
283-
assert "test/metric2" in metric_names
284-
assert 0.5 in values
285-
assert 0.7 in values
286-
assert 0.3 in values
287-
# Steps should be converted using determined multiplier
288-
assert all(isinstance(step, int) for step in steps)
282+
# First call should be for test/metric1 (2 metrics)
283+
first_call = calls[0]
284+
assert first_call[1]["run_id"] == "RUN-123"
285+
metrics = first_call[1]["metrics"]
286+
assert len(metrics) == 2
287+
assert all(metric.key == "test/metric1" for metric in metrics)
288+
assert all(isinstance(metric.step, int) for metric in metrics)
289+
290+
# Second call should be for test/metric2 (1 metric)
291+
second_call = calls[1]
292+
assert second_call[1]["run_id"] == "RUN-123"
293+
metrics = second_call[1]["metrics"]
294+
assert len(metrics) == 1
295+
assert metrics[0].key == "test/metric2"
296+
assert isinstance(metrics[0].step, int)
289297

290298

291299
def test_upload_run_data():
@@ -308,11 +316,13 @@ def test_upload_run_data():
308316
with (
309317
patch("mlflow.start_run") as mock_start_run,
310318
patch("mlflow.log_params") as mock_log_params,
311-
patch("mlflow.log_metric") as mock_log_metric,
319+
patch("neptune_exporter.loaders.mlflow.MlflowClient") as mock_client_class,
312320
patch("mlflow.log_artifact") as mock_log_artifact,
313321
patch("pathlib.Path.exists", return_value=True),
314322
):
315323
mock_start_run.return_value.__enter__.return_value = None
324+
mock_client = Mock()
325+
mock_client_class.return_value = mock_client
316326

317327
# Convert to PyArrow table
318328
import pyarrow as pa
@@ -324,7 +334,7 @@ def test_upload_run_data():
324334
# Verify methods were called
325335
mock_start_run.assert_called_once()
326336
mock_log_params.assert_called_once()
327-
mock_log_metric.assert_called_once()
337+
mock_client.log_batch.assert_called_once() # For metrics
328338
mock_log_artifact.assert_called_once()
329339

330340

@@ -402,6 +412,7 @@ def test_upload_artifacts_string_series():
402412
"attribute_path": ["test/string_series", "test/string_series"],
403413
"attribute_type": ["string_series", "string_series"],
404414
"step": [Decimal("1.0"), Decimal("2.0")],
415+
"timestamp": [pd.Timestamp("2023-01-01"), pd.Timestamp("2023-01-02")],
405416
"string_value": ["value1", "value2"],
406417
}
407418
)

0 commit comments

Comments
 (0)