diff --git a/src/xpu/ocl/engine_factory.hpp b/src/xpu/ocl/engine_factory.hpp index f2b6ed1292b..37862dd01aa 100644 --- a/src/xpu/ocl/engine_factory.hpp +++ b/src/xpu/ocl/engine_factory.hpp @@ -46,7 +46,7 @@ class engine_factory_t : public impl::engine_factory_t { std::vector ocl_devices; status_t status = xpu::ocl::get_devices(&ocl_devices, CL_DEVICE_TYPE_GPU); - if (status != status::success) return status; + if (status != status::success) return 0; return ocl_devices.size(); } diff --git a/src/xpu/ocl/utils.cpp b/src/xpu/ocl/utils.cpp index abe8c8ed726..59a8d28da85 100644 --- a/src/xpu/ocl/utils.cpp +++ b/src/xpu/ocl/utils.cpp @@ -190,12 +190,11 @@ static bool is_intel_platform(cl_platform_id platform) { } status_t get_devices(std::vector *devices, - cl_device_type device_type, cl_uint vendor_id /* = 0x8086 */) { + cl_device_type device_type) { cl_uint num_platforms = 0; cl_int err = clGetPlatformIDs(0, nullptr, &num_platforms); - // No platforms - a valid scenario - if (err == CL_PLATFORM_NOT_FOUND_KHR) return status::success; + if (err == CL_PLATFORM_NOT_FOUND_KHR) return status::runtime_error; OCL_CHECK(err); @@ -203,8 +202,6 @@ status_t get_devices(std::vector *devices, OCL_CHECK(clGetPlatformIDs(num_platforms, &platforms[0], nullptr)); for (size_t i = 0; i < platforms.size(); ++i) { - if (!is_intel_platform(platforms[i])) continue; - cl_uint num_devices = 0; cl_int err = clGetDeviceIDs( platforms[i], device_type, 0, nullptr, &num_devices); @@ -219,17 +216,16 @@ status_t get_devices(std::vector *devices, OCL_CHECK(clGetDeviceIDs(platforms[i], device_type, num_devices, &plat_devices[0], nullptr)); - // Use the devices for the requested vendor only. for (size_t j = 0; j < plat_devices.size(); ++j) { - cl_uint v_id; - OCL_CHECK(clGetDeviceInfo(plat_devices[j], CL_DEVICE_VENDOR_ID, - sizeof(cl_uint), &v_id, nullptr)); - if (v_id == vendor_id) { devices->push_back(plat_devices[j]); } + devices->push_back(plat_devices[j]); } } } - // No devices found but still return success - return status::success; + + if (devices->size() != 0) + return status::success; + + return status::runtime_error; } status_t get_devices(std::vector *devices, diff --git a/src/xpu/ocl/utils.hpp b/src/xpu/ocl/utils.hpp index 65364e59e4d..8b9609c0fa0 100644 --- a/src/xpu/ocl/utils.hpp +++ b/src/xpu/ocl/utils.hpp @@ -273,7 +273,7 @@ struct ext_func_t { std::string get_kernel_name(cl_kernel kernel); status_t get_devices(std::vector *devices, - cl_device_type device_type, cl_uint vendor_id = 0x8086); + cl_device_type device_type); status_t get_devices(std::vector *devices, std::vector> *sub_devices,