@@ -1412,6 +1412,43 @@ struct test_conv_transpose_1d : public test_case {
14121412 }
14131413};
14141414
1415+ struct test_conv_transpose_1d_gemm : public test_case {
1416+ const std::array<int64_t , 4 > ne_input;
1417+ const std::array<int64_t , 4 > ne_kernel;
1418+
1419+ const int s0; // stride
1420+ const int p0; // padding
1421+ const int d0; // dilation
1422+
1423+ ggml_type input_type;
1424+ ggml_type kernel_type;
1425+
1426+ std::string vars () override {
1427+ return VARS_TO_STR5 (ne_input, ne_kernel, s0, p0, d0);
1428+ }
1429+
1430+ test_conv_transpose_1d_gemm (std::array<int64_t , 4 > ne_input = {197 , 32 , 1 , 1 }, // [input_width, input_height, input_channels, 1]
1431+ std::array<int64_t , 4 > ne_kernel = {16 , 32 , 32 , 1 }, // [kernel_width, kernel_height, input_channels, 1]
1432+ int s0 = 1 , int p0 = 0 , int d0 = 1 ,
1433+ ggml_type input_type = GGML_TYPE_F32,
1434+ ggml_type kernel_type = GGML_TYPE_F16)
1435+ : ne_input(ne_input)
1436+ , ne_kernel(ne_kernel)
1437+ , s0(s0)
1438+ , p0(p0)
1439+ , d0(d0)
1440+ , input_type(input_type)
1441+ , kernel_type(kernel_type)
1442+ {}
1443+
1444+ ggml_tensor * build_graph (ggml_context * ctx) override {
1445+ ggml_tensor * input = ggml_new_tensor (ctx, input_type, 4 , ne_input.data ());
1446+ ggml_tensor * kernel = ggml_new_tensor (ctx, kernel_type, 4 , ne_kernel.data ());
1447+ ggml_tensor * out = ggml_conv_transpose_1d_gemm (ctx, kernel, input, s0, p0, d0);
1448+ return out;
1449+ }
1450+ };
1451+
14151452// GGML_OP_IM2COL
14161453struct test_im2col : public test_case {
14171454 const ggml_type type_input;
@@ -2330,6 +2367,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
23302367 test_cases.emplace_back (new test_conv_transpose_1d ({3 ,2 ,1 ,1 }, {3 ,1 ,2 ,1 }, 1 , 0 , 1 ));
23312368 test_cases.emplace_back (new test_conv_transpose_1d ({2 ,1 ,1 ,1 }, {3 ,1 ,1 ,1 }, 1 , 0 , 1 ));
23322369
2370+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ());
2371+ for (int64_t s0 = 1 ; s0 < 4 ; ++s0) {
2372+ for (int64_t p0 = 0 ; p0 < 2 ; ++p0) {
2373+ for (int64_t d0 = 1 ; d0 < 4 ; ++d0) {
2374+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {2 ,3 ,2 ,1 }, s0, p0, d0));
2375+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,2 ,2 ,1 }, s0, p0, d0));
2376+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,1 ,2 ,1 }, s0, p0, d0));
2377+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({2 ,1 ,1 ,1 }, {3 ,1 ,1 ,1 }, s0, p0, d0));
2378+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {2 ,3 ,2 ,1 },
2379+ s0, p0, d0, GGML_TYPE_F16));
2380+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,2 ,2 ,1 },
2381+ s0, p0, d0, GGML_TYPE_F16));
2382+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,1 ,2 ,1 },
2383+ s0, p0, d0, GGML_TYPE_F16));
2384+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({2 ,1 ,1 ,1 }, {3 ,1 ,1 ,1 },
2385+ s0, p0, d0, GGML_TYPE_F16));
2386+ }
2387+ }
2388+ }
23332389
23342390 test_cases.emplace_back (new test_repeat (GGML_TYPE_F32, {10 , 10 , 10 , 10 }, {1 , 1 , 1 , 1 }));
23352391 test_cases.emplace_back (new test_repeat (GGML_TYPE_F32, {10 , 10 , 10 , 10 }, {2 , 1 , 1 , 1 }));
0 commit comments