diff --git a/optimum/tpu/distributed_model.py b/optimum/tpu/distributed_model.py index 62a2d35c..4f3eb983 100644 --- a/optimum/tpu/distributed_model.py +++ b/optimum/tpu/distributed_model.py @@ -12,7 +12,7 @@ import torch.multiprocessing as mp from optimum.tpu.modeling import AutoModelForCausalLM -from transformers import PretrainedConfig, AutoConfig +from transformers import PretrainedConfig class ModelCommand(Enum):