Skip to content

Commit a98a501

Browse files
authored
Merge pull request #32 from JuliaDecisionFocusedLearning/cleanup-argmax
Improve argmax benchmarks
2 parents fd4c1f8 + 475e7b4 commit a98a501

File tree

6 files changed

+35
-88
lines changed

6 files changed

+35
-88
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DecisionFocusedLearningBenchmarks"
22
uuid = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
33
authors = ["Members of JuliaDecisionFocusedLearning"]
4-
version = "0.2.3"
4+
version = "0.2.4"
55

66
[deps]
77
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"

src/Argmax/Argmax.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothin
4040
return ArgmaxBenchmark(instance_dim, nb_features, model)
4141
end
4242

43+
function Utils.is_minimization_problem(::ArgmaxBenchmark)
44+
return false
45+
end
46+
4347
"""
4448
$TYPEDSIGNATURES
4549

src/Argmax2D/Argmax2D.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,15 @@ Custom constructor for [`Argmax2DBenchmark`](@ref).
4040
"""
4141
function Argmax2DBenchmark(; nb_features::Int=5, seed=nothing, polytope_vertex_range=[6])
4242
Random.seed!(seed)
43-
model = Chain(Dense(nb_features => 2; bias=false), vec)
43+
model = Dense(nb_features => 2; bias=false)
4444
return Argmax2DBenchmark(nb_features, model, polytope_vertex_range)
4545
end
4646

47-
maximizer(θ; instance) = instance[argmax(dot(θ, v) for v in instance)]
47+
function Utils.is_minimization_problem(::Argmax2DBenchmark)
48+
return false
49+
end
50+
51+
maximizer(θ; instance, kwargs...) = instance[argmax(dot(θ, v) for v in instance)]
4852

