Skip to content

Commit

Permalink
Support inference with local gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Eladlev committed Apr 26, 2024
1 parent 4f39702 commit 1b1bd03
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ def get_llm(config: dict):


elif config['type'] == 'HuggingFacePipeline':
device = config.get('gpu_device', -1)
device_map = config.get('device_map', None)

return HuggingFacePipeline.from_model_id(
model_id=config['name'],
task="text-generation",
pipeline_kwargs={"max_new_tokens": config['max_new_tokens']},
device=device,
device_map=device_map
)
else:
raise NotImplementedError("LLM not implemented")
Expand Down

0 comments on commit 1b1bd03

Please sign in to comment.