Skip to content

Commit

Permalink
Merge pull request #1595 from mikeedjones/feat/remove-get_convo-asserts
Browse files Browse the repository at this point in the history
Remove get_convo asserts from tests
  • Loading branch information
okhat authored Oct 8, 2024
2 parents d403737 + ead37b6 commit cc368f8
Show file tree
Hide file tree
Showing 49 changed files with 2,863 additions and 734 deletions.
4 changes: 2 additions & 2 deletions dsp/modules/dummy_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


# This testing module was moved in PR #735 to patch Arize Phoenix logging
class DummyLM(LM):
class DSPDummyLM(LM):
"""Dummy language model for unit testing purposes."""

def __init__(self, answers: Union[list[str], dict[str, str]], follow_examples: bool = False):
Expand Down Expand Up @@ -61,7 +61,7 @@ def basic_request(self, prompt, n=1, **kwargs) -> dict[str, list[dict[str, str]]
},
)

RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m"
RED, _, RESET = "\033[91m", "\033[92m", "\033[0m"
print("=== DummyLM ===")
print(prompt, end="")
print(f"{RED}{answer}{RESET}")
Expand Down
46 changes: 25 additions & 21 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,30 @@
import threading
from copy import deepcopy
from contextlib import contextmanager

from dsp.utils.utils import dotdict

DEFAULT_CONFIG = dotdict(
lm=None,
adapter=None,
rm=None,
branch_idx=0,
reranker=None,
compiled_lm=None,
force_reuse_cached_compilation=False,
compiling=False,
skip_logprobs=False,
trace=[],
release=0,
bypass_assert=False,
bypass_suggest=False,
assert_failures=0,
suggest_failures=0,
langchain_history=[],
experimental=False,
backoff_time=10,
)


class Settings:
"""DSP configuration settings."""
Expand All @@ -25,27 +47,9 @@ def __new__(cls):
# TODO: remove first-class support for re-ranker and potentially combine with RM to form a pipeline of sorts
# eg: RetrieveThenRerankPipeline(RetrievalModel, Reranker)
# downstream operations like dsp.retrieve would use configs from the defined pipeline.
config = dotdict(
lm=None,
adapter=None,
rm=None,
branch_idx=0,
reranker=None,
compiled_lm=None,
force_reuse_cached_compilation=False,
compiling=False, # TODO: can probably be removed
skip_logprobs=False,
trace=[],
release=0,
bypass_assert=False,
bypass_suggest=False,
assert_failures=0,
suggest_failures=0,
langchain_history=[],
experimental=False,
backoff_time = 10
)
cls._instance.__append(config)

# make a deepcopy of the default config to avoid modifying the default config
cls._instance.__append(deepcopy(DEFAULT_CONFIG))

return cls._instance

Expand Down
99 changes: 59 additions & 40 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import re
import ast
import json
import re
import textwrap
from typing import get_args, get_origin

from pydantic import TypeAdapter
import pydantic
from pydantic import TypeAdapter

from .base import Adapter
from typing import get_origin, get_args

field_header_pattern = re.compile(r'\[\[ ## (\w+) ## \]\]')
field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")


class ChatAdapter(Adapter):
Expand All @@ -21,9 +22,11 @@ def format(self, signature, demos, inputs):
# Extract demos where some of the output_fields are not filled in.
incomplete_demos = [demo for demo in demos if not all(k in demo for k in signature.fields)]
complete_demos = [demo for demo in demos if demo not in incomplete_demos]
incomplete_demos = [demo for demo in incomplete_demos \
if any(k in demo for k in signature.input_fields) and \
any(k in demo for k in signature.output_fields)]
incomplete_demos = [
demo
for demo in incomplete_demos
if any(k in demo for k in signature.input_fields) and any(k in demo for k in signature.output_fields)
]

demos = incomplete_demos + complete_demos

Expand All @@ -32,44 +35,52 @@ def format(self, signature, demos, inputs):
for demo in demos:
messages.append(format_turn(signature, demo, role="user", incomplete=demo in incomplete_demos))
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))

messages.append(format_turn(signature, inputs, role="user"))

return messages

def parse(self, signature, completion, _parse_values=True):
sections = [(None, [])]

for line in completion.splitlines():
match = field_header_pattern.match(line.strip())
if match: sections.append((match.group(1), []))
else: sections[-1][1].append(line)
if match:
sections.append((match.group(1), []))
else:
sections[-1][1].append(line)

sections = [(k, '\n'.join(v).strip()) for k, v in sections]
sections = [(k, "\n".join(v).strip()) for k, v in sections]

