Skip to content

Commit

Permalink
Changed template interpolation to work with python < 3.11
Browse files Browse the repository at this point in the history
Added a version check that uses `string.Template.get_identifiers` if Python is > 3.11, falls back to
a reimplementation.
  • Loading branch information
jeeger committed Feb 29, 2024
1 parent 39c3393 commit c10f855
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion llm/templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel
import string
from typing import Optional, Any, Dict, List, Tuple
import platform


class Template(BaseModel):
Expand Down Expand Up @@ -41,10 +42,25 @@ def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[st
return text
# Confirm all variables in text are provided
string_template = string.Template(text)
vars = string_template.get_identifiers()
vars = cls.extract_identifiers(string_template)
missing = [p for p in vars if p not in params]
if missing:
raise cls.MissingVariables(
"Missing variables: {}".format(", ".join(missing))
)
return string_template.substitute(**params)

@classmethod
def extract_identifiers(cls, template: string.Template) -> list[str]:
(major, minor, patchlevel) = platform.python_version_tuple()
if int(major) >= 3 and int(minor) >= 11:
# Added in Python 3.11
return template.get_identifiers()
else:
result = set()
# Adapted from source at https://github.com/python/cpython/blob/86e5e063aba76a7f4fc58f7d06b17b0a4730fd8e/Lib/string.py#L157
for match in template.pattern.finditer(template.template):
named = match.group("named") or match.group("braced")
if named is not None:
result.add(match.group("named"))
return list(result)

0 comments on commit c10f855

Please sign in to comment.