diff --git a/.github/workflows/oneapi_githubactions_build.yml b/.github/workflows/oneapi_githubactions_build.yml new file mode 100644 index 0000000000..abbb84d839 --- /dev/null +++ b/.github/workflows/oneapi_githubactions_build.yml @@ -0,0 +1,82 @@ +name: oneapi_ghactions_buildrun + +on: + push: + branches: [ "feature/sycl" ] + +defaults: + run: + shell: bash + +env: + BUILD_TYPE: RELEASE + +jobs: + buildrun: + runs-on: ubuntu-latest + steps: + - name: Install software + run: | + sudo apt update + sudo apt install -y gpg-agent wget + # download the key to system keyring + wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | sudo tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null + # add signed entry to apt sources and configure the APT client to use Intel repository: + echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list + sudo apt update + sudo apt install intel-oneapi-hpc-toolkit + + - name: Setup oneAPI + run: | + source /opt/intel/oneapi/setvars.sh + printenv >> $GITHUB_ENV + which icpx + icpx -v + cat /proc/cpuinfo + + - uses: actions/checkout@v4 + + - name: Ccache for gh actions + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ${{ github.job }} + max-size: 2000M + + - name: Configure CMake + run: > + cmake + -B ${{github.workspace}}/build + -GNinja + -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} + -DCMAKE_C_COMPILER=icx + -DCMAKE_CXX_COMPILER=icpx + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + -DQUDA_TARGET_TYPE=SYCL + -DQUDA_SYCL_TARGETS=spir64_x86_64 + -DCMAKE_CXX_FLAGS="-Wno-unsupported-floating-point-opt" + -DCMAKE_SYCL_FLAGS="-Xs -march=avx512 -Wno-unsupported-floating-point-opt" + -DSYCL_LINK_FLAGS="-Xs -march=avx512 -fsycl-device-code-split=per_kernel -fsycl-max-parallel-link-jobs=4 -flink-huge-device-code" + -DQUDA_DIRAC_COVDEV=OFF + -DQUDA_DIRAC_DISTANCE_PRECONDITIONING=OFF + -DQUDA_MULTIGRID=ON + -DQUDA_INTERFACE_QDPJIT=ON + -DQUDA_FAST_COMPILE_REDUCE=ON + -DQUDA_FAST_COMPILE_DSLASH=ON + -DQUDA_OPENMP=OFF + -DQUDA_MPI=ON + -DQUDA_PRECISION=12 + -DQUDA_DIRAC_DEFAULT_OFF=ON + -DQUDA_DIRAC_STAGGERED=ON + -DQUDA_DIRAC_WILSON=ON + + - name: Build + run: cmake --build ${{github.workspace}}/build + + - name: Install + run: cmake --install ${{github.workspace}}/build + + - name: Run + run: | + cd ${{github.workspace}}/build + #ctest + ctest -E 'invert_test_asqtad_single|invert_test_splitgrid_asqtad_single|unitarize_link_single' diff --git a/CMakeLists.txt b/CMakeLists.txt index c3fa2a949c..d91b1aa8bd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -211,7 +211,11 @@ if(QUDA_MAX_MULTI_RHS_TILE GREATER QUDA_MAX_MULTI_RHS) message(SEND_ERROR "QUDA_MAX_MULTI_RHS_TILE is greater than QUDA_MAX_MULTI_RHS") endif() -set(QUDA_MAX_KERNEL_ARG_SIZE "4096" CACHE STRING "maximum static size of the kernel arguments in bytes passed to a kernel on the target architecture") +if(${QUDA_TARGET_TYPE} STREQUAL "SYCL") + set(QUDA_MAX_KERNEL_ARG_SIZE "2048" CACHE STRING "maximum static size of the kernel arguments in bytes passed to a kernel on the target architecture") +else() + set(QUDA_MAX_KERNEL_ARG_SIZE "4096" CACHE STRING "maximum static size of the kernel arguments in bytes passed to a kernel on the target architecture") +endif() if(QUDA_MAX_KERNEL_ARG_SIZE GREATER 32764) message(SEND_ERROR "Maximum QUDA_MAX_KERNEL_ARG_SIZE is 32764") endif() diff --git a/cmake/CMakeDetermineSYCLCompiler.cmake b/cmake/CMakeDetermineSYCLCompiler.cmake new file mode 100644 index 0000000000..144b288e92 --- /dev/null +++ b/cmake/CMakeDetermineSYCLCompiler.cmake @@ -0,0 +1,36 @@ +if(NOT CMAKE_SYCL_COMPILER) + set(CMAKE_SYCL_COMPILER ${CMAKE_CXX_COMPILER}) +endif() +mark_as_advanced(CMAKE_SYCL_COMPILER) +message(STATUS "The SYCL compiler is " ${CMAKE_SYCL_COMPILER}) + +if(NOT CMAKE_SYCL_COMPILER_ID_RUN) + set(CMAKE_SYCL_COMPILER_ID_RUN 1) + + # Try to identify the compiler. + set(CMAKE_SYCL_COMPILER_ID) + set(CMAKE_SYCL_PLATFORM_ID) + file(READ ${CMAKE_ROOT}/Modules/CMakePlatformId.h.in CMAKE_SYCL_COMPILER_ID_PLATFORM_CONTENT) + + set(CMAKE_SYCL_COMPILER_ID_TEST_FLAGS_FIRST) + set(CMAKE_SYCL_COMPILER_ID_TEST_FLAGS) + + set(CMAKE_CXX_COMPILER_ID_CONTENT "#if defined(__INTEL_LLVM_COMPILER)\n# define COMPILER_ID \"IntelLLVM\"\n") + string(APPEND CMAKE_CXX_COMPILER_ID_CONTENT "#elif defined(__clang__)\n# define COMPILER_ID \"Clang\"\n") + string(APPEND CMAKE_CXX_COMPILER_ID_CONTENT "#endif\n") + include(${CMAKE_ROOT}/Modules/CMakeDetermineCompilerId.cmake) + CMAKE_DETERMINE_COMPILER_ID(SYCL SYCLFLAGS CMakeCXXCompilerId.cpp) + + _cmake_find_compiler_sysroot(SYCL) +endif() + + +#set(CMAKE_SYCL_COMPILER_ID_TEST_FLAGS_FIRST) +#set(CMAKESYCL_COMPILER_ID_TEST_FLAGS "-c") +#include(${CMAKE_ROOT}/Modules/CMakeDetermineCompilerId.cmake) +#CMAKE_DETERMINE_COMPILER_ID(SYCL SYCLFLAGS CMakeCXXCompilerId.cpp) + +configure_file(${CMAKE_CURRENT_LIST_DIR}/CMakeSYCLCompiler.cmake.in + ${CMAKE_PLATFORM_INFO_DIR}/CMakeSYCLCompiler.cmake) + +set(CMAKE_SYCL_COMPILER_ENV_VAR "SYCL") diff --git a/cmake/CMakeSYCLCompiler.cmake.in b/cmake/CMakeSYCLCompiler.cmake.in new file mode 100644 index 0000000000..2dc0b7acd2 --- /dev/null +++ b/cmake/CMakeSYCLCompiler.cmake.in @@ -0,0 +1,3 @@ +set(CMAKE_SYCL_COMPILER "@CMAKE_SYCL_COMPILER@") +set(CMAKE_SYCL_COMPILER_LOADED 1) +set(CMAKE_SYCL_COMPILER_ENV_VAR "SYCL") diff --git a/cmake/CMakeSYCLInformation.cmake b/cmake/CMakeSYCLInformation.cmake new file mode 100644 index 0000000000..6572616fbf --- /dev/null +++ b/cmake/CMakeSYCLInformation.cmake @@ -0,0 +1,47 @@ +if(NOT CMAKE_SYCL_COMPILE_OPTIONS_PIC) + set(CMAKE_SYCL_COMPILE_OPTIONS_PIC ${CMAKE_CXX_COMPILE_OPTIONS_PIC}) +endif() + +if(NOT CMAKE_SYCL_COMPILE_OPTIONS_PIE) + set(CMAKE_SYCL_COMPILE_OPTIONS_PIE ${CMAKE_CXX_COMPILE_OPTIONS_PIE}) +endif() +if(NOT CMAKE_SYCL_LINK_OPTIONS_PIE) + set(CMAKE_SYCL_LINK_OPTIONS_PIE ${CMAKE_CXX_LINK_OPTIONS_PIE}) +endif() +if(NOT CMAKE_SYCL_LINK_OPTIONS_NO_PIE) + set(CMAKE_SYCL_LINK_OPTIONS_NO_PIE ${CMAKE_CXX_LINK_OPTIONS_NO_PIE}) +endif() + +if(NOT CMAKE_SYCL_OUTPUT_EXTENSION) + set(CMAKE_SYCL_OUTPUT_EXTENSION ${CMAKE_CXX_OUTPUT_EXTENSION}) +endif() + +if(NOT CMAKE_INCLUDE_FLAG_SYCL) + set(CMAKE_INCLUDE_FLAG_SYCL ${CMAKE_INCLUDE_FLAG_CXX}) +endif() + +if(NOT CMAKE_SYCL_COMPILE_OPTIONS_EXPLICIT_LANGUAGE) + set(CMAKE_SYCL_COMPILE_OPTIONS_EXPLICIT_LANGUAGE ${CMAKE_CXX_COMPILE_OPTIONS_EXPLICIT_LANGUAGE}) +endif() + +if(NOT CMAKE_SYCL_DEPENDS_USE_COMPILER) + set(CMAKE_SYCL_DEPENDS_USE_COMPILER ${CMAKE_CXX_DEPENDS_USE_COMPILER}) +endif() + +if(NOT CMAKE_DEPFILE_FLAGS_SYCL) + set(CMAKE_DEPFILE_FLAGS_SYCL ${CMAKE_DEPFILE_FLAGS_CXX}) +endif() + +if(NOT CMAKE_SYCL_DEPFILE_FORMAT) + set(CMAKE_SYCL_DEPFILE_FORMAT ${CMAKE_CXX_DEPFILE_FORMAT}) +endif() + +if(NOT CMAKE_SYCL_COMPILE_OBJECT) + set(CMAKE_SYCL_COMPILE_OBJECT " -o -c ") +endif() + +if(NOT CMAKE_SYCL_LINK_EXECUTABLE) + set(CMAKE_SYCL_LINK_EXECUTABLE " -o ") +endif() + +set(CMAKE_SYCL_INFORMATION_LOADED 1) diff --git a/cmake/CMakeTestSYCLCompiler.cmake b/cmake/CMakeTestSYCLCompiler.cmake new file mode 100644 index 0000000000..e7c7219631 --- /dev/null +++ b/cmake/CMakeTestSYCLCompiler.cmake @@ -0,0 +1 @@ +set(CMAKE_SYCL_COMPILER_WORKS 1 CACHE INTERNAL "") diff --git a/include/array.h b/include/array.h index 3005087c85..e5e65b1493 100644 --- a/include/array.h +++ b/include/array.h @@ -34,6 +34,8 @@ namespace quda return output; } + template constexpr T &elem(array &a, int i) { return a[i]; } + /** * @brief Element-wise maximum of two arrays * @param a first array diff --git a/include/blas_helper.cuh b/include/blas_helper.cuh index 806eef5f5e..fa92ec024b 100644 --- a/include/blas_helper.cuh +++ b/include/blas_helper.cuh @@ -193,10 +193,10 @@ namespace quda norm_t max_[n]; // two-pass to increase ILP (assumes length divisible by two, e.g. complex-valued) #pragma unroll - for (int i = 0; i < n; i++) max_[i] = fmaxf(fabsf((norm_t)v[i].real()), fabsf((norm_t)v[i].imag())); + for (int i = 0; i < n; i++) max_[i] = quda::max(quda::abs((norm_t)v[i].real()), quda::abs((norm_t)v[i].imag())); norm_t scale = 0.0; #pragma unroll - for (int i = 0; i < n; i++) scale = fmaxf(max_[i], scale); + for (int i = 0; i < n; i++) scale = quda::max(max_[i], scale); norm = scale * fixedInvMaxValue::value; return fdividef(fixedMaxValue::value, scale); } @@ -309,7 +309,7 @@ namespace quda memcpy(&vecTmp[6], &norm, sizeof(norm_t)); // pack the norm array vecTmp2; copy_and_scale(vecTmp2, &v_[0], scale_inv); - std::memcpy(&vecTmp, &vecTmp2, sizeof(vecTmp2)); + memcpy(&vecTmp, &vecTmp2, sizeof(vecTmp2)); // second do vectorized copy into memory vector_store(data.spinor, parity * cb_offset + x, vecTmp); } diff --git a/include/clover_field_order.h b/include/clover_field_order.h index af1a8d73c8..63949971e3 100644 --- a/include/clover_field_order.h +++ b/include/clover_field_order.h @@ -860,8 +860,8 @@ namespace quda { if (clover.Order() != QUDA_QDPJIT_CLOVER_ORDER) { errorQuda("Invalid clover order %d for this accessor", clover.Order()); } - offdiag = clover_ ? ((Float **)clover_)[0] : clover.data(inverse)[0]; - diag = clover_ ? ((Float **)clover_)[1] : clover.data(inverse)[1]; + offdiag = clover_ ? reinterpret_cast(clover_)[0] : clover.data(inverse)[0]; + diag = clover_ ? reinterpret_cast(clover_)[1] : clover.data(inverse)[1]; } QudaTwistFlavorType TwistFlavor() const { return twist_flavor; } diff --git a/include/color_spinor_field.h b/include/color_spinor_field.h index 6362b4677b..add4b30fc6 100644 --- a/include/color_spinor_field.h +++ b/include/color_spinor_field.h @@ -488,7 +488,7 @@ namespace quda template auto data() const { if (ghost_only) errorQuda("Not defined for ghost-only field"); - return reinterpret_cast(v.data()); + return static_cast(v.data()); } /** @@ -635,6 +635,7 @@ namespace quda @param[in] gdr_recv Whether we are using GDR on the receive side */ int commsQuery(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false) const; + void commsQuery(int n, int d[], bool done[], bool gdr_send, bool gdr_recv) const; /** @brief Wait on halo communication to complete @@ -872,7 +873,6 @@ namespace quda /** * @brief Print the site vector - * @param[in] a The field we are printing from * @param[in] parity Parity index * @param[in] x_cb Checkerboard space-time index * @param[in] rank The rank we are requesting from (default is rank = 0) diff --git a/include/color_spinor_field_order.h b/include/color_spinor_field_order.h index 46ad849079..ee935fadfe 100644 --- a/include/color_spinor_field_order.h +++ b/include/color_spinor_field_order.h @@ -241,8 +241,9 @@ namespace quda constexpr int M = nSpinBlock * nColor * nVec; #pragma unroll for (int i = 0; i < M; i++) { - vec_t tmp - = vector_load(reinterpret_cast(in + parity * offset_cb), x_cb * N + chi * M + i); + // vec_t tmp + // = vector_load(reinterpret_cast(in + parity * offset_cb), x_cb * N + chi * M + i); + vec_t tmp = vector_load(in + parity * offset_cb, x_cb * N + chi * M + i); memcpy(&out[i], &tmp, sizeof(vec_t)); } } @@ -1061,7 +1062,7 @@ namespace quda for (int i = 0; i < length_ghost / 2; i++) max_[i] = fmaxf((norm_type)fabsf((norm_type)v[i]), (norm_type)fabsf((norm_type)v[i + length_ghost / 2])); #pragma unroll - for (int i = 0; i < length_ghost / 2; i++) scale = fmaxf(max_[i], scale); + for (int i = 0; i < length_ghost / 2; i++) scale = max(max_[i], scale); ghost_norm[2 * dim + dir][parity * faceVolumeCB[dim] + x] = scale * fixedInvMaxValue::value; scale_inv = fdividef(fixedMaxValue::value, scale); } @@ -1203,7 +1204,7 @@ namespace quda for (int i = 0; i < length / 2; i++) max_[i] = fmaxf(fabsf((norm_type)v[i]), fabsf((norm_type)v[i + length / 2])); #pragma unroll - for (int i = 0; i < length / 2; i++) scale = fmaxf(max_[i], scale); + for (int i = 0; i < length / 2; i++) scale = max(max_[i], scale); norm[x + parity * norm_offset] = scale * fixedInvMaxValue::value; scale_inv = fdividef(fixedMaxValue::value, scale); } @@ -1306,10 +1307,10 @@ namespace quda // two-pass to increase ILP (assumes length divisible by two, e.g. complex-valued) #pragma unroll for (int i = 0; i < length_ghost / 2; i++) - max_[i] = fmaxf(fabsf((norm_type)v[i]), fabsf((norm_type)v[i + length_ghost / 2])); + max_[i] = max(abs((norm_type)v[i]), abs((norm_type)v[i + length_ghost / 2])); norm_type scale = 0.0; #pragma unroll - for (int i = 0; i < length_ghost / 2; i++) scale = fmaxf(max_[i], scale); + for (int i = 0; i < length_ghost / 2; i++) scale = max(max_[i], scale); norm_type nrm = scale * fixedInvMaxValue::value; real scale_inv = fdividef(fixedMaxValue::value, scale); @@ -1411,11 +1412,10 @@ namespace quda norm_type max_[length / 2]; // two-pass to increase ILP (assumes length divisible by two, e.g. complex-valued) #pragma unroll - for (int i = 0; i < length / 2; i++) - max_[i] = fmaxf(fabsf((norm_type)v[i]), fabsf((norm_type)v[i + length / 2])); + for (int i = 0; i < length / 2; i++) max_[i] = max(abs((norm_type)v[i]), abs((norm_type)v[i + length / 2])); norm_type scale = 0.0; #pragma unroll - for (int i = 0; i < length / 2; i++) scale = fmaxf(max_[i], scale); + for (int i = 0; i < length / 2; i++) scale = max(max_[i], scale); norm_type nrm = scale * fixedInvMaxValue::value; real scale_inv = fdividef(fixedMaxValue::value, scale); diff --git a/include/comm_quda.h b/include/comm_quda.h index 8496efab5a..66c5dba0e4 100644 --- a/include/comm_quda.h +++ b/include/comm_quda.h @@ -415,6 +415,7 @@ namespace quda void comm_start(MsgHandle *mh); void comm_wait(MsgHandle *mh); int comm_query(MsgHandle *mh); + // void comm_query(int n, MsgHandle *mh[], int *outcount, int array_of_indices[]); template void comm_allreduce_sum(T &v); template void comm_allreduce_max(T &v); diff --git a/include/communicator_quda.h b/include/communicator_quda.h index bf0e9ffba5..cdca2655a7 100644 --- a/include/communicator_quda.h +++ b/include/communicator_quda.h @@ -747,6 +747,8 @@ namespace quda int comm_query(MsgHandle *mh); + // void comm_query(int n, MsgHandle *mh[], int *outcount, int array_of_indices[]); + template T deterministic_reduce(T *array, int n) { std::sort(array, array + n); // sort reduction into ascending order for deterministic reduction diff --git a/include/convert.h b/include/convert.h index f56751873c..2d1026cb31 100644 --- a/include/convert.h +++ b/include/convert.h @@ -128,6 +128,7 @@ namespace quda } }; +#if 0 /** @brief Fast float-to-integer round used on the device */ @@ -148,6 +149,7 @@ namespace quda return i; } }; +#endif /** @brief Regular double-to-integer round used on the host @@ -156,6 +158,7 @@ namespace quda constexpr int operator()(double d) { return static_cast(rint(d)); } }; +#if 0 /** @brief Fast double-to-integer round used on the device */ @@ -166,6 +169,7 @@ namespace quda return reinterpret_cast(d); } }; +#endif /** @brief Copy function which is trival between floating point diff --git a/include/dslash_helper.cuh b/include/dslash_helper.cuh index 834b59425c..41883d7b1a 100644 --- a/include/dslash_helper.cuh +++ b/include/dslash_helper.cuh @@ -500,6 +500,12 @@ namespace quda dslash.template operator()(x_cb, s, parity); } + template + __forceinline__ __device__ void apply_dslash(D &dslash, int x_cb, int s, int parity, bool alive) + { + dslash.template operator()(x_cb, s, parity, alive); + } + #ifdef NVSHMEM_COMMS /** * @brief helper function for nvshmem uber kernel to signal that the interior kernel has completed. @@ -678,7 +684,8 @@ namespace quda { } - __forceinline__ __device__ void operator()(int, int s, int parity) + template + __forceinline__ __device__ void operator()(int, int s, int parity, bool alive = true) { typename Arg::D dslash(*this); // for full fields set parity from z thread index else use arg setting @@ -686,10 +693,11 @@ namespace quda if ((kernel_type == INTERIOR_KERNEL || kernel_type == UBER_KERNEL) && target::block_idx().x < static_cast(arg.pack_blocks)) { - // first few blocks do packing kernel - typename Arg::template P packer; - packer(arg, s, 1 - parity, dslash.twist_pack()); // flip parity since pack is on input - + if (!allthreads || alive) { + // first few blocks do packing kernel + typename Arg::template P packer; + packer(arg, s, 1 - parity, dslash.twist_pack()); // flip parity since pack is on input + } // we use that when running the exterior -- this is either // * an explicit call to the exterior when not merged with the interior or // * the interior with exterior_blocks > 0 @@ -731,9 +739,19 @@ namespace quda } #endif } else { - if (x_cb >= arg.threads) return; + if (x_cb >= arg.threads) { + if constexpr (allthreads) + alive = false; + else + return; + } + + if constexpr (allthreads) { + apply_dslash(dslash, x_cb, s, parity, alive); + } else { + apply_dslash(dslash, x_cb, s, parity); + } - apply_dslash(dslash, x_cb, s, parity); if constexpr (use_nvshmem_comms && kernel_type == UBER_KERNEL) { __syncthreads(); if (target::thread_idx().x == 0 && target::thread_idx().y == 0 && target::thread_idx().z == 0) diff --git a/include/externals/json.hpp b/include/externals/json.hpp index cb27e05811..443aa9a665 100644 --- a/include/externals/json.hpp +++ b/include/externals/json.hpp @@ -21895,7 +21895,7 @@ inline void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL& j1, nlohmann::NLOHMANN_BASIC /// @brief user-defined string literal for JSON values /// @sa https://json.nlohmann.me/api/basic_json/operator_literal_json/ JSON_HEDLEY_NON_NULL(1) -inline nlohmann::json operator "" _json(const char* s, std::size_t n) +inline nlohmann::json operator""_json(const char* s, std::size_t n) { return nlohmann::json::parse(s, s + n); } @@ -21903,7 +21903,7 @@ inline nlohmann::json operator "" _json(const char* s, std::size_t n) /// @brief user-defined string literal for JSON pointer /// @sa https://json.nlohmann.me/api/basic_json/operator_literal_json_pointer/ JSON_HEDLEY_NON_NULL(1) -inline nlohmann::json::json_pointer operator "" _json_pointer(const char* s, std::size_t n) +inline nlohmann::json::json_pointer operator""_json_pointer(const char* s, std::size_t n) { return nlohmann::json::json_pointer(std::string(s, n)); } diff --git a/include/gauge_field_order.h b/include/gauge_field_order.h index 4561f1f21f..a851569746 100644 --- a/include/gauge_field_order.h +++ b/include/gauge_field_order.h @@ -1945,6 +1945,7 @@ namespace quda { LegacyOrder(u, ghost_), volumeCB(u.VolumeCB()) { for (int i = 0; i < 4; i++) gauge[i] = gauge_ ? ((Float **)gauge_)[i] : u.data(i); + // for (int i = 0; i < 4; i++) gauge[i] = gauge_ ? gauge_[i] : u.data(i); } __device__ __host__ inline void load(complex v[length / 2], int x, int dir, int parity, real = 1.0) const @@ -1991,6 +1992,7 @@ namespace quda { LegacyOrder(u, ghost_), volumeCB(u.VolumeCB()) { for (int i = 0; i < 4; i++) gauge[i] = gauge_ ? ((Float **)gauge_)[i] : u.data(i); + // for (int i = 0; i < 4; i++) gauge[i] = gauge_ ? gauge_[i] : u.data(i); } __device__ __host__ inline void load(complex v[length / 2], int x, int dir, int parity, real = 1.0) const diff --git a/include/kernels/block_orthogonalize.cuh b/include/kernels/block_orthogonalize.cuh index 3a70a4a096..e3e0868c7f 100644 --- a/include/kernels/block_orthogonalize.cuh +++ b/include/kernels/block_orthogonalize.cuh @@ -135,7 +135,8 @@ namespace quda { for (int c = 0; c < nColor; c++) arg.V(parity, x_cb, chirality * spinBlock + s, c, i) = v(s, c); } - __device__ __host__ inline void operator()(dim3 block, dim3 thread) + template // true if all threads in block will enter, even if out of range + __device__ __host__ inline void operator()(dim3 block, dim3 thread, bool alive = true) { int x_coarse = block.x; int x_fine_offset = thread.x; @@ -146,14 +147,20 @@ namespace quda { int x_cb[n_sites_per_thread]; for (int tx = 0; tx < n_sites_per_thread; tx++) { - int x_fine_offset_tx = x_fine_offset * n_sites_per_thread + tx; - // all threads with x_fine_offset greater than aggregate_size_cb are second parity - int parity_offset = (x_fine_offset_tx >= arg.aggregate_size_cb && fineSpin != 1) ? 1 : 0; - x_offset_cb[tx] = x_fine_offset_tx - parity_offset * arg.aggregate_size_cb; - parity[tx] = fineSpin == 1 ? chirality : arg.nParity == 2 ? parity_offset : arg.parity; - - x_cb[tx] = x_offset_cb[tx] >= arg.aggregate_size_cb ? 0 : - arg.coarse_to_fine[ (x_coarse*2 + parity[tx]) * arg.aggregate_size_cb + x_offset_cb[tx] ] - parity[tx]*arg.fineVolumeCB; + if (!allthreads || alive) { + int x_fine_offset_tx = x_fine_offset * n_sites_per_thread + tx; + // all threads with x_fine_offset greater than aggregate_size_cb are second parity + int parity_offset = (x_fine_offset_tx >= arg.aggregate_size_cb && fineSpin != 1) ? 1 : 0; + x_offset_cb[tx] = x_fine_offset_tx - parity_offset * arg.aggregate_size_cb; + parity[tx] = fineSpin == 1 ? chirality : arg.nParity == 2 ? parity_offset : arg.parity; + + x_cb[tx] = x_offset_cb[tx] >= arg.aggregate_size_cb ? + 0 : + arg.coarse_to_fine[(x_coarse * 2 + parity[tx]) * arg.aggregate_size_cb + x_offset_cb[tx]] + - parity[tx] * arg.fineVolumeCB; + } else { + x_offset_cb[tx] = arg.aggregate_size_cb; + } } if (fineSpin == 1) chirality = 0; // when using staggered chirality is mapped to parity diff --git a/include/kernels/block_transpose.cuh b/include/kernels/block_transpose.cuh index 4499dfa07c..d0eb84673e 100644 --- a/include/kernels/block_transpose.cuh +++ b/include/kernels/block_transpose.cuh @@ -73,7 +73,8 @@ namespace quda - V: spatial -> spin/color -> nVec The transpose uses shared memory to avoid strided memory accesses. */ - __device__ __host__ inline void operator()(int x_cb, int) + template // true if all threads in block will enter, even if out of range + __device__ __host__ inline void operator()(int x_cb, int, bool = true) { int parity_color = target::block_idx().z; int color = parity_color % Arg::nColor; diff --git a/include/kernels/clover_outer_product.cuh b/include/kernels/clover_outer_product.cuh index e887e65f0d..65953e4008 100644 --- a/include/kernels/clover_outer_product.cuh +++ b/include/kernels/clover_outer_product.cuh @@ -40,7 +40,7 @@ namespace quda { const ColorSpinorField &p_halo, cvector_ref &x, const ColorSpinorField &x_halo, const std::vector &coeff) : kernel_param(dim3(dim == -1 ? static_cast(x_halo.getDslashConstant().volume_4d_cb) : - x_halo.getDslashConstant().ghostFaceCB[dim], + x_halo.getDslashConstant().ghostFaceCB[dim == -1 ? 0 : dim], x.SiteSubset(), dim == -1 ? 4 : dim)), n_src(p.size()), force(force), diff --git a/include/kernels/coarse_op_kernel.cuh b/include/kernels/coarse_op_kernel.cuh index 7a59bdfae5..f91f5873d6 100644 --- a/include/kernels/coarse_op_kernel.cuh +++ b/include/kernels/coarse_op_kernel.cuh @@ -1382,7 +1382,7 @@ namespace quda { }; template struct storeCoarseSharedAtomic_impl { - template void operator()(Args...) + template void operator()(Args...) { errorQuda("Shared-memory atomic aggregation not supported on host"); } @@ -1402,9 +1402,9 @@ namespace quda { template using Cache = SharedMemoryCache, DimsStaticConditional<2, 1, 1>>; template using Ops = KernelOps>; - template + template inline __device__ void operator()(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, int i0, int j0, - int parity, const Pack &pack, const Ftor &ftor) + int parity, const Pack &pack, const Ftor &ftor, bool active) { using Arg = typename Ftor::Arg; const Arg &arg = ftor.arg; @@ -1468,57 +1468,61 @@ namespace quda { if (tx < Arg::coarseSpin*Arg::coarseSpin && (parity == 0 || arg.parity_flip == 1) ) { + if (!allthreads || active) { #pragma unroll - for (int i = 0; i < TileType::M; i++) { + for (int i = 0; i < TileType::M; i++) { #pragma unroll - for (int j = 0; j < TileType::N; j++) { - if (pack.dir == QUDA_IN_PLACE) { - // same as dir == QUDA_FORWARDS - arg.X_atomic.atomicAdd(0,coarse_parity,coarse_x_cb,s_row,s_col,i0+i,j0+j, - X[i_block0+i][j_block0+j][x_][s_row][s_col]); - } else { - arg.Y_atomic.atomicAdd(dim_index,coarse_parity,coarse_x_cb,s_row,s_col,i0+i,j0+j, - Y[i_block0+i][j_block0+j][x_][s_row][s_col]); - - if (pack.dir == QUDA_BACKWARDS) { - arg.X_atomic.atomicAdd(0,coarse_parity,coarse_x_cb,s_col,s_row,j0+j,i0+i, - conj(X[i_block0+i][j_block0+j][x_][s_row][s_col])); + for (int j = 0; j < TileType::N; j++) { + if (pack.dir == QUDA_IN_PLACE) { + // same as dir == QUDA_FORWARDS + arg.X_atomic.atomicAdd(0, coarse_parity, coarse_x_cb, s_row, s_col, i0 + i, j0 + j, + X[i_block0 + i][j_block0 + j][x_][s_row][s_col]); } else { - arg.X_atomic.atomicAdd(0,coarse_parity,coarse_x_cb,s_row,s_col,i0+i,j0+j, - X[i_block0+i][j_block0+j][x_][s_row][s_col]); - } - - if (!arg.bidirectional) { - if (Arg::fineSpin != 1 && s_row == s_col) arg.X_atomic.atomicAdd(0,coarse_parity,coarse_x_cb,s_row,s_col,i0+i,j0+j, - X[i_block0+i][j_block0+j][x_][s_row][s_col]); - else arg.X_atomic.atomicAdd(0,coarse_parity,coarse_x_cb,s_row,s_col,i0+i,j0+j, - -X[i_block0+i][j_block0+j][x_][s_row][s_col]); - } - } // dir == QUDA_IN_PLACE + arg.Y_atomic.atomicAdd(dim_index, coarse_parity, coarse_x_cb, s_row, s_col, i0 + i, j0 + j, + Y[i_block0 + i][j_block0 + j][x_][s_row][s_col]); + + if (pack.dir == QUDA_BACKWARDS) { + arg.X_atomic.atomicAdd(0, coarse_parity, coarse_x_cb, s_col, s_row, j0 + j, i0 + i, + conj(X[i_block0 + i][j_block0 + j][x_][s_row][s_col])); + } else { + arg.X_atomic.atomicAdd(0, coarse_parity, coarse_x_cb, s_row, s_col, i0 + i, j0 + j, + X[i_block0 + i][j_block0 + j][x_][s_row][s_col]); + } + + if (!arg.bidirectional) { + if (Arg::fineSpin != 1 && s_row == s_col) + arg.X_atomic.atomicAdd(0, coarse_parity, coarse_x_cb, s_row, s_col, i0 + i, j0 + j, + X[i_block0 + i][j_block0 + j][x_][s_row][s_col]); + else + arg.X_atomic.atomicAdd(0, coarse_parity, coarse_x_cb, s_row, s_col, i0 + i, j0 + j, + -X[i_block0 + i][j_block0 + j][x_][s_row][s_col]); + } + } // dir == QUDA_IN_PLACE + } } } } } }; - template + template __device__ __host__ void storeCoarseSharedAtomic(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, - int i0, int j0, int parity, const Ftor &ftor) + int i0, int j0, int parity, const Ftor &ftor, bool active) { using Arg = typename Ftor::Arg; const Arg &arg = ftor.arg; switch (arg.dir) { case QUDA_BACKWARDS: - target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, - Pack(), ftor); + target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, + parity, Pack(), ftor, active); break; case QUDA_FORWARDS: - target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, - Pack(), ftor); + target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, + parity, Pack(), ftor, active); break; case QUDA_IN_PLACE: - target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, - Pack(), ftor); + target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, + parity, Pack(), ftor, active); break; default: break;// do nothing @@ -1605,9 +1609,9 @@ namespace quda { } - template + template __device__ __host__ void computeVUV(const Ftor &ftor, int parity, int x_cb, int i0, int j0, int parity_coarse_, - int coarse_x_cb_) + int coarse_x_cb_, bool active) { using Arg = typename Ftor::Arg; const Arg &arg = ftor.arg; @@ -1634,7 +1638,7 @@ namespace quda { using Ctype = decltype(make_tile_C, false>(arg.vuvTile)); Ctype vuv[Arg::coarseSpin * Arg::coarseSpin]; - multiplyVUV(vuv, arg, parity, x_cb, i0, j0); + if (!allthreads || active) multiplyVUV(vuv, arg, parity, x_cb, i0, j0); if (isDiagonal && !isFromCoarseClover) { #pragma unroll @@ -1642,8 +1646,8 @@ namespace quda { } if (arg.shared_atomic) - storeCoarseSharedAtomic(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, ftor); - else + storeCoarseSharedAtomic(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, ftor, active); + else if (!allthreads || active) storeCoarseGlobalAtomic(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, arg); } @@ -1721,17 +1725,24 @@ namespace quda { @param[in] parity_c_row parity * output color row @param[in] c_col output coarse color column */ - __device__ __host__ inline void operator()(int x_cb, int parity_c_row, int c_col) + template + __device__ __host__ inline void operator()(int x_cb, int parity_c_row, int c_col, bool active = true) { - int parity, parity_coarse, x_coarse_cb, c_row; - target::dispatch(parity_coarse, x_coarse_cb, parity, x_cb, parity_c_row, c_row, c_col, arg); - - if (parity > 1) return; - if (c_row >= arg.vuvTile.M_tiles) return; - if (c_col >= arg.vuvTile.N_tiles) return; - if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) return; - - computeVUV(*this, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, x_coarse_cb); + int parity = 0, parity_coarse = 0, x_coarse_cb = 0, c_row = 0; + if (!allthreads || active) + target::dispatch(parity_coarse, x_coarse_cb, parity, x_cb, parity_c_row, c_row, c_col, arg); + + // if (parity > 1) return; + // if (c_row >= arg.vuvTile.M_tiles) return; + // if (c_col >= arg.vuvTile.N_tiles) return; + // if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) return; + if (parity > 1) active = false; + if (c_row >= arg.vuvTile.M_tiles) active = false; + if (c_col >= arg.vuvTile.N_tiles) active = false; + if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) active = false; + + computeVUV(*this, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, + x_coarse_cb, active); } }; @@ -1751,17 +1762,24 @@ namespace quda { @param[in] parity_c_row parity * output color row @param[in] c_col output coarse color column */ - __device__ __host__ inline void operator()(int x_cb, int parity_c_row, int c_col) + template + __device__ __host__ inline void operator()(int x_cb, int parity_c_row, int c_col, bool active = true) { - int parity, parity_coarse, x_coarse_cb, c_row; - target::dispatch(parity_coarse, x_coarse_cb, parity, x_cb, parity_c_row, c_row, c_col, arg); - - if (parity > 1) return; - if (c_row >= arg.vuvTile.M_tiles) return; - if (c_col >= arg.vuvTile.N_tiles) return; - if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) return; - - computeVUV(*this, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, x_coarse_cb); + int parity = 0, parity_coarse = 0, x_coarse_cb = 0, c_row = 0; + if (!allthreads || active) + target::dispatch(parity_coarse, x_coarse_cb, parity, x_cb, parity_c_row, c_row, c_col, arg); + + // if (parity > 1) return; + // if (c_row >= arg.vuvTile.M_tiles) return; + // if (c_col >= arg.vuvTile.N_tiles) return; + // if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) return; + if (parity > 1) active = false; + if (c_row >= arg.vuvTile.M_tiles) active = false; + if (c_col >= arg.vuvTile.N_tiles) active = false; + if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) active = false; + + computeVUV(*this, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, + x_coarse_cb, active); } }; diff --git a/include/kernels/color_spinor_pack.cuh b/include/kernels/color_spinor_pack.cuh index e51aa85a4d..bd122f9adf 100644 --- a/include/kernels/color_spinor_pack.cuh +++ b/include/kernels/color_spinor_pack.cuh @@ -209,9 +209,9 @@ namespace quda { } }; - template + template __device__ __host__ inline std::enable_if_t - compute_site_max(const Ftor &, int, int, int, int, int) + compute_site_max(const Ftor &, int, int, int, int, int, bool) { return static_cast(1.0); // dummy return for non-block float } @@ -219,24 +219,27 @@ namespace quda { /** Compute the max element over the spin-color components of a given site. */ - template + template __device__ __host__ inline std::enable_if_t - compute_site_max(const Ftor &ftor, int src_idx, int x_cb, int spinor_parity, int spin_block, int color_block) + compute_site_max(const Ftor &ftor, int src_idx, int x_cb, int spinor_parity, int spin_block, int color_block, + bool active) { using real = typename Ftor::Arg::real; const int Ms = spins_per_thread(Ftor::Arg::nSpin); const int Mc = colors_per_thread(Ftor::Arg::nColor); complex thread_max = {0.0, 0.0}; + if (!allthreads || active) { #pragma unroll - for (int spin_local=0; spin_local z = ftor.arg.in[src_idx](spinor_parity, x_cb, s, c); - thread_max.real(max(thread_max.real(), abs(z.real()))); - thread_max.imag(max(thread_max.imag(), abs(z.imag()))); + for (int color_local = 0; color_local < Mc; color_local++) { + int c = color_block + color_local; + complex z = ftor.arg.in[src_idx](spinor_parity, x_cb, s, c); + thread_max.real(max(thread_max.real(), abs(z.real()))); + thread_max.imag(max(thread_max.imag(), abs(z.imag()))); + } } } @@ -306,7 +309,8 @@ namespace quda { } static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ void operator()(int tid, int spin_color_block, int parity) + template + __device__ __host__ void operator()(int tid, int spin_color_block, int parity, bool active = true) { const int Ms = spins_per_thread(Arg::nSpin); const int Mc = colors_per_thread(Arg::nColor); @@ -322,21 +326,22 @@ namespace quda { int src_idx; int x_cb = indexFromFaceIndex(src_idx, dim, dir, ghost_idx, parity, arg); - auto max = compute_site_max(*this, src_idx, x_cb, spinor_parity, spin_block, color_block); + auto max = compute_site_max(*this, src_idx, x_cb, spinor_parity, spin_block, color_block, active); + if (!allthreads || active) { #pragma unroll - for (int spin_local=0; spin_local + __device__ __host__ inline void operator()(int x_cb, int src_flavor, int parity, bool active = true) { using namespace linalg; // for Cholesky const int clover_parity = arg.nParity == 2 ? parity : arg.parity; @@ -214,15 +215,21 @@ namespace quda { const int flavor = src_flavor % 2; int my_flavor_idx = x_cb + flavor * arg.volumeCB; - fermion in = arg.in[src_idx](my_flavor_idx, spinor_parity); - in.toRel(); // change to chiral basis here - + fermion in; int chirality = flavor; // relabel flavor as chirality + Mat A; + if (!allthreads || active) { + in = arg.in[src_idx](my_flavor_idx, spinor_parity); + in.toRel(); // change to chiral basis here + A = arg.clover(x_cb, clover_parity, chirality); + } else { + in = fermion {}; + A = Mat {}; + } + // (C + i mu gamma_5 tau_3 - epsilon tau_1 ) [note: appropriate signs carried in arg.a / arg.b] const complex a(0.0, chirality == 0 ? arg.a : -arg.a); - Mat A = arg.clover(x_cb, clover_parity, chirality); - SharedMemoryCache cache {*this}; half_fermion in_chi[n_flavor]; // flavor array of chirally projected fermion @@ -251,27 +258,32 @@ namespace quda { out_chi[flavor] += arg.b * in_chi[1 - flavor]; } - if (arg.inverse) { - if (arg.dynamic_clover) { - Mat A2 = A.square(); - A2 += arg.a2_minus_b2; - Cholesky, N> cholesky(A2); + if (!allthreads || active) { + if (arg.inverse) { + if (arg.dynamic_clover) { + Mat A2 = A.square(); + A2 += arg.a2_minus_b2; + Cholesky, N> cholesky(A2); #pragma unroll - for (int flavor = 0; flavor < n_flavor; flavor++) - out_chi[flavor] = static_cast(0.25) * cholesky.backward(cholesky.forward(out_chi[flavor])); - } else { - Mat Ainv = arg.cloverInv(x_cb, clover_parity, chirality); + for (int flavor = 0; flavor < n_flavor; flavor++) + out_chi[flavor] = static_cast(0.25) * cholesky.backward(cholesky.forward(out_chi[flavor])); + } else { + Mat Ainv = arg.cloverInv(x_cb, clover_parity, chirality); #pragma unroll - for (int flavor = 0; flavor < n_flavor; flavor++) - out_chi[flavor] = static_cast(2.0) * (Ainv * out_chi[flavor]); + for (int flavor = 0; flavor < n_flavor; flavor++) + out_chi[flavor] = static_cast(2.0) * (Ainv * out_chi[flavor]); + } } } swizzle(out_chi, chirality); // undo the flavor-chirality swizzle - fermion out = out_chi[0].chiral_reconstruct(0) + out_chi[1].chiral_reconstruct(1); - out.toNonRel(); // change basis back - arg.out[src_idx](my_flavor_idx, spinor_parity) = out; + if (!allthreads || active) { + fermion out = out_chi[0].chiral_reconstruct(0) + out_chi[1].chiral_reconstruct(1); + out.toNonRel(); // change basis back + + arg.out[src_idx](my_flavor_idx, spinor_parity) = out; + } } }; } diff --git a/include/kernels/dslash_coarse.cuh b/include/kernels/dslash_coarse.cuh index 086a7def06..f630e66c0d 100644 --- a/include/kernels/dslash_coarse.cuh +++ b/include/kernels/dslash_coarse.cuh @@ -339,7 +339,8 @@ namespace quda { } static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int x_cb_color_offset, int src_parity, int sMd) + template + __device__ __host__ inline void operator()(int x_cb_color_offset, int src_parity, int sMd, bool active = true) { int x_cb = x_cb_color_offset; int color_offset = 0; @@ -368,11 +369,16 @@ namespace quda { typename CoarseDslashParams::array_t out {}; if (Arg::dslash) { - applyDslash(out, dim, dir, x_cb, src_idx, parity, s, color_block, color_offset, arg); + if (!allthreads || active) { + applyDslash(out, dim, dir, x_cb, src_idx, parity, s, color_block, color_offset, arg); + } target::dispatch(out, dir, dim, *this); } - if (doBulk() && Arg::clover && dir==0 && dim==0) applyClover(out, arg, x_cb, src_idx, parity, s, color_block, color_offset); + if (!allthreads || active) { + if (doBulk() && Arg::clover && dir == 0 && dim == 0) + applyClover(out, arg, x_cb, src_idx, parity, s, color_block, color_offset); + } if (dir==0 && dim==0) { const int my_spinor_parity = (arg.nParity == 2) ? parity : 0; @@ -380,13 +386,17 @@ namespace quda { // reduce down to the first group of column-split threads out = warp_combine(out); + if (!allthreads || active) { #pragma unroll - for (int color_local=0; color_local()) arg.out[src_idx](my_spinor_parity, x_cb, s, c) = out[color_local]; - else arg.out[src_idx](my_spinor_parity, x_cb, s, c) += out[color_local]; + for (int color_local = 0; color_local < Mc; color_local++) { + int c = color_block + color_local; // global color index + if (color_offset == 0) { + // if not halo we just store, else we accumulate + if (doBulk()) + arg.out[src_idx](my_spinor_parity, x_cb, s, c) = out[color_local]; + else + arg.out[src_idx](my_spinor_parity, x_cb, s, c) += out[color_local]; + } } } } diff --git a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh index 46e0ae876a..439cce5433 100644 --- a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh +++ b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh @@ -73,8 +73,8 @@ namespace quda template constexpr domainWall4DFusedM5(const Ftor &ftor) : KernelOpsT(ftor), arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation - template - __device__ __host__ __forceinline__ void operator()(int idx, int src_s, int parity) + template + __device__ __host__ __forceinline__ void operator()(int idx, int src_s, int parity, bool alive = true) { typedef typename mapper::type real; typedef ColorSpinor Vector; @@ -82,73 +82,74 @@ namespace quda int src_idx = src_s / arg.Ls; int s = src_s % arg.Ls; - bool active - = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) + bool active = mykernel_type != EXTERIOR_KERNEL_ALL; // is thread active (non-trival for fused kernel only) int thread_dim; // which dimension is thread working on (fused kernel only) auto coord = getCoords(arg, idx, s, parity, thread_dim); const int my_spinor_parity = arg.nParity == 2 ? parity : 0; Vector stencil_out; - applyWilson(stencil_out, arg, coord, parity, idx, thread_dim, active, src_idx); + if (!allthreads || alive) { + applyWilson(stencil_out, arg, coord, parity, idx, thread_dim, active, src_idx); + } Vector out; - constexpr bool shared = true; // Use shared memory - // In the following `x_cb` are all passed as `x_cb = 0`, since it will not be used if `shared = true`, and `shared = true` - if (active) { - - /****** - * Apply M5pre - */ - if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE) { - constexpr bool sync = false; - out = d5(*this, stencil_out, - my_spinor_parity, 0, s, src_idx); - } + if (allthreads||active) { + /****** + * Apply M5pre + */ + if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE) { + constexpr bool sync = false; + out = d5 + (*this, stencil_out, my_spinor_parity, 0, s, src_idx, alive&&active); + } } int xs = coord.x_cb + s * arg.dc.volume_4d_cb; if (Arg::dslash5_type == Dslash5Type::M5_INV_MOBIUS_M5_INV_DAG) { - /****** - * Apply the two M5inv's: - * this is actually y = 1 * x - kappa_b^2 * m5inv * D4 * in - * out = m5inv-dagger * y - */ - if (active) { - constexpr bool sync = false; - out = variableInv( - *this, stencil_out, my_spinor_parity, 0, s, src_idx); + /****** + * Apply the two M5inv's: + * this is actually y = 1 * x - kappa_b^2 * m5inv * D4 * in + * out = m5inv-dagger * y + */ + if (allthreads||active) { + constexpr bool sync = false; + out = variableInv + (*this, stencil_out, my_spinor_parity, 0, s, src_idx, alive&&active); } - Vector aggregate_external; - if (xpay && mykernel_type == INTERIOR_KERNEL) { - Vector x = arg.x[src_idx](xs, my_spinor_parity); - out = x + arg.a_5[s] * out; - } else if (mykernel_type != INTERIOR_KERNEL && active) { - Vector y = arg.y[src_idx](xs, my_spinor_parity); - aggregate_external = xpay ? arg.a_5[s] * out : out; - out = y + aggregate_external; - } + if (!allthreads||alive) { + Vector aggregate_external; + if (xpay && mykernel_type == INTERIOR_KERNEL) { + Vector x = arg.x[src_idx](xs, my_spinor_parity); + out = x + arg.a_5[s] * out; + } else if (mykernel_type != INTERIOR_KERNEL && active) { + Vector y = arg.y[src_idx](xs, my_spinor_parity); + aggregate_external = xpay ? arg.a_5[s] * out : out; + out = y + aggregate_external; + } - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.y[src_idx](xs, my_spinor_parity) = out; + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.y[src_idx](xs, my_spinor_parity) = out; - if (mykernel_type != INTERIOR_KERNEL && active) { - Vector x = arg.out[src_idx](xs, my_spinor_parity); - out = x + aggregate_external; + if (mykernel_type != INTERIOR_KERNEL && active) { + Vector x = arg.out[src_idx](xs, my_spinor_parity); + out = x + aggregate_external; + } } bool complete = isComplete(arg, coord); - if (complete && active) { - constexpr bool sync = true; - constexpr bool this_dagger = true; - // Then we apply the second m5inv-dag - out = variableInv( - *this, out, my_spinor_parity, 0, s, src_idx); - } + if (allthreads || (complete && active)) { + constexpr bool sync = true; + constexpr bool this_dagger = true; + // Then we apply the second m5inv-dag + auto tmp = variableInv + (*this, out, my_spinor_parity, 0, s, src_idx, alive && complete && active); + if (alive && complete && active) out = tmp; + } } else if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS || Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE_M5_MOB) { @@ -159,25 +160,28 @@ namespace quda * or out = m5mob * x - kappa_b^2 * m5pre *D4 * in (Dslash5Type::DSLASH5_PRE_MOBIUS_M5_MOBIUS) */ - if (active) { - if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS) { out = stencil_out; } + if (allthreads || active) { + if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS) { out = stencil_out; } - if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE_M5_MOB) { - constexpr bool sync = false; - out = d5( - *this, stencil_out, my_spinor_parity, 0, s, src_idx); - } - } + if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE_M5_MOB) { + constexpr bool sync = false; + out = d5(*this, stencil_out, my_spinor_parity, 0, s, src_idx, alive && active); + } + } if (xpay && mykernel_type == INTERIOR_KERNEL) { - Vector x = arg.x[src_idx](xs, my_spinor_parity); + Vector x; + if (!allthreads || alive) x = arg.x[src_idx](xs, my_spinor_parity); constexpr bool sync_m5mob = Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS ? false : true; - x = d5( - *this, x, my_spinor_parity, 0, s, src_idx); - out = x + arg.a_5[s] * out; + x = d5(*this, x, my_spinor_parity, 0, s, src_idx, alive); + if (!allthreads || alive) out = x + arg.a_5[s] * out; } else if (mykernel_type != INTERIOR_KERNEL && active) { - Vector x = arg.out[src_idx](xs, my_spinor_parity); - out = x + (xpay ? arg.a_5[s] * out : out); + if (!allthreads || alive) { + Vector x = arg.out[src_idx](xs, my_spinor_parity); + out = x + (xpay ? arg.a_5[s] * out : out); + } } } else { @@ -191,20 +195,22 @@ namespace quda if (Arg::dslash5_type == Dslash5Type::M5_INV_MOBIUS) { // Apply the m5inv. constexpr bool sync = false; - out = variableInv( - *this, stencil_out, my_spinor_parity, 0, s, src_idx); + out = variableInv + (*this, stencil_out, my_spinor_parity, 0, s, src_idx, alive); } - if (xpay && mykernel_type == INTERIOR_KERNEL) { - Vector x = arg.x[src_idx](xs, my_spinor_parity); - out = x + arg.a_5[s] * out; - } else if (mykernel_type != INTERIOR_KERNEL && active) { - Vector x = arg.out[src_idx](xs, my_spinor_parity); - out = x + (xpay ? arg.a_5[s] * out : out); - } + if (!allthreads || alive) { + if (xpay && mykernel_type == INTERIOR_KERNEL) { + Vector x = arg.x[src_idx](xs, my_spinor_parity); + out = x + arg.a_5[s] * out; + } else if (mykernel_type != INTERIOR_KERNEL && active) { + Vector x = arg.out[src_idx](xs, my_spinor_parity); + out = x + (xpay ? arg.a_5[s] * out : out); + } + } bool complete = isComplete(arg, coord); - if (complete && active) { + if (allthreads || (complete && active)) { /****** * First apply M5inv, and then M5pre @@ -212,12 +218,13 @@ namespace quda if (Arg::dslash5_type == Dslash5Type::M5_INV_MOBIUS_M5_PRE) { // Apply the m5inv. constexpr bool sync_m5inv = false; - out = variableInv( - *this, out, my_spinor_parity, 0, s, src_idx); + auto tmp = variableInv + (*this, out, my_spinor_parity, 0, s, src_idx, alive && complete && active); // Apply the m5pre. constexpr bool sync_m5pre = true; - out = d5(*this, out, my_spinor_parity, - 0, s, src_idx); + tmp = d5 + (*this, tmp, my_spinor_parity, 0, s, src_idx, alive && complete && active); + if (alive && complete && active) out = tmp; } /****** @@ -226,16 +233,17 @@ namespace quda if (Arg::dslash5_type == Dslash5Type::M5_PRE_MOBIUS_M5_INV) { // Apply the m5pre. constexpr bool sync_m5pre = false; - out = d5(*this, out, my_spinor_parity, - 0, s, src_idx); + auto tmp = d5 + (*this, out, my_spinor_parity, 0, s, src_idx, alive && complete && active); // Apply the m5inv. constexpr bool sync_m5inv = true; - out = variableInv( - *this, out, my_spinor_parity, 0, s, src_idx); + tmp = variableInv + (*this, tmp, my_spinor_parity, 0, s, src_idx, alive && complete && active); + if (alive && complete && active) out = tmp; } } } - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](xs, my_spinor_parity) = out; + if (alive && (mykernel_type != EXTERIOR_KERNEL_ALL || active)) arg.out[src_idx](xs, my_spinor_parity) = out; } }; diff --git a/include/kernels/dslash_domain_wall_m5.cuh b/include/kernels/dslash_domain_wall_m5.cuh index 9ea0419b8a..25433b46e4 100644 --- a/include/kernels/dslash_domain_wall_m5.cuh +++ b/include/kernels/dslash_domain_wall_m5.cuh @@ -215,9 +215,9 @@ namespace quda using Ops = std::conditional_t, NoKernelOps>; }; - template - __device__ __host__ inline Vector d5(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s, int src_idx) + template + __device__ __host__ inline Vector d5(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s, int src_idx, bool alive) { const Arg &arg = ftor.arg; int local_src_idx = target::thread_idx().y / arg.Ls; @@ -240,19 +240,21 @@ namespace quda cache.save(in.project(4, proj_dir)); cache.sync(); } - const int fwd_s = (s + 1) % arg.Ls; - const int fwd_idx = fwd_s * arg.volume_4d_cb + x_cb; - HalfVector half_in; - if constexpr (shared) { - half_in = cache.load(threadIdx.x, local_src_idx * arg.Ls + fwd_s, parity); - } else { - Vector full_in = arg.in[src_idx](fwd_idx, parity); - half_in = full_in.project(4, proj_dir); - } - if (s == arg.Ls - 1) { - out += (-arg.m_f * half_in).reconstruct(4, proj_dir); - } else { - out += half_in.reconstruct(4, proj_dir); + if (!allthreads || alive) { + const int fwd_s = (s + 1) % arg.Ls; + const int fwd_idx = fwd_s * arg.volume_4d_cb + x_cb; + HalfVector half_in; + if constexpr (shared) { + half_in = cache.load(threadIdx.x, local_src_idx * arg.Ls + fwd_s, parity); + } else { + Vector full_in = arg.in[src_idx](fwd_idx, parity); + half_in = full_in.project(4, proj_dir); + } + if (s == arg.Ls - 1) { + out += (-arg.m_f * half_in).reconstruct(4, proj_dir); + } else { + out += half_in.reconstruct(4, proj_dir); + } } } @@ -263,20 +265,22 @@ namespace quda cache.save(in.project(4, proj_dir)); cache.sync(); } - const int back_s = (s + arg.Ls - 1) % arg.Ls; - const int back_idx = back_s * arg.volume_4d_cb + x_cb; - HalfVector half_in; - if constexpr (shared) { - half_in = cache.load(threadIdx.x, local_src_idx * arg.Ls + back_s, parity); - } else { - Vector full_in = arg.in[src_idx](back_idx, parity); - half_in = full_in.project(4, proj_dir); - } - if (s == 0) { - out += (-arg.m_f * half_in).reconstruct(4, proj_dir); - } else { - out += half_in.reconstruct(4, proj_dir); - } + if (!allthreads || alive) { + const int back_s = (s + arg.Ls - 1) % arg.Ls; + const int back_idx = back_s * arg.volume_4d_cb + x_cb; + HalfVector half_in; + if constexpr (shared) { + half_in = cache.load(threadIdx.x, local_src_idx * arg.Ls + back_s, parity); + } else { + Vector full_in = arg.in[src_idx](back_idx, parity); + half_in = full_in.project(4, proj_dir); + } + if (s == 0) { + out += (-arg.m_f * half_in).reconstruct(4, proj_dir); + } else { + out += half_in.reconstruct(4, proj_dir); + } + } } } else { // use_half_vector @@ -291,40 +295,44 @@ namespace quda cache.sync(); } - { // forwards direction - const int fwd_s = (s + 1) % arg.Ls; - const int fwd_idx = fwd_s * arg.volume_4d_cb + x_cb; - const Vector in - = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + fwd_s, parity) : arg.in[src_idx](fwd_idx, parity); - constexpr int proj_dir = dagger ? +1 : -1; - if (s == arg.Ls - 1) { - out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); - } else { - out += in.project(4, proj_dir).reconstruct(4, proj_dir); + if (!allthreads || alive) { + { // forwards direction + const int fwd_s = (s + 1) % arg.Ls; + const int fwd_idx = fwd_s * arg.volume_4d_cb + x_cb; + const Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + fwd_s, parity) : + arg.in[src_idx](fwd_idx, parity); + constexpr int proj_dir = dagger ? +1 : -1; + if (s == arg.Ls - 1) { + out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); + } else { + out += in.project(4, proj_dir).reconstruct(4, proj_dir); + } } - } - { // backwards direction - const int back_s = (s + arg.Ls - 1) % arg.Ls; - const int back_idx = back_s * arg.volume_4d_cb + x_cb; - const Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + back_s, parity) : - arg.in[src_idx](back_idx, parity); - constexpr int proj_dir = dagger ? -1 : +1; - if (s == 0) { - out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); - } else { - out += in.project(4, proj_dir).reconstruct(4, proj_dir); + { // backwards direction + const int back_s = (s + arg.Ls - 1) % arg.Ls; + const int back_idx = back_s * arg.volume_4d_cb + x_cb; + const Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + back_s, parity) : + arg.in[src_idx](back_idx, parity); + constexpr int proj_dir = dagger ? -1 : +1; + if (s == 0) { + out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); + } else { + out += in.project(4, proj_dir).reconstruct(4, proj_dir); + } } } } // use_half_vector - if (type == Dslash5Type::DSLASH5_MOBIUS_PRE || type == Dslash5Type::M5_INV_MOBIUS_M5_PRE - || type == Dslash5Type::M5_PRE_MOBIUS_M5_INV) { - Vector diagonal = shared ? in : arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = coeff.alpha(s) * out + coeff.beta(s) * diagonal; - } else if (type == Dslash5Type::DSLASH5_MOBIUS) { - Vector diagonal = shared ? in : arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = coeff.kappa(s) * out + diagonal; + if (!allthreads || alive) { + if (type == Dslash5Type::DSLASH5_MOBIUS_PRE || type == Dslash5Type::M5_INV_MOBIUS_M5_PRE + || type == Dslash5Type::M5_PRE_MOBIUS_M5_INV) { + Vector diagonal = shared ? in : arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = coeff.alpha(s) * out + coeff.beta(s) * diagonal; + } else if (type == Dslash5Type::DSLASH5_MOBIUS) { + Vector diagonal = shared ? in : arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = coeff.kappa(s) * out + diagonal; + } } return out; @@ -346,7 +354,8 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s Ls dimension coordinate */ - __device__ __host__ inline void operator()(int x_cb, int src_s, int parity) + template + __device__ __host__ inline void operator()(int x_cb, int src_s, int parity, bool alive = true) { using real = typename Arg::real; coeff_type::value, Arg> coeff(arg); @@ -358,22 +367,24 @@ namespace quda constexpr bool sync = false; constexpr bool shared = false; - Vector out = d5(*this, Vector(), parity, x_cb, s, src_idx); - - if (Arg::xpay) { - if (Arg::type == Dslash5Type::DSLASH5_DWF) { - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = x + arg.a * out; - } else if (Arg::type == Dslash5Type::DSLASH5_MOBIUS_PRE) { - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = x + coeff.a(s) * out; - } else if (Arg::type == Dslash5Type::DSLASH5_MOBIUS) { - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = coeff.a(s) * x + out; + Vector out = d5(*this, Vector(), parity, x_cb, s, src_idx, alive); + + if (!allthreads || alive) { + if (Arg::xpay) { + if (Arg::type == Dslash5Type::DSLASH5_DWF) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = x + arg.a * out; + } else if (Arg::type == Dslash5Type::DSLASH5_MOBIUS_PRE) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = x + coeff.a(s) * out; + } else if (Arg::type == Dslash5Type::DSLASH5_MOBIUS) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = coeff.a(s) * x + out; + } } - } - arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + } } }; @@ -398,9 +409,9 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s_ Ls dimension coordinate */ - template + template __device__ __host__ inline Vector constantInv(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s_, - int src_idx) + int src_idx, bool alive) { using Arg = typename Ftor::Arg; const Arg &arg = ftor.arg; @@ -421,23 +432,25 @@ namespace quda Vector out; - for (int s = 0; s < arg.Ls; s++) { + if (!allthreads || alive) { + for (int s = 0; s < arg.Ls; s++) { - Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : - arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : + arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - { - int exp = s_ < s ? arg.Ls - s + s_ : s_ - s; - real factorR = inv * fpow(k, exp) * (s_ < s ? -arg.m_f : static_cast(1.0)); - constexpr int proj_dir = dagger ? -1 : +1; - out += factorR * (in.project(4, proj_dir)).reconstruct(4, proj_dir); - } + { + int exp = s_ < s ? arg.Ls - s + s_ : s_ - s; + real factorR = inv * fpow(k, exp) * (s_ < s ? -arg.m_f : static_cast(1.0)); + constexpr int proj_dir = dagger ? -1 : +1; + out += factorR * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + } - { - int exp = s_ > s ? arg.Ls - s_ + s : s - s_; - real factorL = inv * fpow(k, exp) * (s_ > s ? -arg.m_f : static_cast(1.0)); - constexpr int proj_dir = dagger ? +1 : -1; - out += factorL * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + { + int exp = s_ > s ? arg.Ls - s_ + s : s - s_; + real factorL = inv * fpow(k, exp) * (s_ > s ? -arg.m_f : static_cast(1.0)); + constexpr int proj_dir = dagger ? +1 : -1; + out += factorL * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + } } } @@ -467,9 +480,9 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s_ Ls dimension coordinate */ - template + template __device__ __host__ inline Vector variableInv(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s_, - int src_idx) + int src_idx, bool alive) { const Arg &arg = ftor.arg; int local_src_idx = target::thread_idx().y / arg.Ls; @@ -486,30 +499,32 @@ namespace quda { // first do R constexpr int proj_dir = dagger ? -1 : +1; - if (shared) { - if (sync) { cache.sync(); } + if constexpr (shared) { + if constexpr (sync) { cache.sync(); } cache.save(in.project(4, proj_dir)); cache.sync(); } - int s = s_; - auto R = coeff.inv(); - HalfVector r; - for (int s_count = 0; s_count < arg.Ls; s_count++) { - auto factorR = (s_ < s ? -arg.m_f * R : R); - - if (shared) { - r += factorR * cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity); - } else { - Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - r += factorR * in.project(4, proj_dir); - } - - R *= coeff.kappa(s); - s = (s + arg.Ls - 1) % arg.Ls; - } - - out += r.reconstruct(4, proj_dir); + if (!allthreads || alive) { + int s = s_; + auto R = coeff.inv(); + HalfVector r; + for (int s_count = 0; s_count < arg.Ls; s_count++) { + auto factorR = (s_ < s ? -arg.m_f * R : R); + + if (shared) { + r += factorR * cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity); + } else { + Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + r += factorR * in.project(4, proj_dir); + } + + R *= coeff.kappa(s); + s = (s + arg.Ls - 1) % arg.Ls; + } + + out += r.reconstruct(4, proj_dir); + } } { // second do L @@ -520,24 +535,26 @@ namespace quda cache.sync(); } - int s = s_; - auto L = coeff.inv(); - HalfVector l; - for (int s_count = 0; s_count < arg.Ls; s_count++) { - auto factorL = (s_ > s ? -arg.m_f * L : L); - - if (shared) { - l += factorL * cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity); - } else { - Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - l += factorL * in.project(4, proj_dir); - } - - L *= coeff.kappa(s); - s = (s + 1) % arg.Ls; - } - - out += l.reconstruct(4, proj_dir); + if (!allthreads || alive) { + int s = s_; + auto L = coeff.inv(); + HalfVector l; + for (int s_count = 0; s_count < arg.Ls; s_count++) { + auto factorL = (s_ > s ? -arg.m_f * L : L); + + if (shared) { + l += factorL * cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity); + } else { + Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + l += factorL * in.project(4, proj_dir); + } + + L *= coeff.kappa(s); + s = (s + 1) % arg.Ls; + } + + out += l.reconstruct(4, proj_dir); + } } } else { // use_half_vector using Cache = std::conditional_t, const Ftor &>; @@ -548,44 +565,46 @@ namespace quda cache.sync(); } - { // first do R - constexpr int proj_dir = dagger ? -1 : +1; + if (!allthreads || alive) { + { // first do R + constexpr int proj_dir = dagger ? -1 : +1; + + int s = s_; + auto R = coeff.inv(); + HalfVector r; + for (int s_count = 0; s_count < arg.Ls; s_count++) { + auto factorR = (s_ < s ? -arg.m_f * R : R); - int s = s_; - auto R = coeff.inv(); - HalfVector r; - for (int s_count = 0; s_count < arg.Ls; s_count++) { - auto factorR = (s_ < s ? -arg.m_f * R : R); + Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : + arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + r += factorR * in.project(4, proj_dir); - Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : - arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - r += factorR * in.project(4, proj_dir); + R *= coeff.kappa(s); + s = (s + arg.Ls - 1) % arg.Ls; + } - R *= coeff.kappa(s); - s = (s + arg.Ls - 1) % arg.Ls; + out += r.reconstruct(4, proj_dir); } - out += r.reconstruct(4, proj_dir); - } + { // second do L + constexpr int proj_dir = dagger ? +1 : -1; - { // second do L - constexpr int proj_dir = dagger ? +1 : -1; + int s = s_; + auto L = coeff.inv(); + HalfVector l; + for (int s_count = 0; s_count < arg.Ls; s_count++) { + auto factorL = (s_ > s ? -arg.m_f * L : L); - int s = s_; - auto L = coeff.inv(); - HalfVector l; - for (int s_count = 0; s_count < arg.Ls; s_count++) { - auto factorL = (s_ > s ? -arg.m_f * L : L); + Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : + arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + l += factorL * in.project(4, proj_dir); - Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : - arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - l += factorL * in.project(4, proj_dir); + L *= coeff.kappa(s); + s = (s + 1) % arg.Ls; + } - L *= coeff.kappa(s); - s = (s + 1) % arg.Ls; + out += l.reconstruct(4, proj_dir); } - - out += l.reconstruct(4, proj_dir); } } // use_half_vector @@ -618,7 +637,8 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s Ls dimension coordinate */ - __device__ __host__ inline void operator()(int x_cb, int src_s, int parity) + template + __device__ __host__ inline void operator()(int x_cb, int src_s, int parity, bool alive = true) { constexpr int nSpin = 4; using real = typename Arg::real; @@ -628,21 +648,25 @@ namespace quda int src_idx = src_s / arg.Ls; int s = src_s % arg.Ls; - Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - Vector out; + Vector in, out; + if (!allthreads || alive) { in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); } constexpr bool sync = false; if constexpr (mobius_m5::var_inverse()) { // zMobius, must call variableInv - out = variableInv(*this, in, parity, x_cb, s, src_idx); + out + = variableInv(*this, in, parity, x_cb, s, src_idx, alive); } else { - out = constantInv(*this, in, parity, x_cb, s, src_idx); + out + = constantInv(*this, in, parity, x_cb, s, src_idx, alive); } - if (Arg::xpay) { - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = x + coeff.a(s) * out; - } + if (!allthreads || alive) { + if (Arg::xpay) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = x + coeff.a(s) * out; + } - arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + } } }; diff --git a/include/kernels/dslash_mobius_eofa.cuh b/include/kernels/dslash_mobius_eofa.cuh index 49e65da6d7..3e2bb4e647 100644 --- a/include/kernels/dslash_mobius_eofa.cuh +++ b/include/kernels/dslash_mobius_eofa.cuh @@ -110,7 +110,8 @@ namespace quda } static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int x_cb, int src_s, int parity) + template + __device__ __host__ inline void operator()(int x_cb, int src_s, int parity, bool alive = true) { using real = typename Arg::real; typedef ColorSpinor Vector; @@ -121,7 +122,7 @@ namespace quda SharedMemoryCache cache {*this}; Vector out; - cache.save(arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity)); + if (!allthreads || alive) { cache.save(arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity)); } cache.sync(); auto Ls = arg.Ls; @@ -165,11 +166,13 @@ namespace quda } if (Arg::xpay) { // really axpy - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = arg.a * x + out; - } - } - arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + if (!allthreads || alive) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = arg.a * x + out; + } + } + } + if (!allthreads || alive) { arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; } } }; @@ -196,7 +199,8 @@ namespace quda } static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int x_cb, int src_s, int parity) + template + __device__ __host__ inline void operator()(int x_cb, int src_s, int parity, bool alive = true) { using real = typename Arg::real; typedef ColorSpinor Vector; @@ -206,7 +210,7 @@ namespace quda const auto sherman_morrison = arg.sherman_morrison; SharedMemoryCache cache {*this}; - cache.save(arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity)); + if (!allthreads || alive) { cache.save(arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity)); } cache.sync(); Vector out; @@ -233,10 +237,12 @@ namespace quda } } if (Arg::xpay) { // really axpy - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = x + arg.a * out; + if (!allthreads || alive) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = x + arg.a * out; + } } - arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + if (!allthreads || alive) { arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; } } }; diff --git a/include/kernels/dslash_ndeg_twisted_clover.cuh b/include/kernels/dslash_ndeg_twisted_clover.cuh index 049f129d14..625bf68778 100644 --- a/include/kernels/dslash_ndeg_twisted_clover.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover.cuh @@ -14,7 +14,7 @@ namespace quda static constexpr int length = (nSpin / (nSpin / 2)) * 2 * nColor * nColor * (nSpin / 2) * (nSpin / 2) / 2; typedef typename clover_mapper::type C; typedef typename mapper::type real; - + const C A; /** the clover field */ real a; /** this is the Wilson-dslash scale factor */ real b; /** this is the chiral twist factor */ @@ -58,8 +58,8 @@ namespace quda out(x) = M*in = a * D * in + (A(x) + i*b*gamma_5*tau_3 + c*tau_1)*x Note this routine only exists in xpay form. */ - template - __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity) + template + __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity, bool alive = true) { typedef typename mapper::type real; typedef ColorSpinor Vector; @@ -67,9 +67,8 @@ namespace quda int src_idx = src_flavor / 2; int flavor = src_flavor % 2; - bool active - = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) - int thread_dim; // which dimension is thread working on (fused kernel only) + bool active = mykernel_type != EXTERIOR_KERNEL_ALL; // is thread active (non-trival for fused kernel only) + int thread_dim; // which dimension is thread working on (fused kernel only) auto coord = getCoords(arg, idx, flavor, parity, thread_dim); @@ -77,53 +76,64 @@ namespace quda const int my_flavor_idx = coord.x_cb + flavor * arg.dc.volume_4d_cb; Vector out; - if (arg.dd_out.isZero(coord)) { - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; - return; + if (!allthreads || alive) { + if (arg.dd_out.isZero(coord)) { + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + if constexpr (!allthreads) return; + else alive = false; + } } - // defined in dslash_wilson.cuh - applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); + if (!allthreads || alive) { + // defined in dslash_wilson.cuh + applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); + } if constexpr (mykernel_type == INTERIOR_KERNEL) { if (arg.dd_x.isZero(coord)) { - out = arg.a * out; + if (!allthreads || alive) out = arg.a * out; } else { - // apply the chiral and flavor twists - // use consistent load order across s to ensure better cache locality - Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); SharedMemoryCache cache {*this}; - cache.save(x); + Vector tmp; + if (!allthreads || alive) { + // apply the chiral and flavor twists + // use consistent load order across s to ensure better cache locality + Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); + cache.save(x); - x.toRel(); // switch to chiral basis + x.toRel(); // switch to chiral basis - Vector tmp; #pragma unroll - for (int chirality = 0; chirality < 2; chirality++) { - constexpr int n = Arg::nColor * Arg::nSpin / 2; - HMatrix A = arg.A(coord.x_cb, parity, chirality); - HalfVector x_chi = x.chiral_project(chirality); - HalfVector Ax_chi = A * x_chi; - // i * mu * gamma_5 * tau_3 - const complex b(0.0, (chirality ^ flavor) == 0 ? static_cast(arg.b) : -static_cast(arg.b)); - Ax_chi += b * x_chi; - tmp += Ax_chi.chiral_reconstruct(chirality); + for (int chirality = 0; chirality < 2; chirality++) { + constexpr int n = Arg::nColor * Arg::nSpin / 2; + HMatrix A = arg.A(coord.x_cb, parity, chirality); + HalfVector x_chi = x.chiral_project(chirality); + HalfVector Ax_chi = A * x_chi; + // i * mu * gamma_5 * tau_3 + const complex b(0.0, + (chirality ^ flavor) == 0 ? static_cast(arg.b) : -static_cast(arg.b)); + Ax_chi += b * x_chi; + tmp += Ax_chi.chiral_reconstruct(chirality); + } + + tmp.toNonRel(); + // tmp += (c * tau_1) * x } - - tmp.toNonRel(); - // tmp += (c * tau_1) * x cache.sync(); - tmp += arg.c * cache.load_y(target::thread_idx().y + 1 - 2 * flavor); + if (!allthreads || alive) { + tmp += arg.c * cache.load_y(target::thread_idx().y + 1 - 2 * flavor); - // add the Wilson part with normalisation - out = tmp + arg.a * out; + // add the Wilson part with normalisation + out = tmp + arg.a * out; + } } } else if (active) { Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); out = x + arg.a * out; } - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + if (!allthreads || alive) + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; } }; diff --git a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh index 4b1a470db7..fb61259d65 100644 --- a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh @@ -13,7 +13,7 @@ namespace quda using WilsonArg::nSpin; static constexpr int length = (nSpin / (nSpin / 2)) * 2 * nColor * nColor * (nSpin / 2) * (nSpin / 2) / 2; static constexpr bool dynamic_clover = clover::dynamic_inverse(); - + typedef typename mapper::type real; typedef typename clover_mapper::type C; const C A; @@ -64,8 +64,8 @@ namespace quda out(x) = M*in = a*(C + i*b*gamma_5*tau_3 + c*tau_1)/(C^2 + b^2 - c^2)*D*x ( xpay == false ) out(x) = M*in = in + a*(C + i*b*gamma_5*tau_3 + c*tau_1)/(C^2 + b^2 - c^2)*D*x ( xpay == true ) */ - template - __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity) + template + __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity, bool alive = true) { using namespace linalg; // for Cholesky typedef typename mapper::type real; @@ -75,98 +75,107 @@ namespace quda int src_idx = src_flavor / 2; int flavor = src_flavor % 2; - - bool active - = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) - int thread_dim; // which dimension is thread working on (fused kernel only) + bool active = mykernel_type != EXTERIOR_KERNEL_ALL; // is thread active (non-trival for fused kernel only) + int thread_dim; // which dimension is thread working on (fused kernel only) auto coord = getCoords(arg, idx, flavor, parity, thread_dim); const int my_spinor_parity = arg.nParity == 2 ? parity : 0; int my_flavor_idx = coord.x_cb + flavor * arg.dc.volume_4d_cb; Vector out; - if (arg.dd_out.isZero(coord)) { - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; - return; + if (!allthreads || alive) { + if (arg.dd_out.isZero(coord)) { + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + if constexpr (!allthreads) return; + else alive = false; + } } - // defined in dslash_wilson.cuh - applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); + if (!allthreads || alive) { + // defined in dslash_wilson.cuh + applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); - if (mykernel_type != INTERIOR_KERNEL && active) { - // if we're not the interior kernel, then we must sum the partial - Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); - out += x; + if (mykernel_type != INTERIOR_KERNEL && active) { + // if we're not the interior kernel, then we must sum the partial + Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); + out += x; + } } + constexpr int n_flavor = 2; + HalfVector out_chi[n_flavor]; // flavor array of chirally projected fermion if (isComplete(arg, coord) && active) { - out.toRel(); - - constexpr int n_flavor = 2; - HalfVector out_chi[n_flavor]; // flavor array of chirally projected fermion + out.toRel(); #pragma unroll - for (int i = 0; i < n_flavor; i++) out_chi[i] = out.chiral_project(i); - - int chirality = flavor; // relabel flavor as chirality - - SharedMemoryCache cache {*this}; - - auto swizzle = [&](HalfVector x[2], int chirality) { - if (chirality == 0) - cache.save_y(x[1], target::thread_idx().y); - else - cache.save_y(x[0], target::thread_idx().y); - cache.sync(); - if (chirality == 0) - x[1] = cache.load_y(target::thread_idx().y + 1); - else - x[0] = cache.load_y(target::thread_idx().y - 1); - }; - - swizzle(out_chi, chirality); // apply the flavor-chirality swizzle between threads - - // load in the clover matrix - HMat A = arg.A(coord.x_cb, parity, chirality); + for (int i = 0; i < n_flavor; i++) out_chi[i] = out.chiral_project(i); + } - HalfVector A_chi[n_flavor]; + int chirality = flavor; // relabel flavor as chirality + SharedMemoryCache cache {*this}; + auto swizzle = [&](HalfVector x[2], int chirality) { + if (chirality == 0) + cache.save_y(x[1], target::thread_idx().y); + else + cache.save_y(x[0], target::thread_idx().y); + cache.sync(); + if (chirality == 0) + x[1] = cache.load_y(target::thread_idx().y + 1); + else + x[0] = cache.load_y(target::thread_idx().y - 1); + }; + + swizzle(out_chi, chirality); // apply the flavor-chirality swizzle between threads + + if (!allthreads || alive) { + if (isComplete(arg, coord) && active) { + // load in the clover matrix + HMat A = arg.A(coord.x_cb, parity, chirality); + + HalfVector A_chi[n_flavor]; #pragma unroll - for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { - const complex b(0.0, (chirality^flavor_) == 0 ? arg.b : -arg.b); - A_chi[flavor_] = A * out_chi[flavor_]; - A_chi[flavor_] += b * out_chi[flavor_]; - A_chi[flavor_] += arg.c * out_chi[1 - flavor_]; - } - - if constexpr (Arg::dynamic_clover) { - HMat A2 = A.square(); - A2 += arg.b2_minus_c2; - Cholesky, Arg::nColor * Arg::nSpin / 2> cholesky(A2); + for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { + const complex b(0.0, (chirality^flavor_) == 0 ? arg.b : -arg.b); + A_chi[flavor_] = A * out_chi[flavor_]; + A_chi[flavor_] += b * out_chi[flavor_]; + A_chi[flavor_] += arg.c * out_chi[1 - flavor_]; + } + + if constexpr (Arg::dynamic_clover) { + HMat A2 = A.square(); + A2 += arg.b2_minus_c2; + Cholesky, Arg::nColor * Arg::nSpin / 2> cholesky(A2); #pragma unroll - for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { - out_chi[flavor_] = static_cast(0.25) * cholesky.backward(cholesky.forward(A_chi[flavor_])); - } - } else { - HMat A2inv = arg.A2inv(coord.x_cb, parity, chirality); + for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { + out_chi[flavor_] = static_cast(0.25) * cholesky.backward(cholesky.forward(A_chi[flavor_])); + } + } else { + HMat A2inv = arg.A2inv(coord.x_cb, parity, chirality); #pragma unroll - for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { - out_chi[flavor_] = static_cast(2.0) * (A2inv * A_chi[flavor_]); - } - } - - swizzle(out_chi, chirality); // undo the flavor-chirality swizzle - Vector tmp = out_chi[0].chiral_reconstruct(0) + out_chi[1].chiral_reconstruct(1); - tmp.toNonRel(); // switch back to non-chiral basis - - if (xpay && !arg.dd_x.isZero(coord)) { - Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); - out = x + arg.a * tmp; - } else { - // multiplication with a needed here? - out = arg.a * tmp; - } + for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { + out_chi[flavor_] = static_cast(2.0) * (A2inv * A_chi[flavor_]); + } + } + } } - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + swizzle(out_chi, chirality); // undo the flavor-chirality swizzle + + if (!allthreads || alive) { + if (isComplete(arg, coord) && active) { + Vector tmp = out_chi[0].chiral_reconstruct(0) + out_chi[1].chiral_reconstruct(1); + tmp.toNonRel(); // switch back to non-chiral basis + + if (xpay && !arg.dd_x.isZero(coord)) { + Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); + out = x + arg.a * tmp; + } else { + // multiplication with a needed here? + out = arg.a * tmp; + } + } + + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + } } }; } // namespace quda diff --git a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh index 7effb07ae3..8244eb6787 100644 --- a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh @@ -64,8 +64,8 @@ namespace quda - with xpay: out(x) = M*in = x + a*(1+i*b*gamma_5 + c*tau_1)D * in */ - template - __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity) + template + __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity, bool alive = true) { typedef typename mapper::type real; typedef ColorSpinor Vector; @@ -73,62 +73,68 @@ namespace quda int src_idx = src_flavor / 2; int flavor = src_flavor % 2; - bool active - = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) - int thread_dim; // which dimension is thread working on (fused kernel only) + bool active = mykernel_type != EXTERIOR_KERNEL_ALL; // is thread active (non-trival for fused kernel only) + int thread_dim; // which dimension is thread working on (fused kernel only) auto coord = getCoords(arg, idx, flavor, parity, thread_dim); const int my_spinor_parity = arg.nParity == 2 ? parity : 0; int my_flavor_idx = coord.x_cb + flavor * arg.dc.volume_4d_cb; Vector out; - if (arg.dd_out.isZero(coord)) { - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; - return; + if (!allthreads || alive) { + if (arg.dd_out.isZero(coord)) { + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + if constexpr (!allthreads) return; + else alive = false; + } } - if (!dagger || Arg::asymmetric) // defined in dslash_wilson.cuh - applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); - else // defined in dslash_twisted_mass_preconditioned - applyWilsonTM(out, arg, coord, parity, idx, thread_dim, active, src_idx); - - if (xpay && mykernel_type == INTERIOR_KERNEL && !arg.dd_x.isZero(coord)) { - - if (!dagger || Arg::asymmetric) { // apply inverse twist which is undone below - // use consistent load order across s to ensure better cache locality - Vector x0 = arg.x[src_idx](coord.x_cb + 0 * arg.dc.volume_4d_cb, my_spinor_parity); - Vector x1 = arg.x[src_idx](coord.x_cb + 1 * arg.dc.volume_4d_cb, my_spinor_parity); - if (flavor == 0) - out += arg.a_inv * (x0 + arg.b_inv * x0.igamma(4) + arg.c_inv * x1); - else - out += arg.a_inv * (x1 - arg.b_inv * x1.igamma(4) + arg.c_inv * x0); - } else { - Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); - out += x; // just directly add since twist already applied in the dslash + if (!allthreads || alive) { + if (!dagger || Arg::asymmetric) // defined in dslash_wilson.cuh + applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); + else // defined in dslash_twisted_mass_preconditioned + applyWilsonTM(out, arg, coord, parity, idx, thread_dim, active, src_idx); + + if (xpay && mykernel_type == INTERIOR_KERNEL && !arg.dd_x.isZero(coord)) { + if constexpr (!dagger || Arg::asymmetric) { // apply inverse twist which is undone below + // use consistent load order across s to ensure better cache locality + Vector x0 = arg.x[src_idx](coord.x_cb + 0 * arg.dc.volume_4d_cb, my_spinor_parity); + Vector x1 = arg.x[src_idx](coord.x_cb + 1 * arg.dc.volume_4d_cb, my_spinor_parity); + if (flavor == 0) + out += arg.a_inv * (x0 + arg.b_inv * x0.igamma(4) + arg.c_inv * x1); + else + out += arg.a_inv * (x1 - arg.b_inv * x1.igamma(4) + arg.c_inv * x0); + } else { + Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); + out += x; // just directly add since twist already applied in the dslash + } + } else if (mykernel_type != INTERIOR_KERNEL && active) { + // if we're not the interior kernel, then we must sum the partial + Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); + out += x; } - - } else if (mykernel_type != INTERIOR_KERNEL && active) { - // if we're not the interior kernel, then we must sum the partial - Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); - out += x; } if constexpr (!dagger || Arg::asymmetric) { // apply A^{-1} to D*in SharedMemoryCache cache {*this}; - if (isComplete(arg, coord) && active) { - // to apply the preconditioner we need to put "out" in shared memory so the other flavor can access it - cache.save(out); - } - - cache.sync(); // safe to sync in here since other threads will exit - if (isComplete(arg, coord) && active) { - if (flavor == 0) - out = arg.a * (out + arg.b * out.igamma(4) + arg.c * cache.load_y(target::thread_idx().y + 1)); - else - out = arg.a * (out - arg.b * out.igamma(4) + arg.c * cache.load_y(target::thread_idx().y - 1)); - } + if (!allthreads || alive) { + if (isComplete(arg, coord) && active) { + // to apply the preconditioner we need to put "out" in shared memory so the other flavor can access it + cache.save(out); + } + } + cache.sync(); // safe to sync here since other threads will exit if allowed, or all be here + if (!allthreads || alive) { + if (isComplete(arg, coord) && active) { + if (flavor == 0) + out = arg.a * (out + arg.b * out.igamma(4) + arg.c * cache.load_y(target::thread_idx().y + 1)); + else + out = arg.a * (out - arg.b * out.igamma(4) + arg.c * cache.load_y(target::thread_idx().y - 1)); + } + } } - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + if (!allthreads || alive) + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; } }; diff --git a/include/kernels/gauge_fix_ovr.cuh b/include/kernels/gauge_fix_ovr.cuh index 1f147742d2..a4c663040c 100644 --- a/include/kernels/gauge_fix_ovr.cuh +++ b/include/kernels/gauge_fix_ovr.cuh @@ -146,7 +146,7 @@ namespace quda { constexpr computeFix(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) { } static constexpr const char *filename() { return KERNEL_FILE; } - __device__ inline void operator()(int idx, int mu) + template __device__ inline void operator()(int idx, int mu, bool active = true) { using real = typename Arg::real; using Link = Matrix, 3>; @@ -161,7 +161,7 @@ namespace quda { for (int dr = 0; dr < 4; dr++) p += arg.border[dr]; getCoords(x, idx, arg.X, p + parity); } else { - idx = arg.borderpoints[parity][idx]; // load the lattice site assigment + if (!allthreads || active) idx = arg.borderpoints[parity][idx]; // load the lattice site assigment x[3] = idx / (X[0] * X[1] * X[2]); x[2] = (idx / (X[0] * X[1])) % X[2]; x[1] = (idx / X[0]) % X[1]; @@ -188,7 +188,8 @@ namespace quda { parity = 1 - parity; } idx = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1; - Link link = arg.u(dim, idx, parity); + Link link; + if (!allthreads || active) link = arg.u(dim, idx, parity); if constexpr (Arg::type == 0) { // 8 threads per lattice site, the reduction is performed by shared memory without using atomicadd. @@ -201,11 +202,13 @@ namespace quda { GaugeFixHit_NoAtomicAdd_LessSM(link, arg.relax_boost, mu, *this); } + if (!allthreads || active) arg.u(dim, idx, parity) = link; arg.u(dim, idx, parity) = link; } else if constexpr (Arg::type == 2 || Arg::type == 3) { // 4 threads per lattice site idx = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1; - Link link = arg.u(mu, idx, parity); + Link link; + if (!allthreads || active) link = arg.u(mu, idx, parity); switch (mu) { case 0: x[0] = (x[0] - 1 + X[0]) % X[0]; break; @@ -214,7 +217,8 @@ namespace quda { case 3: x[3] = (x[3] - 1 + X[3]) % X[3]; break; } int idx1 = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1; - Link link1 = arg.u(mu, idx1, 1 - parity); + Link link1; + if (!allthreads || active) link1 = arg.u(mu, idx1, 1 - parity); if constexpr (Arg::type == 2) { // 4 threads per lattice site, the reduction is performed by shared memory without using atomicadd. @@ -227,8 +231,10 @@ namespace quda { GaugeFixHit_NoAtomicAdd_LessSM(link, link1, arg.relax_boost, mu, *this); } - arg.u(mu, idx, parity) = link; - arg.u(mu, idx1, 1 - parity) = link1; + if (!allthreads || active) { + arg.u(mu, idx, parity) = link; + arg.u(mu, idx1, 1 - parity) = link1; + } } } }; diff --git a/include/kernels/madwf_transfer.cuh b/include/kernels/madwf_transfer.cuh index 616d6a40c0..008631ad4e 100644 --- a/include/kernels/madwf_transfer.cuh +++ b/include/kernels/madwf_transfer.cuh @@ -111,7 +111,8 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s The output Ls dimension coordinate */ - __device__ __host__ inline void operator()(int x_cb, int s, int parity) + template + __device__ __host__ inline void operator()(int x_cb, int s, int parity, bool active = true) { constexpr bool dagger = Arg::dagger; @@ -132,14 +133,16 @@ namespace quda } cache.sync(); - Vector out; - // t -> s_in, s-> s_out - for (int t = 0; t < Ls_in; t++) { - Vector in = arg.in(t * volume_4d_cb + x_cb, parity); - int wm_index = dagger ? t * Ls_out + s : s * Ls_in + t; - matrix_vector_multiply(out, reinterpret_cast(cache.data())[wm_index], in); + if (!allthreads || active) { + Vector out; + // t -> s_in, s-> s_out + for (int t = 0; t < Ls_in; t++) { + Vector in = arg.in(t * volume_4d_cb + x_cb, parity); + int wm_index = dagger ? t * Ls_out + s : s * Ls_in + t; + matrix_vector_multiply(out, reinterpret_cast(cache.data())[wm_index], in); + } + arg.out(s * volume_4d_cb + x_cb, parity) = out; } - arg.out(s * volume_4d_cb + x_cb, parity) = out; } }; } // namespace madwf_ml diff --git a/include/kernels/multi_blas_core.cuh b/include/kernels/multi_blas_core.cuh index 0c7db44292..0abe2e37ee 100644 --- a/include/kernels/multi_blas_core.cuh +++ b/include/kernels/multi_blas_core.cuh @@ -15,7 +15,8 @@ namespace quda #ifndef QUDA_FAST_COMPILE_REDUCE constexpr bool enable_warp_split() { return false; } #else - constexpr bool enable_warp_split() { return true; } + // constexpr bool enable_warp_split() { return true; } + constexpr bool enable_warp_split() { return false; } #endif /** @@ -64,12 +65,15 @@ namespace quda @param[in,out] arg Argument struct with required meta data (input/output fields, functor, etc.) */ - template struct MultiBlas_ { + // template struct MultiBlas_ { + template struct MultiBlas_ : only_warp_combine, Arg::n / 2>> { + // std::conditional_t { const Arg &arg; constexpr MultiBlas_(const Arg &arg) : arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int i, int k, int parity) + template // true if all threads in group will enter, even if out of range + __device__ __host__ inline void operator()(int i, int k, int parity, bool active = true) { using vec = array, Arg::n/2>; @@ -83,22 +87,24 @@ namespace quda const int l_idx = lane_id / vector_site_width; vec x, y, z, w; - if (l_idx == 0 || warp_split == 1) { - if (arg.f.read.Y) arg.Y[k].load(y, idx, parity); - if (arg.f.read.W) arg.W[k].load(w, idx, parity); - } else { - y = ::quda::zero, Arg::n/2>(); - w = ::quda::zero, Arg::n/2>(); - } + if (!allthreads || active) { + if (l_idx == 0 || warp_split == 1) { + if (arg.f.read.Y) arg.Y[k].load(y, idx, parity); + if (arg.f.read.W) arg.W[k].load(w, idx, parity); + } else { + y = ::quda::zero, Arg::n / 2>(); + w = ::quda::zero, Arg::n / 2>(); + } #pragma unroll - for (int l_ = 0; l_ < Arg::NXZ; l_ += warp_split) { - const int l = l_ + l_idx; - if (l < Arg::NXZ || warp_split == 1) { - if (arg.f.read.X) arg.X[l].load(x, idx, parity); - if (arg.f.read.Z) arg.Z[l].load(z, idx, parity); - - arg.f(x, y, z, w, k, l); + for (int l_ = 0; l_ < Arg::NXZ; l_ += warp_split) { + const int l = l_ + l_idx; + if (l < Arg::NXZ || warp_split == 1) { + if (arg.f.read.X) arg.X[l].load(x, idx, parity); + if (arg.f.read.Z) arg.Z[l].load(z, idx, parity); + + arg.f(x, y, z, w, k, l); + } } } @@ -106,9 +112,11 @@ namespace quda if (arg.f.write.Y) y = warp_combine(y); if (arg.f.write.W) w = warp_combine(w); - if (l_idx == 0 || warp_split == 1) { - if (arg.f.write.Y) arg.Y[k].save(y, idx, parity); - if (arg.f.write.W) arg.W[k].save(w, idx, parity); + if (!allthreads || active) { + if (l_idx == 0 || warp_split == 1) { + if (arg.f.write.Y) arg.Y[k].save(y, idx, parity); + if (arg.f.write.W) arg.W[k].save(w, idx, parity); + } } } }; diff --git a/include/kernels/restrictor.cuh b/include/kernels/restrictor.cuh index a73af4f70f..51a8bc854b 100644 --- a/include/kernels/restrictor.cuh +++ b/include/kernels/restrictor.cuh @@ -139,7 +139,8 @@ namespace quda { constexpr Restrictor(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) { } static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(dim3 block, dim3 thread) + template + __device__ __host__ inline void operator()(dim3 block, dim3 thread, bool active = true) { int x_fine_offset = thread.x; const int x_coarse = block.x; @@ -149,50 +150,55 @@ namespace quda { const int coarse_color_block = coarse_color_thread * coarse_color_per_thread; vector reduced{0}; - while (x_fine_offset < arg.aggregate_size) { - // all threads with x_fine_offset greater than aggregate_size_cb are second parity - const int parity_offset = x_fine_offset >= arg.aggregate_size_cb ? 1 : 0; - const int x_fine_cb_offset = x_fine_offset % arg.aggregate_size_cb; - const int parity = arg.nParity == 2 ? parity_offset : arg.parity; + if (!allthreads || active) { + while (x_fine_offset < arg.aggregate_size) { + // all threads with x_fine_offset greater than aggregate_size_cb are second parity + const int parity_offset = x_fine_offset >= arg.aggregate_size_cb ? 1 : 0; + const int x_fine_cb_offset = x_fine_offset % arg.aggregate_size_cb; + const int parity = arg.nParity == 2 ? parity_offset : arg.parity; - // look-up map is ordered as (coarse-block-id + fine-point-id), - // with fine-point-id parity ordered - const int x_fine_site_id = (x_coarse * 2 + parity) * arg.aggregate_size_cb + x_fine_cb_offset; - const int x_fine = arg.coarse_to_fine[x_fine_site_id]; - const int x_fine_cb = x_fine - parity * arg.in[src_idx].VolumeCB(); + // look-up map is ordered as (coarse-block-id + fine-point-id), + // with fine-point-id parity ordered + const int x_fine_site_id = (x_coarse * 2 + parity) * arg.aggregate_size_cb + x_fine_cb_offset; + const int x_fine = arg.coarse_to_fine[x_fine_site_id]; + const int x_fine_cb = x_fine - parity * arg.in[src_idx].VolumeCB(); - array, Arg::fineSpin * coarse_color_per_thread> tmp{0}; + array, Arg::fineSpin * coarse_color_per_thread> tmp {0}; - rotateCoarseColor(tmp, arg, src_idx, parity, x_fine_cb, coarse_color_block); + rotateCoarseColor(tmp, arg, src_idx, parity, x_fine_cb, coarse_color_block); - // perform any local spin coarsening + // perform any local spin coarsening #pragma unroll - for (int s = 0; s= arg.out[src_idx].VolumeCB() ? 1 : 0; - const int x_coarse_cb = x_coarse - parity_coarse*arg.out[src_idx].VolumeCB(); + if (!allthreads || active) { + if (target::thread_idx().x == 0) { + const int parity_coarse = x_coarse >= arg.out[src_idx].VolumeCB() ? 1 : 0; + const int x_coarse_cb = x_coarse - parity_coarse * arg.out[src_idx].VolumeCB(); #pragma unroll - for (int s = 0; s < Arg::coarseSpin; s++) { + for (int s = 0; s < Arg::coarseSpin; s++) { #pragma unroll - for (int coarse_color_local=0; coarse_color_local std::enable_if_t, T &> constexpr elem(T &a, int i) { return (&a)[i]; } + + template ().x)> + std::enable_if_t, R &> constexpr elem(T &a, int i) + { + return (&a.x)[i]; + } + + template ().x.x), int = 0> + std::enable_if_t, R &> constexpr elem(T &a, int i) + { + return (&a.x.x)[i]; + } + /* Here we use traits to define the greater type used for mixing types of computation involving these types */ diff --git a/include/targets/cuda/shared_memory_helper.h b/include/targets/cuda/shared_memory_helper.h index 69d8c095ce..3b4b46a132 100644 --- a/include/targets/cuda/shared_memory_helper.h +++ b/include/targets/cuda/shared_memory_helper.h @@ -80,8 +80,9 @@ namespace quda /** @brief Constructor for SharedMemory object. */ - template - constexpr SharedMemory(const KernelOps &) : data(cache(get_offset(target::block_dim()))) + template + constexpr SharedMemory(const KernelOps &, const Arg &...arg) : + data(cache(get_offset(target::block_dim(), arg...))) { } diff --git a/include/targets/cuda/target_device.h b/include/targets/cuda/target_device.h index ee7c646172..077504027a 100644 --- a/include/targets/cuda/target_device.h +++ b/include/targets/cuda/target_device.h @@ -32,24 +32,42 @@ namespace quda #ifdef _NVHPC_CUDA // nvc++: run-time dispatch using if target - template