diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 655325ebe190..16ea8149e3cb 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -1259,6 +1260,46 @@ inline const char* ArgTypeCode2Str(int type_code) { } } +/*! \brief Convert a DLDeviceType to string */ +inline std::string DLDeviceType2Str(DLDeviceType ty) { + switch (ty) { + case DLDeviceType::kDLCPU: + return "cpu"; + case DLDeviceType::kDLCUDA: + return "cuda"; + case DLDeviceType::kDLCUDAHost: + return "cuda-host"; + case DLDeviceType::kDLOpenCL: + return "opencl"; + case DLDeviceType::kDLVulkan: + return "vulkan"; + case DLDeviceType::kDLMetal: + return "metal"; + case DLDeviceType::kDLVPI: + return "vpi"; + case DLDeviceType::kDLROCM: + return "rocm"; + case DLDeviceType::kDLROCMHost: + return "rocm-host"; + case DLDeviceType::kDLCUDAManaged: + return "cuda-managed"; + case DLDeviceType::kDLOneAPI: + return "oneapi"; + case DLDeviceType::kDLWebGPU: + return "webgpu"; + case DLDeviceType::kDLHexagon: + return "hexagon"; + default: + return "Device(" + std::to_string(ty) + ")"; + } + throw; +} + +inline std::ostream& operator<<(std::ostream& os, const DLDevice& device) { + os << DLDeviceType2Str(device.device_type) << ":" << device.device_id; + return os; +} + namespace detail { template