Skip to content

Commit

Permalink
Add option to clear cache on instaniating model
Browse files Browse the repository at this point in the history
  • Loading branch information
riedgar-ms committed May 1, 2024
1 parent 559b341 commit 9e46101
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
5 changes: 5 additions & 0 deletions guidance/models/_azureai_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
azureai_studio_endpoint: str,
azureai_model_deployment: str,
azureai_studio_key: str,
clear_cache: bool,
):
self._endpoint = azureai_studio_endpoint
self._deployment = azureai_model_deployment
Expand All @@ -34,6 +35,8 @@ def __init__(
/ f"azureaistudio.tokens.{deployment_id}"
)
self.cache = dc.Cache(path)
if clear_cache:
self.cache.clear()

super().__init__(tokenizer, max_streaming_tokens, timeout, compute_log_probs)

Expand Down Expand Up @@ -147,6 +150,7 @@ def __init__(
max_streaming_tokens: int = 1000,
timeout: float = 0.5,
compute_log_probs: bool = False,
clear_cache: bool = False,
):
super().__init__(
AzureAIStudioChatEngine(
Expand All @@ -157,6 +161,7 @@ def __init__(
max_streaming_tokens=max_streaming_tokens,
timeout=timeout,
compute_log_probs=compute_log_probs,
clear_cache=False,
),
echo=echo,
)
3 changes: 3 additions & 0 deletions tests/models/test_azureai_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_azureai_phi3_chat_smoke(rate_limiter):
azureai_studio_endpoint=azureai_studio_endpoint,
azureai_studio_deployment=azureai_studio_deployment,
azureai_studio_key=azureai_studio_key,
clear_cache=True,
)
assert isinstance(lm, models.AzureAIStudioChat)

Expand All @@ -46,6 +47,7 @@ def test_azureai_mistral_chat_smoke(rate_limiter):
azureai_studio_endpoint=azureai_studio_endpoint,
azureai_studio_deployment=azureai_studio_deployment,
azureai_studio_key=azureai_studio_key,
clear_cache=True,
)
assert isinstance(lm, models.AzureAIStudioChat)
lm.engine.cache.clear()
Expand Down Expand Up @@ -74,6 +76,7 @@ def test_azureai_llama3_chat_smoke(rate_limiter):
azureai_studio_endpoint=azureai_studio_endpoint,
azureai_studio_deployment=azureai_studio_deployment,
azureai_studio_key=azureai_studio_key,
clear_cache=True,
)
assert isinstance(lm, models.AzureAIStudioChat)

Expand Down

0 comments on commit 9e46101

Please sign in to comment.