Skip to content

Commit

Permalink
use a new mode for kernel import: 'usePrecompiledAndBakedKernel'
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardGe committed Sep 26, 2024
1 parent 58db9c1 commit 4cccc45
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 15 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ build/
result.xml
UnitTest/bitcodes/*.fatbin
Test/SimpleD3D12/cache/**

ParallelPrimitives/cache/KernelArgs.h
ParallelPrimitives/cache/Kernels.h
ParallelPrimitives/cache/oro_compiled_kernels.h
35 changes: 35 additions & 0 deletions Orochi/OrochiUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,41 @@ oroFunction OrochiUtils::getFunctionFromString( oroDevice device, const char* so
return f;
}

oroFunction OrochiUtils::getFunctionFromPrecompiledBinary_asData( const unsigned char* precompData, size_t dataSizeInBytes, const std::string& funcName )
{
std::lock_guard<std::recursive_mutex> lock( m_mutex );

const std::string cacheName = OrochiUtilsImpl::getCacheName( "___BAKED_BIN___", funcName );
if( m_kernelMap.find( cacheName.c_str() ) != m_kernelMap.end() )
{
return m_kernelMap[cacheName].function;
}

oroModule module = nullptr;
oroError e = oroModuleLoadData( &module, precompData );
if ( e != oroSuccess )
{
// add some verbose info to help debugging missing data
printf("oroModuleLoadData FAILED (error = %d) loading baked precomp data: %s\n", e, funcName.c_str());
return nullptr;
}

oroFunction functionOut{};
e = oroModuleGetFunction( &functionOut, module, funcName.c_str() );
if ( e != oroSuccess )
{
// add some verbose info to help debugging missing data
printf("oroModuleGetFunction FAILED (error = %d) loading baked precomp data: %s\n", e, funcName.c_str());
return nullptr;
}
OROASSERT( e == oroSuccess, 0 );

m_kernelMap[cacheName].function = functionOut;
m_kernelMap[cacheName].module = module;

return functionOut;
}

oroFunction OrochiUtils::getFunctionFromPrecompiledBinary( const std::string& path, const std::string& funcName )
{
std::lock_guard<std::recursive_mutex> lock( m_mutex );
Expand Down
4 changes: 4 additions & 0 deletions Orochi/OrochiUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class OrochiUtils

oroFunction getFunctionFromPrecompiledBinary( const std::string& path, const std::string& funcName );

// this function is like 'getFunctionFromPrecompiledBinary' but instead of giving a path to a file, we give the data directly.
// ( use the script convert_binary_to_array.py to convert the .hipfb to a C-array. )
oroFunction getFunctionFromPrecompiledBinary_asData( const unsigned char* data, size_t dataSizeInBytes, const std::string& funcName );

oroFunction getFunctionFromFile( oroDevice device, const char* path, const char* funcName, std::vector<const char*>* opts );
oroFunction getFunctionFromString( oroDevice device, const char* source, const char* path, const char* funcName, std::vector<const char*>* opts, int numHeaders, const char** headers, const char** includeNames );
oroFunction getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector<const char*>* opts, int numHeaders = 0, const char** headers = 0, const char** includeNames = 0, oroModule* loadedModule = 0 );
Expand Down
60 changes: 45 additions & 15 deletions ParallelPrimitives/RadixSort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,47 @@
#include <dlfcn.h>
#endif

namespace
{
#if defined( ORO_PRECOMPILED )
constexpr auto useBitCode = true;
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
#include <ParallelPrimitives/cache/oro_compiled_kernels.h> // generate this header with 'convert_binary_to_array.py'
#else
constexpr auto useBitCode = false;
const unsigned char oro_compiled_kernels_h[] = "";
const size_t oro_compiled_kernels_h_size = 0;
#endif

#if defined( ORO_PP_LOAD_FROM_STRING )
constexpr auto useBakeKernel = true;
#else
constexpr auto useBakeKernel = false;
static const char* hip_RadixSortKernels = nullptr;
namespace hip
namespace
{
static const char** RadixSortKernelsArgs = nullptr;
static const char** RadixSortKernelsIncludes = nullptr;
} // namespace hip

// if those 2 preprocessors are enabled, this activates the 'usePrecompiledAndBakedKernel' mode.
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )

// this flag means that we bake the precompiled kernels
constexpr auto usePrecompiledAndBakedKernel = true;

constexpr auto useBitCode = false;
constexpr auto useBakeKernel = false;

#else

constexpr auto usePrecompiledAndBakedKernel = false;

#if defined( ORO_PRECOMPILED )
constexpr auto useBitCode = true; // this flag means we use the bitcode file
#else
constexpr auto useBitCode = false;
#endif

#if defined( ORO_PP_LOAD_FROM_STRING )
constexpr auto useBakeKernel = true; // this flag means we use the HIP source code embeded in the binary ( as a string )
#else
constexpr auto useBakeKernel = false;
static const char* hip_RadixSortKernels = nullptr;
namespace hip
{
static const char** RadixSortKernelsArgs = nullptr;
static const char** RadixSortKernelsIncludes = nullptr;
} // namespace hip
#endif

#endif

static_assert( !( useBitCode && useBakeKernel ), "useBitCode and useBakeKernel cannot coexist" );
Expand Down Expand Up @@ -211,9 +234,14 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
opts.push_back( sort_block_size_param.c_str() );
opts.push_back( sort_num_warps_param.c_str() );


for( const auto& record : records )
{
if constexpr( useBakeKernel )
if constexpr( usePrecompiledAndBakedKernel )
{
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData(oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() );
}
else if constexpr( useBakeKernel )
{
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromString( m_device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts, 1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes );
}
Expand All @@ -231,6 +259,8 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
printKernelInfo( record.kernelName, oroFunctions[record.kernelType] );
}
}

return;
}

int RadixSort::calculateWGsToExecute( const int blockSize ) const noexcept
Expand Down
27 changes: 27 additions & 0 deletions scripts/convert_binary_to_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# convert_binary_to_header.py
import sys
from pathlib import Path

def binary_to_c_array(bin_file, array_name):
with open(bin_file, 'rb') as f:
binary_data = f.read()

hex_array = ', '.join(f'0x{b:02x}' for b in binary_data)
c_array = f'const unsigned char {array_name}[] = {{\n {hex_array}\n}};\n'
c_array += f'const size_t {array_name}_size = sizeof({array_name});\n'
return c_array

if __name__ == "__main__":
if len(sys.argv) != 3:
print(f"Usage: {sys.argv[0]} <input_binary_file> <output_header_file>")
sys.exit(1)

bin_file = sys.argv[1]
header_file_path = sys.argv[2]
header_file = Path(header_file_path).name
array_name = header_file.replace('.', '_')

c_array = binary_to_c_array(bin_file, array_name)
with open(header_file_path, 'w') as f:
f.write("// generated by convert_binary_to_header.py\n")
f.write(c_array)

0 comments on commit 4cccc45

Please sign in to comment.