From c8d357149019a18f9b5ded9fcbaf85781b494bea Mon Sep 17 00:00:00 2001 From: Dan Hoeflinger Date: Thu, 16 May 2024 14:12:48 -0400 Subject: [PATCH] improvements for device_vector --- help_function/src/onedpl_test_sort_by_key.cpp | 28 ++++--------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/help_function/src/onedpl_test_sort_by_key.cpp b/help_function/src/onedpl_test_sort_by_key.cpp index 4f34c3901..09f583af8 100644 --- a/help_function/src/onedpl_test_sort_by_key.cpp +++ b/help_function/src/onedpl_test_sort_by_key.cpp @@ -201,40 +201,22 @@ int main() { { // Test Two, test calls to dpct::sort using device vectors - dpct::device_vector keys_vec(10); - dpct::device_vector values_vec(10); - std::vector keys_data{4, 8, 5, 3, 0, 9, 7, 2, 1, 6}; std::vector values_data{13, 16, 17, 11, 19, 14, 12, 18, 10, 15}; - dpct::get_default_queue().submit([&](sycl::handler& h) { - h.memcpy(keys_vec.data(), keys_data.data(), 10 * sizeof(int)); - }); - - dpct::get_default_queue().submit([&](sycl::handler& h) { - h.memcpy(values_vec.data(), values_data.data(), 10 * sizeof(int)); - }); - dpct::get_default_queue().wait(); + dpct::device_vector keys_vec(keys_data); + dpct::device_vector values_vec(values_data); auto keys_it = keys_vec.begin(); auto keys_it_end = keys_vec.end(); auto values_it = values_vec.begin(); { // call algorithm - dpct::sort(oneapi::dpl::execution::make_device_policy<>(dpct::get_default_queue()), keys_it, keys_it_end, values_it); + dpct::sort(oneapi::dpl::execution::dpcpp_default, keys_it, keys_it_end, values_it); // keys is now = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} // values is now = {19, 10, 18, 11, 13, 17, 15, 12, 16, 14} } - dpct::get_default_queue().submit([&](sycl::handler& h) { - h.memcpy(keys_data.data(), keys_vec.data(), 10 * sizeof(int)); - }); - - dpct::get_default_queue().submit([&](sycl::handler& h) { - h.memcpy(values_data.data(), values_vec.data(), 10 * sizeof(int)); - }); - dpct::get_default_queue().wait(); - { int check_keys[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; int check_values[10] = {19, 10, 18, 11, 13, 17, 15, 12, 16, 14}; @@ -243,8 +225,8 @@ int main() { // check that values and keys are correct for (int i = 0; i != 10; ++i) { - num_failing += ASSERT_EQUAL(test_name, values_data[i], check_values[i]); - num_failing += ASSERT_EQUAL(test_name, keys_data[i], check_keys[i]); + num_failing += ASSERT_EQUAL(test_name, values_vec[i], check_values[i]); + num_failing += ASSERT_EQUAL(test_name, keys_vec[i], check_keys[i]); } failed_tests += test_passed(num_failing, test_name);