Skip to content

Commit

Permalink
Merge pull request #59 from Eladlev/support_local_gpu
Browse files Browse the repository at this point in the history
Support inference with local gpu
  • Loading branch information
Eladlev authored Apr 26, 2024
2 parents 7f373f2 + 1b1bd03 commit 3bd6734
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ def get_llm(config: dict):


elif config['type'] == 'HuggingFacePipeline':
device = config.get('gpu_device', -1)
device_map = config.get('device_map', None)

return HuggingFacePipeline.from_model_id(
model_id=config['name'],
task="text-generation",
pipeline_kwargs={"max_new_tokens": config['max_new_tokens']},
device=device,
device_map=device_map
)
else:
raise NotImplementedError("LLM not implemented")
Expand Down

0 comments on commit 3bd6734

Please sign in to comment.