Skip to content

Commit

Permalink
API for renderer to cache OptiX ptx (#1938)
Browse files Browse the repository at this point in the history
This change adds a caching layer to the OptiX PTX pipeline that can skip LLVM generation and optimization. This helps improve scene build times for re-renders of scenes with large shader counts.

We generate a hash key from the optimized shadergroup, and depend on the renderer to provide a cache backend implementation. There is no overhead if the renderer doesn't explicitly opt-in to ptx caching.

A simple backend will be added to testshade/testrender in a follow-up PR.

As background information:

The sequence at runtime is group oso ->1-> runtime optimize by liboslexec ->2-> JIT to PTX via LLVM ->3-> driver converts PTX to actual executable GPU code on that hardware.

The "OptiX Cache" (part of OptiX & driver) speed up step (3) by not having the last step for optimized/JITed shaders it's encountered before.

This PR adds another cache to step (2), managed by OSL and/or the renderer internals, to allow you to skip the bulk of the work for that step for optimized shaders you've encountered already.

You still pay full price the very first time a shader is encountered, and that leads to terrible TTFP (time to first pixel). But this should take a big bite out of that in practice since it's very common to have encountered most shader configuration before. If the implementation the renderer provides is to store the cache persistently on disk or in a real database, it will be shared from run to run and possibly from user to user.

---------

Signed-off-by: Chris Hellmuth <[email protected]>
  • Loading branch information
chellmuth authored Feb 11, 2025
1 parent 4ce094f commit d0d4a06
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 37 deletions.
13 changes: 13 additions & 0 deletions src/include/OSL/rendererservices.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,19 @@ class OSLEXECPUBLIC RendererServices {
}
};

// Default no-op implementations of the caching api.
// Currently used for caching optix ptx before llvm generation.
virtual void cache_insert(string_view cachename, string_view key,
string_view value) const
{
}

virtual bool cache_get(string_view cachename, string_view key,
std::string& value) const
{
return false;
}

/// A renderer may choose to support batched execution by providing pointers
/// to objects satisfying the BatchedRendererServices<WidthOf<#>> interface
/// for specific batch sizes.
Expand Down
1 change: 1 addition & 0 deletions src/liboslexec/backendllvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ class BackendLLVM final : public OSOProcessorBase {

/// Return whether or not we are compiling for an OptiX-based renderer.
bool use_optix() { return m_use_optix; }
bool use_optix_cache() { return shadingsys().use_optix_cache(); }

/// Return if we should compile against free function versions of Renderer Service.
bool use_rs_bitcode() { return m_use_rs_bitcode; }
Expand Down
19 changes: 19 additions & 0 deletions src/liboslexec/instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,25 @@ ShaderGroup::setup_interactive_arena(cspan<uint8_t> paramblock)



void
ShaderGroup::generate_optix_cache_key(string_view code)
{
const uint64_t ir_key = Strutil::strhash(code);

std::string safegroup;
safegroup = Strutil::replace(name(), "/", "_", true);
safegroup = Strutil::replace(safegroup, ":", "_", true);

// Cache key includes the groupname in addition to the serialized IR.
// This is because the groupname makes its way into the ptx's direct callable name,
// but isn't included in the serialization.
std::string cache_key = fmtformat("cache-osl-ptx-{}-{}", safegroup, ir_key);

m_optix_cache_key = cache_key;
}



std::string
ShaderGroup::serialize() const
{
Expand Down
21 changes: 13 additions & 8 deletions src/liboslexec/llvm_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,13 @@ LLVMGEN(llvm_gen_printf_legacy)
}
#endif

// Some ops prepend things
if (op.opname() == op_error || op.opname() == op_warning) {
s = fmtformat("Shader {} [{}]: {}", op.opname(),
rop.inst()->shadername(), s);
// TODO: optix cache should handle ustrings generated during llvm-gen
if (!rop.use_optix_cache()) {
// Some ops prepend things
if (op.opname() == op_error || op.opname() == op_warning) {
s = fmtformat("Shader {} [{}]: {}", op.opname(),
rop.inst()->shadername(), s);
}
}

// Now go back and put the new format string in its place
Expand Down Expand Up @@ -709,10 +712,12 @@ LLVMGEN(llvm_gen_print_fmt)
}
}
}
// Some ops prepend things
if (op.opname() == op_error || op.opname() == op_warning) {
s = fmtformat("Shader {} [{}]: {}", op.opname(),
rop.inst()->shadername(), s);
if (!rop.use_optix_cache()) {
// Some ops prepend things
if (op.opname() == op_error || op.opname() == op_warning) {
s = fmtformat("Shader {} [{}]: {}", op.opname(),
rop.inst()->shadername(), s);
}
}
ustring s_ustring(s.c_str());
call_args.push_back(rop.llvm_const_hash(s_ustring));
Expand Down
8 changes: 8 additions & 0 deletions src/liboslexec/llvm_instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2481,6 +2481,14 @@ BackendLLVM::run()
group().llvm_compiled_layer(nlayers - 1));
}

