Skip to content

Commit

Permalink
client v2.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
phact committed Sep 16, 2024
1 parent 0c1d860 commit b8d968a
Show file tree
Hide file tree
Showing 17 changed files with 162 additions and 131 deletions.
1 change: 1 addition & 0 deletions client/astra_assistants/astra_assistants_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def on_tool_call_done(self, tool_call):
tool_call_results_string = tool_call_results_obj
else:
tool_call_results_string = self.tool_call_results
print(f"tool_call.id {tool_call.id}")
self.tool_output = ToolOutput(
tool_call_id=tool_call.id,
output=tool_call_results_string
Expand Down
4 changes: 2 additions & 2 deletions client/astra_assistants/astra_assistants_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def create_assistant(self):
model=self.model,
tools=tool_functions
)
print("Assistant created:", self.assistant)
logger.debug("Assistant created:", self.assistant)
return self.assistant

def create_thread(self):
# Create and return a new thread
thread = self.client.beta.threads.create()
print("Thread generated:", thread)
logger.debug("Thread generated:", thread)
return thread

def stream_thread(self, content, tool_choice = None, thread_id: str = None, thread = None, additional_instructions = None):
Expand Down
4 changes: 2 additions & 2 deletions client/astra_assistants/async_openai_with_default_key.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
import os
from openai import OpenAI
from openai import AsyncOpenAI

logger = logging.getLogger(__name__)


class AsyncOpenAIWithDefaultKey(OpenAI):
class AsyncOpenAIWithDefaultKey(AsyncOpenAI):
def __init__(self, *args, **kwargs):
key = os.environ.get("OPENAI_API_KEY", "dummy")
if key == "dummy":
Expand Down
12 changes: 4 additions & 8 deletions client/astra_assistants/tools/structured_code/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from astra_assistants.tools.structured_code.program_cache import StructuredProgramEntry, ProgramCache, StructuredProgram
from astra_assistants.tools.tool_interface import ToolInterface
from astra_assistants.utils import copy_program_from_cache


class StructuredEditDelete(BaseModel):
Expand Down Expand Up @@ -39,13 +40,7 @@ def set_program_id(self, program_id):

def call(self, edit: StructuredEditDelete):
try:
program = None
for entry in self.program_cache:
if entry.program_id == self.program_id:
program = entry.program.copy()
break
if not program:
raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?")
program = copy_program_from_cache(self.program_id, self.program_cache)

if edit.end_line_number:
del program.lines[edit.start_line_number-1:edit.end_line_number]
Expand All @@ -54,10 +49,11 @@ def call(self, edit: StructuredEditDelete):

new_program_id = str(uuid1())
entry = StructuredProgramEntry(program_id=new_program_id, program=program)
self.program_cache.append(entry)
self.program_cache.add(entry)
print(f"program after edit: \n{program.to_string()}")
return {'program_id': new_program_id, 'output': program}

except Exception as e:
print(f"Error: {e}")
raise e

57 changes: 28 additions & 29 deletions client/astra_assistants/tools/structured_code/indent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from astra_assistants.tools.structured_code.program_cache import ProgramCache, StructuredProgramEntry, StructuredProgram
from astra_assistants.tools.tool_interface import ToolInterface
from astra_assistants.utils import copy_program_from_cache

ts_language = Language(tspython.language())
parser = Parser(ts_language)
Expand Down Expand Up @@ -74,7 +75,7 @@ def get_indentation_from_node(node, source_code):
else:
print("No indentation found around the specified line.")
if target_line > 0:
return get_indentation_unit(source_code, target_line-1)
return get_indentation_unit(source_code, target_line - 1)
else:
print(f"No node found at line {target_line + 1}.")
if target_line > 0:
Expand All @@ -84,9 +85,13 @@ def get_indentation_from_node(node, source_code):


class IndentLeftEdit(BaseModel):
thoughts: str = Field(..., description="The message to be described to the user explaining how the indent left edit will work, think step by step.")
start_line_number: int = Field(..., description="Line number where the indent left edit starts (first line is line 1). ALWAYS requried")
end_line_number: Optional[int] = Field(None, description="Line number where the indent left edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)")
thoughts: str = Field(...,
description="The message to be described to the user explaining how the indent left edit will work, think step by step.")
start_line_number: int = Field(...,
description="Line number where the indent left edit starts (first line is line 1). ALWAYS requried")
end_line_number: Optional[int] = Field(None,
description="Line number where the indent left edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)")

