55
66Return `x` reshaped into an array one dimensionality higher than `x`,
77where `dims` indicates in which dimension `x` is extended.
8+ `dims` can be an integer between 1 and `ndims(x)+1`.
89
910See also [`flatten`](@ref), [`stack`](@ref).
1011
@@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1)
3334 [1, 2] [3, 4] [5, 6]
3435```
3536"""
36- function unsqueeze (x:: AbstractArray ; dims:: Int )
37- sz = ntuple (i -> i < dims ? size (x, i) : i == dims ? 1 : size (x, i - 1 ), ndims (x) + 1 )
37+ function unsqueeze (x:: AbstractArray{T,N} ; dims:: Int ) where {T, N}
38+ @assert 1 <= dims <= N + 1
39+ sz = ntuple (i -> i < dims ? size (x, i) : i == dims ? 1 : size (x, i - 1 ), N + 1 )
3840 return reshape (x, sz)
3941end
4042
@@ -55,51 +57,6 @@ _unsqueeze(x, dims) = unsqueeze(x; dims)
5557
5658Base. show_function (io:: IO , u:: Base.Fix2{typeof(_unsqueeze)} , :: Bool ) = print (io, " unsqueeze(dims=" , u. x, " )" )
5759
58- """
59- stack(xs; dims)
60-
61- Concatenate the given array of arrays `xs` into a single array along the
62- given dimension `dims`.
63-
64- See also [`stack`](@ref) and [`batch`](@ref).
65-
66- # Examples
67-
68- ```jldoctest
69- julia> xs = [[1, 2], [3, 4], [5, 6]]
70- 3-element Vector{Vector{Int64}}:
71- [1, 2]
72- [3, 4]
73- [5, 6]
74-
75- julia> stack(xs, dims=1)
76- 3×2 Matrix{Int64}:
77- 1 2
78- 3 4
79- 5 6
80-
81- julia> stack(xs, dims=2)
82- 2×3 Matrix{Int64}:
83- 1 3 5
84- 2 4 6
85-
86- julia> stack(xs, dims=3)
87- 2×1×3 Array{Int64, 3}:
88- [:, :, 1] =
89- 1
90- 2
91-
92- [:, :, 2] =
93- 3
94- 4
95-
96- [:, :, 3] =
97- 5
98- 6
99- ```
100- """
101- stack (xs; dims:: Int ) = cat (unsqueeze .(xs; dims)... ; dims)
102-
10360"""
10461 unstack(xs; dims)
10562
329286
330287batchindex (xs, i) = (reverse (Base. tail (reverse (axes (xs))))... , i)
331288
332- function batch (xs:: AbstractArray{<:AbstractArray} )
333- # Don't use stack(xs, dims=N+1), it is much slower.
334- # Here we do reduce(vcat, xs) along with some reshapes.
335- szxs = size (xs)
336- @assert length (xs) > 0 " Minimum batch size is 1."
337- szx = size (xs[1 ])
338- @assert all (x -> size (x) == szx, xs) " All arrays must be of the same size."
339- vxs = vec (vec .(xs))
340- y = reduce (vcat, vxs)
341- return reshape (y, szx... , szxs... )
342- end
289+ batch (xs:: AbstractArray{<:AbstractArray} ) = stack (xs)
343290
344291function batch (xs:: Vector{<:Tuple} )
345292 @assert length (xs) > 0 " Input should be non-empty"
0 commit comments