Skip to content

Commit 7977fd2

Browse files
committed
Refactoring, IO implementation for saving, tests and mlflow running on
CI
1 parent cba8599 commit 7977fd2

File tree

10 files changed

+112
-38
lines changed

10 files changed

+112
-38
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ env:
1515
TEST_MLJBASE: "true"
1616
jobs:
1717
test:
18+
services:
19+
mlflow:
20+
image: adacotechjp/mlflow:2.3.1
21+
ports:
22+
- 5000:5000
1823
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
1924
runs-on: ${{ matrix.os }}
2025
timeout-minutes: 60

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6868
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
6969

7070
[targets]
71-
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
71+
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables", "MLFlowClient"]

ext/LoggersExt/LoggersExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
module LoggersExt
22

33
using MLJBase: info, name, Model,
4-
params, Machine, Measure,
5-
flat_params
4+
Machine, Measure, flat_params
65

7-
import MLJBase: save, evaluate!, MLFlowLogger
6+
import MLJBase: save, evaluate!, mlflow_logger
7+
8+
include("utils.jl")
89

910
include("mlflow.jl")
1011

ext/LoggersExt/mlflow.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ using MLFlowClient: MLFlow, logparam, logmetric,
22
createrun, MLFlowRun, updaterun,
33
logartifact, getorcreateexperiment
44

