Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions src/interface/eager.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,18 @@ function Statistics.std(
end

function reshape_plan(tns, dims)
if ndims(tns) == 0
combine_mask = ()
split_mask = ntuple(i -> (i:i...,), length(dims))
return combine_mask, split_mask
end

if length(dims) == 0
combine_mask = ndims(tns) == 0 ? () : ((1:ndims(tns)...,),)
split_mask = ()
return combine_mask, split_mask
end

num_colon = count(x -> x === Colon(), dims)
if num_colon > 1
throw(ArgumentError("Only one colon is allowed in the reshape dimensions."))
Expand Down Expand Up @@ -415,6 +427,14 @@ function splitdims_rep_def(tns::RepeatData, dims, mask)
res
end

function splitdims_rep_def(tns::ElementData, dims, mask)
res = splitdims_rep(tns, mask...)
for dim in dims
res = ExtrudeData(res)
end
res
end

@staged function reshape_constructor(tns, dims, combine_mask, split_mask)
combine_mask = combine_mask.parameters[1]
split_mask = split_mask.parameters[1]
Expand All @@ -438,6 +458,16 @@ end
dst_idxs = [Symbol(:j_, m) for m in 1:M]
dst_dims = [Symbol(:dst_dim_, m) for m in 1:M]

if N == 0
return quote
@finch begin
dst .= 0
dst[$(fill(1, M)...)] = src[]
end
return dst
end
end

for (combine_group, split_group) in zip(combine_mask, split_mask)
src_tmps[combine_group[end]] = src_idxs[combine_group[end]]
flat_idx = src_tmps[combine_group[1]]
Expand Down Expand Up @@ -510,8 +540,7 @@ end
return unblock(striplines(res))
end

Base.reshape(tns::AbstractTensor, dims::Union{Integer,Colon}...) =
reshape(tns, (dims...,))
Base.reshape(tns::AbstractTensor, dims::Union{Integer,Colon}...) = reshape(tns, (dims...,))
function Base.reshape(
tns::SwizzleArray{perm}, dims::Tuple{Vararg{Union{Integer,Colon}}}
) where {perm}
Expand Down Expand Up @@ -545,9 +574,15 @@ function Base.reshape(tns::AbstractTensor, dims::Tuple{Vararg{Union{Integer,Colo
)
end
end

(combine_mask, split_mask) = reshape_plan(tns, dims)
dst = reshape_constructor(tns, dims, Val(combine_mask), Val(split_mask))
reshape_kernel(dst, tns, dims, Val(combine_mask), Val(split_mask))

if length(dims) == 0
Tensor(Element(default(tns), tns[fill(1, ndims(tns))...]))
else
dst = reshape_constructor(tns, dims, Val(combine_mask), Val(split_mask))
reshape_kernel(dst, tns, dims, Val(combine_mask), Val(split_mask))
end
end
function reshape!(dst, src::AbstractTensor, dims::Union{Integer,Colon}...)
reshape!(dst, src, dims)
Expand Down
10 changes: 10 additions & 0 deletions test/suites/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,16 @@ end
A = swizzle(Tensor(Dense(Sparse(Element(0))), LinearIndices((6, 6))), 2, 1)
B = permutedims(LinearIndices((6, 6)))
@test reshape(A, 3, 12) == reshape(B, 3, 12)

T1 = swizzle(Tensor(Element(0, 0)))
@test reshape(T1) == reshape(fill(0), ())
@test reshape(T1, 1) == reshape(fill(0), 1)
@test reshape(T1, 1, 1) == reshape(fill(0), 1, 1)

T2 = swizzle(Tensor([0]), 1)
@test reshape(T2) == reshape([0], ())
@test reshape(T2, 1) == reshape([0], 1)
@test reshape(T2, 1, 1) == reshape([0], 1, 1)
end

let
Expand Down
Loading