diff --git a/client/astra_assistants/astra_assistants_event_handler.py b/client/astra_assistants/astra_assistants_event_handler.py index 79c642b..0f76e6c 100644 --- a/client/astra_assistants/astra_assistants_event_handler.py +++ b/client/astra_assistants/astra_assistants_event_handler.py @@ -41,6 +41,7 @@ def on_tool_call_done(self, tool_call): tool_call_results_string = tool_call_results_obj else: tool_call_results_string = self.tool_call_results + print(f"tool_call.id {tool_call.id}") self.tool_output = ToolOutput( tool_call_id=tool_call.id, output=tool_call_results_string diff --git a/client/astra_assistants/astra_assistants_manager.py b/client/astra_assistants/astra_assistants_manager.py index 65067d3..53f6d79 100644 --- a/client/astra_assistants/astra_assistants_manager.py +++ b/client/astra_assistants/astra_assistants_manager.py @@ -67,13 +67,13 @@ def create_assistant(self): model=self.model, tools=tool_functions ) - print("Assistant created:", self.assistant) + logger.debug("Assistant created:", self.assistant) return self.assistant def create_thread(self): # Create and return a new thread thread = self.client.beta.threads.create() - print("Thread generated:", thread) + logger.debug("Thread generated:", thread) return thread def stream_thread(self, content, tool_choice = None, thread_id: str = None, thread = None, additional_instructions = None): diff --git a/client/astra_assistants/async_openai_with_default_key.py b/client/astra_assistants/async_openai_with_default_key.py index 0e029e4..e4a6446 100644 --- a/client/astra_assistants/async_openai_with_default_key.py +++ b/client/astra_assistants/async_openai_with_default_key.py @@ -1,11 +1,11 @@ import logging import os -from openai import OpenAI +from openai import AsyncOpenAI logger = logging.getLogger(__name__) -class AsyncOpenAIWithDefaultKey(OpenAI): +class AsyncOpenAIWithDefaultKey(AsyncOpenAI): def __init__(self, *args, **kwargs): key = os.environ.get("OPENAI_API_KEY", "dummy") if key == "dummy": diff --git a/client/astra_assistants/tools/structured_code/delete.py b/client/astra_assistants/tools/structured_code/delete.py index 2bb9fa5..461491d 100644 --- a/client/astra_assistants/tools/structured_code/delete.py +++ b/client/astra_assistants/tools/structured_code/delete.py @@ -5,6 +5,7 @@ from astra_assistants.tools.structured_code.program_cache import StructuredProgramEntry, ProgramCache, StructuredProgram from astra_assistants.tools.tool_interface import ToolInterface +from astra_assistants.utils import copy_program_from_cache class StructuredEditDelete(BaseModel): @@ -39,13 +40,7 @@ def set_program_id(self, program_id): def call(self, edit: StructuredEditDelete): try: - program = None - for entry in self.program_cache: - if entry.program_id == self.program_id: - program = entry.program.copy() - break - if not program: - raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?") + program = copy_program_from_cache(self.program_id, self.program_cache) if edit.end_line_number: del program.lines[edit.start_line_number-1:edit.end_line_number] @@ -54,10 +49,11 @@ def call(self, edit: StructuredEditDelete): new_program_id = str(uuid1()) entry = StructuredProgramEntry(program_id=new_program_id, program=program) - self.program_cache.append(entry) + self.program_cache.add(entry) print(f"program after edit: \n{program.to_string()}") return {'program_id': new_program_id, 'output': program} except Exception as e: print(f"Error: {e}") raise e + diff --git a/client/astra_assistants/tools/structured_code/indent.py b/client/astra_assistants/tools/structured_code/indent.py index 526427b..18f776d 100644 --- a/client/astra_assistants/tools/structured_code/indent.py +++ b/client/astra_assistants/tools/structured_code/indent.py @@ -7,6 +7,7 @@ from astra_assistants.tools.structured_code.program_cache import ProgramCache, StructuredProgramEntry, StructuredProgram from astra_assistants.tools.tool_interface import ToolInterface +from astra_assistants.utils import copy_program_from_cache ts_language = Language(tspython.language()) parser = Parser(ts_language) @@ -74,7 +75,7 @@ def get_indentation_from_node(node, source_code): else: print("No indentation found around the specified line.") if target_line > 0: - return get_indentation_unit(source_code, target_line-1) + return get_indentation_unit(source_code, target_line - 1) else: print(f"No node found at line {target_line + 1}.") if target_line > 0: @@ -84,9 +85,13 @@ def get_indentation_from_node(node, source_code): class IndentLeftEdit(BaseModel): - thoughts: str = Field(..., description="The message to be described to the user explaining how the indent left edit will work, think step by step.") - start_line_number: int = Field(..., description="Line number where the indent left edit starts (first line is line 1). ALWAYS requried") - end_line_number: Optional[int] = Field(None, description="Line number where the indent left edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)") + thoughts: str = Field(..., + description="The message to be described to the user explaining how the indent left edit will work, think step by step.") + start_line_number: int = Field(..., + description="Line number where the indent left edit starts (first line is line 1). ALWAYS requried") + end_line_number: Optional[int] = Field(None, + description="Line number where the indent left edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)") + class Config: schema_extra = { "examples": [ @@ -109,9 +114,13 @@ class Config: class IndentRightEdit(BaseModel): - thoughts: str = Field(..., description="The message to be described to the user explaining how the indent right edit will work, think step by step.") - start_line_number: int = Field(..., description="Line number where the indent right edit starts (first line is line 1). ALWAYS requried") - end_line_number: Optional[int] = Field(None, description="Line number where the indent right edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)") + thoughts: str = Field(..., + description="The message to be described to the user explaining how the indent right edit will work, think step by step.") + start_line_number: int = Field(..., + description="Line number where the indent right edit starts (first line is line 1). ALWAYS requried") + end_line_number: Optional[int] = Field(None, + description="Line number where the indent right edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)") + class Config: schema_extra = { "examples": [ @@ -144,22 +153,17 @@ def __init__(self, program_cache: ProgramCache): def set_program_id(self, program_id): self.program_id = program_id - def call(self, edit: IndentRightEdit): try: - program = None - for entry in self.program_cache: - if entry.program_id == self.program_id: - program = entry.program.copy() - break - if not program: - raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?") + program = copy_program_from_cache(self.program_id, self.program_cache) + print(f"program before edit: \n{program.to_string()}") print(f"edit: {edit}") - indentation_unit = get_indentation_unit(program.to_string(with_line_numbers=False), edit.start_line_number-1) + indentation_unit = get_indentation_unit(program.to_string(with_line_numbers=False), + edit.start_line_number - 1) - i = edit.start_line_number-1 + i = edit.start_line_number - 1 if edit.end_line_number is not None: while i < edit.end_line_number: program.lines[i] = f"{indentation_unit}{program.lines[i]}" @@ -169,7 +173,7 @@ def call(self, edit: IndentRightEdit): new_program_id = str(uuid1()) entry = StructuredProgramEntry(program_id=new_program_id, program=program) - self.program_cache.append(entry) + self.program_cache.add(entry) print(f"program after edit: \n{program.to_string()}") return {'program_id': new_program_id, 'output': program} except Exception as e: @@ -188,21 +192,16 @@ def __init__(self, program_cache: ProgramCache): def set_program_id(self, program_id): self.program_id = program_id - def call(self, edit: IndentLeftEdit): try: - program = None - for entry in self.program_cache: - if entry.program_id == self.program_id: - program = entry.program.copy() - break - if not program: - raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?") + program = copy_program_from_cache(self.program_id, self.program_cache) + print(f"program before edit: \n{program.to_string()}") print(f"edit: {edit}") - indentation_unit = get_indentation_unit(program.to_string(with_line_numbers=False), edit.start_line_number-1) - i = edit.start_line_number-1 + indentation_unit = get_indentation_unit(program.to_string(with_line_numbers=False), + edit.start_line_number - 1) + i = edit.start_line_number - 1 if edit.end_line_number is not None: while i < edit.end_line_number and i < len(program.lines): program.lines[i] = program.lines[i].replace(indentation_unit, "", 1) @@ -212,7 +211,7 @@ def call(self, edit: IndentLeftEdit): new_program_id = str(uuid1()) entry = StructuredProgramEntry(program_id=new_program_id, program=program) - self.program_cache.append(entry) + self.program_cache.add(entry) print(f"program after edit: \n{program.to_string()}") return {'program_id': new_program_id, 'output': program} except Exception as e: diff --git a/client/astra_assistants/tools/structured_code/insert.py b/client/astra_assistants/tools/structured_code/insert.py index 50435fb..fb92be5 100644 --- a/client/astra_assistants/tools/structured_code/insert.py +++ b/client/astra_assistants/tools/structured_code/insert.py @@ -4,6 +4,7 @@ from astra_assistants.tools.structured_code.program_cache import ProgramCache, StructuredProgram from astra_assistants.tools.tool_interface import ToolInterface +from astra_assistants.utils import copy_program_from_cache class StructuredEditInsert(BaseModel): @@ -32,13 +33,7 @@ def set_program_id(self, program_id): def call(self, edit: StructuredEditInsert): try: - program = None - for entry in self.program_cache: - if entry.program_id == self.program_id: - program = entry.program.copy() - break - if not program: - raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?") + program = copy_program_from_cache(self.program_id, self.program_cache) instructions = (f"Write some code based on the instructions provided.\n" f"## Instructions:\n" diff --git a/client/astra_assistants/tools/structured_code/program_cache.py b/client/astra_assistants/tools/structured_code/program_cache.py index 97dbff5..6a50dda 100644 --- a/client/astra_assistants/tools/structured_code/program_cache.py +++ b/client/astra_assistants/tools/structured_code/program_cache.py @@ -58,23 +58,45 @@ class Config: arbitrary_types_allowed = True -class ProgramCache(list): +class ProgramCache: def __init__(self, *args): - super().__init__(*args) + self.cache = {} # Dictionary to hold the programs by ID + self.order = [] # List to maintain the insertion order of program IDs self.session_manager = LspSessionManager() - def append(self, item: StructuredProgramEntry) -> None: + def add(self, item: StructuredProgramEntry) -> None: + program_id = item.program_id # Assuming the program has an 'id' attribute self.process(item) - super().append(item) + + # Add or update the cache + self.cache[program_id] = item + + # Track insertion order; remove if already present (for updates), then append at the end + if program_id in self.order: + self.order.remove(program_id) + self.order.append(program_id) def extend(self, iterable: List[StructuredProgramEntry]) -> None: for item in iterable: + program_id = item.program_id self.process(item) - super().extend(iterable) - def insert(self, index: int, item: StructuredProgramEntry) -> None: - self.process(item) - super().insert(index, item) + # Add or update the cache + self.cache[program_id] = item + + # Maintain insertion order + if program_id in self.order: + self.order.remove(program_id) + self.order.append(program_id) + + def get(self, program_id) -> StructuredProgramEntry: + return self.cache.get(program_id, None) + + def get_latest(self) -> StructuredProgramEntry: + if not self.order: + return None # No entries + latest_program_id = self.order[-1] # Get the last program added + return self.cache[latest_program_id] def close(self) -> None: self.session_manager.close() @@ -165,7 +187,6 @@ def get_diagnostics(self, uri, program_str, document_version=1): if document_version == 1: notification = self.session_manager.send_notification("textDocument/didOpen", payload) else: - text_change_event = types.TextDocumentContentChangeEvent_Type2( text=program_str, ) @@ -177,7 +198,7 @@ def get_diagnostics(self, uri, program_str, document_version=1): ) did_change_payload_dict = convert_keys_to_camel_case(converter.unstructure(did_change_payload_obj)) notification = self.session_manager.send_notification("textDocument/didChange", did_change_payload_dict) - assert notification['uri'] == uri, "notification on the wrong file" + assert notification['uri'] == uri, f"Error with notification {notification} uri {uri}" diagnostics = notification["diagnostics"] diags = [] for diagnostic in diagnostics: diff --git a/client/astra_assistants/tools/structured_code/replace.py b/client/astra_assistants/tools/structured_code/replace.py index f78a077..97fb4e6 100644 --- a/client/astra_assistants/tools/structured_code/replace.py +++ b/client/astra_assistants/tools/structured_code/replace.py @@ -4,6 +4,7 @@ from astra_assistants.tools.structured_code.program_cache import ProgramCache, StructuredProgram from astra_assistants.tools.tool_interface import ToolInterface +from astra_assistants.utils import copy_program_from_cache class StructuredEditReplace(BaseModel): @@ -44,13 +45,7 @@ def set_program_id(self, program_id): def call(self, edit: StructuredEditReplace): try: - program = None - for entry in self.program_cache: - if entry.program_id == self.program_id: - program = entry.program.copy() - break - if not program: - raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?") + program = copy_program_from_cache(self.program_id, self.program_cache) instructions = (f"Write some code based on the instructions provided.\n" f"## Instructions:\n" diff --git a/client/astra_assistants/tools/structured_code/rewrite.py b/client/astra_assistants/tools/structured_code/rewrite.py index 4ef87e8..d85c2c1 100644 --- a/client/astra_assistants/tools/structured_code/rewrite.py +++ b/client/astra_assistants/tools/structured_code/rewrite.py @@ -4,10 +4,13 @@ from astra_assistants.tools.structured_code.program_cache import ProgramCache, StructuredProgram from astra_assistants.tools.tool_interface import ToolInterface +from astra_assistants.utils import copy_program_from_cache class StructuredRewrite(BaseModel): - thoughts: str = Field(..., description="The message to be described to the user explaining how the edit will work, think step by step.") + thoughts: str = Field(..., + description="The message to be described to the user explaining how the edit will work, think step by step.") + class Config: schema_extra = { "example": { @@ -27,16 +30,9 @@ def __init__(self, program_cache: ProgramCache): def set_program_id(self, program_id): self.program_id = program_id - def call(self, edit: StructuredRewrite): try: - program = None - for pair in self.program_cache: - if pair.program_id == self.program_id: - program = pair.program.copy() - break - if not program: - raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?") + program = copy_program_from_cache(self.program_id, self.program_cache) instructions = (f"Rewrite the code snippet based on the instructions provided.\n" f"## Instructions:\n" @@ -51,7 +47,8 @@ def call(self, edit: StructuredRewrite): f"{program.to_string()}") print(f"providing instructions: \n{instructions}") - return {'program_id': self.program_id, 'output': instructions, 'tool': self.__class__.__name__, 'edit': edit} + return {'program_id': self.program_id, 'output': instructions, 'tool': self.__class__.__name__, + 'edit': edit} except Exception as e: print(f"Error: {e}") raise e diff --git a/client/astra_assistants/tools/structured_code/util.py b/client/astra_assistants/tools/structured_code/util.py index 4d1bc27..bc21486 100644 --- a/client/astra_assistants/tools/structured_code/util.py +++ b/client/astra_assistants/tools/structured_code/util.py @@ -1,4 +1,5 @@ import re +import traceback from uuid import uuid1 from astra_assistants.tools.structured_code.program_cache import StructuredProgram, StructuredProgramEntry @@ -81,7 +82,7 @@ def is_valid_python_code(code: str) -> bool: def add_program_to_cache(program, program_cache): program_id = str(uuid1()) entry = StructuredProgramEntry(program_id=program_id, program=program) - program_cache.append(entry) + program_cache.add(entry) return program_id @@ -107,6 +108,10 @@ def process_program_with_tool(program, text, tool, edit): program.lines.insert(edit.start_line_number + i, line) i += 1 return program + elif tool == "StructuredCodeFileGenerator": + program = program_str_to_program(text, program.language, program.filename, program.tags, program.description) + program.filename = program.filename.split('.')[0] + '/app.py' + return program else: print(f"no changes for tool {tool}") program = program_str_to_program(text, program.language, program.filename, program.tags, program.description) @@ -114,43 +119,45 @@ def process_program_with_tool(program, text, tool, edit): def add_chunks_to_cache(chunks, cache, function=None): - first_chunk = next(chunks) - assert not isinstance(first_chunk, str) - last_program = None - if "program_id" in first_chunk: - program_id = first_chunk["program_id"] - last_program = None - for cached_program in cache: - if cached_program.program_id == program_id: - last_program = cached_program.program - break - assert last_program is not None - else: - last_program = first_chunk['program_desc'] - # If the tool expects code output in chunks output will be a string - if isinstance(first_chunk["output"], str): - text = "" - for chunk in chunks: - text += chunk - program = None - # tools like file generator don't have edits - if 'tool' in first_chunk and 'edit' in first_chunk: - tool = first_chunk['tool'] - edit = first_chunk['edit'] - text = sanitize_program_str(text, last_program.language) - program = process_program_with_tool(last_program, text, tool, edit) - print(f"edit: \n{edit}\ntext: \n{text}") + try: + first_chunk = next(chunks) + assert not isinstance(first_chunk, str) + if "program_id" in first_chunk: + program_id = first_chunk["program_id"] + last_program = cache.get(program_id).program else: - program = program_str_to_program(text, last_program.language, last_program.filename, last_program.tags, - last_program.description) - print(f"program after edit: \n{program.to_string()}") - program_id = add_program_to_cache(program, cache) - if function is not None: - function(chunks, text) - return {'program_id': program_id, 'output': program} - else: - if function is not None: - function(chunks, first_chunk) - return first_chunk + last_program = first_chunk['program_desc'] + # If the tool expects code output in chunks output will be a string + if isinstance(first_chunk["output"], str): + text = "" + for chunk in chunks: + text += chunk + program = None + # tools like file generator don't have edits + if 'tool' in first_chunk: + edit = None + if 'edit' in first_chunk: + edit = first_chunk['edit'] + tool = first_chunk['tool'] + text = sanitize_program_str(text, last_program.language) + program = process_program_with_tool(last_program, text, tool, edit) + print(f"edit: \n{edit}\ntext: \n{text}") + else: + program = program_str_to_program(text, last_program.language, last_program.filename, last_program.tags, + last_program.description) + print(f"program after edit: \n{program.to_string()}") + program_id = add_program_to_cache(program, cache) + if function is not None: + function(chunks, text) + return {'program_id': program_id, 'output': program} else: - raise Exception(f"No function provided to handle chunks, function required for first_chunk {first_chunk}") + if function is not None: + function(chunks, first_chunk) + return first_chunk + else: + raise Exception(f"No function provided to handle chunks, function required for first_chunk {first_chunk}") + except Exception as e: + print(f"Error: {e}") + trace = traceback.format_exc() + print(trace) + return None diff --git a/client/astra_assistants/tools/tool_interface.py b/client/astra_assistants/tools/tool_interface.py index 05305f5..dc0b6cb 100644 --- a/client/astra_assistants/tools/tool_interface.py +++ b/client/astra_assistants/tools/tool_interface.py @@ -1,3 +1,4 @@ +import json from abc import ABC, abstractmethod import inspect from pydantic import BaseModel, Field @@ -63,9 +64,10 @@ def to_function(self): "type": "function", "function": { "name": self.__class__.__name__, - "description": f"{self.__class__.__name__} function.", + #"description": f"{self.__class__.__name__} function.", + "description": self.call.__doc__ or f"{self.__class__.__name__} function.", "parameters": parameters } } - # print(json.dumps(function)) + print(json.dumps(function)) return function diff --git a/client/astra_assistants/utils.py b/client/astra_assistants/utils.py index 026f720..f3bfc58 100644 --- a/client/astra_assistants/utils.py +++ b/client/astra_assistants/utils.py @@ -51,3 +51,13 @@ def env_var_is_missing(provider: str, env_vars: dict) -> bool: return True return False + +def copy_program_from_cache(program_id, program_cache): + if program_id is None: + raise Exception("You must call set_program_id() before calling call()") + try: + program = program_cache.get(program_id).program.copy() + return program + except Exception as e: + print(f"program_id {program_id} not found in cache: Error: {e}") + raise e \ No newline at end of file diff --git a/client/pyproject.toml b/client/pyproject.toml index 1ea0dd5..028c297 100644 --- a/client/pyproject.toml +++ b/client/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "astra-assistants" -version = "2.1.0.19" +version = "2.1.1" description = "Astra Assistants API - drop in replacement for OpenAI Assistants, powered by AstraDB" authors = ["phact "] readme = "README.md" diff --git a/client/tests/astra-assistants/test_chat_completion.py b/client/tests/astra-assistants/test_chat_completion.py index 65ed47d..60e28e8 100644 --- a/client/tests/astra-assistants/test_chat_completion.py +++ b/client/tests/astra-assistants/test_chat_completion.py @@ -72,6 +72,7 @@ def print_chat_completion(model, client): def test_chat_completion_gpt4(patched_openai_client): model="gpt-4-1106-preview" + #model="openai/o1-preview" print_chat_completion(model, patched_openai_client) def test_chat_completion_gpt_4o_mini(patched_openai_client): diff --git a/client/tests/tools/conftest.py b/client/tests/tools/conftest.py index b10b78e..c5e97cc 100644 --- a/client/tests/tools/conftest.py +++ b/client/tests/tools/conftest.py @@ -54,3 +54,9 @@ def patched_openai_client(wait_for_server) -> OpenAI: oai = patch(OpenAI()) #oai = OpenAI() return oai + +@pytest.fixture(scope="function") +def openai_client(wait_for_server) -> OpenAI: + oai = OpenAI() + #oai = OpenAI() + return oai diff --git a/client/tests/tools/test_lsp_session.py b/client/tests/tools/test_lsp_session.py index bc9605d..01c7cd7 100644 --- a/client/tests/tools/test_lsp_session.py +++ b/client/tests/tools/test_lsp_session.py @@ -20,5 +20,5 @@ def test_publish_diagnostics(): programs = ProgramCache() programs.append(program) - programs[0].program.to_string(with_line_numbers=False) + programs.get_latest().program.to_string(with_line_numbers=False) programs.close() \ No newline at end of file diff --git a/client/tests/tools/test_structured_code.py b/client/tests/tools/test_structured_code.py index 5a8149c..0d33226 100644 --- a/client/tests/tools/test_structured_code.py +++ b/client/tests/tools/test_structured_code.py @@ -64,8 +64,8 @@ def test_structured_code_raw(patched_openai_client): event_handler = AstraEventHandler(patched_openai_client) event_handler.register_tool(code_replace) - program_id = programs[0].program_id - program = programs[0].program + program_id = programs.get_latest().program_id + program = programs.get_latest().program patched_openai_client.beta.threads.messages.create(thread.id, content=f"nice, now add trigonometric functions to program_id {program_id}: \n{program.to_string()}" , role="user") code_replace.set_program_id(program_id) with patched_openai_client.beta.threads.runs.create_and_stream( @@ -230,7 +230,7 @@ def factorial(n): except Exception as e: print(e) - assert len(programs) == 1 + assert len(programs.order) == 1 code_indent_left.set_program_id(program_id) chunks: ToolOutput = assistant_manager.stream_thread( content="Fix the indentation.", @@ -238,7 +238,7 @@ def factorial(n): ) tool_call_result = next(chunks) - assert len(programs) == 2 + assert len(programs.order) == 2 code_rewriter.set_program_id(tool_call_result['program_id']) chunks: ToolOutput = assistant_manager.stream_thread( @@ -247,7 +247,7 @@ def factorial(n): ) program_id = add_chunks_to_cache(chunks, programs)['program_id'] - assert len(programs) == 3 + assert len(programs.order) == 3 print(program_id) programs.close() @@ -301,7 +301,7 @@ def factorial(n): except Exception as e: print(e) - assert len(programs) == 1 + assert len(programs.order) == 1 code_indent_left.set_program_id(program_id) chunks: ToolOutput = assistant_manager.stream_thread( content="Fix the indentation.", @@ -309,7 +309,7 @@ def factorial(n): ) tool_call_result = next(chunks) - assert len(programs) == 2 + assert len(programs.order) == 2 code_rewriter.set_program_id(tool_call_result['program_id']) chunks: ToolOutput = assistant_manager.stream_thread( @@ -317,9 +317,10 @@ def factorial(n): tool_choice=code_rewriter ) - program_id = add_chunks_to_cache(chunks, programs)['program_id'] - assert program_id == programs[len(programs)-1].program_id - assert len(programs) == 3 + result = add_chunks_to_cache(chunks, programs) + program_id = result['program_id'] + assert program_id == programs.get_latest().program_id + assert len(programs.order) == 3 print(program_id) code_insert.set_program_id(program_id) @@ -329,7 +330,7 @@ def factorial(n): ) program_id = add_chunks_to_cache(chunks, programs)['program_id'] - assert len(programs) == 4 + assert len(programs.order) == 4 print(program_id) code_delete.set_program_id(program_id) @@ -339,7 +340,7 @@ def factorial(n): ) tool_call_result = next(chunks) - assert len(programs) == 5 + assert len(programs.order) == 5 code_rewriter.set_program_id(tool_call_result['program_id']) code_replace.set_program_id(program_id) @@ -349,7 +350,7 @@ def factorial(n): ) program_id = add_chunks_to_cache(chunks, programs)['program_id'] - assert len(programs) == 6 + assert len(programs.order) == 6 print(program_id) programs.close()