From 7e01f402d79488cb5379b6088653ff6a3030cc17 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 25 Oct 2024 13:19:59 -0500 Subject: [PATCH] [MigraphX] Fix potential synchronization problem when ORT_ENABLE_STREAM is true (#22589) ### Description Replace `hipMemcpy` with `hipMemcpyWithStream` ### Motivation and Context `hipMemcpy` uses default stream, which may be out of synchronization with the current stream when ORT_ENABLE_STREAM is defined. --- onnxruntime/core/providers/migraphx/gpu_data_transfer.cc | 2 +- .../core/providers/migraphx/migraphx_execution_provider.cc | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 94480c308b99f..51625b83b8f61 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -57,7 +57,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); } else { // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } } else if (src_device.Type() == OrtDevice::GPU) { HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index e41cd577b0b21..dca38480434fe 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1445,7 +1445,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::vector ort_shape{res_lens.begin(), res_lens.end()}; auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); void* output_data = output_tensor.GetTensorMutableRawData(); - HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice)); + HIP_CALL_THROW(hipMemcpyWithStream(output_data, + gpu_res.data(), + res_shape.bytes(), + hipMemcpyDeviceToDevice, + static_cast(rocm_stream))); } } };