|
1 | 1 | import pytest
|
2 | 2 |
|
| 3 | +import llama_cpp |
| 4 | +import transformers |
| 5 | + |
| 6 | +import outlines |
3 | 7 | from outlines.processors.guide import CFGGuide, Generate, RegexGuide, StopAtEOSGuide, Write
|
| 8 | +from outlines import caching |
| 9 | + |
| 10 | +try: |
| 11 | + import mlx_lm |
| 12 | + HAS_MLX = True |
| 13 | +except ImportError: |
| 14 | + HAS_MLX = False |
| 15 | + |
| 16 | +try: |
| 17 | + import vllm |
| 18 | + HAS_VLLM = True |
| 19 | +except ImportError: |
| 20 | + HAS_VLLM = False |
4 | 21 |
|
5 | 22 |
|
6 | 23 | def assert_expected_tensor_ids(tensor, ids):
|
@@ -181,6 +198,45 @@ def convert_token_to_string(self, token):
|
181 | 198 | assert fsm.is_final_state(state)
|
182 | 199 |
|
183 | 200 |
|
| 201 | +def test_regex_guide_caching(): |
| 202 | + assert caching._caching_enabled |
| 203 | + |
| 204 | + cache = caching.get_cache() |
| 205 | + _, _ = cache.stats(enable=True, reset=True) # (hits, misses) |
| 206 | + |
| 207 | + regex = r"[0-9]{3}" |
| 208 | + |
| 209 | + models = [ |
| 210 | + outlines.from_transformers( |
| 211 | + transformers.AutoModelForCausalLM.from_pretrained("erwanf/gpt2-mini"), |
| 212 | + transformers.AutoTokenizer.from_pretrained("erwanf/gpt2-mini") |
| 213 | + ), |
| 214 | + outlines.from_llamacpp(llama_cpp.Llama.from_pretrained( |
| 215 | + repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", |
| 216 | + filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", |
| 217 | + )) |
| 218 | + ] |
| 219 | + if HAS_MLX: |
| 220 | + models.append(outlines.from_mlxlm( |
| 221 | + *mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit") |
| 222 | + )) |
| 223 | + if HAS_VLLM: |
| 224 | + models.append(outlines.from_vllm(vllm.LLM("erwanf/gpt2-mini"))) |
| 225 | + |
| 226 | + for i, model in enumerate(models): |
| 227 | + # First call for each model should be a miss |
| 228 | + RegexGuide.from_regex(regex, model.tokenizer) |
| 229 | + expected_misses = i + 1 |
| 230 | + expected_hits = i |
| 231 | + assert cache.stats(enable=True, reset=False) == (expected_hits, expected_misses) |
| 232 | + |
| 233 | + # Second call for each model |
| 234 | + RegexGuide.from_regex(regex, model.tokenizer) |
| 235 | + expected_misses = i + 1 |
| 236 | + expected_hits = i + 1 |
| 237 | + assert cache.stats(enable=True, reset=False) == (expected_hits, expected_misses) |
| 238 | + |
| 239 | + |
184 | 240 | def test_cfg():
|
185 | 241 | class MockTokenizer:
|
186 | 242 | vocabulary = {"{": 1, "}": 2, "[": 3, "]": 4, "eos": 5}
|
|
0 commit comments