Skip to content

Commit 202c928

Browse files
authored
Update named tuple distribution
1 parent 33b9da2 commit 202c928

File tree

3 files changed

+14
-28
lines changed

3 files changed

+14
-28
lines changed

src/namedtuple/productnamedtuple.jl

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,31 +38,17 @@ julia> var(d) # var of marginals
3838
(x = 1.0, y = [0.031746031746031744, 0.031746031746031744])
3939
```
4040
"""
41-
struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <:
41+
struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport} <:
4242
Distribution{NamedTupleVariate{Tnames},S}
4343
dists::NamedTuple{Tnames,Tdists}
4444
end
4545
function ProductNamedTupleDistribution(
4646
dists::NamedTuple{K,V}
4747
) where {K,V<:Tuple{Distribution,Vararg{Distribution}}}
4848
vs = _product_valuesupport(values(dists))
49-
eltypes = _product_namedtuple_eltype(values(dists))
50-
return ProductNamedTupleDistribution{K,V,vs,eltypes}(dists)
49+
return ProductNamedTupleDistribution{K,V,vs}(dists)
5150
end
5251

53-
_gentype(d::UnivariateDistribution) = eltype(d)
54-
_gentype(d::Distribution{<:ArrayLikeVariate{S}}) where {S} = Array{eltype(d),S}
55-
function _gentype(d::Distribution{CholeskyVariate})
56-
T = eltype(d)
57-
return LinearAlgebra.Cholesky{T,Matrix{T}}
58-
end
59-
function _gentype(d::ProductNamedTupleDistribution{K}) where {K}
60-
return NamedTuple{K,Tuple{map(_gentype, values(d.dists))...}}
61-
end
62-
_gentype(::Distribution) = Any
63-
64-
_product_namedtuple_eltype(dists) = typejoin(map(_gentype, dists)...)
65-
6652
function Base.show(io::IO, d::ProductNamedTupleDistribution)
6753
return show_multline(io, d, collect(pairs(d.dists)))
6854
end
@@ -88,8 +74,6 @@ end
8874

8975
# Properties
9076

91-
Base.eltype(::Type{<:ProductNamedTupleDistribution{<:Any,<:Any,<:Any,T}}) where {T} = T
92-
9377
Base.minimum(d::ProductNamedTupleDistribution) = map(minimum, d.dists)
9478

9579
Base.maximum(d::ProductNamedTupleDistribution) = map(maximum, d.dists)
@@ -166,9 +150,9 @@ end
166150
function Base.rand(
167151
rng::AbstractRNG, d::ProductNamedTupleDistribution{K}, dims::Dims
168152
) where {K}
169-
return convert(AbstractArray{<:NamedTuple{K}}, _rand(rng, sampler(d), dims))
153+
return rand(rng, sampler(d), dims)
170154
end
171155

172-
function _rand!(rng::AbstractRNG, d::ProductNamedTupleDistribution, xs::AbstractArray)
173-
return _rand!(rng, sampler(d), xs)
156+
Base.@propagate_inbounds function Base.rand!(rng::AbstractRNG, d::ProductNamedTupleDistribution, xs::AbstractArray)
157+
return rand!(rng, sampler(d), xs)
174158
end

src/samplers/productnamedtuple.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ function Base.rand(rng::AbstractRNG, spl::ProductNamedTupleSampler{K}) where {K}
77
return NamedTuple{K}(map(Base.Fix1(rand, rng), spl.samplers))
88
end
99

10-
function _rand(rng::AbstractRNG, spl::ProductNamedTupleSampler, dims::Dims)
11-
return map(CartesianIndices(dims)) do _
12-
return rand(rng, spl)
13-
end
10+
function Base.rand(rng::AbstractRNG, s::ProductNamedTupleSampler, dims::Dims)
11+
r = rand(rng, s)
12+
out = Array{typeof(r)}(undef, dims)
13+
out[1] = r
14+
rand!(rng, s, @view(out[2:end]))
15+
return out
1416
end
1517

16-
function _rand!(
18+
function Base.rand!(
1719
rng::AbstractRNG, spl::ProductNamedTupleSampler, xs::AbstractArray{<:NamedTuple{K}}
1820
) where {K}
1921
for i in eachindex(xs)

test/namedtuple/productnamedtuple.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ using Test
6868
(x=product_distribution((x=Normal(), y=Gamma())),),
6969
]
7070
d = ProductNamedTupleDistribution(nt)
71-
@test eltype(d) === eltype(rand(d))
71+
@test @test_deprecated(eltype(d)) === eltype(rand(d))
7272
end
7373
end
7474

@@ -168,7 +168,7 @@ using Test
168168
d = ProductNamedTupleDistribution(nt)
169169
rng = MersenneTwister(973)
170170
x1 = @inferred rand(rng, d)
171-
@test eltype(x1) === eltype(d)
171+
@test eltype(x1) === @test_deprecated(eltype(d))
172172
rng = MersenneTwister(973)
173173
x2 = (
174174
x=rand(rng, nt.x), y=rand(rng, nt.y), z=rand(rng, nt.z), w=rand(rng, nt.w)

0 commit comments

Comments
 (0)