Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/phi/kernels/legacy/cpu/one_hot_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@

namespace phi {

template <typename DeviceContext, typename InT>
template <typename Context, typename InT>
struct OneHotV2OpFunctor {
const DenseTensor* in_;
DenseTensor* out_;
int depth_;
const DeviceContext& dev_ctx_;
const Context& dev_ctx_;

OneHotV2OpFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& dev_ctx)
const Context& dev_ctx)
: in_(in), out_(out), depth_(depth), dev_ctx_(dev_ctx) {}

template <typename OutT>
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/legacy/gpu/one_hot_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ __global__ void FillOutputKernel(const InT* p_in_data,
}
}

template <typename DeviceContext, typename InT>
template <typename Context, typename InT>
struct OneHotV2OpCUDAFunctor {
const DenseTensor* in_;
DenseTensor* out_;
const DeviceContext& dev_ctx_;
const Context& dev_ctx_;
int depth_;

OneHotV2OpCUDAFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& dev_ctx)
const Context& dev_ctx)
: in_(in), out_(out), depth_(depth), dev_ctx_(dev_ctx) {}

template <typename OutT>
Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/kernels/xpu/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ int XPUReduce(const Context& dev_ctx,
return r;
}

template <typename DeviceContext, typename T, typename OutT, typename Functor>
void ReduceKernelImpl(const DeviceContext& dev_ctx,
template <typename Context, typename T, typename OutT, typename Functor>
void ReduceKernelImpl(const Context& dev_ctx,
const phi::DenseTensor& input,
phi::DenseTensor* output,
const std::vector<int64_t>& xdims,
Expand All @@ -118,8 +118,8 @@ void ReduceKernelImpl(const DeviceContext& dev_ctx,
}
}

template <typename DeviceContext, typename T, typename Functor>
void XPUReduce(const DeviceContext& dev_ctx,
template <typename Context, typename T, typename Functor>
void XPUReduce(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
Expand All @@ -142,17 +142,17 @@ void XPUReduce(const DeviceContext& dev_ctx,
// do reduce sum
PD_VISIT_XPU_REDUCE_TYPES(
x.dtype(), "ReduceKernelImpl", ([&] {
phi::ReduceKernelImpl<DeviceContext, T, data_t, Functor>(
phi::ReduceKernelImpl<Context, T, data_t, Functor>(
dev_ctx, x, out, xdims, reduce_dims);
}));
} else {
// cast x tensor to out_dtype
auto tmp_tensor = phi::Cast<T, DeviceContext>(dev_ctx, x, out_dtype);
auto tmp_tensor = phi::Cast<T, Context>(dev_ctx, x, out_dtype);

// do reduce sum
PD_VISIT_XPU_REDUCE_TYPES(
out_dtype, "ReduceKernelImpl", ([&] {
phi::ReduceKernelImpl<DeviceContext, T, data_t, Functor>(
phi::ReduceKernelImpl<Context, T, data_t, Functor>(
dev_ctx, tmp_tensor, out, xdims, reduce_dims);
}));

Expand Down
Loading