Skip to content

Commit

Permalink
Merge pull request #325 from explosion/fix/model-init
Browse files Browse the repository at this point in the history
Fix Llama 2 model init and Anthropic requests
  • Loading branch information
rmitsch authored Oct 13, 2023
2 parents 3a8b28e + 1240146 commit 89b5b65
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 12 deletions.
10 changes: 1 addition & 9 deletions spacy_llm/models/hf/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def init_model(self) -> Any:
return transformers.pipeline(
"text-generation",
model=self._name,
use_auth_token=True,
return_full_text=False,
**self._config_init,
)
Expand All @@ -48,14 +47,7 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove

@staticmethod
def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]:
default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs()
return (
{
**default_cfg_init,
"trust_remote_code": True,
},
default_cfg_run,
)
return HuggingFace.compile_default_configs()


@registry.llm_models("spacy.Llama2.v1")
Expand Down
1 change: 1 addition & 0 deletions spacy_llm/models/rest/anthropic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
headers = {
**self._credentials,
"model": self._name,
"anthropic_version": self._config.get("anthropic_version", "2023-06-01"),
"Content-Type": "application/json",
}

Expand Down
6 changes: 3 additions & 3 deletions spacy_llm/tests/models/test_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"""


@pytest.mark.skip(reason="CI runner needs more GPU memory")
# @pytest.mark.skip(reason="CI runner needs more GPU memory")
@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_init():
Expand All @@ -52,7 +52,7 @@ def test_init():
)


@pytest.mark.skip(reason="CI runner needs more GPU memory")
# @pytest.mark.skip(reason="CI runner needs more GPU memory")
@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_init_from_config():
Expand All @@ -62,7 +62,7 @@ def test_init_from_config():
torch.cuda.empty_cache()


@pytest.mark.skip(reason="CI runner needs more GPU memory")
# @pytest.mark.skip(reason="CI runner needs more GPU memory")
@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_invalid_model():
Expand Down

0 comments on commit 89b5b65

Please sign in to comment.