From c1bdeb0679b3b11e88a5135f5da14a4b238690c7 Mon Sep 17 00:00:00 2001 From: makslevental Date: Mon, 30 Sep 2024 21:21:22 -0400 Subject: [PATCH] use xrtDeviceHandle --- .github/workflows/ci-linux.yml | 7 ++- .github/workflows/ci-windows.yml | 6 +-- .../driver/xrt/direct_allocator.cc | 9 ++-- .../driver/xrt/direct_allocator.h | 2 +- .../driver/xrt/direct_command_buffer.cc | 1 - .../driver/xrt/native_executable.cc | 16 ++++--- .../driver/xrt/native_executable.h | 6 +-- .../driver/xrt/nop_executable_cache.cc | 8 ++-- .../driver/xrt/nop_executable_cache.h | 2 +- .../src/iree-amd-aie/driver/xrt/xrt_device.cc | 36 +++++++++++----- .../src/iree-amd-aie/driver/xrt/xrt_device.h | 3 +- .../src/iree-amd-aie/driver/xrt/xrt_driver.cc | 43 +++++-------------- tests/conftest.py | 16 ++++--- tests/test_matmul.py | 9 ++-- 14 files changed, 82 insertions(+), 82 deletions(-) diff --git a/.github/workflows/ci-linux.yml b/.github/workflows/ci-linux.yml index 1ac6b7940..f9f69c11b 100644 --- a/.github/workflows/ci-linux.yml +++ b/.github/workflows/ci-linux.yml @@ -150,12 +150,11 @@ jobs: source .venv/bin/activate pip install -r tests/requirements.txt - - name : E2E comparison of AIE to llvm-cpu run: | source .venv/bin/activate source /opt/xilinx/xrt/setup.sh - for i in {1..100}; do + for i in {1..50}; do echo "run $i" python build_tools/ci/cpu_comparison/run.py \ test_aie_vs_cpu \ @@ -196,7 +195,7 @@ jobs: sudo prlimit -lunlimited --pid $$ source .venv/bin/activate source /opt/xilinx/xrt/setup.sh - for i in {1..100}; do + for i in {1..50}; do echo "run $i" bash build_tools/ci/run_matmul_test.sh \ test_matmuls \ @@ -210,7 +209,7 @@ jobs: run: | source .venv/bin/activate source /opt/xilinx/xrt/setup.sh - for i in {1..100}; do + for i in {1..50}; do echo "run $i" pytest -v tests \ --capture=tee-sys \ diff --git a/.github/workflows/ci-windows.yml b/.github/workflows/ci-windows.yml index b041351b7..ef31c478a 100644 --- a/.github/workflows/ci-windows.yml +++ b/.github/workflows/ci-windows.yml @@ -169,7 +169,7 @@ jobs: shell: bash run: | source .venv/Scripts/activate - for i in {1..100}; do + for i in {1..50}; do echo "run $i" bash build_tools/ci/run_matmul_test.sh \ /c/test_matmuls \ @@ -181,7 +181,7 @@ jobs: shell: bash run: | source .venv/Scripts/activate - for i in {1..100}; do + for i in {1..50}; do echo "run $i" python build_tools/ci/cpu_comparison/run.py \ /c/test_aie_vs_cpu \ @@ -194,7 +194,7 @@ jobs: ls $env:XILINX_XRT .\.venv\Scripts\Activate.ps1 mkdir temp - for ($i = 1; $i -le 100; $i++) { + for ($i = 1; $i -le 50; $i++) { echo "run $i" pytest tests -sv ` --basetemp=$PWD\temp ` diff --git a/runtime/src/iree-amd-aie/driver/xrt/direct_allocator.cc b/runtime/src/iree-amd-aie/driver/xrt/direct_allocator.cc index 641f08c9c..be2370d2f 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/direct_allocator.cc +++ b/runtime/src/iree-amd-aie/driver/xrt/direct_allocator.cc @@ -28,7 +28,7 @@ typedef struct iree_hal_xrt_allocator_t { // The device that this allocator is attached to. iree_hal_device_t* base_device; - xrt::device device; + xrtDeviceHandle device_hdl; iree_allocator_t host_allocator; @@ -46,7 +46,7 @@ static iree_hal_xrt_allocator_t* iree_hal_xrt_allocator_cast( } iree_status_t iree_hal_xrt_allocator_create( - iree_hal_device_t* base_device, xrt::device device, + iree_hal_device_t* base_device, xrtDeviceHandle device_hdl, iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) { IREE_ASSERT_ARGUMENT(base_device); IREE_ASSERT_ARGUMENT(out_allocator); @@ -61,7 +61,7 @@ iree_status_t iree_hal_xrt_allocator_create( &allocator->resource); allocator->base_device = base_device; iree_hal_device_retain(base_device); - allocator->device = device; + allocator->device_hdl = device_hdl; allocator->host_allocator = host_allocator; *out_allocator = (iree_hal_allocator_t*)allocator; @@ -171,7 +171,8 @@ static iree_status_t iree_hal_xrt_allocator_allocate_buffer( std::unique_ptr xrt_buffer; try { - xrt_buffer = std::make_unique(allocator->device, allocation_size, + xrt::device device(xrtDeviceToXclDevice(allocator->device_hdl)); + xrt_buffer = std::make_unique(device, allocation_size, XRT_BO_FLAGS_HOST_ONLY, group_id); } catch (...) { IREE_TRACE_ZONE_END(z0); diff --git a/runtime/src/iree-amd-aie/driver/xrt/direct_allocator.h b/runtime/src/iree-amd-aie/driver/xrt/direct_allocator.h index 39a0f3e10..104bb2e2b 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/direct_allocator.h +++ b/runtime/src/iree-amd-aie/driver/xrt/direct_allocator.h @@ -17,7 +17,7 @@ extern "C" { // Creates an XRT memory allocator. iree_status_t iree_hal_xrt_allocator_create( - iree_hal_device_t* base_device, xrt::device device, + iree_hal_device_t* base_device, xrtDeviceHandle device_hdl, iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator); #ifdef __cplusplus diff --git a/runtime/src/iree-amd-aie/driver/xrt/direct_command_buffer.cc b/runtime/src/iree-amd-aie/driver/xrt/direct_command_buffer.cc index 5785f3484..770527e93 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/direct_command_buffer.cc +++ b/runtime/src/iree-amd-aie/driver/xrt/direct_command_buffer.cc @@ -292,7 +292,6 @@ static iree_status_t iree_hal_xrt_direct_command_buffer_push_descriptor_set( IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &binding->buffer)); - std::unique_ptr sub_buffer; current_bindings[i] = iree_hal_xrt_buffer_handle( iree_hal_buffer_allocated_buffer(binding->buffer)); current_offsets[i] = diff --git a/runtime/src/iree-amd-aie/driver/xrt/native_executable.cc b/runtime/src/iree-amd-aie/driver/xrt/native_executable.cc index c3f8c7aa4..6d37d9e53 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/native_executable.cc +++ b/runtime/src/iree-amd-aie/driver/xrt/native_executable.cc @@ -99,7 +99,8 @@ static iree_status_t iree_amd_aie_hal_xrt_native_executable_flatbuffer_verify( } iree_status_t iree_hal_xrt_native_executable_create( - xrt::device device, const iree_hal_executable_params_t* executable_params, + xrtDeviceHandle device_hdl, + const iree_hal_executable_params_t* executable_params, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) { IREE_ASSERT_ARGUMENT(executable_params); IREE_ASSERT_ARGUMENT(out_executable); @@ -180,24 +181,27 @@ iree_status_t iree_hal_xrt_native_executable_create( // XRT API needs this vector and cant actually read a void*. std::vector xclbinVector( xclbin_fb, xclbin_fb + flatbuffers_string_len(xclbin_fb)); + xrt::xclbin xclbin; try { - params->xclbin = xrt::xclbin(xclbinVector); + xclbin = xrt::xclbin(xclbinVector); } catch (std::exception& e) { return iree_make_status(IREE_STATUS_INTERNAL, "XCLBIN load error: %s", e.what()); } + xrt::device device(xrtDeviceToXclDevice(device_hdl)); + IREE_ASSERT(device, "failed to find device"); + try { - device.register_xclbin(params->xclbin); + device.register_xclbin(xclbin); } catch (std::exception& e) { return iree_make_status(IREE_STATUS_INTERNAL, "XCLBIN register error: %s", e.what()); } try { - params->context = - xrt::hw_context(device, params->xclbin.get_uuid(), - xrt::hw_context::access_mode::exclusive); + params->context = xrt::hw_context( + device, xclbin.get_uuid(), xrt::hw_context::access_mode::exclusive); } catch (std::exception& e) { return iree_make_status(IREE_STATUS_INTERNAL, "xrt::hw_context context: %s", e.what()); diff --git a/runtime/src/iree-amd-aie/driver/xrt/native_executable.h b/runtime/src/iree-amd-aie/driver/xrt/native_executable.h index 141bbebca..bc01b9d23 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/native_executable.h +++ b/runtime/src/iree-amd-aie/driver/xrt/native_executable.h @@ -7,7 +7,7 @@ #ifndef IREE_AMD_AIE_DRIVER_XRT_NATIVE_EXECUTABLE_H_ #define IREE_AMD_AIE_DRIVER_XRT_NATIVE_EXECUTABLE_H_ -#include +#include #include "iree/base/api.h" #include "iree/base/tracing.h" @@ -22,7 +22,6 @@ extern "C" { // Object and launch parameters for a compute kernel. typedef struct iree_hal_xrt_kernel_params_t { xrt::hw_context context; - xrt::xclbin xclbin; // The kernel code object. xrt::kernel kernel; // Instruction buffer argument to the kernel. @@ -37,7 +36,8 @@ typedef struct iree_hal_xrt_kernel_params_t { // |out_executable| must be released by the caller (see // iree_hal_executable_release). iree_status_t iree_hal_xrt_native_executable_create( - xrt::device device, const iree_hal_executable_params_t* executable_params, + xrtDeviceHandle device_hdl, + const iree_hal_executable_params_t* executable_params, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable); // Returns the kernel launch parameters for the given |entry_point|. diff --git a/runtime/src/iree-amd-aie/driver/xrt/nop_executable_cache.cc b/runtime/src/iree-amd-aie/driver/xrt/nop_executable_cache.cc index 655133e61..3120e5a49 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/nop_executable_cache.cc +++ b/runtime/src/iree-amd-aie/driver/xrt/nop_executable_cache.cc @@ -17,7 +17,7 @@ typedef struct iree_hal_xrt_nop_executable_cache_t { // at offset 0. iree_hal_resource_t resource; - xrt::device device; + xrtDeviceHandle device_hdl; iree_allocator_t host_allocator; } iree_hal_xrt_nop_executable_cache_t; @@ -35,7 +35,7 @@ iree_hal_xrt_nop_executable_cache_cast( } iree_status_t iree_hal_xrt_nop_executable_cache_create( - xrt::device device, iree_string_view_t identifier, + xrtDeviceHandle device_hdl, iree_string_view_t identifier, iree_allocator_t host_allocator, iree_hal_executable_cache_t** out_executable_cache) { IREE_ASSERT_ARGUMENT(out_executable_cache); @@ -49,7 +49,7 @@ iree_status_t iree_hal_xrt_nop_executable_cache_create( iree_hal_resource_initialize(&iree_hal_xrt_nop_executable_cache_vtable, &executable_cache->resource); executable_cache->host_allocator = host_allocator; - executable_cache->device = device; + executable_cache->device_hdl = device_hdl; *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache; IREE_TRACE_ZONE_END(z0); @@ -82,7 +82,7 @@ static iree_status_t iree_hal_xrt_nop_executable_cache_prepare_executable( iree_hal_xrt_nop_executable_cache_t* executable_cache = iree_hal_xrt_nop_executable_cache_cast(base_executable_cache); return iree_hal_xrt_native_executable_create( - executable_cache->device, executable_params, + executable_cache->device_hdl, executable_params, executable_cache->host_allocator, out_executable); } diff --git a/runtime/src/iree-amd-aie/driver/xrt/nop_executable_cache.h b/runtime/src/iree-amd-aie/driver/xrt/nop_executable_cache.h index 9a84f9e6d..5362f98af 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/nop_executable_cache.h +++ b/runtime/src/iree-amd-aie/driver/xrt/nop_executable_cache.h @@ -22,7 +22,7 @@ extern "C" { // |out_executable_cache| must be released by the caller (see // iree_hal_executable_cache_release). iree_status_t iree_hal_xrt_nop_executable_cache_create( - xrt::device device, iree_string_view_t identifier, + xrtDeviceHandle device_hdl, iree_string_view_t identifier, iree_allocator_t host_allocator, iree_hal_executable_cache_t** out_executable_cache); diff --git a/runtime/src/iree-amd-aie/driver/xrt/xrt_device.cc b/runtime/src/iree-amd-aie/driver/xrt/xrt_device.cc index 7b9a36f78..03aa86c9f 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/xrt_device.cc +++ b/runtime/src/iree-amd-aie/driver/xrt/xrt_device.cc @@ -6,6 +6,7 @@ #include "iree-amd-aie/driver/xrt/xrt_device.h" +#include "experimental/xrt_system.h" #include "iree-amd-aie/driver/xrt/direct_allocator.h" #include "iree-amd-aie/driver/xrt/direct_command_buffer.h" #include "iree-amd-aie/driver/xrt/nop_executable_cache.h" @@ -32,7 +33,7 @@ typedef struct iree_hal_xrt_device_t { iree_allocator_t host_allocator; iree_hal_allocator_t* device_allocator; - xrt::device device; + xrtDeviceHandle device_hdl; } iree_hal_xrt_device_t; namespace { @@ -52,17 +53,30 @@ void iree_hal_xrt_device_params_initialize( } static iree_status_t iree_hal_xrt_device_create_internal( - iree_string_view_t identifier, xrt::device xrt_device, - const iree_hal_xrt_device_params_t* params, iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { + iree_string_view_t identifier, const iree_hal_xrt_device_params_t* params, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { iree_hal_xrt_device_t* device = nullptr; iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size; IREE_RETURN_IF_ERROR( iree_allocator_malloc(host_allocator, total_size, (void**)&device)); + try { + if (IREE_UNLIKELY(xrt::system::enumerate_devices() == 0)) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "No XRT devices found"); + } + } catch (std::exception& e) { + return iree_make_status(IREE_STATUS_INTERNAL, + "xrt::system::enumerate_devices failed: %s", + e.what()); + } + + xrtDeviceHandle device_hdl = xrtDeviceOpen(0); + IREE_ASSERT(device_hdl, "failed to open xrt device"); + iree_status_t status = - iree_hal_xrt_allocator_create((iree_hal_device_t*)device, xrt_device, + iree_hal_xrt_allocator_create((iree_hal_device_t*)device, device_hdl, host_allocator, &device->device_allocator); if (iree_status_is_ok(status)) { iree_hal_resource_initialize(&iree_hal_xrt_device_vtable, @@ -74,7 +88,7 @@ static iree_status_t iree_hal_xrt_device_create_internal( &device->block_pool); device->host_allocator = host_allocator; - device->device = xrt_device; + device->device_hdl = device_hdl; device->params = *params; *out_device = (iree_hal_device_t*)device; } else { @@ -85,13 +99,12 @@ static iree_status_t iree_hal_xrt_device_create_internal( iree_status_t iree_hal_xrt_device_create( iree_string_view_t identifier, const iree_hal_xrt_device_params_t* params, - xrt::device device, iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { IREE_ASSERT_ARGUMENT(out_device); IREE_TRACE_ZONE_BEGIN(z0); iree_status_t status = iree_hal_xrt_device_create_internal( - identifier, device, params, host_allocator, out_device); + identifier, params, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); return status; @@ -104,7 +117,9 @@ static void iree_hal_xrt_device_destroy(iree_hal_device_t* base_device) { iree_hal_allocator_release(device->device_allocator); iree_arena_block_pool_deinitialize(&device->block_pool); + xrtDeviceHandle device_hdl = device->device_hdl; iree_allocator_free(host_allocator, device); + (void)xrtDeviceClose(device_hdl); IREE_TRACE_ZONE_END(z0); } @@ -201,7 +216,8 @@ static iree_status_t iree_hal_xrt_device_create_executable_cache( iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { iree_hal_xrt_device_t* device = iree_hal_xrt_device_cast(base_device); return iree_hal_xrt_nop_executable_cache_create( - device->device, identifier, device->host_allocator, out_executable_cache); + device->device_hdl, identifier, device->host_allocator, + out_executable_cache); } static iree_status_t iree_hal_xrt_device_import_file( diff --git a/runtime/src/iree-amd-aie/driver/xrt/xrt_device.h b/runtime/src/iree-amd-aie/driver/xrt/xrt_device.h index f23610dbd..aa77fdeb7 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/xrt_device.h +++ b/runtime/src/iree-amd-aie/driver/xrt/xrt_device.h @@ -21,8 +21,7 @@ extern "C" { // |out_device| must be released by the caller (see iree_hal_device_release). iree_status_t iree_hal_xrt_device_create( iree_string_view_t identifier, const iree_hal_xrt_device_params_t* params, - xrt::device device, iree_allocator_t host_allocator, - iree_hal_device_t** out_device); + iree_allocator_t host_allocator, iree_hal_device_t** out_device); #ifdef __cplusplus } // extern "C" diff --git a/runtime/src/iree-amd-aie/driver/xrt/xrt_driver.cc b/runtime/src/iree-amd-aie/driver/xrt/xrt_driver.cc index 2f8a1134f..ca868860a 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/xrt_driver.cc +++ b/runtime/src/iree-amd-aie/driver/xrt/xrt_driver.cc @@ -6,13 +6,10 @@ #include "iree-amd-aie/driver/xrt/xrt_device.h" #include "iree/base/api.h" -#include "iree/base/target_platform.h" #include "iree/base/tracing.h" #include "iree/hal/api.h" // XRT includes -#include "experimental/xrt_system.h" -#include "xrt.h" #include "xrt/xrt_device.h" #include "xrt/xrt_kernel.h" @@ -40,7 +37,7 @@ typedef struct iree_hal_xrt_driver_t { // Parameters used to control device behavior. iree_hal_xrt_device_params_t device_params; - xrt::device device; + xrtDeviceHandle device_hdl; } iree_hal_xrt_driver_t; @@ -79,23 +76,6 @@ iree_status_t iree_hal_xrt_driver_create_internal( (char*)driver + iree_sizeof_struct(*driver)); driver->device_params = *device_params; - try { - if (IREE_UNLIKELY(xrt::system::enumerate_devices() == 0)) { - return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, - "No XRT devices found"); - } - } catch (std::exception& e) { - return iree_make_status(IREE_STATUS_INTERNAL, - "xrt::system::enumerate_devices failed: %s", - e.what()); - } - // Get handle to xrt device - try { - driver->device = xrt::device(0); - } catch (std::exception& e) { - return iree_make_status(IREE_STATUS_INTERNAL, "xrt::device(0) failed: %s", - e.what()); - } *out_driver = reinterpret_cast(driver); return iree_ok_status(); } @@ -130,10 +110,10 @@ static iree_status_t iree_hal_xrt_driver_dump_device_info( iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, iree_string_builder_t* builder) { iree_hal_xrt_driver_t* driver = iree_hal_xrt_driver_cast(base_driver); - xrt::device device = driver->device; + xrtDeviceHandle device_hdl = driver->device_hdl; IREE_RETURN_IF_ERROR( iree_string_builder_append_cstring(builder, "\n- Platform:")); - + xrt::device device(xrtDeviceToXclDevice(device_hdl)); std::string platform_info = device.get_info(); const char* platform_info_str = platform_info.c_str(); if (platform_info_str) { @@ -149,7 +129,7 @@ static iree_status_t iree_hal_xrt_driver_dump_device_info( // |out_device_info| must point to valid memory and additional data will be // appended to |buffer_ptr| and the new pointer is returned. static iree_status_t iree_hal_xrt_populate_device_info( - xrt::device device, uint8_t* buffer_ptr, uint8_t** out_buffer_ptr, + xrtDeviceHandle device_hdl, uint8_t* buffer_ptr, uint8_t** out_buffer_ptr, iree_hal_device_info_t* out_device_info) { *out_buffer_ptr = buffer_ptr; @@ -157,6 +137,7 @@ static iree_status_t iree_hal_xrt_populate_device_info( // We currenly only work with one XRT device and its device id is 0. out_device_info->device_id = 0; + xrt::device device(xrtDeviceToXclDevice(device_hdl)); std::string device_name = device.get_info(); const size_t name_len = strlen(device_name.c_str()); if (name_len >= IREE_HAL_XRT_MAX_DEVICE_NAME_LENGTH) { @@ -177,7 +158,7 @@ static iree_status_t iree_hal_xrt_driver_query_available_devices( iree_host_size_t* out_device_info_count, iree_hal_device_info_t** out_device_infos) { iree_hal_xrt_driver_t* driver = iree_hal_xrt_driver_cast(base_driver); - xrt::device device = driver->device; + xrtDeviceHandle device_hdl = driver->device_hdl; // Allocate the return infos and populate with the devices. iree_hal_device_info_t* device_infos = nullptr; iree_host_size_t single_info_size = @@ -190,7 +171,7 @@ static iree_status_t iree_hal_xrt_driver_query_available_devices( // Append all path and name strings at the end of the struct. uint8_t* buffer_ptr = (uint8_t*)device_infos + sizeof(iree_hal_device_info_t); iree_status_t status = iree_hal_xrt_populate_device_info( - device, buffer_ptr, &buffer_ptr, device_infos); + device_hdl, buffer_ptr, &buffer_ptr, device_infos); if (iree_status_is_ok(status)) { // We currenly only work with one XRT device. *out_device_info_count = 1; @@ -209,9 +190,8 @@ static iree_status_t iree_hal_xrt_driver_create_device_by_id( iree_hal_xrt_driver_t* driver = iree_hal_xrt_driver_cast(base_driver); iree_string_view_t device_name = iree_make_cstring_view("xrt"); - iree_status_t status = - iree_hal_xrt_device_create(device_name, &driver->device_params, - driver->device, host_allocator, out_device); + iree_status_t status = iree_hal_xrt_device_create( + device_name, &driver->device_params, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); return status; @@ -226,9 +206,8 @@ static iree_status_t iree_hal_xrt_driver_create_device_by_path( iree_hal_xrt_driver_t* driver = iree_hal_xrt_driver_cast(base_driver); iree_string_view_t device_name = iree_make_cstring_view("xrt"); - iree_status_t status = - iree_hal_xrt_device_create(device_name, &driver->device_params, - driver->device, host_allocator, out_device); + iree_status_t status = iree_hal_xrt_device_create( + device_name, &driver->device_params, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); return status; diff --git a/tests/conftest.py b/tests/conftest.py index e1b3f17b4..3a2a5f76b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,17 @@ -import os from contextlib import contextmanager from pathlib import Path import numpy as np import pytest -from iree.runtime import VmModule +from ml_dtypes import bfloat16 from iree.compiler import ir from iree.compiler._mlir_libs import get_dialect_registry from iree.compiler.api import Session, Output, Source from iree.compiler.extras import types as T -from ml_dtypes import bfloat16 +from iree.runtime import VmModule from iree.runtime import get_driver, Config, SystemContext - for t in [ "i8", "i16", @@ -99,8 +97,13 @@ def session_module(iree_session, tmp_path) -> ir.Module: yield iree_session, module_op +@pytest.fixture(scope="session") +def device(device="xrt") -> ir.Module: + yield get_driver(device).create_default_device() + + @contextmanager -def invokable_module(session, module, device="xrt") -> VmModule: +def invokable_module(session, module, device) -> VmModule: source = Source.wrap_buffer(session, str(module).encode()) inv = session.invocation() inv.parse_source(source) @@ -108,8 +111,7 @@ def invokable_module(session, module, device="xrt") -> VmModule: compiled_flatbuffer = Output.open_membuffer() inv.output_vm_bytecode(compiled_flatbuffer) - driver = get_driver(device) - config = Config(device=driver.create_default_device()) + config = Config(device=device) ctx = SystemContext(config=config) vm_module = VmModule.copy_buffer(ctx.instance, compiled_flatbuffer.map_memory()) ctx.add_vm_module(vm_module) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 9cd2f0d7c..658572e19 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize("target_backend", ["amd-aie"]) -def test_smol_matmul(session_module, target_backend): +def test_smol_matmul(session_module, target_backend, device): session, module = session_module @func(T.tensor(32, 16, T.i8()), T.tensor(16, 32, T.i8())) @@ -23,7 +23,7 @@ def matmul_i8_i32(lhs, rhs): arg0 = np.ones((32, 16), dtype=np.int8) arg1 = np.ones((16, 32), dtype=np.int8) - with invokable_module(session, module) as module: + with invokable_module(session, module, device) as module: results = module[matmul_i8_i32.__name__](arg0, arg1).to_host() assert np.array_equal(results, arg0 @ arg1) @@ -98,7 +98,6 @@ def matmul(lhs, rhs): (128, 256, 128), ] - small_i8_shapes_small = [ (64, 64, 64), (128, 256, 128), @@ -136,6 +135,7 @@ def test_matmul( lower_to_aie_pipeline, tile_pipeline, num_repeat_runs, + device, ): session, module = session_module @@ -146,8 +146,9 @@ def test_matmul( acc_type = mlir_type_to_np_dtype(acc_type) arg0 = np.ones((M, K), dtype=lhs_rhs_type) arg1 = np.ones((K, N), dtype=lhs_rhs_type) - with invokable_module(session, module) as module: + with invokable_module(session, module, device) as module: for i in range(num_repeat_runs): + print(f"{matmul_name} run {i}") results = module[matmul_name](arg0, arg1).to_host() assert np.array_equal( results, (arg0.astype(acc_type) @ arg1.astype(acc_type))