From 4cccc45ac3916f2127c9cd5c230b4da79b7a07a7 Mon Sep 17 00:00:00 2001 From: Richard Geslot Date: Thu, 26 Sep 2024 15:50:48 +0200 Subject: [PATCH] use a new mode for kernel import: 'usePrecompiledAndBakedKernel' --- .gitignore | 4 ++ Orochi/OrochiUtils.cpp | 35 +++++++++++++++++ Orochi/OrochiUtils.h | 4 ++ ParallelPrimitives/RadixSort.cpp | 60 ++++++++++++++++++++++-------- scripts/convert_binary_to_array.py | 27 ++++++++++++++ 5 files changed, 115 insertions(+), 15 deletions(-) create mode 100644 scripts/convert_binary_to_array.py diff --git a/.gitignore b/.gitignore index b2f0886..240dbfb 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,7 @@ build/ result.xml UnitTest/bitcodes/*.fatbin Test/SimpleD3D12/cache/** + +ParallelPrimitives/cache/KernelArgs.h +ParallelPrimitives/cache/Kernels.h +ParallelPrimitives/cache/oro_compiled_kernels.h diff --git a/Orochi/OrochiUtils.cpp b/Orochi/OrochiUtils.cpp index b47be25..7c2fcd9 100644 --- a/Orochi/OrochiUtils.cpp +++ b/Orochi/OrochiUtils.cpp @@ -558,6 +558,41 @@ oroFunction OrochiUtils::getFunctionFromString( oroDevice device, const char* so return f; } +oroFunction OrochiUtils::getFunctionFromPrecompiledBinary_asData( const unsigned char* precompData, size_t dataSizeInBytes, const std::string& funcName ) +{ + std::lock_guard lock( m_mutex ); + + const std::string cacheName = OrochiUtilsImpl::getCacheName( "___BAKED_BIN___", funcName ); + if( m_kernelMap.find( cacheName.c_str() ) != m_kernelMap.end() ) + { + return m_kernelMap[cacheName].function; + } + + oroModule module = nullptr; + oroError e = oroModuleLoadData( &module, precompData ); + if ( e != oroSuccess ) + { + // add some verbose info to help debugging missing data + printf("oroModuleLoadData FAILED (error = %d) loading baked precomp data: %s\n", e, funcName.c_str()); + return nullptr; + } + + oroFunction functionOut{}; + e = oroModuleGetFunction( &functionOut, module, funcName.c_str() ); + if ( e != oroSuccess ) + { + // add some verbose info to help debugging missing data + printf("oroModuleGetFunction FAILED (error = %d) loading baked precomp data: %s\n", e, funcName.c_str()); + return nullptr; + } + OROASSERT( e == oroSuccess, 0 ); + + m_kernelMap[cacheName].function = functionOut; + m_kernelMap[cacheName].module = module; + + return functionOut; +} + oroFunction OrochiUtils::getFunctionFromPrecompiledBinary( const std::string& path, const std::string& funcName ) { std::lock_guard lock( m_mutex ); diff --git a/Orochi/OrochiUtils.h b/Orochi/OrochiUtils.h index 3f281e9..e8ca8cf 100644 --- a/Orochi/OrochiUtils.h +++ b/Orochi/OrochiUtils.h @@ -69,6 +69,10 @@ class OrochiUtils oroFunction getFunctionFromPrecompiledBinary( const std::string& path, const std::string& funcName ); + // this function is like 'getFunctionFromPrecompiledBinary' but instead of giving a path to a file, we give the data directly. + // ( use the script convert_binary_to_array.py to convert the .hipfb to a C-array. ) + oroFunction getFunctionFromPrecompiledBinary_asData( const unsigned char* data, size_t dataSizeInBytes, const std::string& funcName ); + oroFunction getFunctionFromFile( oroDevice device, const char* path, const char* funcName, std::vector* opts ); oroFunction getFunctionFromString( oroDevice device, const char* source, const char* path, const char* funcName, std::vector* opts, int numHeaders, const char** headers, const char** includeNames ); oroFunction getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector* opts, int numHeaders = 0, const char** headers = 0, const char** includeNames = 0, oroModule* loadedModule = 0 ); diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index 951df61..ce77424 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -40,24 +40,47 @@ #include #endif -namespace -{ -#if defined( ORO_PRECOMPILED ) -constexpr auto useBitCode = true; +#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING ) +#include // generate this header with 'convert_binary_to_array.py' #else -constexpr auto useBitCode = false; +const unsigned char oro_compiled_kernels_h[] = ""; +const size_t oro_compiled_kernels_h_size = 0; #endif -#if defined( ORO_PP_LOAD_FROM_STRING ) -constexpr auto useBakeKernel = true; -#else -constexpr auto useBakeKernel = false; -static const char* hip_RadixSortKernels = nullptr; -namespace hip +namespace { -static const char** RadixSortKernelsArgs = nullptr; -static const char** RadixSortKernelsIncludes = nullptr; -} // namespace hip + +// if those 2 preprocessors are enabled, this activates the 'usePrecompiledAndBakedKernel' mode. +#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING ) + + // this flag means that we bake the precompiled kernels + constexpr auto usePrecompiledAndBakedKernel = true; + + constexpr auto useBitCode = false; + constexpr auto useBakeKernel = false; + +#else + + constexpr auto usePrecompiledAndBakedKernel = false; + + #if defined( ORO_PRECOMPILED ) + constexpr auto useBitCode = true; // this flag means we use the bitcode file + #else + constexpr auto useBitCode = false; + #endif + + #if defined( ORO_PP_LOAD_FROM_STRING ) + constexpr auto useBakeKernel = true; // this flag means we use the HIP source code embeded in the binary ( as a string ) + #else + constexpr auto useBakeKernel = false; + static const char* hip_RadixSortKernels = nullptr; + namespace hip + { + static const char** RadixSortKernelsArgs = nullptr; + static const char** RadixSortKernelsIncludes = nullptr; + } // namespace hip + #endif + #endif static_assert( !( useBitCode && useBakeKernel ), "useBitCode and useBakeKernel cannot coexist" ); @@ -211,9 +234,14 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string opts.push_back( sort_block_size_param.c_str() ); opts.push_back( sort_num_warps_param.c_str() ); + for( const auto& record : records ) { - if constexpr( useBakeKernel ) + if constexpr( usePrecompiledAndBakedKernel ) + { + oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData(oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() ); + } + else if constexpr( useBakeKernel ) { oroFunctions[record.kernelType] = m_oroutils.getFunctionFromString( m_device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts, 1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes ); } @@ -231,6 +259,8 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string printKernelInfo( record.kernelName, oroFunctions[record.kernelType] ); } } + + return; } int RadixSort::calculateWGsToExecute( const int blockSize ) const noexcept diff --git a/scripts/convert_binary_to_array.py b/scripts/convert_binary_to_array.py new file mode 100644 index 0000000..baab3b8 --- /dev/null +++ b/scripts/convert_binary_to_array.py @@ -0,0 +1,27 @@ +# convert_binary_to_header.py +import sys +from pathlib import Path + +def binary_to_c_array(bin_file, array_name): + with open(bin_file, 'rb') as f: + binary_data = f.read() + + hex_array = ', '.join(f'0x{b:02x}' for b in binary_data) + c_array = f'const unsigned char {array_name}[] = {{\n {hex_array}\n}};\n' + c_array += f'const size_t {array_name}_size = sizeof({array_name});\n' + return c_array + +if __name__ == "__main__": + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + + bin_file = sys.argv[1] + header_file_path = sys.argv[2] + header_file = Path(header_file_path).name + array_name = header_file.replace('.', '_') + + c_array = binary_to_c_array(bin_file, array_name) + with open(header_file_path, 'w') as f: + f.write("// generated by convert_binary_to_header.py\n") + f.write(c_array)