fields = {}
for k, v in sections:
if (k not in fields) and (k in signature.output_fields):
try:
fields[k] = parse_value(v, signature.output_fields[k].annotation) if _parse_values else v
except Exception as e:
raise ValueError(f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```")
raise ValueError(
f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```"
)

if fields.keys() != signature.output_fields.keys():
raise ValueError(f"Expected {signature.output_fields.keys()} but got {fields.keys()}")

return fields


def format_blob(blob):
if '\n' not in blob and "«" not in blob and "»" not in blob: return f"«{blob}»"
if "\n" not in blob and "«" not in blob and "»" not in blob:
return f"«{blob}»"

modified_blob = blob.replace('\n', '\n ')
modified_blob = blob.replace("\n", "\n ")
return f"«««\n {modified_blob}\n»»»"


def format_list(items):
if len(items) == 0: return "N/A"
if len(items) == 1: return format_blob(items[0])
if len(items) == 0:
return "N/A"
if len(items) == 1:
return format_blob(items[0])

return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(items)])

Expand All @@ -89,82 +100,90 @@ def format_fields(fields):
v = _format_field_value(v)
output.append(f"[[ ## {k} ## ]]\n{v}")

return '\n\n'.join(output).strip()
return "\n\n".join(output).strip()


def parse_value(value, annotation):
if annotation is str: return str(value)
if annotation is str:
return str(value)
parsed_value = value
if isinstance(value, str):
try: parsed_value = json.loads(value)
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
try: parsed_value = ast.literal_eval(value)
except (ValueError, SyntaxError): parsed_value = value
try:
parsed_value = ast.literal_eval(value)
except (ValueError, SyntaxError):
parsed_value = value
return TypeAdapter(annotation).validate_python(parsed_value)


def format_turn(signature, values, role, incomplete=False):
def format_turn(signature, values, role, incomplete=False):
content = []

if role == "user":
field_names = signature.input_fields.keys()
if incomplete:
content.append("This is an example of the task, though some input or output fields are not supplied.")
else:
field_names, values = list(signature.output_fields.keys()) + ['completed'], {**values, 'completed': ''}
field_names, values = list(signature.output_fields.keys()) + ["completed"], {**values, "completed": ""}

if not incomplete:
if not set(values).issuperset(set(field_names)):
raise ValueError(f"Expected {field_names} but got {values.keys()}")

content.append(format_fields({k: values.get(k, "Not supplied for this particular example.") for k in field_names}))

if role == "user":
content.append("Respond with the corresponding output fields, starting with the field " +
", then ".join(f"`{f}`" for f in signature.output_fields) +
", and then ending with the marker for `completed`.")
content.append(
"Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`{f}`" for f in signature.output_fields)
+ ", and then ending with the marker for `completed`."
)

return {"role": role, "content": '\n\n'.join(content).strip()}
return {"role": role, "content": "\n\n".join(content).strip()}


def get_annotation_name(annotation):
origin = get_origin(annotation)
args = get_args(annotation)
if origin is None:
if hasattr(annotation, '__name__'):
if hasattr(annotation, "__name__"):
return annotation.__name__
else:
return str(annotation)
else:
args_str = ', '.join(get_annotation_name(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
args_str = ", ".join(get_annotation_name(arg) for arg in args)
return f"{get_annotation_name(origin)}[{args_str}]"


def enumerate_fields(fields):
parts = []
for idx, (k, v) in enumerate(fields.items()):
parts.append(f"{idx+1}. `{k}`")
parts[-1] += f" ({get_annotation_name(v.annotation)})"
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra['desc'] != f'${{{k}}}' else ''
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""

return "\n".join(parts).strip()

return '\n'.join(parts).strip()

def prepare_instructions(signature):
parts = []
parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")

parts.append(format_fields({f : f"{{{f}}}" for f in signature.input_fields}))
parts.append(format_fields({f : f"{{{f}}}" for f in signature.output_fields}))
parts.append(format_fields({'completed' : ""}))
parts.append(format_fields({f: f"{{{f}}}" for f in signature.input_fields}))
parts.append(format_fields({f: f"{{{f}}}" for f in signature.output_fields}))
parts.append(format_fields({"completed": ""}))

instructions = textwrap.dedent(signature.instructions)
objective = ('\n' + ' ' * 8).join([''] + instructions.splitlines())
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
parts.append(f"In adhering to this structure, your objective is: {objective}")

# parts.append("You will receive some input fields in each interaction. " +
# "Respond only with the corresponding output fields, starting with the field " +
# ", then ".join(f"`{f}`" for f in signature.output_fields) +
# ", and then ending with the marker for `completed`.")

return '\n\n'.join(parts).strip()
return "\n\n".join(parts).strip()
Loading

0 comments on commit cc368f8

Please sign in to comment.