@@ -439,16 +439,67 @@ namespace xt
439439 using requested_value_type = detail::conditional_promote_to_complex_t <e1_value_type, e2_requested_value_type>;
440440 };
441441
442+ /* *********************************
443+ * Expression Order Optimizations *
444+ **********************************/
445+
446+ class optimize_expression
447+ {
448+ private:
449+ template <class E1 , class E2 >
450+ struct equal_rank
451+ {
452+ static constexpr bool value = get_rank<E1 >::value == get_rank<E2 >::value;
453+ };
454+
455+ template <class E1 , class ... E>
456+ struct all_equal_rank
457+ {
458+ static constexpr bool value = xtl::conjunction<equal_rank<E1 , E>...>::value
459+ && (get_rank<E1 >::value != SIZE_MAX);
460+ };
461+
462+ template <class F , class ... CT, class ... S, size_t ... I, size_t ... J>
463+ inline auto impl_reorder_function (const xfunction<F, CT...>& e, std::tuple<S...> slices, std::index_sequence<I...>, std::index_sequence<J...>)
464+ {
465+ return make_lambda_xfunction (F (), view (std::get<I>(e.arguments ()), std::get<J>(slices)...)...);
466+ }
467+
468+ public:
469+ // when we have a view of a function where the closures of the functions are of equal rank (i.e no broadcasting)
470+ // we can flip the order of the function and the view such that we have a function of views of containers which
471+ // can be linearly assigned unlike the inverse.
472+ template <class F , class ... CT, class ... S, class = std::enable_if_t <all_equal_rank<std::decay_t <CT>...>::value>>
473+ inline auto reorder (const xview<xfunction<F, CT...>, S...>& e)
474+ {
475+ return impl_reorder_function (
476+ e.expression (),
477+ e.slices (),
478+ std::make_index_sequence<sizeof ...(CT)>(),
479+ std::make_index_sequence<sizeof ...(S)>()
480+ );
481+ }
482+
483+ // base case no applicable optimization
484+ template <class E >
485+ inline auto & reorder (E&& e)
486+ {
487+ return std::forward<E>(e);
488+ }
489+ };
490+
442491 template <class E1 , class E2 >
443492 inline void xexpression_assigner_base<xtensor_expression_tag>::assign_data(
444493 xexpression<E1 >& e1 ,
445494 const xexpression<E2 >& e2 ,
446495 bool trivial
447496 )
448497 {
449- E1 & de1 = e1 .derived_cast ();
450- const E2 & de2 = e2 .derived_cast ();
451- using traits = xassign_traits<E1 , E2 >;
498+ auto & de1 = e1 .derived_cast ();
499+ const auto & de2 = optimize_expression ().reorder (e2 .derived_cast ());
500+ using dst_type = typename std::decay_t <decltype (de1)>;
501+ using src_type = typename std::decay_t <decltype (de2)>;
502+ using traits = xassign_traits<dst_type, src_type>;
452503
453504 bool linear_assign = traits::linear_assign (de1, de2, trivial);
454505 constexpr bool simd_assign = traits::simd_assign ();
0 commit comments