Skip to content

Commit

Permalink
Add support for json_mode unsupported models
Browse files Browse the repository at this point in the history
  • Loading branch information
whoisarpit committed Feb 4, 2025
1 parent 031849e commit b745916
Showing 1 changed file with 39 additions and 3 deletions.
42 changes: 39 additions & 3 deletions patchwork/steps/SimplifiedLLM/SimplifiedLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@


class SimplifiedLLM(Step):
# Models that don't support native JSON mode
JSON_MODE_UNSUPPORTED_MODELS = {
"gemini-2.0-flash-thinking-exp",
# Add other models here as needed
}

def __init__(self, inputs):
super().__init__(inputs)
missing_keys = SimplifiedLLMInputs.__required_keys__.difference(set(inputs.keys()))
Expand All @@ -28,6 +34,7 @@ def __init__(self, inputs):
self.is_json_mode = inputs.get("json", False)
self.json_example = inputs.get("json_example")
self.inputs = inputs
self.is_json_mode_unsupported = inputs.get("model") in self.JSON_MODE_UNSUPPORTED_MODELS

def __record_status_or_raise(self, retry_data: RetryData, step: Step):
if retry_data.retry_count == retry_data.retry_limit or step.status != StepStatus.FAILED:
Expand All @@ -49,15 +56,31 @@ def __json_loads(json_str: str) -> dict:
logger.debug(f"Json to decode: \n{json_str}\nError: \n{e}")
raise e

@staticmethod
def w(text: str) -> str:
try:
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1:
return text[start : end + 1]
return text
except Exception:
return text

def __retry_unit(self, prepare_prompt_outputs, call_llm_inputs, retry_data: RetryData):
call_llm = CallLLM(call_llm_inputs)
call_llm_outputs = call_llm.run()
self.__record_status_or_raise(retry_data, call_llm)

if self.is_json_mode:
json_responses = []

for response in call_llm_outputs.get("openai_responses"):
try:
# For models that don't support JSON mode, extract JSON from the text response first
if self.is_json_mode_unsupported:
response = self.w(response)

json_response = self.__json_loads(response)
json_responses.append(json_response)
except json.JSONDecodeError as e:
Expand Down Expand Up @@ -91,6 +114,14 @@ def run(self) -> dict:
prompts = [dict(role="user", content=self.user)]
if self.system:
prompts.insert(0, dict(role="system", content=self.system))

# Special handling for models that don't support JSON mode
if self.is_json_mode_unsupported and self.is_json_mode and self.json_example:
# Append JSON example to user message
prompts[-1][
"content"
] += f"\nPlease format your response as a JSON object like this example:\n{json.dumps(self.json_example, indent=2)}"

prepare_prompt_inputs = dict(
prompt_template=prompts,
prompt_values=self.prompt_values,
Expand All @@ -100,9 +131,14 @@ def run(self) -> dict:
self.set_status(prepare_prompt.status, prepare_prompt.status_message)

model_keys = [key for key in self.inputs.keys() if key.startswith("model_")]
response_format = dict(type="json_object" if self.is_json_mode else "text")
if self.json_example is not None:
response_format = example_json_to_schema(self.json_example)

# Set response format based on model and mode
response_format = None
if not self.is_json_mode_unsupported:
response_format = dict(type="json_object" if self.is_json_mode else "text")
if self.json_example is not None:
response_format = example_json_to_schema(self.json_example)

call_llm_inputs = {
"prompts": prepare_prompt_outputs.get("prompts"),
**{
Expand Down

0 comments on commit b745916

Please sign in to comment.