diff --git a/higgsfield/llama/llama.py b/higgsfield/llama/llama.py index 90e6068..63cbbdf 100644 --- a/higgsfield/llama/llama.py +++ b/higgsfield/llama/llama.py @@ -59,14 +59,14 @@ def __init__( if not checkpoint_path: if cpu_init_rank0: if rank == 0: - model = LlamaForCausalLM.from_pretrained(model_name) + model = LlamaForCausalLM.from_pretrained(model_name, use_cache=False) else: - llama_config = LlamaConfig.from_pretrained(model_name) + llama_config = LlamaConfig.from_pretrained(model_name, use_cache=False) with torch.device('meta'): model = LlamaForCausalLM(llama_config) else: - model = LlamaForCausalLM.from_pretrained(model_name) + model = LlamaForCausalLM.from_pretrained(model_name, use_cache=False) else: if not cpu_init_rank0: print("Ignoring cpu_init_rank0=False while loading model from checkpoint path") @@ -298,4 +298,4 @@ def __init__( precision, cpu_init_rank0, cpu_offload, - ) \ No newline at end of file + )