Skip to content

Commit

Permalink
Fixed memory leaks in rocrand_tests (ROCm#557)
Browse files Browse the repository at this point in the history
* added hipFree to test_rocrand_cpp_basic

* fixed memory leak for test_rocrand_config_dispatch

* fixed a memory leak in test_utils

* changed createGraph to createAndLaunchGraph, as well as fixed stream capture order

* changed default boolean (kaunchGraph, sync)  to be true in createAndLaunchGraph

* added back missing end stream capture

* reformated curlys for consistency

* removed createAndLaunchGraph inside resetGraphHelper
  • Loading branch information
NguyenNhuDi authored Oct 17, 2024
1 parent 70d29fe commit c1640d9
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 84 deletions.
2 changes: 2 additions & 0 deletions test/internal/test_rocrand_config_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ TEST(rocrand_config_dispatch_tests, host_matches_device)

ASSERT_NE(host_arch, rocrand_impl::host::target_arch::invalid);
ASSERT_EQ(host_arch, device_arch);

HIP_CHECK(hipFree(device_arch_ptr));
}

TEST(rocrand_config_dispatch_tests, parse_common_architectures)
Expand Down
6 changes: 4 additions & 2 deletions test/test_rocrand_cpp_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ TYPED_TEST(rocrand_cpp_basic_tests, move_construction)

float actual;
HIP_CHECK(hipMemcpy(&actual, d_data, sizeof(actual), hipMemcpyDeviceToHost));

ASSERT_EQ(expected, actual);

HIP_CHECK(hipFree(d_data));
}

TYPED_TEST(rocrand_cpp_basic_tests, move_assignment)
Expand Down Expand Up @@ -119,6 +120,7 @@ TYPED_TEST(rocrand_cpp_basic_tests, move_assignment)

float actual;
HIP_CHECK(hipMemcpy(&actual, d_data, sizeof(actual), hipMemcpyDeviceToHost));

ASSERT_EQ(expected, actual);

HIP_CHECK(hipFree(d_data));
}
50 changes: 25 additions & 25 deletions test/test_rocrand_hipgraphs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,34 @@ void test_float(std::function<rocrand_status(rocrand_generator, float*, size_t,
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);
test_utils::GraphHelper gHelper;

gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(
generate_fn(generator, data, 1, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(
generate_fn(generator, data+1, 2, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(
generate_fn(generator, data, size, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
gHelper.cleanupGraphHelper();
HIP_CHECK(hipStreamDestroy(stream));
}

Expand Down Expand Up @@ -109,34 +109,34 @@ TEST_P(rocrand_hipgraph_generate_tests, uniform_float_test)
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);
test_utils::GraphHelper gHelper;
gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(
rocrand_generate_uniform(generator, data, 1)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(
rocrand_generate_uniform(generator, data+1, 2)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(
rocrand_generate_uniform(generator, data, size)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
gHelper.cleanupGraphHelper();
HIP_CHECK(hipStreamDestroy(stream));
}

Expand All @@ -159,28 +159,28 @@ TEST_P(rocrand_hipgraph_generate_tests, poisson_test)
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);
test_utils::GraphHelper gHelper;
gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(rocrand_generate_poisson(generator, data, 1, 10.0));

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(rocrand_generate_poisson(generator, data + 1, 2, 500.0));

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(rocrand_generate_poisson(generator, data, size, 5000.0));

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
gHelper.cleanupGraphHelper();
HIP_CHECK(hipStreamDestroy(stream));
}

Expand Down
112 changes: 55 additions & 57 deletions test/test_utils_hipgraphs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,64 +28,62 @@
// Note: graphs will not work on the default stream.
namespace test_utils
{
class GraphHelper{
private:
hipGraph_t graph;
hipGraphExec_t graph_instance;
public:

inline void startStreamCapture(hipStream_t & stream)
{
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
}

inline void endStreamCapture(hipStream_t & stream)
{
HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph));
}

inline void createAndLaunchGraph(hipStream_t & stream, const bool launchGraph=true, const bool sync=true)
{

endStreamCapture(stream);

HIP_CHECK_NON_VOID(hipGraphInstantiate(&graph_instance, graph, nullptr, nullptr, 0));

// Optionally launch the graph
if (launchGraph)
HIP_CHECK_NON_VOID(hipGraphLaunch(graph_instance, stream));

// Optionally synchronize the stream when we're done
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}

inline hipGraph_t createGraphHelper(hipStream_t& stream, const bool beginCapture=true)
{
// Create a new graph
hipGraph_t graph;
HIP_CHECK_NON_VOID(hipGraphCreate(&graph, 0));

// Optionally begin stream capture
if (beginCapture)
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));

return graph;
}

inline void cleanupGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance)
{
HIP_CHECK_NON_VOID(hipGraphDestroy(graph));
HIP_CHECK_NON_VOID(hipGraphExecDestroy(instance));
}

inline void resetGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance, hipStream_t& stream, const bool beginCapture=true)
{
// Destroy the old graph and instance
cleanupGraphHelper(graph, instance);

// Create a new graph and optionally begin capture
graph = createGraphHelper(stream, beginCapture);
}

inline hipGraphExec_t endCaptureGraphHelper(hipGraph_t& graph, hipStream_t& stream, const bool launchGraph=false, const bool sync=false)
{
// End the capture
HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph));

// Instantiate the graph
hipGraphExec_t instance;
HIP_CHECK_NON_VOID(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0));

// Optionally launch the graph
if (launchGraph)
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream));

// Optionally synchronize the stream when we're done
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));

return instance;
}

inline void launchGraphHelper(hipGraphExec_t& instance, hipStream_t& stream, const bool sync=false)
{
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream));

// Optionally sync after the launch
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}

inline void cleanupGraphHelper()
{
HIP_CHECK_NON_VOID(hipGraphDestroy(this->graph));
HIP_CHECK_NON_VOID(hipGraphExecDestroy(this->graph_instance));
}

inline void resetGraphHelper(hipStream_t& stream, const bool beginCapture=true)
{
// Destroy the old graph and instance
cleanupGraphHelper();

if(beginCapture)
startStreamCapture(stream);
}

inline void launchGraphHelper(hipStream_t& stream,const bool sync=false)
{
HIP_CHECK_NON_VOID(hipGraphLaunch(this->graph_instance, stream));

// Optionally sync after the launch
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}
};
} // end namespace test_utils

#endif //ROCRAND_TEST_UTILS_HIPGRAPHS_HPP

0 comments on commit c1640d9

Please sign in to comment.