5-
struct MLFlowInstance
5+
struct MLFlowLogger
66
client::MLFlow
77
experiment_name::String
88
artifact_location::Union{String, Missing}
99
end
10-
MLFlowLogger(base_uri::String, experiment_name::String,
10+
mlflow_logger(base_uri::String, experiment_name::String,
1111
artifact_location::Union{String, Missing}) =
12-
MLFlowInstance(MLFlow(base_uri), experiment_name, artifact_location)
12+
MLFlowLogger(MLFlow(base_uri), experiment_name, artifact_location)
1313

1414
function _logmodelparams(client::MLFlow, run::MLFlowRun, model::Model)
1515
model_params = params(model) |> flat_params |> collect
@@ -18,8 +18,8 @@ function _logmodelparams(client::MLFlow, run::MLFlowRun, model::Model)
1818
end
1919
end
2020

21-
function _logmachinemeasures(client::MLFlow, run::MLFlowRun, measures::Vector{Measure},
22-
measurements::Vector{Float64})
21+
function _logmachinemeasures(client::MLFlow, run::MLFlowRun, measures::Vector{T},
22+
measurements::Vector{Float64}) where T<:Measure
2323
measure_names = measures .|> info .|> x -> x.name
2424
for (name, value) in zip(measure_names, measurements)
2525
logmetric(client, run, name, value)
@@ -29,7 +29,7 @@ end
2929
function evaluate!(mach::Machine, resampling, weights,
3030
class_weights, rows, verbosity,
3131
repeats, measures, operations,
32-
acceleration, force, logger::MLFlowInstance)
32+
acceleration, force, logger::MLFlowLogger)
3333
performance_evaluation = evaluate!(mach, resampling, weights,
3434
class_weights, rows, verbosity,
3535
repeats, measures, operations,
@@ -46,17 +46,19 @@ function evaluate!(mach::Machine, resampling, weights,
4646
return performance_evaluation
4747
end
4848

49-
function save(logger::MLFlowInstance, mach::Machine)
49+
function save(logger::MLFlowLogger, mach::Machine)
50+
io = IOBuffer()
51+
save(io, mach)
52+
5053
model_name = name(mach.model)
51-
fname = "$(model_name).jls"
52-
save(fname, mach)
5354

5455
experiment = getorcreateexperiment(logger.client, logger.experiment_name,
5556
artifact_location=logger.artifact_location)
5657
run = createrun(logger.client, experiment;
5758
run_name="$(model_name) run")
5859

5960
_logmodelparams(logger.client, run, mach.model)
60-
logartifact(logger.client, run, fname)
61-
rm(fname)
61+
fname = "$(model_name).jls"
62+
logartifact(logger.client, run, fname, io)
63+
updaterun(logger.client, run, "FINISHED")
6264
end
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
istransparent(::Any) = false
2-
istransparent(::MLJType) = true
1+
isamodel(::Any) = false
2+
isamodel(::Model) = true
33

44
"""
5-
params(m::MLJType)
5+
params(m::Model)
66
77
Recursively convert any transparent object `m` into a named tuple,
88
keyed on the property names of `m`. An object is *transparent* if
9-
`MLJBase.istransparent(m) == true`. The named tuple is possibly nested
9+
`isamodel(m) == true`. The named tuple is possibly nested
1010
because `params` is recursively applied to the property values, which
1111
themselves might be transparent.
1212
13-
For most `MLJType` objects, properties are synonymous with fields, but
13+
For most `Model` objects, properties are synonymous with fields, but
1414
this is not a hard requirement.
1515
16-
Most objects of type `MLJType` are transparent.
16+
Most objects of type `Model` are transparent.
1717
1818
julia> params(EnsembleModel(atom=ConstantClassifier()))
1919
(atom = (target_type = Bool,),
@@ -24,7 +24,7 @@ Most objects of type `MLJType` are transparent.
2424
parallel = true,)
2525
2626
"""
27-
params(m) = params(m, Val(istransparent(m)))
27+
params(m) = params(m, Val(isamodel(m)))
2828
params(m, ::Val{false}) = m
2929
function params(m, ::Val{true})
3030
fields = propertynames(m)

src/MLJBase.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,15 @@ export coerce, coerce!, autotype, schema, info
294294
export UnivariateFiniteArray, UnivariateFiniteVector
295295

296296
# -----------------------------------------------------------------------
297-
# abstract model types defined in MLJModelInterface.jl and extended here:
297+
# re-export from MLJModelInterface.jl
298+
299+
#abstract model types defined in MLJModelInterface.jl and extended here:
298300
for T in EXTENDED_ABSTRACT_MODEL_TYPES
299301
@eval(export $T)
300302
end
301303

304+
export params
305+
302306
# -------------------------------------------------------------------
303307
# exports from this module, MLJBase
304308

@@ -308,9 +312,6 @@ export default_resource
308312
# one_dimensional_ranges.jl:
309313
export ParamRange, NumericRange, NominalRange, iterator, scale
310314

311-
# parameter_inspection.jl:
312-
export params # note this is *not* an extension of StatsBase.params
313-
314315
# data.jl:
315316
export partition, unpack, complement, restrict, corestrict
316317

@@ -381,7 +382,7 @@ export pdf, sampler, mode, median, mean, shuffle!, categorical, shuffle,
381382
levels, levels!, std, Not, support, logpdf, LittleDict
382383

383384
# loggers.jl
384-
export MLFlowLogger
385+
export mlflow_logger
385386

386387
if !isdefined(Base, :get_extension)
387388
include("../ext/LoggersExt/LoggersExt.jl")

src/loggers.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""
2-
MLFlowLogger(; base_uri="localhost:5000", experiment_name=missing)
2+
mlflow_logger(; base_uri="localhost:5000", experiment_name=missing)
33
4-
Base type for MLFlow logger. Creates an instance of MLFlow, as defined in
4+
Constructor for the base type for MLFlow logger. Creates an instance of MLFlow,
5+
as defined in
56
[`MLFlowClient.jl`](https://juliaai.github.io/MLFlowClient.jl/dev/), and logs
67
to an experiment.
78
@@ -13,14 +14,14 @@ If `experiment_name` is not provided, a new experiment with the name
1314
"MLJ.jl experiments" will be created.
1415
1516
### Return value
16-
A `MLFlowInstance` object, containing a
17+
A `MLFlowLogger` object, containing a
1718
[`MLFlow`](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlow)
1819
object and the experiment name
1920
2021
"""
21-
MLFlowLogger(; base_uri="http://localhost:5000",
22+
mlflow_logger(; base_uri="http://localhost:5000",
2223
experiment_name="MLJ experiments",
2324
artifact_location=missing) =
24-
MLFlowLogger(base_uri, experiment_name, artifact_location)
25-
MLFlowLogger(_, _, _) =
26-
error("Please run `import MLFlowClient` to use MLFlowLogger.")
25+
mlflow_logger(base_uri, experiment_name, artifact_location)
26+
mlflow_logger(_, _, _) =
27+
error("Please run `import MLFlowClient` to use mlflow_logger.")

src/utilities.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ LittleDict{...} with 3 entries:
4848
"Y" => 3
4949
```
5050
"""
51-
function flat_params(params::NamedTuple)
51+
function flat_params(parameters::NamedTuple)
5252
result = LittleDict{String, Any}()
53-
for key in keys(params)
54-
value = getproperty(params, key)
53+
for key in keys(parameters)
54+
value = params(getproperty(parameters, key))
5555
if value isa NamedTuple
5656
sub_dict = flat_params(value)
5757
for (sub_key, sub_value) in pairs(sub_dict)
58-
new_key = string(key, "_", sub_key)
58+
new_key = string(key, "__", sub_key)
5959
result[new_key] = sub_value
6060
end
6161
else

test/extensions/loggers.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
module TestLoggers
2+
3+
using Test
4+
using MLJBase
5+
using ..Models
6+
7+
@testset "mlflow logger" begin
8+
artifact_directory = "mlj-test"
9+
experiment_name = "mlflow logger tests"
10+
11+
@testset "outside extension tests" begin
12+
@test_throws ErrorException mlflow_logger()
13+
14+
using MLFlowClient
15+
logger = mlflow_logger(; experiment_name=experiment_name, artifact_location=artifact_directory)
16+
17+
@test logger.client isa MLFlow
18+
@test logger.experiment_name == experiment_name
19+
@test logger.artifact_location == artifact_directory
20+
end # @testset "outside extension tests"
21+
22+
@testset "extension tests" begin
23+
X = (x=rand(4),)
24+
y = ["Chenta", "Missy", "Gala", "Wendy"] |> categorical
25+
26+
mach = machine(ConstantClassifier(), X, y)
27+
fit!(mach, verbosity=0)
28+
29+
logger = mlflow_logger(; experiment_name=experiment_name, artifact_location=artifact_directory)
30+
31+
@testset "save" begin
32+
run = MLJBase.save(logger, mach)
33+
experiment = getexperiment(logger.client, run.info.experiment_id)
34+
@test run isa MLFlowRun
35+
@test experiment isa MLFlowExperiment
36+
37+
deleterun(logger.client, run)
38+
deleteexperiment(logger.client, experiment)
39+
end # @testset "save"
40+
41+
@testset "evaluate!" begin
42+
evaluate!(mach, resampling=Holdout(), logger=logger)
43+
44+
experiments = searchexperiments(logger.client)
45+
experiments_ids = experiments .|> (e -> e.experiment_id)
46+
runs = searchruns(logger.client, experiments_ids)
47+
48+
# it's 2 because of the default experiment
49+
@test length(experiments_ids) == 2
50+
@test length(runs) == 1
51+
52+
deleterun(logger.client, runs[1])
53+
deleteexperiment(logger.client, experiments[2])
54+
end # @testset "evaluate!"
55+
end # @testset "extension tests"
56+
end # @testset "mlflow logger"
57+
58+
end # module
59+
60+
true

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,7 @@ end
8080
@test include("hyperparam/one_dimensional_ranges.jl")
8181
@test include("hyperparam/one_dimensional_range_methods.jl")
8282
end
83+
84+
@conditional_testset "extensions" begin
85+
@test include("extensions/loggers.jl")
86+
end

0 commit comments

Comments
 (0)