Skip to content

Commit

Permalink
changed createGraph to createAndLaunchGraph, as well as fixed stream …
Browse files Browse the repository at this point in the history
…capture order
  • Loading branch information
NguyenNhuDi committed Oct 3, 2024
1 parent 9709b37 commit d958870
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 17 deletions.
23 changes: 13 additions & 10 deletions test/test_rocrand_hipgraphs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,28 @@ void test_float(std::function<rocrand_status(rocrand_generator, float*, size_t,

test_utils::GraphHelper gHelper;

gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(
generate_fn(generator, data, 1, mean, stddev)
);

gHelper.createGraph(stream, true, true, true);
gHelper.createAndLaunchGraph(stream, true, true);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(
generate_fn(generator, data+1, 2, mean, stddev)
);

gHelper.createGraph(stream, true, true, true);
gHelper.createAndLaunchGraph(stream, true, true);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(
generate_fn(generator, data, size, mean, stddev)
);

gHelper.createGraph(stream, true, true, true);
gHelper.createAndLaunchGraph(stream, true, true);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
Expand Down Expand Up @@ -109,28 +110,29 @@ TEST_P(rocrand_hipgraph_generate_tests, uniform_float_test)
rocrand_set_stream(generator, stream);

test_utils::GraphHelper gHelper;
gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(
rocrand_generate_uniform(generator, data, 1)
);

gHelper.createGraph(stream, true, true, true);
gHelper.createAndLaunchGraph(stream, true, true);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(
rocrand_generate_uniform(generator, data+1, 2)
);

gHelper.createGraph(stream, true, true, true);
gHelper.createAndLaunchGraph(stream, true, true);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(
rocrand_generate_uniform(generator, data, size)
);

gHelper.createGraph(stream, true, true, true);
gHelper.createAndLaunchGraph(stream, true, true);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
Expand Down Expand Up @@ -158,22 +160,23 @@ TEST_P(rocrand_hipgraph_generate_tests, poisson_test)
rocrand_set_stream(generator, stream);

test_utils::GraphHelper gHelper;
gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(rocrand_generate_poisson(generator, data, 1, 10.0));

gHelper.createGraph(stream, true, true, true);
gHelper.createAndLaunchGraph(stream, true, true);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(rocrand_generate_poisson(generator, data + 1, 2, 500.0));

gHelper.createGraph(stream, true, true, true);
gHelper.createAndLaunchGraph(stream, true, true);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(rocrand_generate_poisson(generator, data, size, 5000.0));

gHelper.createGraph(stream, true, true, true);
gHelper.createAndLaunchGraph(stream, true, true);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
Expand Down
20 changes: 13 additions & 7 deletions test/test_utils_hipgraphs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@ namespace test_utils
hipGraph_t graph;
hipGraphExec_t graph_instance;
public:
void createGraph(hipStream_t & stream, const bool beginCapture = true, const bool launchGraph=false, const bool sync=false){
if (beginCapture)
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));

// End the capture

inline void startStreamCapture(hipStream_t & stream){
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
}

inline void endStreamCapture(hipStream_t & stream){
HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph));
}

inline void createAndLaunchGraph(hipStream_t & stream, const bool launchGraph=false, const bool sync=false){

HIP_CHECK_NON_VOID(hipGraphInstantiate(&graph_instance, graph, nullptr, nullptr, 0));

// Optionally launch the graph
Expand All @@ -62,8 +66,10 @@ namespace test_utils
// Destroy the old graph and instance
cleanupGraphHelper();

// Create a new graph and optionally begin capture
createGraph(stream, beginCapture);
if(beginCapture)
startStreamCapture(stream);

createAndLaunchGraph(stream, true, true);
}

inline void launchGraphHelper(hipStream_t& stream,const bool sync=false)
Expand Down

0 comments on commit d958870

Please sign in to comment.