Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Move Source to ORM model #1979

Merged
merged 20 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions alembic/versions/cda66b6cb0d6_move_sources_to_orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Move sources to orm

Revision ID: cda66b6cb0d6
Revises: b6d7ca024aa9
Create Date: 2024-11-07 13:29:57.186107

"""

from typing import Sequence, Union

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "cda66b6cb0d6"
down_revision: Union[str, None] = "b6d7ca024aa9"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("sources", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("sources", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("sources", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("sources", sa.Column("_last_updated_by_id", sa.String(), nullable=True))

# Data migration step:
op.add_column("sources", sa.Column("organization_id", sa.String(), nullable=True))
# Populate `organization_id` based on `user_id`
# Use a raw SQL query to update the organization_id
op.execute(
"""
UPDATE sources
SET organization_id = users.organization_id
FROM users
WHERE sources.user_id = users.id
"""
)

# Set `organization_id` as non-nullable after population
op.alter_column("sources", "organization_id", nullable=False)

op.alter_column("sources", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.drop_index("sources_idx_user", table_name="sources")
op.create_foreign_key(None, "sources", "organizations", ["organization_id"], ["id"])
op.drop_column("sources", "user_id")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("sources", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.drop_constraint(None, "sources", type_="foreignkey")
op.create_index("sources_idx_user", "sources", ["user_id"], unique=False)
op.alter_column("sources", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.drop_column("sources", "organization_id")
op.drop_column("sources", "_last_updated_by_id")
op.drop_column("sources", "_created_by_id")
op.drop_column("sources", "is_deleted")
op.drop_column("sources", "updated_at")
# ### end Alembic commands ###
6 changes: 4 additions & 2 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.services.source_manager import SourceManager
from letta.services.user_manager import UserManager
from letta.system import (
get_heartbeat,
get_initial_boot_messages,
Expand Down Expand Up @@ -1311,7 +1313,7 @@ def migrate_embedding(self, embedding_config: EmbeddingConfig):
def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore):
"""Attach data with name `source_name` to the agent from source_connector."""
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory

user = UserManager().get_user_by_id(self.agent_state.user_id)
filters = {"user_id": self.agent_state.user_id, "source_id": source_id}
size = source_connector.size(filters)
page_size = 100
Expand Down Expand Up @@ -1339,7 +1341,7 @@ def attach_source(self, source_id: str, source_connector: StorageConnector, ms:
self.persistence_manager.archival_memory.storage.save()

# attach to agent
source = ms.get_source(source_id=source_id)
source = SourceManager().get_source_by_id(source_id=source_id, actor=user)
assert source is not None, f"Source {source_id} not found in metadata store"
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)

Expand Down
29 changes: 16 additions & 13 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def load_file_to_source(self, filename: str, source_id: str, blocking=True) -> J
def delete_file_from_source(self, source_id: str, file_id: str) -> None:
raise NotImplementedError

def create_source(self, name: str) -> Source:
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
raise NotImplementedError

def delete_source(self, source_id: str):
Expand Down Expand Up @@ -1188,7 +1188,7 @@ def delete_file_from_source(self, source_id: str, file_id: str) -> None:
if response.status_code not in [200, 204]:
raise ValueError(f"Failed to delete tool: {response.text}")

def create_source(self, name: str) -> Source:
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
"""
Create a source

Expand All @@ -1198,7 +1198,8 @@ def create_source(self, name: str) -> Source:
Returns:
source (Source): Created source
"""
payload = {"name": name}
source_create = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config)
payload = source_create.model_dump()
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers)
response_json = response.json()
return Source(**response_json)
Expand Down Expand Up @@ -1253,7 +1254,7 @@ def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
Returns:
source (Source): Updated source
"""
request = SourceUpdate(id=source_id, name=name)
request = SourceUpdate(name=name)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update source: {response.text}")
Expand Down Expand Up @@ -2453,7 +2454,7 @@ def list_jobs(self):
def list_active_jobs(self):
return self.server.list_active_jobs(user_id=self.user_id)

def create_source(self, name: str) -> Source:
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
"""
Create a source

Expand All @@ -2463,8 +2464,10 @@ def create_source(self, name: str) -> Source:
Returns:
source (Source): Created source
"""
request = SourceCreate(name=name)
return self.server.create_source(request=request, user_id=self.user_id)
source = Source(
name=name, embedding_config=embedding_config or self._default_embedding_config, organization_id=self.user.organization_id
)
return self.server.source_manager.create_source(source=source, actor=self.user)

def delete_source(self, source_id: str):
"""
Expand All @@ -2475,7 +2478,7 @@ def delete_source(self, source_id: str):
"""

# TODO: delete source data
self.server.delete_source(source_id=source_id, user_id=self.user_id)
self.server.delete_source(source_id=source_id, actor=self.user)

def get_source(self, source_id: str) -> Source:
"""
Expand All @@ -2487,7 +2490,7 @@ def get_source(self, source_id: str) -> Source:
Returns:
source (Source): Source
"""
return self.server.get_source(source_id=source_id, user_id=self.user_id)
return self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)

def get_source_id(self, source_name: str) -> str:
"""
Expand All @@ -2499,7 +2502,7 @@ def get_source_id(self, source_name: str) -> str:
Returns:
source_id (str): ID of the source
"""
return self.server.get_source_id(source_name=source_name, user_id=self.user_id)
return self.server.source_manager.get_source_by_name(source_name=source_name, actor=self.user).id

def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
"""
Expand Down Expand Up @@ -2532,7 +2535,7 @@ def list_sources(self) -> List[Source]:
sources (List[Source]): List of sources
"""