4953
"""
5054
$TYPEDSIGNATURES
@@ -56,7 +60,7 @@ function Utils.generate_dataset(
5660
)
5761
(; nb_features, encoder, polytope_vertex_range) = bench
5862
return map(1:dataset_size) do _
59-
x = randn(rng, nb_features)
63+
x = randn(rng, Float32, nb_features)
6064
θ_true = encoder(x)
6165
θ_true ./= 2 * norm(θ_true)
6266
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
@@ -84,23 +88,30 @@ function Utils.generate_statistical_model(
8488
)
8589
Random.seed!(rng, seed)
8690
(; nb_features) = bench
87-
model = Chain(Dense(nb_features => 2; bias=false), vec)
91+
model = Dense(nb_features => 2; bias=false)
8892
return model
8993
end
9094

95+
function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...)
96+
pl = init_plot()
97+
plot_polytope!(pl, instance)
98+
plot_objective!(pl, θ)
99+
return plot_maximizer!(pl, θ, instance, maximizer)
100+
end
101+
91102
"""
92103
$TYPEDSIGNATURES
93104
94105
Plot the data sample for the [`Argmax2DBenchmark`](@ref).
95106
"""
96107
function Utils.plot_data(
97-
::Argmax2DBenchmark, sample::DataSample; θ_true=sample.θ_true, kwargs...
108+
bench::Argmax2DBenchmark,
109+
sample::DataSample;
110+
instance=sample.instance,
111+
θ=sample.θ_true,
112+
kwargs...,
98113
)
99-
(; instance) = sample
100-
pl = init_plot()
101-
plot_polytope!(pl, instance)
102-
plot_objective!(pl, θ_true)
103-
return plot_maximizer!(pl, θ_true, instance, maximizer)
114+
return Utils.plot_data(bench; instance, θ, kwargs...)
104115
end
105116

106117
export Argmax2DBenchmark

src/Argmax2D/polytope.jl

Lines changed: 3 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,16 @@ function plot_polytope!(pl, vertices)
2121
fillcolor=:gray,
2222
fillalpha=0.2,
2323
linecolor=:black,
24-
label=L"\mathrm{conv}(\mathcal{V})",
24+
label=L"\mathrm{conv}(\mathcal{Y}(x))",
2525
)
2626
end;
2727

28-
const logocolors = Colors.JULIA_LOGO_COLORS
29-
3028
function plot_objective!(pl, θ)
3129
Plots.plot!(
3230
pl,
3331
[0.0, θ[1]],
3432
[0.0, θ[2]];
35-
color=logocolors.purple,
33+
color=Colors.JULIA_LOGO_COLORS.purple,
3634
arrow=true,
3735
lw=2,
3836
label=nothing,
@@ -47,81 +45,9 @@ function plot_maximizer!(pl, θ, instance, maximizer)
4745
pl,
4846
[ŷ[1]],
4947
[ŷ[2]];
50-
color=logocolors.red,
48+
color=Colors.JULIA_LOGO_COLORS.red,
5149
markersize=9,
5250
markershape=:square,
5351
label=L"f(\theta)",
5452
)
5553
end;
56-
57-
# function get_angle(v)
58-
# @assert !(norm(v) ≈ 0)
59-
# v = v ./ norm(v)
60-
# if v[2] >= 0
61-
# return acos(v[1])
62-
# else
63-
# return π + acos(-v[1])
64-
# end
65-
# end;
66-
67-
# function plot_distribution!(pl, probadist)
68-
# A = probadist.atoms
69-
# As = sort(A; by=get_angle)
70-
# p = probadist.weights
71-
# Plots.plot!(
72-
# pl,
73-
# vcat(map(first, As), first(As[1])),
74-
# vcat(map(last, As), last(As[1]));
75-
# fillrange=0,
76-
# fillcolor=:blue,
77-
# fillalpha=0.1,
78-
# linestyle=:dash,
79-
# linecolor=logocolors.blue,
80-
# label=L"\mathrm{conv}(\hat{p}(\theta))",
81-
# )
82-
# return Plots.scatter!(
83-
# pl,
84-
# map(first, A),
85-
# map(last, A);
86-
# markersize=25 .* p .^ 0.5,
87-
# markercolor=logocolors.blue,
88-
# markerstrokewidth=0,
89-
# markeralpha=0.4,
90-
# label=L"\hat{p}(\theta)",
91-
# )
92-
# end;
93-
94-
# function plot_expectation!(pl, probadist)
95-
# ŷΩ = compute_expectation(probadist)
96-
# return scatter!(
97-
# pl,
98-
# [ŷΩ[1]],
99-
# [ŷΩ[2]];
100-
# color=logocolors.blue,
101-
# markersize=6,
102-
# markershape=:hexagon,
103-
# label=L"\hat{f}(\theta)",
104-
# )
105-
# end;
106-
107-
# function compress_distribution!(
108-
# probadist::FixedAtomsProbabilityDistribution{A,W}; atol=0
109-
# ) where {A,W}
110-
# (; atoms, weights) = probadist
111-
# to_delete = Int[]
112-
# for i in length(probadist):-1:1
113-
# ai = atoms[i]
114-
# for j in 1:(i - 1)
115-
# aj = atoms[j]
116-
# if isapprox(ai, aj; atol=atol)
117-
# weights[j] += weights[i]
118-
# push!(to_delete, i)
119-
# break
120-
# end
121-
# end
122-
# end
123-
# sort!(to_delete)
124-
# deleteat!(atoms, to_delete)
125-
# deleteat!(weights, to_delete)
126-
# return probadist
127-
# end;

test/argmax.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
model = generate_statistical_model(b)
1515
maximizer = generate_maximizer(b)
1616

17+
gap = compute_gap(b, dataset, model, maximizer)
18+
@test gap >= 0
19+
1720
for (i, sample) in enumerate(dataset)
1821
(; x, θ_true, y_true) = sample
1922
@test size(x) == (nb_features, instance_dim)

test/argmax_2d.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
model = generate_statistical_model(b)
1414
maximizer = generate_maximizer(b)
1515

16+
gap = compute_gap(b, dataset, model, maximizer)
17+
@test gap >= 0
18+
1619
# Test plot_data
1720
figure = plot_data(b, dataset[1])
1821
@test figure isa Plots.Plot

0 commit comments

Comments
 (0)