Skip to content

Commit 15933c1

Browse files
use Base definition of stack
1 parent ce7b203 commit 15933c1

File tree

5 files changed

+9
-71
lines changed

5 files changed

+9
-71
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "MLUtils"
22
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
33
authors = ["Carlo Lucibello <[email protected]> and contributors"]
4-
version = "0.2.11"
4+
version = "0.3.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
910
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
1011
FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443"

src/MLUtils.jl

+6-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ import ChainRulesCore: rrule
1414
using ChainRulesCore: @non_differentiable, unthunk, AbstractZero,
1515
NoTangent, ZeroTangent, ProjectTo
1616

17+
if VERSION < v"1.9.0-DEV.1163"
18+
import Compat: stack
19+
else
20+
import Base: stack
21+
end
1722

1823
include("observation.jl")
1924
export numobs,
@@ -66,7 +71,7 @@ export batch,
6671
ones_like,
6772
rand_like,
6873
randn_like,
69-
stack,
74+
stack, # in Base since julia v1.9
7075
unbatch,
7176
unsqueeze,
7277
unstack,

src/deprecations.jl

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Deprecations v0.1
2-
@deprecate stack(x, dims) stack(x; dims=dims)
32
@deprecate unstack(x, dims) unstack(x; dims=dims)
43
@deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims)
54
@deprecate unsqueeze(dims::Int) unsqueeze(dims=dims)

src/utils.jl

-66
Original file line numberDiff line numberDiff line change
@@ -57,72 +57,6 @@ _unsqueeze(x, dims) = unsqueeze(x; dims)
5757

5858
Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, "unsqueeze(dims=", u.x, ")")
5959

60-
"""
61-
stack(xs; dims)
62-
63-
Concatenate the given array of arrays `xs` into a single array along the
64-
new dimension `dims`. All arrays need to be of the same size.
65-
66-
See also [`unsqueeze`](@ref), [`unstack`](@ref) and [`batch`](@ref).
67-
68-
# Examples
69-
70-
```jldoctest
71-
julia> xs = [[1, 2], [3, 4], [5, 6]]
72-
3-element Vector{Vector{Int64}}:
73-
[1, 2]
74-
[3, 4]
75-
[5, 6]
76-
77-
julia> stack(xs, dims=1)
78-
3×2 Matrix{Int64}:
79-
1 2
80-
3 4
81-
5 6
82-
83-
julia> stack(xs, dims=2)
84-
2×3 Matrix{Int64}:
85-
1 3 5
86-
2 4 6
87-
88-
julia> stack(xs, dims=3)
89-
2×1×3 Array{Int64, 3}:
90-
[:, :, 1] =
91-
1
92-
2
93-
94-
[:, :, 2] =
95-
3
96-
4
97-
98-
[:, :, 3] =
99-
5
100-
6
101-
```
102-
"""
103-
function stack(xs; dims::Int)
104-
N = ndims(xs[1])
105-
if dims <= N
106-
vs = unsqueeze.(xs; dims)
107-
else
108-
vs = xs
109-
end
110-
if dims == 1
111-
return reduce(vcat, vs)
112-
elseif dims === 2
113-
return reduce(hcat, vs)
114-
else
115-
return reduce((x, y) -> cat(x, y; dims=dims), vs)
116-
end
117-
end
118-
119-
function rrule(::typeof(stack), xs; dims::Int)
120-
function stack_pullback(Δ)
121-
return (NoTangent(), unstack(unthunk(Δ); dims=dims))
122-
end
123-
return stack(xs; dims=dims), stack_pullback
124-
end
125-
12660
"""
12761
unstack(xs; dims)
12862

test/utils.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414
x = randn(3,3)
1515
stacked = stack([x, x], dims=2)
1616
@test size(stacked) == (3,2,3)
17-
@test_broken @inferred(stack([x, x], dims=2)) == stacked
17+
@test @inferred(stack([x, x], dims=2)) == stacked
1818

1919
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
2020
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
@@ -30,7 +30,6 @@ end
3030
a = [[1] for i in 1:10000]
3131
@test size(stack(a, dims=1)) == (10000, 1)
3232
@test size(stack(a, dims=2)) == (1, 10000)
33-
@test size(stack(a, dims=3)) == (1, 1, 10000)
3433
end
3534

3635
@testset "batch and unbatch" begin

0 commit comments

Comments
 (0)