Skip to content

Commit 689d4ce

Browse files
RobinPicardrlouf
authored andcommitted
Add caching to the function cached_create_states_mapping
1 parent 5149c0e commit 689d4ce

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

outlines/processors/guide.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
from outlines import grammars
18+
from outlines.caching import cache
1819
from outlines.fsm.parsing import PartialLark, PartialParserState
1920

2021
if TYPE_CHECKING:
@@ -72,6 +73,7 @@ def copy(self):
7273
return self
7374

7475

76+
@cache()
7577
def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
7678
return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)
7779

tests/processors/test_guide.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
import pytest
22

3+
import llama_cpp
4+
import transformers
5+
6+
import outlines
37
from outlines.processors.guide import CFGGuide, Generate, RegexGuide, StopAtEOSGuide, Write
8+
from outlines import caching
9+
10+
try:
11+
import mlx_lm
12+
HAS_MLX = True
13+
except ImportError:
14+
HAS_MLX = False
15+
16+
try:
17+
import vllm
18+
HAS_VLLM = True
19+
except ImportError:
20+
HAS_VLLM = False
421

522

623
def assert_expected_tensor_ids(tensor, ids):
@@ -181,6 +198,45 @@ def convert_token_to_string(self, token):
181198
assert fsm.is_final_state(state)
182199

183200

201+
def test_regex_guide_caching():
202+
assert caching._caching_enabled
203+
204+
cache = caching.get_cache()
205+
_, _ = cache.stats(enable=True, reset=True) # (hits, misses)
206+
207+
regex = r"[0-9]{3}"
208+
209+
models = [
210+
outlines.from_transformers(
211+
transformers.AutoModelForCausalLM.from_pretrained("erwanf/gpt2-mini"),
212+
transformers.AutoTokenizer.from_pretrained("erwanf/gpt2-mini")
213+
),
214+
outlines.from_llamacpp(llama_cpp.Llama.from_pretrained(
215+
repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF",
216+
filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf",
217+
))
218+
]
219+
if HAS_MLX:
220+
models.append(outlines.from_mlxlm(
221+
*mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
222+
))
223+
if HAS_VLLM:
224+
models.append(outlines.from_vllm(vllm.LLM("erwanf/gpt2-mini")))
225+
226+
for i, model in enumerate(models):
227+
# First call for each model should be a miss
228+
RegexGuide.from_regex(regex, model.tokenizer)
229+
expected_misses = i + 1
230+
expected_hits = i
231+
assert cache.stats(enable=True, reset=False) == (expected_hits, expected_misses)
232+
233+
# Second call for each model
234+
RegexGuide.from_regex(regex, model.tokenizer)
235+
expected_misses = i + 1
236+
expected_hits = i + 1
237+
assert cache.stats(enable=True, reset=False) == (expected_hits, expected_misses)
238+
239+
184240
def test_cfg():
185241
class MockTokenizer:
186242
vocabulary = {"{": 1, "}": 2, "[": 3, "]": 4, "eos": 5}

0 commit comments

Comments
 (0)