@@ -7,6 +7,7 @@ using TypedTables: Table
77using DataAPI: refarray, refvalue
88using Adapt: adapt, Adapt
99using JLArrays
10+ using Random
1011using Test
1112
1213using Documenter: doctest
@@ -1151,29 +1152,39 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
11511152 # used inside of broadcast but we also test it here explicitly
11521153 @test isa (@inferred (Base. dataids (s)), NTuple{N, UInt} where {N})
11531154
1154- # Make sure we can handle style with similar defined
1155- # And we can handle most conflicts
1156- # `s1` and `s2` have similar defined, but `s3` does not
1157- # `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle`
1158- s1 = StructArray {ComplexF64} ((MyArray1 (rand (2 )), MyArray1 (rand (2 ))))
1159- s2 = StructArray {ComplexF64} ((MyArray2 (rand (2 )), MyArray2 (rand (2 ))))
1160- s3 = StructArray {ComplexF64} ((MyArray3 (rand (2 )), MyArray3 (rand (2 ))))
1161- s4 = StructArray {ComplexF64} ((rand (2 ), rand (2 )))
1162-
1163- function _test_similar (a, b, c)
1164- try
1165- d = StructArray {ComplexF64} ((a. re .+ b. re .- c. re, a. im .+ b. im .- c. im))
1166- @test typeof (a .+ b .- c) == typeof (d)
1167- catch
1168- @test_throws MethodError a .+ b .- c
1155+
1156+ @testset " style conflict check" begin
1157+ using StructArrays: StructArrayStyle
1158+ # Make sure we can handle style with similar defined
1159+ # And we can handle most conflicts
1160+ # `s1` and `s2` have similar defined, but `s3` does not
1161+ # `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle`
1162+ s1 = StructArray {ComplexF64} ((MyArray1 (rand (2 )), MyArray1 (rand (2 ))))
1163+ s2 = StructArray {ComplexF64} ((MyArray2 (rand (2 )), MyArray2 (rand (2 ))))
1164+ s3 = StructArray {ComplexF64} ((MyArray3 (rand (2 )), MyArray3 (rand (2 ))))
1165+ s4 = StructArray {ComplexF64} ((rand (2 ), rand (2 )))
1166+ test_set = Any[s1, s2, s3, s4]
1167+ tested_style = Any[]
1168+ dotaddadd ((a, b, c),) = @. a + b + c
1169+ for is in Iterators. product (randperm (4 ), randperm (4 ), randperm (4 ))
1170+ as = map (i -> test_set[i], is)
1171+ ares = map (a-> a. re, as)
1172+ aims = map (a-> a. im, as)
1173+ style = Broadcast. combine_styles (ares... )
1174+ if ! (style in tested_style)
1175+ push! (tested_style, style)
1176+ if style isa Broadcast. ArrayStyle{MyArray3}
1177+ @test_throws MethodError dotaddadd (as)
1178+ else
1179+ d = StructArray {ComplexF64} ((dotaddadd (ares), dotaddadd (aims)))
1180+ @test @inferred (dotaddadd (as)):: typeof (d) == d
1181+ end
1182+ end
11691183 end
1184+ @test length (tested_style) == 5
11701185 end
1171- for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4)
1172- _test_similar (s, s′, s″)
1173- end
1174-
11751186 # test for dimensionality track
1176- s = s1
1187+ s = StructArray {ComplexF64} (( MyArray1 ( rand ( 2 )), MyArray1 ( rand ( 2 ))))
11771188 @test Base. broadcasted (+ , s, s) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{1} }
11781189 @test Base. broadcasted (+ , s, 1 : 2 ) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{1} }
11791190 @test Base. broadcasted (+ , s, reshape (1 : 2 ,1 ,2 )) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{2} }
@@ -1197,22 +1208,25 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
11971208 @test (x -> x. x. x. a). (StructArray (x= StructArray (x= StructArray (a= 1 : 3 )))) == [1 , 2 , 3 ]
11981209
11991210 @testset " ambiguity check" begin
1200- function _test (a, b, c)
1201- if a isa StructArray || b isa StructArray || c isa StructArray
1202- d = @inferred a .+ b .- c
1203- @test d == collect (a) .+ collect (b) .- collect (c)
1204- @test d isa StructArray
1205- end
1206- end
1207- testset = Any[StructArray ([1 ;2 + im]),
1211+ test_set = Any[StructArray ([1 ;2 + im]),
12081212 1 : 2 ,
12091213 (1 ,2 ),
1210- StructArray (@SArray [1 1 + 2im ]),
1211- (@SArray [1 2 ])
1212- ]
1213- for aa in testset, bb in testset, cc in testset
1214- _test (aa, bb, cc)
1214+ StructArray (@SArray [1 ;1 + 2im ]),
1215+ (@SArray [1 2 ]),
1216+ 1 ]
1217+ tested_style = StructArrayStyle[]
1218+ dotaddsub ((a, b, c),) = @. a + b - c
1219+ for is in Iterators. product (randperm (6 ), randperm (6 ), randperm (6 ))
1220+ as = map (i -> test_set[i], is)
1221+ if any (a -> a isa StructArray, as)
1222+ style = Broadcast. combine_styles (as... )
1223+ if ! (style in tested_style)
1224+ push! (tested_style, style)
1225+ @test @inferred (dotaddsub (as)):: StructArray == dotaddsub (map (collect, as))
1226+ end
1227+ end
12151228 end
1229+ @test length (tested_style) == 4
12161230 end
12171231
12181232 @testset " StructStaticArray" begin
0 commit comments