return self.server.list_all_sources(user_id=self.user_id)
return self.server.list_all_sources(actor=self.user)

def list_attached_sources(self, agent_id: str) -> List[Source]:
"""
Expand Down Expand Up @@ -2572,8 +2575,8 @@ def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
source (Source): Updated source
"""
# TODO should the arg here just be "source_update: Source"?
request = SourceUpdate(id=source_id, name=name)
return self.server.update_source(request=request, user_id=self.user_id)
request = SourceUpdate(name=name)
return self.server.source_manager.update_source(source_id=source_id, source_update=request, actor=self.user)

# archival memory

Expand Down
6 changes: 3 additions & 3 deletions letta/data_sources/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def load_data(
passage_store: StorageConnector,
file_metadata_store: StorageConnector,
):
"""Load data from a connector (generates file and passages) into a specified source_id, associatedw with a user_id."""
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
embedding_config = source.embedding_config

# embedding model
Expand Down Expand Up @@ -88,7 +88,7 @@ def load_data(
file_id=file_metadata.id,
source_id=source.id,
metadata_=passage_metadata,
user_id=source.user_id,
user_id=source.created_by_id,
embedding_config=source.embedding_config,
embedding=embedding,
)
Expand Down Expand Up @@ -155,7 +155,7 @@ def find_files(self, source: Source) -> Iterator[FileMetadata]:

for metadata in extract_metadata_from_files(files):
yield FileMetadata(
user_id=source.user_id,
user_id=source.created_by_id,
source_id=source.id,
file_name=metadata.get("file_name"),
file_path=metadata.get("file_path"),
Expand Down
2 changes: 0 additions & 2 deletions letta/llm_api/google_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,8 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool =

try:
response = requests.get(url, headers=headers)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")

# Grab the models out
model_list = response["models"]
Expand Down
95 changes: 3 additions & 92 deletions letta/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.source import Source
from letta.schemas.tool_rule import (
BaseToolRule,
InitToolRule,
Expand Down Expand Up @@ -292,40 +291,6 @@ def to_record(self) -> AgentState:
return agent_state


class SourceModel(Base):
"""Defines data model for storing Passages (consisting of text, embedding)"""

__tablename__ = "sources"
__table_args__ = {"extend_existing": True}

# Assuming passage_id is the primary key
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
embedding_config = Column(EmbeddingConfigColumn)
description = Column(String)
metadata_ = Column(JSON)
Index(__tablename__ + "_idx_user", user_id),

# TODO: add num passages

def __repr__(self) -> str:
return f"<Source(passage_id='{self.id}', name='{self.name}')>"

def to_record(self) -> Source:
return Source(
id=self.id,
user_id=self.user_id,
name=self.name,
created_at=self.created_at,
embedding_config=self.embedding_config,
description=self.description,
metadata_=self.metadata_,
)


class AgentSourceMappingModel(Base):
"""Stores mapping between agent -> source"""

Expand Down Expand Up @@ -497,14 +462,6 @@ def create_agent(self, agent: AgentState):
session.add(AgentModel(**fields))
session.commit()

@enforce_types
def create_source(self, source: Source):
with self.session_maker() as session:
if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}")
session.add(SourceModel(**vars(source)))
session.commit()

@enforce_types
def create_block(self, block: Block):
with self.session_maker() as session:
Expand All @@ -522,6 +479,7 @@ def create_block(self, block: Block):
):

raise ValueError(f"Block with name {block.template_name} already exists")

session.add(BlockModel(**vars(block)))
session.commit()

Expand All @@ -536,12 +494,6 @@ def update_agent(self, agent: AgentState):
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
session.commit()

@enforce_types
def update_source(self, source: Source):
with self.session_maker() as session:
session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source))
session.commit()

@enforce_types
def update_block(self, block: Block):
with self.session_maker() as session:
Expand Down Expand Up @@ -591,29 +543,12 @@ def delete_agent(self, agent_id: str):

session.commit()

@enforce_types
def delete_source(self, source_id: str):
with self.session_maker() as session:
# delete from sources table
session.query(SourceModel).filter(SourceModel.id == source_id).delete()

# delete any mappings
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete()

session.commit()

@enforce_types
def list_agents(self, user_id: str) -> List[AgentState]:
with self.session_maker() as session:
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
return [r.to_record() for r in results]

@enforce_types
def list_sources(self, user_id: str) -> List[Source]:
with self.session_maker() as session:
results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
return [r.to_record() for r in results]

@enforce_types
def get_agent(
self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None
Expand All @@ -630,21 +565,6 @@ def get_agent(
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
return results[0].to_record()

@enforce_types
def get_source(
self, source_id: Optional[str] = None, user_id: Optional[str] = None, source_name: Optional[str] = None
) -> Optional[Source]:
with self.session_maker() as session:
if source_id:
results = session.query(SourceModel).filter(SourceModel.id == source_id).all()
else:
assert user_id is not None and source_name is not None
results = session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()

@enforce_types
def get_block(self, block_id: str) -> Optional[Block]:
with self.session_maker() as session:
Expand Down Expand Up @@ -699,19 +619,10 @@ def attach_source(self, user_id: str, agent_id: str, source_id: str):
session.commit()

@enforce_types
def list_attached_sources(self, agent_id: str) -> List[Source]:
def list_attached_source_ids(self, agent_id: str) -> List[str]:
with self.session_maker() as session:
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()

sources = []
# make sure source exists
for r in results:
source = self.get_source(source_id=r.source_id)
if source:
sources.append(source)
else:
printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.")
return sources
return [r.source_id for r in results]

@enforce_types
def list_attached_agents(self, source_id: str) -> List[str]:
Expand Down
Loading
Loading