class Config:
schema_extra = {
"examples": [
Expand All @@ -109,9 +114,13 @@ class Config:


class IndentRightEdit(BaseModel):
thoughts: str = Field(..., description="The message to be described to the user explaining how the indent right edit will work, think step by step.")
start_line_number: int = Field(..., description="Line number where the indent right edit starts (first line is line 1). ALWAYS requried")
end_line_number: Optional[int] = Field(None, description="Line number where the indent right edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)")
thoughts: str = Field(...,
description="The message to be described to the user explaining how the indent right edit will work, think step by step.")
start_line_number: int = Field(...,
description="Line number where the indent right edit starts (first line is line 1). ALWAYS requried")
end_line_number: Optional[int] = Field(None,
description="Line number where the indent right edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)")

class Config:
schema_extra = {
"examples": [
Expand Down Expand Up @@ -144,22 +153,17 @@ def __init__(self, program_cache: ProgramCache):
def set_program_id(self, program_id):
self.program_id = program_id


def call(self, edit: IndentRightEdit):
try:
program = None
for entry in self.program_cache:
if entry.program_id == self.program_id:
program = entry.program.copy()
break
if not program:
raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?")
program = copy_program_from_cache(self.program_id, self.program_cache)

print(f"program before edit: \n{program.to_string()}")
print(f"edit: {edit}")

indentation_unit = get_indentation_unit(program.to_string(with_line_numbers=False), edit.start_line_number-1)
indentation_unit = get_indentation_unit(program.to_string(with_line_numbers=False),
edit.start_line_number - 1)

i = edit.start_line_number-1
i = edit.start_line_number - 1
if edit.end_line_number is not None:
while i < edit.end_line_number:
program.lines[i] = f"{indentation_unit}{program.lines[i]}"
Expand All @@ -169,7 +173,7 @@ def call(self, edit: IndentRightEdit):

new_program_id = str(uuid1())
entry = StructuredProgramEntry(program_id=new_program_id, program=program)
self.program_cache.append(entry)
self.program_cache.add(entry)
print(f"program after edit: \n{program.to_string()}")
return {'program_id': new_program_id, 'output': program}
except Exception as e:
Expand All @@ -188,21 +192,16 @@ def __init__(self, program_cache: ProgramCache):
def set_program_id(self, program_id):
self.program_id = program_id


def call(self, edit: IndentLeftEdit):
try:
program = None
for entry in self.program_cache:
if entry.program_id == self.program_id:
program = entry.program.copy()
break
if not program:
raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?")
program = copy_program_from_cache(self.program_id, self.program_cache)

print(f"program before edit: \n{program.to_string()}")
print(f"edit: {edit}")

indentation_unit = get_indentation_unit(program.to_string(with_line_numbers=False), edit.start_line_number-1)
i = edit.start_line_number-1
indentation_unit = get_indentation_unit(program.to_string(with_line_numbers=False),
edit.start_line_number - 1)
i = edit.start_line_number - 1
if edit.end_line_number is not None:
while i < edit.end_line_number and i < len(program.lines):
program.lines[i] = program.lines[i].replace(indentation_unit, "", 1)
Expand All @@ -212,7 +211,7 @@ def call(self, edit: IndentLeftEdit):

new_program_id = str(uuid1())
entry = StructuredProgramEntry(program_id=new_program_id, program=program)
self.program_cache.append(entry)
self.program_cache.add(entry)
print(f"program after edit: \n{program.to_string()}")
return {'program_id': new_program_id, 'output': program}
except Exception as e:
Expand Down
9 changes: 2 additions & 7 deletions client/astra_assistants/tools/structured_code/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from astra_assistants.tools.structured_code.program_cache import ProgramCache, StructuredProgram
from astra_assistants.tools.tool_interface import ToolInterface
from astra_assistants.utils import copy_program_from_cache


class StructuredEditInsert(BaseModel):
Expand Down Expand Up @@ -32,13 +33,7 @@ def set_program_id(self, program_id):

def call(self, edit: StructuredEditInsert):
try:
program = None
for entry in self.program_cache:
if entry.program_id == self.program_id:
program = entry.program.copy()
break
if not program:
raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?")
program = copy_program_from_cache(self.program_id, self.program_cache)

instructions = (f"Write some code based on the instructions provided.\n"
f"## Instructions:\n"
Expand Down
41 changes: 31 additions & 10 deletions client/astra_assistants/tools/structured_code/program_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,45 @@ class Config:
arbitrary_types_allowed = True


class ProgramCache(list):
class ProgramCache:
def __init__(self, *args):
super().__init__(*args)
self.cache = {} # Dictionary to hold the programs by ID
self.order = [] # List to maintain the insertion order of program IDs
self.session_manager = LspSessionManager()

def append(self, item: StructuredProgramEntry) -> None:
def add(self, item: StructuredProgramEntry) -> None:
program_id = item.program_id # Assuming the program has an 'id' attribute
self.process(item)
super().append(item)

# Add or update the cache
self.cache[program_id] = item

# Track insertion order; remove if already present (for updates), then append at the end
if program_id in self.order:
self.order.remove(program_id)
self.order.append(program_id)

def extend(self, iterable: List[StructuredProgramEntry]) -> None:
for item in iterable:
program_id = item.program_id
self.process(item)
super().extend(iterable)

def insert(self, index: int, item: StructuredProgramEntry) -> None:
self.process(item)
super().insert(index, item)
# Add or update the cache
self.cache[program_id] = item

# Maintain insertion order
if program_id in self.order:
self.order.remove(program_id)
self.order.append(program_id)

def get(self, program_id) -> StructuredProgramEntry:
return self.cache.get(program_id, None)

def get_latest(self) -> StructuredProgramEntry:
if not self.order:
return None # No entries
latest_program_id = self.order[-1] # Get the last program added
return self.cache[latest_program_id]

def close(self) -> None:
self.session_manager.close()
Expand Down Expand Up @@ -165,7 +187,6 @@ def get_diagnostics(self, uri, program_str, document_version=1):
if document_version == 1:
notification = self.session_manager.send_notification("textDocument/didOpen", payload)
else:

text_change_event = types.TextDocumentContentChangeEvent_Type2(
text=program_str,
)
Expand All @@ -177,7 +198,7 @@ def get_diagnostics(self, uri, program_str, document_version=1):
)
did_change_payload_dict = convert_keys_to_camel_case(converter.unstructure(did_change_payload_obj))
notification = self.session_manager.send_notification("textDocument/didChange", did_change_payload_dict)
assert notification['uri'] == uri, "notification on the wrong file"
assert notification['uri'] == uri, f"Error with notification {notification} uri {uri}"
diagnostics = notification["diagnostics"]
diags = []
for diagnostic in diagnostics:
Expand Down
9 changes: 2 additions & 7 deletions client/astra_assistants/tools/structured_code/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from astra_assistants.tools.structured_code.program_cache import ProgramCache, StructuredProgram
from astra_assistants.tools.tool_interface import ToolInterface
from astra_assistants.utils import copy_program_from_cache


class StructuredEditReplace(BaseModel):
Expand Down Expand Up @@ -44,13 +45,7 @@ def set_program_id(self, program_id):

def call(self, edit: StructuredEditReplace):
try:
program = None
for entry in self.program_cache:
if entry.program_id == self.program_id:
program = entry.program.copy()
break
if not program:
raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?")
program = copy_program_from_cache(self.program_id, self.program_cache)

instructions = (f"Write some code based on the instructions provided.\n"
f"## Instructions:\n"
Expand Down
17 changes: 7 additions & 10 deletions client/astra_assistants/tools/structured_code/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

from astra_assistants.tools.structured_code.program_cache import ProgramCache, StructuredProgram
from astra_assistants.tools.tool_interface import ToolInterface
from astra_assistants.utils import copy_program_from_cache


class StructuredRewrite(BaseModel):
thoughts: str = Field(..., description="The message to be described to the user explaining how the edit will work, think step by step.")
thoughts: str = Field(...,
description="The message to be described to the user explaining how the edit will work, think step by step.")

class Config:
schema_extra = {
"example": {
Expand All @@ -27,16 +30,9 @@ def __init__(self, program_cache: ProgramCache):
def set_program_id(self, program_id):
self.program_id = program_id


def call(self, edit: StructuredRewrite):
try:
program = None
for pair in self.program_cache:
if pair.program_id == self.program_id:
program = pair.program.copy()
break
if not program:
raise Exception(f"Program id {self.program_id} not found, did you forget to call set_program_id()?")
program = copy_program_from_cache(self.program_id, self.program_cache)

instructions = (f"Rewrite the code snippet based on the instructions provided.\n"
f"## Instructions:\n"
Expand All @@ -51,7 +47,8 @@ def call(self, edit: StructuredRewrite):
f"{program.to_string()}")
print(f"providing instructions: \n{instructions}")

return {'program_id': self.program_id, 'output': instructions, 'tool': self.__class__.__name__, 'edit': edit}
return {'program_id': self.program_id, 'output': instructions, 'tool': self.__class__.__name__,
'edit': edit}
except Exception as e:
print(f"Error: {e}")
raise e
Loading

0 comments on commit b8d968a

Please sign in to comment.