Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frozen ThreadState #107

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

DoeringChristian
Copy link
Contributor

@DoeringChristian DoeringChristian commented Oct 23, 2024

This PR implements the Freezing feature in drjit-core (depends on #106).

It makes the following changes:

  • Adds the freezing api to jit.h, including the following functions: jit_freeze_start, jit_freeze_stop, jit_freeze_replay, jit_freeze_dry_run, jit_freeze_pause, jit_freeze_resume, jit_freeze_abort and jit_freeze_destroy as well as the Recording struct.
  • Adds the implementation of the RecordThreadState, a wrapper around either the CUDA or LLVM ThreadState for freezing them.

@DoeringChristian DoeringChristian force-pushed the frozen-threadstate branch 3 times, most recently from 288194d to bb38a86 Compare October 23, 2024 13:23
@DoeringChristian DoeringChristian marked this pull request as ready for review October 25, 2024 10:58
@DoeringChristian
Copy link
Contributor Author

@merlinND, could you please do a preliminary pass over this PR, to catch some obvious mistakes of mine. Note, this also includes the commit from #106, which I'll remove when that gets merged.

Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review part 1. It's mostly about documentation / comments / docstrings to make everything super clear.

It's great that you already created unit tests at the drjit-core level!

I still have to look at the big record_ts.* files.

src/init.cpp Outdated
@@ -10,6 +10,7 @@
#include "internal.h"
#include "cuda_ts.h"
#include "llvm_ts.h"
#include "record_ts.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this include used?

Comment on lines +1292 to +1294
jit_var_inc_ref(o0.index());

outputs[0] = o0.index();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can be done with UInt32::steal()? (Not sure)

tests/vcall.cpp Outdated

jit_log(LogLevel::Info, "Replay:");
{
BasePtr self = (arange<UInt32>(10) + 1) % 3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you try replaying with a different width?

};

{
BasePtr self = arange<UInt32>(10) % 3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a comment explaining what happens if some instances are not used during recording, but we try to use it during replay (i.e. whether or not it's expected to work).

i0.index(),
};

jit_freeze_start(Backend, inputs, 2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment specifying whether or not this will evaluate some things (e.g. inputs) before starting the recording.
If it doesn't do any evals, then maybe you could move the self.eval(); i0.eval(); just above to make clear that they're needed.

tests/record.cpp Outdated
}

/**
* This tests, weather it is possible to record multiple kernels in parallel.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* This tests, weather it is possible to record multiple kernels in parallel.
* This tests, weather it is possible to record multiple independent kernels in the same recording.

tests/record.cpp Outdated
}

/**
* This tests recording and replay of a horizontal reduction operation (hsum).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* This tests recording and replay of a horizontal reduction operation (hsum).
* This tests the recording and replay of a horizontal reduction operation (hsum).

tests/record.cpp Outdated

jit_freeze_start(Backend, inputs, 1);

UInt32 o0 = hsum(i0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more interesting if there was one operation creating a first kernel, and then hsum() is applied on the output of the first kernel.

Comment on lines +465 to +471
/**
* Basic addition test.
* Supplying a different input should replay the operation, with this input.
* In this case, the input at replay is incremented and should result in an
* incremented output.
*/
TEST_BOTH(9_resized_input) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't be needed anymore if you change the width during replay for all of the previous tests.

jit_freeze_destroy(recording);
}

TEST_BOTH(10_input_passthrough) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment describing the test case, like for the previous tests.

Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review part 2

src/record_ts.h Outdated
///
/// Output variables are only tracked through the outputs array, as this
/// information is only needed when constructing the output variables.
///
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
///

bool reverse;
} prefix_reduce;
/// The bucket count for the mkperm operation
uint32_t bucket_count;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the bucket count of an mkperm operation needs to change during replay, would that be correctly detected & handled (maybe by re-recording)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. As long as it comes from python it should be fine. If it is determined by the width of some variable this might be a problem.

