Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions src/read.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,46 @@

# Convenience function to allow for things like Array(tp) or CuArray(tp)
# Not sure if this counts as type piracy...
(::Type{T})(p::TensorProto) where T = array(p) |> T
(::Type{Ref{T}})(p::TensorProto) where T = array(p) |> T |> Ref


"""
array(p::TensorProto)
array(p::TensorProto, wrap=Array)

Return `p` as an reshaped and reinterpreted array.
Return `p` as an `Array` of the correct type. Second argument can be used to change type of the returned array
"""
function array(p::TensorProto)
function array(p::TensorProto, wrap=Array)
# Copy pasted from jl
# Can probably be cleaned up a bit
# TODO: Add missing datatypes...
if p.data_type === TensorProto_DataType.INT64
if isdefined(p, :int64_data) && !isempty(p.int64_data)
return reshape(reinterpret(Int64, p.int64_data), reverse(p.dims)...)
return reshape(reinterpret(Int64, p.int64_data), reverse(p.dims)...) |> wrap
end
return reshape(reinterpret(Int64, p.raw_data), reverse(p.dims)...)
return reshape(reinterpret(Int64, p.raw_data), reverse(p.dims)...) |> wrap
end

if p.data_type === TensorProto_DataType.INT32
if isdefined(p, :int32_data) && !isempty(p.int32_data)
return reshape(p.int32_data , reverse(p.dims)...)
return reshape(p.int32_data , reverse(p.dims)...) |> wrap
end
return reshape(reinterpret(Int32, p.raw_data), reverse(p.dims)...)
return reshape(reinterpret(Int32, p.raw_data), reverse(p.dims)...) |> wrap
end

if p.data_type === TensorProto_DataType.INT8
return reshape(reinterpret(Int8, p.raw_data), reverse(p.dims)...)
return reshape(reinterpret(Int8, p.raw_data), reverse(p.dims)...) |> wrap
end

if p.data_type === TensorProto_DataType.DOUBLE
if isdefined(p, :double_data) && !isempty(p.double_data)
return reshape(p.double_data , reverse(p.dims)...)
return reshape(p.double_data , reverse(p.dims)...) |> wrap
end
return reshape(reinterpret(Float64, p.raw_data), reverse(p.dims)...)
return reshape(reinterpret(Float64, p.raw_data), reverse(p.dims)...) |> wrap
end

if p.data_type === TensorProto_DataType.FLOAT
if isdefined(p,:float_data) && !isempty(p.float_data)
return reshape(reinterpret(Float32, p.float_data), reverse(p.dims)...)
return reshape(reinterpret(Float32, p.float_data), reverse(p.dims)...) |> wrap
end
return reshape(reinterpret(Float32, p.raw_data), reverse(p.dims)...)
return reshape(reinterpret(Float32, p.raw_data), reverse(p.dims)...) |> wrap
end

if p.data_type === TensorProto_DataType.FLOAT16
return reshape(reinterpret(Float16, p.raw_data), reverse(p.dims)...)
return reshape(reinterpret(Float16, p.raw_data), reverse(p.dims)...) |> wrap
end
end

Expand Down
10 changes: 4 additions & 6 deletions test/readwrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@
end

@testset "TensorProto" begin
import BaseOnnx: TensorProto
import BaseOnnx: TensorProto, array

@testset "Tensor type $T size $s" for T in (Int8, Int32, Int64, Float16, Float32, Float64), s in ((1,),
(1, 2),
(1, 2, 3),
(1, 2, 3, 4),
(1, 2, 3, 4, 5))
exp = reshape(collect(T, 1:prod(s)), s...)
@test TensorProto(exp) |> serdeser |> Array == exp

@test TensorProto(exp) |> serdeser |> array == exp
end

end

@testset "ValueInfo" begin
Expand All @@ -41,7 +39,7 @@
end

@testset "Attribute" begin
import BaseOnnx: AttributeProto, TensorProto, attribute
import BaseOnnx: AttributeProto, TensorProto, attribute, array

@testset "Attribute type $(first(p))" for p in (
:Int64 => 12,
Expand All @@ -62,7 +60,7 @@
@testset "Attribute type TensorProto" begin
# TensorProto has undef fields which mess up straigh comparison
arr = collect(1:4)
@test AttributeProto(:ff => TensorProto(arr)) |> serdeser |> attribute |> last |> Array == arr
@test AttributeProto(:ff => TensorProto(arr)) |> serdeser |> attribute |> last |> array == arr
end

@testset "Attribute Dict" begin
Expand Down