if (shadingsys().use_optix_cache()) {
std::string cache_key = group().optix_cache_key();
renderer()->cache_insert(
"optix_ptx", cache_key,
optix_cache_wrap(group().m_llvm_ptx_compiled_version,
group().llvm_groupdata_size()));
}

// We are destroying the entire module below,
// no reason to bother destroying individual functions
#if 0
Expand Down
27 changes: 27 additions & 0 deletions src/liboslexec/oslexec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,32 @@ shadertype_from_name(string_view name)
}



std::string
optix_cache_wrap(string_view ptx, size_t groupdata_size)
{
// Cache string is the ptx file with groupdata size on top as a comment.
// This way the cache string is a valid ptx program, which can be useful
// for debugging.
return fmtformat("// {}\n{}", groupdata_size, ptx);
}



void
optix_cache_unwrap(string_view cache_value, std::string& ptx,
size_t& groupdata_size)
{
size_t groupdata_end_index = cache_value.find('\n');
if (groupdata_end_index != std::string::npos) {
constexpr int offset = 3; // Account for the "// " prefix
std::string groupdata_string
= cache_value.substr(offset, groupdata_end_index - offset);
groupdata_size = std::stoll(groupdata_string);

ptx = cache_value.substr(groupdata_end_index + 1);
}
}

}; // namespace pvt
OSL_NAMESPACE_END
20 changes: 17 additions & 3 deletions src/liboslexec/oslexec_pvt.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ struct PerThreadInfo {

namespace pvt {

void
optix_cache_unwrap(string_view cache_value, std::string& ptx,
size_t& groupdata_size);
std::string
optix_cache_wrap(string_view ptx, size_t groupdata_size);

// forward definitions
class ShadingSystemImpl;
class ShaderInstance;
Expand Down Expand Up @@ -632,6 +638,7 @@ class ShadingSystemImpl {
TextureSystem* texturesys() const { return m_texturesys; }

bool use_optix() const { return m_use_optix; }
bool use_optix_cache() const { return m_use_optix_cache; }
bool debug_nan() const { return m_debugnan; }
bool debug_uninit() const { return m_debug_uninit; }
bool lockgeom_default() const { return m_lockgeom_default; }
Expand Down Expand Up @@ -954,9 +961,10 @@ class ShadingSystemImpl {
std::vector<ustring> m_raytypes; ///< Names of ray types
std::vector<ustring> m_renderer_outputs; ///< Names of renderer outputs
std::vector<SymLocationDesc> m_symlocs;
int m_max_local_mem_KB; ///< Local storage can a shader use
int m_compile_report; ///< Print compilation report?
bool m_use_optix; ///< This is an OptiX-based renderer
int m_max_local_mem_KB; ///< Local storage can a shader use
int m_compile_report; ///< Print compilation report?
bool m_use_optix; ///< This is an OptiX-based renderer
bool m_use_optix_cache; ///< Renderer-enabled caching for OptiX ptx
int m_max_optix_groupdata_alloc; ///< Maximum OptiX groupdata buffer allocation
bool m_buffer_printf; ///< Buffer/batch printf output?
bool m_no_noise; ///< Substitute trivial noise calls
Expand Down Expand Up @@ -1843,6 +1851,10 @@ class ShaderGroup {
void name(ustring name) { m_name = name; }
ustring name() const { return m_name; }

// Generate and memoize the cache key so we don't calculate it twice
void generate_optix_cache_key(string_view code);
std::string optix_cache_key() const { return m_optix_cache_key; }

std::string serialize() const;

void lock() const { m_mutex.lock(); }
Expand Down Expand Up @@ -2046,6 +2058,8 @@ class ShaderGroup {
atomic_ll m_executions { 0 }; ///< Number of times the group executed
atomic_ll m_stat_total_shading_time_ticks { 0 }; // Shading time (ticks)

std::string m_optix_cache_key;

// PTX assembly for compiled ShaderGroup
std::string m_llvm_ptx_compiled_version;

Expand Down
15 changes: 15 additions & 0 deletions src/liboslexec/runtimeoptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3124,6 +3124,21 @@ RuntimeOptimizer::printinst(std::ostream& out) const



std::string
RuntimeOptimizer::serialize()
{
std::ostringstream ss {};
int nlayers = (int)group().nlayers();
for (int layer = 0; layer < nlayers; ++layer) {
set_inst(layer);
printinst(ss);
}

return ss.str();
}



void
RuntimeOptimizer::run()
{
Expand Down
2 changes: 2 additions & 0 deletions src/liboslexec/runtimeoptimize.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,8 @@ class RuntimeOptimizer final : public OSOProcessorBase {
fmtformat(fmt, std::forward<Args>(args)...));
}

std::string serialize();

private:
int m_optimize; ///< Current optimization level
bool m_opt_simplify_param; ///< Turn instance params into const?
Expand Down
71 changes: 45 additions & 26 deletions src/liboslexec/shadingsys.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,7 @@ ShadingSystemImpl::ShadingSystemImpl(RendererServices* renderer,
, m_max_local_mem_KB(2048)
, m_compile_report(0)
, m_use_optix(renderer->supports("OptiX"))
, m_use_optix_cache(m_use_optix && renderer->supports("optix_ptx_cache"))
, m_max_optix_groupdata_alloc(0)
, m_buffer_printf(true)
, m_no_noise(false)
Expand Down Expand Up @@ -3801,6 +3802,9 @@ ShadingSystemImpl::optimize_group(ShaderGroup& group, ShadingContext* ctx,
}
group.m_optimized = true;

if (use_optix_cache())
group.generate_optix_cache_key(rop.serialize());

spin_lock stat_lock(m_stat_mutex);
if (!need_jit) {
m_stat_opt_locking_time += locking_time;
Expand All @@ -3812,34 +3816,49 @@ ShadingSystemImpl::optimize_group(ShaderGroup& group, ShadingContext* ctx,
}

if (need_jit) {
BackendLLVM lljitter(*this, group, ctx);
lljitter.run();

// NOTE: it is now possible to optimize and not JIT
// which would leave the cleanup to happen
// when the ShadingSystem is destroyed

// Only cleanup when are not batching or if
// the batch jit has already happened,
// as it requires the ops so we can't delete them yet!
if (((renderer()->batched(WidthOf<16>()) == nullptr)
&& (renderer()->batched(WidthOf<8>()) == nullptr)
&& (renderer()->batched(WidthOf<4>()) == nullptr))
|| group.batch_jitted()) {
group_post_jit_cleanup(group);
bool cached = false;
if (use_optix_cache()) {
std::string cache_key = group.optix_cache_key();

std::string cache_value;
if (renderer()->cache_get("optix_ptx", cache_key, cache_value)) {
cached = true;
optix_cache_unwrap(cache_value,
group.m_llvm_ptx_compiled_version,
group.m_llvm_groupdata_size);
}
}

group.m_jitted = true;
spin_lock stat_lock(m_stat_mutex);
m_stat_opt_locking_time += locking_time;
m_stat_optimization_time += timer();
m_stat_total_llvm_time += lljitter.m_stat_total_llvm_time;
m_stat_llvm_setup_time += lljitter.m_stat_llvm_setup_time;
m_stat_llvm_irgen_time += lljitter.m_stat_llvm_irgen_time;
m_stat_llvm_opt_time += lljitter.m_stat_llvm_opt_time;
m_stat_llvm_jit_time += lljitter.m_stat_llvm_jit_time;
m_stat_max_llvm_local_mem = std::max(m_stat_max_llvm_local_mem,
lljitter.m_llvm_local_mem);
if (!cached) {
BackendLLVM lljitter(*this, group, ctx);
lljitter.run();

// NOTE: it is now possible to optimize and not JIT
// which would leave the cleanup to happen
// when the ShadingSystem is destroyed

// Only cleanup when are not batching or if
// the batch jit has already happened,
// as it requires the ops so we can't delete them yet!
if (((renderer()->batched(WidthOf<16>()) == nullptr)
&& (renderer()->batched(WidthOf<8>()) == nullptr)
&& (renderer()->batched(WidthOf<4>()) == nullptr))
|| group.batch_jitted()) {
group_post_jit_cleanup(group);
}

group.m_jitted = true;
spin_lock stat_lock(m_stat_mutex);
m_stat_opt_locking_time += locking_time;
m_stat_optimization_time += timer();
m_stat_total_llvm_time += lljitter.m_stat_total_llvm_time;
m_stat_llvm_setup_time += lljitter.m_stat_llvm_setup_time;
m_stat_llvm_irgen_time += lljitter.m_stat_llvm_irgen_time;
m_stat_llvm_opt_time += lljitter.m_stat_llvm_opt_time;
m_stat_llvm_jit_time += lljitter.m_stat_llvm_jit_time;
m_stat_max_llvm_local_mem = std::max(m_stat_max_llvm_local_mem,
lljitter.m_llvm_local_mem);
}
}

if (ctx_allocated) {
Expand Down

0 comments on commit d0d4a06

Please sign in to comment.