diff --git a/Paddle b/Paddle index 93a5410253b..86238399d4c 160000 --- a/Paddle +++ b/Paddle @@ -1 +1 @@ -Subproject commit 93a5410253bf2ca0945f4551e1a58ad7a5aec996 +Subproject commit 86238399d4c720a0dccc07d416ece8168225d757 diff --git a/backends/iluvatar_gpu/CMakeLists.txt b/backends/iluvatar_gpu/CMakeLists.txt index 1b65161112d..827d5553c01 100644 --- a/backends/iluvatar_gpu/CMakeLists.txt +++ b/backends/iluvatar_gpu/CMakeLists.txt @@ -34,6 +34,7 @@ include(version) include(generic) include(cblas) include(external/eigen) +include(external/magma) include(external/xxhash) include(external/zlib) include(external/protobuf) @@ -119,11 +120,13 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/*.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/math/*.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/eigen/*.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/magma/magma_function.cc # cudnn/cublas ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cudnn.cc ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cublas.cc ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cublasLt.cc ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cufft.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/magma.cc # kernels/gpu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/spectral_norm_kernel.cu @@ -1008,7 +1011,7 @@ target_link_libraries( ixattnbkd nccl # change nccl to ${FLAGCX_LIB} if compiling with FlagCX ${FLAGCX_LIB} -) + magma) include_directories(BEFORE ${PADDLE_SOURCE_DIR}) diff --git a/backends/iluvatar_gpu/cmake/external/magma.cmake b/backends/iluvatar_gpu/cmake/external/magma.cmake new file mode 100644 index 00000000000..612ea1b600e --- /dev/null +++ b/backends/iluvatar_gpu/cmake/external/magma.cmake @@ -0,0 +1,89 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +include(ExternalProject) + +set(MAGMA_PREFIX_DIR ${THIRD_PARTY_PATH}/magma) +set(MAGMA_DOWNLOAD_DIR + ${PADDLE_SOURCE_DIR}/third_party/magma/${CMAKE_SYSTEM_NAME}) +set(MAGMA_INSTALL_DIR ${THIRD_PARTY_PATH}/install/magma) +set(MAGMA_LIB_DIR ${MAGMA_INSTALL_DIR}/lib) + +# Note(zhouwei): magma need fortran compiler which many machines don't have, so +# use precompiled library. use magma tag v2.9.0 on 07/28/2025 +# https://github.com/icl-utk-edu/magma/tree/v2.9.0 +if(LINUX) + set(MAGMA_FILE + "magma_local.tar.gz" + CACHE STRING "" FORCE) + set(MAGMA_URL + "file:///home/tianyu.zhou/tyzhou/magma_local.tar.gz" + CACHE STRING "" FORCE) + set(MAGMA_URL_MD5 9715dfad9eb073e099f46feb6587232d) + set(MAGMA_LIB "${MAGMA_LIB_DIR}/libmagma.so") +elseif(WIN32) + message("magma do not support windows yet, skip ...") +else() # MacOS + message("magma do not support macos or other platform yet, skip ...") +endif() + +function(download_magma) + message( + STATUS "Downloading ${MAGMA_URL} to ${MAGMA_DOWNLOAD_DIR}/${MAGMA_FILE}") + # NOTE: If the version is updated, consider emptying the folder; maybe add + # timeout + file( + DOWNLOAD ${MAGMA_URL} ${MAGMA_DOWNLOAD_DIR}/${MAGMA_FILE} + EXPECTED_MD5 ${MAGMA_URL_MD5} + STATUS ERR) + if(ERR EQUAL 0) + message(STATUS "Download ${MAGMA_FILE} success") + else() + message( + FATAL_ERROR + "Download failed, error: ${ERR}\n You can try downloading ${MAGMA_FILE} again" + ) + endif() +endfunction() + +# Download and check magma. +if(EXISTS ${MAGMA_DOWNLOAD_DIR}/${MAGMA_FILE}) + file(MD5 ${MAGMA_DOWNLOAD_DIR}/${MAGMA_FILE} MAGMA_MD5) + if(NOT MAGMA_MD5 STREQUAL MAGMA_URL_MD5) + # clean build file + file(REMOVE_RECURSE ${MAGMA_PREFIX_DIR}) + file(REMOVE_RECURSE ${MAGMA_INSTALL_DIR}) + download_magma() + endif() +else() + download_magma() +endif() + +ExternalProject_Add( + extern_magma + ${EXTERNAL_PROJECT_LOG_ARGS} + URL ${MAGMA_DOWNLOAD_DIR}/${MAGMA_FILE} + URL_MD5 ${MAGMA_URL_MD5} + DOWNLOAD_DIR ${MAGMA_DOWNLOAD_DIR} + SOURCE_DIR ${MAGMA_LIB_DIR} + PREFIX ${MAGMA_PREFIX_DIR} + DOWNLOAD_NO_PROGRESS 1 + PATCH_COMMAND "" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + BUILD_BYPRODUCTS ${MAGMA_LIB}) + +add_definitions(-DPADDLE_WITH_MAGMA) diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/eig_grad_kernel.cu b/backends/iluvatar_gpu/kernels/cuda_kernels/eig_grad_kernel.cu new file mode 100644 index 00000000000..7748e32cd60 --- /dev/null +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/eig_grad_kernel.cu @@ -0,0 +1,519 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/dynload/cublas.h" +#include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/cpu/eig.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/unsqueeze.h" +#include "paddle/phi/kernels/transpose_kernel.h" +#include "runtime/iluvatar_context.h" + +namespace phi { + +template +void SolveLinearSystemGPU(const GPUContext& dev_ctx, + const T* matrix_data, + const T* rhs_data, + T* out_data, + int order, + int rhs_cols, + int batch_count); + +template <> +void SolveLinearSystemGPU>( + const phi::GPUContext& dev_ctx, + const phi::dtype::complex* + matrix_data, // device ptr, row-major, size batch*order*order + const phi::dtype::complex* + rhs_data, // device ptr, row-major, size batch*order*rhs_cols + phi::dtype::complex* + out_data, // device ptr, row-major, size batch*order*rhs_cols + int order, + int rhs_cols, + int batch_count) { + // handles + cublasHandle_t cublas_handle = dev_ctx.cublas_handle(); + // cusolverDnHandle_t cusolver_handle = dev_ctx.cusolver_dn_handle(); + cusolverDnHandle_t cusolver_handle = GetSolverHandle(dev_ctx.stream()); + + auto stream = phi::Stream(reinterpret_cast(dev_ctx.stream())); + + // cuComplex constants + const cuComplex kAlpha = make_cuFloatComplex(1.0f, 0.0f); + const cuComplex kZero = make_cuFloatComplex(0.0f, 0.0f); + + // Sizes + const size_t A_one_bytes = + static_cast(order) * order * sizeof(cuComplex); + const size_t B_one_bytes = + static_cast(order) * rhs_cols * sizeof(cuComplex); + const size_t A_batch_bytes = A_one_bytes * batch_count; + const size_t B_batch_bytes = B_one_bytes * batch_count; + + const cuComplex* A_row_all = reinterpret_cast(matrix_data); + const cuComplex* B_row_all = reinterpret_cast(rhs_data); + cuComplex* X_row_all = reinterpret_cast(out_data); + + auto dA_col_alloc = + phi::memory_utils::Alloc(dev_ctx.GetPlace(), A_batch_bytes, stream); + auto dB_col_alloc = + phi::memory_utils::Alloc(dev_ctx.GetPlace(), B_batch_bytes, stream); + cuComplex* dA_col = reinterpret_cast(dA_col_alloc->ptr()); + cuComplex* dB_col = reinterpret_cast(dB_col_alloc->ptr()); + + auto d_pivots_alloc = phi::memory_utils::Alloc( + dev_ctx.GetPlace(), + static_cast(batch_count) * order * sizeof(int), + stream); + int* d_pivots = reinterpret_cast(d_pivots_alloc->ptr()); + + auto d_info_alloc = + phi::memory_utils::Alloc(dev_ctx.GetPlace(), + static_cast(batch_count) * sizeof(int), + stream); + int* d_info = reinterpret_cast(d_info_alloc->ptr()); + + // A_row layout: row-major (order x order), B_row layout: row-major (order + // x rhs_cols) + for (int i = 0; i < batch_count; ++i) { + const cuComplex* A_row = A_row_all + static_cast(i) * order * order; + cuComplex* A_col = dA_col + static_cast(i) * order * order; + const cuComplex* B_row = + B_row_all + static_cast(i) * order * rhs_cols; + cuComplex* B_col = dB_col + static_cast(i) * order * rhs_cols; + + // transpose A_row (row-major) -> A_col (column-major) via C = A^T + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgeam( + cublas_handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + order, + order, + &kAlpha, + A_row, + order, // lda: when interpreting A_row as (order x order) row-major, + // using order + &kZero, + nullptr, + order, + A_col, + order)); // ldc = order (column-major leading dim) + + // transpose B_row (row-major order x rhs_cols) -> B_col (column-major order + // x rhs_cols) + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgeam( + cublas_handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + order, + rhs_cols, + &kAlpha, + B_row, + rhs_cols, // lda when A_row is viewed row-major: leading = rhs_cols + &kZero, + nullptr, + rhs_cols, + B_col, + order)); // ldc = order + } + + int lwork = 0; + cuComplex* dA_col0 = dA_col; + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgetrf_bufferSize( + cusolver_handle, order, order, dA_col0, order, &lwork)); + + size_t work_bytes = static_cast(lwork) * sizeof(cuComplex); + auto d_work_alloc = + phi::memory_utils::Alloc(dev_ctx.GetPlace(), work_bytes, stream); + cuComplex* d_work = reinterpret_cast(d_work_alloc->ptr()); + + for (int i = 0; i < batch_count; ++i) { + cuComplex* A_col = dA_col + static_cast(i) * order * order; + cuComplex* B_col = dB_col + static_cast(i) * order * rhs_cols; + int* pivots_i = d_pivots + static_cast(i) * order; + int* info_i = d_info + i; + + // getrf (LU factorization) on A_col (column-major) + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgetrf( + cusolver_handle, order, order, A_col, order, d_work, pivots_i, info_i)); + + // getrs: solve A_col * X_col = B_col + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgetrs( + cusolver_handle, + CUBLAS_OP_N, // no transpose on column-major matrix + order, + rhs_cols, + A_col, + order, + pivots_i, + B_col, + order, + info_i)); + } + + for (int i = 0; i < batch_count; ++i) { + cuComplex* B_col = dB_col + static_cast(i) * order * + rhs_cols; // X in column-major + cuComplex* X_row = X_row_all + static_cast(i) * order * + rhs_cols; // target row-major + + // transpose X_col -> X_row + // We use C = A^T : A has shape (order x rhs_cols) in column-major, so C + // will be (rhs_cols x order), but we want X_row with shape (order x + // rhs_cols) in row-major; calling cublasCgeam with op=T and adjusted dims + // works: + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgeam( + cublas_handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + rhs_cols, + order, // rowsC = rhs_cols, colsC = order + &kAlpha, + B_col, + order, // B_col lda = order (col-major) + &kZero, + nullptr, + order, + X_row, + rhs_cols)); // X_row ldc = rhs_cols (row-major leading dimension) + } + + std::vector h_info(batch_count, 0); + phi::memory_utils::Copy(phi::CPUPlace(), + h_info.data(), + dev_ctx.GetPlace(), + d_info, + static_cast(batch_count) * sizeof(int), + reinterpret_cast(dev_ctx.stream())); + dev_ctx.Wait(); + + for (int i = 0; i < batch_count; ++i) { + PADDLE_ENFORCE_EQ( + h_info[i], + 0, + errors::External( + "cuSOLVER getrf/getrs failed at batch %d, info: %d", i, h_info[i])); + } +} + +template <> +void SolveLinearSystemGPU>( + const phi::GPUContext& dev_ctx, + const phi::dtype::complex* + matrix_data, // device ptr, row-major, size batch*order*order + const phi::dtype::complex* + rhs_data, // device ptr, row-major, size batch*order*rhs_cols + phi::dtype::complex* + out_data, // device ptr, row-major, size batch*order*rhs_cols + int order, + int rhs_cols, + int batch_count) { + // handles + cublasHandle_t cublas_handle = dev_ctx.cublas_handle(); + // cusolverDnHandle_t cusolver_handle = dev_ctx.cusolver_dn_handle(); + cusolverDnHandle_t cusolver_handle = GetSolverHandle(dev_ctx.stream()); + + auto stream = phi::Stream(reinterpret_cast(dev_ctx.stream())); + + // cuDoubleComplex constants + const cuDoubleComplex kAlpha = make_cuDoubleComplex(1.0f, 0.0f); + const cuDoubleComplex kZero = make_cuDoubleComplex(0.0f, 0.0f); + + // Sizes + const size_t A_one_bytes = + static_cast(order) * order * sizeof(cuDoubleComplex); + const size_t B_one_bytes = + static_cast(order) * rhs_cols * sizeof(cuDoubleComplex); + const size_t A_batch_bytes = A_one_bytes * batch_count; + const size_t B_batch_bytes = B_one_bytes * batch_count; + + const cuDoubleComplex* A_row_all = + reinterpret_cast(matrix_data); + const cuDoubleComplex* B_row_all = + reinterpret_cast(rhs_data); + cuDoubleComplex* X_row_all = reinterpret_cast(out_data); + + auto dA_col_alloc = + phi::memory_utils::Alloc(dev_ctx.GetPlace(), A_batch_bytes, stream); + auto dB_col_alloc = + phi::memory_utils::Alloc(dev_ctx.GetPlace(), B_batch_bytes, stream); + cuDoubleComplex* dA_col = + reinterpret_cast(dA_col_alloc->ptr()); + cuDoubleComplex* dB_col = + reinterpret_cast(dB_col_alloc->ptr()); + + auto d_pivots_alloc = phi::memory_utils::Alloc( + dev_ctx.GetPlace(), + static_cast(batch_count) * order * sizeof(int), + stream); + int* d_pivots = reinterpret_cast(d_pivots_alloc->ptr()); + + auto d_info_alloc = + phi::memory_utils::Alloc(dev_ctx.GetPlace(), + static_cast(batch_count) * sizeof(int), + stream); + int* d_info = reinterpret_cast(d_info_alloc->ptr()); + + // A_row layout: row-major (order x order), B_row layout: row-major (order + // x rhs_cols) + for (int i = 0; i < batch_count; ++i) { + const cuDoubleComplex* A_row = + A_row_all + static_cast(i) * order * order; + cuDoubleComplex* A_col = dA_col + static_cast(i) * order * order; + const cuDoubleComplex* B_row = + B_row_all + static_cast(i) * order * rhs_cols; + cuDoubleComplex* B_col = dB_col + static_cast(i) * order * rhs_cols; + + // transpose A_row (row-major) -> A_col (column-major) via C = A^T + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgeam( + cublas_handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + order, + order, + &kAlpha, + A_row, + order, // lda: when interpreting A_row as (order x order) row-major, + // using order + &kZero, + nullptr, + order, + A_col, + order)); // ldc = order (column-major leading dim) + + // transpose B_row (row-major order x rhs_cols) -> B_col (column-major order + // x rhs_cols) + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgeam( + cublas_handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + order, + rhs_cols, + &kAlpha, + B_row, + rhs_cols, // lda when A_row is viewed row-major: leading = rhs_cols + &kZero, + nullptr, + rhs_cols, + B_col, + order)); // ldc = order + } + + int lwork = 0; + cuDoubleComplex* dA_col0 = dA_col; + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgetrf_bufferSize( + cusolver_handle, order, order, dA_col0, order, &lwork)); + + size_t work_bytes = static_cast(lwork) * sizeof(cuDoubleComplex); + auto d_work_alloc = + phi::memory_utils::Alloc(dev_ctx.GetPlace(), work_bytes, stream); + cuDoubleComplex* d_work = + reinterpret_cast(d_work_alloc->ptr()); + + for (int i = 0; i < batch_count; ++i) { + cuDoubleComplex* A_col = dA_col + static_cast(i) * order * order; + cuDoubleComplex* B_col = dB_col + static_cast(i) * order * rhs_cols; + int* pivots_i = d_pivots + static_cast(i) * order; + int* info_i = d_info + i; + + // getrf (LU factorization) on A_col (column-major) + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgetrf( + cusolver_handle, order, order, A_col, order, d_work, pivots_i, info_i)); + + // getrs: solve A_col * X_col = B_col + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgetrs( + cusolver_handle, + CUBLAS_OP_N, // no transpose on column-major matrix + order, + rhs_cols, + A_col, + order, + pivots_i, + B_col, + order, + info_i)); + } + + for (int i = 0; i < batch_count; ++i) { + cuDoubleComplex* B_col = dB_col + static_cast(i) * order * + rhs_cols; // X in column-major + cuDoubleComplex* X_row = X_row_all + static_cast(i) * order * + rhs_cols; // target row-major + + // transpose X_col -> X_row + // We use C = A^T : A has shape (order x rhs_cols) in column-major, so C + // will be (rhs_cols x order), but we want X_row with shape (order x + // rhs_cols) in row-major; calling cublasZgeam with op=T and adjusted dims + // works: + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgeam( + cublas_handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + rhs_cols, + order, // rowsC = rhs_cols, colsC = order + &kAlpha, + B_col, + order, // B_col lda = order (col-major) + &kZero, + nullptr, + order, + X_row, + rhs_cols)); // X_row ldc = rhs_cols (row-major leading dimension) + } + + std::vector h_info(batch_count, 0); + phi::memory_utils::Copy(phi::CPUPlace(), + h_info.data(), + dev_ctx.GetPlace(), + d_info, + static_cast(batch_count) * sizeof(int), + reinterpret_cast(dev_ctx.stream())); + dev_ctx.Wait(); + + for (int i = 0; i < batch_count; ++i) { + PADDLE_ENFORCE_EQ( + h_info[i], + 0, + errors::External( + "cuSOLVER getrf/getrs failed at batch %d, info: %d", i, h_info[i])); + } +} + +template +void ComputeBackwardForComplexInputGPU(const DenseTensor& L, + const DenseTensor& V, + const paddle::optional& gL, + const paddle::optional& gV, + T* x_grad_data, + int batch_count, + int order, + const Context& dev_ctx) { + DenseTensor gL_safe; + if (gL.get_ptr()) { + gL_safe = gL.get(); + } else { + gL_safe = + Fill(dev_ctx, common::vectorize(L.dims()), T(0)); + } + + DenseTensor gV_safe; + if (gV.get_ptr()) { + gV_safe = gV.get(); + } else { + gV_safe = + Fill(dev_ctx, common::vectorize(V.dims()), T(0)); + } + DenseTensor trans_v = phi::TransposeLast2Dim(dev_ctx, V); + DenseTensor Vh = phi::Conj(dev_ctx, trans_v); + DenseTensor Lconj = phi::Conj(dev_ctx, L); + DenseTensor Econj = phi::Subtract(dev_ctx, + phi::funcs::Unsqueeze(Lconj, -2), + phi::funcs::Unsqueeze(Lconj, -1)); + DenseTensor VhgV = phi::Matmul(dev_ctx, Vh, gV_safe); + DenseTensor diag_real = phi::Real(dev_ctx, VhgV); + + auto cpu_place = phi::CPUPlace(); + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* cpu_ctx = static_cast(pool.Get(cpu_place)); + + DenseTensor diag_real_cpu; + diag_real_cpu.Resize(diag_real.dims()); + phi::Copy(dev_ctx, diag_real, cpu_place, false, &diag_real_cpu); + + DenseTensor diag_res_cpu = + phi::funcs::BatchDiag((*cpu_ctx), diag_real_cpu, batch_count); + + DenseTensor diag_res; + dev_ctx.template Alloc(&diag_res); + phi::Copy(dev_ctx, diag_res_cpu, dev_ctx.GetPlace(), false, &diag_res); + + DenseTensor diag_unsqueezed = phi::funcs::Unsqueeze(diag_res, -2); + + auto numel = diag_unsqueezed.numel(); + DenseTensor diag_unsqueezed_complex; + auto* data_diag_un = diag_unsqueezed.data>(); + diag_unsqueezed_complex.Resize(diag_unsqueezed.dims()); + auto* data_diag_un_com = dev_ctx.template Alloc( + &diag_unsqueezed_complex, static_cast(numel * sizeof(T))); + + phi::funcs::ForRange for_range(dev_ctx, numel); + phi::funcs::RealToComplexFunctor functor( + data_diag_un, data_diag_un_com, numel); + for_range(functor); + // real tensor multiply complex tensor in broadcast manner + DenseTensor res1 = phi::Multiply(dev_ctx, V, diag_unsqueezed_complex); + DenseTensor res2 = phi::Matmul(dev_ctx, Vh, res1); + DenseTensor result = phi::Subtract(dev_ctx, VhgV, res2); + + result.Resize(V.dims()); + dev_ctx.template Alloc(&result); + result = phi::Divide(dev_ctx, result, Econj); + result = phi::funcs::DiagFill( + dev_ctx, order, order, order, 0, gL_safe, result); + DenseTensor rhs = phi::Matmul(dev_ctx, result, Vh); + + // solve linear system + // solve(Vh, rhs, out, m, k) + // Vh: matrix with shape [m,m] + // rhs: rhs with shape [m,k] + // x_grad: out + int m = static_cast(Vh.dims(-1)); + int k = static_cast(rhs.dims(-1)); + auto* matrix_data = Vh.data(); + auto* rhs_data = rhs.data(); + + SolveLinearSystemGPU( + dev_ctx, matrix_data, rhs_data, x_grad_data, m, k, batch_count); +} + +template +void EigGradKernel(const Context& dev_ctx, + const DenseTensor& out_w, + const DenseTensor& out_v, + const paddle::optional& dout_w, + const paddle::optional& dout_v, + DenseTensor* dx) { + auto* dx_data = dev_ctx.template Alloc>(dx); + if (dx->numel() == 0) { + return; + } + auto& dims = out_v.dims(); + phi::DDim dim_origin = dims; + int num_dims = dim_origin.size(); + int batch_count = BatchCount(out_v); + const int order = static_cast(dim_origin[num_dims - 1]); + + ComputeBackwardForComplexInputGPU, Context>( + out_w, out_v, dout_w, dout_v, dx_data, batch_count, order, dev_ctx); +} + +} // namespace phi + +// Register the kernel +PD_REGISTER_PLUGIN_KERNEL(eig_grad, + iluvatar_gpu, + ALL_LAYOUT, + phi::EigGradKernel, + float, + phi::complex64) { + kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); + kernel->InputAt(2).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); + kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); +} diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/eig_kernel.cu b/backends/iluvatar_gpu/kernels/cuda_kernels/eig_kernel.cu new file mode 100644 index 00000000000..8461b82ece7 --- /dev/null +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/eig_kernel.cu @@ -0,0 +1,123 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/eig.h" +#include "paddle/phi/kernels/eig_kernel.h" + +namespace phi { + +template +void EigKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out_w, + DenseTensor* out_v) { + dev_ctx.template Alloc>(out_w); + dev_ctx.template Alloc>(out_v); + + if (x.numel() == 0) { + return; + } + + auto cpu_place = phi::CPUPlace(); + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* cpu_ctx = static_cast(pool.Get(cpu_place)); + + // prepare cpu Tensor here, since magma requires output on cpu + DenseTensor out_w_cpu, out_v_cpu; + out_w_cpu.Resize(out_w->dims()); + (*cpu_ctx).template Alloc>(&out_w_cpu); + out_v_cpu.Resize(x.dims()); + (*cpu_ctx).template Alloc>(&out_v_cpu); + + if (!IsComplexType(x.dtype())) { + // output still be complex though input is real + int batch_count = BatchCount(x); + int order = static_cast(x.dims()[x.dims().size() - 1]); + + DenseTensor real_w_cpu, real_v_cpu; + + std::vector real_w_dim = common::vectorize(out_w->dims()); + real_w_dim.back() *= 2; + real_w_cpu.Resize(common::make_ddim(real_w_dim)); + (*cpu_ctx).template Alloc>(&real_w_cpu); + real_v_cpu.Resize(x.dims()); + (*cpu_ctx).template Alloc>(&real_v_cpu); + + phi::ApplyEigKernelMagma, Context>( + dev_ctx, x, &real_w_cpu, &real_v_cpu); + + // 1. extract real part & imag part from real_w_cpu + DenseTensor real_part_cpu = phi::funcs::Slice>( + (*cpu_ctx), real_w_cpu, {-1}, {0}, {order}); + DenseTensor imag_part_cpu = phi::funcs::Slice>( + (*cpu_ctx), real_w_cpu, {-1}, {order}, {order * 2}); + + // 2. construct complex values + auto* real_part_data = real_part_cpu.data>(); + auto* imag_part_data = imag_part_cpu.data>(); + int64_t out_w_numel = static_cast(out_w->numel()); + + phi::funcs::ForRange for_range((*cpu_ctx), out_w_numel); + phi::funcs::RealImagToComplexFunctor> functor( + real_part_data, + imag_part_data, + out_w_cpu.data>(), + out_w_numel); + for_range(functor); + + // 3. construct complex vectors + DenseTensor real_v_trans_cpu = + phi::TransposeLast2Dim, phi::CPUContext>( + (*cpu_ctx), real_v_cpu); + DenseTensor out_v_trans_cpu; + out_v_trans_cpu.Resize(x.dims()); + (*cpu_ctx).template Alloc>(&out_v_trans_cpu); + + phi::ConstructComplexVectors, + phi::dtype::Complex, + phi::CPUContext>(&out_v_trans_cpu, + out_w_cpu, + real_v_trans_cpu, + (*cpu_ctx), + batch_count, + order); + + TransposeTwoAxis, phi::CPUContext>( + out_v_trans_cpu, + &out_v_cpu, + x.dims().size() - 1, + x.dims().size() - 2, + (*cpu_ctx)); + + } else { + phi::ApplyEigKernelMagma(dev_ctx, x, &out_w_cpu, &out_v_cpu); + } + + // copy result from cpu to gpu tensor + phi::Copy(dev_ctx, out_w_cpu, dev_ctx.GetPlace(), false, out_w); + phi::Copy(dev_ctx, out_v_cpu, dev_ctx.GetPlace(), false, out_v); +} + +} // namespace phi + +PD_REGISTER_PLUGIN_KERNEL( + eig, iluvatar_gpu, ALL_LAYOUT, phi::EigKernel, float, phi::complex64) { + if (kernel_key.dtype() == phi::DataType::FLOAT32) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); + kernel->OutputAt(1).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); + } +} diff --git a/backends/iluvatar_gpu/runtime/iluvatar_context.h b/backends/iluvatar_gpu/runtime/iluvatar_context.h index a007bf4154c..a6ea37a0caa 100644 --- a/backends/iluvatar_gpu/runtime/iluvatar_context.h +++ b/backends/iluvatar_gpu/runtime/iluvatar_context.h @@ -77,6 +77,11 @@ class DnnWorkspaceHandle { namespace { // NOLINT inline cudnnHandle_t dnn_handle_ = nullptr; inline std::once_flag flag_dnn_; + +inline cusolverDnHandle_t solver_handle_ = nullptr; +inline std::function solver_handle_creator_{nullptr}; +inline std::once_flag flag_solver_; + inline void InitDnnHandle(cudnnHandle_t* handle, gpuStream_t stream, Place place) { @@ -114,6 +119,29 @@ inline DnnWorkspaceHandle GetDnnWorkspace(Allocator* alloactor, const gpuStream_t& stream) { return DnnWorkspaceHandle(alloactor, stream); } + +inline void InitSolverHandle(cusolverDnHandle_t* handle, gpuStream_t stream) { + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnCreate(handle)); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnSetStream(*handle, stream)); +} + +inline cusolverDnHandle_t GetSolverHandle(gpuStream_t stream) { + std::call_once(flag_solver_, [&]() { + if (!solver_handle_) { + if (!solver_handle_creator_) { + InitSolverHandle(&solver_handle_, stream); + } else { + solver_handle_ = solver_handle_creator_(); + } + } + }); + PADDLE_ENFORCE_NOT_NULL( + solver_handle_, + common::errors::InvalidArgument( + "The GPU solver handle is nullptr. It must not be null.")); + return solver_handle_; +} + } // namespace phi namespace iluvatar { diff --git a/backends/iluvatar_gpu/tests/unittests/test_eig_op_iluvatar.py b/backends/iluvatar_gpu/tests/unittests/test_eig_op_iluvatar.py new file mode 100644 index 00000000000..3f5fbb9f7f3 --- /dev/null +++ b/backends/iluvatar_gpu/tests/unittests/test_eig_op_iluvatar.py @@ -0,0 +1,391 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest, skip_check_grad_ci +from utils import dygraph_guard + +import paddle +from paddle import base + + +# cast output to complex for numpy.linalg.eig +def cast_to_complex(input, output): + if input.dtype == np.float32: + output = output.astype(np.complex64) + elif input.dtype == np.float64: + output = output.astype(np.complex128) + return output + + +# define eig backward function for a single square matrix +def eig_backward(w, v, grad_w, grad_v): + v_tran = np.transpose(v) + v_tran = np.conjugate(v_tran) + w_conj = np.conjugate(w) + w_conj_l = w_conj.reshape(1, w.size) + w_conj_r = w_conj.reshape(w.size, 1) + w_conj_2d = w_conj_l - w_conj_r + + vhgv = np.matmul(v_tran, grad_v) + real_vhgv = np.real(vhgv) + diag_real = real_vhgv.diagonal() + + diag_2d = diag_real.reshape(1, w.size) + rhs = v * diag_2d + mid = np.matmul(v_tran, rhs) + result = vhgv - mid + + res = np.divide(result, w_conj_2d) + row, col = np.diag_indices_from(res) + res[row, col] = 1.0 + + tmp = np.matmul(res, v_tran) + dx = np.linalg.solve(v_tran, tmp) + return dx + + +class TestEigOp(OpTest): + def setUp(self): + paddle.enable_static() + paddle.device.set_device("iluvatar_gpu") + self.op_type = "eig" + self.python_api = paddle.linalg.eig + self.__class__.op_type = self.op_type + self.init_input() + self.inputs = {"X": OpTest.np_dtype_to_base_dtype(self.x)} + self.outputs = {"Eigenvalues": self.out[0], "Eigenvectors": self.out[1]} + + def init_input(self): + self.set_dtype() + self.set_dims() + self.x = np.random.random(self.shape).astype(self.dtype) + self.out = np.linalg.eig(self.x) + self.out = ( + cast_to_complex(self.x, self.out[0]), + cast_to_complex(self.x, self.out[1]), + ) + + # for the real input, a customized checker is needed + def checker(self, outs): + actual_out_w = outs[0].flatten() + expect_out_w = self.out[0].flatten() + actual_out_v = outs[1].flatten() + expect_out_v = self.out[1].flatten() + + length_w = len(expect_out_w) + act_w_real = np.sort( + np.array([np.abs(actual_out_w[i].real) for i in range(length_w)]) + ) + act_w_imag = np.sort( + np.array([np.abs(actual_out_w[i].imag) for i in range(length_w)]) + ) + exp_w_real = np.sort( + np.array([np.abs(expect_out_w[i].real) for i in range(length_w)]) + ) + exp_w_imag = np.sort( + np.array([np.abs(expect_out_w[i].imag) for i in range(length_w)]) + ) + + for i in range(length_w): + np.testing.assert_allclose( + act_w_real[i], + exp_w_real[i], + rtol=1e-06, + atol=1e-05, + err_msg="The eigenvalues real part have diff: \nExpected " + + str(act_w_real[i]) + + "\n" + + "But got: " + + str(exp_w_real[i]), + ) + np.testing.assert_allclose( + act_w_imag[i], + exp_w_imag[i], + rtol=1e-06, + atol=1e-05, + err_msg="The eigenvalues image part have diff: \nExpected " + + str(act_w_imag[i]) + + "\n" + + "But got: " + + str(exp_w_imag[i]), + ) + + length_v = len(expect_out_v) + act_v_real = np.sort( + np.array([np.abs(actual_out_v[i].real) for i in range(length_v)]) + ) + act_v_imag = np.sort( + np.array([np.abs(actual_out_v[i].imag) for i in range(length_v)]) + ) + exp_v_real = np.sort( + np.array([np.abs(expect_out_v[i].real) for i in range(length_v)]) + ) + exp_v_imag = np.sort( + np.array([np.abs(expect_out_v[i].imag) for i in range(length_v)]) + ) + + for i in range(length_v): + np.testing.assert_allclose( + act_v_real[i], + exp_v_real[i], + rtol=1e-06, + atol=1e-05, + err_msg="The eigenvectors real part have diff: \nExpected " + + str(act_v_real[i]) + + "\n" + + "But got: " + + str(exp_v_real[i]), + ) + np.testing.assert_allclose( + act_v_imag[i], + exp_v_imag[i], + rtol=1e-06, + atol=1e-05, + err_msg="The eigenvectors image part have diff: \nExpected " + + str(act_v_imag[i]) + + "\n" + + "But got: " + + str(exp_v_imag[i]), + ) + + def set_dtype(self): + self.dtype = np.complex64 + + def set_dims(self): + self.shape = (10, 10) + + def init_grad(self): + # grad_w, grad_v complex dtype + gtype = self.dtype + if self.dtype == np.float32: + gtype = np.complex64 + elif self.dtype == np.float64: + gtype = np.complex128 + self.grad_w = np.ones(self.out[0].shape, gtype) + self.grad_v = np.ones(self.out[1].shape, gtype) + self.grad_x = eig_backward(self.out[0], self.out[1], self.grad_w, self.grad_v) + + def test_check_output(self): + self.check_output_with_place_customized( + checker=self.checker, + place=paddle.CustomPlace("iluvatar_gpu", 0), + check_pir=True, + ) + + def test_check_grad(self): + self.init_grad() + self.check_grad( + ["X"], + ["Eigenvalues", "Eigenvectors"], + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_w, self.grad_v], + check_pir=True, + ) + + +@skip_check_grad_ci( + reason="For float dtype, numpy.linalg.eig forward outputs real or complex when input is real, therefore the grad computation may be not the same with paddle.linalg.eig" +) +class TestFloat(TestEigOp): + def set_dtype(self): + self.dtype = np.float32 + + def test_check_grad(self): + pass + + +class TestEigStatic(TestEigOp): + def test_check_output_with_place(self): + paddle.enable_static() + place = paddle.CustomPlace("iluvatar_gpu", 0) + input_np = np.random.random([3, 3]).astype("complex") + expect_val, expect_vec = np.linalg.eig(input_np) + with base.program_guard(base.Program(), base.Program()): + input = paddle.static.data(name="input", shape=[3, 3], dtype="complex") + act_val, act_vec = paddle.linalg.eig(input) + + exe = base.Executor(place) + fetch_val, fetch_vec = exe.run( + base.default_main_program(), + feed={"input": input_np}, + fetch_list=[act_val, act_vec], + ) + np.testing.assert_allclose( + expect_val, + fetch_val, + rtol=1e-06, + atol=1e-06, + err_msg="The eigen values have diff: \nExpected " + + str(expect_val) + + "\n" + + "But got: " + + str(fetch_val), + ) + np.testing.assert_allclose( + np.abs(expect_vec), + np.abs(fetch_vec), + rtol=1e-06, + atol=1e-06, + err_msg="The eigen vectors have diff: \nExpected " + + str(np.abs(expect_vec)) + + "\n" + + "But got: " + + str(np.abs(fetch_vec)), + ) + + +class TestEigDyGraph(unittest.TestCase): + def test_check_output_with_place(self): + np.random.seed(1024) + input_np = np.random.random([3, 3]).astype("complex64") + expect_val, expect_vec = np.linalg.eig(input_np) + + paddle.set_device("iluvatar_gpu") + paddle.disable_static() + + input_tensor = paddle.to_tensor(input_np) + fetch_val, fetch_vec = paddle.linalg.eig(input_tensor) + + np.testing.assert_allclose( + expect_val, + fetch_val.numpy(), + rtol=1e-06, + atol=1e-06, + err_msg="The eigen values have diff: \nExpected " + + str(expect_val) + + "\n" + + "But got: " + + str(fetch_val), + ) + np.testing.assert_allclose( + np.abs(expect_vec), + np.abs(fetch_vec.numpy()), + rtol=1e-06, + atol=1e-06, + err_msg="The eigen vectors have diff: \nExpected " + + str(np.abs(expect_vec)) + + "\n" + + "But got: " + + str(np.abs(fetch_vec.numpy())), + ) + + # def test_check_grad(self): + # test_shape = [3, 3] + # test_type = 'float32' + # paddle.set_device("iluvatar_gpu") + + # np.random.seed(1024) + # input_np = np.random.random(test_shape).astype(test_type) + # real_w, real_v = np.linalg.eig(input_np) + + # grad_w = np.ones(real_w.shape, test_type) + # grad_v = np.ones(real_v.shape, test_type) + # grad_x = eig_backward(real_w, real_v, grad_w, grad_v) + + # with base.dygraph.guard(): + # x = paddle.to_tensor(input_np) + # x.stop_gradient = False + # w, v = paddle.linalg.eig(x) + # (w.sum() + v.sum()).backward() + + # np.testing.assert_allclose( + # np.abs(x.grad.numpy()), + # np.abs(grad_x), + # rtol=1e-05, + # atol=1e-05, + # err_msg='The grad x have diff: \nExpected ' + # + str(np.abs(grad_x)) + # + '\n' + # + 'But got: ' + # + str(np.abs(x.grad.numpy())), + # ) + + +class TestEigWrongDimsError(unittest.TestCase): + def test_error(self): + paddle.device.set_device("iluvatar_gpu") + paddle.disable_static() + a = np.random.random(3).astype("float32") + x = paddle.to_tensor(a) + self.assertRaises(ValueError, paddle.linalg.eig, x) + + +class TestEigNotSquareError(unittest.TestCase): + def test_error(self): + paddle.device.set_device("iluvatar_gpu") + paddle.disable_static() + a = np.random.random((1, 2, 3)).astype("float32") + x = paddle.to_tensor(a) + self.assertRaises(ValueError, paddle.linalg.eig, x) + + +class TestEigUnsupportedDtypeError(unittest.TestCase): + def test_error(self): + paddle.device.set_device("iluvatar_gpu") + paddle.disable_static() + a = (np.random.random((3, 3)) * 10).astype("int64") + x = paddle.to_tensor(a) + self.assertRaises(RuntimeError, paddle.linalg.eig, x) + + +class TestOptionalGradInput(unittest.TestCase): + def test_eager(self): + with dygraph_guard(), paddle.device.device_guard("iluvatar_gpu"): + x = paddle.randn(3, 3, requires_grad=True) + w, v = paddle.linalg.eig(x) + + np.testing.assert_allclose( + (x @ v).numpy(), + (w.unsqueeze(0) * v).numpy(), + atol=1e-5, + rtol=1e-5, + ) # Aμ = λμ + + # (dw_dx,) = paddle.grad(w, x, retain_graph=True) + # (dv_dx,) = paddle.grad(v, x, retain_graph=True) + # (dwdv_dx,) = paddle.grad([w, v], x) + # np.testing.assert_allclose( + # (dw_dx + dv_dx).numpy(), + # dwdv_dx.numpy(), + # atol=1e-5, + # rtol=1e-5, + # ) + + def test_dy2st(self): + with dygraph_guard(), paddle.device.device_guard("iluvatar_gpu"): + x = paddle.randn(3, 3, requires_grad=True) + + def f(x): + w, v = paddle.linalg.eig(x) + return ( + w, + v, + ) + + st_f = paddle.jit.to_static(f, full_graph=True, backend=None) + + w, v = st_f(x) + np.testing.assert_allclose( + (x @ v).numpy(), + (w.unsqueeze(0) * v).numpy(), + atol=1e-5, + rtol=1e-5, + ) # Aμ = λμ + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/iluvatar_gpu/tests/unittests/test_linalg_eig_op_iluvatar.py b/backends/iluvatar_gpu/tests/unittests/test_linalg_eig_op_iluvatar.py new file mode 100644 index 00000000000..18581f25af0 --- /dev/null +++ b/backends/iluvatar_gpu/tests/unittests/test_linalg_eig_op_iluvatar.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from utils import dygraph_guard + +import paddle + + +class TestEigAPI0Size(unittest.TestCase): + def test_errors(self): + with dygraph_guard(), paddle.device.device_guard("iluvatar_gpu"): + for shape in [[0, 0], [0, 4, 4], [1, 0, 2, 3, 3]]: + x = paddle.randn(shape=shape, dtype="float32", requires_grad=True) + w, v = paddle.linalg.eig(x) + self.assertEqual(w.shape, shape[:-1]) + self.assertEqual(v.shape, shape) + + # (dw_dx,) = paddle.grad(w.abs().sum(), x, retain_graph=True) + # self.assertEqual(dw_dx.shape, x.shape) + # (dv_dx,) = paddle.grad(v.abs().sum(), x, retain_graph=True) + # self.assertEqual(dv_dx.shape, x.shape) + # (dwv_dx,) = paddle.grad( + # w.abs().sum() + v.abs().sum(), x, retain_graph=True + # ) + # self.assertEqual(dwv_dx.shape, x.shape) + + +if __name__ == "__main__": + unittest.main()