src/record_ts.h Outdated
struct {
/// The reduce type of a prefix reduction operation
ReduceOp rtype;
/// Weather a prefix sum operation is exclusive
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Weather a prefix sum operation is exclusive
/// Whether a prefix sum operation is exclusive

Search & replace for this typo

/// Does this operation use optix?
bool uses_optix = false;
/// A copy of the shader binding table, used by the kernel.
OptixShaderBindingTable *sbt;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a deep copy? What happens if the corresponding scene gets destroyed, could this or any elements of the SBT become dangling pointers?

src/record_ts.h Outdated
OpOutput,
/// This variable is part of the function input
Input,
/// This variable has been captured
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify what "captured" means. Is it "captured" in the sense of a closure?

src/record_ts.h Outdated
* This is used by the input variables of a kernel.
*/
uint32_t add_variable(const void *ptr, RecordVariable rv) {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

src/record_ts.h Outdated
}

/// Return the slot index given the data pointer of a variable.
/// This fails if the variable has not been added.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// This fails if the variable has not been added.
/// This fails if the variable had not been previously added.

src/record_ts.h Outdated
jitc_raise("record(): Varaible at slot s%u was read from by "
"operation o%u, but has not yet been initialized! "
"This can happen if the variable was not part of "
"the input but is used by an recorded operation.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain how the user could make it "part of the input", if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, does this make sense?

                jitc_raise("record(): Varaible at slot s%u was read from by "
                           "operation o%u, but has not yet been initialized! "
                           "This can occur if the variable was not part of "
                           "the input but is used by a recorded operation, for "
                           "example if it was not specified as a member in a "
                           "DRJIT_STRUCT but used in the frozen function.",
                           info.slot,
                           (uint32_t) this->m_recording.operations.size());

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! Fixed a couple of typos below:

                jitc_raise("record(): Variable at slot s%u was read by "
                           "operation o%u, but it had not yet been initialized! "
                           "This can occur if the variable was not part of "
                           "the input but is used by a recorded operation, for "
                           "example if it was not specified as a member in a "
                           "DRJIT_STRUCT but used in the frozen function.",
                           info.slot,
                           (uint32_t) this->m_recording.operations.size());

src/record_ts.h Outdated
Comment on lines 832 to 833
/// Helper function recording an output access, given the slot and \ref
/// VarType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Helper function recording an output access, given the slot and \ref
/// VarType
/// Helper function recording an output access, given the slot and \ref VarType

and same below

src/record_ts.h Outdated
Comment on lines 855 to 873
void jitc_freeze_start(JitBackend backend, const uint32_t *inputs,
uint32_t n_inputs);

Recording *jitc_freeze_stop(JitBackend backend, const uint32_t *outputs,
uint32_t n_outputs);

void jitc_freeze_abort(JitBackend backend);

void jitc_freeze_destroy(Recording *recording);

bool jitc_freeze_pause(JitBackend backend);

bool jitc_freeze_resume(JitBackend backend);

void jitc_freeze_replay(Recording *recording, const uint32_t *inputs,
uint32_t *outputs);

int jitc_freeze_dry_run(Recording *recording, const uint32_t *inputs,
uint32_t *outputs);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move these to the top to make them more noticeable, since RecordThreadState is so long.

Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review part 3.
It's a partial review of record_ts.cpp, but unfortunately I won't have time to review the rest of the file, there's too much code.

this->init = rv.init;

if (init == RecordVarInit::Captured) {
// copy the variable, so that it isn't changed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that always match the behavior of no-recording?
I.e. if there are captured variables and I run the code with / without recording, would I always get the same result?

src/record_ts.h Outdated
jitc_raise("record(): Varaible at slot s%u was read from by "
"operation o%u, but has not yet been initialized! "
"This can happen if the variable was not part of "
"the input but is used by an recorded operation.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! Fixed a couple of typos below:

                jitc_raise("record(): Variable at slot s%u was read by "
                           "operation o%u, but it had not yet been initialized! "
                           "This can occur if the variable was not part of "
                           "the input but is used by a recorded operation, for "
                           "example if it was not specified as a member in a "
                           "DRJIT_STRUCT but used in the frozen function.",
                           info.slot,
                           (uint32_t) this->m_recording.operations.size());

Comment on lines +132 to +133
/// Temporary variables used for replaying a recording.
static std::vector<ReplayVariable> replay_variables;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if someone tries replaying two different recordings from two different threads?
(I'm not sure what thread-safety the new DrJit version provides)


// Reconstruct the \ref kernel_params for this launch given the
// allocations when replaying.
kernel_params.clear();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any risk this interferes with an unrelated pending kernel launch?

Comment on lines +209 to +210
// First 3 parameters reserved for: kernel ptr, size, ITT
// identifier
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// First 3 parameters reserved for: kernel ptr, size, ITT
// identifier
// First 3 parameters reserved for: kernel ptr, size, ITT identifier

Comment on lines +505 to +506
jitc_log(LogLevel::Debug, " src.data=%p", src_var.data);
jitc_log(LogLevel::Debug, " dst.data=%p", dst_var.data);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorganize a bit so that the 3 logs can be combined?

offsets_var.alloc(backend, bucket_count * 4 + 1,
offsets_info.vtype);

jitc_log(LogLevel::Debug,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine all the log calls

"replay(): MemcpyAsync s%u <- s%u [%zu]",
dst_info.slot, src_info.slot, src_var.data_size);

dst_var.alloc(backend, src_var.data_size);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be a problem that the memory is allocated even during a dry-run? (for this and other ops)
Will it be quickly freed again? We don't want to accidentally double memory usage because of dry runs.


};
break;
case OpType::Aggregate: {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move contents to a dedicated function since it's a complex case

default:
jitc_fail(
"An operation has been recorded, that is not known to "
"the replay functionality!");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print the op type to help with debugging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants