Skip to content

Commit

Permalink
Fix breaking changes due to updated Anthropic SDK (#452)
Browse files Browse the repository at this point in the history
Anthropic just moved their tool use from beta to main so we have to
change the import `from anthropic.types.beta.tools import ToolUseBlock`
to `from anthropic.types import ToolUseBlock`. You cannot run the eval
without this change as things break.

Also, my IDE automatically sorted the imported packages and removed some
extra spaces -- this explains all the other changes.
  • Loading branch information
eitanturok authored Jun 5, 2024
1 parent fade5e4 commit 33cabef
Showing 1 changed file with 14 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from model_handler.handler import BaseHandler
import json
import os
import time

from anthropic import Anthropic
from anthropic.types import TextBlock
from anthropic.types.beta.tools import ToolUseBlock
from model_handler.model_style import ModelStyle
from anthropic.types import TextBlock, ToolUseBlock
from model_handler.claude_prompt_handler import ClaudePromptingHandler
from model_handler.constant import GORILLA_TO_OPENAPI
from model_handler.handler import BaseHandler
from model_handler.model_style import ModelStyle
from model_handler.utils import (
convert_to_tool,
ast_parse,
augment_prompt_by_languge,
convert_to_function_call,
convert_to_tool,
language_specific_pre_processing,
ast_parse,
convert_to_function_call
)
from model_handler.constant import GORILLA_TO_OPENAPI
import os, time, json


class ClaudeFCHandler(BaseHandler):
Expand Down Expand Up @@ -52,7 +54,7 @@ def inference(self, prompt, functions, test_category):
tool_call_outputs.append({content.name: json.dumps(content.input)})
result = tool_call_outputs if tool_call_outputs else text_outputs[0]
return result, {"input_tokens": response.usage.input_tokens, "output_tokens": response.usage.output_tokens, "latency": latency}

def decode_ast(self,result,language="Python"):
if "FC" not in self.model_name:
decoded_output = ast_parse(result,language)
Expand All @@ -69,7 +71,7 @@ def decode_ast(self,result,language="Python"):
params[key] = str(params[key])
decoded_output.append({name: params})
return decoded_output

def decode_execute(self,result):
if "FC" not in self.model_name:
decoded_output = ast_parse(result)
Expand All @@ -82,4 +84,4 @@ def decode_execute(self,result):
return execution_list
else:
function_call = convert_to_function_call(result)
return function_call
return function_call

0 comments on commit 33cabef

Please sign in to comment.