Skip to content

Commit 1cd258b

Browse files
committed
Moving params from MLJFlow.jl and preparing evaluate! to work with log_evaluation
1 parent a86babf commit 1cd258b

File tree

5 files changed

+113
-30
lines changed

5 files changed

+113
-30
lines changed

src/MLJBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ export TransformedTargetModel
349349

350350
# resampling.jl:
351351
export ResamplingStrategy, Holdout, CV, StratifiedCV, TimeSeriesCV,
352-
evaluate!, Resampler, PerformanceEvaluation, log_evaluation
352+
evaluate!, Resampler, PerformanceEvaluation
353353

354354
# `MLJType` and the abstract `Model` subtypes are exported from within
355355
# src/composition/abstract_types.jl

src/composition/models/stacking.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ function internal_stack_report(
388388
# For each model we record the results mimicking the fields PerformanceEvaluation
389389
results = NamedTuple{modelnames}(
390390
[(
391-
measure = stack.measures,
392391
model = model,
392+
measure = stack.measures,
393393
measurement = Vector{Any}(undef, n_measures),
394394
operation = _actual_operations(nothing, stack.measures, model, verbosity),
395395
per_fold = [Vector{Any}(undef, nfolds) for _ in 1:n_measures],

src/resampling.jl

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -474,11 +474,11 @@ be interpreted with caution. See, for example, Bates et al.
474474
These fields are part of the public API of the `PerformanceEvaluation`
475475
struct.
476476
477-
- `measure`: vector of measures (metrics) used to evaluate performance
478-
479477
- `model`: model used to create the performance evaluation. In the case a
480478
tuning model, this is the best model found.
481479
480+
- `measure`: vector of measures (metrics) used to evaluate performance
481+
482482
- `measurement`: vector of measurements - one for each element of
483483
`measure` - aggregating the performance measurements over all
484484
train/test pairs (folds). The aggregation method applied for a given
@@ -512,15 +512,15 @@ struct.
512512
training and evaluation respectively.
513513
"""
514514
struct PerformanceEvaluation{M,
515-
Model,
515+
Measure,
516516
Measurement,
517517
Operation,
518518
PerFold,
519519
PerObservation,
520520
FittedParamsPerFold,
521521
ReportPerFold} <: MLJType
522-
measure::M
523-
model::Model
522+
model::M
523+
measure::Measure
524524
measurement::Measurement
525525
operation::Operation
526526
per_fold::PerFold
@@ -573,9 +573,9 @@ function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)
573573

574574
println(io, "PerformanceEvaluation object "*
575575
"with these fields:")
576-
println(io, " measure, operation, measurement, per_fold,\n"*
576+
println(io, " model, measure, operation, measurement, per_fold,\n"*
577577
" per_observation, fitted_params_per_fold,\n"*
578-
" report_per_fold, train_test_rows", "model")
578+
" report_per_fold, train_test_rows")
579579
println(io, "Extract:")
580580
show_color = MLJBase.SHOW_COLOR[]
581581
color_off()
@@ -812,6 +812,24 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" *
812812

813813
# --------------------------------------------------------------
814814
# User interface points: `evaluate!` and `evaluate`
815+
#
816+
"""
817+
log_evaluation(logger, performance_evaluation)
818+
819+
Log a performance evaluation to `logger`, an object specific to some logging
820+
platform, such as mlflow. If `logger=nothing` then no logging is performed.
821+
The method is called at the end of every call to `evaluate/evaluate!` using
822+
the logger provided by the `logger` keyword argument.
823+
824+
# Implementations for new logging platforms
825+
#
826+
Julia interfaces to workflow logging platforms, such as mlflow (provided by
827+
the MLFlowClient.jl interface) should overload
828+
`log_evaluation(logger::LoggerType, performance_evaluation)`,
829+
where `LoggerType` is a platform-specific type for logger objects. For an
830+
example, see the implementation provided by the MLJFlow.jl package.
831+
"""
832+
log_evaluation(logger, performance_evaluation) = nothing
815833

816834
"""
817835
evaluate!(mach,
@@ -825,7 +843,8 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" *
825843
acceleration=default_resource(),
826844
force=false,
827845
verbosity=1,
828-
check_measure=true)
846+
check_measure=true,
847+
logger=nothing)
829848
830849
Estimate the performance of a machine `mach` wrapping a supervised
831850
model in data, using the specified `resampling` strategy (defaulting
@@ -924,6 +943,8 @@ untouched.
924943
925944
- `check_measure` - default is `true`
926945
946+
- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref))
947+
927948
### Return value
928949
929950
A [`PerformanceEvaluation`](@ref) object. See
@@ -943,7 +964,8 @@ function evaluate!(mach::Machine{<:Measurable};
943964
repeats=1,
944965
force=false,
945966
check_measure=true,
946-
verbosity=1)
967+
verbosity=1,
968+
logger=nothing)
947969

948970
# this method just checks validity of options, preprocess the
949971
# weights, measures, operations, and dispatches a
@@ -984,9 +1006,12 @@ function evaluate!(mach::Machine{<:Measurable};
9841006

9851007
_acceleration= _process_accel_settings(acceleration)
9861008

987-
evaluate!(mach, resampling, weights, class_weights, rows, verbosity,
988-
repeats, _measures, _operations, _acceleration, force)
1009+
evaluation = evaluate!(mach, resampling, weights, class_weights, rows,
1010+
verbosity, repeats, _measures, _operations,
1011+
_acceleration, force)
1012+
log_evaluation(logger, evaluation)
9891013

1014+
evaluation
9901015
end
9911016

9921017
"""
@@ -1161,22 +1186,6 @@ function measure_specific_weights(measure, weights, class_weights, test)
11611186
return nothing
11621187
end
11631188

1164-
# Workflow logging interfaces, such as MLJFlow (MLFlow connection via
1165-
# MLFlowClient.jl), overload the following method but replace the `logger`
1166-
# argument with `logger::LoggerType`, where `LoggerType` is specific to the
1167-
# logging platform.
1168-
"""
1169-
log_evaluation(logger, performance_evaluation)
1170-
1171-
Logs a performance evaluation (the returning object from `evaluate!`)
1172-
to a logging platform. The default implementation does nothing.
1173-
Can be overloaded by logging interfaces, such as MLJFlow (MLFlow
1174-
connection via MLFlowClient.jl), which replace the `logger` argument
1175-
with `logger::LoggerType`, where `LoggerType` is specific to the logging
1176-
platform.
1177-
"""
1178-
log_evaluation(logger, performance_evaluation) = nothing
1179-
11801189
# Evaluation when `resampling` is a TrainTestPairs (CORE EVALUATOR):
11811190
function evaluate!(mach::Machine, resampling, weights,
11821191
class_weights, rows, verbosity, repeats,
@@ -1285,8 +1294,8 @@ function evaluate!(mach::Machine, resampling, weights,
12851294
end
12861295

12871296
return PerformanceEvaluation(
1288-
measures,
12891297
mach.model,
1298+
measures,
12901299
per_measure,
12911300
operations,
12921301
per_fold,

src/utilities.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,66 @@ end
469469

470470
generate_name!(model, existing_names; kwargs...) =
471471
generate_name!(typeof(model), existing_names; kwargs...)
472+
473+
isamodel(::Any) = false
474+
isamodel(::Model) = true
475+
476+
"""
477+
deep_params(m::Model)
478+
479+
Recursively convert any object subtyping `Model` into a named tuple,
480+
keyed on the property names of `m`. The named tuple is possibly nested
481+
because `deep_params` is recursively applied to the property values, which
482+
themselves might subtype `Model`.
483+
484+
For most `Model` objects, properties are synonymous with fields, but
485+
this is not a hard requirement.
486+
487+
julia> deep_params(EnsembleModel(atom=ConstantClassifier()))
488+
(atom = (target_type = Bool,),
489+
weights = Float64[],
490+
bagging_fraction = 0.8,
491+
rng_seed = 0,
492+
n = 100,
493+
parallel = true,)
494+
495+
"""
496+
deep_params(m) = deep_params(m, Val(isamodel(m)))
497+
deep_params(m, ::Val{false}) = m
498+
function deep_params(m, ::Val{true})
499+
fields = propertynames(m)
500+
NamedTuple{fields}(Tuple([deep_params(getproperty(m, field))
501+
for field in fields]))
502+
end
503+
504+
"""
505+
flat_params(t::NamedTuple)
506+
507+
View a nested named tuple `t` as a tree and return, as a Dict, the key subtrees
508+
and the values at the leaves, in the order they appear in the original tuple.
509+
510+
```julia-repl
511+
julia> t = (X = (x = 1, y = 2), Y = 3)
512+
julia> flat_params(t)
513+
LittleDict{...} with 3 entries:
514+
"X__x" => 1
515+
"X__y" => 2
516+
"Y" => 3
517+
```
518+
"""
519+
function flat_params(parameters::NamedTuple)
520+
result = LittleDict{String, Any}()
521+
for key in keys(parameters)
522+
value = params(getproperty(parameters, key))
523+
if value isa NamedTuple
524+
sub_dict = flat_params(value)
525+
for (sub_key, sub_value) in pairs(sub_dict)
526+
new_key = string(key, "__", sub_key)
527+
result[new_key] = sub_value
528+
end
529+
else
530+
result[string(key)] = value
531+
end
532+
end
533+
return result
534+
end

test/utilities.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ struct Baz <: Foo end
2020
@test flat_values(t) == (1, 2, 3)
2121
end
2222

23+
@testset "flattening parameters" begin
24+
t = (a = (ax = (ax1 = 1, ax2 = 2), ay = 3), b = 4)
25+
dict_t = Dict(
26+
"a__ax__ax1" => 1,
27+
"a__ax__ax2" => 2,
28+
"a__ay" => 3,
29+
"b" => 4,
30+
)
31+
@test MLJBase.flat_params(t) == dict_t
32+
end
33+
2334
mutable struct M
2435
a1
2536
a2

0 commit comments

Comments
 (0)