diff --git a/.github/workflows/build-quick-static-n-1.yml b/.github/workflows/build-quick-static-n-1.yml index 3348333b..3eca7843 100644 --- a/.github/workflows/build-quick-static-n-1.yml +++ b/.github/workflows/build-quick-static-n-1.yml @@ -54,6 +54,7 @@ jobs: -D CMAKE_CXX_COMPILER_LAUNCHER=ccache \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_STATIC=0 \ + -D BUILD_L0_LOADER_TESTS=1 \ .. make -j$(nproc) diff --git a/.github/workflows/build-quick-static.yml b/.github/workflows/build-quick-static.yml index 409c536a..1df24aa5 100644 --- a/.github/workflows/build-quick-static.yml +++ b/.github/workflows/build-quick-static.yml @@ -35,6 +35,7 @@ jobs: -D CMAKE_C_COMPILER_LAUNCHER=ccache \ -D CMAKE_CXX_COMPILER_LAUNCHER=ccache \ -D CMAKE_BUILD_TYPE=Release \ + -D BUILD_L0_LOADER_TESTS=1 \ -D BUILD_STATIC=0 \ .. make -j$(nproc) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec691bc8..9bec9a09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,6 @@ # Level zero loader changelog +## v1.26.0 +* Refactor L0 Init to delay loading of driver libraries until flags match the drivers requested. ## v1.25.2 * Enable support for Dynamic Tracing of zer* APIs * Fix issues with zer* apis during validation layer intercepts diff --git a/CMakeLists.txt b/CMakeLists.txt index ecfc8d08..95f2a77e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ if(MSVC AND (MSVC_VERSION LESS 1900)) endif() # This project follows semantic versioning (https://semver.org/) -project(level-zero VERSION 1.25.2) +project(level-zero VERSION 1.26.0) include(GNUInstallDirs) find_package(Git) diff --git a/PRODUCT_GUID.txt b/PRODUCT_GUID.txt index 720776b8..902fec0b 100644 --- a/PRODUCT_GUID.txt +++ b/PRODUCT_GUID.txt @@ -1,2 +1,2 @@ -1.25.2 -738dfd00-2750-4425-9d91-6b68a2590ded \ No newline at end of file +1.26.0 +af2b4113-ecb0-4777-b3ce-d81307ed1156 diff --git a/samples/zello_world/zello_world.cpp b/samples/zello_world/zello_world.cpp index f00bf286..0a5a6318 100644 --- a/samples/zello_world/zello_world.cpp +++ b/samples/zello_world/zello_world.cpp @@ -42,6 +42,7 @@ int main( int argc, char *argv[] ) bool tracing_runtime_enabled = false; bool legacy_init = false; bool tracing_enabled = false; + bool npu_test = false; if( argparse( argc, argv, "-null", "--enable_null_driver" ) ) { putenv_safe( const_cast( "ZE_ENABLE_NULL_DRIVER=1" ) ); @@ -69,13 +70,25 @@ int main( int argc, char *argv[] ) { legacy_init = true; } + if( argparse( argc, argv, "-npu", "--enable_npu" ) ) + { + npu_test = true; + } ze_result_t status; - const ze_device_type_t type = ZE_DEVICE_TYPE_GPU; + ze_device_type_t type = ZE_DEVICE_TYPE_GPU; + if (npu_test) { + std::cout << "NPU Test Enabled. Looking for NPU devices." << std::endl; + type = ZE_DEVICE_TYPE_VPU; + } ze_init_driver_type_desc_t driverTypeDesc = {}; driverTypeDesc.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC; - driverTypeDesc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU; + if (npu_test) { + driverTypeDesc.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU; + } else { + driverTypeDesc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU; + } driverTypeDesc.pNext = nullptr; ze_driver_handle_t pDriver = nullptr; diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index 1618fd9f..08514f11 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -22,6 +22,17 @@ using namespace loader_driver_ddi; namespace loader { + __${x}dlllocal ze_result_t ${X}_APICALL + ${n}loaderInitDriverDDITables(loader::driver_t *driver) { + ze_result_t result = ZE_RESULT_SUCCESS; + %for tbl in th.get_pfntables(specs, meta, n, tags): + result = ${tbl['export']['name']}FromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + %endfor + return result; + } %for obj in th.extract_objs(specs, r"function"): <% ret_type = obj['return_type'] @@ -65,6 +76,12 @@ namespace loader if(drv.initStatus != ZE_RESULT_SUCCESS) continue; %endif + if (!drv.handle || !drv.ddiInitialized) { + auto res = loader::context->init_driver( drv, flags, nullptr ); + if (res != ZE_RESULT_SUCCESS) { + continue; + } + } %if re.match(r"Init", obj['name']) and namespace == "zes": if (!drv.dditable.${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}) { drv.initSysManStatus = ZE_RESULT_ERROR_UNINITIALIZED; @@ -90,6 +107,13 @@ namespace loader %elif re.match(r"\w+DriverGet$", th.make_func_name(n, tags, obj)) or re.match(r"\w+InitDrivers$", th.make_func_name(n, tags, obj)): uint32_t total_driver_handle_count = 0; + %if re.match(r"\w+InitDrivers$", th.make_func_name(n, tags, obj)): + for( auto& drv : loader::context->zeDrivers ) { + if (!drv.handle || !drv.ddiInitialized) { + loader::context->init_driver( drv, 0, desc); + } + } + %endif { std::lock_guard lock(loader::context->sortMutex); @@ -122,15 +146,16 @@ namespace loader %endif { %if not (re.match(r"\w+InitDrivers$", th.make_func_name(n, tags, obj))) and namespace != "zes": - if(drv.initStatus != ZE_RESULT_SUCCESS) + if(drv.initStatus != ZE_RESULT_SUCCESS || !drv.ddiInitialized) continue; %elif namespace == "zes": - if(drv.initStatus != ZE_RESULT_SUCCESS || drv.initSysManStatus != ZE_RESULT_SUCCESS) + if(drv.initStatus != ZE_RESULT_SUCCESS || drv.initSysManStatus != ZE_RESULT_SUCCESS || !drv.ddiInitialized) continue; %else: if (!drv.dditable.${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}) { %if re.match(r"\w+InitDrivers$", th.make_func_name(n, tags, obj)): drv.initDriversStatus = ${X}_RESULT_ERROR_UNINITIALIZED; + result = ${X}_RESULT_ERROR_UNINITIALIZED; %else: drv.initStatus = ${X}_RESULT_ERROR_UNINITIALIZED; %endif @@ -179,7 +204,8 @@ namespace loader for( uint32_t i = 0; i < library_driver_handle_count; ++i ) { uint32_t driver_index = total_driver_handle_count + i; %if namespace != "zes": - drv.zerDriverHandle = phDrivers[ driver_index ]; + if (drv.zerddiInitResult == ZE_RESULT_SUCCESS) + drv.zerDriverHandle = phDrivers[ driver_index ]; if (drv.driverDDIHandleSupportQueried == false) { uint32_t extensionCount = 0; ze_result_t res = drv.dditable.ze.Driver.pfnGetExtensionProperties(phDrivers[ driver_index ], &extensionCount, nullptr); @@ -503,6 +529,52 @@ ${tbl['export']['name']}Legacy() %endfor +%for tbl in th.get_pfntables(specs, meta, n, tags): +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's ${tbl['name']} table +/// with current process' addresses +/// +/// @returns +/// - ::${X}_RESULT_SUCCESS +/// - ::${X}_RESULT_ERROR_UNINITIALIZED +/// - ::${X}_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::${X}_RESULT_ERROR_UNSUPPORTED_VERSION +__${x}dlllocal ${x}_result_t ${X}_APICALL +${tbl['export']['name']}FromDriver(loader::driver_t *driver) +{ + ${x}_result_t result = ${X}_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast<${tbl['pfn']}>( + GET_FUNCTION_PTR( driver->handle, "${tbl['export']['name']}") ); + if(!getTable) + %if th.isNewProcTable(tbl['export']['name']) is True and namespace != "zer": + { + //It is valid to not have this proc addr table + return ${X}_RESULT_SUCCESS; + } + %else: + return driver->initStatus; + %endif + %if tbl['experimental'] is False and namespace != "zer": #//Experimental Tables may not be implemented in driver + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.${n}.${tbl['name']}); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + %if namespace != "zes": + %if tbl['name'] == "Global" and namespace != "zer": + if (driver->dditable.ze.Global.pfnInitDrivers) { + loader::context->initDriversSupport = true; + } + %endif + %endif + %else: + result = getTable( loader::context->ddi_init_version, &driver->dditable.${n}.${tbl['name']}); + %endif + return result; +} +%endfor %for tbl in th.get_pfntables(specs, meta, n, tags): /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's ${tbl['name']} table @@ -534,63 +606,26 @@ ${tbl['export']['name']}( if( loader::context->version < version ) return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ${x}_result_t result = ${X}_RESULT_SUCCESS; - %if tbl['experimental'] is False and namespace != "zer": #//Experimental Tables may not be implemented in driver - bool atLeastOneDriverValid = false; - %endif - // Load the device-driver DDI tables %if namespace != "zes": - for( auto& drv : loader::context->zeDrivers ) + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; %else: - for( auto& drv : *loader::context->sysmanInstanceDrivers ) + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); %endif - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast<${tbl['pfn']}>( - GET_FUNCTION_PTR( drv.handle, "${tbl['export']['name']}") ); - if(!getTable) - %if th.isNewProcTable(tbl['export']['name']) is True and namespace != "zer": - { - atLeastOneDriverValid = true; - //It is valid to not have this proc addr table - continue; - } - %else: - continue; - %endif - %if tbl['experimental'] is False and namespace != "zer": #//Experimental Tables may not be implemented in driver - auto getTableResult = getTable( version, &drv.dditable.${n}.${tbl['name']}); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - %if namespace != "zes": - %if tbl['name'] == "Global" and namespace != "zer": - if (drv.dditable.ze.Global.pfnInitDrivers) { - loader::context->initDriversSupport = true; - } - %endif - %endif - %else: - result = getTable( version, &drv.dditable.${n}.${tbl['name']}); - %endif + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = ${tbl['export']['name']}FromDriver(firstDriver); } - %if tbl['experimental'] is False and namespace != "zer": #//Experimental Tables may not be implemented in driver - if(!atLeastOneDriverValid) - result = ${X}_RESULT_ERROR_UNINITIALIZED; - else - result = ${X}_RESULT_SUCCESS; - %endif - if( ${X}_RESULT_SUCCESS == result ) { %if namespace != "zes": if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) - %else: + %elif namespace == "zes": if( ( loader::context->sysmanInstanceDrivers->size() > 1 ) || loader::context->forceIntercept ) %endif { diff --git a/scripts/templates/ldrddi.h.mako b/scripts/templates/ldrddi.h.mako index 7c4a3ae5..5885d026 100644 --- a/scripts/templates/ldrddi.h.mako +++ b/scripts/templates/ldrddi.h.mako @@ -20,6 +20,10 @@ from templates import helper as th namespace loader { + /////////////////////////////////////////////////////////////////////////////// + // Forward declaration for driver_t so this header can reference loader::driver_t* + // without requiring inclusion of ze_loader_internal.h (which includes this file). + struct driver_t; /////////////////////////////////////////////////////////////////////////////// %for obj in th.extract_objs(specs, r"handle"): %if 'class' in obj: @@ -32,6 +36,8 @@ namespace loader %endif %endfor + __${x}dlllocal ze_result_t ${X}_APICALL + ${n}loaderInitDriverDDITables(loader::driver_t *driver); } namespace loader_driver_ddi @@ -57,6 +63,8 @@ extern "C" { %for tbl in th.get_pfntables(specs, meta, n, tags): __${x}dlllocal void ${X}_APICALL ${tbl['export']['name']}Legacy(); +__${x}dlllocal ze_result_t ${X}_APICALL +${tbl['export']['name']}FromDriver(loader::driver_t *driver); %endfor #if defined(__cplusplus) diff --git a/scripts/templates/ldrddi_driver_ddi.cpp.mako b/scripts/templates/ldrddi_driver_ddi.cpp.mako index c696cf52..b4a29cab 100644 --- a/scripts/templates/ldrddi_driver_ddi.cpp.mako +++ b/scripts/templates/ldrddi_driver_ddi.cpp.mako @@ -103,8 +103,15 @@ namespace loader_driver_ddi // Check if the default driver supports DDI Handles if (loader::context->defaultZerDriverHandle == nullptr) { %if ret_type == 'ze_result_t': + if (loader::context->zeDrivers.front().zerddiInitResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) { + return ${X}_RESULT_ERROR_UNSUPPORTED_FEATURE; + } return ${X}_RESULT_ERROR_UNINITIALIZED; %else: + if (loader::context->zeDrivers.front().zerddiInitResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) { + error_state::setErrorDesc("ERROR UNSUPPORTED FEATURE"); + return ${failure_return}; + } error_state::setErrorDesc("ERROR UNINITIALIZED"); return ${failure_return}; %endif diff --git a/scripts/templates/nullddi.cpp.mako b/scripts/templates/nullddi.cpp.mako index f362be45..cd3b145a 100644 --- a/scripts/templates/nullddi.cpp.mako +++ b/scripts/templates/nullddi.cpp.mako @@ -54,13 +54,23 @@ namespace driver // generic implementation %if re.match("Init", obj['name']): %if re.match("InitDrivers", obj['name']): + // Check compile-time definitions first + bool is_npu = false; + bool is_gpu = false; + #ifdef ZEL_NULL_DRIVER_TYPE_NPU + is_npu = true; + #endif + + #ifdef ZEL_NULL_DRIVER_TYPE_GPU + is_gpu = true; + #endif auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" ); - if (std::strcmp(driver_type.c_str(), "GPU") == 0) { + if (std::strcmp(driver_type.c_str(), "GPU") == 0 || is_gpu) { if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_GPU)) { return ${X}_RESULT_ERROR_UNINITIALIZED; } } - if (std::strcmp(driver_type.c_str(), "NPU") == 0) { + if (std::strcmp(driver_type.c_str(), "NPU") == 0 || is_npu) { if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_NPU)) { return ${X}_RESULT_ERROR_UNINITIALIZED; } diff --git a/scripts/templates/ze_loader_internal.h.mako b/scripts/templates/ze_loader_internal.h.mako index 556f9dc6..e56dba49 100644 --- a/scripts/templates/ze_loader_internal.h.mako +++ b/scripts/templates/ze_loader_internal.h.mako @@ -70,6 +70,13 @@ namespace loader bool driverDDIHandleSupportQueried = false; ze_driver_handle_t zerDriverHandle = nullptr; bool zerDriverDDISupported = true; + ze_api_version_t versionRequested = ZE_API_VERSION_CURRENT; + bool ddiInitialized = false; + bool customDriver = false; + ze_result_t zeddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED; + ze_result_t zetddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED; + ze_result_t zesddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED; + ze_result_t zerddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED; }; using driver_vector_t = std::vector< driver_t >; @@ -98,6 +105,7 @@ namespace loader std::unordered_map sampler_handle_map; ze_api_version_t version = ZE_API_VERSION_CURRENT; ze_api_version_t configured_version = ZE_API_VERSION_CURRENT; + ze_api_version_t ddi_init_version = ZE_API_VERSION_CURRENT; driver_vector_t allDrivers; driver_vector_t zeDrivers; @@ -113,10 +121,9 @@ namespace loader std::vector compVersions; const char *LOADER_COMP_NAME = "loader"; - ze_result_t check_drivers(ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly); void debug_trace_message(std::string errorMessage, std::string errorValue); ze_result_t init(); - ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly); + ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc); void add_loader_version(); bool driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly); void driverOrdering(driver_vector_t *drivers); @@ -130,6 +137,7 @@ namespace loader std::atomic sortingInProgress = {false}; std::mutex sortMutex; bool instrumentationEnabled = false; + bool pciOrderingRequested = false; dditable_t tracing_dditable = {}; std::shared_ptr zel_logger; ze_driver_handle_t defaultZerDriverHandle = nullptr; diff --git a/source/drivers/null/CMakeLists.txt b/source/drivers/null/CMakeLists.txt index 34d3d6ab..cf6baae7 100644 --- a/source/drivers/null/CMakeLists.txt +++ b/source/drivers/null/CMakeLists.txt @@ -48,6 +48,67 @@ add_library(ze_null_test2 SHARED ${CMAKE_CURRENT_SOURCE_DIR}/zer_nullddi.cpp ) +if(BUILD_L0_LOADER_TESTS) + # Add fake Intel GPU and NPU drivers for testing driver type initialization + add_library(ze_intel_gpu SHARED + ${CMAKE_CURRENT_SOURCE_DIR}/ze_null.h + ${CMAKE_CURRENT_SOURCE_DIR}/ze_null.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ze_nullddi.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/zet_nullddi.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/zes_nullddi.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/zer_nullddi.cpp + ) + + add_library(ze_intel_npu SHARED + ${CMAKE_CURRENT_SOURCE_DIR}/ze_null.h + ${CMAKE_CURRENT_SOURCE_DIR}/ze_null.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ze_nullddi.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/zet_nullddi.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/zes_nullddi.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/zer_nullddi.cpp + ) + set_target_properties(ze_intel_gpu PROPERTIES + VERSION "${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH}" + SOVERSION "${PROJECT_VERSION_MAJOR}" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib_fake" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin_fake" + ) + + set_target_properties(ze_intel_npu PROPERTIES + VERSION "${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH}" + SOVERSION "${PROJECT_VERSION_MAJOR}" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib_fake" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin_fake" + ) + target_include_directories(ze_intel_gpu + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + ) + + target_include_directories(ze_intel_npu + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + ) + target_compile_definitions(ze_intel_gpu PUBLIC ZEL_NULL_DRIVER_ID=3 ZEL_NULL_DRIVER_TYPE_GPU=1) + target_compile_definitions(ze_intel_npu PUBLIC ZEL_NULL_DRIVER_ID=4 ZEL_NULL_DRIVER_TYPE_NPU=1) + + # Install fake drivers to separate directory + if(INSTALL_NULL_DRIVER) + install(TARGETS ze_intel_gpu + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}_fake COMPONENT level-zero-devel + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}_fake COMPONENT level-zero + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}_fake COMPONENT level-zero + NAMELINK_COMPONENT level-zero-devel + ) + install(TARGETS ze_intel_npu + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}_fake COMPONENT level-zero-devel + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}_fake COMPONENT level-zero + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}_fake COMPONENT level-zero + NAMELINK_COMPONENT level-zero-devel + ) + endif() +endif() + set_target_properties(ze_null_test1 PROPERTIES VERSION "${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH}" SOVERSION "${PROJECT_VERSION_MAJOR}" diff --git a/source/drivers/null/ze_null.cpp b/source/drivers/null/ze_null.cpp index 1dd504bb..2418e033 100644 --- a/source/drivers/null/ze_null.cpp +++ b/source/drivers/null/ze_null.cpp @@ -376,9 +376,22 @@ namespace driver ze_device_properties_t deviceProperties = {}; deviceProperties.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; deviceProperties.type = ZE_DEVICE_TYPE_GPU; + + // Check compile-time definitions first + #ifdef ZEL_NULL_DRIVER_TYPE_NPU + deviceProperties.type = ZE_DEVICE_TYPE_VPU; + #endif + + #ifdef ZEL_NULL_DRIVER_TYPE_GPU + deviceProperties.type = ZE_DEVICE_TYPE_GPU; + #endif + + // Environment variable can override auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" ); if (std::strcmp(driver_type.c_str(), "NPU") == 0) { deviceProperties.type = ZE_DEVICE_TYPE_VPU; + } else if (std::strcmp(driver_type.c_str(), "GPU") == 0) { + deviceProperties.type = ZE_DEVICE_TYPE_GPU; } #if defined(_WIN32) strcpy_s( deviceProperties.name, "Null Device" ); diff --git a/source/drivers/null/ze_nullddi.cpp b/source/drivers/null/ze_nullddi.cpp index 4b96006c..ed88ac4f 100644 --- a/source/drivers/null/ze_nullddi.cpp +++ b/source/drivers/null/ze_nullddi.cpp @@ -107,13 +107,23 @@ namespace driver else { // generic implementation + // Check compile-time definitions first + bool is_npu = false; + bool is_gpu = false; + #ifdef ZEL_NULL_DRIVER_TYPE_NPU + is_npu = true; + #endif + + #ifdef ZEL_NULL_DRIVER_TYPE_GPU + is_gpu = true; + #endif auto driver_type = getenv_string( "ZEL_TEST_NULL_DRIVER_TYPE" ); - if (std::strcmp(driver_type.c_str(), "GPU") == 0) { + if (std::strcmp(driver_type.c_str(), "GPU") == 0 || is_gpu) { if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_GPU)) { return ZE_RESULT_ERROR_UNINITIALIZED; } } - if (std::strcmp(driver_type.c_str(), "NPU") == 0) { + if (std::strcmp(driver_type.c_str(), "NPU") == 0 || is_npu) { if (!(desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_NPU)) { return ZE_RESULT_ERROR_UNINITIALIZED; } diff --git a/source/lib/ze_lib.cpp b/source/lib/ze_lib.cpp index 05984bad..dc5dd2de 100644 --- a/source/lib/ze_lib.cpp +++ b/source/lib/ze_lib.cpp @@ -140,7 +140,6 @@ namespace ze_lib return result; } - bool zeInitDriversSupport = true; ze_api_version_t current_api_version = version; const std::string loader_name = "loader"; for (auto &component : versions) { @@ -152,7 +151,6 @@ namespace ze_lib if (component.component_lib_version.minor < 18) { std::string message = "ze_lib Context Init() Version Does not support zeInitDrivers"; debug_trace_message(message, ""); - zeInitDriversSupport = false; } } else { std::string message = "ze_lib Context Init() Loader version is too new, returning "; @@ -302,76 +300,6 @@ namespace ze_lib } // End DDI Table Inits - // Check which drivers and layers can be init on this system. - if( ZE_RESULT_SUCCESS == result) - { - // Check which drivers support the ze_driver_flag_t specified - // No need to check if only initializing sysman - bool requireDdiReinit = false; - #ifdef L0_STATIC_LOADER_BUILD - if (zeInitDriversSupport) { - typedef ze_result_t (ZE_APICALL *zelLoaderDriverCheck_t)(ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly); - auto loaderDriverCheck = reinterpret_cast( - GET_FUNCTION_PTR(loader, "zelLoaderDriverCheck") ); - if (loaderDriverCheck == nullptr) { - std::string message = "ze_lib Context Init() zelLoaderDriverCheck missing, returning "; - debug_trace_message(message, to_string(ZE_RESULT_ERROR_UNINITIALIZED)); - return ZE_RESULT_ERROR_UNINITIALIZED; - } - result = loaderDriverCheck(flags, desc, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly); - } else { - typedef ze_result_t (ZE_APICALL *zelLoaderDriverCheck_t)(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly); - auto loaderDriverCheck = reinterpret_cast( - GET_FUNCTION_PTR(loader, "zelLoaderDriverCheck") ); - if (loaderDriverCheck == nullptr) { - std::string message = "ze_lib Context Init() zelLoaderDriverCheck missing, returning "; - debug_trace_message(message, to_string(ZE_RESULT_ERROR_UNINITIALIZED)); - return ZE_RESULT_ERROR_UNINITIALIZED; - } - result = loaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly); - } - #else - result = zelLoaderDriverCheck(flags, desc, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly); - #endif - if (result != ZE_RESULT_SUCCESS) { - std::string message = "ze_lib Context Init() zelLoaderDriverCheck failed with "; - debug_trace_message(message, to_string(result)); - } - // If a driver was removed from the driver list, then the ddi tables need to be reinit to allow for passthru directly to the driver. - if (requireDdiReinit && loaderContextAccessAllowed) { - // If a user has already called the core apis, then ddi table reinit is not possible due to handles already being read by the user. - if (!sysmanOnly && !ze_lib::context->zeInuse) { - // reInit the ZE DDI Tables - if( ZE_RESULT_SUCCESS == result ) - { - result = zeDdiTableInit(version); - } - // reInit the ZET DDI Tables - if( ZE_RESULT_SUCCESS == result ) - { - result = zetDdiTableInit(version); - } - // reInit the ZER DDI Tables - if (ZE_RESULT_SUCCESS == result) - { - result = zerDdiTableInit(version); - } - // If ze/zet/zer ddi tables have been reinit and no longer use the intercept layer, then handles passed to zelLoaderTranslateHandleInternal do not require translation. - // Setting intercept_enabled==false changes the behavior of zelLoaderTranslateHandleInternal to avoid translation. - // Translation is only required if the intercept layer is enabled for the ZE handle types. - loaderContext->intercept_enabled = false; - } - // If a user has already called the zes/ze apis, then ddi table reinit is not possible due to handles already being read by the user. - if (!(ze_lib::context->zesInuse || ze_lib::context->zeInuse)) { - // reInit the ZES DDI Tables - if( ZE_RESULT_SUCCESS == result ) - { - result = zesDdiTableInit(version); - } - } - } - } - if( ZE_RESULT_SUCCESS == result ) { #ifdef L0_STATIC_LOADER_BUILD diff --git a/source/loader/driver_discovery.h b/source/loader/driver_discovery.h index 22433fb1..f27ba3ad 100644 --- a/source/loader/driver_discovery.h +++ b/source/loader/driver_discovery.h @@ -10,10 +10,17 @@ #include #include +#include "ze_loader_internal.h" namespace loader { -using DriverLibraryPath = std::string; +struct DriverLibraryPath { + std::string path; + bool customDriver; + zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32; + DriverLibraryPath(const std::string& p, bool isCustom = false, zel_driver_type_t type = ZEL_DRIVER_TYPE_FORCE_UINT32) + : path(p), customDriver(isCustom), driverType(type) {} +}; std::vector discoverEnabledDrivers(); diff --git a/source/loader/linux/driver_discovery_lin.cpp b/source/loader/linux/driver_discovery_lin.cpp index e4bf88d6..70053b55 100644 --- a/source/loader/linux/driver_discovery_lin.cpp +++ b/source/loader/linux/driver_discovery_lin.cpp @@ -13,6 +13,90 @@ #include #include +#include +#include +#include + +#include +#include +#include +// Helper to split a colon-separated path string +static std::vector splitPaths(const std::string& paths) { + std::vector result; + std::stringstream ss(paths); + std::string item; + while (std::getline(ss, item, ':')) { + if (!item.empty()) result.push_back(item); + } + return result; +} + +// Helper to check if a file exists and is readable +static bool fileExistsReadable(const std::string& path) { + struct stat sb; + return (stat(path.c_str(), &sb) == 0) && (access(path.c_str(), R_OK) == 0); +} + +// Helper to get all library search paths from LD_LIBRARY_PATH, standard locations, and /etc/ld.so.conf +static std::vector getLibrarySearchPaths() { + std::vector paths; + // LD_LIBRARY_PATH + const char* ldLibPath = getenv("LD_LIBRARY_PATH"); + if (ldLibPath) { + auto split = splitPaths(ldLibPath); + paths.insert(paths.end(), split.begin(), split.end()); + } + // Standard locations + paths.push_back("/lib"); + paths.push_back("/usr/lib"); + paths.push_back("/usr/local/lib"); + // /etc/ld.so.conf and included files + std::ifstream ldSoConf("/etc/ld.so.conf"); + if (ldSoConf) { + std::string line; + while (std::getline(ldSoConf, line)) { + if (line.empty()) continue; + if (line.find("include ") == 0) { + std::string pattern = line.substr(8); + // Simple glob: /etc/ld.so.conf.d/*.conf + std::string dir = pattern.substr(0, pattern.find_last_of('/')); + std::string ext = pattern.substr(pattern.find_last_of('.')); + DIR* d = opendir(dir.c_str()); + if (d) { + struct dirent* ent; + while ((ent = readdir(d)) != nullptr) { + std::string fname = ent->d_name; + if (fname.size() > ext.size() && fname.substr(fname.size()-ext.size()) == ext) { + std::ifstream incFile(dir + "/" + fname); + std::string incLine; + while (std::getline(incFile, incLine)) { + if (!incLine.empty() && incLine[0] != '#') + paths.push_back(incLine); + } + } + } + closedir(d); + } + } else if (line[0] != '#') { + paths.push_back(line); + } + } + } + return paths; +} + +// Main function: search for a library file in all known library paths +static bool libraryExistsInSearchPaths(const std::string& filename) { + auto paths = getLibrarySearchPaths(); + for (const auto& dir : paths) { + std::string fullPath = dir + "/" + filename; + if (fileExistsReadable(fullPath)) { + return true; + } + } + return false; +} + namespace loader { static const char *knownDriverNames[] = { @@ -29,15 +113,45 @@ std::vector discoverEnabledDrivers() { // ZE_ENABLE_ALT_DRIVERS is for development/debug only altDrivers = getenv("ZE_ENABLE_ALT_DRIVERS"); if (altDrivers == nullptr) { + // Standard drivers - not custom for (auto path : knownDriverNames) { - enabledDrivers.emplace_back(path); + if (libraryExistsInSearchPaths(path)) { + // Extract the base library name for robust driver type detection + // path is like "libze_intel_gpu.so.1" + std::string libName = path; + + // Remove "lib" prefix if present + if (libName.compare(0, 3, "lib") == 0) { + libName = libName.substr(3); + } + + // Remove file extension + size_t dotPos = libName.find('.'); + if (dotPos != std::string::npos) { + libName = libName.substr(0, dotPos); + } + + // Now match against the core driver name (e.g., "ze_intel_gpu", "ze_intel_npu") + // Check for exact matches or word boundaries to avoid substring false positives + zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32; + + // Check more specific patterns first to avoid partial matches + if (libName == "ze_intel_gpu" || libName == "ze_intel_gpu_legacy1") { + driverType = ZEL_DRIVER_TYPE_GPU; + } else if (libName == "ze_intel_vpu" || libName == "ze_intel_npu") { + driverType = ZEL_DRIVER_TYPE_NPU; + } + + enabledDrivers.emplace_back(path, false, driverType); + } } } else { + // Alternative drivers from environment variable - these are custom std::stringstream ss(altDrivers); while (ss.good()) { std::string substr; getline(ss, substr, ','); - enabledDrivers.emplace_back(substr); + enabledDrivers.emplace_back(substr, true); } } return enabledDrivers; diff --git a/source/loader/windows/driver_discovery_win.cpp b/source/loader/windows/driver_discovery_win.cpp index 46f7350c..9ead6f9d 100644 --- a/source/loader/windows/driver_discovery_win.cpp +++ b/source/loader/windows/driver_discovery_win.cpp @@ -33,16 +33,18 @@ std::vector discoverEnabledDrivers() { // ZE_ENABLE_ALT_DRIVERS is for development/debug only envBufferSize = GetEnvironmentVariable("ZE_ENABLE_ALT_DRIVERS", &altDrivers[0], envBufferSize); if (!envBufferSize) { + // Standard drivers discovered from registry - not custom auto displayDrivers = discoverDriversBasedOnDisplayAdapters(GUID_DEVCLASS_DISPLAY); auto computeDrivers = discoverDriversBasedOnDisplayAdapters(GUID_DEVCLASS_COMPUTEACCELERATOR); enabledDrivers.insert(enabledDrivers.end(), displayDrivers.begin(), displayDrivers.end()); enabledDrivers.insert(enabledDrivers.end(), computeDrivers.begin(), computeDrivers.end()); } else { + // Alternative drivers from environment variable - these are custom std::stringstream ss(altDrivers.c_str()); while (ss.good()) { std::string substr; getline(ss, substr, ','); - enabledDrivers.emplace_back(substr); + enabledDrivers.emplace_back(substr, true, ZEL_DRIVER_TYPE_FORCE_UINT32); } } @@ -110,7 +112,7 @@ DriverLibraryPath readDriverPathForDisplayAdapter(DEVINST dnDevNode) { if (CR_SUCCESS != configErr) { assert(false && "CM_Open_DevNode_Key failed"); - return ""; + return DriverLibraryPath("", false); } DWORD regValueType = {}; @@ -133,7 +135,7 @@ DriverLibraryPath readDriverPathForDisplayAdapter(DEVINST dnDevNode) { regOpStatus = RegCloseKey(hkey); assert((ERROR_SUCCESS == regOpStatus) && "RegCloseKey failed"); - return driverPath; + return DriverLibraryPath(driverPath, false); } std::wstring readDisplayAdaptersDeviceIdsList(const GUID rguid) { @@ -193,14 +195,23 @@ std::vector discoverDriversBasedOnDisplayAdapters(const GUID auto driverPath = readDriverPathForDisplayAdapter(devinst); - if (driverPath.empty()) { + if (driverPath.path.empty()) { continue; } - bool alreadyOnTheList = (enabledDrivers.end() != std::find(enabledDrivers.begin(), enabledDrivers.end(), driverPath)); + bool alreadyOnTheList = (enabledDrivers.end() != std::find_if(enabledDrivers.begin(), enabledDrivers.end(), + [&driverPath](const DriverLibraryPath& d) { return d.path == driverPath.path; })); if (alreadyOnTheList) { continue; } + driverPath.customDriver = false; + if (rguid == GUID_DEVCLASS_DISPLAY) { + driverPath.driverType = ZEL_DRIVER_TYPE_GPU; + } else if (rguid == GUID_DEVCLASS_COMPUTEACCELERATOR) { + driverPath.driverType = ZEL_DRIVER_TYPE_NPU; + } else { + driverPath.driverType = ZEL_DRIVER_TYPE_FORCE_UINT32; + } enabledDrivers.push_back(std::move(driverPath)); } diff --git a/source/loader/ze_ldrddi.cpp b/source/loader/ze_ldrddi.cpp index afe63e1c..9b3277d5 100644 --- a/source/loader/ze_ldrddi.cpp +++ b/source/loader/ze_ldrddi.cpp @@ -13,6 +13,131 @@ using namespace loader_driver_ddi; namespace loader { + __zedlllocal ze_result_t ZE_APICALL + zeloaderInitDriverDDITables(loader::driver_t *driver) { + ze_result_t result = ZE_RESULT_SUCCESS; + result = zeGetGlobalProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetRTASBuilderProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetRTASBuilderExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetRTASParallelOperationProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetRTASParallelOperationExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetDriverProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetDriverExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetDeviceProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetDeviceExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetContextProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetCommandQueueProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetCommandListProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetCommandListExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetEventProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetEventExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetEventPoolProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetFenceProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetImageProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetImageExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetKernelProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetKernelExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetMemProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetMemExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetModuleProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetModuleBuildLogProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetPhysicalMemProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetSamplerProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetVirtualMemProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetFabricEdgeExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zeGetFabricVertexExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + return result; + } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for zeInit __zedlllocal ze_result_t ZE_APICALL @@ -28,6 +153,12 @@ namespace loader { if(drv.initStatus != ZE_RESULT_SUCCESS) continue; + if (!drv.handle || !drv.ddiInitialized) { + auto res = loader::context->init_driver( drv, flags, nullptr ); + if (res != ZE_RESULT_SUCCESS) { + continue; + } + } drv.initStatus = drv.dditable.ze.Global.pfnInit( flags ); if(drv.initStatus == ZE_RESULT_SUCCESS) atLeastOneDriverValid = true; @@ -71,7 +202,7 @@ namespace loader for( auto& drv : loader::context->zeDrivers ) { - if(drv.initStatus != ZE_RESULT_SUCCESS) + if(drv.initStatus != ZE_RESULT_SUCCESS || !drv.ddiInitialized) continue; if( ( 0 < *pCount ) && ( *pCount == total_driver_handle_count)) @@ -102,7 +233,8 @@ namespace loader { for( uint32_t i = 0; i < library_driver_handle_count; ++i ) { uint32_t driver_index = total_driver_handle_count + i; - drv.zerDriverHandle = phDrivers[ driver_index ]; + if (drv.zerddiInitResult == ZE_RESULT_SUCCESS) + drv.zerDriverHandle = phDrivers[ driver_index ]; if (drv.driverDDIHandleSupportQueried == false) { uint32_t extensionCount = 0; ze_result_t res = drv.dditable.ze.Driver.pfnGetExtensionProperties(phDrivers[ driver_index ], &extensionCount, nullptr); @@ -209,6 +341,11 @@ namespace loader ze_result_t result = ZE_RESULT_SUCCESS; uint32_t total_driver_handle_count = 0; + for( auto& drv : loader::context->zeDrivers ) { + if (!drv.handle || !drv.ddiInitialized) { + loader::context->init_driver( drv, 0, desc); + } + } { std::lock_guard lock(loader::context->sortMutex); @@ -225,6 +362,7 @@ namespace loader { if (!drv.dditable.ze.Global.pfnInitDrivers) { drv.initDriversStatus = ZE_RESULT_ERROR_UNINITIALIZED; + result = ZE_RESULT_ERROR_UNINITIALIZED; continue; } @@ -256,7 +394,8 @@ namespace loader { for( uint32_t i = 0; i < library_driver_handle_count; ++i ) { uint32_t driver_index = total_driver_handle_count + i; - drv.zerDriverHandle = phDrivers[ driver_index ]; + if (drv.zerddiInitResult == ZE_RESULT_SUCCESS) + drv.zerDriverHandle = phDrivers[ driver_index ]; if (drv.driverDDIHandleSupportQueried == false) { uint32_t extensionCount = 0; ze_result_t res = drv.dditable.ze.Driver.pfnGetExtensionProperties(phDrivers[ driver_index ], &extensionCount, nullptr); @@ -7641,100 +7780,26 @@ zeGetFabricVertexExpProcAddrTableLegacy() /// - ::ZE_RESULT_ERROR_UNINITIALIZED /// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER /// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION -ZE_DLLEXPORT ze_result_t ZE_APICALL -zeGetGlobalProcAddrTable( - ze_api_version_t version, ///< [in] API version requested - ze_global_dditable_t* pDdiTable ///< [in,out] pointer to table of DDI function pointers - ) +__zedlllocal ze_result_t ZE_APICALL +zeGetGlobalProcAddrTableFromDriver(loader::driver_t *driver) { - if( loader::context->zeDrivers.size() < 1 ) { - return ZE_RESULT_ERROR_UNINITIALIZED; - } - - if( nullptr == pDdiTable ) - return ZE_RESULT_ERROR_INVALID_NULL_POINTER; - - if( loader::context->version < version ) - return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; - ze_result_t result = ZE_RESULT_SUCCESS; - - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetGlobalProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Global); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - if (drv.dditable.ze.Global.pfnInitDrivers) { - loader::context->initDriversSupport = true; - } - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; - - if( ZE_RESULT_SUCCESS == result ) - { - if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) - { - // return pointers to loader's DDIs - loader::loaderDispatch->pCore->Global = new ze_global_dditable_t; - if (version >= ZE_API_VERSION_1_0) { - pDdiTable->pfnInit = loader::zeInit; - } - if (version >= ZE_API_VERSION_1_10) { - pDdiTable->pfnInitDrivers = loader::zeInitDrivers; - } - zeGetGlobalProcAddrTableLegacy(); - } - else - { - // return pointers directly to driver's DDIs - *pDdiTable = loader::context->zeDrivers.front().dditable.ze.Global; - } - } - - // If the validation layer is enabled, then intercept the loader's DDIs - if(( ZE_RESULT_SUCCESS == result ) && ( nullptr != loader::context->validationLayer )) - { - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR(loader::context->validationLayer, "zeGetGlobalProcAddrTable") ); - if(!getTable) - return ZE_RESULT_ERROR_UNINITIALIZED; - result = getTable( version, pDdiTable ); - } - - // If the API tracing layer is enabled, then intercept the loader's DDIs - if(( ZE_RESULT_SUCCESS == result ) && ( nullptr != loader::context->tracingLayer )) - { - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR(loader::context->tracingLayer, "zeGetGlobalProcAddrTable") ); - if(!getTable) - return ZE_RESULT_ERROR_UNINITIALIZED; - ze_global_dditable_t dditable; - memcpy(&dditable, pDdiTable, sizeof(ze_global_dditable_t)); - result = getTable( version, &dditable ); - loader::context->tracing_dditable.ze.Global = dditable; - if ( loader::context->tracingLayerEnabled ) { - result = getTable( version, pDdiTable ); - } + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetGlobalProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Global); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + if (driver->dditable.ze.Global.pfnInitDrivers) { + loader::context->initDriversSupport = true; } - return result; } - /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's RTASBuilder table /// with current process' addresses @@ -7744,126 +7809,23 @@ zeGetGlobalProcAddrTable( /// - ::ZE_RESULT_ERROR_UNINITIALIZED /// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER /// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION -ZE_DLLEXPORT ze_result_t ZE_APICALL -zeGetRTASBuilderProcAddrTable( - ze_api_version_t version, ///< [in] API version requested - ze_rtas_builder_dditable_t* pDdiTable ///< [in,out] pointer to table of DDI function pointers - ) +__zedlllocal ze_result_t ZE_APICALL +zeGetRTASBuilderProcAddrTableFromDriver(loader::driver_t *driver) { - if( loader::context->zeDrivers.size() < 1 ) { - return ZE_RESULT_ERROR_UNINITIALIZED; - } - - if( nullptr == pDdiTable ) - return ZE_RESULT_ERROR_INVALID_NULL_POINTER; - - if( loader::context->version < version ) - return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; - ze_result_t result = ZE_RESULT_SUCCESS; - - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetRTASBuilderProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.RTASBuilder); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; - - if( ZE_RESULT_SUCCESS == result ) - { - if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) - { - // return pointers to loader's DDIs - loader::loaderDispatch->pCore->RTASBuilder = new ze_rtas_builder_dditable_t; - if (version >= ZE_API_VERSION_1_13) { - if (loader::context->driverDDIPathDefault) { - pDdiTable->pfnCreateExt = loader_driver_ddi::zeRTASBuilderCreateExt; - } else { - pDdiTable->pfnCreateExt = loader::zeRTASBuilderCreateExt; - } - } - if (version >= ZE_API_VERSION_1_13) { - if (loader::context->driverDDIPathDefault) { - pDdiTable->pfnGetBuildPropertiesExt = loader_driver_ddi::zeRTASBuilderGetBuildPropertiesExt; - } else { - pDdiTable->pfnGetBuildPropertiesExt = loader::zeRTASBuilderGetBuildPropertiesExt; - } - } - if (version >= ZE_API_VERSION_1_13) { - if (loader::context->driverDDIPathDefault) { - pDdiTable->pfnBuildExt = loader_driver_ddi::zeRTASBuilderBuildExt; - } else { - pDdiTable->pfnBuildExt = loader::zeRTASBuilderBuildExt; - } - } - if (version >= ZE_API_VERSION_1_13) { - if (loader::context->driverDDIPathDefault) { - pDdiTable->pfnCommandListAppendCopyExt = loader_driver_ddi::zeRTASBuilderCommandListAppendCopyExt; - } else { - pDdiTable->pfnCommandListAppendCopyExt = loader::zeRTASBuilderCommandListAppendCopyExt; - } - } - if (version >= ZE_API_VERSION_1_13) { - if (loader::context->driverDDIPathDefault) { - pDdiTable->pfnDestroyExt = loader_driver_ddi::zeRTASBuilderDestroyExt; - } else { - pDdiTable->pfnDestroyExt = loader::zeRTASBuilderDestroyExt; - } - } - zeGetRTASBuilderProcAddrTableLegacy(); - } - else - { - // return pointers directly to driver's DDIs - *pDdiTable = loader::context->zeDrivers.front().dditable.ze.RTASBuilder; - } - } - - // If the validation layer is enabled, then intercept the loader's DDIs - if(( ZE_RESULT_SUCCESS == result ) && ( nullptr != loader::context->validationLayer )) - { - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR(loader::context->validationLayer, "zeGetRTASBuilderProcAddrTable") ); - if(!getTable) - return ZE_RESULT_ERROR_UNINITIALIZED; - result = getTable( version, pDdiTable ); - } - - // If the API tracing layer is enabled, then intercept the loader's DDIs - if(( ZE_RESULT_SUCCESS == result ) && ( nullptr != loader::context->tracingLayer )) - { - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR(loader::context->tracingLayer, "zeGetRTASBuilderProcAddrTable") ); - if(!getTable) - return ZE_RESULT_ERROR_UNINITIALIZED; - ze_rtas_builder_dditable_t dditable; - memcpy(&dditable, pDdiTable, sizeof(ze_rtas_builder_dditable_t)); - result = getTable( version, &dditable ); - loader::context->tracing_dditable.ze.RTASBuilder = dditable; - if ( loader::context->tracingLayerEnabled ) { - result = getTable( version, pDdiTable ); - } - } - + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetRTASBuilderProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.RTASBuilder); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; return result; } - /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's RTASBuilderExp table /// with current process' addresses @@ -7873,38 +7835,916 @@ zeGetRTASBuilderProcAddrTable( /// - ::ZE_RESULT_ERROR_UNINITIALIZED /// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER /// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION -ZE_DLLEXPORT ze_result_t ZE_APICALL -zeGetRTASBuilderExpProcAddrTable( - ze_api_version_t version, ///< [in] API version requested - ze_rtas_builder_exp_dditable_t* pDdiTable ///< [in,out] pointer to table of DDI function pointers - ) +__zedlllocal ze_result_t ZE_APICALL +zeGetRTASBuilderExpProcAddrTableFromDriver(loader::driver_t *driver) { - if( loader::context->zeDrivers.size() < 1 ) { - return ZE_RESULT_ERROR_UNINITIALIZED; - } - - if( nullptr == pDdiTable ) - return ZE_RESULT_ERROR_INVALID_NULL_POINTER; - - if( loader::context->version < version ) - return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; - ze_result_t result = ZE_RESULT_SUCCESS; - - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetRTASBuilderExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.RTASBuilderExp); - } - - - if( ZE_RESULT_SUCCESS == result ) + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetRTASBuilderExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.RTASBuilderExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's RTASParallelOperation table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetRTASParallelOperationProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetRTASParallelOperationProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.RTASParallelOperation); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's RTASParallelOperationExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetRTASParallelOperationExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetRTASParallelOperationExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.RTASParallelOperationExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Driver table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetDriverProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetDriverProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Driver); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's DriverExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetDriverExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetDriverExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.DriverExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Device table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetDeviceProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetDeviceProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Device); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's DeviceExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetDeviceExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetDeviceExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.DeviceExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Context table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetContextProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetContextProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Context); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's CommandQueue table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetCommandQueueProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetCommandQueueProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.CommandQueue); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's CommandList table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetCommandListProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetCommandListProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.CommandList); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's CommandListExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetCommandListExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetCommandListExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.CommandListExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Event table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetEventProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetEventProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Event); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EventExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetEventExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetEventExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.EventExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EventPool table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetEventPoolProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetEventPoolProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.EventPool); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Fence table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetFenceProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetFenceProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Fence); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Image table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetImageProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetImageProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Image); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's ImageExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetImageExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetImageExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.ImageExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Kernel table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetKernelProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetKernelProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Kernel); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetKernelExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetKernelExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.KernelExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Mem table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetMemProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetMemProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Mem); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's MemExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetMemExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetMemExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.MemExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Module table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetModuleProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetModuleProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Module); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's ModuleBuildLog table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetModuleBuildLogProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetModuleBuildLogProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.ModuleBuildLog); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's PhysicalMem table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetPhysicalMemProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetPhysicalMemProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.PhysicalMem); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Sampler table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetSamplerProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetSamplerProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.Sampler); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's VirtualMem table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetVirtualMemProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetVirtualMemProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.ze.VirtualMem); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's FabricEdgeExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetFabricEdgeExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetFabricEdgeExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.FabricEdgeExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's FabricVertexExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zeGetFabricVertexExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zeGetFabricVertexExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.ze.FabricVertexExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Global table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +ZE_DLLEXPORT ze_result_t ZE_APICALL +zeGetGlobalProcAddrTable( + ze_api_version_t version, ///< [in] API version requested + ze_global_dditable_t* pDdiTable ///< [in,out] pointer to table of DDI function pointers + ) +{ + if( loader::context->zeDrivers.size() < 1 ) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + + if( nullptr == pDdiTable ) + return ZE_RESULT_ERROR_INVALID_NULL_POINTER; + + if( loader::context->version < version ) + return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + + loader::context->ddi_init_version = version; + + ze_result_t result = ZE_RESULT_SUCCESS; + + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetGlobalProcAddrTableFromDriver(firstDriver); + } + + if( ZE_RESULT_SUCCESS == result ) + { + if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) + { + // return pointers to loader's DDIs + loader::loaderDispatch->pCore->Global = new ze_global_dditable_t; + if (version >= ZE_API_VERSION_1_0) { + pDdiTable->pfnInit = loader::zeInit; + } + if (version >= ZE_API_VERSION_1_10) { + pDdiTable->pfnInitDrivers = loader::zeInitDrivers; + } + zeGetGlobalProcAddrTableLegacy(); + } + else + { + // return pointers directly to driver's DDIs + *pDdiTable = loader::context->zeDrivers.front().dditable.ze.Global; + } + } + + // If the validation layer is enabled, then intercept the loader's DDIs + if(( ZE_RESULT_SUCCESS == result ) && ( nullptr != loader::context->validationLayer )) + { + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR(loader::context->validationLayer, "zeGetGlobalProcAddrTable") ); + if(!getTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + result = getTable( version, pDdiTable ); + } + + // If the API tracing layer is enabled, then intercept the loader's DDIs + if(( ZE_RESULT_SUCCESS == result ) && ( nullptr != loader::context->tracingLayer )) + { + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR(loader::context->tracingLayer, "zeGetGlobalProcAddrTable") ); + if(!getTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + ze_global_dditable_t dditable; + memcpy(&dditable, pDdiTable, sizeof(ze_global_dditable_t)); + result = getTable( version, &dditable ); + loader::context->tracing_dditable.ze.Global = dditable; + if ( loader::context->tracingLayerEnabled ) { + result = getTable( version, pDdiTable ); + } + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's RTASBuilder table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +ZE_DLLEXPORT ze_result_t ZE_APICALL +zeGetRTASBuilderProcAddrTable( + ze_api_version_t version, ///< [in] API version requested + ze_rtas_builder_dditable_t* pDdiTable ///< [in,out] pointer to table of DDI function pointers + ) +{ + if( loader::context->zeDrivers.size() < 1 ) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + + if( nullptr == pDdiTable ) + return ZE_RESULT_ERROR_INVALID_NULL_POINTER; + + if( loader::context->version < version ) + return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + + loader::context->ddi_init_version = version; + + ze_result_t result = ZE_RESULT_SUCCESS; + + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetRTASBuilderProcAddrTableFromDriver(firstDriver); + } + + if( ZE_RESULT_SUCCESS == result ) + { + if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) + { + // return pointers to loader's DDIs + loader::loaderDispatch->pCore->RTASBuilder = new ze_rtas_builder_dditable_t; + if (version >= ZE_API_VERSION_1_13) { + if (loader::context->driverDDIPathDefault) { + pDdiTable->pfnCreateExt = loader_driver_ddi::zeRTASBuilderCreateExt; + } else { + pDdiTable->pfnCreateExt = loader::zeRTASBuilderCreateExt; + } + } + if (version >= ZE_API_VERSION_1_13) { + if (loader::context->driverDDIPathDefault) { + pDdiTable->pfnGetBuildPropertiesExt = loader_driver_ddi::zeRTASBuilderGetBuildPropertiesExt; + } else { + pDdiTable->pfnGetBuildPropertiesExt = loader::zeRTASBuilderGetBuildPropertiesExt; + } + } + if (version >= ZE_API_VERSION_1_13) { + if (loader::context->driverDDIPathDefault) { + pDdiTable->pfnBuildExt = loader_driver_ddi::zeRTASBuilderBuildExt; + } else { + pDdiTable->pfnBuildExt = loader::zeRTASBuilderBuildExt; + } + } + if (version >= ZE_API_VERSION_1_13) { + if (loader::context->driverDDIPathDefault) { + pDdiTable->pfnCommandListAppendCopyExt = loader_driver_ddi::zeRTASBuilderCommandListAppendCopyExt; + } else { + pDdiTable->pfnCommandListAppendCopyExt = loader::zeRTASBuilderCommandListAppendCopyExt; + } + } + if (version >= ZE_API_VERSION_1_13) { + if (loader::context->driverDDIPathDefault) { + pDdiTable->pfnDestroyExt = loader_driver_ddi::zeRTASBuilderDestroyExt; + } else { + pDdiTable->pfnDestroyExt = loader::zeRTASBuilderDestroyExt; + } + } + zeGetRTASBuilderProcAddrTableLegacy(); + } + else + { + // return pointers directly to driver's DDIs + *pDdiTable = loader::context->zeDrivers.front().dditable.ze.RTASBuilder; + } + } + + // If the validation layer is enabled, then intercept the loader's DDIs + if(( ZE_RESULT_SUCCESS == result ) && ( nullptr != loader::context->validationLayer )) + { + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR(loader::context->validationLayer, "zeGetRTASBuilderProcAddrTable") ); + if(!getTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + result = getTable( version, pDdiTable ); + } + + // If the API tracing layer is enabled, then intercept the loader's DDIs + if(( ZE_RESULT_SUCCESS == result ) && ( nullptr != loader::context->tracingLayer )) + { + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR(loader::context->tracingLayer, "zeGetRTASBuilderProcAddrTable") ); + if(!getTable) + return ZE_RESULT_ERROR_UNINITIALIZED; + ze_rtas_builder_dditable_t dditable; + memcpy(&dditable, pDdiTable, sizeof(ze_rtas_builder_dditable_t)); + result = getTable( version, &dditable ); + loader::context->tracing_dditable.ze.RTASBuilder = dditable; + if ( loader::context->tracingLayerEnabled ) { + result = getTable( version, pDdiTable ); + } + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's RTASBuilderExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +ZE_DLLEXPORT ze_result_t ZE_APICALL +zeGetRTASBuilderExpProcAddrTable( + ze_api_version_t version, ///< [in] API version requested + ze_rtas_builder_exp_dditable_t* pDdiTable ///< [in,out] pointer to table of DDI function pointers + ) +{ + if( loader::context->zeDrivers.size() < 1 ) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + + if( nullptr == pDdiTable ) + return ZE_RESULT_ERROR_INVALID_NULL_POINTER; + + if( loader::context->version < version ) + return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + + loader::context->ddi_init_version = version; + + ze_result_t result = ZE_RESULT_SUCCESS; + + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetRTASBuilderExpProcAddrTableFromDriver(firstDriver); + } + + if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) { @@ -8001,30 +8841,15 @@ zeGetRTASParallelOperationProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetRTASParallelOperationProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.RTASParallelOperation); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetRTASParallelOperationProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -8123,21 +8948,16 @@ zeGetRTASParallelOperationExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetRTASParallelOperationExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.RTASParallelOperationExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetRTASParallelOperationExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -8235,30 +9055,15 @@ zeGetDriverProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetDriverProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Driver); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetDriverProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -8388,21 +9193,16 @@ zeGetDriverExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetDriverExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.DriverExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetDriverExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -8479,30 +9279,15 @@ zeGetDeviceProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetDeviceProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Device); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetDeviceProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -8734,21 +9519,16 @@ zeGetDeviceExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetDeviceExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.DeviceExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetDeviceExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -8825,30 +9605,15 @@ zeGetContextProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetContextProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Context); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetContextProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -8982,30 +9747,15 @@ zeGetCommandQueueProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetCommandQueueProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.CommandQueue); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetCommandQueueProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -9118,30 +9868,15 @@ zeGetCommandListProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetCommandListProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.CommandList); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetCommandListProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -9478,21 +10213,16 @@ zeGetCommandListExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetCommandListExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.CommandListExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetCommandListExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -9618,30 +10348,15 @@ zeGetEventProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetEventProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Event); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetEventProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -9789,21 +10504,16 @@ zeGetEventExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetEventExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.EventExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetEventExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -9880,30 +10590,15 @@ zeGetEventPoolProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetEventPoolProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.EventPool); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetEventPoolProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -10030,30 +10725,15 @@ zeGetFenceProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetFenceProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Fence); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetFenceProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -10159,30 +10839,15 @@ zeGetImageProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetImageProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Image); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetImageProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -10288,21 +10953,16 @@ zeGetImageExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetImageExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.ImageExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetImageExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -10393,30 +11053,15 @@ zeGetKernelProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetKernelProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Kernel); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetKernelProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -10571,21 +11216,16 @@ zeGetKernelExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetKernelExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.KernelExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetKernelExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -10683,30 +11323,15 @@ zeGetMemProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetMemProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Mem); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetMemProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -10861,21 +11486,16 @@ zeGetMemExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetMemExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.MemExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetMemExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -10973,30 +11593,15 @@ zeGetModuleProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetModuleProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Module); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetModuleProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -11130,30 +11735,15 @@ zeGetModuleBuildLogProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetModuleBuildLogProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.ModuleBuildLog); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetModuleBuildLogProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -11238,30 +11828,15 @@ zeGetPhysicalMemProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetPhysicalMemProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.PhysicalMem); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetPhysicalMemProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -11346,30 +11921,15 @@ zeGetSamplerProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetSamplerProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.Sampler); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetSamplerProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -11454,30 +12014,15 @@ zeGetVirtualMemProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetVirtualMemProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.ze.VirtualMem); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetVirtualMemProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -11597,21 +12142,16 @@ zeGetFabricEdgeExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetFabricEdgeExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.FabricEdgeExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetFabricEdgeExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -11702,21 +12242,16 @@ zeGetFabricVertexExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zeGetFabricVertexExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.ze.FabricVertexExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zeGetFabricVertexExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) diff --git a/source/loader/ze_ldrddi.h b/source/loader/ze_ldrddi.h index 485fef93..331113f8 100644 --- a/source/loader/ze_ldrddi.h +++ b/source/loader/ze_ldrddi.h @@ -11,6 +11,10 @@ namespace loader { + /////////////////////////////////////////////////////////////////////////////// + // Forward declaration for driver_t so this header can reference loader::driver_t* + // without requiring inclusion of ze_loader_internal.h (which includes this file). + struct driver_t; /////////////////////////////////////////////////////////////////////////////// using ze_driver_object_t = object_t < ze_driver_handle_t >; using ze_driver_factory_t = singleton_factory_t < ze_driver_object_t, ze_driver_handle_t >; @@ -75,6 +79,8 @@ namespace loader using ze_rtas_parallel_operation_exp_object_t = object_t < ze_rtas_parallel_operation_exp_handle_t >; using ze_rtas_parallel_operation_exp_factory_t = singleton_factory_t < ze_rtas_parallel_operation_exp_object_t, ze_rtas_parallel_operation_exp_handle_t >; + __zedlllocal ze_result_t ZE_APICALL + zeloaderInitDriverDDITables(loader::driver_t *driver); } namespace loader_driver_ddi @@ -1658,64 +1664,124 @@ extern "C" { __zedlllocal void ZE_APICALL zeGetGlobalProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetGlobalProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetRTASBuilderProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetRTASBuilderProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetRTASBuilderExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetRTASBuilderExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetRTASParallelOperationProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetRTASParallelOperationProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetRTASParallelOperationExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetRTASParallelOperationExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetDriverProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetDriverProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetDriverExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetDriverExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetDeviceProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetDeviceProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetDeviceExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetDeviceExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetContextProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetContextProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetCommandQueueProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetCommandQueueProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetCommandListProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetCommandListProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetCommandListExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetCommandListExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetEventProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetEventProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetEventExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetEventExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetEventPoolProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetEventPoolProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetFenceProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetFenceProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetImageProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetImageProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetImageExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetImageExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetKernelProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetKernelProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetKernelExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetKernelExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetMemProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetMemProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetMemExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetMemExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetModuleProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetModuleProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetModuleBuildLogProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetModuleBuildLogProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetPhysicalMemProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetPhysicalMemProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetSamplerProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetSamplerProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetVirtualMemProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetVirtualMemProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetFabricEdgeExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetFabricEdgeExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zeGetFabricVertexExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zeGetFabricVertexExpProcAddrTableFromDriver(loader::driver_t *driver); #if defined(__cplusplus) }; diff --git a/source/loader/ze_loader.cpp b/source/loader/ze_loader.cpp index 920c58ad..1b376f39 100644 --- a/source/loader/ze_loader.cpp +++ b/source/loader/ze_loader.cpp @@ -257,6 +257,7 @@ namespace loader for (auto &driver : *drivers) { uint32_t pCount = 0; std::vector driverHandles; + driver.pciOrderingRequested = loader::context->pciOrderingRequested; ze_result_t res = ZE_RESULT_SUCCESS; if (desc && driver.dditable.ze.Global.pfnInitDrivers) { if (driver.initDriversStatus != ZE_RESULT_SUCCESS) { @@ -323,8 +324,9 @@ namespace loader continue; } } else { + res = ZE_RESULT_ERROR_UNINITIALIZED; if (debugTraceEnabled) { - std::string message = "driverSorting " + driver.name + " zeDriverGet and zeInitDrivers not supported, skipping driver"; + std::string message = "driverSorting " + driver.name + " zeDriverGet and zeInitDrivers not supported, skipping driver with error "; debug_trace_message(message, loader::to_string(res)); } continue; @@ -332,7 +334,8 @@ namespace loader for (auto handle : driverHandles) { uint32_t extensionCount = 0; - driver.zerDriverHandle = handle; + if (driver.zerddiInitResult == ZE_RESULT_SUCCESS) + driver.zerDriverHandle = handle; ze_result_t res = driver.dditable.ze.Driver.pfnGetExtensionProperties(handle, &extensionCount, nullptr); if (res != ZE_RESULT_SUCCESS) { if (loader::context->debugTraceEnabled) { @@ -416,6 +419,7 @@ namespace loader } bool integratedGPU = false; bool discreteGPU = false; + bool npu = false; bool other = false; for( auto device : deviceHandles ) { ze_device_properties_t deviceProperties = {}; @@ -435,11 +439,17 @@ namespace loader } else { discreteGPU = true; } + } else if (deviceProperties.type == ZE_DEVICE_TYPE_VPU) { + npu = true; } else { other = true; } } - if (integratedGPU && discreteGPU && other) { + if (driver.driverType == ZEL_DRIVER_TYPE_NPU && npu == false) { + // Driver was forced to NPU but no NPU devices found, skip updating type. + continue; + } + if (integratedGPU && discreteGPU && (other || npu)) { driver.driverType = ZEL_DRIVER_TYPE_MIXED; } else if (integratedGPU && discreteGPU) { driver.driverType = ZEL_DRIVER_TYPE_GPU; @@ -447,6 +457,8 @@ namespace loader driver.driverType = ZEL_DRIVER_TYPE_INTEGRATED_GPU; } else if (discreteGPU) { driver.driverType = ZEL_DRIVER_TYPE_DISCRETE_GPU; + } else if (npu) { + driver.driverType = ZEL_DRIVER_TYPE_NPU; } else if (other) { driver.driverType = ZEL_DRIVER_TYPE_OTHER; } @@ -470,200 +482,103 @@ namespace loader return true; } - /** - * @brief Checks and initializes drivers based on the provided flags and descriptors. - * - * This function performs the following operations: - * 1. If debug tracing is enabled, logs the input parameters. - * 2. If `zeInitDrivers` is not supported by the driver and it is called first, returns `ZE_RESULT_ERROR_UNINITIALIZED`. - * 3. Determines the appropriate driver vector (`zeDrivers` or `zesDrivers`) based on the input parameters. - * 4. Iterates over the drivers and attempts to initialize each driver: - * - If initialization fails and the driver is not in use, removes the driver from the list. - * - If the number of drivers becomes one and interception is not forced, sets the `requireDdiReinit` flag to true. - * - If the initialization fails and `return_first_driver_result` is true, returns the result immediately. - * - If initialization succeeds, marks the driver as in use. - * 5. If no drivers are left, returns `ZE_RESULT_ERROR_UNINITIALIZED`. - * 6. Returns `ZE_RESULT_SUCCESS` if at least one driver is successfully initialized. - * - * @param flags Initialization flags. - * @param desc Driver type descriptor (optional). - * @param globalInitStored Pointer to global DDI table for initialization. - * @param sysmanGlobalInitStored Pointer to Sysman global DDI table for initialization. - * @param requireDdiReinit Pointer to a boolean flag indicating if DDI reinitialization is required. - * @param sysmanOnly Boolean flag indicating if only Sysman drivers should be checked. - * @return `ZE_RESULT_SUCCESS` if at least one driver is successfully initialized, otherwise an appropriate error code. - */ - ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly) { + ze_result_t context_t::init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc) { + bool loadDriver = false; if (debugTraceEnabled) { - if (desc) { - std::string message = "check_drivers(" + std::string("desc->flags=") + loader::to_string(desc) + ")"; - debug_trace_message(message, ""); - } else { - std::string message = "check_drivers(" + std::string("flags=") + loader::to_string(flags) + ")"; - debug_trace_message(message, ""); - } + std::string message = "Initializing driver " + driver.name + " with type " + std::to_string(driver.driverType);\ + debug_trace_message(message, ""); } - // If zeInitDrivers is not supported by this driver, but zeInitDrivers is called first, then return uninitialized. - if (desc && !loader::context->initDriversSupport) { - if (debugTraceEnabled) { - std::string message = "zeInitDrivers called first, but not supported by driver, returning uninitialized."; - debug_trace_message(message, ""); + if ((!desc && (flags == 0 || flags & ZE_INIT_FLAG_GPU_ONLY)) || (desc && desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_GPU)) { + if (driver.driverType == ZEL_DRIVER_TYPE_GPU || driver.driverType == ZEL_DRIVER_TYPE_DISCRETE_GPU || driver.driverType == ZEL_DRIVER_TYPE_INTEGRATED_GPU) { + if (debugTraceEnabled) { + std::string message = "init driver " + driver.name + " found GPU Supported Driver."; + debug_trace_message(message, ""); + } + loadDriver = true; } - return ZE_RESULT_ERROR_UNINITIALIZED; } - - - bool return_first_driver_result=false; - std::string initName = "zeInit"; - driver_vector_t *drivers = &zeDrivers; - // If desc is set, then this is zeInitDrivers. - if (desc) { - initName = "zeInitDrivers"; - } - // If this is sysmanOnly check_drivers, then zesInit is being called and we need to use zesDrivers. - if (sysmanOnly) { - drivers = &zesDrivers; - initName = "zesInit"; - } - if(drivers->size()==1) { - return_first_driver_result=true; - } - bool pciOrderingRequested = getenv_tobool( "ZE_ENABLE_PCI_ID_DEVICE_ORDER" ); - loader::context->instrumentationEnabled = getenv_tobool( "ZET_ENABLE_PROGRAM_INSTRUMENTATION" ); - - for(auto it = drivers->begin(); it != drivers->end(); ) - { - it->pciOrderingRequested = pciOrderingRequested; - std::string freeLibraryErrorValue; - ze_result_t result = init_driver(*it, flags, desc, globalInitStored, sysmanGlobalInitStored, sysmanOnly); - if(result != ZE_RESULT_SUCCESS) { - // If the driver has already been init and handles are to be read, then this driver cannot be removed from the list. - // Also, if any driver supports zeInitDrivers, then no driver can be removed to allow for different sets of drivers. - if (!it->driverInuse && !loader::context->initDriversSupport) { - if (debugTraceEnabled) { - std::string errorMessage = "Check Drivers Failed on " + it->name + " , driver will be removed. " + initName + " failed with "; - debug_trace_message(errorMessage, loader::to_string(result)); - } - it = drivers->erase(it); - // If the number of drivers is now ==1, then we need to reinit the ddi tables to pass through. - // If ZE_ENABLE_LOADER_INTERCEPT is set to 1, then even if drivers were removed, don't reinit the ddi tables. - if (drivers->size() == 1 && !loader::context->forceIntercept) { - *requireDdiReinit = true; - } - } else { - it++; + if ((!desc && (flags == 0 || flags & ZE_INIT_FLAG_VPU_ONLY)) || (desc && desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_NPU)) { + if (driver.driverType == ZEL_DRIVER_TYPE_NPU || driver.driverType == ZEL_DRIVER_TYPE_OTHER) { + if (debugTraceEnabled) { + std::string message = "init driver " + driver.name + " found VPU/NPU Supported Driver."; + debug_trace_message(message, ""); } - if(return_first_driver_result) - return result; - } else { - // If this is a single driver system, then the first success for this driver needs to be set. - it->driverInuse = true; - it++; + loadDriver = true; } } - if(drivers->size() == 0) - return ZE_RESULT_ERROR_UNINITIALIZED; + loadDriver = !driver.handle && driver.customDriver ? true : loadDriver; - // Set default driver handle and DDI table to the first driver in the list before sorting. - if (loader::context->zeDrivers.front().zerDriverDDISupported) - loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle; - else - loader::context->defaultZerDriverHandle = nullptr; - loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer; - return ZE_RESULT_SUCCESS; - } - - ze_result_t context_t::init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly) { - if (sysmanOnly) { - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR(driver.handle, "zesGetGlobalProcAddrTable")); - if(!getTable) { + if (loadDriver && !driver.handle) { + auto handle = LOAD_DRIVER_LIBRARY( driver.name.c_str() ); + if( NULL != handle ) + { + driver.handle = handle; + } else { + std::string loadLibraryErrorValue; + GET_LIBRARY_ERROR(loadLibraryErrorValue); if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zesGetGlobalProcAddrTable function pointer null. Returning "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); + std::string errorMessage = "init driver " + driver.name + " failed, Load Library of " + driver.name + " failed with "; + debug_trace_message(errorMessage, loadLibraryErrorValue); } return ZE_RESULT_ERROR_UNINITIALIZED; } + } - zes_global_dditable_t global; - auto getTableResult = getTable(this->configured_version, &global); - if(getTableResult != ZE_RESULT_SUCCESS) { + if (driver.handle && !driver.ddiInitialized) { + auto res = loader::zeloaderInitDriverDDITables(&driver); + if (res != ZE_RESULT_SUCCESS) { if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zesGetGlobalProcAddrTable() failed with "; - debug_trace_message(errorMessage, loader::to_string(getTableResult)); + std::string message = "init driver " + driver.name + " failed, zeloaderInitDriverDDITables returned "; + debug_trace_message(message, loader::to_string(res)); } - return ZE_RESULT_ERROR_UNINITIALIZED; + driver.zeddiInitResult = res; + } else { + driver.zeddiInitResult = ZE_RESULT_SUCCESS; } - - if(nullptr == global.pfnInit) { + res = loader::zesloaderInitDriverDDITables(&driver); + if (res != ZE_RESULT_SUCCESS) { if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zesInit function pointer null. Returning "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); + std::string message = "init driver " + driver.name + " failed, zesloaderInitDriverDDITables returned "; + debug_trace_message(message, loader::to_string(res)); } - return ZE_RESULT_ERROR_UNINITIALIZED; - } - - // Use the previously init ddi table pointer to zesInit to allow for intercept of the zesInit calls - ze_result_t res = sysmanGlobalInitStored->pfnInit(flags); - // Verify that this driver successfully init in the call above. - if (driver.initSysManStatus != ZE_RESULT_SUCCESS) { - res = driver.initSysManStatus; - } - if (debugTraceEnabled) { - std::string message = "init driver " + driver.name + " zesInit(" + loader::to_string(flags) + ") returning "; - debug_trace_message(message, loader::to_string(res)); + driver.zesddiInitResult = res; + } else { + driver.zesddiInitResult = ZE_RESULT_SUCCESS; } - return res; - } else { - if (!desc) { - auto pfnInit = driver.dditable.ze.Global.pfnInit; - if(nullptr == pfnInit || globalInitStored->pfnInit == nullptr) { - if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); - } - return ZE_RESULT_ERROR_UNINITIALIZED; - } - - // Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls - ze_result_t res = globalInitStored->pfnInit(flags); - // Verify that this driver successfully init in the call above. - if (res != ZE_RESULT_SUCCESS || driver.initStatus != ZE_RESULT_SUCCESS) { - if (driver.initStatus != ZE_RESULT_SUCCESS) - res = driver.initStatus; - if (debugTraceEnabled) { - std::string message = "init driver (global ddi) " + driver.name + " zeInit(" + loader::to_string(flags) + ") returning "; - debug_trace_message(message, loader::to_string(res)); - } - return res; + res = loader::zetloaderInitDriverDDITables(&driver); + if (res != ZE_RESULT_SUCCESS) { + if (debugTraceEnabled) { + std::string message = "init driver " + driver.name + " failed, zetloaderInitDriverDDITables returned "; + debug_trace_message(message, loader::to_string(res)); } + driver.zetddiInitResult = res; } else { - auto pfnInitDrivers = driver.dditable.ze.Global.pfnInitDrivers; - if(nullptr == pfnInitDrivers || globalInitStored->pfnInitDrivers == nullptr) { - if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, pfnInitDrivers function pointer null. Returning "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); - } - return ZE_RESULT_ERROR_UNINITIALIZED; + driver.zetddiInitResult = ZE_RESULT_SUCCESS; + } + res = loader::zerloaderInitDriverDDITables(&driver); + if (res != ZE_RESULT_SUCCESS) { + if (debugTraceEnabled) { + std::string message = "init driver " + driver.name + " failed, zerloaderInitDriverDDITables returned "; + debug_trace_message(message, loader::to_string(res)); } + driver.zerddiInitResult = res; + driver.zerDriverHandle = nullptr; + } else { + driver.zerddiInitResult = ZE_RESULT_SUCCESS; + } + driver.ddiInitialized = true; + } - uint32_t pCount = 0; - // Use the previously init ddi table pointer to zeInitDrivers to allow for intercept of the zeInitDrivers calls - ze_result_t res = globalInitStored->pfnInitDrivers(&pCount, nullptr, desc); - // Verify that this driver successfully init in the call above. - if (res != ZE_RESULT_SUCCESS || driver.initDriversStatus != ZE_RESULT_SUCCESS) { - if (driver.initDriversStatus != ZE_RESULT_SUCCESS) - res = driver.initDriversStatus; - if (debugTraceEnabled) { - std::string message = "init driver (global ddi) " + driver.name + " zeInitDrivers(" + loader::to_string(desc) + ") returning "; - debug_trace_message(message, loader::to_string(res)); - } - return res; - } + if (!driver.handle && !driver.ddiInitialized) { + if (debugTraceEnabled) { + std::string message = "init driver " + driver.name + " does not match the requested flags or desc, skipping driver."; + debug_trace_message(message, ""); } - return ZE_RESULT_SUCCESS; + return ZE_RESULT_ERROR_UNINITIALIZED; } + + return ZE_RESULT_SUCCESS; } /////////////////////////////////////////////////////////////////////////////// @@ -672,6 +587,8 @@ namespace loader if (driverEnvironmentQueried) { return ZE_RESULT_SUCCESS; } + loader::context->instrumentationEnabled = getenv_tobool( "ZET_ENABLE_PROGRAM_INSTRUMENTATION" ); + loader::context->pciOrderingRequested = getenv_tobool( "ZE_ENABLE_PCI_ID_DEVICE_ORDER" ); loader::loaderDispatch = new ze_handle_t(); loader::loaderDispatch->pCore = new ze_dditable_driver_t(); loader::loaderDispatch->pCore->version = ZE_API_VERSION_CURRENT; @@ -759,30 +676,39 @@ namespace loader } } - for( auto name : discoveredDrivers ) + for( auto driverInfo : discoveredDrivers ) { - auto handle = LOAD_DRIVER_LIBRARY( name.c_str() ); - if( NULL != handle ) - { - if (debugTraceEnabled) { - std::string message = "Loading Driver " + name + " succeeded"; + if (discoveredDrivers.size() == 1) { + auto handle = LOAD_DRIVER_LIBRARY( driverInfo.path.c_str() ); + if( NULL != handle ) + { + if (debugTraceEnabled) { + std::string message = "Loading Driver " + driverInfo.path + " succeeded"; #if !defined(_WIN32) && !defined(ANDROID) - // TODO: implement same message for windows, move dlinfo to ze_util.h as a macro - struct link_map *dlinfo_map; - if (dlinfo(handle, RTLD_DI_LINKMAP, &dlinfo_map) == 0) { - message += " from: " + std::string(dlinfo_map->l_name); - } + // TODO: implement same message for windows, move dlinfo to ze_util.h as a macro + struct link_map *dlinfo_map; + if (dlinfo(handle, RTLD_DI_LINKMAP, &dlinfo_map) == 0) { + message += " from: " + std::string(dlinfo_map->l_name); + } #endif - debug_trace_message(message, ""); + debug_trace_message(message, ""); + } + allDrivers.emplace_back(); + allDrivers.rbegin()->handle = handle; + allDrivers.rbegin()->name = driverInfo.path; + allDrivers.rbegin()->customDriver = driverInfo.customDriver; + } else if (debugTraceEnabled) { + GET_LIBRARY_ERROR(loadLibraryErrorValue); + std::string errorMessage = "Load Library of " + driverInfo.path + " failed with "; + debug_trace_message(errorMessage, loadLibraryErrorValue); + loadLibraryErrorValue.clear(); } + } else { allDrivers.emplace_back(); - allDrivers.rbegin()->handle = handle; - allDrivers.rbegin()->name = name; - } else if (debugTraceEnabled) { - GET_LIBRARY_ERROR(loadLibraryErrorValue); - std::string errorMessage = "Load Library of " + name + " failed with "; - debug_trace_message(errorMessage, loadLibraryErrorValue); - loadLibraryErrorValue.clear(); + allDrivers.rbegin()->handle = nullptr; + allDrivers.rbegin()->name = driverInfo.path; + allDrivers.rbegin()->customDriver = driverInfo.customDriver; + allDrivers.rbegin()->driverType = driverInfo.driverType; } } if(allDrivers.size()==0){ @@ -861,6 +787,10 @@ namespace loader driverEnvironmentQueried = true; + // Set default driver zer DDI table to the first driver in the list before sorting. + // Leave the zer Driver Handle as nullptr until init when the drivers are sorted and initialized. + loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer; + zel_logger->log_info("zeInit succeeded"); return ZE_RESULT_SUCCESS; }; diff --git a/source/loader/ze_loader_api.cpp b/source/loader/ze_loader_api.cpp index c15b9008..60bd4f73 100644 --- a/source/loader/ze_loader_api.cpp +++ b/source/loader/ze_loader_api.cpp @@ -28,14 +28,14 @@ zeLoaderInit() /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for verifying usable L0 Drivers for Loader to report -/// +/// @deprecated This function is deprecated and will be removed in a future release. /// @returns /// - ::ZE_RESULT_SUCCESS /// - ::ZE_RESULT_ERROR_UNINITIALIZED ZE_DLLEXPORT ze_result_t ZE_APICALL zelLoaderDriverCheck(ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly) { - return loader::context->check_drivers(flags, desc, globalInitStored, sysmanGlobalInitStored, requireDdiReinit, sysmanOnly); + return ZE_RESULT_SUCCESS; } /////////////////////////////////////////////////////////////////////////////// diff --git a/source/loader/ze_loader_api.h b/source/loader/ze_loader_api.h index 876f51ef..eecdbb46 100644 --- a/source/loader/ze_loader_api.h +++ b/source/loader/ze_loader_api.h @@ -28,7 +28,7 @@ zeLoaderInit(); /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for verifying usable L0 Drivers for Loader to report -/// +/// @deprecated This function is deprecated and will be removed in a future release. /// @returns /// - ::ZE_RESULT_SUCCESS /// - ::ZE_RESULT_ERROR_UNINITIALIZED diff --git a/source/loader/ze_loader_internal.h b/source/loader/ze_loader_internal.h index 894d6bf0..8cc4df6c 100644 --- a/source/loader/ze_loader_internal.h +++ b/source/loader/ze_loader_internal.h @@ -61,6 +61,13 @@ namespace loader bool driverDDIHandleSupportQueried = false; ze_driver_handle_t zerDriverHandle = nullptr; bool zerDriverDDISupported = true; + ze_api_version_t versionRequested = ZE_API_VERSION_CURRENT; + bool ddiInitialized = false; + bool customDriver = false; + ze_result_t zeddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED; + ze_result_t zetddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED; + ze_result_t zesddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED; + ze_result_t zerddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED; }; using driver_vector_t = std::vector< driver_t >; @@ -134,6 +141,7 @@ namespace loader std::unordered_map sampler_handle_map; ze_api_version_t version = ZE_API_VERSION_CURRENT; ze_api_version_t configured_version = ZE_API_VERSION_CURRENT; + ze_api_version_t ddi_init_version = ZE_API_VERSION_CURRENT; driver_vector_t allDrivers; driver_vector_t zeDrivers; @@ -149,10 +157,9 @@ namespace loader std::vector compVersions; const char *LOADER_COMP_NAME = "loader"; - ze_result_t check_drivers(ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly); void debug_trace_message(std::string errorMessage, std::string errorValue); ze_result_t init(); - ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly); + ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc); void add_loader_version(); bool driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly); void driverOrdering(driver_vector_t *drivers); @@ -166,6 +173,7 @@ namespace loader std::atomic sortingInProgress = {false}; std::mutex sortMutex; bool instrumentationEnabled = false; + bool pciOrderingRequested = false; dditable_t tracing_dditable = {}; std::shared_ptr zel_logger; ze_driver_handle_t defaultZerDriverHandle = nullptr; diff --git a/source/loader/zer_ldrddi.cpp b/source/loader/zer_ldrddi.cpp index f63cce5b..60b9710a 100644 --- a/source/loader/zer_ldrddi.cpp +++ b/source/loader/zer_ldrddi.cpp @@ -13,6 +13,15 @@ using namespace loader_driver_ddi; namespace loader { + __zedlllocal ze_result_t ZE_APICALL + zerloaderInitDriverDDITables(loader::driver_t *driver) { + ze_result_t result = ZE_RESULT_SUCCESS; + result = zerGetGlobalProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + return result; + } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for zerGetLastErrorDescription __zedlllocal ze_result_t ZE_APICALL @@ -114,6 +123,28 @@ zerGetGlobalProcAddrTableLegacy() } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Global table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zerGetGlobalProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zerGetGlobalProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zer.Global); + return result; +} /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Global table /// with current process' addresses @@ -139,21 +170,16 @@ zerGetGlobalProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zerGetGlobalProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zer.Global); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zerGetGlobalProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) diff --git a/source/loader/zer_ldrddi.h b/source/loader/zer_ldrddi.h index 5e4ad2e4..c64eec07 100644 --- a/source/loader/zer_ldrddi.h +++ b/source/loader/zer_ldrddi.h @@ -12,6 +12,12 @@ namespace loader { /////////////////////////////////////////////////////////////////////////////// + // Forward declaration for driver_t so this header can reference loader::driver_t* + // without requiring inclusion of ze_loader_internal.h (which includes this file). + struct driver_t; + /////////////////////////////////////////////////////////////////////////////// + __zedlllocal ze_result_t ZE_APICALL + zerloaderInitDriverDDITables(loader::driver_t *driver); } namespace loader_driver_ddi @@ -43,6 +49,8 @@ extern "C" { __zedlllocal void ZE_APICALL zerGetGlobalProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zerGetGlobalProcAddrTableFromDriver(loader::driver_t *driver); #if defined(__cplusplus) }; diff --git a/source/loader/zer_ldrddi_driver_ddi.cpp b/source/loader/zer_ldrddi_driver_ddi.cpp index c878e4f5..b9d674b9 100644 --- a/source/loader/zer_ldrddi_driver_ddi.cpp +++ b/source/loader/zer_ldrddi_driver_ddi.cpp @@ -29,6 +29,9 @@ namespace loader_driver_ddi // Check if the default driver supports DDI Handles if (loader::context->defaultZerDriverHandle == nullptr) { + if (loader::context->zeDrivers.front().zerddiInitResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) { + return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE; + } return ZE_RESULT_ERROR_UNINITIALIZED; } auto dditable = reinterpret_cast( loader::context->defaultZerDriverHandle )->pRuntime; @@ -63,6 +66,10 @@ namespace loader_driver_ddi // Check if the default driver supports DDI Handles if (loader::context->defaultZerDriverHandle == nullptr) { + if (loader::context->zeDrivers.front().zerddiInitResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) { + error_state::setErrorDesc("ERROR UNSUPPORTED FEATURE"); + return UINT32_MAX; + } error_state::setErrorDesc("ERROR UNINITIALIZED"); return UINT32_MAX; } @@ -102,6 +109,10 @@ namespace loader_driver_ddi // Check if the default driver supports DDI Handles if (loader::context->defaultZerDriverHandle == nullptr) { + if (loader::context->zeDrivers.front().zerddiInitResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) { + error_state::setErrorDesc("ERROR UNSUPPORTED FEATURE"); + return nullptr; + } error_state::setErrorDesc("ERROR UNINITIALIZED"); return nullptr; } @@ -141,6 +152,10 @@ namespace loader_driver_ddi // Check if the default driver supports DDI Handles if (loader::context->defaultZerDriverHandle == nullptr) { + if (loader::context->zeDrivers.front().zerddiInitResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) { + error_state::setErrorDesc("ERROR UNSUPPORTED FEATURE"); + return nullptr; + } error_state::setErrorDesc("ERROR UNINITIALIZED"); return nullptr; } diff --git a/source/loader/zes_ldrddi.cpp b/source/loader/zes_ldrddi.cpp index ed93eeb6..6fca85b6 100644 --- a/source/loader/zes_ldrddi.cpp +++ b/source/loader/zes_ldrddi.cpp @@ -13,6 +13,107 @@ using namespace loader_driver_ddi; namespace loader { + __zedlllocal ze_result_t ZE_APICALL + zesloaderInitDriverDDITables(loader::driver_t *driver) { + ze_result_t result = ZE_RESULT_SUCCESS; + result = zesGetGlobalProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetDeviceProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetDeviceExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetDriverProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetDriverExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetDiagnosticsProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetEngineProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetFabricPortProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetFanProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetFirmwareProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetFirmwareExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetFrequencyProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetLedProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetMemoryProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetOverclockProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetPerformanceFactorProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetPowerProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetPsuProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetRasProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetRasExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetSchedulerProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetStandbyProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetTemperatureProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zesGetVFManagementExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + return result; + } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for zesInit __zedlllocal ze_result_t ZE_APICALL @@ -28,6 +129,12 @@ namespace loader { if(drv.initStatus != ZE_RESULT_SUCCESS || drv.initSysManStatus != ZE_RESULT_SUCCESS) continue; + if (!drv.handle || !drv.ddiInitialized) { + auto res = loader::context->init_driver( drv, flags, nullptr ); + if (res != ZE_RESULT_SUCCESS) { + continue; + } + } if (!drv.dditable.zes.Global.pfnInit) { drv.initSysManStatus = ZE_RESULT_ERROR_UNINITIALIZED; continue; @@ -74,7 +181,7 @@ namespace loader for( auto& drv : *loader::context->sysmanInstanceDrivers ) { - if(drv.initStatus != ZE_RESULT_SUCCESS || drv.initSysManStatus != ZE_RESULT_SUCCESS) + if(drv.initStatus != ZE_RESULT_SUCCESS || drv.initSysManStatus != ZE_RESULT_SUCCESS || !drv.ddiInitialized) continue; if( ( 0 < *pCount ) && ( *pCount == total_driver_handle_count)) @@ -4927,6 +5034,616 @@ zesGetVFManagementExpProcAddrTableLegacy() } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Global table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetGlobalProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetGlobalProcAddrTable") ); + if(!getTable) + { + //It is valid to not have this proc addr table + return ZE_RESULT_SUCCESS; + } + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Global); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Device table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetDeviceProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetDeviceProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Device); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's DeviceExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetDeviceExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetDeviceExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zes.DeviceExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Driver table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetDriverProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetDriverProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Driver); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's DriverExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetDriverExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetDriverExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zes.DriverExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Diagnostics table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetDiagnosticsProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetDiagnosticsProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Diagnostics); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Engine table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetEngineProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetEngineProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Engine); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's FabricPort table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetFabricPortProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetFabricPortProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.FabricPort); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Fan table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetFanProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetFanProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Fan); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Firmware table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetFirmwareProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetFirmwareProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Firmware); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's FirmwareExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetFirmwareExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetFirmwareExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zes.FirmwareExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Frequency table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetFrequencyProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetFrequencyProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Frequency); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Led table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetLedProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetLedProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Led); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Memory table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetMemoryProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetMemoryProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Memory); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Overclock table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetOverclockProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetOverclockProcAddrTable") ); + if(!getTable) + { + //It is valid to not have this proc addr table + return ZE_RESULT_SUCCESS; + } + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Overclock); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's PerformanceFactor table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetPerformanceFactorProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetPerformanceFactorProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.PerformanceFactor); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Power table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetPowerProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetPowerProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Power); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Psu table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetPsuProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetPsuProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Psu); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Ras table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetRasProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetRasProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Ras); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's RasExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetRasExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetRasExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zes.RasExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Scheduler table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetSchedulerProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetSchedulerProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Scheduler); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Standby table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetStandbyProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetStandbyProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Standby); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Temperature table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetTemperatureProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetTemperatureProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zes.Temperature); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's VFManagementExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zesGetVFManagementExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zesGetVFManagementExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zes.VFManagementExp); + return result; +} /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Global table /// with current process' addresses @@ -4952,35 +5669,16 @@ zesGetGlobalProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetGlobalProcAddrTable") ); - if(!getTable) - { - atLeastOneDriverValid = true; - //It is valid to not have this proc addr table - continue; - } - auto getTableResult = getTable( version, &drv.dditable.zes.Global); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetGlobalProcAddrTableFromDriver(firstDriver); } - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->sysmanInstanceDrivers->size() > 1 ) || loader::context->forceIntercept ) @@ -5037,30 +5735,15 @@ zesGetDeviceProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetDeviceProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Device); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetDeviceProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -5374,21 +6057,16 @@ zesGetDeviceExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetDeviceExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zes.DeviceExp); + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetDeviceExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->sysmanInstanceDrivers->size() > 1 ) || loader::context->forceIntercept ) @@ -5463,30 +6141,15 @@ zesGetDriverProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetDriverProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Driver); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetDriverProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -5572,21 +6235,16 @@ zesGetDriverExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetDriverExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zes.DriverExp); + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetDriverExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->sysmanInstanceDrivers->size() > 1 ) || loader::context->forceIntercept ) @@ -5647,30 +6305,15 @@ zesGetDiagnosticsProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetDiagnosticsProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Diagnostics); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetDiagnosticsProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -5746,30 +6389,15 @@ zesGetEngineProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetEngineProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Engine); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetEngineProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -5845,30 +6473,15 @@ zesGetFabricPortProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetFabricPortProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.FabricPort); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetFabricPortProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -5979,30 +6592,15 @@ zesGetFanProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetFanProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Fan); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetFanProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -6099,30 +6697,15 @@ zesGetFirmwareProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetFirmwareProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Firmware); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetFirmwareProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -6205,21 +6788,16 @@ zesGetFirmwareExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetFirmwareExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zes.FirmwareExp); + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetFirmwareExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->sysmanInstanceDrivers->size() > 1 ) || loader::context->forceIntercept ) @@ -6287,30 +6865,15 @@ zesGetFrequencyProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetFrequencyProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Frequency); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetFrequencyProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -6484,30 +7047,15 @@ zesGetLedProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetLedProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Led); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetLedProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -6590,30 +7138,15 @@ zesGetMemoryProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetMemoryProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Memory); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetMemoryProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -6689,35 +7222,16 @@ zesGetOverclockProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetOverclockProcAddrTable") ); - if(!getTable) - { - atLeastOneDriverValid = true; - //It is valid to not have this proc addr table - continue; - } - auto getTableResult = getTable( version, &drv.dditable.zes.Overclock); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetOverclockProcAddrTableFromDriver(firstDriver); } - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->sysmanInstanceDrivers->size() > 1 ) || loader::context->forceIntercept ) @@ -6834,30 +7348,15 @@ zesGetPerformanceFactorProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetPerformanceFactorProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.PerformanceFactor); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetPerformanceFactorProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -6933,30 +7432,15 @@ zesGetPowerProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetPowerProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Power); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetPowerProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -7067,30 +7551,15 @@ zesGetPsuProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetPsuProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Psu); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetPsuProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -7159,30 +7628,15 @@ zesGetRasProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetRasProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Ras); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetRasProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -7265,21 +7719,16 @@ zesGetRasExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetRasExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zes.RasExp); + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetRasExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->sysmanInstanceDrivers->size() > 1 ) || loader::context->forceIntercept ) @@ -7347,30 +7796,15 @@ zesGetSchedulerProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetSchedulerProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Scheduler); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetSchedulerProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -7481,30 +7915,15 @@ zesGetStandbyProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetStandbyProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Standby); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetStandbyProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -7580,30 +7999,15 @@ zesGetTemperatureProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetTemperatureProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zes.Temperature); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetTemperatureProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -7686,21 +8090,16 @@ zesGetVFManagementExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : *loader::context->sysmanInstanceDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zesGetVFManagementExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zes.VFManagementExp); + auto driverCount = loader::context->sysmanInstanceDrivers->size(); + auto firstDriver = &loader::context->sysmanInstanceDrivers->at(0); + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zesGetVFManagementExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->sysmanInstanceDrivers->size() > 1 ) || loader::context->forceIntercept ) diff --git a/source/loader/zes_ldrddi.h b/source/loader/zes_ldrddi.h index a9438284..0e852851 100644 --- a/source/loader/zes_ldrddi.h +++ b/source/loader/zes_ldrddi.h @@ -11,6 +11,10 @@ namespace loader { + /////////////////////////////////////////////////////////////////////////////// + // Forward declaration for driver_t so this header can reference loader::driver_t* + // without requiring inclusion of ze_loader_internal.h (which includes this file). + struct driver_t; /////////////////////////////////////////////////////////////////////////////// using zes_driver_object_t = object_t < zes_driver_handle_t >; using zes_driver_factory_t = singleton_factory_t < zes_driver_object_t, zes_driver_handle_t >; @@ -69,6 +73,8 @@ namespace loader using zes_vf_object_t = object_t < zes_vf_handle_t >; using zes_vf_factory_t = singleton_factory_t < zes_vf_object_t, zes_vf_handle_t >; + __zedlllocal ze_result_t ZE_APICALL + zesloaderInitDriverDDITables(loader::driver_t *driver); } namespace loader_driver_ddi @@ -1264,52 +1270,100 @@ extern "C" { __zedlllocal void ZE_APICALL zesGetGlobalProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetGlobalProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetDeviceProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetDeviceProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetDeviceExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetDeviceExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetDriverProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetDriverProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetDriverExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetDriverExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetDiagnosticsProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetDiagnosticsProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetEngineProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetEngineProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetFabricPortProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetFabricPortProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetFanProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetFanProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetFirmwareProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetFirmwareProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetFirmwareExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetFirmwareExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetFrequencyProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetFrequencyProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetLedProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetLedProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetMemoryProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetMemoryProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetOverclockProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetOverclockProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetPerformanceFactorProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetPerformanceFactorProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetPowerProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetPowerProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetPsuProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetPsuProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetRasProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetRasProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetRasExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetRasExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetSchedulerProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetSchedulerProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetStandbyProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetStandbyProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetTemperatureProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetTemperatureProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zesGetVFManagementExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zesGetVFManagementExpProcAddrTableFromDriver(loader::driver_t *driver); #if defined(__cplusplus) }; diff --git a/source/loader/zet_ldrddi.cpp b/source/loader/zet_ldrddi.cpp index 01a9beef..43825040 100644 --- a/source/loader/zet_ldrddi.cpp +++ b/source/loader/zet_ldrddi.cpp @@ -13,6 +13,87 @@ using namespace loader_driver_ddi; namespace loader { + __zedlllocal ze_result_t ZE_APICALL + zetloaderInitDriverDDITables(loader::driver_t *driver) { + ze_result_t result = ZE_RESULT_SUCCESS; + result = zetGetMetricDecoderExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetMetricProgrammableExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetMetricTracerExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetDeviceProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetDeviceExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetContextProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetCommandListProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetCommandListExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetKernelProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetModuleProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetDebugProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetMetricProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetMetricExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetMetricGroupProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetMetricGroupExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetMetricQueryProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetMetricQueryPoolProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetMetricStreamerProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + result = zetGetTracerExpProcAddrTableFromDriver(driver); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + return result; + } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for zetModuleGetDebugInfo __zedlllocal ze_result_t ZE_APICALL @@ -2601,6 +2682,472 @@ zetGetTracerExpProcAddrTableLegacy() } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's MetricDecoderExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricDecoderExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetMetricDecoderExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zet.MetricDecoderExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's MetricProgrammableExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricProgrammableExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetMetricProgrammableExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zet.MetricProgrammableExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's MetricTracerExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricTracerExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetMetricTracerExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zet.MetricTracerExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Device table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetDeviceProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetDeviceProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.Device); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's DeviceExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetDeviceExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetDeviceExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zet.DeviceExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Context table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetContextProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetContextProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.Context); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's CommandList table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetCommandListProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetCommandListProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.CommandList); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's CommandListExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetCommandListExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetCommandListExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zet.CommandListExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Kernel table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetKernelProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetKernelProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.Kernel); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Module table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetModuleProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetModuleProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.Module); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Debug table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetDebugProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetDebugProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.Debug); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Metric table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetMetricProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.Metric); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's MetricExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetMetricExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zet.MetricExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's MetricGroup table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricGroupProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetMetricGroupProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.MetricGroup); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's MetricGroupExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricGroupExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetMetricGroupExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + result = getTable( loader::context->ddi_init_version, &driver->dditable.zet.MetricGroupExp); + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's MetricQuery table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricQueryProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetMetricQueryProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.MetricQuery); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's MetricQueryPool table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricQueryPoolProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetMetricQueryPoolProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.MetricQueryPool); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's MetricStreamer table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricStreamerProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetMetricStreamerProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.MetricStreamer); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's TracerExp table +/// with current process' addresses +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_UNINITIALIZED +/// - ::ZE_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::ZE_RESULT_ERROR_UNSUPPORTED_VERSION +__zedlllocal ze_result_t ZE_APICALL +zetGetTracerExpProcAddrTableFromDriver(loader::driver_t *driver) +{ + ze_result_t result = ZE_RESULT_SUCCESS; + if(driver->initStatus != ZE_RESULT_SUCCESS) + return driver->initStatus; + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR( driver->handle, "zetGetTracerExpProcAddrTable") ); + if(!getTable) + return driver->initStatus; + auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.zet.TracerExp); + if(getTableResult == ZE_RESULT_SUCCESS) { + loader::context->configured_version = loader::context->ddi_init_version; + } else + driver->initStatus = getTableResult; + return result; +} /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's MetricDecoderExp table /// with current process' addresses @@ -2626,21 +3173,16 @@ zetGetMetricDecoderExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetMetricDecoderExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zet.MetricDecoderExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetMetricDecoderExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -2715,21 +3257,16 @@ zetGetMetricProgrammableExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetMetricProgrammableExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zet.MetricProgrammableExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetMetricProgrammableExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -2811,21 +3348,16 @@ zetGetMetricTracerExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetMetricTracerExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zet.MetricTracerExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetMetricTracerExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -2921,30 +3453,15 @@ zetGetDeviceProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetDeviceProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.Device); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetDeviceProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -3006,21 +3523,16 @@ zetGetDeviceExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetDeviceExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zet.DeviceExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetDeviceExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -3102,30 +3614,15 @@ zetGetContextProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetContextProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.Context); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetContextProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -3187,30 +3684,15 @@ zetGetCommandListProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetCommandListProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.CommandList); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetCommandListProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -3293,21 +3775,16 @@ zetGetCommandListExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetCommandListExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zet.CommandListExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetCommandListExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -3368,30 +3845,15 @@ zetGetKernelProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetKernelProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.Kernel); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetKernelProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -3453,30 +3915,15 @@ zetGetModuleProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetModuleProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.Module); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetModuleProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -3538,30 +3985,15 @@ zetGetDebugProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetDebugProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.Debug); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetDebugProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -3700,30 +4132,15 @@ zetGetMetricProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetMetricProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.Metric); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetMetricProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -3792,21 +4209,16 @@ zetGetMetricExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetMetricExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zet.MetricExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetMetricExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -3881,30 +4293,15 @@ zetGetMetricGroupProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetMetricGroupProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.MetricGroup); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetMetricGroupProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -3980,21 +4377,16 @@ zetGetMetricGroupExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetMetricGroupExpProcAddrTable") ); - if(!getTable) - continue; - result = getTable( version, &drv.dditable.zet.MetricGroupExp); + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetMetricGroupExpProcAddrTableFromDriver(firstDriver); } - if( ZE_RESULT_SUCCESS == result ) { if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept ) @@ -4111,30 +4503,15 @@ zetGetMetricQueryProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetMetricQueryProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.MetricQuery); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetMetricQueryProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -4217,30 +4594,15 @@ zetGetMetricQueryPoolProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetMetricQueryPoolProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.MetricQueryPool); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetMetricQueryPoolProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -4309,30 +4671,15 @@ zetGetMetricStreamerProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetMetricStreamerProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.MetricStreamer); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetMetricStreamerProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { @@ -4408,30 +4755,15 @@ zetGetTracerExpProcAddrTable( if( loader::context->version < version ) return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + loader::context->ddi_init_version = version; + ze_result_t result = ZE_RESULT_SUCCESS; - bool atLeastOneDriverValid = false; - // Load the device-driver DDI tables - for( auto& drv : loader::context->zeDrivers ) - { - if(drv.initStatus != ZE_RESULT_SUCCESS) - continue; - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR( drv.handle, "zetGetTracerExpProcAddrTable") ); - if(!getTable) - continue; - auto getTableResult = getTable( version, &drv.dditable.zet.TracerExp); - if(getTableResult == ZE_RESULT_SUCCESS) { - atLeastOneDriverValid = true; - loader::context->configured_version = version; - } else - drv.initStatus = getTableResult; - } - - if(!atLeastOneDriverValid) - result = ZE_RESULT_ERROR_UNINITIALIZED; - else - result = ZE_RESULT_SUCCESS; + auto driverCount = loader::context->zeDrivers.size(); + auto firstDriver = &loader::context->zeDrivers[0]; + if (driverCount == 1 && firstDriver && !loader::context->forceIntercept) { + result = zetGetTracerExpProcAddrTableFromDriver(firstDriver); + } if( ZE_RESULT_SUCCESS == result ) { diff --git a/source/loader/zet_ldrddi.h b/source/loader/zet_ldrddi.h index c65bc97f..1916aeff 100644 --- a/source/loader/zet_ldrddi.h +++ b/source/loader/zet_ldrddi.h @@ -11,6 +11,10 @@ namespace loader { + /////////////////////////////////////////////////////////////////////////////// + // Forward declaration for driver_t so this header can reference loader::driver_t* + // without requiring inclusion of ze_loader_internal.h (which includes this file). + struct driver_t; /////////////////////////////////////////////////////////////////////////////// using zet_driver_object_t = object_t < zet_driver_handle_t >; using zet_driver_factory_t = singleton_factory_t < zet_driver_object_t, zet_driver_handle_t >; @@ -60,6 +64,8 @@ namespace loader using zet_metric_programmable_exp_object_t = object_t < zet_metric_programmable_exp_handle_t >; using zet_metric_programmable_exp_factory_t = singleton_factory_t < zet_metric_programmable_exp_object_t, zet_metric_programmable_exp_handle_t >; + __zedlllocal ze_result_t ZE_APICALL + zetloaderInitDriverDDITables(loader::driver_t *driver); } namespace loader_driver_ddi @@ -738,42 +744,80 @@ extern "C" { __zedlllocal void ZE_APICALL zetGetMetricDecoderExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricDecoderExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetMetricProgrammableExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricProgrammableExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetMetricTracerExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricTracerExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetDeviceProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetDeviceProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetDeviceExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetDeviceExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetContextProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetContextProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetCommandListProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetCommandListProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetCommandListExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetCommandListExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetKernelProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetKernelProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetModuleProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetModuleProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetDebugProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetDebugProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetMetricProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetMetricExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetMetricGroupProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricGroupProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetMetricGroupExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricGroupExpProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetMetricQueryProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricQueryProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetMetricQueryPoolProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricQueryPoolProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetMetricStreamerProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetMetricStreamerProcAddrTableFromDriver(loader::driver_t *driver); __zedlllocal void ZE_APICALL zetGetTracerExpProcAddrTableLegacy(); +__zedlllocal ze_result_t ZE_APICALL +zetGetTracerExpProcAddrTableFromDriver(loader::driver_t *driver); #if defined(__cplusplus) }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c7b81c1c..7c77e35f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -15,6 +15,18 @@ if(BUILD_STATIC OR NOT WIN32) target_sources(tests PRIVATE driver_ordering_unit_tests.cpp) endif() +# For builds on non-Windows platforms, include init_driver_unit_tests +# The tests rely on ablity to locate fake drivers by specific names which is not possible on Windows. +if(NOT WIN32) + target_sources(tests PRIVATE init_driver_unit_tests.cpp) +endif() + +# Only build init_driver_dynamic_unit_tests for dynamic builds on non-Windows platforms +# as it requires internal loader symbols that are not exported in Windows DLLs +if(NOT BUILD_STATIC AND NOT WIN32) + target_sources(tests PRIVATE init_driver_dynamic_unit_tests.cpp) +endif() + # For static builds, we need to include the loader source files directly # since they are not built into the library when BUILD_STATIC=1 if(BUILD_STATIC) @@ -70,6 +82,37 @@ if(BUILD_STATIC) endif() endif() +# Create fake driver copies for init_driver_unit_tests +if(NOT BUILD_STATIC) + if(MSVC) + add_custom_command(TARGET tests POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + $/ze_fake_gpu.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + $/ze_fake_npu.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + $/ze_fake_vpu.dll + COMMENT "Copying null drivers to fake driver names for init_driver_unit_tests" + ) + else() + add_custom_command(TARGET tests POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${CMAKE_BINARY_DIR}/lib/libze_null.so.1 + ${CMAKE_BINARY_DIR}/lib/libze_fake_gpu.so.1 + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${CMAKE_BINARY_DIR}/lib/libze_null_test1.so.1 + ${CMAKE_BINARY_DIR}/lib/libze_fake_npu.so.1 + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${CMAKE_BINARY_DIR}/lib/libze_null_test2.so.1 + ${CMAKE_BINARY_DIR}/lib/libze_fake_vpu.so.1 + COMMENT "Copying null drivers to fake driver names for init_driver_unit_tests" + ) + endif() +endif() + add_test(NAME tests_api COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingzeGetLoaderVersionsAPIThenValidVersionIsReturned*) set_property(TEST tests_api PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") add_test(NAME tests_init_gpu_all COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingZeInitDriversWithGPUTypeThenExpectPassWithGPUorAllOnly*) @@ -90,7 +133,7 @@ add_test(NAME tests_both_gpu COMMAND tests --gtest_filter=*GivenLevelZeroLoaderP set_property(TEST tests_both_gpu PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") add_test(NAME tests_both_npu COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenBothCallsSucceedWithNPUTypes*) set_property(TEST tests_both_npu PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") -add_test(NAME tests_missing_api COMMAND tests --gtest_filter=*GivenZeInitDriversUnsupportedOnTheDriverWhenCallingZeInitDriversThenUninitializedReturned*) +add_test(NAME tests_missing_api COMMAND tests --gtest_filter=*GivenZeInitDriversUnsupportedOnTheDriverWhenCallingZeInitDriversThenUnSupportedReturned*) set_property(TEST tests_missing_api PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") add_test(NAME tests_multi_call_failure COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingZeInitDriversWithTypesUnsupportedWithFailureThenSupportedTypesThenSuccessReturned*) set_property(TEST tests_multi_call_failure PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") @@ -382,8 +425,8 @@ foreach(test_name IN ITEMS set_property(TEST ${test_name}_alt_drivers APPEND PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1;${ALT_DRIVERS_ENV}") endforeach() -add_test(NAME tests_sigle_driver_sysman_vf_management_api COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingSysManVfApisThenExpectNullDriverIsReachedSuccessfully) -set_property(TEST tests_sigle_driver_sysman_vf_management_api PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") +add_test(NAME tests_single_driver_sysman_vf_management_api COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingSysManVfApisThenExpectNullDriverIsReachedSuccessfully) +set_property(TEST tests_single_driver_sysman_vf_management_api PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") add_test(NAME tests_multi_driver_sysman_vf_management_api COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingSysManVfApisThenExpectNullDriverIsReachedSuccessfully) if (MSVC) @@ -597,6 +640,15 @@ set_property(TEST driver_ordering_trim_function PROPERTY ENVIRONMENT "ZE_ENABLE_ add_test(NAME driver_ordering_parse_driver_order COMMAND tests --gtest_filter=DriverOrderingHelperFunctionsTest.ParseDriverOrder_*) set_property(TEST driver_ordering_parse_driver_order PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") +# Init Driver Unit Tests +add_test(NAME init_driver_unit_tests COMMAND tests --gtest_filter=InitDriverUnitTest.*) +if (MSVC) + set_property(TEST init_driver_unit_tests PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1;") +elseif(NOT BUILD_STATIC) + set_property(TEST init_driver_unit_tests PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1;LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib_fake:${CMAKE_BINARY_DIR}/lib/") +else() + set_property(TEST init_driver_unit_tests PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1;") +endif() # These tests are currently not supported on Windows. The reason is that the std::cerr is not being redirected to a pipe in Windows to be then checked against the expected output. if(NOT MSVC) @@ -700,7 +752,7 @@ set_property(TEST test_ze_and_zer_tracing_static PROPERTY ENVIRONMENT "ZE_ENABLE add_test(NAME test_ze_and_zer_tracing_dynamic COMMAND tests --gtest_filter=*TracingParameterizedTest*GivenLoaderWithDynamicTracingEnabledAndBothZeAndZerCallbacksRegisteredWhenCallingBothApisThenBothAreTraced*) set_property(TEST test_ze_and_zer_tracing_dynamic PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") -add_test(NAME test_zer_unsupported_and_ze_tracing_dynamic COMMAND tests --gtest_filter=*TracingParameterizedTest*GivenLoaderWithDynamicTracingEnabledAndZerApisUnsupportedAndBothZeAndZerCallbacksRegisteredWhenCallingBothApisThenTracingWorksForZeOnly*) +add_test(NAME test_zer_unsupported_and_ze_tracing_dynamic COMMAND tests --gtest_filter=*TracingParameterizedTest*GivenLoaderWithDynamicTracingEnabledAndZerApisUnsupportedAndBothZeAndZerCallbacksRegisteredWhenCallingBothApisThenTracingWorksForZeAndZerCallbacksAreStillInvoked*) set_property(TEST test_zer_unsupported_and_ze_tracing_dynamic PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") # ZER API Validation Layer Tests diff --git a/test/init_driver_dynamic_unit_tests.cpp b/test/init_driver_dynamic_unit_tests.cpp new file mode 100644 index 00000000..800ae427 --- /dev/null +++ b/test/init_driver_dynamic_unit_tests.cpp @@ -0,0 +1,255 @@ +/* + * Copyright (C) 2025 Intel Corporation + * SPDX-License-Identifier: MIT + */ + + +#include "test/init_driver_unit_tests_common.h" + + +TEST_F(InitDriverUnitTest, zeInitDriversWithFakeIntelGPUAndNPU_InitGPUFirst) { + // First, initialize GPU drivers using zeInitDrivers with GPU flag + uint32_t driverCount = 0; + ze_init_driver_type_desc_t descGPU = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + descGPU.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU; + descGPU.pNext = nullptr; + + // Call zeInitDrivers to init GPU drivers + ze_result_t result = zeInitDrivers(&driverCount, nullptr, &descGPU); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + if (driverCount > 0) { + std::vector gpuDrivers(driverCount); + result = zeInitDrivers(&driverCount, gpuDrivers.data(), &descGPU); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + // Make sure we can call into GPU drivers + for (uint32_t i = 0; i < driverCount; i++) { + EXPECT_NE(gpuDrivers[i], nullptr); + + // Try to get devices from this GPU driver + uint32_t deviceCount = 0; + ze_result_t devResult = zeDeviceGet(gpuDrivers[i], &deviceCount, nullptr); + // Should succeed (even if 0 devices for null driver) + EXPECT_EQ(devResult, ZE_RESULT_SUCCESS); + } + + // Verify that GPU drivers were initialized in the context + bool foundGPUDriver = false; + for (const auto& driver : loader::context->zeDrivers) { + if (driver.driverType == loader::ZEL_DRIVER_TYPE_DISCRETE_GPU || + driver.driverType == loader::ZEL_DRIVER_TYPE_GPU || + driver.driverType == loader::ZEL_DRIVER_TYPE_INTEGRATED_GPU) { + foundGPUDriver = true; + EXPECT_TRUE(driver.ddiInitialized); + break; + } + } + EXPECT_TRUE(foundGPUDriver); + } +} + +TEST_F(InitDriverUnitTest, zeInitDriversWithFakeIntelGPUAndNPU_InitNPUFirst) { + // First, initialize NPU drivers using zeInitDrivers with NPU flag + uint32_t driverCount = 0; + ze_init_driver_type_desc_t descNPU = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + descNPU.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU; + descNPU.pNext = nullptr; + + // Call zeInitDrivers to init NPU drivers + ze_result_t result = zeInitDrivers(&driverCount, nullptr, &descNPU); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + if (driverCount > 0) { + std::vector npuDrivers(driverCount); + result = zeInitDrivers(&driverCount, npuDrivers.data(), &descNPU); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + // Make sure we can call into NPU drivers + for (uint32_t i = 0; i < driverCount; i++) { + EXPECT_NE(npuDrivers[i], nullptr); + + // Try to get devices from this NPU driver + uint32_t deviceCount = 0; + ze_result_t devResult = zeDeviceGet(npuDrivers[i], &deviceCount, nullptr); + // Should succeed (even if 0 devices for null driver) + EXPECT_EQ(devResult, ZE_RESULT_SUCCESS); + } + + // Verify that NPU drivers were initialized in the context + bool foundNPUDriver = false; + for (const auto& driver : loader::context->zeDrivers) { + printf("Driver Name: %s, Type: %u, DDI Initialized: %d\n", driver.name.c_str(), driver.driverType, driver.ddiInitialized); + if (driver.driverType == loader::ZEL_DRIVER_TYPE_NPU) { + foundNPUDriver = true; + EXPECT_TRUE(driver.ddiInitialized); + break; + } + } + EXPECT_TRUE(foundNPUDriver); + } +} + +TEST_F(InitDriverUnitTest, zeInitDriversWithFakeIntelGPUAndNPU_InitGPUThenNPU) { + // Step 1: Initialize GPU drivers first + uint32_t gpuDriverCount = 0; + ze_init_driver_type_desc_t descGPU = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + descGPU.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU; + descGPU.pNext = nullptr; + + ze_result_t result = zeInitDrivers(&gpuDriverCount, nullptr, &descGPU); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + std::vector gpuDrivers(gpuDriverCount); + if (gpuDriverCount > 0) { + result = zeInitDrivers(&gpuDriverCount, gpuDrivers.data(), &descGPU); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + // Verify we can call GPU driver APIs + for (uint32_t i = 0; i < gpuDriverCount; i++) { + EXPECT_NE(gpuDrivers[i], nullptr); + uint32_t deviceCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGet(gpuDrivers[i], &deviceCount, nullptr)); + } + } + + // Step 2: Now initialize NPU drivers + uint32_t npuDriverCount = 0; + ze_init_driver_type_desc_t descNPU = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + descNPU.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU; + descNPU.pNext = nullptr; + + result = zeInitDrivers(&npuDriverCount, nullptr, &descNPU); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + std::vector npuDrivers(npuDriverCount); + if (npuDriverCount > 0) { + result = zeInitDrivers(&npuDriverCount, npuDrivers.data(), &descNPU); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + // Verify we can call NPU driver APIs + for (uint32_t i = 0; i < npuDriverCount; i++) { + EXPECT_NE(npuDrivers[i], nullptr); + uint32_t deviceCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGet(npuDrivers[i], &deviceCount, nullptr)); + } + } + + // Step 3: Verify both GPU and NPU drivers still work + // Ensure GPU drivers still respond to API calls + for (uint32_t i = 0; i < gpuDriverCount; i++) { + uint32_t deviceCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGet(gpuDrivers[i], &deviceCount, nullptr)); + } + + // Ensure NPU drivers respond to API calls + for (uint32_t i = 0; i < npuDriverCount; i++) { + uint32_t deviceCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGet(npuDrivers[i], &deviceCount, nullptr)); + } +} + +TEST_F(InitDriverUnitTest, zeInitWithFakeIntelGPU_ThenzeInitDriversWithNPU) { + // Step 1: Use zeInit to initialize (typically gets all drivers) + ze_result_t result = zeInit(0); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + // Get GPU drivers using zeDriverGet + uint32_t allDriverCount = 0; + result = zeDriverGet(&allDriverCount, nullptr); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + std::vector allDrivers(allDriverCount); + if (allDriverCount > 0) { + result = zeDriverGet(&allDriverCount, allDrivers.data()); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + // Try calling APIs on the drivers returned by zeDriverGet + for (uint32_t i = 0; i < allDriverCount; i++) { + EXPECT_NE(allDrivers[i], nullptr); + uint32_t deviceCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGet(allDrivers[i], &deviceCount, nullptr)); + } + } + + // Step 2: Now use zeInitDrivers to get NPU drivers specifically + uint32_t npuDriverCount = 0; + ze_init_driver_type_desc_t descNPU = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + descNPU.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU; + descNPU.pNext = nullptr; + + result = zeInitDrivers(&npuDriverCount, nullptr, &descNPU); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + std::vector npuDrivers(npuDriverCount); + if (npuDriverCount > 0) { + result = zeInitDrivers(&npuDriverCount, npuDrivers.data(), &descNPU); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + // Verify we can call NPU driver APIs + for (uint32_t i = 0; i < npuDriverCount; i++) { + EXPECT_NE(npuDrivers[i], nullptr); + uint32_t deviceCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGet(npuDrivers[i], &deviceCount, nullptr)); + } + } + + // Step 3: Verify original drivers from zeDriverGet still work + for (uint32_t i = 0; i < allDriverCount; i++) { + uint32_t deviceCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGet(allDrivers[i], &deviceCount, nullptr)); + } +} + +TEST_F(InitDriverUnitTest, zeInitDriversWithAllTypes_VerifyRoutingToCorrectDriver) { + // Initialize all driver types + uint32_t allDriverCount = 0; + ze_init_driver_type_desc_t descAll = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + descAll.flags = UINT32_MAX; // All driver types + descAll.pNext = nullptr; + + ze_result_t result = zeInitDrivers(&allDriverCount, nullptr, &descAll); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + std::vector allDrivers(allDriverCount); + if (allDriverCount > 0) { + result = zeInitDrivers(&allDriverCount, allDrivers.data(), &descAll); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + + // Verify each driver responds to API calls + for (uint32_t i = 0; i < allDriverCount; i++) { + EXPECT_NE(allDrivers[i], nullptr); + + // Call zeDeviceGet to verify routing works + uint32_t deviceCount = 0; + ze_result_t devResult = zeDeviceGet(allDrivers[i], &deviceCount, nullptr); + EXPECT_EQ(devResult, ZE_RESULT_SUCCESS); + + // Try to get driver properties + ze_driver_properties_t props = {}; + props.stype = ZE_STRUCTURE_TYPE_DRIVER_PROPERTIES; + ze_result_t propResult = zeDriverGetProperties(allDrivers[i], &props); + // Should succeed for properly initialized drivers + EXPECT_EQ(propResult, ZE_RESULT_SUCCESS); + } + + // Verify that both GPU and NPU drivers are initialized in the context + bool foundGPUDriver = false; + bool foundNPUDriver = false; + for (const auto& driver : loader::context->zeDrivers) { + if (driver.driverType == loader::ZEL_DRIVER_TYPE_DISCRETE_GPU || + driver.driverType == loader::ZEL_DRIVER_TYPE_GPU || + driver.driverType == loader::ZEL_DRIVER_TYPE_INTEGRATED_GPU) { + foundGPUDriver = true; + EXPECT_TRUE(driver.ddiInitialized); + } + if (driver.driverType == loader::ZEL_DRIVER_TYPE_NPU) { + foundNPUDriver = true; + EXPECT_TRUE(driver.ddiInitialized); + } + } + + // At least one type of driver should be found if we have drivers + EXPECT_TRUE(foundGPUDriver || foundNPUDriver); + } +} \ No newline at end of file diff --git a/test/init_driver_unit_tests.cpp b/test/init_driver_unit_tests.cpp new file mode 100644 index 00000000..cbe76d37 --- /dev/null +++ b/test/init_driver_unit_tests.cpp @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2025 Intel Corporation + * SPDX-License-Identifier: MIT + */ + + +#include "test/init_driver_unit_tests_common.h" + +// Helper to create a mock null driver with a given name and type +loader::driver_t createNullDriver(const std::string& name, loader::zel_driver_type_t type) { + loader::driver_t driver; + + std::string libraryPath; + auto loaderLibraryPathEnv = getenv_string("ZEL_LIBRARY_PATH"); + if (!loaderLibraryPathEnv.empty()) { + libraryPath = loaderLibraryPathEnv; +#ifdef _WIN32 + libraryPath.append("\\"); +#else + libraryPath.append("/"); +#endif + } + +#ifdef _WIN32 + driver.name = libraryPath + name + ".dll"; +#else + driver.name = libraryPath + "lib" + name + ".so.1"; +#endif + driver.driverType = type; + driver.handle = nullptr; // Simulate null driver + driver.initStatus = ZE_RESULT_SUCCESS; + driver.driverInuse = false; + driver.ddiInitialized = false; + return driver; +} + + +TEST_F(InitDriverUnitTest, InitWithSingleGPUDriver) { + loader::driver_t gpuDriver = createNullDriver("ze_fake_gpu", loader::ZEL_DRIVER_TYPE_DISCRETE_GPU); + ze_result_t result = loader::context->init_driver(gpuDriver, ZE_INIT_FLAG_GPU_ONLY, nullptr); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + EXPECT_TRUE(gpuDriver.ddiInitialized); +} + +TEST_F(InitDriverUnitTest, InitWithSingleNPUDriver) { + loader::driver_t npuDriver = createNullDriver("ze_fake_npu", loader::ZEL_DRIVER_TYPE_NPU); + ze_init_driver_type_desc_t desc = {}; + desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU; + ze_result_t result = loader::context->init_driver(npuDriver, 0, &desc); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + EXPECT_TRUE(npuDriver.ddiInitialized); +} + +TEST_F(InitDriverUnitTest, InitWithSingleVPUDriver) { + loader::driver_t vpuDriver = createNullDriver("ze_fake_vpu", loader::ZEL_DRIVER_TYPE_NPU); + ze_result_t result = loader::context->init_driver(vpuDriver, ZE_INIT_FLAG_VPU_ONLY, nullptr); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + EXPECT_TRUE(vpuDriver.ddiInitialized); +} + +TEST_F(InitDriverUnitTest, zeInitWithMultipleDrivers) { + std::vector drivers = { + createNullDriver("ze_fake_gpu", loader::ZEL_DRIVER_TYPE_DISCRETE_GPU), + createNullDriver("ze_fake_npu", loader::ZEL_DRIVER_TYPE_NPU), + createNullDriver("ze_fake_vpu", loader::ZEL_DRIVER_TYPE_NPU) + }; + for (auto& driver : drivers) { + ze_result_t result = loader::context->init_driver(driver, 0, nullptr); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + EXPECT_TRUE(driver.ddiInitialized); + } +} + +TEST_F(InitDriverUnitTest, zeInitDriversWithMultipleDrivers) { + std::vector drivers = { + createNullDriver("ze_fake_gpu", loader::ZEL_DRIVER_TYPE_DISCRETE_GPU), + createNullDriver("ze_fake_npu", loader::ZEL_DRIVER_TYPE_NPU), + createNullDriver("ze_fake_vpu", loader::ZEL_DRIVER_TYPE_NPU) + }; + ze_init_driver_type_desc_t desc = {}; + desc.flags = UINT32_MAX; // Request all driver types + for (auto& driver : drivers) { + ze_result_t result = loader::context->init_driver(driver, 0, &desc); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + EXPECT_TRUE(driver.ddiInitialized); + } +} + +TEST_F(InitDriverUnitTest, zeInitDriversWithMultipleDriversNPURequested) { + std::vector drivers = { + createNullDriver("ze_fake_gpu", loader::ZEL_DRIVER_TYPE_DISCRETE_GPU), + createNullDriver("ze_fake_npu", loader::ZEL_DRIVER_TYPE_NPU), + createNullDriver("ze_fake_vpu", loader::ZEL_DRIVER_TYPE_NPU) + }; + ze_init_driver_type_desc_t desc = {}; + desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU; // Request NPU driver types + for (auto& driver : drivers) { + if (driver.driverType == loader::ZEL_DRIVER_TYPE_NPU) { + ze_result_t result = loader::context->init_driver(driver, 0, &desc); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + EXPECT_TRUE(driver.ddiInitialized); + } else { + ze_result_t result = loader::context->init_driver(driver, 0, &desc); + EXPECT_NE(result, ZE_RESULT_SUCCESS); + EXPECT_FALSE(driver.ddiInitialized); + } + } +} + +TEST_F(InitDriverUnitTest, zeInitDriversWithMultipleDriversGPURequested) { + std::vector drivers = { + createNullDriver("ze_fake_gpu", loader::ZEL_DRIVER_TYPE_DISCRETE_GPU), + createNullDriver("ze_fake_npu", loader::ZEL_DRIVER_TYPE_NPU), + createNullDriver("ze_fake_vpu", loader::ZEL_DRIVER_TYPE_NPU) + }; + ze_init_driver_type_desc_t desc = {}; + desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU; // Request GPU driver types + for (auto& driver : drivers) { + if (driver.driverType == loader::ZEL_DRIVER_TYPE_DISCRETE_GPU) { + ze_result_t result = loader::context->init_driver(driver, 0, &desc); + EXPECT_EQ(result, ZE_RESULT_SUCCESS); + EXPECT_TRUE(driver.ddiInitialized); + } else { + ze_result_t result = loader::context->init_driver(driver, 0, &desc); + EXPECT_NE(result, ZE_RESULT_SUCCESS); + EXPECT_FALSE(driver.ddiInitialized); + } + } +} + +TEST_F(InitDriverUnitTest, InitWithUnsupportedNullDriverType) { + loader::driver_t otherDriver = createNullDriver("ze_fake_other", loader::ZEL_DRIVER_TYPE_OTHER); + ze_result_t result = loader::context->init_driver(otherDriver, 0, nullptr); + EXPECT_NE(result, ZE_RESULT_SUCCESS); + EXPECT_FALSE(otherDriver.ddiInitialized); +} diff --git a/test/init_driver_unit_tests_common.h b/test/init_driver_unit_tests_common.h new file mode 100644 index 00000000..6775f977 --- /dev/null +++ b/test/init_driver_unit_tests_common.h @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2025 Intel Corporation + * SPDX-License-Identifier: MIT + */ + + +#include "gtest/gtest.h" +#include "source/loader/ze_loader_internal.h" +#include "ze_api.h" +#include +#include +#include +#include + +class InitDriverUnitTest : public ::testing::Test { +protected: + void SetUp() override { + if (!loader::context) { + loader::context = new loader::context_t(); + loader::context->debugTraceEnabled = false; + } + } +}; \ No newline at end of file diff --git a/test/loader_api.cpp b/test/loader_api.cpp index 8fa6102b..e2a70837 100644 --- a/test/loader_api.cpp +++ b/test/loader_api.cpp @@ -212,7 +212,7 @@ TEST( TEST( LoaderInit, - GivenZeInitDriversUnsupportedOnTheDriverWhenCallingZeInitDriversThenUninitializedReturned) { + GivenZeInitDriversUnsupportedOnTheDriverWhenCallingZeInitDriversThenUnSupportedReturned) { uint32_t pInitDriversCount = 0; uint32_t pDriverGetCount = 0; @@ -220,7 +220,7 @@ TEST( desc.flags = UINT32_MAX; desc.pNext = nullptr; putenv_safe( const_cast( "ZEL_TEST_MISSING_API=zeInitDrivers" ) ); - EXPECT_EQ(ZE_RESULT_ERROR_UNINITIALIZED, zeInitDrivers(&pInitDriversCount, nullptr, &desc)); + EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, zeInitDrivers(&pInitDriversCount, nullptr, &desc)); EXPECT_EQ(pInitDriversCount, 0); EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0)); EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&pDriverGetCount, nullptr)); diff --git a/test/loader_tracing_layer.cpp b/test/loader_tracing_layer.cpp index c004ad7a..3d3445f2 100644 --- a/test/loader_tracing_layer.cpp +++ b/test/loader_tracing_layer.cpp @@ -955,7 +955,7 @@ namespace } TEST_P(TracingParameterizedTest, - GivenLoaderWithDynamicTracingEnabledAndZerApisUnsupportedAndBothZeAndZerCallbacksRegisteredWhenCallingBothApisThenTracingWorksForZeOnly) + GivenLoaderWithDynamicTracingEnabledAndZerApisUnsupportedAndBothZeAndZerCallbacksRegisteredWhenCallingBothApisThenTracingWorksForZeAndZerCallbacksAreStillInvoked) { putenv_safe(const_cast("ZEL_TEST_NULL_DRIVER_DISABLE_ZER_API=1")); InitMethod initMethod = GetParam(); @@ -978,8 +978,9 @@ namespace const char *errorString = nullptr; ze_result_t result = zerGetLastErrorDescription(&errorString); EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, result); - EXPECT_EQ(0, tracingData.getZerPrologueCallCount("zerGetLastErrorDescription")); - EXPECT_EQ(0, tracingData.getZerEpilogueCallCount("zerGetLastErrorDescription")); + // ZER callbacks should still be called in the tracing layer even if the driver ends up not supporting ZER APIs + EXPECT_EQ(1, tracingData.getZerPrologueCallCount("zerGetLastErrorDescription")); + EXPECT_EQ(1, tracingData.getZerEpilogueCallCount("zerGetLastErrorDescription")); callBasicZeApis(drivers); verifyBasicZeApisCalledBothCallbackTypes(1); @@ -993,8 +994,9 @@ namespace errorString = nullptr; result = zerGetLastErrorDescription(&errorString); EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, result); - EXPECT_EQ(0, tracingData.getZerPrologueCallCount("zerGetLastErrorDescription")); - EXPECT_EQ(0, tracingData.getZerEpilogueCallCount("zerGetLastErrorDescription")); + // ZER callbacks should still be called in the tracing layer even if the driver ends up not supporting ZER APIs + EXPECT_EQ(2, tracingData.getZerPrologueCallCount("zerGetLastErrorDescription")); + EXPECT_EQ(2, tracingData.getZerEpilogueCallCount("zerGetLastErrorDescription")); uint32_t deviceCount = 1; std::vector devices(deviceCount);