Skip to content

Commit 832483a

Browse files
committed
Add from_file class methods to CFG and JsonSchema
1 parent a872758 commit 832483a

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

outlines/types/dsl.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,24 @@ def _display_node(self) -> str:
190190
def __repr__(self):
191191
return f"CFG(definition='{self.definition}')"
192192

193+
def __eq__(self, other):
194+
if not isinstance(other, CFG):
195+
return False
196+
return self.definition == other.definition
197+
198+
@classmethod
199+
def from_file(cls, path: str):
200+
"""Create a CFG instance from a file containing a CFG definition.
201+
202+
Parameters
203+
----------
204+
path : str
205+
The path to the file containing the CFG definition.
206+
"""
207+
with open(path, "r") as f:
208+
definition = f.read()
209+
return cls(definition)
210+
193211

194212
@dataclass
195213
class FSM(Term):
@@ -247,6 +265,20 @@ def __eq__(self, other):
247265
except json.JSONDecodeError:
248266
return self.schema == other.schema
249267

268+
@classmethod
269+
def from_file(cls, path: str):
270+
"""Create a JsonSchema instance from a .json file containing a JSON
271+
schema.
272+
273+
Parameters
274+
----------
275+
path : str
276+
The path to the file containing the JSON schema.
277+
"""
278+
with open(path, "r") as f:
279+
schema = json.load(f)
280+
return cls(schema)
281+
250282

251283
@dataclass
252284
class KleeneStar(Term):

tests/types/test_dsl.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import datetime
22
import json
3+
import os
34
import sys
5+
import tempfile
46
from dataclasses import dataclass
57
from enum import Enum
68
from typing import (
@@ -33,6 +35,7 @@
3335
Sequence,
3436
String,
3537
Term,
38+
CFG,
3639
_handle_dict,
3740
_handle_list,
3841
_handle_literal,
@@ -285,6 +288,40 @@ def test_dsl_display():
285288
)
286289

287290

291+
def test_dsl_cfg_from_file():
292+
grammar_content = """
293+
?start: expression
294+
?expression: term (("+" | "-") term)*
295+
?term: factor (("*" | "/") factor)*
296+
?factor: NUMBER
297+
"""
298+
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=True) as temp_file:
299+
temp_file.write(grammar_content)
300+
temp_file.flush()
301+
temp_file_path = temp_file.name
302+
cfg = CFG.from_file(temp_file_path)
303+
assert cfg == CFG(grammar_content)
304+
305+
306+
def test_dsl_json_schema_from_file():
307+
schema_content = """
308+
{
309+
"type": "object",
310+
"properties": {
311+
"name": {
312+
"type": "string"
313+
}
314+
}
315+
}
316+
"""
317+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=True) as temp_file:
318+
temp_file.write(schema_content)
319+
temp_file.flush()
320+
temp_file_path = temp_file.name
321+
schema = JsonSchema.from_file(temp_file_path)
322+
assert schema == JsonSchema(schema_content)
323+
324+
288325
def test_dsl_python_types_to_terms():
289326
with pytest.raises(RecursionError):
290327
python_types_to_terms(None, 11)

0 commit comments

Comments
 (0)