From 982ad75078d1b300397733baaf4ba887dd01aad7 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 19 Nov 2025 06:46:41 -0800 Subject: [PATCH 01/10] Extend ufunc f/w to support binary funcs with two output arrays --- .../elementwise_functions/common.hpp | 390 ++++++++++++++++++ .../elementwise_functions.hpp | 273 ++++++++++++ .../simplify_iteration_space.cpp | 114 +++++ .../simplify_iteration_space.hpp | 21 + .../type_dispatch_building.hpp | 18 + 5 files changed, 816 insertions(+) diff --git a/dpnp/backend/extensions/elementwise_functions/common.hpp b/dpnp/backend/extensions/elementwise_functions/common.hpp index ff5adb1400c5..df2b3afe53b9 100644 --- a/dpnp/backend/extensions/elementwise_functions/common.hpp +++ b/dpnp/backend/extensions/elementwise_functions/common.hpp @@ -316,6 +316,210 @@ struct UnaryTwoOutputsStridedFunctor } }; +/** + * @brief Functor for evaluation of a binary function with two output arrays on + * contiguous arrays. + * + * @note It extends BinaryContigFunctor from + * dpctl::tensor::kernels::elementwise_common namespace. + */ +template +struct BinaryTwoOutputsContigFunctor +{ +private: + const argT1 *in1 = nullptr; + const argT2 *in2 = nullptr; + resT1 *out1 = nullptr; + resT2 *out2 = nullptr; + std::size_t nelems_; + +public: + BinaryTwoOutputsContigFunctor(const argT1 *inp1, + const argT2 *inp2, + resT1 *res1, + resT2 *res2, + std::size_t n_elems) + : in1(inp1), in2(inp2), out1(res1), out2(res2), nelems_(n_elems) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz; + BinaryOperatorT op{}; + /* Each work-item processes vec_sz elements, contiguous in memory */ + /* NOTE: work-group size must be divisible by sub-group size */ + + if constexpr (enable_sg_loadstore && + BinaryOperatorT::supports_sg_loadstore::value && + BinaryOperatorT::supports_vec::value && (vec_sz > 1)) + { + auto sg = ndit.get_sub_group(); + std::uint16_t sgSize = sg.get_max_local_range()[0]; + + const std::size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + elems_per_wi * sgSize < nelems_) { + sycl::vec res1_vec; + sycl::vec res2_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + std::size_t offset = base + it * sgSize; + auto in1_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&in1[offset]); + auto in2_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&in2[offset]); + auto out1_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&out1[offset]); + auto out2_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&out2[offset]); + + const sycl::vec arg1_vec = + sub_group_load(sg, in1_multi_ptr); + const sycl::vec arg2_vec = + sub_group_load(sg, in2_multi_ptr); + res1_vec = op(arg1_vec, arg2_vec, res2_vec); + sub_group_store(sg, res1_vec, out1_multi_ptr); + sub_group_store(sg, res2_vec, out2_multi_ptr); + } + } + else { + const std::size_t lane_id = sg.get_local_id()[0]; + for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) { + out1[k] = op(in1[k], in2[k], out2[k]); + } + } + } + else if constexpr (enable_sg_loadstore && + BinaryOperatorT::supports_sg_loadstore::value) + { + auto sg = ndit.get_sub_group(); + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + + const std::size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + elems_per_wi * sgSize < nelems_) { +#pragma unroll + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const std::size_t offset = base + it * sgSize; + auto in1_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&in1[offset]); + auto in2_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&in2[offset]); + auto out1_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&out1[offset]); + auto out2_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&out2[offset]); + + const sycl::vec arg1_vec = + sub_group_load(sg, in1_multi_ptr); + const sycl::vec arg2_vec = + sub_group_load(sg, in2_multi_ptr); + + sycl::vec res1_vec; + sycl::vec res2_vec; +#pragma unroll + for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) { + res1_vec[vec_id] = + op(arg1_vec[vec_id], arg2_vec[vec_id], + res2_vec[vec_id]); + } + sub_group_store(sg, res1_vec, out1_multi_ptr); + sub_group_store(sg, res2_vec, out2_multi_ptr); + } + } + else { + const std::size_t lane_id = sg.get_local_id()[0]; + for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) { + out1[k] = op(in1[k], in2[k], out2[k]); + } + } + } + else { + const std::size_t sgSize = + ndit.get_sub_group().get_local_range()[0]; + const std::size_t gid = ndit.get_global_linear_id(); + const std::size_t elems_per_sg = sgSize * elems_per_wi; + + const std::size_t start = + (gid / sgSize) * (elems_per_sg - sgSize) + gid; + const std::size_t end = std::min(nelems_, start + elems_per_sg); + for (std::size_t offset = start; offset < end; offset += sgSize) { + out1[offset] = op(in1[offset], in2[offset], out2[offset]); + } + } + } +}; + +/** + * @brief Functor for evaluation of a binary function with two output arrays on + * strided data. + * + * @note It extends BinaryStridedFunctor from + * dpctl::tensor::kernels::elementwise_common namespace. + */ +template +struct BinaryTwoOutputsStridedFunctor +{ +private: + const argT1 *in1 = nullptr; + const argT2 *in2 = nullptr; + resT1 *out1 = nullptr; + resT2 *out2 = nullptr; + FourOffsets_IndexerT four_offsets_indexer_; + +public: + BinaryTwoOutputsStridedFunctor(const argT1 *inp1_tp, + const argT2 *inp2_tp, + resT1 *res1_tp, + resT2 *res2_tp, + const FourOffsets_IndexerT &inps_res_indexer) + : in1(inp1_tp), in2(inp2_tp), out1(res1_tp), out2(res2_tp), + four_offsets_indexer_(inps_res_indexer) + { + } + + void operator()(sycl::id<1> wid) const + { + const auto &four_offsets_ = + four_offsets_indexer_(static_cast(wid.get(0))); + + const auto &inp1_offset = four_offsets_.get_first_offset(); + const auto &inp2_offset = four_offsets_.get_second_offset(); + const auto &out1_offset = four_offsets_.get_third_offset(); + const auto &out2_offset = four_offsets_.get_fourth_offset(); + + BinaryOperatorT op{}; + out1[out1_offset] = + op(in1[inp1_offset], in2[inp2_offset], out2[out2_offset]); + } +}; + /** * @brief Function to submit a kernel for unary functor with two output arrays * on contiguous arrays. @@ -454,6 +658,163 @@ sycl::event unary_two_outputs_strided_impl( return comp_ev; } +/** + * @brief Function to submit a kernel for binary functor with two output arrays + * on contiguous arrays. + * + * @note It extends binary_contig_impl from + * dpctl::tensor::kernels::elementwise_common namespace. + */ +template + class BinaryTwoOutputsType, + template + class BinaryTwoOutputsContigFunctorT, + template + class kernel_name, + std::uint8_t vec_sz = 4u, + std::uint8_t n_vecs = 2u> +sycl::event + binary_two_outputs_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res1_p, + ssize_t res1_offset, + char *res2_p, + ssize_t res2_offset, + const std::vector &depends = {}) +{ + const std::size_t n_work_items_needed = nelems / (n_vecs * vec_sz); + const std::size_t lws = + select_lws(exec_q.get_device(), n_work_items_needed); + + const std::size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + using resTy1 = typename BinaryTwoOutputsType::value_type1; + using resTy2 = typename BinaryTwoOutputsType::value_type2; + using BaseKernelName = + kernel_name; + + const argTy1 *arg1_tp = + reinterpret_cast(arg1_p) + arg1_offset; + const argTy2 *arg2_tp = + reinterpret_cast(arg2_p) + arg2_offset; + resTy1 *res1_tp = reinterpret_cast(res1_p) + res1_offset; + resTy2 *res2_tp = reinterpret_cast(res2_p) + res2_offset; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + if (is_aligned(arg1_tp) && + is_aligned(arg2_tp) && + is_aligned(res1_tp) && + is_aligned(res2_tp)) + { + static constexpr bool enable_sg_loadstore = true; + using KernelName = BaseKernelName; + using Impl = BinaryTwoOutputsContigFunctorT; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, nelems)); + } + else { + static constexpr bool disable_sg_loadstore = false; + using KernelName = + disabled_sg_loadstore_wrapper_krn; + using Impl = BinaryTwoOutputsContigFunctorT; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, nelems)); + } + }); + return comp_ev; +} + +/** + * @brief Function to submit a kernel for binary functor with two output arrays + * on strided data. + * + * @note It extends binary_strided_impl from + * dpctl::tensor::kernels::elementwise_common namespace. + */ +template < + typename argTy1, + typename argTy2, + template + class BinaryTwoOutputsType, + template + class BinaryTwoOutputsStridedFunctorT, + template + class kernel_name> +sycl::event binary_two_outputs_strided_impl( + sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res1_p, + ssize_t res1_offset, + char *res2_p, + ssize_t res2_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy1 = + typename BinaryTwoOutputsType::value_type1; + using resTy2 = + typename BinaryTwoOutputsType::value_type2; + + using IndexerT = + typename dpctl::tensor::offset_utils::FourOffsets_StridedIndexer; + + const IndexerT indexer{nd, arg1_offset, arg2_offset, + res1_offset, res2_offset, shape_and_strides}; + + const argTy1 *arg1_tp = reinterpret_cast(arg1_p); + const argTy2 *arg2_tp = reinterpret_cast(arg2_p); + resTy1 *res1_tp = reinterpret_cast(res1_p); + resTy2 *res2_tp = reinterpret_cast(res2_p); + + using Impl = BinaryTwoOutputsStridedFunctorT; + + cgh.parallel_for>( + {nelems}, Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, indexer)); + }); + return comp_ev; +} + // Typedefs for function pointers typedef sycl::event (*unary_two_outputs_contig_impl_fn_ptr_t)( @@ -478,4 +839,33 @@ typedef sycl::event (*unary_two_outputs_strided_impl_fn_ptr_t)( const std::vector &, const std::vector &); +typedef sycl::event (*binary_two_outputs_contig_impl_fn_ptr_t)( + sycl::queue &, + std::size_t, + const char *, + ssize_t, + const char *, + ssize_t, + char *, + ssize_t, + char *, + ssize_t, + const std::vector &); + +typedef sycl::event (*binary_two_outputs_strided_impl_fn_ptr_t)( + sycl::queue &, + std::size_t, + int, + const ssize_t *, + const char *, + ssize_t, + const char *, + ssize_t, + char *, + ssize_t, + char *, + ssize_t, + const std::vector &, + const std::vector &); + } // namespace dpnp::extensions::py_internal::elementwise_common diff --git a/dpnp/backend/extensions/elementwise_functions/elementwise_functions.hpp b/dpnp/backend/extensions/elementwise_functions/elementwise_functions.hpp index 8132f0dad824..955d7ea52e80 100644 --- a/dpnp/backend/extensions/elementwise_functions/elementwise_functions.hpp +++ b/dpnp/backend/extensions/elementwise_functions/elementwise_functions.hpp @@ -859,6 +859,279 @@ py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype, } } +/*! @brief Template implementing Python API for querying of type support by + * binary elementwise functions */ +template +std::pair + py_binary_two_outputs_ufunc(const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst1, + const dpctl::tensor::usm_ndarray &dst2, + sycl::queue &exec_q, + const std::vector depends, + // + const output_typesT &output_types_table, + const contig_dispatchT &contig_dispatch_table, + const strided_dispatchT &strided_dispatch_table) +{ + // check type_nums + int src1_typenum = src1.get_typenum(); + int src2_typenum = src2.get_typenum(); + int dst1_typenum = dst1.get_typenum(); + int dst2_typenum = dst2.get_typenum(); + + auto array_types = td_ns::usm_ndarray_types(); + int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum); + int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum); + int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum); + int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum); + + std::pair output_typeids = + output_types_table[src1_typeid][src2_typeid]; + + if (dst1_typeid != output_typeids.first || + dst2_typeid != output_typeids.second) { + throw py::value_error( + "One of destination arrays has unexpected elemental data type."); + } + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst1, dst2})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst1); + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst2); + + // check shapes, broadcasting is assumed done by caller + // check that dimensions are the same + int src1_nd = src1.get_ndim(); + int src2_nd = src2.get_ndim(); + int dst1_nd = dst1.get_ndim(); + int dst2_nd = dst2.get_ndim(); + + if (dst1_nd != src1_nd || dst1_nd != src2_nd || dst1_nd != dst2_nd) { + throw py::value_error("Array dimensions are not the same."); + } + + // check that shapes are the same + const py::ssize_t *src1_shape = src1.get_shape_raw(); + const py::ssize_t *src2_shape = src2.get_shape_raw(); + const py::ssize_t *dst1_shape = dst1.get_shape_raw(); + const py::ssize_t *dst2_shape = dst2.get_shape_raw(); + bool shapes_equal(true); + std::size_t src_nelems(1); + + for (int i = 0; i < dst1_nd; ++i) { + const auto &sh_i = dst1_shape[i]; + src_nelems *= static_cast(src1_shape[i]); + shapes_equal = + shapes_equal && (src1_shape[i] == sh_i && src2_shape[i] == sh_i && + dst2_shape[i] == sh_i); + } + if (!shapes_equal) { + throw py::value_error("Array shapes are not the same."); + } + + // if nelems is zero, return + if (src_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst1, + src_nelems); + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst2, + src_nelems); + + // check memory overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + auto const &same_logical_tensors = + dpctl::tensor::overlap::SameLogicalTensors(); + if ((overlap(src1, dst1) && !same_logical_tensors(src1, dst1)) || + (overlap(src2, dst1) && !same_logical_tensors(src2, dst1)) || + (overlap(dst1, dst2) && !same_logical_tensors(dst1, dst2))) + { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + const char *src1_data = src1.get_data(); + const char *src2_data = src2.get_data(); + char *dst1_data = dst1.get_data(); + char *dst2_data = dst2.get_data(); + + // handle contiguous inputs + bool is_src1_c_contig = src1.is_c_contiguous(); + bool is_src1_f_contig = src1.is_f_contiguous(); + + bool is_src2_c_contig = src2.is_c_contiguous(); + bool is_src2_f_contig = src2.is_f_contiguous(); + + bool is_dst1_c_contig = dst1.is_c_contiguous(); + bool is_dst1_f_contig = dst1.is_f_contiguous(); + + bool is_dst2_c_contig = dst2.is_c_contiguous(); + bool is_dst2_f_contig = dst2.is_f_contiguous(); + + bool all_c_contig = (is_src1_c_contig && is_src2_c_contig && + is_dst1_c_contig && is_dst2_c_contig); + bool all_f_contig = (is_src1_f_contig && is_src2_f_contig && + is_dst1_f_contig && is_dst2_f_contig); + + // dispatch for contiguous inputs + if (all_c_contig || all_f_contig) { + auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid]; + + if (contig_fn != nullptr) { + auto comp_ev = + contig_fn(exec_q, src_nelems, src1_data, 0, src2_data, 0, + dst1_data, 0, dst2_data, 0, depends); + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {src1, src2, dst1, dst2}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + } + + // simplify strides + auto const &src1_strides = src1.get_strides_vector(); + auto const &src2_strides = src2.get_strides_vector(); + auto const &dst1_strides = dst1.get_strides_vector(); + auto const &dst2_strides = dst2.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_src1_strides; + shT simplified_src2_strides; + shT simplified_dst1_strides; + shT simplified_dst2_strides; + py::ssize_t src1_offset(0); + py::ssize_t src2_offset(0); + py::ssize_t dst1_offset(0); + py::ssize_t dst2_offset(0); + + int nd = dst1_nd; + const py::ssize_t *shape = src1_shape; + + simplify_iteration_space_4( + nd, shape, src1_strides, src2_strides, dst1_strides, dst2_strides, + // outputs + simplified_shape, simplified_src1_strides, simplified_src2_strides, + simplified_dst1_strides, simplified_dst2_strides, src1_offset, + src2_offset, dst1_offset, dst2_offset); + + std::vector host_tasks{}; + static constexpr auto unit_stride = std::initializer_list{1}; + + if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) && + isEqual(simplified_src2_strides, unit_stride) && + isEqual(simplified_dst1_strides, unit_stride) && + isEqual(simplified_dst2_strides, unit_stride)) + { + auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid]; + + if (contig_fn != nullptr) { + auto comp_ev = + contig_fn(exec_q, src_nelems, src1_data, src1_offset, src2_data, + src2_offset, dst1_data, dst1_offset, dst2_data, + dst2_offset, depends); + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {src1, src2, dst1, dst2}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + } + + // dispatch to strided code + auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid]; + + if (strided_fn == nullptr) { + throw std::runtime_error( + "Strided implementation is missing for src1_typeid=" + + std::to_string(src1_typeid) + + " and src2_typeid=" + std::to_string(src2_typeid)); + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_sz_event_triple_ = device_allocate_and_pack( + exec_q, host_tasks, simplified_shape, simplified_src1_strides, + simplified_src2_strides, simplified_dst1_strides, + simplified_dst2_strides); + auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_)); + auto ©_shape_ev = std::get<2>(ptr_sz_event_triple_); + + const py::ssize_t *shape_strides = shape_strides_owner.get(); + + sycl::event strided_fn_ev = + strided_fn(exec_q, src_nelems, nd, shape_strides, src1_data, + src1_offset, src2_data, src2_offset, dst1_data, dst1_offset, + dst2_data, dst2_offset, depends, {copy_shape_ev}); + + // async free of shape_strides temporary + sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {strided_fn_ev}, shape_strides_owner); + host_tasks.push_back(tmp_cleanup_ev); + + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {src1, src2, dst1, dst2}, host_tasks), + strided_fn_ev); +} + +/** + * @brief Template implementing Python API for querying of type support by + * a binary elementwise function with two output arrays. + */ +template +std::pair py_binary_two_outputs_ufunc_result_type( + const py::dtype &input1_dtype, + const py::dtype &input2_dtype, + const output_typesT &output_types_table) +{ + int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl + int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl + int src1_typeid = -1; + int src2_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + src1_typeid = array_types.typenum_to_lookup_id(tn1); + src2_typeid = array_types.typenum_to_lookup_id(tn2); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 || + src2_typeid >= td_ns::num_types) + { + throw std::runtime_error("binary output type lookup failed"); + } + std::pair dst_typeids = + output_types_table[src1_typeid][src2_typeid]; + int dst1_typeid = dst_typeids.first; + int dst2_typeid = dst_typeids.second; + + if (dst1_typeid < 0 || dst2_typeid < 0) { + auto res = py::none(); + auto py_res = py::cast(res); + return std::make_pair(py_res, py_res); + } + else { + using type_utils::_dtype_from_typenum; + + auto dst1_typenum_t = static_cast(dst1_typeid); + auto dst2_typenum_t = static_cast(dst2_typeid); + auto dt1 = _dtype_from_typenum(dst1_typenum_t); + auto dt2 = _dtype_from_typenum(dst2_typenum_t); + + return std::make_pair(py::cast(dt1), + py::cast(dt2)); + } +} + // ==================== Inplace binary functions ======================= template (nd)); } } + +void simplify_iteration_space_4( + int &nd, + const py::ssize_t *const &shape, + // src1 + std::vector const &src1_strides, + // src2 + std::vector const &src2_strides, + // src3 + std::vector const &src3_strides, + // dst + std::vector const &dst_strides, + // output + std::vector &simplified_shape, + std::vector &simplified_src1_strides, + std::vector &simplified_src2_strides, + std::vector &simplified_src3_strides, + std::vector &simplified_dst_strides, + py::ssize_t &src1_offset, + py::ssize_t &src2_offset, + py::ssize_t &src3_offset, + py::ssize_t &dst_offset) +{ + using dpctl::tensor::strides::simplify_iteration_four_strides; + if (nd > 1) { + // Simplify iteration space to reduce dimensionality + // and improve access pattern + simplified_shape.reserve(nd); + simplified_shape.insert(std::end(simplified_shape), shape, shape + nd); + assert(simplified_shape.size() == static_cast(nd)); + + simplified_src1_strides.reserve(nd); + simplified_src1_strides.insert(std::end(simplified_src1_strides), + std::begin(src1_strides), + std::end(src1_strides)); + assert(simplified_src1_strides.size() == static_cast(nd)); + + simplified_src2_strides.reserve(nd); + simplified_src2_strides.insert(std::end(simplified_src2_strides), + std::begin(src2_strides), + std::end(src2_strides)); + assert(simplified_src2_strides.size() == static_cast(nd)); + + simplified_src3_strides.reserve(nd); + simplified_src3_strides.insert(std::end(simplified_src3_strides), + std::begin(src3_strides), + std::end(src3_strides)); + assert(simplified_src3_strides.size() == static_cast(nd)); + + simplified_dst_strides.reserve(nd); + simplified_dst_strides.insert(std::end(simplified_dst_strides), + std::begin(dst_strides), + std::end(dst_strides)); + assert(simplified_dst_strides.size() == static_cast(nd)); + + int contracted_nd = simplify_iteration_four_strides( + nd, simplified_shape.data(), simplified_src1_strides.data(), + simplified_src2_strides.data(), simplified_src3_strides.data(), + simplified_dst_strides.data(), + src1_offset, // modified by reference + src2_offset, // modified by reference + src3_offset, // modified by reference + dst_offset // modified by reference + ); + simplified_shape.resize(contracted_nd); + simplified_src1_strides.resize(contracted_nd); + simplified_src2_strides.resize(contracted_nd); + simplified_src3_strides.resize(contracted_nd); + simplified_dst_strides.resize(contracted_nd); + + nd = contracted_nd; + } + else if (nd == 1) { + src1_offset = 0; + src2_offset = 0; + src3_offset = 0; + dst_offset = 0; + // Populate vectors + simplified_shape.reserve(nd); + simplified_shape.push_back(shape[0]); + assert(simplified_shape.size() == static_cast(nd)); + + simplified_src1_strides.reserve(nd); + simplified_src2_strides.reserve(nd); + simplified_src3_strides.reserve(nd); + simplified_dst_strides.reserve(nd); + + if ((src1_strides[0] < 0) && (src2_strides[0] < 0) && + (src3_strides[0] < 0) && (dst_strides[0] < 0)) + { + simplified_src1_strides.push_back(-src1_strides[0]); + simplified_src2_strides.push_back(-src2_strides[0]); + simplified_src3_strides.push_back(-src3_strides[0]); + simplified_dst_strides.push_back(-dst_strides[0]); + if (shape[0] > 1) { + src1_offset += src1_strides[0] * (shape[0] - 1); + src2_offset += src2_strides[0] * (shape[0] - 1); + src3_offset += src3_strides[0] * (shape[0] - 1); + dst_offset += dst_strides[0] * (shape[0] - 1); + } + } + else { + simplified_src1_strides.push_back(src1_strides[0]); + simplified_src2_strides.push_back(src2_strides[0]); + simplified_src3_strides.push_back(src3_strides[0]); + simplified_dst_strides.push_back(dst_strides[0]); + } + + assert(simplified_src1_strides.size() == static_cast(nd)); + assert(simplified_src2_strides.size() == static_cast(nd)); + assert(simplified_src3_strides.size() == static_cast(nd)); + assert(simplified_dst_strides.size() == static_cast(nd)); + } +} } // namespace dpnp::extensions::py_internal diff --git a/dpnp/backend/extensions/elementwise_functions/simplify_iteration_space.hpp b/dpnp/backend/extensions/elementwise_functions/simplify_iteration_space.hpp index f89e424c84c1..5dc7b196b3e0 100644 --- a/dpnp/backend/extensions/elementwise_functions/simplify_iteration_space.hpp +++ b/dpnp/backend/extensions/elementwise_functions/simplify_iteration_space.hpp @@ -61,4 +61,25 @@ void simplify_iteration_space_3(int &, py::ssize_t &, py::ssize_t &, py::ssize_t &); + +void simplify_iteration_space_4(int &, + const py::ssize_t *const &, + // src1 + std::vector const &, + // src2 + std::vector const &, + // src3 + std::vector const &, + // dst + std::vector const &, + // output + std::vector &, + std::vector &, + std::vector &, + std::vector &, + std::vector &, + py::ssize_t &, + py::ssize_t &, + py::ssize_t &, + py::ssize_t &); } // namespace dpnp::extensions::py_internal diff --git a/dpnp/backend/extensions/elementwise_functions/type_dispatch_building.hpp b/dpnp/backend/extensions/elementwise_functions/type_dispatch_building.hpp index b113457dbf6b..aa8fd58b549c 100644 --- a/dpnp/backend/extensions/elementwise_functions/type_dispatch_building.hpp +++ b/dpnp/backend/extensions/elementwise_functions/type_dispatch_building.hpp @@ -48,6 +48,24 @@ struct TypeMapTwoResultsEntry : std::bool_constant> using result_type2 = ResTy2; }; +/** + * Extends dpctl::tensor::type_dispatch::BinaryTypeMapResultEntry helper + * structure with support of the two result types. + */ +template +struct BinaryTypeMapTwoResultsEntry + : std::bool_constant, + std::is_same>> +{ + using result_type1 = ResTy1; + using result_type2 = ResTy2; +}; + /** * Extends dpctl::tensor::type_dispatch::DefaultResultEntry helper structure * with support of the two result types. From fc95d01448e59862b3fef0d05dae8eea5030d774 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 19 Nov 2025 06:48:37 -0800 Subject: [PATCH 02/10] Add dovmod implementation to ufunc extensions --- dpnp/backend/extensions/ufunc/CMakeLists.txt | 1 + .../ufunc/elementwise_functions/common.cpp | 2 + .../ufunc/elementwise_functions/divmod.cpp | 172 ++++++++++++++++++ .../ufunc/elementwise_functions/divmod.hpp | 38 ++++ .../ufunc/elementwise_functions/populate.hpp | 113 ++++++++++++ 5 files changed, 326 insertions(+) create mode 100644 dpnp/backend/extensions/ufunc/elementwise_functions/divmod.cpp create mode 100644 dpnp/backend/extensions/ufunc/elementwise_functions/divmod.hpp diff --git a/dpnp/backend/extensions/ufunc/CMakeLists.txt b/dpnp/backend/extensions/ufunc/CMakeLists.txt index e7a5e5a03222..5609522f58a4 100644 --- a/dpnp/backend/extensions/ufunc/CMakeLists.txt +++ b/dpnp/backend/extensions/ufunc/CMakeLists.txt @@ -31,6 +31,7 @@ set(_elementwise_sources ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/bitwise_count.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/degrees.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/divmod.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/erf_funcs.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fabs.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fix.cpp diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp index 2cf81aead4e1..df409464a5c2 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp @@ -30,6 +30,7 @@ #include "bitwise_count.hpp" #include "degrees.hpp" +#include "divmod.hpp" #include "erf_funcs.hpp" #include "fabs.hpp" #include "fix.hpp" @@ -63,6 +64,7 @@ void init_elementwise_functions(py::module_ m) { init_bitwise_count(m); init_degrees(m); + init_divmod(m); init_erf_funcs(m); init_fabs(m); init_fix(m); diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/divmod.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/divmod.cpp new file mode 100644 index 000000000000..af87dcc85f53 --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/divmod.cpp @@ -0,0 +1,172 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include +#include +#include + +#include + +#include "dpctl4pybind11.hpp" + +#include "divmod.hpp" +#include "kernels/elementwise_functions/divmod.hpp" +#include "populate.hpp" + +// include a local copy of elementwise common header from dpctl tensor: +// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp +// TODO: replace by including dpctl header once available +#include "../../elementwise_functions/elementwise_functions.hpp" + +#include "../../elementwise_functions/common.hpp" +#include "../../elementwise_functions/type_dispatch_building.hpp" + +// utils extension header +#include "ext/common.hpp" + +// dpctl tensor headers +#include "kernels/elementwise_functions/common.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpnp::extensions::ufunc +{ +namespace py = pybind11; +namespace py_int = dpnp::extensions::py_internal; + +namespace impl +{ +namespace ew_cmn_ns = dpnp::extensions::py_internal::elementwise_common; +namespace td_int_ns = py_int::type_dispatch; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpnp::kernels::divmod::DivmodFunctor; + +template +struct OutputType +{ + using table_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_int_ns:: + BinaryTypeMapTwoResultsEntry, + td_int_ns:: + BinaryTypeMapTwoResultsEntry, + td_int_ns:: + BinaryTypeMapTwoResultsEntry, + td_int_ns:: + BinaryTypeMapTwoResultsEntry, + td_int_ns:: + BinaryTypeMapTwoResultsEntry, + td_int_ns:: + BinaryTypeMapTwoResultsEntry, + td_int_ns:: + BinaryTypeMapTwoResultsEntry, + td_int_ns:: + BinaryTypeMapTwoResultsEntry, + td_int_ns::BinaryTypeMapTwoResultsEntry, + td_int_ns::BinaryTypeMapTwoResultsEntry, + td_int_ns::BinaryTypeMapTwoResultsEntry, + td_int_ns::DefaultTwoResultsEntry>; + using value_type1 = typename table_type::result_type1; + using value_type2 = typename table_type::result_type2; +}; + +template +using ContigFunctor = ew_cmn_ns::BinaryTwoOutputsContigFunctor< + argTy1, + argTy2, + resTy1, + resTy2, + DivmodFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using StridedFunctor = ew_cmn_ns::BinaryTwoOutputsStridedFunctor< + argTy1, + argTy2, + resTy1, + resTy2, + IndexerT, + DivmodFunctor>; + +using ew_cmn_ns::binary_two_outputs_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_two_outputs_strided_impl_fn_ptr_t; + +static binary_two_outputs_contig_impl_fn_ptr_t + divmod_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static std::pair divmod_output_typeid_table[td_ns::num_types] + [td_ns::num_types]; +static binary_two_outputs_strided_impl_fn_ptr_t + divmod_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +MACRO_POPULATE_DISPATCH_2OUTS_TABLES(divmod); +} // namespace impl + +void init_divmod(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_divmod_dispatch_tables(); + using impl::divmod_contig_dispatch_table; + using impl::divmod_output_typeid_table; + using impl::divmod_strided_dispatch_table; + + auto divmod_pyapi = [&](const arrayT &src1, const arrayT &src2, + const arrayT &dst1, const arrayT &dst2, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_int::py_binary_two_outputs_ufunc( + src1, src2, dst1, dst2, exec_q, depends, + divmod_output_typeid_table, divmod_contig_dispatch_table, + divmod_strided_dispatch_table); + }; + m.def("_divmod", divmod_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst1"), py::arg("dst2"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + auto divmod_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_int::py_binary_two_outputs_ufunc_result_type( + dtype1, dtype2, divmod_output_typeid_table); + }; + m.def("_divmod_result_type", divmod_result_type_pyapi); + } +} +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/divmod.hpp b/dpnp/backend/extensions/ufunc/elementwise_functions/divmod.hpp new file mode 100644 index 000000000000..e62f9e195c15 --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/divmod.hpp @@ -0,0 +1,38 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpnp::extensions::ufunc +{ +void init_divmod(py::module_ m); +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/populate.hpp b/dpnp/backend/extensions/ufunc/elementwise_functions/populate.hpp index de2637013c5b..f0c630562aae 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/populate.hpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/populate.hpp @@ -335,3 +335,116 @@ namespace ext_ns = ext::common; ext_ns::init_dispatch_table( \ __name__##_output_typeid_table); \ }; + +/** + * @brief A macro used to define factories and a populating binary universal + * functions with two output arrays. + */ +#define MACRO_POPULATE_DISPATCH_2OUTS_TABLES(__name__) \ + template \ + class __name__##_contig_kernel; \ + \ + template \ + sycl::event __name__##_contig_impl( \ + sycl::queue &exec_q, size_t nelems, const char *arg1_p, \ + py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \ + char *res1_p, py::ssize_t res1_offset, char *res2_p, \ + py::ssize_t res2_offset, const std::vector &depends = {}) \ + { \ + return ew_cmn_ns::binary_two_outputs_contig_impl< \ + argTy1, argTy2, OutputType, ContigFunctor, \ + __name__##_contig_kernel>( \ + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res1_p, \ + res1_offset, res2_p, res2_offset, depends); \ + } \ + \ + template \ + struct ContigFactory \ + { \ + fnT get() \ + { \ + if constexpr (std::is_same_v< \ + typename OutputType::value_type1, \ + void> || \ + std::is_same_v< \ + typename OutputType::value_type2, void>) \ + { \ + \ + fnT fn = nullptr; \ + return fn; \ + } \ + else { \ + fnT fn = __name__##_contig_impl; \ + return fn; \ + } \ + } \ + }; \ + \ + template \ + struct TypeMapFactory \ + { \ + std::enable_if_t>::value, \ + std::pair> \ + get() \ + { \ + using rT1 = typename OutputType::value_type1; \ + using rT2 = typename OutputType::value_type2; \ + return std::make_pair(td_ns::GetTypeid{}.get(), \ + td_ns::GetTypeid{}.get()); \ + } \ + }; \ + \ + template \ + class __name__##_strided_kernel; \ + \ + template \ + sycl::event __name__##_strided_impl( \ + sycl::queue &exec_q, size_t nelems, int nd, \ + const py::ssize_t *shape_and_strides, const char *arg1_p, \ + py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \ + char *res1_p, py::ssize_t res1_offset, char *res2_p, \ + py::ssize_t res2_offset, const std::vector &depends, \ + const std::vector &additional_depends) \ + { \ + return ew_cmn_ns::binary_two_outputs_strided_impl< \ + argTy1, argTy2, OutputType, StridedFunctor, \ + __name__##_strided_kernel>( \ + exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \ + arg2_p, arg2_offset, res1_p, res1_offset, res2_p, res2_offset, \ + depends, additional_depends); \ + } \ + \ + template \ + struct StridedFactory \ + { \ + fnT get() \ + { \ + if constexpr (std::is_same_v< \ + typename OutputType::value_type1, \ + void> || \ + std::is_same_v< \ + typename OutputType::value_type2, void>) \ + { \ + fnT fn = nullptr; \ + return fn; \ + } \ + else { \ + fnT fn = __name__##_strided_impl; \ + return fn; \ + } \ + } \ + }; \ + \ + void populate_##__name__##_dispatch_tables(void) \ + { \ + ext_ns::init_dispatch_table( \ + __name__##_contig_dispatch_table); \ + ext_ns::init_dispatch_table( \ + __name__##_strided_dispatch_table); \ + ext_ns::init_dispatch_table, TypeMapFactory>( \ + __name__##_output_typeid_table); \ + }; From b16668c25aff7eb02562597a99e253a37ed6c399 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 19 Nov 2025 11:26:54 -0800 Subject: [PATCH 03/10] Add a function with SYCL kernel for divmod --- .../kernels/elementwise_functions/divmod.hpp | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 dpnp/backend/kernels/elementwise_functions/divmod.hpp diff --git a/dpnp/backend/kernels/elementwise_functions/divmod.hpp b/dpnp/backend/kernels/elementwise_functions/divmod.hpp new file mode 100644 index 000000000000..34c138b4f903 --- /dev/null +++ b/dpnp/backend/kernels/elementwise_functions/divmod.hpp @@ -0,0 +1,120 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +namespace dpnp::kernels::divmod +{ +template +struct DivmodFunctor +{ + using argT = argT1; + + static_assert(std::is_same_v, + "Input types are expected to be the same"); + static_assert(std::is_integral_v || std::is_floating_point_v || + std::is_same_v, + "Input types are expected to be integral or floating"); + + using supports_vec = typename std::false_type; + using supports_sg_loadstore = typename std::true_type; + + divT operator()(const argT &in1, const argT &in2, modT &mod) const + { + if constexpr (std::is_integral_v) { + if (in2 == argT(0)) { + mod = modT(0); + return divT(0); + } + + if constexpr (std::is_signed_v) { + if ((in1 == std::numeric_limits::min()) && + (in2 == argT(-1))) { + mod = modT(0); + return std::numeric_limits::min(); + } + } + + divT div = in1 / in2; + mod = in1 % in2; + + if constexpr (std::is_signed_v) { + if (l_xor(in1 > 0, in2 > 0) && (mod != 0)) { + div -= divT(1); + mod += in2; + } + } + return div; + } + else { + mod = sycl::fmod(in1, in2); + if (!in2) { + // in2 == 0 (not NaN): return result of fmod (for IEEE is nan) + return in1 / in2; + } + + // (in1 - mod) should be very nearly an integer multiple of in2 + auto div = (in1 - mod) / in2; + + // adjust fmod result to conform to Python convention of remainder + if (mod) { + if (l_xor(in2 < 0, mod < 0)) { + mod += in2; + div -= divT(1.0); + } + } + else { + // if mod is zero ensure correct sign + mod = sycl::copysign(modT(0), in2); + } + + // snap quotient to nearest integral value + if (div) { + auto floordiv = sycl::floor(div); + if (div - floordiv > divT(0.5)) { + floordiv += divT(1.0); + } + div = floordiv; + } + else { + // if div is zero ensure correct sign + div = sycl::copysign(divT(0), in1 / in2); + } + return div; + } + } + +private: + bool l_xor(bool b1, bool b2) const + { + return (b1 != b2); + } +}; +} // namespace dpnp::kernels::divmod From b99ff9cc9bbd32a8f9fad65a759b1bf186eafc3e Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 19 Nov 2025 11:28:23 -0800 Subject: [PATCH 04/10] Add new DPNPBinaryTwoOutputsFunc class --- doc/conf.py | 2 + dpnp/dpnp_algo/dpnp_elementwise_common.py | 313 +++++++++++++++++++++- dpnp/dpnp_utils/dpnp_utils_common.py | 25 ++ 3 files changed, 337 insertions(+), 3 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 32754fc0453a..469e6d5f5353 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -13,6 +13,7 @@ from dpnp.dpnp_algo.dpnp_elementwise_common import ( DPNPBinaryFunc, DPNPBinaryFuncOutKw, + DPNPBinaryTwoOutputsFunc, DPNPUnaryFunc, DPNPUnaryTwoOutputsFunc, ) @@ -215,6 +216,7 @@ def _can_document_member(member, *args, **kwargs): ( DPNPBinaryFunc, DPNPBinaryFuncOutKw, + DPNPBinaryTwoOutputsFunc, DPNPUnaryFunc, DPNPUnaryTwoOutputsFunc, ), diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index f8fcb96c732f..13be0ec169eb 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -39,12 +39,19 @@ BinaryElementwiseFunc, UnaryElementwiseFunc, ) +from dpctl.tensor._scalar_utils import ( + _get_dtype, + _get_shape, + _validate_dtype, +) import dpnp import dpnp.backend.extensions.vm._vm_impl as vmi from dpnp.dpnp_array import dpnp_array +from dpnp.dpnp_utils import get_usm_allocations from dpnp.dpnp_utils.dpnp_utils_common import ( find_buf_dtype_3out, + find_buf_dtype_4out, ) __all__ = [ @@ -52,6 +59,7 @@ "DPNPAngle", "DPNPBinaryFunc", "DPNPBinaryFuncOutKw", + "DPNPBinaryTwoOutputsFunc", "DPNPFix", "DPNPImag", "DPNPReal", @@ -347,7 +355,7 @@ def __call__( buf_dt, res1_dt, res2_dt = find_buf_dtype_3out( x.dtype, - self.result_type_resolver_fn_, + self.get_type_result_resolver_function(), x.sycl_device, ) if res1_dt is None or res2_dt is None: @@ -444,13 +452,12 @@ def __call__( out[i] = dpt.empty_like(x, dtype=res_dt, order=order) # Call the unary function with input and output arrays - dep_evs = _manager.submitted_events ht_unary_ev, unary_ev = self.get_implementation_function()( x, dpnp.get_usm_ndarray(out[0]), dpnp.get_usm_ndarray(out[1]), sycl_queue=exec_q, - depends=dep_evs, + depends=_manager.submitted_events, ) _manager.add_event_pair(ht_unary_ev, unary_ev) @@ -795,6 +802,306 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) +class DPNPBinaryTwoOutputsFunc(BinaryElementwiseFunc): + """ + Class that implements unary element-wise functions with two output arrays. + + Parameters + ---------- + name : {str} + Name of the unary function + result_type_resolver_fn : {callable} + Function that takes dtype of the input and returns the dtype of + the result if the implementation functions supports it, or + returns `None` otherwise. + unary_dp_impl_fn : {callable} + Data-parallel implementation function with signature + `impl_fn(src: usm_ndarray, dst: usm_ndarray, + sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])` + where the `src` is the argument array, `dst` is the + array to be populated with function values, effectively + evaluating `dst = func(src)`. + The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s. + The first event corresponds to data-management host tasks, + including lifetime management of argument Python objects to ensure + that their associated USM allocation is not freed before offloaded + computational tasks complete execution, while the second event + corresponds to computational tasks associated with function evaluation. + docs : {str} + Documentation string for the unary function. + mkl_fn_to_call : {None, str} + Check input arguments to answer if function from OneMKL VM library + can be used. + mkl_impl_fn : {None, str} + Function from OneMKL VM library to call. + + """ + + def __init__( + self, + name, + result_type_resolver_fn, + binary_dp_impl_fn, + docs, + ): + super().__init__( + name, + result_type_resolver_fn, + binary_dp_impl_fn, + docs, + ) + self.__name__ = "DPNPBinaryTwoOutputsFunc" + + @property + def nout(self): + """Returns the number of arguments treated as outputs.""" + return 2 + + def __call__( + self, + x1, + x2, + out1=None, + out2=None, + /, + *, + out=(None, None), + where=True, + order="K", + dtype=None, + subok=True, + **kwargs, + ): + if kwargs: + raise NotImplementedError( + f"Requested function={self.name_} with kwargs={kwargs} " + "isn't currently supported." + ) + elif where is not True: + raise NotImplementedError( + f"Requested function={self.name_} with where={where} " + "isn't currently supported." + ) + elif dtype is not None: + raise NotImplementedError( + f"Requested function={self.name_} with dtype={dtype} " + "isn't currently supported." + ) + elif subok is not True: + raise NotImplementedError( + f"Requested function={self.name_} with subok={subok} " + "isn't currently supported." + ) + + dpnp.check_supported_arrays_type(x1, x2, scalar_type=True) + + if order is None: + order = "K" + elif order in "afkcAFKC": + order = order.upper() + else: + raise ValueError( + "order must be one of 'C', 'F', 'A', or 'K' " f"(got '{order}')" + ) + + res_usm_type, exec_q = get_usm_allocations([x1, x2]) + x1 = dpnp.get_usm_ndarray_or_scalar(x1) + x2 = dpnp.get_usm_ndarray_or_scalar(x2) + + x1_sh = _get_shape(x1) + x2_sh = _get_shape(x2) + try: + res_shape = dpnp.broadcast_shapes(x1_sh, x2_sh) + except ValueError: + raise ValueError( + "operands could not be broadcast together with shapes " + f"{x1_sh} and {x2_sh}" + ) + + sycl_dev = exec_q.sycl_device + x1_dt = _get_dtype(x1, sycl_dev) + x2_dt = _get_dtype(x2, sycl_dev) + if not all(_validate_dtype(dt) for dt in [x1_dt, x2_dt]): + raise ValueError("Operands have unsupported data types") + + x1_dt, x2_dt = self.get_array_dtype_scalar_type_resolver_function()( + x1_dt, x2_dt, sycl_dev + ) + + buf1_dt, buf2_dt, res1_dt, res2_dt = find_buf_dtype_4out( + x1_dt, + x2_dt, + self.get_type_result_resolver_function(), + sycl_dev, + ) + if res1_dt is None or res2_dt is None: + raise ValueError( + f"function '{self.name_}' does not support input type " + f"({x1_dt}, {x2_dt}), " + "and the input could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + buf_dts = [buf1_dt, buf2_dt] + + if not isinstance(out, tuple): + raise TypeError("'out' must be a tuple of arrays") + + if len(out) != self.nout: + raise ValueError( + "'out' tuple must have exactly one entry per ufunc output" + ) + + if not (out1 is None and out2 is None): + if all(res is None for res in out): + out = (out1, out2) + else: + raise TypeError( + "cannot specify 'out' as both a positional and keyword argument" + ) + + orig_out, out = list(out), list(out) + res_dts = [res1_dt, res2_dt] + + for i in range(self.nout): + if out[i] is None: + continue + + res = dpnp.get_usm_ndarray(out[i]) + if not res.flags.writable: + raise ValueError("output array is read-only") + + if res.shape != res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {res_shape}, got {res.shape}" + ) + + if dpu.get_execution_queue((exec_q, res.sycl_queue)) is None: + raise dpnp.exceptions.ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + res_dt = res_dts[i] + if res_dt != res.dtype: + if not dpnp.can_cast(res_dt, res.dtype, casting="same_kind"): + raise TypeError( + f"Cannot cast ufunc '{self.name_}' output {i + 1} from " + f"{res_dt} to {res.dtype} with casting rule 'same_kind'" + ) + + # Allocate a temporary buffer with the required dtype + out[i] = dpt.empty_like(res, dtype=res_dt) + else: + for x, dt in zip([x1, x2], buf_dts): + if dpnp.isscalar(x): + pass + elif dt is not None: + pass + elif not dti._array_overlap(x, res): + pass + elif dti._same_logical_tensors(x, res): + pass + + # Allocate a temporary buffer to avoid memory overlapping. + # Note if `dt` is not None, a temporary copy of `x` will be + # created, so the array overlap check isn't needed. + out[i] = dpt.empty_like(res) + break + + x1 = dpnp.as_usm_ndarray(x1, dtype=x1_dt, sycl_queue=exec_q) + x2 = dpnp.as_usm_ndarray(x2, dtype=x2_dt, sycl_queue=exec_q) + + if order == "A": + if x1.flags.f_contiguous and x2.flags.f_contiguous: + order = "F" + else: + order = "C" + + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + # Cast input array to the supported type if needed + if any(dt is not None for dt in buf_dts): + if all(dt is not None for dt in buf_dts): + if x1.flags.c_contiguous and x2.flags.c_contiguous: + order = "C" + elif x1.flags.f_contiguous and x2.flags.f_contiguous: + order = "F" + + arrs = [x1, x2] + buf_dts = [buf1_dt, buf2_dt] + for i in range(self.nout): + buf_dt = buf_dts[i] + if buf_dt is None: + continue + + x = arrs[i] + if order == "K": + buf = dtc._empty_like_orderK(x, buf_dt) + else: + buf = dpt.empty_like(x, dtype=buf_dt, order=order) + + ht_copy_ev, copy_ev = dti._copy_usm_ndarray_into_usm_ndarray( + src=x, dst=buf, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + + arrs[i] = buf + x1, x2 = arrs + + # Allocate a buffer for the output arrays if needed + for i in range(self.nout): + if out[i] is None: + res_dt = res_dts[i] + if order == "K": + out[i] = dtc._empty_like_pair_orderK( + x1, x2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out[i] = dpt.empty( + res_shape, + dtype=res_dt, + order=order, + usm_type=res_usm_type, + sycl_queue=exec_q, + ) + + # Broadcast shapes of input arrays + if x1.shape != res_shape: + x1 = dpt.broadcast_to(x1, res_shape) + if x2.shape != res_shape: + x2 = dpt.broadcast_to(x2, res_shape) + + # Call the binary function with input and output arrays + ht_binary_ev, binary_ev = self.get_implementation_function()( + x1, + x2, + dpnp.get_usm_ndarray(out[0]), + dpnp.get_usm_ndarray(out[1]), + sycl_queue=exec_q, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_binary_ev, binary_ev) + + for i in range(self.nout): + orig_res, res = orig_out[i], out[i] + if not (orig_res is None or orig_res is res): + # Copy the out data from temporary buffer to original memory + ht_copy_ev, copy_ev = dti._copy_usm_ndarray_into_usm_ndarray( + src=res, + dst=dpnp.get_usm_ndarray(orig_res), + sycl_queue=exec_q, + depends=[binary_ev], + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + res = out[i] = orig_res + + if not isinstance(res, dpnp_array): + # Always return dpnp.ndarray + out[i] = dpnp_array._create_from_usm_ndarray(res) + return tuple(out) + + class DPNPAngle(DPNPUnaryFunc): """Class that implements dpnp.angle unary element-wise functions.""" diff --git a/dpnp/dpnp_utils/dpnp_utils_common.py b/dpnp/dpnp_utils/dpnp_utils_common.py index 2cf5973d1e8c..e4bde2e1ec86 100644 --- a/dpnp/dpnp_utils/dpnp_utils_common.py +++ b/dpnp/dpnp_utils/dpnp_utils_common.py @@ -36,6 +36,7 @@ __all__ = [ "find_buf_dtype_3out", + "find_buf_dtype_4out", "result_type_for_device", "to_supported_dtypes", ] @@ -60,6 +61,30 @@ def find_buf_dtype_3out(arg_dtype, query_fn, sycl_dev): return None, None, None +def find_buf_dtype_4out(arg1_dtype, arg2_dtype, query_fn, sycl_dev): + """Works as dpu._find_buf_dtype2, but with two output arrays.""" + + res1_dt, res2_dt = query_fn(arg1_dtype, arg2_dtype) + if res1_dt and res2_dt: + return None, None, res1_dt, res2_dt + + _fp16 = sycl_dev.has_aspect_fp16 + _fp64 = sycl_dev.has_aspect_fp64 + all_dts = dtu._all_data_types(_fp16, _fp64) + for buf1_dt in all_dts: + for buf2_dt in all_dts: + if dtu._can_cast( + arg1_dtype, buf1_dt, _fp16, _fp64 + ) and dtu._can_cast(arg2_dtype, buf2_dt, _fp16, _fp64): + res1_dt, res2_dt = query_fn(buf1_dt, buf2_dt) + if res1_dt and res2_dt: + ret_buf1_dt = None if buf1_dt == arg1_dtype else buf1_dt + ret_buf2_dt = None if buf2_dt == arg2_dtype else buf2_dt + return ret_buf1_dt, ret_buf2_dt, res1_dt, res2_dt + + return None, None, None, None + + def result_type_for_device(dtypes, device): """Works as dpnp.result_type, but taking into account the device capabilities.""" From cac6399c34a4dbefb85423fa278e17cce196e573 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 19 Nov 2025 11:29:08 -0800 Subject: [PATCH 05/10] Add python implementation of divmod --- dpnp/dpnp_iface_mathematical.py | 104 +++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 3 deletions(-) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index b4b28695145a..cab654959eb4 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -66,6 +66,7 @@ DPNPAngle, DPNPBinaryFunc, DPNPBinaryFuncOutKw, + DPNPBinaryTwoOutputsFunc, DPNPFix, DPNPImag, DPNPReal, @@ -102,6 +103,7 @@ "cumulative_prod", "cumulative_sum", "diff", + "divmod", "divide", "ediff1d", "fabs", @@ -1630,6 +1632,100 @@ def diff(a, n=1, axis=-1, prepend=None, append=None): ) +_DIVMOD_DOCSTRING = r""" +Calculates the quotient and the remainder for each element :math:`x1_i` of the +input array `x1` with the respective element :math:`x2_i` of the input array +`x2`. + +For full documentation refer to :obj:`numpy.divmod`. + +Parameters +---------- +x1 : {dpnp.ndarray, usm_ndarray} + Dividend input array, expected to have a real-valued floating-point data + type. +x2 : {dpnp.ndarray, usm_ndarray} + Divisor input array, expected to have a real-valued floating-point data + type. +out1 : {None, dpnp.ndarray, usm_ndarray}, optional + Output array for the quotient to populate. Array must have the same shape + as `x` and the expected data type. + + Default: ``None``. +out2 : {None, dpnp.ndarray, usm_ndarray}, optional + Output array for the remainder to populate. Array must have the same shape + as `x` and the expected data type. + + Default: ``None``. +out : tuple of None, dpnp.ndarray, or usm_ndarray, optional + A location into which the result is stored. If provided, it must be a tuple + and have length equal to the number of outputs. Each provided array must + have the same shape as `x` and the expected data type. + It is prohibited to pass output arrays through `out` keyword when either + `out1` or `out2` is passed. + + Default: ``(None, None)``. +order : {None, "C", "F", "A", "K"}, optional + Memory layout of the newly output array, if parameter `out` is ``None``. + + Default: ``"K"``. + +Returns +------- +quotient : dpnp.ndarray + Element-wise quotient resulting from floor division. +remainder : dpnp.ndarray + Element-wise remainder from floor division. + +Limitations +----------- +Parameters `where`, `dtype` and `subok` are supported with their default values. +Keyword argument `kwargs` is currently unsupported. +Otherwise ``NotImplementedError`` exception will be raised. + +Notes +----- +At least one of `x1` or `x2` must be an array. + +If ``x1.shape != x2.shape``, they must be broadcastable to a common shape +(which becomes the shape of the output). + +Equivalent to :math:`(x1 // x2, x1 \% x2)`, but faster because it avoids +redundant work. It is used to implement the Python built-in function ``divmod`` +on :class:`dpnp.ndarray`. + +Complex dtypes are not supported, they will raise a ``TypeError``. + +See Also +-------- +:obj:`dpnp.floor_divide` : Equivalent to Python's :math:`//` operator. +:obj:`dpnp.remainder` : Equivalent to Python's :math:`\%` operator. +:obj:`dpnp.modf` : Equivalent to ``divmod(x, 1)`` for positive `x` with the + return values switched. + +Examples +-------- +>>> import dpnp as np +>>> np.divmod(np.arange(5), 3) +(array([0, 0, 0, 1, 1]), array([0, 1, 2, 0, 1])) + +The Python built-in function ``divmod`` function can be used as a shorthand for +``np.divmod`` on :class:`dpnp.ndarray`. + +>>> x = np.arange(5) +>>> divmod(x, 3) +(array([0, 0, 0, 1, 1]), array([0, 1, 2, 0, 1])) + +""" + +divmod = DPNPBinaryTwoOutputsFunc( + "divmod", + ufi._divmod_result_type, + ufi._divmod, + _DIVMOD_DOCSTRING, +) + + def ediff1d(ary, to_end=None, to_begin=None): """ The differences between consecutive elements of an array. @@ -2065,6 +2161,7 @@ def ediff1d(ary, to_end=None, to_begin=None): See Also -------- :obj:`dpnp.remainder` : Remainder complementary to floor_divide. +:obj:`dpnp.divmod` : Simultaneous floor division and remainder. :obj:`dpnp.divide` : Standard division. :obj:`dpnp.floor` : Round a number to the nearest integer toward minus infinity. :obj:`dpnp.ceil` : Round a number to the nearest integer toward infinity. @@ -2445,7 +2542,7 @@ def ediff1d(ary, to_end=None, to_begin=None): """ frexp = DPNPUnaryTwoOutputsFunc( - "_frexp", + "frexp", ufi._frexp_result_type, ufi._frexp, _FREXP_DOCSTRING, @@ -3207,7 +3304,7 @@ def interp(x, xp, fp, left=None, right=None, period=None): """ ldexp = DPNPBinaryFunc( - "_ldexp", + "ldexp", ufi._ldexp_result_type, ufi._ldexp, _LDEXP_DOCSTRING, @@ -3487,7 +3584,7 @@ def interp(x, xp, fp, left=None, right=None, period=None): """ modf = DPNPUnaryTwoOutputsFunc( - "_modf", + "modf", ufi._modf_result_type, ufi._modf, _MODF_DOCSTRING, @@ -4344,6 +4441,7 @@ def real_if_close(a, tol=100): See Also -------- :obj:`dpnp.fmod` : Calculate the element-wise remainder of division. +:obj:`dpnp.divmod` : Simultaneous floor division and remainder. :obj:`dpnp.divide` : Standard division. :obj:`dpnp.floor` : Round a number to the nearest integer toward minus infinity. :obj:`dpnp.floor_divide` : Compute the largest integer smaller or equal to the From 95c6f0b86ee8641097d19808baea40f3399686c0 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 19 Nov 2025 11:29:37 -0800 Subject: [PATCH 06/10] Enabled muted umath tests --- dpnp/tests/test_umath.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dpnp/tests/test_umath.py b/dpnp/tests/test_umath.py index ca025fd27c55..abd16fa25eb7 100644 --- a/dpnp/tests/test_umath.py +++ b/dpnp/tests/test_umath.py @@ -115,8 +115,6 @@ def test_umaths(test_cases): and not (vmi._is_available() and has_support_aspect64()) ): pytest.skip("dpctl-2031") - elif umath in ["divmod"]: - pytest.skip("Not implemented umath") elif umath in ["vecmat", "matvec"]: if is_win_platform() and not is_gpu_device(): pytest.skip("SAT-8073") From e3b51e991fbcef88b4c0ff94779fb7b593493dac Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 19 Nov 2025 11:30:35 -0800 Subject: [PATCH 07/10] Add new test scope for binary ufuncs with two output arrays --- dpnp/tests/test_binary_two_outputs_ufuncs.py | 271 +++++++++++++++++++ 1 file changed, 271 insertions(+) create mode 100644 dpnp/tests/test_binary_two_outputs_ufuncs.py diff --git a/dpnp/tests/test_binary_two_outputs_ufuncs.py b/dpnp/tests/test_binary_two_outputs_ufuncs.py new file mode 100644 index 000000000000..92e30468b299 --- /dev/null +++ b/dpnp/tests/test_binary_two_outputs_ufuncs.py @@ -0,0 +1,271 @@ +import itertools + +import numpy +import pytest +from numpy.testing import ( + assert_array_equal, +) + +import dpnp + +from .helper import ( + generate_random_numpy_array, + get_all_dtypes, + get_complex_dtypes, + get_float_dtypes, + get_integer_dtypes, +) + +""" +The scope includes tests with only functions which are instances of +`DPNPBinaryTwoOutputsFunc` class. + +""" + + +@pytest.mark.parametrize("func", ["divmod"]) +class TestBinaryTwoOutputs: + ALL_DTYPES = get_all_dtypes(no_none=True) + ALL_DTYPES_NO_COMPLEX = get_all_dtypes( + no_none=True, no_float16=False, no_complex=True + ) + ALL_FLOAT_DTYPES = get_float_dtypes(no_float16=False) + + def _signs(self, dtype): + if numpy.issubdtype(dtype, numpy.unsignedinteger): + return (+1,) + else: + return (+1, -1) + + @pytest.mark.usefixtures("suppress_divide_numpy_warnings") + @pytest.mark.parametrize("dt", ALL_DTYPES_NO_COMPLEX) + def test_basic(self, func, dt): + a = generate_random_numpy_array((2, 5), dtype=dt) + b = generate_random_numpy_array((2, 5), dtype=dt) + ia, ib = dpnp.array(a), dpnp.array(b) + + res1, res2 = getattr(dpnp, func)(ia, ib) + exp1, exp2 = getattr(numpy, func)(a, b) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + + @pytest.mark.parametrize("dt1", ALL_DTYPES_NO_COMPLEX) + @pytest.mark.parametrize("dt2", ALL_DTYPES_NO_COMPLEX) + def test_signs(self, func, dt1, dt2): + for sign1, sign2 in itertools.product( + self._signs(dt1), self._signs(dt2) + ): + a = numpy.array(sign1 * 71, dtype=dt1) + b = numpy.array(sign2 * 19, dtype=dt2) + ia, ib = dpnp.array(a), dpnp.array(b) + + res1, res2 = getattr(dpnp, func)(ia, ib) + exp1, exp2 = getattr(numpy, func)(a, b) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + + @pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES) + def test_float_exact(self, func, dt): + # test that float results are exact for small integers + nlst = list(range(-127, 0)) + plst = list(range(1, 128)) + dividend = nlst + [0] + plst + divisor = nlst + plst + arg = list(itertools.product(dividend, divisor)) + + a, b = numpy.array(arg, dtype=dt).T + ia, ib = dpnp.array(a), dpnp.array(b) + + res1, res2 = getattr(dpnp, func)(ia, ib) + exp1, exp2 = getattr(numpy, func)(a, b) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + + @pytest.mark.parametrize("dt1", get_float_dtypes()) + @pytest.mark.parametrize("dt2", get_float_dtypes()) + @pytest.mark.parametrize( + "sign1, sign2", [(+1, +1), (+1, -1), (-1, +1), (-1, -1)] + ) + def test_float_roundoff(self, func, dt1, dt2, sign1, sign2): + a = numpy.array(sign1 * 78 * 6e-8, dtype=dt1) + b = numpy.array(sign2 * 6e-8, dtype=dt2) + ia, ib = dpnp.array(a), dpnp.array(b) + + res1, res2 = getattr(dpnp, func)(ia, ib) + exp1, exp2 = getattr(numpy, func)(a, b) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + + @pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings") + @pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES) + @pytest.mark.parametrize( + "val1", [0.0, 1.0, numpy.inf, -numpy.inf, numpy.nan] + ) + @pytest.mark.parametrize( + "val2", [0.0, 1.0, numpy.inf, -numpy.inf, numpy.nan] + ) + def test_special_float_values(self, func, dt, val1, val2): + a = numpy.array(val1, dtype=dt) + b = numpy.array(val2, dtype=dt) + ia, ib = dpnp.array(a), dpnp.array(b) + + res1, res2 = getattr(dpnp, func)(ia, ib) + exp1, exp2 = getattr(numpy, func)(a, b) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES) + def test_float_overflow(self, func, dt): + a = numpy.finfo(dt).tiny + a = numpy.array(a, dtype=dt) + ia = dpnp.array(a, dtype=dt) + + res1, res2 = getattr(dpnp, func)(4, ia) + exp1, exp2 = getattr(numpy, func)(4, a) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + + @pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES) + def test_out(self, func, dt): + a = numpy.array(5.7, dtype=dt) + ia = dpnp.array(a) + + out1 = numpy.empty((), dtype=dt) + out2 = numpy.empty((), dtype=dt) + iout1, iout2 = dpnp.array(out1), dpnp.array(out2) + + res1, res2 = getattr(dpnp, func)(ia, 2, iout1) + exp1, exp2 = getattr(numpy, func)(a, 2, out1) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + assert res1 is iout1 + + res1, res2 = getattr(dpnp, func)(ia, 2, None, iout2) + exp1, exp2 = getattr(numpy, func)(a, 2, None, out2) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + assert res2 is iout2 + + res1, res2 = getattr(dpnp, func)(ia, 2, iout1, iout2) + exp1, exp2 = getattr(numpy, func)(a, 2, out1, out2) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + assert res1 is iout1 + assert res2 is iout2 + + @pytest.mark.parametrize("dt1", ALL_DTYPES_NO_COMPLEX) + @pytest.mark.parametrize("dt2", ALL_DTYPES_NO_COMPLEX) + @pytest.mark.parametrize("out1_dt", ALL_DTYPES) + @pytest.mark.parametrize("out2_dt", ALL_DTYPES) + def test_2out_all_dtypes(self, func, dt1, dt2, out1_dt, out2_dt): + a = numpy.ones((3, 1), dtype=dt1) + b = numpy.ones((3, 4), dtype=dt2) + ia, ib = dpnp.array(a), dpnp.array(b) + + out1 = numpy.zeros_like(b, dtype=out1_dt) + out2 = numpy.zeros_like(b, dtype=out2_dt) + iout1, iout2 = dpnp.array(out1), dpnp.array(out2) + + try: + res1, res2 = getattr(dpnp, func)(ia, ib, out=(iout1, iout2)) + except TypeError: + # expect numpy to fail with the same reason + with pytest.raises(TypeError): + _ = getattr(numpy, func)(a, b, out=(out1, out2)) + else: + exp1, exp2 = getattr(numpy, func)(a, b, out=(out1, out2)) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + assert res1 is iout1 + assert res2 is iout2 + + @pytest.mark.usefixtures("suppress_invalid_numpy_warnings") + @pytest.mark.parametrize("stride", [-4, -2, -1, 1, 2, 4]) + @pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES) + def test_strides_out(self, func, stride, dt): + a = numpy.array( + [numpy.nan, numpy.nan, numpy.inf, -numpy.inf, 0.0, -0.0, 1.0, -1.0], + dtype=dt, + ) + ia = dpnp.array(a) + + out1 = numpy.ones_like(a, dtype=dt) + out2 = 2 * numpy.ones_like(a, dtype=dt) + iout_mant, iout_exp = dpnp.array(out1), dpnp.array(out2) + + res1, res2 = getattr(dpnp, func)( + ia[::stride], 2, out=(iout_mant[::stride], iout_exp[::stride]) + ) + exp1, exp2 = getattr(numpy, func)( + a[::stride], 2, out=(out1[::stride], out2[::stride]) + ) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + + assert_array_equal(iout_mant, out1) + assert_array_equal(iout_exp, out2) + + @pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES) + def test_out1_overlap(self, func, dt): + size = 15 + a = numpy.ones(2 * size, dtype=dt) + ia = dpnp.array(a) + + # out1 overlaps memory of input array + _ = getattr(dpnp, func)(ia[size::], 1, ia[::2]) + _ = getattr(numpy, func)(a[size::], 1, a[::2]) + assert_array_equal(ia, a) + + @pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES) + def test_empty(self, func, dt): + a = numpy.empty(0, dtype=dt) + ia = dpnp.array(a) + + res1, res2 = getattr(dpnp, func)(ia, ia) + exp1, exp2 = getattr(numpy, func)(a, a) + assert_array_equal(res1, exp1, strict=True) + assert_array_equal(res2, exp2, strict=True) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + @pytest.mark.parametrize("dt", get_complex_dtypes()) + def test_complex_dtype(self, func, xp, dt): + a = xp.array( + [0.9 + 1j, -0.1 + 1j, 0.9 + 0.5 * 1j, 0.9 + 2.0 * 1j], dtype=dt + ) + with pytest.raises((TypeError, ValueError)): + _ = getattr(xp, func)(a, 7) + + +class TestDivmod: + @pytest.mark.usefixtures("suppress_divide_numpy_warnings") + @pytest.mark.parametrize("dt", get_integer_dtypes()) + def test_int_zero(self, dt): + a = numpy.array(0, dtype=dt) + ia = dpnp.array(a) + + res1, res2 = dpnp.divmod(ia, 0) + exp1, exp2 = numpy.divmod(a, 0) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + + @pytest.mark.parametrize("dt", get_integer_dtypes(no_unsigned=True)) + def test_min_int(self, dt): + a = numpy.array(numpy.iinfo(dt).min, dtype=dt) + ia = dpnp.array(a) + + res1, res2 = dpnp.divmod(ia, -1) + exp1, exp2 = numpy.divmod(a, -1) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) + + @pytest.mark.parametrize("dt", get_integer_dtypes(no_unsigned=True)) + def test_special_int(self, dt): + # a and b have different sign and mod != 0 + a, b = numpy.array(-1, dtype=dt), numpy.array(3, dtype=dt) + ia, ib = dpnp.array(a), dpnp.array(b) + + res1, res2 = dpnp.divmod(ia, ib) + exp1, exp2 = numpy.divmod(a, b) + assert_array_equal(res1, exp1) + assert_array_equal(res2, exp2) From 25fee152c5081ab9db655535c21e0587ff03d19f Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 19 Nov 2025 11:33:33 -0800 Subject: [PATCH 08/10] Enable third party tests --- .../cupy/core_tests/test_ndarray_ufunc.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/dpnp/tests/third_party/cupy/core_tests/test_ndarray_ufunc.py b/dpnp/tests/third_party/cupy/core_tests/test_ndarray_ufunc.py index 0a6624acc59b..1df3bbfc8fbc 100644 --- a/dpnp/tests/third_party/cupy/core_tests/test_ndarray_ufunc.py +++ b/dpnp/tests/third_party/cupy/core_tests/test_ndarray_ufunc.py @@ -1,11 +1,11 @@ +from __future__ import annotations + import numpy import pytest import dpnp as cupy from dpnp.tests.third_party.cupy import testing -pytest.skip("UFunc interface is not supported", allow_module_level=True) - class C(cupy.ndarray): @@ -20,6 +20,7 @@ def __array_finalize__(self, obj): self.info = getattr(obj, "info", None) +@pytest.mark.skip("UFunc interface is not supported") class TestArrayUfunc: @testing.for_all_dtypes() @@ -200,8 +201,8 @@ def test_types(self, xp, ufunc): sig for sig in types # CuPy does not support the following dtypes: - # (c)longdouble, datetime, timedelta, and object. - if not any(t in sig for t in "GgMmO") + # longlong, (c)longdouble, datetime, timedelta, and object. + if not any(t in sig for t in "QqGgMmO") ) ) return types @@ -210,7 +211,7 @@ def test_types(self, xp, ufunc): def test_unary_out_tuple(self, xp): dtype = xp.float64 a = testing.shaped_arange((2, 3), xp, dtype) - out = xp.zeros((2, 3), dtype) + out = xp.zeros((2, 3), dtype=dtype) ret = xp.sin(a, out=(out,)) assert ret is out return ret @@ -225,8 +226,8 @@ def test_unary_out_positional_none(self, xp): def test_binary_out_tuple(self, xp): dtype = xp.float64 a = testing.shaped_arange((2, 3), xp, dtype) - b = xp.ones((2, 3), dtype) - out = xp.zeros((2, 3), dtype) + b = xp.ones((2, 3), dtype=dtype) + out = xp.zeros((2, 3), dtype=dtype) ret = xp.add(a, b, out=(out,)) assert ret is out return ret @@ -235,7 +236,7 @@ def test_binary_out_tuple(self, xp): def test_biary_out_positional_none(self, xp): dtype = xp.float64 a = testing.shaped_arange((2, 3), xp, dtype) - b = xp.ones((2, 3), dtype) + b = xp.ones((2, 3), dtype=dtype) return xp.add(a, b, None) @testing.numpy_cupy_allclose() @@ -243,8 +244,8 @@ def test_divmod_out_tuple(self, xp): dtype = xp.float64 a = testing.shaped_arange((2, 3), xp, dtype) b = testing.shaped_reverse_arange((2, 3), xp, dtype) - out0 = xp.zeros((2, 3), dtype) - out1 = xp.zeros((2, 3), dtype) + out0 = xp.zeros((2, 3), dtype=dtype) + out1 = xp.zeros((2, 3), dtype=dtype) ret = xp.divmod(a, b, out=(out0, out1)) assert ret[0] is out0 assert ret[1] is out1 @@ -254,7 +255,7 @@ def test_divmod_out_tuple(self, xp): def test_divmod_out_positional_none(self, xp): dtype = xp.float64 a = testing.shaped_arange((2, 3), xp, dtype) - b = xp.ones((2, 3), dtype) + b = xp.ones((2, 3), dtype=dtype) return xp.divmod(a, b, None, None) @testing.numpy_cupy_allclose() @@ -262,7 +263,7 @@ def test_divmod_out_partial(self, xp): dtype = xp.float64 a = testing.shaped_arange((2, 3), xp, dtype) b = testing.shaped_reverse_arange((2, 3), xp, dtype) - out0 = xp.zeros((2, 3), dtype) + out0 = xp.zeros((2, 3), dtype=dtype) ret = xp.divmod(a, b, out0) # out1 is None assert ret[0] is out0 return ret @@ -272,7 +273,7 @@ def test_divmod_out_partial_tuple(self, xp): dtype = xp.float64 a = testing.shaped_arange((2, 3), xp, dtype) b = testing.shaped_reverse_arange((2, 3), xp, dtype) - out1 = xp.zeros((2, 3), dtype) + out1 = xp.zeros((2, 3), dtype=dtype) ret = xp.divmod(a, b, out=(None, out1)) assert ret[1] is out1 return ret From 9ec6b8fea6c39826d63f513f785f29fd9ac36597 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 19 Nov 2025 14:39:30 -0800 Subject: [PATCH 09/10] Add support of python built-in function divmod() --- dpnp/dpnp_array.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 656f099a0c4a..301626acce85 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -212,7 +212,10 @@ def __copy__(self): # '__deepcopy__', # '__dir__', - # '__divmod__', + + def __divmod__(self, other, /): + r"""Return :math:`\text{divmod(self, value)}`.""" + return dpnp.divmod(self, other) def __dlpack__( self, /, *, stream=None, max_version=None, dl_device=None, copy=None @@ -493,7 +496,10 @@ def __rand__(self, other, /): r"""Return :math:`\text{value & self}`.""" return dpnp.bitwise_and(other, self) - # '__rdivmod__', + def __rdivmod__(self, other, /): + r"""Return :math:`\text{divmod(value, self)}`.""" + return dpnp.divmod(other, self) + # '__reduce__', # '__reduce_ex__', @@ -503,7 +509,7 @@ def __repr__(self): def __rfloordiv__(self, other, /): r"""Return :math:`\text{value // self}`.""" - return dpnp.floor_divide(self, other) + return dpnp.floor_divide(other, self) def __rlshift__(self, other, /): r"""Return :math:`\text{value << self}`.""" From 3cb0fea0d76ecdb7022bcb9ddfba0e211de151eb Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 19 Nov 2025 14:40:50 -0800 Subject: [PATCH 10/10] Enable third party tests for elementwise operations --- .../core_tests/test_ndarray_elementwise_op.py | 66 +++++++++++++++---- .../cupy/creation_tests/test_ranges.py | 1 - 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/dpnp/tests/third_party/cupy/core_tests/test_ndarray_elementwise_op.py b/dpnp/tests/third_party/cupy/core_tests/test_ndarray_elementwise_op.py index e240f73ddb4d..ea164f4e3167 100644 --- a/dpnp/tests/third_party/cupy/core_tests/test_ndarray_elementwise_op.py +++ b/dpnp/tests/third_party/cupy/core_tests/test_ndarray_elementwise_op.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import functools import operator import numpy @@ -6,13 +9,40 @@ import dpnp as cupy from dpnp.tests.third_party.cupy import testing -pytest.skip("operator interface is not supported", allow_module_level=True) + +def cast_exception_type(): + """ + Decorator for parameterized tests to cast raising exception + ValueError(...does not support input types...) to TypeError(...) matching + NumPy behavior. + + The exception raised when a pair of input dtypes is not supported and could + not be safely coerced to any supported one according to the casting rule. + + """ + + def decorator(impl): + @functools.wraps(impl) + def test_func(self, *args, **kw): + xp = kw["xp"] + + try: + return impl(self, *args, **kw) + except ValueError as e: + if xp is cupy and "does not support input types" in str(e): + raise TypeError(e) + raise + + return test_func + + return decorator class TestArrayElementwiseOp: @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(rtol=1e-6, accept_error=TypeError) + @cast_exception_type() def check_array_scalar_op( self, op, @@ -106,6 +136,7 @@ def test_rpow_scalar(self): @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(atol=1.0, accept_error=TypeError) + @cast_exception_type() def check_ipow_scalar(self, xp, x_type, y_type): a = xp.array([[1, 2, 3], [4, 5, 6]], x_type) return operator.ipow(a, y_type(3)) @@ -157,6 +188,7 @@ def test_ne_scalar(self): @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(accept_error=TypeError) + @cast_exception_type() def check_array_array_op( self, op, xp, x_type, y_type, no_bool=False, no_complex=False ): @@ -216,13 +248,14 @@ def check_pow_array(self, xp, x_type, y_type): def test_pow_array(self): # There are some precision issues in HIP that prevent # checking with atol=0 - if cupy.cuda.runtime.is_hip: - self.check_pow_array() - else: - self.check_array_array_op(operator.pow) + # if cupy.cuda.runtime.is_hip: + # self.check_pow_array() + # else: + self.check_array_array_op(operator.pow) @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(atol=1.0, accept_error=TypeError) + @cast_exception_type() def check_ipow_array(self, xp, x_type, y_type): a = xp.array([[1, 2, 3], [4, 5, 6]], x_type) b = xp.array([[6, 5, 4], [3, 2, 1]], y_type) @@ -259,6 +292,7 @@ def test_ne_array(self): @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(accept_error=TypeError) + @cast_exception_type() def check_array_broadcasted_op( self, op, xp, x_type, y_type, no_bool=False, no_complex=False ): @@ -320,13 +354,14 @@ def check_broadcasted_pow(self, xp, x_type, y_type): def test_broadcasted_pow(self): # There are some precision issues in HIP that prevent # checking with atol=0 - if cupy.cuda.runtime.is_hip: - self.check_broadcasted_pow() - else: - self.check_array_broadcasted_op(operator.pow) + # if cupy.cuda.runtime.is_hip: + # self.check_broadcasted_pow() + # else: + self.check_array_broadcasted_op(operator.pow) @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(atol=1.0, accept_error=TypeError) + @cast_exception_type() def check_broadcasted_ipow(self, xp, x_type, y_type): a = xp.array([[1, 2, 3], [4, 5, 6]], x_type) b = xp.array([[1], [2]], y_type) @@ -480,6 +515,7 @@ def test_typecast_(self, xp, op, dtype, val): a = op(val, (testing.shaped_arange((5,), xp, dtype) - 2)) return a + @pytest.mark.skip("TODO") @pytest.mark.parametrize( "val", [ @@ -513,9 +549,11 @@ def check_array_boolarray_op(self, op, xp, x_type): b = xp.array([[3, 1, 4], [-1, -5, -9]], numpy.int8).view(bool) return op(a, b) + @testing.with_requires("dpctl>=0.22.0dev0") def test_add_array_boolarray(self): self.check_array_boolarray_op(operator.add) + @testing.with_requires("dpctl>=0.22.0dev0") def test_iadd_array_boolarray(self): self.check_array_boolarray_op(operator.iadd) @@ -524,6 +562,7 @@ class TestArrayIntElementwiseOp: @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(accept_error=TypeError) + @cast_exception_type() def check_array_scalar_op(self, op, xp, x_type, y_type, swap=False): a = xp.array([[0, 1, 2], [1, 0, 2]], dtype=x_type) if swap: @@ -571,6 +610,7 @@ def test_rmod_scalar(self): @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(accept_error=TypeError) + @cast_exception_type() def check_array_scalarzero_op(self, op, xp, x_type, y_type, swap=False): a = xp.array([[0, 1, 2], [1, 0, 2]], dtype=x_type) if swap: @@ -618,6 +658,7 @@ def test_rmod_scalarzero(self): @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(accept_error=TypeError) + @cast_exception_type() def check_array_array_op(self, op, xp, x_type, y_type): a = xp.array([[0, 1, 2], [1, 0, 2]], dtype=x_type) b = xp.array([[0, 0, 1], [0, 1, 2]], dtype=y_type) @@ -663,6 +704,7 @@ def test_imod_array(self): @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(accept_error=TypeError) + @cast_exception_type() def check_array_broadcasted_op(self, op, xp, x_type, y_type): a = xp.array([[0, 1, 2], [1, 0, 2], [2, 1, 0]], dtype=x_type) b = xp.array([[0, 0, 1]], dtype=y_type) @@ -708,6 +750,7 @@ def test_broadcasted_imod(self): @testing.for_all_dtypes_combination(names=["x_type", "y_type"]) @testing.numpy_cupy_allclose(accept_error=TypeError) + @cast_exception_type() def check_array_doubly_broadcasted_op(self, op, xp, x_type, y_type): a = xp.array([[[0, 1, 2]], [[1, 0, 2]]], dtype=x_type) b = xp.array([[0], [0], [1]], dtype=y_type) @@ -733,6 +776,7 @@ def test_doubly_broadcasted_mod(self): self.check_array_doubly_broadcasted_op(operator.mod) +@pytest.mark.skip("objects as input are not supported") @pytest.mark.parametrize( "value", [ @@ -807,7 +851,7 @@ def test_eq_object(self, dtype, value): except TypeError: pytest.skip() - cupy.testing.assert_array_equal(res, expected) + testing.assert_array_equal(res, expected) def test_ne_object(self, dtype, value): expected = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) != value @@ -818,4 +862,4 @@ def test_ne_object(self, dtype, value): except TypeError: pytest.skip() - cupy.testing.assert_array_equal(res, expected) + testing.assert_array_equal(res, expected) diff --git a/dpnp/tests/third_party/cupy/creation_tests/test_ranges.py b/dpnp/tests/third_party/cupy/creation_tests/test_ranges.py index 5849d98d90ef..3790eae96462 100644 --- a/dpnp/tests/third_party/cupy/creation_tests/test_ranges.py +++ b/dpnp/tests/third_party/cupy/creation_tests/test_ranges.py @@ -1,6 +1,5 @@ import functools import math -import sys import unittest import numpy