Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory leak fix for rocRAND unit test #571

Merged
merged 12 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion scripts/copyright-date/check-copyright.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ if $forkdiff; then
source_commit="remotes/$remote/HEAD"

# don't use fork-point for finding fork point (lol)
# see: https://stackoverflow.com/a/53981615
diff_hash="$(git merge-base "$source_commit" "$branch")"
fi

Expand Down
2 changes: 2 additions & 0 deletions test/internal/test_rocrand_config_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ TEST(rocrand_config_dispatch_tests, host_matches_device)

ASSERT_NE(host_arch, rocrand_impl::host::target_arch::invalid);
ASSERT_EQ(host_arch, device_arch);

HIP_CHECK(hipFree(device_arch_ptr));
}

TEST(rocrand_config_dispatch_tests, parse_common_architectures)
Expand Down
6 changes: 4 additions & 2 deletions test/test_rocrand_cpp_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ TYPED_TEST(rocrand_cpp_basic_tests, move_construction)

float actual;
HIP_CHECK(hipMemcpy(&actual, d_data, sizeof(actual), hipMemcpyDeviceToHost));

ASSERT_EQ(expected, actual);

HIP_CHECK(hipFree(d_data));
}

TYPED_TEST(rocrand_cpp_basic_tests, move_assignment)
Expand Down Expand Up @@ -119,6 +120,7 @@ TYPED_TEST(rocrand_cpp_basic_tests, move_assignment)

float actual;
HIP_CHECK(hipMemcpy(&actual, d_data, sizeof(actual), hipMemcpyDeviceToHost));

ASSERT_EQ(expected, actual);

HIP_CHECK(hipFree(d_data));
}
50 changes: 25 additions & 25 deletions test/test_rocrand_hipgraphs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,34 @@ void test_float(std::function<rocrand_status(rocrand_generator, float*, size_t,
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);
test_utils::GraphHelper gHelper;

gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(
generate_fn(generator, data, 1, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(
generate_fn(generator, data+1, 2, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(
generate_fn(generator, data, size, mean, stddev)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
gHelper.cleanupGraphHelper();
HIP_CHECK(hipStreamDestroy(stream));
}

Expand Down Expand Up @@ -109,34 +109,34 @@ TEST_P(rocrand_hipgraph_generate_tests, uniform_float_test)
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);
test_utils::GraphHelper gHelper;
gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(
rocrand_generate_uniform(generator, data, 1)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(
rocrand_generate_uniform(generator, data+1, 2)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(
rocrand_generate_uniform(generator, data, size)
);

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
gHelper.cleanupGraphHelper();
HIP_CHECK(hipStreamDestroy(stream));
}

Expand All @@ -159,28 +159,28 @@ TEST_P(rocrand_hipgraph_generate_tests, poisson_test)
HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
rocrand_set_stream(generator, stream);

hipGraphExec_t graph_instance;
hipGraph_t graph = test_utils::createGraphHelper(stream);
test_utils::GraphHelper gHelper;
gHelper.startStreamCapture(stream);

// Any sizes
ROCRAND_CHECK(rocrand_generate_poisson(generator, data, 1, 10.0));

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

// Any alignment
ROCRAND_CHECK(rocrand_generate_poisson(generator, data + 1, 2, 500.0));

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
test_utils::resetGraphHelper(graph, graph_instance, stream);
gHelper.createAndLaunchGraph(stream);
gHelper.resetGraphHelper(stream);

ROCRAND_CHECK(rocrand_generate_poisson(generator, data, size, 5000.0));

graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipFree(data));
ROCRAND_CHECK(rocrand_destroy_generator(generator));
test_utils::cleanupGraphHelper(graph, graph_instance);
gHelper.cleanupGraphHelper();
HIP_CHECK(hipStreamDestroy(stream));
}

Expand Down
112 changes: 55 additions & 57 deletions test/test_utils_hipgraphs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,64 +28,62 @@
// Note: graphs will not work on the default stream.
namespace test_utils
{
class GraphHelper{
private:
hipGraph_t graph;
hipGraphExec_t graph_instance;
public:

inline void startStreamCapture(hipStream_t & stream)
{
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
}

inline void endStreamCapture(hipStream_t & stream)
{
HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph));
}

inline void createAndLaunchGraph(hipStream_t & stream, const bool launchGraph=true, const bool sync=true)
{

endStreamCapture(stream);

HIP_CHECK_NON_VOID(hipGraphInstantiate(&graph_instance, graph, nullptr, nullptr, 0));

// Optionally launch the graph
if (launchGraph)
HIP_CHECK_NON_VOID(hipGraphLaunch(graph_instance, stream));

// Optionally synchronize the stream when we're done
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}

inline hipGraph_t createGraphHelper(hipStream_t& stream, const bool beginCapture=true)
{
// Create a new graph
hipGraph_t graph;
HIP_CHECK_NON_VOID(hipGraphCreate(&graph, 0));

// Optionally begin stream capture
if (beginCapture)
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));

return graph;
}

inline void cleanupGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance)
{
HIP_CHECK_NON_VOID(hipGraphDestroy(graph));
HIP_CHECK_NON_VOID(hipGraphExecDestroy(instance));
}

inline void resetGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance, hipStream_t& stream, const bool beginCapture=true)
{
// Destroy the old graph and instance
cleanupGraphHelper(graph, instance);

// Create a new graph and optionally begin capture
graph = createGraphHelper(stream, beginCapture);
}

inline hipGraphExec_t endCaptureGraphHelper(hipGraph_t& graph, hipStream_t& stream, const bool launchGraph=false, const bool sync=false)
{
// End the capture
HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph));

// Instantiate the graph
hipGraphExec_t instance;
HIP_CHECK_NON_VOID(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0));

// Optionally launch the graph
if (launchGraph)
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream));

// Optionally synchronize the stream when we're done
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));

return instance;
}

inline void launchGraphHelper(hipGraphExec_t& instance, hipStream_t& stream, const bool sync=false)
{
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream));

// Optionally sync after the launch
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}

inline void cleanupGraphHelper()
{
HIP_CHECK_NON_VOID(hipGraphDestroy(this->graph));
HIP_CHECK_NON_VOID(hipGraphExecDestroy(this->graph_instance));
}

inline void resetGraphHelper(hipStream_t& stream, const bool beginCapture=true)
{
// Destroy the old graph and instance
cleanupGraphHelper();

if(beginCapture)
startStreamCapture(stream);
}

inline void launchGraphHelper(hipStream_t& stream,const bool sync=false)
{
HIP_CHECK_NON_VOID(hipGraphLaunch(this->graph_instance, stream));

// Optionally sync after the launch
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}
};
} // end namespace test_utils

#endif //ROCRAND_TEST_UTILS_HIPGRAPHS_HPP
Loading