@@ -267,25 +267,33 @@ def test_upload_metrics():
267267 }
268268 )
269269
270- with patch ("mlflow.log_metric" ) as mock_log_metric :
270+ with patch ("neptune_exporter.loaders.mlflow.MlflowClient" ) as mock_client_class :
271+ mock_client = Mock ()
272+ mock_client_class .return_value = mock_client
273+
271274 loader .upload_metrics (test_data , "RUN-123" )
272275
273- # Verify metrics were logged
274- assert mock_log_metric . call_count == 3
276+ # Verify log_batch was called twice (once for each metric group)
277+ assert mock_client . log_batch . call_count == 2
275278
276- # Check specific calls
277- calls = mock_log_metric .call_args_list
278- metric_names = [call [0 ][0 ] for call in calls ]
279- values = [call [0 ][1 ] for call in calls ]
280- steps = [call [1 ]["step" ] for call in calls ]
279+ # Check the calls
280+ calls = mock_client .log_batch .call_args_list
281281
282- assert "test/metric1" in metric_names
283- assert "test/metric2" in metric_names
284- assert 0.5 in values
285- assert 0.7 in values
286- assert 0.3 in values
287- # Steps should be converted using determined multiplier
288- assert all (isinstance (step , int ) for step in steps )
282+ # First call should be for test/metric1 (2 metrics)
283+ first_call = calls [0 ]
284+ assert first_call [1 ]["run_id" ] == "RUN-123"
285+ metrics = first_call [1 ]["metrics" ]
286+ assert len (metrics ) == 2
287+ assert all (metric .key == "test/metric1" for metric in metrics )
288+ assert all (isinstance (metric .step , int ) for metric in metrics )
289+
290+ # Second call should be for test/metric2 (1 metric)
291+ second_call = calls [1 ]
292+ assert second_call [1 ]["run_id" ] == "RUN-123"
293+ metrics = second_call [1 ]["metrics" ]
294+ assert len (metrics ) == 1
295+ assert metrics [0 ].key == "test/metric2"
296+ assert isinstance (metrics [0 ].step , int )
289297
290298
291299def test_upload_run_data ():
@@ -308,11 +316,13 @@ def test_upload_run_data():
308316 with (
309317 patch ("mlflow.start_run" ) as mock_start_run ,
310318 patch ("mlflow.log_params" ) as mock_log_params ,
311- patch ("mlflow.log_metric " ) as mock_log_metric ,
319+ patch ("neptune_exporter.loaders. mlflow.MlflowClient " ) as mock_client_class ,
312320 patch ("mlflow.log_artifact" ) as mock_log_artifact ,
313321 patch ("pathlib.Path.exists" , return_value = True ),
314322 ):
315323 mock_start_run .return_value .__enter__ .return_value = None
324+ mock_client = Mock ()
325+ mock_client_class .return_value = mock_client
316326
317327 # Convert to PyArrow table
318328 import pyarrow as pa
@@ -324,7 +334,7 @@ def test_upload_run_data():
324334 # Verify methods were called
325335 mock_start_run .assert_called_once ()
326336 mock_log_params .assert_called_once ()
327- mock_log_metric . assert_called_once ()
337+ mock_client . log_batch . assert_called_once () # For metrics
328338 mock_log_artifact .assert_called_once ()
329339
330340
@@ -402,6 +412,7 @@ def test_upload_artifacts_string_series():
402412 "attribute_path" : ["test/string_series" , "test/string_series" ],
403413 "attribute_type" : ["string_series" , "string_series" ],
404414 "step" : [Decimal ("1.0" ), Decimal ("2.0" )],
415+ "timestamp" : [pd .Timestamp ("2023-01-01" ), pd .Timestamp ("2023-01-02" )],
405416 "string_value" : ["value1" , "value2" ],
406417 }
407418 )
0 commit comments