diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index 82fa3e26a2..c058c38a06 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -132,7 +132,7 @@ __conda_install_gcc () { # https://gcc.gnu.org/onlinedocs/libstdc++/manual/status.html#manual.intro.status.iso # # shellcheck disable=SC2155 - local gcc_version="${GCC_VERSION:-11.4.0}" + local gcc_version="${GCC_VERSION:-14.3.0}" echo "[INSTALL] Installing GCC (${gcc_version}, ${COMPILER_ARCHNAME}) through Conda ..." # shellcheck disable=SC2086 diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e96f6d050..5023bb013f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,7 +61,7 @@ include(CheckCXXCompilerFlag) # Install libraries into correct locations on all platforms include(GNUInstallDirs) -# Check AVX +# Check instruction sets flags set(CHECK_AVX_COMPILE ON) include(${MODULE_PATH}/FindAVX.cmake) # For some reason, AVX flag values end up being packed into a single quoted @@ -69,6 +69,7 @@ include(${MODULE_PATH}/FindAVX.cmake) # the flags here back into list-string format separate_arguments(CXX_AVX2_FLAGS) separate_arguments(CXX_AVX512_FLAGS) +include(${MODULE_PATH}/FindARM.cmake) # Load Python find_package(Python) diff --git a/cmake/modules/FindARM.cmake b/cmake/modules/FindARM.cmake new file mode 100644 index 0000000000..76b6145980 --- /dev/null +++ b/cmake/modules/FindARM.cmake @@ -0,0 +1,21 @@ + +function(get_sve_compiler_flags variable) + # Look for presence of flags + check_cxx_compiler_flag(-march=armv8-a+sve COMPILER_SUPPORTS_SVE) + check_cxx_compiler_flag(-march=armv8-a+sve2 COMPILER_SUPPORTS_SVE2) + + if(COMPILER_SUPPORTS_SVE2) + BLOCK_PRINT( + "The compiler supports SVE2 instructions, setting SVE2 compilation flag" + ) + set(_sve_flags "-march=armv8-a+sve2") + elseif(COMPILER_SUPPORTS_SVE) + BLOCK_PRINT( + "The compiler supports SVE2 instructions, setting SVE2 compilation flag" + ) + set(_sve_flags "-march=armv8-a+sve") + endif() + + # Set output variable in parent scope + set(${variable} ${_sve_flags} PARENT_SCOPE) +endfunction() diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index df3d7c3453..f684426b32 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -106,8 +106,9 @@ project( VERSION 1.3.0 LANGUAGES ${project_languages}) -# AVX Flags Setup - must be set AFTER project declaration +# Instruction sets flags setup - must be set AFTER project declaration include(${CMAKEMODULES}/FindAVX.cmake) +include(${CMAKEMODULES}/FindARM.cmake) # PyTorch Dependencies Setup include(${CMAKEMODULES}/PyTorchSetup.cmake) diff --git a/fbgemm_gpu/cmake/Fbgemm.cmake b/fbgemm_gpu/cmake/Fbgemm.cmake index f6c2a0841a..c2353180be 100644 --- a/fbgemm_gpu/cmake/Fbgemm.cmake +++ b/fbgemm_gpu/cmake/Fbgemm.cmake @@ -18,7 +18,22 @@ set(fbgemm_sources_normal "${FBGEMM}/src/Utils.cc") if(NOT DISABLE_FBGEMM_AUTOVEC) - list(APPEND fbgemm_sources_normal "${FBGEMM}/src/EmbeddingSpMDMAutovec.cc" "${FBGEMM}/src/EmbeddingStatsTracker.cc") + list(APPEND fbgemm_sources_normal + "${FBGEMM}/src/EmbeddingSpMDMAutovec.cc" + "${FBGEMM}/src/EmbeddingStatsTracker.cc") +endif() + +if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|ARM64|arm64") + list(APPEND fbgemm_sources_normal + "${FBGEMM}/src/QuantUtilsNeon.cc") + + # Set SVE flags for autovec if available + get_sve_compiler_flags(sve_compiler_flags) + if(sve_compiler_flags) + set_source_files_properties(${fbgemm_sources_normal} + PROPERTIES COMPILE_OPTIONS + "${sve_compiler_flags}") + endif() endif() set(fbgemm_sources_avx2 @@ -42,11 +57,13 @@ if(CXX_AVX512_FOUND) endif() set(fbgemm_sources ${fbgemm_sources_normal}) + if(CXX_AVX2_FOUND) set(fbgemm_sources ${fbgemm_sources} ${fbgemm_sources_avx2}) endif() + if(CXX_AVX512_FOUND) set(fbgemm_sources ${fbgemm_sources}