Skip to content

Commit

Permalink
Use static validation before compiling
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 582032509
  • Loading branch information
david-lindner committed Nov 13, 2023
1 parent 9de2f88 commit 27ba32b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
29 changes: 23 additions & 6 deletions tracr/compiler/compiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
from tracr.compiler import craft_model_to_transformer
from tracr.compiler import expr_to_craft_graph
from tracr.compiler import rasp_to_graph
from tracr.compiler import validating
from tracr.craft import bases
from tracr.rasp import rasp


COMPILER_BOS = "compiler_bos"
COMPILER_PAD = "compiler_pad"

Expand All @@ -36,13 +38,14 @@ def compile_rasp_to_model(
causal: bool = False,
compiler_bos: str = COMPILER_BOS,
compiler_pad: str = COMPILER_PAD,
mlp_exactness: int = 100) -> assemble.AssembledTransformerModel:
mlp_exactness: int = 100,
) -> assemble.AssembledTransformerModel:
"""Compile a RASP program to transformer weights.
Note that currently not all RASP features are supported. Most unsupported
features are detected at compile time and will cause a NotImplementedError.
However, a few unsupported features cannot be checked at compile time and
can cause silent errors.
can cause silent errors.
See `compiler.validating` for details and a function to quickly check if
a program is compilable with Tracr without needing to compile it.
Expand Down Expand Up @@ -70,12 +73,26 @@ def compile_rasp_to_model(
"""

if compiler_bos in vocab:
raise ValueError("Compiler BOS token must not be present in the vocab. "
f"Found '{compiler_bos}' in {vocab}")
raise ValueError(
"Compiler BOS token must not be present in the vocab. "
f"Found '{compiler_bos}' in {vocab}"
)

if compiler_pad in vocab:
raise ValueError("Compiler PAD token must not be present in the vocab. "
f"Found '{compiler_pad}' in {vocab}")
raise ValueError(
"Compiler PAD token must not be present in the vocab. "
f"Found '{compiler_pad}' in {vocab}"
)

# Perform static validation to fail fast. This catches most programs that
# tracr is unable to compile.
unsupported_exprs = validating.static_validate(program)
if unsupported_exprs:
error_message = "\n".join(
(f"{expr.expr.name}: {expr.reason}" for expr in unsupported_exprs)
)
error_message = f"Unsupported RASP expressions:\n{error_message}"
raise NotImplementedError(error_message)

extracted = rasp_to_graph.extract_rasp_graph(program)
graph, sources, sink = extracted.graph, extracted.sources, extracted.sink
Expand Down
12 changes: 12 additions & 0 deletions tracr/compiler/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,4 +477,16 @@
vocab={1, 2, 3},
max_seq_len=5,
),
dict(
testcase_name="numerical_tokens",
program=rasp.numerical(rasp.tokens),
vocab={1, 2, 3},
max_seq_len=5,
),
dict(
testcase_name="numerical_indices",
program=rasp.numerical(rasp.indices),
vocab={1, 2, 3},
max_seq_len=5,
),
]

0 comments on commit 27ba32b

Please sign in to comment.