Skip to content

Commit

Permalink
Support Distributions.params
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Jan 22, 2024
1 parent 496188e commit a62ac0a
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/mv_binned_dist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ Base.length(d::MvBinnedDist{T,N}) where {T,N} = N
Base.size(d::MvBinnedDist{T,N}) where {T,N} = (N,)
Base.eltype(d::MvBinnedDist{T,N}) where {T,N} = T

Distributions.params(d::MvBinnedDist) = (d._edges, d._bin_pdf, d._closed_left)

Statistics.mean(d::MvBinnedDist) = d._mean
StatsBase.mode(d::MvBinnedDist) = d._mode
Statistics.var(d::MvBinnedDist) = d._var
Expand Down
1 change: 1 addition & 0 deletions src/uv_binned_dist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Base.length(d::UvBinnedDist) = 1
Base.size(d::UvBinnedDist) = ()
Base.eltype(d::UvBinnedDist{T}) where {T} = T

Distributions.params(d::UvBinnedDist) = (d._edge, d._bin_pdf, d._closed_left)

Statistics.mean(d::UvBinnedDist) = d._mean
StatsBase.mode(d::UvBinnedDist) = d._mode
Expand Down
2 changes: 2 additions & 0 deletions test/test_mv_binned_dist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ using Adapt
@test @inferred(size(d)) == (2,)
@test @inferred(eltype(d)) == Float64

@test @inferred(params(d)) == (d._edges, d._bin_pdf, d._closed_left)

@test all(isapprox.(mean(true_dist), @inferred(mean(d)), atol = 0.01))
@test all(isapprox.(mode(true_dist), @inferred(mode(d)), atol = 0.2))
@test all(isapprox.(var(true_dist), @inferred(var(d)), atol = 0.01))
Expand Down
2 changes: 2 additions & 0 deletions test/test_uv_binned_dist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ using Adapt, ForwardDiff
@test @inferred(size(d)) == ()
@test @inferred(eltype(d)) == Float64

@test @inferred(params(d)) == (d._edge, d._bin_pdf, d._closed_left)

@test all(isapprox.(mean(true_dist), @inferred(mean(d)), atol = 0.01))
@test all(isapprox.(mode(true_dist), @inferred(mode(d)), atol = 0.05))
@test all(isapprox.(var(true_dist), @inferred(var(d)), atol = 0.01))
Expand Down

0 comments on commit a62ac0a

Please sign in to comment.