-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Frozen ThreadState #107
base: master
Are you sure you want to change the base?
Frozen ThreadState #107
Conversation
288194d
to
bb38a86
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review part 1. It's mostly about documentation / comments / docstrings to make everything super clear.
It's great that you already created unit tests at the drjit-core
level!
I still have to look at the big record_ts.*
files.
src/init.cpp
Outdated
@@ -10,6 +10,7 @@ | |||
#include "internal.h" | |||
#include "cuda_ts.h" | |||
#include "llvm_ts.h" | |||
#include "record_ts.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this include used?
jit_var_inc_ref(o0.index()); | ||
|
||
outputs[0] = o0.index(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this can be done with UInt32::steal()
? (Not sure)
tests/vcall.cpp
Outdated
|
||
jit_log(LogLevel::Info, "Replay:"); | ||
{ | ||
BasePtr self = (arange<UInt32>(10) + 1) % 3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you try replaying with a different width?
}; | ||
|
||
{ | ||
BasePtr self = arange<UInt32>(10) % 3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a comment explaining what happens if some instances are not used during recording, but we try to use it during replay (i.e. whether or not it's expected to work).
i0.index(), | ||
}; | ||
|
||
jit_freeze_start(Backend, inputs, 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment specifying whether or not this will evaluate some things (e.g. inputs
) before starting the recording.
If it doesn't do any evals, then maybe you could move the self.eval(); i0.eval();
just above to make clear that they're needed.
tests/record.cpp
Outdated
} | ||
|
||
/** | ||
* This tests, weather it is possible to record multiple kernels in parallel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* This tests, weather it is possible to record multiple kernels in parallel. | |
* This tests, weather it is possible to record multiple independent kernels in the same recording. |
tests/record.cpp
Outdated
} | ||
|
||
/** | ||
* This tests recording and replay of a horizontal reduction operation (hsum). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* This tests recording and replay of a horizontal reduction operation (hsum). | |
* This tests the recording and replay of a horizontal reduction operation (hsum). |
tests/record.cpp
Outdated
|
||
jit_freeze_start(Backend, inputs, 1); | ||
|
||
UInt32 o0 = hsum(i0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be more interesting if there was one operation creating a first kernel, and then hsum()
is applied on the output of the first kernel.
/** | ||
* Basic addition test. | ||
* Supplying a different input should replay the operation, with this input. | ||
* In this case, the input at replay is incremented and should result in an | ||
* incremented output. | ||
*/ | ||
TEST_BOTH(9_resized_input) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't be needed anymore if you change the width during replay for all of the previous tests.
jit_freeze_destroy(recording); | ||
} | ||
|
||
TEST_BOTH(10_input_passthrough) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment describing the test case, like for the previous tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review part 2
src/record_ts.h
Outdated
/// | ||
/// Output variables are only tracked through the outputs array, as this | ||
/// information is only needed when constructing the output variables. | ||
/// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// |
bool reverse; | ||
} prefix_reduce; | ||
/// The bucket count for the mkperm operation | ||
uint32_t bucket_count; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the bucket count of an mkperm
operation needs to change during replay, would that be correctly detected & handled (maybe by re-recording)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. As long as it comes from python it should be fine. If it is determined by the width of some variable this might be a problem.
src/record_ts.h
Outdated
struct { | ||
/// The reduce type of a prefix reduction operation | ||
ReduceOp rtype; | ||
/// Weather a prefix sum operation is exclusive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Weather a prefix sum operation is exclusive | |
/// Whether a prefix sum operation is exclusive |
Search & replace for this typo
/// Does this operation use optix? | ||
bool uses_optix = false; | ||
/// A copy of the shader binding table, used by the kernel. | ||
OptixShaderBindingTable *sbt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a deep copy? What happens if the corresponding scene gets destroyed, could this or any elements of the SBT become dangling pointers?
src/record_ts.h
Outdated
OpOutput, | ||
/// This variable is part of the function input | ||
Input, | ||
/// This variable has been captured |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specify what "captured" means. Is it "captured" in the sense of a closure?
src/record_ts.h
Outdated
* This is used by the input variables of a kernel. | ||
*/ | ||
uint32_t add_variable(const void *ptr, RecordVariable rv) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/record_ts.h
Outdated
} | ||
|
||
/// Return the slot index given the data pointer of a variable. | ||
/// This fails if the variable has not been added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// This fails if the variable has not been added. | |
/// This fails if the variable had not been previously added. |
src/record_ts.h
Outdated
jitc_raise("record(): Varaible at slot s%u was read from by " | ||
"operation o%u, but has not yet been initialized! " | ||
"This can happen if the variable was not part of " | ||
"the input but is used by an recorded operation.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain how the user could make it "part of the input", if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, does this make sense?
jitc_raise("record(): Varaible at slot s%u was read from by "
"operation o%u, but has not yet been initialized! "
"This can occur if the variable was not part of "
"the input but is used by a recorded operation, for "
"example if it was not specified as a member in a "
"DRJIT_STRUCT but used in the frozen function.",
info.slot,
(uint32_t) this->m_recording.operations.size());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! Fixed a couple of typos below:
jitc_raise("record(): Variable at slot s%u was read by "
"operation o%u, but it had not yet been initialized! "
"This can occur if the variable was not part of "
"the input but is used by a recorded operation, for "
"example if it was not specified as a member in a "
"DRJIT_STRUCT but used in the frozen function.",
info.slot,
(uint32_t) this->m_recording.operations.size());
src/record_ts.h
Outdated
/// Helper function recording an output access, given the slot and \ref | ||
/// VarType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Helper function recording an output access, given the slot and \ref | |
/// VarType | |
/// Helper function recording an output access, given the slot and \ref VarType |
and same below
src/record_ts.h
Outdated
void jitc_freeze_start(JitBackend backend, const uint32_t *inputs, | ||
uint32_t n_inputs); | ||
|
||
Recording *jitc_freeze_stop(JitBackend backend, const uint32_t *outputs, | ||
uint32_t n_outputs); | ||
|
||
void jitc_freeze_abort(JitBackend backend); | ||
|
||
void jitc_freeze_destroy(Recording *recording); | ||
|
||
bool jitc_freeze_pause(JitBackend backend); | ||
|
||
bool jitc_freeze_resume(JitBackend backend); | ||
|
||
void jitc_freeze_replay(Recording *recording, const uint32_t *inputs, | ||
uint32_t *outputs); | ||
|
||
int jitc_freeze_dry_run(Recording *recording, const uint32_t *inputs, | ||
uint32_t *outputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move these to the top to make them more noticeable, since RecordThreadState
is so long.
489f314
to
6a53f8f
Compare
7ea730c
to
e01f19a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review part 3.
It's a partial review of record_ts.cpp
, but unfortunately I won't have time to review the rest of the file, there's too much code.
this->init = rv.init; | ||
|
||
if (init == RecordVarInit::Captured) { | ||
// copy the variable, so that it isn't changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would that always match the behavior of no-recording?
I.e. if there are captured variables and I run the code with / without recording, would I always get the same result?
src/record_ts.h
Outdated
jitc_raise("record(): Varaible at slot s%u was read from by " | ||
"operation o%u, but has not yet been initialized! " | ||
"This can happen if the variable was not part of " | ||
"the input but is used by an recorded operation.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! Fixed a couple of typos below:
jitc_raise("record(): Variable at slot s%u was read by "
"operation o%u, but it had not yet been initialized! "
"This can occur if the variable was not part of "
"the input but is used by a recorded operation, for "
"example if it was not specified as a member in a "
"DRJIT_STRUCT but used in the frozen function.",
info.slot,
(uint32_t) this->m_recording.operations.size());
/// Temporary variables used for replaying a recording. | ||
static std::vector<ReplayVariable> replay_variables; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if someone tries replaying two different recordings from two different threads?
(I'm not sure what thread-safety the new DrJit version provides)
|
||
// Reconstruct the \ref kernel_params for this launch given the | ||
// allocations when replaying. | ||
kernel_params.clear(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any risk this interferes with an unrelated pending kernel launch?
// First 3 parameters reserved for: kernel ptr, size, ITT | ||
// identifier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// First 3 parameters reserved for: kernel ptr, size, ITT | |
// identifier | |
// First 3 parameters reserved for: kernel ptr, size, ITT identifier |
jitc_log(LogLevel::Debug, " src.data=%p", src_var.data); | ||
jitc_log(LogLevel::Debug, " dst.data=%p", dst_var.data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reorganize a bit so that the 3 logs can be combined?
offsets_var.alloc(backend, bucket_count * 4 + 1, | ||
offsets_info.vtype); | ||
|
||
jitc_log(LogLevel::Debug, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combine all the log calls
"replay(): MemcpyAsync s%u <- s%u [%zu]", | ||
dst_info.slot, src_info.slot, src_var.data_size); | ||
|
||
dst_var.alloc(backend, src_var.data_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be a problem that the memory is allocated even during a dry-run? (for this and other ops)
Will it be quickly freed again? We don't want to accidentally double memory usage because of dry runs.
|
||
}; | ||
break; | ||
case OpType::Aggregate: { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move contents to a dedicated function since it's a complex case
default: | ||
jitc_fail( | ||
"An operation has been recorded, that is not known to " | ||
"the replay functionality!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Print the op type to help with debugging
This PR implements the Freezing feature in
drjit-core
(depends on #106).It makes the following changes:
jit.h
, including the following functions:jit_freeze_start
,jit_freeze_stop
,jit_freeze_replay
,jit_freeze_dry_run
,jit_freeze_pause
,jit_freeze_resume
,jit_freeze_abort
andjit_freeze_destroy
as well as theRecording
struct.RecordThreadState
, a wrapper around either the CUDA or LLVMThreadState
for freezing them.