diff --git a/lmdeploy/lite/apis/gptq.py b/lmdeploy/lite/apis/gptq.py index 0e67913b4..658be4c36 100644 --- a/lmdeploy/lite/apis/gptq.py +++ b/lmdeploy/lite/apis/gptq.py @@ -98,7 +98,7 @@ def auto_gptq(model: str, quantize_config, revision=revision, torch_dtype=torch_dtype, - trust_remote_code=True) + trust_remote_code=True).cuda() # quantize model, the examples should be list of dict whose keys # can only be "input_ids" and "attention_mask"