From b984a94dd20363b85ba0e192140cb6c9ed6cf624 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 4 Nov 2024 16:02:29 -0800 Subject: [PATCH 01/18] wip --- letta/orm/__init__.py | 1 + letta/orm/source.py | 54 +++++++++++++ letta/schemas/letta_base.py | 4 + letta/schemas/source.py | 52 ++++++------- letta/server/server.py | 2 + letta/services/source_manager.py | 90 ++++++++++++++++++++++ tests/test_managers.py | 127 ++++++++++++++++++++++++++++++- 7 files changed, 301 insertions(+), 29 deletions(-) create mode 100644 letta/orm/source.py create mode 100644 letta/services/source_manager.py diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index c95da85b65..b69737ac65 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -1,4 +1,5 @@ from letta.orm.base import Base from letta.orm.organization import Organization +from letta.orm.source import Source from letta.orm.tool import Tool from letta.orm.user import User diff --git a/letta/orm/source.py b/letta/orm/source.py new file mode 100644 index 0000000000..d4905082e8 --- /dev/null +++ b/letta/orm/source.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING, Optional + +from sqlalchemy import JSON, TypeDecorator, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.source import Source as PydanticSource + +if TYPE_CHECKING: + from letta.orm.organization import Organization + + +class EmbeddingConfigColumn(TypeDecorator): + """Custom type for storing EmbeddingConfig as JSON""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + if value: + # return vars(value) + if isinstance(value, EmbeddingConfig): + return value.model_dump() + return value + + def process_result_value(self, value, dialect): + if value: + return EmbeddingConfig(**value) + return value + + +class Source(OrganizationMixin, SqlalchemyBase): + """A source represents an embedded text passage""" + + __tablename__ = "source" + __pydantic_model__ = PydanticSource + + # Add unique constraint on (name, _organization_id) + # An organization should not have multiple sources with the same name + __table_args__ = (UniqueConstraint("name", "_organization_id", name="uix_source_name_organization"),) + + name: Mapped[str] = mapped_column(doc="the name of the source, must be unique within the org", nullable=False) + description: Mapped[str] = mapped_column(nullable=True, doc="a human-readable description of the source") + embedding_config: Mapped[EmbeddingConfig] = mapped_column(EmbeddingConfigColumn, doc="Configuration settings for embedding.") + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the source.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") + # agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources") diff --git a/letta/schemas/letta_base.py b/letta/schemas/letta_base.py index f2b2b09f9d..7bbe15f79b 100644 --- a/letta/schemas/letta_base.py +++ b/letta/schemas/letta_base.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime from logging import getLogger from typing import Optional from uuid import UUID @@ -14,6 +15,9 @@ class LettaBase(BaseModel): """Base schema for Letta schemas (does not include model provider schemas, e.g. OpenAI)""" + created_at: Optional[datetime] = Field(None, description="The timestamp when the source was created.") + updated_at: Optional[datetime] = Field(None, description="The timestamp when the source was last updated.") + model_config = ConfigDict( # allows you to use the snake or camelcase names in your code (ie user_id or userId) populate_by_name=True, diff --git a/letta/schemas/source.py b/letta/schemas/source.py index 8f816ad701..8eb4f98d53 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -1,12 +1,9 @@ -from datetime import datetime from typing import Optional -from fastapi import UploadFile -from pydantic import BaseModel, Field +from pydantic import Field from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_base import LettaBase -from letta.utils import get_utc_time class BaseSource(LettaBase): @@ -15,15 +12,6 @@ class BaseSource(LettaBase): """ __id_prefix__ = "source" - description: Optional[str] = Field(None, description="The description of the source.") - embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.") - # NOTE: .metadata is a reserved attribute on SQLModel - metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") - - -class SourceCreate(BaseSource): - name: str = Field(..., description="The name of the source.") - description: Optional[str] = Field(None, description="The description of the source.") class Source(BaseSource): @@ -34,29 +22,41 @@ class Source(BaseSource): id (str): The ID of the source name (str): The name of the source. embedding_config (EmbeddingConfig): The embedding configuration used by the source. - created_at (datetime): The creation date of the source. user_id (str): The ID of the user that created the source. metadata_ (dict): Metadata associated with the source. description (str): The description of the source. """ - id: str = BaseSource.generate_id_field() + id: str = Field(..., description="The id of the source.") name: str = Field(..., description="The name of the source.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") - created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the source.") - user_id: str = Field(..., description="The ID of the user that created the source.") + organization_id: str = Field(..., description="The ID of the organization that created the source.") + # metadata fields + created_by_id: str = Field(..., description="The id of the user that made this Tool.") + last_updated_by_id: str = Field(..., description="The id of the user that made this Tool.") -class SourceUpdate(BaseSource): - id: str = Field(..., description="The ID of the source.") - name: Optional[str] = Field(None, description="The name of the source.") +class SourceCreate(BaseSource): + """ + Schema for creating a new Source. + """ + + # required + name: str = Field(..., description="The name of the source.") + embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") + + # optional + description: Optional[str] = Field(None, description="The description of the source.") + metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") -class UploadFileToSourceRequest(BaseModel): - file: UploadFile = Field(..., description="The file to upload.") +class SourceUpdate(BaseSource): + """ + Schema for updating an existing Source. + """ -class UploadFileToSourceResponse(BaseModel): - source: Source = Field(..., description="The source the file was uploaded to.") - added_passages: int = Field(..., description="The number of passages added to the source.") - added_documents: int = Field(..., description="The number of files added to the source.") + name: Optional[str] = Field(None, description="The name of the source.") + description: Optional[str] = Field(None, description="The description of the source.") + metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") + embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.") diff --git a/letta/server/server.py b/letta/server/server.py index 9a4b318e8f..c44706d3a2 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -83,6 +83,7 @@ from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.services.organization_manager import OrganizationManager +from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.utils import create_random_username, json_dumps, json_loads @@ -248,6 +249,7 @@ def __init__( self.organization_manager = OrganizationManager() self.user_manager = UserManager() self.tool_manager = ToolManager() + self.source_manager = SourceManager() # Make default user and org if init_with_default_org_and_user: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py new file mode 100644 index 0000000000..5859609597 --- /dev/null +++ b/letta/services/source_manager.py @@ -0,0 +1,90 @@ +from typing import List, Optional + +from letta.orm.errors import NoResultFound +from letta.orm.organization import Organization as OrganizationModel +from letta.orm.source import Source as SourceModel +from letta.schemas.source import Source as PydanticSource +from letta.schemas.source import SourceCreate, SourceUpdate +from letta.schemas.user import User as PydanticUser +from letta.utils import enforce_types, printd + + +class SourceManager: + """Manager class to handle business logic related to Sources.""" + + def __init__(self): + from letta.server.server import db_context + + self.session_maker = db_context + + @enforce_types + def create_source(self, source_create: SourceCreate, actor: PydanticUser) -> PydanticSource: + """Create a new source based on the SourceCreate schema.""" + with self.session_maker() as session: + create_data = source_create.model_dump() + source = SourceModel(**create_data, organization_id=actor.organization_id) + source.create(session, actor=actor) + return source.to_pydantic() + + @enforce_types + def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource: + """Update a source by its ID with the given SourceUpdate object.""" + with self.session_maker() as session: + source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) + + # get update dictionary + update_data = source_update.model_dump(exclude_unset=True, exclude_none=True) + # Remove redundant update fields + update_data = {key: value for key, value in update_data.items() if getattr(source, key) != value} + + if update_data: + for key, value in update_data.items(): + setattr(source, key, value) + source.update(db_session=session, actor=actor) + else: + printd( + f"`update_source` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={source.name}, but found existing source with nothing to update." + ) + + return source.to_pydantic() + + @enforce_types + def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource: + """Delete a source by its ID.""" + with self.session_maker() as session: + source = SourceModel.read(db_session=session, identifier=source_id) + source.delete(db_session=session, actor=actor) + return source.to_pydantic() + + @enforce_types + def list_sources(self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticSource]: + """List all sources with optional pagination.""" + with self.session_maker() as session: + sources = SourceModel.list( + db_session=session, + cursor=cursor, + limit=limit, + _organization_id=OrganizationModel.get_uid_from_identifier(actor.organization_id), + ) + return [source.to_pydantic() for source in sources] + + @enforce_types + def get_source_by_id(self, source_id: str, actor: PydanticUser) -> PydanticSource: + """Retrieve a source by its ID.""" + with self.session_maker() as session: + source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) + return source.to_pydantic() + + @enforce_types + def get_source_by_name(self, source_name: str, actor: PydanticUser) -> PydanticSource: + """Retrieve a source by its name.""" + with self.session_maker() as session: + sources = SourceModel.list( + db_session=session, + name=source_name, + _organization_id=OrganizationModel.get_uid_from_identifier(actor.organization_id), + limit=1, + ) + if not sources: + raise NoResultFound(f"Source with name '{source_name}' not found.") + return sources[0].to_pydantic() diff --git a/tests/test_managers.py b/tests/test_managers.py index c8232963ea..c300850a5a 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3,9 +3,8 @@ import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code -from letta.orm.organization import Organization -from letta.orm.tool import Tool -from letta.orm.user import User +from letta.orm import Organization, Source, Tool, User +from letta.schemas.source import SourceCreate from letta.schemas.tool import ToolCreate, ToolUpdate from letta.services.organization_manager import OrganizationManager @@ -19,6 +18,7 @@ def clear_tables(server: SyncServer): """Fixture to clear the organization table before each test.""" with server.organization_manager.session_maker() as session: + session.execute(delete(Source)) session.execute(delete(Tool)) # Clear all records from the Tool table session.execute(delete(User)) # Clear all records from the user table session.execute(delete(Organization)) # Clear all records from the organization table @@ -357,3 +357,124 @@ def test_delete_tool_by_id(server: SyncServer, tool_fixture): tools = server.tool_manager.list_tools(actor=user) assert len(tools) == 0 + + +# ====================================================================================================================== +# Source Manager Tests +# ====================================================================================================================== + + +def test_create_source(server: SyncServer, actor): + """Test creating a new source.""" + source_create = SourceCreate( + name="Test Source", description="This is a test source.", metadata_={"type": "test"}, embedding_config=None + ) + source = server.source_manager.create_source(source_create=source_create, actor=actor) + + # Assertions to check the created source + assert source.name == source_create.name + assert source.description == source_create.description + assert source.metadata_ == source_create.metadata_ + assert source.organization_id == actor.organization_id + + +# def test_update_source(source_manager, actor): +# """Test updating an existing source.""" +# source_create = SourceCreate(name="Original Source", description="Original description") +# source = source_manager.create_source(source_create=source_create, actor=actor) +# +# # Update the source +# update_data = SourceUpdate( +# name="Updated Source", +# description="Updated description", +# metadata_={"type": "updated"} +# ) +# updated_source = source_manager.update_source(source_id=source.id, source_update=update_data, actor=actor) +# +# # Assertions to verify update +# assert updated_source.name == update_data.name +# assert updated_source.description == update_data.description +# assert updated_source.metadata_ == update_data.metadata_ +# +# +# def test_delete_source(source_manager, actor): +# """Test deleting a source.""" +# source_create = SourceCreate(name="To Delete", description="This source will be deleted.") +# source = source_manager.create_source(source_create=source_create, actor=actor) +# +# # Delete the source +# deleted_source = source_manager.delete_source(source_id=source.id, actor=actor) +# +# # Assertions to verify deletion +# assert deleted_source.id == source.id +# assert deleted_source.is_deleted +# +# # Verify that the source no longer appears in list_sources +# sources = source_manager.list_sources(actor=actor) +# assert len(sources) == 0 +# +# +# def test_list_sources(source_manager, actor): +# """Test listing sources with pagination.""" +# # Create multiple sources +# source_manager.create_source(SourceCreate(name="Source 1"), actor=actor) +# source_manager.create_source(SourceCreate(name="Source 2"), actor=actor) +# +# # List sources without pagination +# sources = source_manager.list_sources(actor=actor) +# assert len(sources) == 2 +# +# # List sources with pagination +# paginated_sources = source_manager.list_sources(actor=actor, limit=1) +# assert len(paginated_sources) == 1 +# +# # Ensure cursor-based pagination works +# next_page = source_manager.list_sources(actor=actor, cursor=paginated_sources[-1].id, limit=1) +# assert len(next_page) == 1 +# assert next_page[0].name != paginated_sources[0].name +# +# +# def test_get_source_by_id(source_manager, actor): +# """Test retrieving a source by ID.""" +# source_create = SourceCreate(name="Retrieve by ID", description="Test source for ID retrieval") +# source = source_manager.create_source(source_create=source_create, actor=actor) +# +# # Retrieve the source by ID +# retrieved_source = source_manager.get_source_by_id(source_id=source.id, actor=actor) +# +# # Assertions to verify the retrieved source matches the created one +# assert retrieved_source.id == source.id +# assert retrieved_source.name == source.name +# assert retrieved_source.description == source.description +# +# +# def test_get_source_by_name(source_manager, actor): +# """Test retrieving a source by name.""" +# source_create = SourceCreate(name="Unique Source", description="Test source for name retrieval") +# source = source_manager.create_source(source_create=source_create, actor=actor) +# +# # Retrieve the source by name +# retrieved_source = source_manager.get_source_by_name(source_name=source.name, actor=actor) +# +# # Assertions to verify the retrieved source matches the created one +# assert retrieved_source.name == source.name +# assert retrieved_source.description == source.description +# +# +# def test_update_source_no_changes(source_manager, actor): +# """Test update_source with no actual changes to verify logging and response.""" +# source_create = SourceCreate(name="No Change Source", description="No changes") +# source = source_manager.create_source(source_create=source_create, actor=actor) +# +# # Attempt to update the source with identical data +# update_data = SourceUpdate( +# id=source.id, +# name="No Change Source", +# description="No changes" +# ) +# updated_source = source_manager.update_source(source_id=source.id, source_update=update_data, actor=actor) +# +# # Assertions to ensure the update returned the source but made no modifications +# assert updated_source.id == source.id +# assert updated_source.name == source.name +# assert updated_source.description == source.description From 54a57ee120b53d297c2f763dcb9afa8aca857401 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 4 Nov 2024 16:16:50 -0800 Subject: [PATCH 02/18] Finish manager tests --- letta/llm_api/google_ai.py | 2 - letta/orm/organization.py | 2 +- letta/providers.py | 1 - letta/schemas/source.py | 2 + tests/test_managers.py | 312 +++++++++++++++++++------------------ 5 files changed, 165 insertions(+), 154 deletions(-) diff --git a/letta/llm_api/google_ai.py b/letta/llm_api/google_ai.py index 5d4e1798ab..57071a2396 100644 --- a/letta/llm_api/google_ai.py +++ b/letta/llm_api/google_ai.py @@ -95,10 +95,8 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = try: response = requests.get(url, headers=headers) - printd(f"response = {response}") response.raise_for_status() # Raises HTTPError for 4XX/5XX status response = response.json() # convert to dict from string - printd(f"response.json = {response}") # Grab the models out model_list = response["models"] diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 51e87e8aa0..0df9bf7805 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -21,10 +21,10 @@ class Organization(SqlalchemyBase): users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") + sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") # TODO: Map these relationships later when we actually make these models # below is just a suggestion # agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") - # sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") # tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") # documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/providers.py b/letta/providers.py index 6fa98327f3..63bbe4752d 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -462,7 +462,6 @@ def list_llm_models(self) -> List[LLMConfig]: response = openai_get_model_list(self.base_url, api_key=None) configs = [] - print(response) for model in response["data"]: configs.append( LLMConfig( diff --git a/letta/schemas/source.py b/letta/schemas/source.py index 8eb4f98d53..f6042eaddc 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -29,8 +29,10 @@ class Source(BaseSource): id: str = Field(..., description="The id of the source.") name: str = Field(..., description="The name of the source.") + description: Optional[str] = Field(None, description="The description of the source.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") organization_id: str = Field(..., description="The ID of the organization that created the source.") + metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") # metadata fields created_by_id: str = Field(..., description="The id of the user that made this Tool.") diff --git a/tests/test_managers.py b/tests/test_managers.py index c300850a5a..4378754770 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4,7 +4,8 @@ import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.orm import Organization, Source, Tool, User -from letta.schemas.source import SourceCreate +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.source import SourceCreate, SourceUpdate from letta.schemas.tool import ToolCreate, ToolUpdate from letta.services.organization_manager import OrganizationManager @@ -13,6 +14,18 @@ from letta.schemas.user import UserCreate, UserUpdate from letta.server.server import SyncServer +# Test Constants +EMBEDDING_CONFIG = EmbeddingConfig( + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_model="letta-free", + embedding_dim=1024, + embedding_chunk_size=300, + azure_endpoint=None, + azure_version=None, + azure_deployment=None, +) + @pytest.fixture(autouse=True) def clear_tables(server: SyncServer): @@ -26,7 +39,28 @@ def clear_tables(server: SyncServer): @pytest.fixture -def tool_fixture(server: SyncServer): +def default_organization(server: SyncServer): + """Fixture to create and return the default organization.""" + org = server.organization_manager.create_default_organization() + yield org + + +@pytest.fixture +def default_user(server: SyncServer, default_organization): + """Fixture to create and return the default user within the default organization.""" + user = server.user_manager.create_default_user(org_id=default_organization.id) + yield user + + +@pytest.fixture +def other_user(server: SyncServer, default_organization): + """Fixture to create and return the default user within the default organization.""" + user = server.user_manager.create_user(UserCreate(name="other", organization_id=default_organization.id)) + yield user + + +@pytest.fixture +def tool_fixture(server: SyncServer, default_user, default_organization): """Fixture to create a tool with default settings and clean up after the test.""" def print_tool(message: str): @@ -40,24 +74,22 @@ def print_tool(message: str): print(message) return message + # Set up tool details source_code = parse_source_code(print_tool) source_type = "python" description = "test_description" tags = ["test"] - org = server.organization_manager.create_default_organization() - user = server.user_manager.create_default_user() - other_user = server.user_manager.create_user(UserCreate(name="other", organization_id=org.id)) tool_create = ToolCreate(description=description, tags=tags, source_code=source_code, source_type=source_type) derived_json_schema = derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name) derived_name = derived_json_schema["name"] tool_create.json_schema = derived_json_schema tool_create.name = derived_name - tool = server.tool_manager.create_tool(tool_create, actor=user) + tool = server.tool_manager.create_tool(tool_create, actor=default_user) - # Yield the created tool, organization, and user for use in tests - yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool_create} + # Yield the created tool for use in tests + yield {"tool": tool, "tool_create": tool_create} @pytest.fixture(scope="module") @@ -169,15 +201,13 @@ def test_update_user(server: SyncServer): # ====================================================================================================================== # Tool Manager Tests # ====================================================================================================================== -def test_create_tool(server: SyncServer, tool_fixture): +def test_create_tool(server: SyncServer, tool_fixture, default_user, default_organization): tool = tool_fixture["tool"] tool_create = tool_fixture["tool_create"] - user = tool_fixture["user"] - org = tool_fixture["organization"] # Assertions to ensure the created tool matches the expected values - assert tool.created_by_id == user.id - assert tool.organization_id == org.id + assert tool.created_by_id == default_user.id + assert tool.organization_id == default_organization.id assert tool.description == tool_create.description assert tool.tags == tool_create.tags assert tool.source_code == tool_create.source_code @@ -185,12 +215,11 @@ def test_create_tool(server: SyncServer, tool_fixture): assert tool.json_schema == derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name) -def test_get_tool_by_id(server: SyncServer, tool_fixture): +def test_get_tool_by_id(server: SyncServer, tool_fixture, default_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] # Fetch the tool by ID using the manager method - fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user) # Assertions to check if the fetched tool matches the created tool assert fetched_tool.id == tool.id @@ -201,54 +230,51 @@ def test_get_tool_by_id(server: SyncServer, tool_fixture): assert fetched_tool.source_type == tool.source_type -def test_get_tool_with_actor(server: SyncServer, tool_fixture): +def test_get_tool_with_actor(server: SyncServer, tool_fixture, default_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] # Fetch the tool by name and organization ID - fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=user) + fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=default_user) # Assertions to check if the fetched tool matches the created tool assert fetched_tool.id == tool.id assert fetched_tool.name == tool.name - assert fetched_tool.created_by_id == user.id + assert fetched_tool.created_by_id == default_user.id assert fetched_tool.description == tool.description assert fetched_tool.tags == tool.tags assert fetched_tool.source_code == tool.source_code assert fetched_tool.source_type == tool.source_type -def test_list_tools(server: SyncServer, tool_fixture): +def test_list_tools(server: SyncServer, tool_fixture, default_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] # List tools (should include the one created by the fixture) - tools = server.tool_manager.list_tools(actor=user) + tools = server.tool_manager.list_tools(actor=default_user) # Assertions to check that the created tool is listed assert len(tools) == 1 assert any(t.id == tool.id for t in tools) -def test_update_tool_by_id(server: SyncServer, tool_fixture): +def test_update_tool_by_id(server: SyncServer, tool_fixture, default_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] updated_description = "updated_description" # Create a ToolUpdate object to modify the tool's description tool_update = ToolUpdate(description=updated_description) # Update the tool using the manager method - server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user) + server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user) # Fetch the updated tool to verify the changes - updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user) # Assertions to check if the update was successful assert updated_tool.description == updated_description -def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture): +def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture, default_user): def counter_tool(counter: int): """ Args: @@ -264,7 +290,6 @@ def counter_tool(counter: int): # Test begins tool = tool_fixture["tool"] - user = tool_fixture["user"] og_json_schema = tool_fixture["tool_create"].json_schema source_code = parse_source_code(counter_tool) @@ -273,10 +298,10 @@ def counter_tool(counter: int): tool_update = ToolUpdate(source_code=source_code) # Update the tool using the manager method - server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user) + server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user) # Fetch the updated tool to verify the changes - updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user) # Assertions to check if the update was successful, and json_schema is updated as well assert updated_tool.source_code == source_code @@ -287,7 +312,7 @@ def counter_tool(counter: int): assert updated_tool.name == new_schema["name"] -def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture): +def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture, default_user): def counter_tool(counter: int): """ Args: @@ -303,7 +328,6 @@ def counter_tool(counter: int): # Test begins tool = tool_fixture["tool"] - user = tool_fixture["user"] og_json_schema = tool_fixture["tool_create"].json_schema source_code = parse_source_code(counter_tool) @@ -313,10 +337,10 @@ def counter_tool(counter: int): tool_update = ToolUpdate(name=name, source_code=source_code) # Update the tool using the manager method - server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user) + server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user) # Fetch the updated tool to verify the changes - updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user) # Assertions to check if the update was successful, and json_schema is updated as well assert updated_tool.source_code == source_code @@ -327,10 +351,8 @@ def counter_tool(counter: int): assert updated_tool.name == name -def test_update_tool_multi_user(server: SyncServer, tool_fixture): +def test_update_tool_multi_user(server: SyncServer, tool_fixture, default_user, other_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] - other_user = tool_fixture["other_user"] updated_description = "updated_description" # Create a ToolUpdate object to modify the tool's description @@ -342,20 +364,19 @@ def test_update_tool_multi_user(server: SyncServer, tool_fixture): # Check that the created_by and last_updated_by fields are correct # Fetch the updated tool to verify the changes - updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user) assert updated_tool.last_updated_by_id == other_user.id - assert updated_tool.created_by_id == user.id + assert updated_tool.created_by_id == default_user.id -def test_delete_tool_by_id(server: SyncServer, tool_fixture): +def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] # Delete the tool using the manager method - server.tool_manager.delete_tool_by_id(tool.id, actor=user) + server.tool_manager.delete_tool_by_id(tool.id, actor=default_user) - tools = server.tool_manager.list_tools(actor=user) + tools = server.tool_manager.list_tools(actor=default_user) assert len(tools) == 0 @@ -364,117 +385,108 @@ def test_delete_tool_by_id(server: SyncServer, tool_fixture): # ====================================================================================================================== -def test_create_source(server: SyncServer, actor): +def test_create_source(server: SyncServer, default_user): """Test creating a new source.""" source_create = SourceCreate( - name="Test Source", description="This is a test source.", metadata_={"type": "test"}, embedding_config=None + name="Test Source", description="This is a test source.", metadata_={"type": "test"}, embedding_config=EMBEDDING_CONFIG ) - source = server.source_manager.create_source(source_create=source_create, actor=actor) + source = server.source_manager.create_source(source_create=source_create, actor=default_user) # Assertions to check the created source assert source.name == source_create.name assert source.description == source_create.description assert source.metadata_ == source_create.metadata_ - assert source.organization_id == actor.organization_id - - -# def test_update_source(source_manager, actor): -# """Test updating an existing source.""" -# source_create = SourceCreate(name="Original Source", description="Original description") -# source = source_manager.create_source(source_create=source_create, actor=actor) -# -# # Update the source -# update_data = SourceUpdate( -# name="Updated Source", -# description="Updated description", -# metadata_={"type": "updated"} -# ) -# updated_source = source_manager.update_source(source_id=source.id, source_update=update_data, actor=actor) -# -# # Assertions to verify update -# assert updated_source.name == update_data.name -# assert updated_source.description == update_data.description -# assert updated_source.metadata_ == update_data.metadata_ -# -# -# def test_delete_source(source_manager, actor): -# """Test deleting a source.""" -# source_create = SourceCreate(name="To Delete", description="This source will be deleted.") -# source = source_manager.create_source(source_create=source_create, actor=actor) -# -# # Delete the source -# deleted_source = source_manager.delete_source(source_id=source.id, actor=actor) -# -# # Assertions to verify deletion -# assert deleted_source.id == source.id -# assert deleted_source.is_deleted -# -# # Verify that the source no longer appears in list_sources -# sources = source_manager.list_sources(actor=actor) -# assert len(sources) == 0 -# -# -# def test_list_sources(source_manager, actor): -# """Test listing sources with pagination.""" -# # Create multiple sources -# source_manager.create_source(SourceCreate(name="Source 1"), actor=actor) -# source_manager.create_source(SourceCreate(name="Source 2"), actor=actor) -# -# # List sources without pagination -# sources = source_manager.list_sources(actor=actor) -# assert len(sources) == 2 -# -# # List sources with pagination -# paginated_sources = source_manager.list_sources(actor=actor, limit=1) -# assert len(paginated_sources) == 1 -# -# # Ensure cursor-based pagination works -# next_page = source_manager.list_sources(actor=actor, cursor=paginated_sources[-1].id, limit=1) -# assert len(next_page) == 1 -# assert next_page[0].name != paginated_sources[0].name -# -# -# def test_get_source_by_id(source_manager, actor): -# """Test retrieving a source by ID.""" -# source_create = SourceCreate(name="Retrieve by ID", description="Test source for ID retrieval") -# source = source_manager.create_source(source_create=source_create, actor=actor) -# -# # Retrieve the source by ID -# retrieved_source = source_manager.get_source_by_id(source_id=source.id, actor=actor) -# -# # Assertions to verify the retrieved source matches the created one -# assert retrieved_source.id == source.id -# assert retrieved_source.name == source.name -# assert retrieved_source.description == source.description -# -# -# def test_get_source_by_name(source_manager, actor): -# """Test retrieving a source by name.""" -# source_create = SourceCreate(name="Unique Source", description="Test source for name retrieval") -# source = source_manager.create_source(source_create=source_create, actor=actor) -# -# # Retrieve the source by name -# retrieved_source = source_manager.get_source_by_name(source_name=source.name, actor=actor) -# -# # Assertions to verify the retrieved source matches the created one -# assert retrieved_source.name == source.name -# assert retrieved_source.description == source.description -# -# -# def test_update_source_no_changes(source_manager, actor): -# """Test update_source with no actual changes to verify logging and response.""" -# source_create = SourceCreate(name="No Change Source", description="No changes") -# source = source_manager.create_source(source_create=source_create, actor=actor) -# -# # Attempt to update the source with identical data -# update_data = SourceUpdate( -# id=source.id, -# name="No Change Source", -# description="No changes" -# ) -# updated_source = source_manager.update_source(source_id=source.id, source_update=update_data, actor=actor) -# -# # Assertions to ensure the update returned the source but made no modifications -# assert updated_source.id == source.id -# assert updated_source.name == source.name -# assert updated_source.description == source.description + assert source.organization_id == default_user.organization_id + + +def test_update_source(server: SyncServer, default_user): + """Test updating an existing source.""" + source_create = SourceCreate(name="Original Source", description="Original description", embedding_config=EMBEDDING_CONFIG) + source = server.source_manager.create_source(source_create=source_create, actor=default_user) + + # Update the source + update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata_={"type": "updated"}) + updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user) + + # Assertions to verify update + assert updated_source.name == update_data.name + assert updated_source.description == update_data.description + assert updated_source.metadata_ == update_data.metadata_ + + +def test_delete_source(server: SyncServer, default_user): + """Test deleting a source.""" + source_create = SourceCreate(name="To Delete", description="This source will be deleted.", embedding_config=EMBEDDING_CONFIG) + source = server.source_manager.create_source(source_create=source_create, actor=default_user) + + # Delete the source + deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user) + + # Assertions to verify deletion + assert deleted_source.id == source.id + + # Verify that the source no longer appears in list_sources + sources = server.source_manager.list_sources(actor=default_user) + assert len(sources) == 0 + + +def test_list_sources(server: SyncServer, default_user): + """Test listing sources with pagination.""" + # Create multiple sources + server.source_manager.create_source(SourceCreate(name="Source 1", embedding_config=EMBEDDING_CONFIG), actor=default_user) + server.source_manager.create_source(SourceCreate(name="Source 2", embedding_config=EMBEDDING_CONFIG), actor=default_user) + + # List sources without pagination + sources = server.source_manager.list_sources(actor=default_user) + assert len(sources) == 2 + + # List sources with pagination + paginated_sources = server.source_manager.list_sources(actor=default_user, limit=1) + assert len(paginated_sources) == 1 + + # Ensure cursor-based pagination works + next_page = server.source_manager.list_sources(actor=default_user, cursor=paginated_sources[-1].id, limit=1) + assert len(next_page) == 1 + assert next_page[0].name != paginated_sources[0].name + + +def test_get_source_by_id(server: SyncServer, default_user): + """Test retrieving a source by ID.""" + source_create = SourceCreate(name="Retrieve by ID", description="Test source for ID retrieval", embedding_config=EMBEDDING_CONFIG) + source = server.source_manager.create_source(source_create=source_create, actor=default_user) + + # Retrieve the source by ID + retrieved_source = server.source_manager.get_source_by_id(source_id=source.id, actor=default_user) + + # Assertions to verify the retrieved source matches the created one + assert retrieved_source.id == source.id + assert retrieved_source.name == source.name + assert retrieved_source.description == source.description + + +def test_get_source_by_name(server: SyncServer, default_user): + """Test retrieving a source by name.""" + source_create = SourceCreate(name="Unique Source", description="Test source for name retrieval", embedding_config=EMBEDDING_CONFIG) + source = server.source_manager.create_source(source_create=source_create, actor=default_user) + + # Retrieve the source by name + retrieved_source = server.source_manager.get_source_by_name(source_name=source.name, actor=default_user) + + # Assertions to verify the retrieved source matches the created one + assert retrieved_source.name == source.name + assert retrieved_source.description == source.description + + +def test_update_source_no_changes(server: SyncServer, default_user): + """Test update_source with no actual changes to verify logging and response.""" + source_create = SourceCreate(name="No Change Source", description="No changes", embedding_config=EMBEDDING_CONFIG) + source = server.source_manager.create_source(source_create=source_create, actor=default_user) + + # Attempt to update the source with identical data + update_data = SourceUpdate(name="No Change Source", description="No changes") + updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user) + + # Assertions to ensure the update returned the source but made no modifications + assert updated_source.id == source.id + assert updated_source.name == source.name + assert updated_source.description == source.description From 48ca876ae6db4b64bcda9d88f7ed2ac39da0972a Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 4 Nov 2024 17:16:47 -0800 Subject: [PATCH 03/18] finish --- letta/agent.py | 6 +- letta/client/client.py | 14 ++-- letta/data_sources/connectors.py | 2 +- letta/metadata.py | 81 +------------------ letta/schemas/letta_base.py | 4 - letta/schemas/source.py | 5 +- letta/server/rest_api/routers/v1/sources.py | 21 +++-- letta/server/server.py | 87 ++++++--------------- letta/services/source_manager.py | 20 ++++- tests/test_local_client.py | 2 +- tests/test_managers.py | 45 ++++++----- tests/test_server.py | 2 +- 12 files changed, 96 insertions(+), 193 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index a529772347..2f1bfbd6ee 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -46,6 +46,8 @@ from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics +from letta.services.source_manager import SourceManager +from letta.services.user_manager import UserManager from letta.system import ( get_heartbeat, get_initial_boot_messages, @@ -1273,7 +1275,7 @@ def migrate_embedding(self, embedding_config: EmbeddingConfig): def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore): """Attach data with name `source_name` to the agent from source_connector.""" # TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory - + user = UserManager().get_user_by_id(self.agent_state.user_id) filters = {"user_id": self.agent_state.user_id, "source_id": source_id} size = source_connector.size(filters) page_size = 100 @@ -1301,7 +1303,7 @@ def attach_source(self, source_id: str, source_connector: StorageConnector, ms: self.persistence_manager.archival_memory.storage.save() # attach to agent - source = ms.get_source(source_id=source_id) + source = SourceManager().get_source_by_id(source_id=source_id, actor=user) assert source is not None, f"Source {source_id} not found in metadata store" ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id) diff --git a/letta/client/client.py b/letta/client/client.py index f4eb2211b6..b6a84f26ce 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2439,7 +2439,7 @@ def create_source(self, name: str) -> Source: source (Source): Created source """ request = SourceCreate(name=name) - return self.server.create_source(request=request, user_id=self.user_id) + return self.server.source_manager.create_source(request=request, actor=actor) def delete_source(self, source_id: str): """ @@ -2450,7 +2450,7 @@ def delete_source(self, source_id: str): """ # TODO: delete source data - self.server.delete_source(source_id=source_id, user_id=self.user_id) + self.server.delete_source(source_id=source_id, actor=self.user) def get_source(self, source_id: str) -> Source: """ @@ -2462,7 +2462,7 @@ def get_source(self, source_id: str) -> Source: Returns: source (Source): Source """ - return self.server.get_source(source_id=source_id, user_id=self.user_id) + return self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user) def get_source_id(self, source_name: str) -> str: """ @@ -2474,7 +2474,7 @@ def get_source_id(self, source_name: str) -> str: Returns: source_id (str): ID of the source """ - return self.server.get_source_id(source_name=source_name, user_id=self.user_id) + return self.server.source_manager.get_source_by_name(source_name=source_name, actor=self.user).id def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None): """ @@ -2507,7 +2507,7 @@ def list_sources(self) -> List[Source]: sources (List[Source]): List of sources """ - return self.server.list_all_sources(user_id=self.user_id) + return self.server.list_all_sources(actor=self.user) def list_attached_sources(self, agent_id: str) -> List[Source]: """ @@ -2547,8 +2547,8 @@ def update_source(self, source_id: str, name: Optional[str] = None) -> Source: source (Source): Updated source """ # TODO should the arg here just be "source_update: Source"? - request = SourceUpdate(id=source_id, name=name) - return self.server.update_source(request=request, user_id=self.user_id) + request = SourceUpdate(name=name) + return self.server.source_manager.update_source(source_id=source_id, source_update=request, actor=self.user) # archival memory diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index f729c8ad10..91d8c2a55a 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -47,7 +47,7 @@ def load_data( passage_store: StorageConnector, file_metadata_store: StorageConnector, ): - """Load data from a connector (generates file and passages) into a specified source_id, associatedw with a user_id.""" + """Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id.""" embedding_config = source.embedding_config # embedding model diff --git a/letta/metadata.py b/letta/metadata.py index 3f9eea5a98..4c74f34cbd 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -292,40 +292,6 @@ def to_record(self) -> AgentState: return agent_state -class SourceModel(Base): - """Defines data model for storing Passages (consisting of text, embedding)""" - - __tablename__ = "sources" - __table_args__ = {"extend_existing": True} - - # Assuming passage_id is the primary key - # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - id = Column(String, primary_key=True) - user_id = Column(String, nullable=False) - name = Column(String, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - embedding_config = Column(EmbeddingConfigColumn) - description = Column(String) - metadata_ = Column(JSON) - Index(__tablename__ + "_idx_user", user_id), - - # TODO: add num passages - - def __repr__(self) -> str: - return f"" - - def to_record(self) -> Source: - return Source( - id=self.id, - user_id=self.user_id, - name=self.name, - created_at=self.created_at, - embedding_config=self.embedding_config, - description=self.description, - metadata_=self.metadata_, - ) - - class AgentSourceMappingModel(Base): """Stores mapping between agent -> source""" @@ -496,14 +462,6 @@ def create_agent(self, agent: AgentState): session.add(AgentModel(**fields)) session.commit() - @enforce_types - def create_source(self, source: Source): - with self.session_maker() as session: - if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0: - raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}") - session.add(SourceModel(**vars(source))) - session.commit() - @enforce_types def create_block(self, block: Block): with self.session_maker() as session: @@ -521,6 +479,7 @@ def create_block(self, block: Block): ): raise ValueError(f"Block with name {block.template_name} already exists") + session.add(BlockModel(**vars(block))) session.commit() @@ -534,12 +493,6 @@ def update_agent(self, agent: AgentState): session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields) session.commit() - @enforce_types - def update_source(self, source: Source): - with self.session_maker() as session: - session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source)) - session.commit() - @enforce_types def update_block(self, block: Block): with self.session_maker() as session: @@ -589,29 +542,12 @@ def delete_agent(self, agent_id: str): session.commit() - @enforce_types - def delete_source(self, source_id: str): - with self.session_maker() as session: - # delete from sources table - session.query(SourceModel).filter(SourceModel.id == source_id).delete() - - # delete any mappings - session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete() - - session.commit() - @enforce_types def list_agents(self, user_id: str) -> List[AgentState]: with self.session_maker() as session: results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all() return [r.to_record() for r in results] - @enforce_types - def list_sources(self, user_id: str) -> List[Source]: - with self.session_maker() as session: - results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all() - return [r.to_record() for r in results] - @enforce_types def get_agent( self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None @@ -628,21 +564,6 @@ def get_agent( assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result return results[0].to_record() - @enforce_types - def get_source( - self, source_id: Optional[str] = None, user_id: Optional[str] = None, source_name: Optional[str] = None - ) -> Optional[Source]: - with self.session_maker() as session: - if source_id: - results = session.query(SourceModel).filter(SourceModel.id == source_id).all() - else: - assert user_id is not None and source_name is not None - results = session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0].to_record() - @enforce_types def get_block(self, block_id: str) -> Optional[Block]: with self.session_maker() as session: diff --git a/letta/schemas/letta_base.py b/letta/schemas/letta_base.py index 7bbe15f79b..f2b2b09f9d 100644 --- a/letta/schemas/letta_base.py +++ b/letta/schemas/letta_base.py @@ -1,5 +1,4 @@ import uuid -from datetime import datetime from logging import getLogger from typing import Optional from uuid import UUID @@ -15,9 +14,6 @@ class LettaBase(BaseModel): """Base schema for Letta schemas (does not include model provider schemas, e.g. OpenAI)""" - created_at: Optional[datetime] = Field(None, description="The timestamp when the source was created.") - updated_at: Optional[datetime] = Field(None, description="The timestamp when the source was last updated.") - model_config = ConfigDict( # allows you to use the snake or camelcase names in your code (ie user_id or userId) populate_by_name=True, diff --git a/letta/schemas/source.py b/letta/schemas/source.py index f6042eaddc..747dfce06f 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Optional from pydantic import Field @@ -37,6 +38,8 @@ class Source(BaseSource): # metadata fields created_by_id: str = Field(..., description="The id of the user that made this Tool.") last_updated_by_id: str = Field(..., description="The id of the user that made this Tool.") + created_at: Optional[datetime] = Field(None, description="The timestamp when the source was created.") + updated_at: Optional[datetime] = Field(None, description="The timestamp when the source was last updated.") class SourceCreate(BaseSource): @@ -46,7 +49,7 @@ class SourceCreate(BaseSource): # required name: str = Field(..., description="The name of the source.") - embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") + embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.") # optional description: Optional[str] = Field(None, description="The description of the source.") diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 388fa3e09e..e1e2f793d3 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -36,7 +36,7 @@ def get_source( """ actor = server.get_user_or_default(user_id=user_id) - return server.get_source(source_id=source_id, user_id=actor.id) + return server.source(source_id=source_id, user_id=actor.id) @router.get("/name/{source_name}", response_model=str, operation_id="get_source_id_by_name") @@ -50,8 +50,8 @@ def get_source_id_by_name( """ actor = server.get_user_or_default(user_id=user_id) - source_id = server.get_source_id(source_name=source_name, user_id=actor.id) - return source_id + source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor) + return source.id @router.get("/", response_model=List[Source], operation_id="list_sources") @@ -64,7 +64,7 @@ def list_sources( """ actor = server.get_user_or_default(user_id=user_id) - return server.list_all_sources(user_id=actor.id) + return server.list_all_sources(actor=actor) @router.post("/", response_model=Source, operation_id="create_source") @@ -78,7 +78,7 @@ def create_source( """ actor = server.get_user_or_default(user_id=user_id) - return server.create_source(request=source, user_id=actor.id) + return server.source_manager.create_source(source_create=source, actor=actor) @router.patch("/{source_id}", response_model=Source, operation_id="update_source") @@ -92,10 +92,7 @@ def update_source( Update the name or documentation of an existing data source. """ actor = server.get_user_or_default(user_id=user_id) - - assert source.id == source_id, "Source ID in path must match ID in request body" - - return server.update_source(request=source, user_id=actor.id) + return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor) @router.delete("/{source_id}", response_model=None, operation_id="delete_source") @@ -109,7 +106,7 @@ def delete_source( """ actor = server.get_user_or_default(user_id=user_id) - server.delete_source(source_id=source_id, user_id=actor.id) + server.delete_source(source_id=source_id, actor=actor) @router.post("/{source_id}/attach", response_model=Source, operation_id="attach_agent_to_source") @@ -124,7 +121,7 @@ def attach_source_to_agent( """ actor = server.get_user_or_default(user_id=user_id) - source = server.ms.get_source(source_id=source_id, user_id=actor.id) + source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) assert source is not None, f"Source with id={source_id} not found." source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id) return source @@ -158,7 +155,7 @@ def upload_file_to_source( """ actor = server.get_user_or_default(user_id=user_id) - source = server.ms.get_source(source_id=source_id, user_id=actor.id) + source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) assert source is not None, f"Source with id={source_id} not found." bytes = file.file.read() diff --git a/letta/server/server.py b/letta/server/server.py index c44706d3a2..afbf7b3b56 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -78,7 +78,7 @@ from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage from letta.schemas.organization import Organization from letta.schemas.passage import Passage -from letta.schemas.source import Source, SourceCreate, SourceUpdate +from letta.schemas.source import Source from letta.schemas.tool import Tool, ToolCreate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -1135,18 +1135,6 @@ def get_agent_id(self, name: str, user_id: str): return None return agent_state.id - def get_source(self, source_id: str, user_id: str) -> Source: - existing_source = self.ms.get_source(source_id=source_id, user_id=user_id) - if not existing_source: - raise ValueError("Source does not exist") - return existing_source - - def get_source_id(self, source_name: str, user_id: str) -> str: - existing_source = self.ms.get_source(source_name=source_name, user_id=user_id) - if not existing_source: - raise ValueError("Source does not exist") - return existing_source.id - def get_agent(self, user_id: str, agent_id: Optional[str] = None, agent_name: Optional[str] = None): """Get the agent state""" return self.ms.get_agent(agent_id=agent_id, agent_name=agent_name, user_id=user_id) @@ -1526,44 +1514,12 @@ def delete_api_key(self, api_key: str) -> APIKey: self.ms.delete_api_key(api_key=api_key) return api_key_obj - def create_source(self, request: SourceCreate, user_id: str) -> Source: # TODO: add other fields - """Create a new data source""" - source = Source( - name=request.name, - user_id=user_id, - embedding_config=self.list_embedding_models()[0], # TODO: require providing this - ) - self.ms.create_source(source) - assert self.ms.get_source(source_name=request.name, user_id=user_id) is not None, f"Failed to create source {request.name}" - return source - - def update_source(self, request: SourceUpdate, user_id: str) -> Source: - """Update an existing data source""" - if not request.id: - existing_source = self.ms.get_source(source_name=request.name, user_id=user_id) - else: - existing_source = self.ms.get_source(source_id=request.id) - if not existing_source: - raise ValueError("Source does not exist") - - # override updated fields - if request.name: - existing_source.name = request.name - if request.metadata_: - existing_source.metadata_ = request.metadata_ - if request.description: - existing_source.description = request.description - - self.ms.update_source(existing_source) - return existing_source - - def delete_source(self, source_id: str, user_id: str): + def delete_source(self, source_id: str, actor: User): """Delete a data source""" - source = self.ms.get_source(source_id=source_id, user_id=user_id) - self.ms.delete_source(source_id) + self.source_manager.delete_source(source_id=source_id, actor=actor) # delete data from passage store - passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) + passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=actor.id) passage_store.delete({"source_id": source_id}) # TODO: delete data from agent passage stores (?) @@ -1605,7 +1561,7 @@ def load_file_to_source(self, source_id: str, file_path: str, job_id: str) -> Jo # try: from letta.data_sources.connectors import DirectoryConnector - source = self.ms.get_source(source_id=source_id) + source = self.source_manager.get_source_by_id(source_id=source_id) connector = DirectoryConnector(input_files=[file_path]) num_passages, num_documents = self.load_data(user_id=source.user_id, source_name=source.name, connector=connector) # except Exception as e: @@ -1641,7 +1597,8 @@ def load_data( # TODO: this should be implemented as a batch job or at least async, since it may take a long time # load data from a data source into the document store - source = self.ms.get_source(source_name=source_name, user_id=user_id) + user = self.user_manager.get_user_by_id(user_id=user_id) + source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) if source is None: raise ValueError(f"Data source {source_name} does not exist for user {user_id}") @@ -1662,9 +1619,13 @@ def attach_source_to_agent( source_name: Optional[str] = None, ) -> Source: # attach a data source to an agent - data_source = self.ms.get_source(source_id=source_id, user_id=user_id, source_name=source_name) - if data_source is None: - raise ValueError(f"Data source id={source_id} name={source_name} does not exist for user_id {user_id}") + user = self.user_manager.get_user_by_id(user_id=user_id) + if source_id: + data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=user) + elif source_name: + data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) + else: + raise ValueError(f"Need to provide at least source_id or source_name to find the source.") # get connection to data source storage source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) @@ -1685,12 +1646,14 @@ def detach_source_from_agent( source_id: Optional[str] = None, source_name: Optional[str] = None, ) -> Source: - if not source_id: - assert source_name is not None, "source_name must be provided if source_id is not" - source = self.ms.get_source(source_name=source_name, user_id=user_id) - source_id = source.id + user = self.user_manager.get_user_by_id(user_id=user_id) + if source_id: + source = self.source_manager.get_source_by_id(source_id=source_id, actor=user) + elif source_name: + source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) else: - source = self.ms.get_source(source_id=source_id) + raise ValueError(f"Need to provide at least source_id or source_name to find the source.") + source_id = source.id # delete all Passage objects with source_id==source_id from agent's archival memory agent = self._get_or_load_agent(agent_id=agent_id) @@ -1715,17 +1678,17 @@ def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passag warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) return [] - def list_all_sources(self, user_id: str) -> List[Source]: + def list_all_sources(self, actor: User) -> List[Source]: """List all sources (w/ extra metadata) belonging to a user""" - sources = self.ms.list_sources(user_id=user_id) + sources = self.source_manager.list_sources(actor=actor) # Add extra metadata to the sources sources_with_metadata = [] for source in sources: # count number of passages - passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) + passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=actor.id) num_passages = passage_conn.size({"source_id": source.id}) # TODO: add when files table implemented @@ -1739,7 +1702,7 @@ def list_all_sources(self, user_id: str) -> List[Source]: attached_agents = [ { "id": str(a_id), - "name": self.ms.get_agent(user_id=user_id, agent_id=a_id).name, + "name": self.ms.get_agent(user_id=actor.id, agent_id=a_id).name, } for a_id in agent_ids ] diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 5859609597..3cc6961681 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -3,6 +3,7 @@ from letta.orm.errors import NoResultFound from letta.orm.organization import Organization as OrganizationModel from letta.orm.source import Source as SourceModel +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceCreate, SourceUpdate from letta.schemas.user import User as PydanticUser @@ -12,6 +13,18 @@ class SourceManager: """Manager class to handle business logic related to Sources.""" + # This is used when no embedding config is provided + DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig( + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_model="letta-free", + embedding_dim=1024, + embedding_chunk_size=300, + azure_endpoint=None, + azure_version=None, + azure_deployment=None, + ) + def __init__(self): from letta.server.server import db_context @@ -21,6 +34,10 @@ def __init__(self): def create_source(self, source_create: SourceCreate, actor: PydanticUser) -> PydanticSource: """Create a new source based on the SourceCreate schema.""" with self.session_maker() as session: + # Provide default embedding config if not given + if not source_create.embedding_config: + source_create.embedding_config = self.DEFAULT_EMBEDDING_CONFIG + create_data = source_create.model_dump() source = SourceModel(**create_data, organization_id=actor.organization_id) source.create(session, actor=actor) @@ -68,8 +85,9 @@ def list_sources(self, actor: PydanticUser, cursor: Optional[str] = None, limit: ) return [source.to_pydantic() for source in sources] + # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types - def get_source_by_id(self, source_id: str, actor: PydanticUser) -> PydanticSource: + def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> PydanticSource: """Retrieve a source by its ID.""" with self.session_maker() as session: source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) diff --git a/tests/test_local_client.py b/tests/test_local_client.py index 246abd710e..2ffd26f779 100644 --- a/tests/test_local_client.py +++ b/tests/test_local_client.py @@ -420,7 +420,7 @@ def test_tools_from_langchain(client: LocalClient): exec(source_code, {}, local_scope) func = local_scope[tool.name] - expected_content = "Albert Einstein ( EYEN-styne; German:" + expected_content = "Albert Einstein" assert expected_content in func(query="Albert Einstein") diff --git a/tests/test_managers.py b/tests/test_managers.py index 4378754770..d5377e7687 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4,28 +4,16 @@ import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.orm import Organization, Source, Tool, User -from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.source import SourceCreate, SourceUpdate from letta.schemas.tool import ToolCreate, ToolUpdate from letta.services.organization_manager import OrganizationManager +from letta.services.source_manager import SourceManager utils.DEBUG = True from letta.config import LettaConfig from letta.schemas.user import UserCreate, UserUpdate from letta.server.server import SyncServer -# Test Constants -EMBEDDING_CONFIG = EmbeddingConfig( - embedding_endpoint_type="hugging-face", - embedding_endpoint="https://embeddings.memgpt.ai", - embedding_model="letta-free", - embedding_dim=1024, - embedding_chunk_size=300, - azure_endpoint=None, - azure_version=None, - azure_deployment=None, -) - @pytest.fixture(autouse=True) def clear_tables(server: SyncServer): @@ -388,7 +376,10 @@ def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user): def test_create_source(server: SyncServer, default_user): """Test creating a new source.""" source_create = SourceCreate( - name="Test Source", description="This is a test source.", metadata_={"type": "test"}, embedding_config=EMBEDDING_CONFIG + name="Test Source", + description="This is a test source.", + metadata_={"type": "test"}, + embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG, ) source = server.source_manager.create_source(source_create=source_create, actor=default_user) @@ -401,7 +392,9 @@ def test_create_source(server: SyncServer, default_user): def test_update_source(server: SyncServer, default_user): """Test updating an existing source.""" - source_create = SourceCreate(name="Original Source", description="Original description", embedding_config=EMBEDDING_CONFIG) + source_create = SourceCreate( + name="Original Source", description="Original description", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG + ) source = server.source_manager.create_source(source_create=source_create, actor=default_user) # Update the source @@ -416,7 +409,9 @@ def test_update_source(server: SyncServer, default_user): def test_delete_source(server: SyncServer, default_user): """Test deleting a source.""" - source_create = SourceCreate(name="To Delete", description="This source will be deleted.", embedding_config=EMBEDDING_CONFIG) + source_create = SourceCreate( + name="To Delete", description="This source will be deleted.", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG + ) source = server.source_manager.create_source(source_create=source_create, actor=default_user) # Delete the source @@ -433,8 +428,12 @@ def test_delete_source(server: SyncServer, default_user): def test_list_sources(server: SyncServer, default_user): """Test listing sources with pagination.""" # Create multiple sources - server.source_manager.create_source(SourceCreate(name="Source 1", embedding_config=EMBEDDING_CONFIG), actor=default_user) - server.source_manager.create_source(SourceCreate(name="Source 2", embedding_config=EMBEDDING_CONFIG), actor=default_user) + server.source_manager.create_source( + SourceCreate(name="Source 1", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG), actor=default_user + ) + server.source_manager.create_source( + SourceCreate(name="Source 2", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG), actor=default_user + ) # List sources without pagination sources = server.source_manager.list_sources(actor=default_user) @@ -452,7 +451,9 @@ def test_list_sources(server: SyncServer, default_user): def test_get_source_by_id(server: SyncServer, default_user): """Test retrieving a source by ID.""" - source_create = SourceCreate(name="Retrieve by ID", description="Test source for ID retrieval", embedding_config=EMBEDDING_CONFIG) + source_create = SourceCreate( + name="Retrieve by ID", description="Test source for ID retrieval", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG + ) source = server.source_manager.create_source(source_create=source_create, actor=default_user) # Retrieve the source by ID @@ -466,7 +467,9 @@ def test_get_source_by_id(server: SyncServer, default_user): def test_get_source_by_name(server: SyncServer, default_user): """Test retrieving a source by name.""" - source_create = SourceCreate(name="Unique Source", description="Test source for name retrieval", embedding_config=EMBEDDING_CONFIG) + source_create = SourceCreate( + name="Unique Source", description="Test source for name retrieval", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG + ) source = server.source_manager.create_source(source_create=source_create, actor=default_user) # Retrieve the source by name @@ -479,7 +482,7 @@ def test_get_source_by_name(server: SyncServer, default_user): def test_update_source_no_changes(server: SyncServer, default_user): """Test update_source with no actual changes to verify logging and response.""" - source_create = SourceCreate(name="No Change Source", description="No changes", embedding_config=EMBEDDING_CONFIG) + source_create = SourceCreate(name="No Change Source", description="No changes", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG) source = server.source_manager.create_source(source_create=source_create, actor=default_user) # Attempt to update the source with identical data diff --git a/tests/test_server.py b/tests/test_server.py index 529e439728..9ca7358340 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -117,7 +117,7 @@ def test_user_message_memory(server, user_id, agent_id): @pytest.mark.order(3) def test_load_data(server, user_id, agent_id): # create source - source = server.create_source(SourceCreate(name="test_source"), user_id=user_id) + source = server.source_manager.create_source(SourceCreate(name="test_source"), actor=server.default_user) # load data archival_memories = [ From 12a913c433e3a0d40873a160892d597420edc744 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 4 Nov 2024 17:27:03 -0800 Subject: [PATCH 04/18] Add upgrade script --- .../versions/fcb07dfbc98f_add_source_table.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 alembic/versions/fcb07dfbc98f_add_source_table.py diff --git a/alembic/versions/fcb07dfbc98f_add_source_table.py b/alembic/versions/fcb07dfbc98f_add_source_table.py new file mode 100644 index 0000000000..8a0cb2bdf5 --- /dev/null +++ b/alembic/versions/fcb07dfbc98f_add_source_table.py @@ -0,0 +1,67 @@ +"""Add source table + +Revision ID: fcb07dfbc98f +Revises: eff245f340f9 +Create Date: 2024-11-04 17:25:39.841950 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op +from letta.orm.source import EmbeddingConfigColumn + +# revision identifiers, used by Alembic. +revision: str = "fcb07dfbc98f" +down_revision: Union[str, None] = "eff245f340f9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "source", + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False), + sa.Column("metadata_", sa.JSON(), nullable=True), + sa.Column("_organization_id", sa.String(), nullable=False), + sa.Column("_id", sa.String(), nullable=False), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["_organization_id"], + ["organization._id"], + ), + sa.PrimaryKeyConstraint("_id"), + sa.UniqueConstraint("name", "_organization_id", name="uix_source_name_organization"), + ) + op.drop_index("sources_idx_user", table_name="sources") + op.drop_table("sources") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "sources", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), + sa.Column("embedding_config", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.Column("description", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("metadata_", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint("id", name="sources_pkey"), + ) + op.create_index("sources_idx_user", "sources", ["user_id"], unique=False) + op.drop_table("source") + # ### end Alembic commands ### From 06336ef99d8aee2fb3efbddbe640288fc7087318 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 4 Nov 2024 17:45:15 -0800 Subject: [PATCH 05/18] Try to fix some unit tests --- letta/client/client.py | 2 +- letta/data_sources/connectors.py | 4 ++-- letta/metadata.py | 14 ++------------ letta/server/rest_api/routers/v1/sources.py | 2 +- letta/server/server.py | 6 ++++-- 5 files changed, 10 insertions(+), 18 deletions(-) diff --git a/letta/client/client.py b/letta/client/client.py index b6a84f26ce..8b90fff938 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1229,7 +1229,7 @@ def update_source(self, source_id: str, name: Optional[str] = None) -> Source: Returns: source (Source): Updated source """ - request = SourceUpdate(id=source_id, name=name) + request = SourceUpdate(name=name) response = requests.patch(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update source: {response.text}") diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index 91d8c2a55a..f9fb3d2af1 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -88,7 +88,7 @@ def load_data( file_id=file_metadata.id, source_id=source.id, metadata_=passage_metadata, - user_id=source.user_id, + user_id=source.created_by_id, embedding_config=source.embedding_config, embedding=embedding, ) @@ -155,7 +155,7 @@ def find_files(self, source: Source) -> Iterator[FileMetadata]: for metadata in extract_metadata_from_files(files): yield FileMetadata( - user_id=source.user_id, + user_id=source.created_by_id, source_id=source.id, file_name=metadata.get("file_name"), file_path=metadata.get("file_path"), diff --git a/letta/metadata.py b/letta/metadata.py index 4c74f34cbd..d16800db8b 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -29,7 +29,6 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction -from letta.schemas.source import Source from letta.schemas.tool_rule import ( BaseToolRule, InitToolRule, @@ -618,19 +617,10 @@ def attach_source(self, user_id: str, agent_id: str, source_id: str): session.commit() @enforce_types - def list_attached_sources(self, agent_id: str) -> List[Source]: + def list_attached_source_ids(self, agent_id: str) -> List[str]: with self.session_maker() as session: results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all() - - sources = [] - # make sure source exists - for r in results: - source = self.get_source(source_id=r.source_id) - if source: - sources.append(source) - else: - printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.") - return sources + return [r.source_id for r in results] @enforce_types def list_attached_agents(self, source_id: str) -> List[str]: diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index e1e2f793d3..be44bc9c1a 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -36,7 +36,7 @@ def get_source( """ actor = server.get_user_or_default(user_id=user_id) - return server.source(source_id=source_id, user_id=actor.id) + return server.source_manager.get_source_by_id(source_id=source_id, actor=actor) @router.get("/name/{source_name}", response_model=str, operation_id="get_source_id_by_name") diff --git a/letta/server/server.py b/letta/server/server.py index afbf7b3b56..eaa6f30ae8 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1563,7 +1563,7 @@ def load_file_to_source(self, source_id: str, file_path: str, job_id: str) -> Jo source = self.source_manager.get_source_by_id(source_id=source_id) connector = DirectoryConnector(input_files=[file_path]) - num_passages, num_documents = self.load_data(user_id=source.user_id, source_name=source.name, connector=connector) + num_passages, num_documents = self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector) # except Exception as e: # # job failed with error # error = str(e) @@ -1668,7 +1668,9 @@ def detach_source_from_agent( def list_attached_sources(self, agent_id: str) -> List[Source]: # list all attached sources to an agent - return self.ms.list_attached_sources(agent_id) + source_ids = self.ms.list_attached_source_ids(agent_id) + + return [self.source_manager.get_source_by_id(source_id=id) for id in source_ids] def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]: # list all attached sources to an agent From e7074002e277e9a89ba7da9543093e99ac07c294 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 4 Nov 2024 17:54:32 -0800 Subject: [PATCH 06/18] Delete sources before tests for test_client --- tests/test_client.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_client.py b/tests/test_client.py index 4cab823d20..e496343751 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,10 +6,12 @@ import pytest from dotenv import load_dotenv +from sqlalchemy import delete from letta import create_client from letta.client.client import LocalClient, RESTClient from letta.constants import DEFAULT_PRESET +from letta.orm import Source from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageStreamStatus @@ -82,6 +84,16 @@ def client(request): yield client +@pytest.fixture(autouse=True) +def clear_tables(): + """Fixture to clear the organization table before each test.""" + from letta.server.server import db_context + + with db_context() as session: + session.execute(delete(Source)) + session.commit() + + # Fixture for test agent @pytest.fixture(scope="module") def agent(client: Union[LocalClient, RESTClient]): From d96195f5f34e057b5fa4516967f0d9bd422cdeeb Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Mon, 4 Nov 2024 18:06:28 -0800 Subject: [PATCH 07/18] add defaults --- letta/client/client.py | 6 +++--- letta/schemas/source.py | 2 +- letta/services/source_manager.py | 16 -------------- tests/test_managers.py | 37 +++++++++++++++++--------------- 4 files changed, 24 insertions(+), 37 deletions(-) diff --git a/letta/client/client.py b/letta/client/client.py index 8b90fff938..9b6a6c7e7c 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2428,7 +2428,7 @@ def list_jobs(self): def list_active_jobs(self): return self.server.list_active_jobs(user_id=self.user_id) - def create_source(self, name: str) -> Source: + def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source: """ Create a source @@ -2438,8 +2438,8 @@ def create_source(self, name: str) -> Source: Returns: source (Source): Created source """ - request = SourceCreate(name=name) - return self.server.source_manager.create_source(request=request, actor=actor) + request = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config) + return self.server.source_manager.create_source(request=request, actor=self.user) def delete_source(self, source_id: str): """ diff --git a/letta/schemas/source.py b/letta/schemas/source.py index 747dfce06f..75362211b1 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -49,7 +49,7 @@ class SourceCreate(BaseSource): # required name: str = Field(..., description="The name of the source.") - embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.") + embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") # optional description: Optional[str] = Field(None, description="The description of the source.") diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 3cc6961681..d0404d110b 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -3,7 +3,6 @@ from letta.orm.errors import NoResultFound from letta.orm.organization import Organization as OrganizationModel from letta.orm.source import Source as SourceModel -from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceCreate, SourceUpdate from letta.schemas.user import User as PydanticUser @@ -13,18 +12,6 @@ class SourceManager: """Manager class to handle business logic related to Sources.""" - # This is used when no embedding config is provided - DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig( - embedding_endpoint_type="hugging-face", - embedding_endpoint="https://embeddings.memgpt.ai", - embedding_model="letta-free", - embedding_dim=1024, - embedding_chunk_size=300, - azure_endpoint=None, - azure_version=None, - azure_deployment=None, - ) - def __init__(self): from letta.server.server import db_context @@ -35,9 +22,6 @@ def create_source(self, source_create: SourceCreate, actor: PydanticUser) -> Pyd """Create a new source based on the SourceCreate schema.""" with self.session_maker() as session: # Provide default embedding config if not given - if not source_create.embedding_config: - source_create.embedding_config = self.DEFAULT_EMBEDDING_CONFIG - create_data = source_create.model_dump() source = SourceModel(**create_data, organization_id=actor.organization_id) source.create(session, actor=actor) diff --git a/tests/test_managers.py b/tests/test_managers.py index d5377e7687..2c2d36b3ec 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4,16 +4,27 @@ import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.orm import Organization, Source, Tool, User +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.source import SourceCreate, SourceUpdate from letta.schemas.tool import ToolCreate, ToolUpdate from letta.services.organization_manager import OrganizationManager -from letta.services.source_manager import SourceManager utils.DEBUG = True from letta.config import LettaConfig from letta.schemas.user import UserCreate, UserUpdate from letta.server.server import SyncServer +DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig( + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_model="letta-free", + embedding_dim=1024, + embedding_chunk_size=300, + azure_endpoint=None, + azure_version=None, + azure_deployment=None, +) + @pytest.fixture(autouse=True) def clear_tables(server: SyncServer): @@ -379,7 +390,7 @@ def test_create_source(server: SyncServer, default_user): name="Test Source", description="This is a test source.", metadata_={"type": "test"}, - embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG, + embedding_config=DEFAULT_EMBEDDING_CONFIG, ) source = server.source_manager.create_source(source_create=source_create, actor=default_user) @@ -392,9 +403,7 @@ def test_create_source(server: SyncServer, default_user): def test_update_source(server: SyncServer, default_user): """Test updating an existing source.""" - source_create = SourceCreate( - name="Original Source", description="Original description", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG - ) + source_create = SourceCreate(name="Original Source", description="Original description", embedding_config=DEFAULT_EMBEDDING_CONFIG) source = server.source_manager.create_source(source_create=source_create, actor=default_user) # Update the source @@ -409,9 +418,7 @@ def test_update_source(server: SyncServer, default_user): def test_delete_source(server: SyncServer, default_user): """Test deleting a source.""" - source_create = SourceCreate( - name="To Delete", description="This source will be deleted.", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG - ) + source_create = SourceCreate(name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG) source = server.source_manager.create_source(source_create=source_create, actor=default_user) # Delete the source @@ -428,12 +435,8 @@ def test_delete_source(server: SyncServer, default_user): def test_list_sources(server: SyncServer, default_user): """Test listing sources with pagination.""" # Create multiple sources - server.source_manager.create_source( - SourceCreate(name="Source 1", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG), actor=default_user - ) - server.source_manager.create_source( - SourceCreate(name="Source 2", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG), actor=default_user - ) + server.source_manager.create_source(SourceCreate(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) + server.source_manager.create_source(SourceCreate(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) # List sources without pagination sources = server.source_manager.list_sources(actor=default_user) @@ -452,7 +455,7 @@ def test_list_sources(server: SyncServer, default_user): def test_get_source_by_id(server: SyncServer, default_user): """Test retrieving a source by ID.""" source_create = SourceCreate( - name="Retrieve by ID", description="Test source for ID retrieval", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG + name="Retrieve by ID", description="Test source for ID retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG ) source = server.source_manager.create_source(source_create=source_create, actor=default_user) @@ -468,7 +471,7 @@ def test_get_source_by_id(server: SyncServer, default_user): def test_get_source_by_name(server: SyncServer, default_user): """Test retrieving a source by name.""" source_create = SourceCreate( - name="Unique Source", description="Test source for name retrieval", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG + name="Unique Source", description="Test source for name retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG ) source = server.source_manager.create_source(source_create=source_create, actor=default_user) @@ -482,7 +485,7 @@ def test_get_source_by_name(server: SyncServer, default_user): def test_update_source_no_changes(server: SyncServer, default_user): """Test update_source with no actual changes to verify logging and response.""" - source_create = SourceCreate(name="No Change Source", description="No changes", embedding_config=SourceManager.DEFAULT_EMBEDDING_CONFIG) + source_create = SourceCreate(name="No Change Source", description="No changes", embedding_config=DEFAULT_EMBEDDING_CONFIG) source = server.source_manager.create_source(source_create=source_create, actor=default_user) # Attempt to update the source with identical data From 48abcba2f26a655f529dbc26db6a4f8f5d2eab7c Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 7 Nov 2024 12:13:26 -0800 Subject: [PATCH 08/18] organize imports --- tests/test_managers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_managers.py b/tests/test_managers.py index 4a2e94170b..8a6fb85cb6 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4,9 +4,6 @@ import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.orm import Organization, Source, Tool, User -from letta.orm.organization import Organization -from letta.orm.tool import Tool -from letta.orm.user import User from letta.schemas.agent import CreateAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig From 902338d90d49015408dc49e46ebbe920336d1628 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 7 Nov 2024 12:20:42 -0800 Subject: [PATCH 09/18] add default id to source create --- letta/schemas/source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/letta/schemas/source.py b/letta/schemas/source.py index f7d3ebe26b..ebff8bf3b4 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -48,7 +48,7 @@ class SourceCreate(BaseSource): """ # required - id: str = Field(BaseSource.generate_id_field(), description="The id of the source.") + id: str = BaseSource.generate_id_field() name: str = Field(..., description="The name of the source.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") From 55d4d2dfe39ae0bd6d592a1564c40c781aac114d Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 7 Nov 2024 12:30:43 -0800 Subject: [PATCH 10/18] merge main --- .../versions/fcb07dfbc98f_add_source_table.py | 67 ------------------- 1 file changed, 67 deletions(-) delete mode 100644 alembic/versions/fcb07dfbc98f_add_source_table.py diff --git a/alembic/versions/fcb07dfbc98f_add_source_table.py b/alembic/versions/fcb07dfbc98f_add_source_table.py deleted file mode 100644 index 8a0cb2bdf5..0000000000 --- a/alembic/versions/fcb07dfbc98f_add_source_table.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Add source table - -Revision ID: fcb07dfbc98f -Revises: eff245f340f9 -Create Date: 2024-11-04 17:25:39.841950 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - -from alembic import op -from letta.orm.source import EmbeddingConfigColumn - -# revision identifiers, used by Alembic. -revision: str = "fcb07dfbc98f" -down_revision: Union[str, None] = "eff245f340f9" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "source", - sa.Column("name", sa.String(), nullable=False), - sa.Column("description", sa.String(), nullable=True), - sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False), - sa.Column("metadata_", sa.JSON(), nullable=True), - sa.Column("_organization_id", sa.String(), nullable=False), - sa.Column("_id", sa.String(), nullable=False), - sa.Column("deleted", sa.Boolean(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), - sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), - sa.Column("_created_by_id", sa.String(), nullable=True), - sa.Column("_last_updated_by_id", sa.String(), nullable=True), - sa.ForeignKeyConstraint( - ["_organization_id"], - ["organization._id"], - ), - sa.PrimaryKeyConstraint("_id"), - sa.UniqueConstraint("name", "_organization_id", name="uix_source_name_organization"), - ) - op.drop_index("sources_idx_user", table_name="sources") - op.drop_table("sources") - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "sources", - sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), - sa.Column("embedding_config", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.Column("description", sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column("metadata_", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint("id", name="sources_pkey"), - ) - op.create_index("sources_idx_user", "sources", ["user_id"], unique=False) - op.drop_table("source") - # ### end Alembic commands ### From 798643e1f5d82a9b7d93f3d98b7e99ae4d941e6d Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 7 Nov 2024 13:36:06 -0800 Subject: [PATCH 11/18] Add migration script --- .../cda66b6cb0d6_move_sources_to_orm.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 alembic/versions/cda66b6cb0d6_move_sources_to_orm.py diff --git a/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py b/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py new file mode 100644 index 0000000000..b247a902dd --- /dev/null +++ b/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py @@ -0,0 +1,66 @@ +"""Move sources to orm + +Revision ID: cda66b6cb0d6 +Revises: b6d7ca024aa9 +Create Date: 2024-11-07 13:29:57.186107 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "cda66b6cb0d6" +down_revision: Union[str, None] = "b6d7ca024aa9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("sources", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("sources", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) + op.add_column("sources", sa.Column("_created_by_id", sa.String(), nullable=True)) + op.add_column("sources", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) + + # Data migration step: + op.add_column("sources", sa.Column("organization_id", sa.String(), nullable=True)) + # Populate `organization_id` based on `user_id` + # Use a raw SQL query to update the organization_id + op.execute( + """ + UPDATE sources + SET organization_id = users.organization_id + FROM users + WHERE sources.user_id = users.id + """ + ) + + # Set `organization_id` as non-nullable after population + op.alter_column("sources", "organization_id", nullable=False) + + op.alter_column("sources", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) + op.drop_index("sources_idx_user", table_name="sources") + op.create_unique_constraint("uix_sources_name_organization", "sources", ["name", "organization_id"]) + op.create_foreign_key(None, "sources", "organizations", ["organization_id"], ["id"]) + op.drop_column("sources", "user_id") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("sources", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.drop_constraint(None, "sources", type_="foreignkey") + op.drop_constraint("uix_sources_name_organization", "sources", type_="unique") + op.create_index("sources_idx_user", "sources", ["user_id"], unique=False) + op.alter_column("sources", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) + op.drop_column("sources", "organization_id") + op.drop_column("sources", "_last_updated_by_id") + op.drop_column("sources", "_created_by_id") + op.drop_column("sources", "is_deleted") + op.drop_column("sources", "updated_at") + # ### end Alembic commands ### From 1f3bee47cd7d3b40df2e731b2285b307353ecaaa Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 7 Nov 2024 14:30:57 -0800 Subject: [PATCH 12/18] Fix tests --- letta/client/client.py | 12 ++++--- letta/schemas/source.py | 3 +- letta/server/rest_api/routers/v1/sources.py | 5 +-- letta/services/organization_manager.py | 25 +++++++------ letta/services/source_manager.py | 10 +++--- tests/test_managers.py | 39 +++++++++++---------- tests/test_server.py | 8 +++-- 7 files changed, 55 insertions(+), 47 deletions(-) diff --git a/letta/client/client.py b/letta/client/client.py index 726ed7e4ec..57ff182f78 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -238,7 +238,7 @@ def load_file_to_source(self, filename: str, source_id: str, blocking=True) -> J def delete_file_from_source(self, source_id: str, file_id: str) -> None: raise NotImplementedError - def create_source(self, name: str) -> Source: + def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source: raise NotImplementedError def delete_source(self, source_id: str): @@ -1188,7 +1188,7 @@ def delete_file_from_source(self, source_id: str, file_id: str) -> None: if response.status_code not in [200, 204]: raise ValueError(f"Failed to delete tool: {response.text}") - def create_source(self, name: str) -> Source: + def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source: """ Create a source @@ -1198,7 +1198,7 @@ def create_source(self, name: str) -> Source: Returns: source (Source): Created source """ - payload = {"name": name} + payload = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config) response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers) response_json = response.json() return Source(**response_json) @@ -2463,8 +2463,10 @@ def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = Returns: source (Source): Created source """ - request = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config) - return self.server.source_manager.create_source(request=request, actor=self.user) + source = Source( + name=name, embedding_config=embedding_config or self._default_embedding_config, organization_id=self.user.organization_id + ) + return self.server.source_manager.create_source(source=source, actor=self.user) def delete_source(self, source_id: str): """ diff --git a/letta/schemas/source.py b/letta/schemas/source.py index ebff8bf3b4..f7ea2439f4 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -32,7 +32,7 @@ class Source(BaseSource): name: str = Field(..., description="The name of the source.") description: Optional[str] = Field(None, description="The description of the source.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") - organization_id: str = Field(..., description="The ID of the organization that created the source.") + organization_id: Optional[str] = Field(None, description="The ID of the organization that created the source.") metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") # metadata fields @@ -48,7 +48,6 @@ class SourceCreate(BaseSource): """ # required - id: str = BaseSource.generate_id_field() name: str = Field(..., description="The name of the source.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index be44bc9c1a..58047f1296 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -69,7 +69,7 @@ def list_sources( @router.post("/", response_model=Source, operation_id="create_source") def create_source( - source: SourceCreate, + source_create: SourceCreate, server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): @@ -77,8 +77,9 @@ def create_source( Create a new data source. """ actor = server.get_user_or_default(user_id=user_id) + source = Source(**source_create.model_dump()) - return server.source_manager.create_source(source_create=source, actor=actor) + return server.source_manager.create_source(source=source, actor=actor) @router.patch("/{source_id}", response_model=Source, operation_id="update_source") diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index 1832c58013..1b7f18b6d8 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -27,18 +27,26 @@ def get_default_organization(self) -> PydanticOrganization: return self.get_organization_by_id(self.DEFAULT_ORG_ID) @enforce_types - def get_organization_by_id(self, org_id: str) -> PydanticOrganization: + def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]: """Fetch an organization by ID.""" with self.session_maker() as session: try: organization = OrganizationModel.read(db_session=session, identifier=org_id) return organization.to_pydantic() except NoResultFound: - raise ValueError(f"Organization with id {org_id} not found.") + return None @enforce_types def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: """Create a new organization. If a name is provided, it is used, otherwise, a random one is generated.""" + org = self.get_organization_by_id(pydantic_org.id) + if org: + return org + else: + return self._create_organization(pydantic_org=pydantic_org) + + @enforce_types + def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: with self.session_maker() as session: org = OrganizationModel(**pydantic_org.model_dump()) org.create(session) @@ -47,16 +55,7 @@ def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrg @enforce_types def create_default_organization(self) -> PydanticOrganization: """Create the default organization.""" - with self.session_maker() as session: - # Try to get it first - try: - org = OrganizationModel.read(db_session=session, identifier=self.DEFAULT_ORG_ID) - # If it doesn't exist, make it - except NoResultFound: - org = OrganizationModel(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID) - org.create(session) - - return org.to_pydantic() + return self.create_organization(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)) @enforce_types def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization: @@ -73,7 +72,7 @@ def delete_organization_by_id(self, org_id: str): """Delete an organization by marking it as deleted.""" with self.session_maker() as session: organization = OrganizationModel.read(db_session=session, identifier=org_id) - organization.delete(session) + organization.hard_delete(session) @enforce_types def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 82090bd7a3..6d4569ce31 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -3,7 +3,7 @@ from letta.orm.errors import NoResultFound from letta.orm.source import Source as SourceModel from letta.schemas.source import Source as PydanticSource -from letta.schemas.source import SourceCreate, SourceUpdate +from letta.schemas.source import SourceUpdate from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types, printd @@ -17,12 +17,12 @@ def __init__(self): self.session_maker = db_context @enforce_types - def create_source(self, source_create: SourceCreate, actor: PydanticUser) -> PydanticSource: - """Create a new source based on the SourceCreate schema.""" + def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource: + """Create a new source based on the PydanticSource schema.""" with self.session_maker() as session: # Provide default embedding config if not given - create_data = source_create.model_dump() - source = SourceModel(**create_data, organization_id=actor.organization_id) + source.organization_id = actor.organization_id + source = SourceModel(**source.model_dump(exclude_none=True)) source.create(session, actor=actor) return source.to_pydantic() diff --git a/tests/test_managers.py b/tests/test_managers.py index 8a6fb85cb6..542b7d86e2 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -9,7 +9,8 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization as PydanticOrganization -from letta.schemas.source import SourceCreate, SourceUpdate +from letta.schemas.source import Source as PydanticSource +from letta.schemas.source import SourceUpdate from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolUpdate from letta.services.organization_manager import OrganizationManager @@ -425,25 +426,25 @@ def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user): def test_create_source(server: SyncServer, default_user): """Test creating a new source.""" - source_create = SourceCreate( + source_pydantic = PydanticSource( name="Test Source", description="This is a test source.", metadata_={"type": "test"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) - source = server.source_manager.create_source(source_create=source_create, actor=default_user) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) # Assertions to check the created source - assert source.name == source_create.name - assert source.description == source_create.description - assert source.metadata_ == source_create.metadata_ + assert source.name == source_pydantic.name + assert source.description == source_pydantic.description + assert source.metadata_ == source_pydantic.metadata_ assert source.organization_id == default_user.organization_id def test_update_source(server: SyncServer, default_user): """Test updating an existing source.""" - source_create = SourceCreate(name="Original Source", description="Original description", embedding_config=DEFAULT_EMBEDDING_CONFIG) - source = server.source_manager.create_source(source_create=source_create, actor=default_user) + source_pydantic = PydanticSource(name="Original Source", description="Original description", embedding_config=DEFAULT_EMBEDDING_CONFIG) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) # Update the source update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata_={"type": "updated"}) @@ -457,8 +458,10 @@ def test_update_source(server: SyncServer, default_user): def test_delete_source(server: SyncServer, default_user): """Test deleting a source.""" - source_create = SourceCreate(name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG) - source = server.source_manager.create_source(source_create=source_create, actor=default_user) + source_pydantic = PydanticSource( + name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG + ) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) # Delete the source deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user) @@ -474,8 +477,8 @@ def test_delete_source(server: SyncServer, default_user): def test_list_sources(server: SyncServer, default_user): """Test listing sources with pagination.""" # Create multiple sources - server.source_manager.create_source(SourceCreate(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) - server.source_manager.create_source(SourceCreate(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) + server.source_manager.create_source(PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) + server.source_manager.create_source(PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) # List sources without pagination sources = server.source_manager.list_sources(actor=default_user) @@ -493,10 +496,10 @@ def test_list_sources(server: SyncServer, default_user): def test_get_source_by_id(server: SyncServer, default_user): """Test retrieving a source by ID.""" - source_create = SourceCreate( + source_pydantic = PydanticSource( name="Retrieve by ID", description="Test source for ID retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG ) - source = server.source_manager.create_source(source_create=source_create, actor=default_user) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) # Retrieve the source by ID retrieved_source = server.source_manager.get_source_by_id(source_id=source.id, actor=default_user) @@ -509,10 +512,10 @@ def test_get_source_by_id(server: SyncServer, default_user): def test_get_source_by_name(server: SyncServer, default_user): """Test retrieving a source by name.""" - source_create = SourceCreate( + source_pydantic = PydanticSource( name="Unique Source", description="Test source for name retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG ) - source = server.source_manager.create_source(source_create=source_create, actor=default_user) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) # Retrieve the source by name retrieved_source = server.source_manager.get_source_by_name(source_name=source.name, actor=default_user) @@ -524,8 +527,8 @@ def test_get_source_by_name(server: SyncServer, default_user): def test_update_source_no_changes(server: SyncServer, default_user): """Test update_source with no actual changes to verify logging and response.""" - source_create = SourceCreate(name="No Change Source", description="No changes", embedding_config=DEFAULT_EMBEDDING_CONFIG) - source = server.source_manager.create_source(source_create=source_create, actor=default_user) + source_pydantic = PydanticSource(name="No Change Source", description="No changes", embedding_config=DEFAULT_EMBEDDING_CONFIG) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) # Attempt to update the source with identical data update_data = SourceUpdate(name="No Change Source", description="No changes") diff --git a/tests/test_server.py b/tests/test_server.py index aa6efd82dd..bd68e45ab7 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -8,6 +8,8 @@ from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.enums import MessageRole +from .test_managers import DEFAULT_EMBEDDING_CONFIG + utils.DEBUG = True from letta.config import LettaConfig from letta.schemas.agent import CreateAgent @@ -24,7 +26,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory from letta.schemas.message import Message -from letta.schemas.source import SourceCreate +from letta.schemas.source import Source from letta.server.server import SyncServer from .utils import DummyDataConnector @@ -117,7 +119,9 @@ def test_user_message_memory(server, user_id, agent_id): @pytest.mark.order(3) def test_load_data(server, user_id, agent_id): # create source - source = server.source_manager.create_source(SourceCreate(name="test_source"), actor=server.default_user) + source = server.source_manager.create_source( + Source(name="test_source", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=server.default_user + ) # load data archival_memories = [ From 1900e9f0ef726e4fc75c36a6c54b4d7562cc0bbb Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 7 Nov 2024 14:34:13 -0800 Subject: [PATCH 13/18] Fix client --- letta/client/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/letta/client/client.py b/letta/client/client.py index 57ff182f78..519902ba47 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1198,7 +1198,8 @@ def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = Returns: source (Source): Created source """ - payload = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config) + source_create = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config) + payload = source_create.model_dump() response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers) response_json = response.json() return Source(**response_json) From ae527417bc7d89e4dd4fc1092b1ddcb6c2576192 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 7 Nov 2024 16:54:51 -0800 Subject: [PATCH 14/18] Make source creation safe from duplication --- letta/services/source_manager.py | 33 ++++++++++++++++++++------------ tests/test_managers.py | 15 +++++++++++++++ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 6d4569ce31..ce52bb142c 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -19,12 +19,17 @@ def __init__(self): @enforce_types def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource: """Create a new source based on the PydanticSource schema.""" - with self.session_maker() as session: - # Provide default embedding config if not given - source.organization_id = actor.organization_id - source = SourceModel(**source.model_dump(exclude_none=True)) - source.create(session, actor=actor) - return source.to_pydantic() + # Try getting the source first by id or name + db_source = self.get_source_by_id(source.id, actor=actor) or self.get_source_by_name(source.name, actor=actor) + if db_source: + return db_source + else: + with self.session_maker() as session: + # Provide default embedding config if not given + source.organization_id = actor.organization_id + source = SourceModel(**source.model_dump(exclude_none=True)) + source.create(session, actor=actor) + return source.to_pydantic() @enforce_types def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource: @@ -70,14 +75,17 @@ def list_sources(self, actor: PydanticUser, cursor: Optional[str] = None, limit: # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types - def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> PydanticSource: + def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]: """Retrieve a source by its ID.""" with self.session_maker() as session: - source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) - return source.to_pydantic() + try: + source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) + return source.to_pydantic() + except NoResultFound: + return None @enforce_types - def get_source_by_name(self, source_name: str, actor: PydanticUser) -> PydanticSource: + def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]: """Retrieve a source by its name.""" with self.session_maker() as session: sources = SourceModel.list( @@ -87,5 +95,6 @@ def get_source_by_name(self, source_name: str, actor: PydanticUser) -> PydanticS limit=1, ) if not sources: - raise NoResultFound(f"Source with name '{source_name}' not found.") - return sources[0].to_pydantic() + return None + else: + return sources[0].to_pydantic() diff --git a/tests/test_managers.py b/tests/test_managers.py index 542b7d86e2..a49d8d3719 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -441,6 +441,21 @@ def test_create_source(server: SyncServer, default_user): assert source.organization_id == default_user.organization_id +def test_create_sources_with_same_name_does_not_error(server: SyncServer, default_user): + """Test creating a new source.""" + name = "Test Source" + source_pydantic = PydanticSource( + name=name, + description="This is a test source.", + metadata_={"type": "test"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + same_source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + + assert source.id == same_source.id + + def test_update_source(server: SyncServer, default_user): """Test updating an existing source.""" source_pydantic = PydanticSource(name="Original Source", description="Original description", embedding_config=DEFAULT_EMBEDDING_CONFIG) From 909cf830150bff18baf5c750ce60b73ec5bfc8fb Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 7 Nov 2024 17:17:37 -0800 Subject: [PATCH 15/18] make optional for now for web migration --- letta/schemas/source.py | 3 ++- tests/test_managers.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/letta/schemas/source.py b/letta/schemas/source.py index f7ea2439f4..0a458dfded 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -49,7 +49,8 @@ class SourceCreate(BaseSource): # required name: str = Field(..., description="The name of the source.") - embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") + # TODO: @matt, make this required after shub makes the FE changes + embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.") # optional description: Optional[str] = Field(None, description="The description of the source.") diff --git a/tests/test_managers.py b/tests/test_managers.py index a49d8d3719..7dcb730929 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -447,10 +447,16 @@ def test_create_sources_with_same_name_does_not_error(server: SyncServer, defaul source_pydantic = PydanticSource( name=name, description="This is a test source.", - metadata_={"type": "test"}, + metadata_={"type": "medical"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source_pydantic = PydanticSource( + name=name, + description="This is a different test source.", + metadata_={"type": "legal"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) same_source = server.source_manager.create_source(source=source_pydantic, actor=default_user) assert source.id == same_source.id From 90039926f1466671d09c6b643cdf8309ad4d2bb6 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 8 Nov 2024 15:24:36 -0800 Subject: [PATCH 16/18] hard delete --- letta/services/source_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index ce52bb142c..9e290a7f68 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -58,7 +58,7 @@ def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource: """Delete a source by its ID.""" with self.session_maker() as session: source = SourceModel.read(db_session=session, identifier=source_id) - source.delete(db_session=session, actor=actor) + source.hard_delete(db_session=session, actor=actor) return source.to_pydantic() @enforce_types From deaade35551c3fe77d2c8cc2e9cf867f3df000da Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 12 Nov 2024 09:41:26 -0800 Subject: [PATCH 17/18] Remove unique constraint on org --- alembic/versions/cda66b6cb0d6_move_sources_to_orm.py | 2 -- letta/orm/source.py | 6 +----- letta/services/source_manager.py | 6 +++--- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py b/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py index b247a902dd..f46bef6b4d 100644 --- a/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py +++ b/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py @@ -45,7 +45,6 @@ def upgrade() -> None: op.alter_column("sources", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) op.drop_index("sources_idx_user", table_name="sources") - op.create_unique_constraint("uix_sources_name_organization", "sources", ["name", "organization_id"]) op.create_foreign_key(None, "sources", "organizations", ["organization_id"], ["id"]) op.drop_column("sources", "user_id") # ### end Alembic commands ### @@ -55,7 +54,6 @@ def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.add_column("sources", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False)) op.drop_constraint(None, "sources", type_="foreignkey") - op.drop_constraint("uix_sources_name_organization", "sources", type_="unique") op.create_index("sources_idx_user", "sources", ["user_id"], unique=False) op.alter_column("sources", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) op.drop_column("sources", "organization_id") diff --git a/letta/orm/source.py b/letta/orm/source.py index 79bf795d46..e8a7ed47a7 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Optional -from sqlalchemy import JSON, TypeDecorator, UniqueConstraint +from sqlalchemy import JSON, TypeDecorator from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin @@ -40,10 +40,6 @@ class Source(SqlalchemyBase, OrganizationMixin): __tablename__ = "sources" __pydantic_model__ = PydanticSource - # Add unique constraint on (name, organization_id) - # An organization should not have multiple sources with the same name - __table_args__ = (UniqueConstraint("name", "organization_id", name="uix_sources_name_organization"),) - name: Mapped[str] = mapped_column(doc="the name of the source, must be unique within the org", nullable=False) description: Mapped[str] = mapped_column(nullable=True, doc="a human-readable description of the source") embedding_config: Mapped[EmbeddingConfig] = mapped_column(EmbeddingConfigColumn, doc="Configuration settings for embedding.") diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 9e290a7f68..e09bddd959 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -19,8 +19,8 @@ def __init__(self): @enforce_types def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource: """Create a new source based on the PydanticSource schema.""" - # Try getting the source first by id or name - db_source = self.get_source_by_id(source.id, actor=actor) or self.get_source_by_name(source.name, actor=actor) + # Try getting the source first by id + db_source = self.get_source_by_id(source.id, actor=actor) if db_source: return db_source else: @@ -58,7 +58,7 @@ def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource: """Delete a source by its ID.""" with self.session_maker() as session: source = SourceModel.read(db_session=session, identifier=source_id) - source.hard_delete(db_session=session, actor=actor) + source.delete(db_session=session, actor=actor) return source.to_pydantic() @enforce_types From 07ea01d4c76d4296872bc13b8d00156ae67110dd Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 12 Nov 2024 09:44:11 -0800 Subject: [PATCH 18/18] Adjust tests --- tests/test_managers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_managers.py b/tests/test_managers.py index 7dcb730929..8436d7ba6d 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -459,7 +459,8 @@ def test_create_sources_with_same_name_does_not_error(server: SyncServer, defaul ) same_source = server.source_manager.create_source(source=source_pydantic, actor=default_user) - assert source.id == same_source.id + assert source.name == same_source.name + assert source.id != same_source.id def test_update_source(server: SyncServer, default_user):