Skip to content

Commit

Permalink
Merge pull request #2065 from Agenta-AI/oss-feature/age-667-project-2…
Browse files Browse the repository at this point in the history
…-update-scope-in-database_routers

[feature] Projects - Checkpoint 2
  • Loading branch information
aakrem authored Sep 28, 2024
2 parents 86ca203 + 805d430 commit 964f4cd
Show file tree
Hide file tree
Showing 40 changed files with 1,677 additions and 987 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import uuid
import traceback
from typing import Optional


import click
from sqlalchemy.future import select
from sqlalchemy import create_engine, delete
from sqlalchemy.orm import sessionmaker, Session

from agenta_backend.models.deprecated_models import (
DeprecatedEvaluatorConfigDB,
DeprecatedAppDB,
)


BATCH_SIZE = 1000


def get_app_db(session: Session, app_id: str) -> Optional[DeprecatedAppDB]:
query = session.execute(select(DeprecatedAppDB).filter_by(id=uuid.UUID(app_id)))
return query.scalars().first()


def update_evaluators_with_app_name():
engine = create_engine(os.getenv("POSTGRES_URI"))
sync_session = sessionmaker(engine, expire_on_commit=False)

with sync_session() as session:
try:
offset = 0
while True:
records = (
session.execute(
select(DeprecatedEvaluatorConfigDB)
.filter(DeprecatedEvaluatorConfigDB.app_id.isnot(None))
.offset(offset)
.limit(BATCH_SIZE)
)
.scalars()
.all()
)
if not records:
break

# Update records with app_name as prefix
for record in records:
evaluator_config_app = get_app_db(
session=session, app_id=str(record.app_id)
)
if record.app_id is not None and evaluator_config_app is not None:
record.name = f"{record.name} ({evaluator_config_app.app_name})"

session.commit()
offset += BATCH_SIZE

# Delete deprecated evaluator configs with app_id as None
session.execute(
delete(DeprecatedEvaluatorConfigDB).where(
DeprecatedEvaluatorConfigDB.app_id.is_(None)
)
)
session.commit()
except Exception as e:
session.rollback()
click.echo(
click.style(
f"ERROR updating evaluator config names: {traceback.format_exc()}",
fg="red",
)
)
raise e
Original file line number Diff line number Diff line change
@@ -1,41 +1,86 @@
import os
import traceback
from typing import Sequence


import click
from sqlalchemy.future import select
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from agenta_backend.models.db_models import ProjectDB
from sqlalchemy.orm import sessionmaker, Session

from agenta_backend.models.db_models import (
ProjectDB,
AppDB,
AppVariantDB,
AppVariantRevisionsDB,
VariantBaseDB,
DeploymentDB,
ImageDB,
AppEnvironmentDB,
AppEnvironmentRevisionDB,
EvaluationScenarioDB,
EvaluationDB,
EvaluatorConfigDB,
HumanEvaluationDB,
HumanEvaluationScenarioDB,
TestSetDB,
)


BATCH_SIZE = 1000
MODELS = [
AppDB,
AppVariantDB,
AppVariantRevisionsDB,
VariantBaseDB,
DeploymentDB,
ImageDB,
AppEnvironmentDB,
AppEnvironmentRevisionDB,
EvaluationScenarioDB,
EvaluationDB,
EvaluatorConfigDB,
HumanEvaluationDB,
HumanEvaluationScenarioDB,
TestSetDB,
]


def get_default_projects(session):
query = session.execute(select(ProjectDB).filter_by(is_default=True))
return query.scalars().all()


def check_for_multiple_default_projects(session: Session) -> Sequence[ProjectDB]:
default_projects = get_default_projects(session)
if len(default_projects) > 1:
raise ValueError(
"Multiple default projects found. Please ensure only one exists."
)
return default_projects


def create_default_project():
PROJECT_NAME = "Default Project"
engine = create_engine(os.getenv("POSTGRES_URI"))
sync_session = sessionmaker(engine, expire_on_commit=False)

with sync_session() as session:
try:
default_projects = get_default_projects(session)
if len(default_projects) > 1:
raise ValueError(
"Multiple default projects found. Please ensure only one exists."
)

default_projects = check_for_multiple_default_projects(session)
if len(default_projects) == 0:
new_project = ProjectDB(project_name=PROJECT_NAME, is_default=True)
session.add(new_project)
session.commit()

except Exception as e:
session.rollback()
click.echo(click.style(f"ERROR: {traceback.format_exc()}", fg="red"))
click.echo(
click.style(
f"ERROR creating default project: {traceback.format_exc()}",
fg="red",
)
)
raise e


Expand All @@ -45,18 +90,13 @@ def remove_default_project():

with sync_session() as session:
try:
default_projects = get_default_projects(session)
default_projects = check_for_multiple_default_projects(session)
if len(default_projects) == 0:
click.echo(
click.style("No default project found to remove.", fg="yellow")
)
return

if len(default_projects) > 1:
raise ValueError(
"Multiple default projects found. Please ensure only one exists."
)

session.delete(default_projects[0])
session.commit()
click.echo(click.style("Default project removed successfully.", fg="green"))
Expand All @@ -65,3 +105,84 @@ def remove_default_project():
session.rollback()
click.echo(click.style(f"ERROR: {traceback.format_exc()}", fg="red"))
raise e


def add_project_id_to_db_entities():
engine = create_engine(os.getenv("POSTGRES_URI"))
sync_session = sessionmaker(engine, expire_on_commit=False)

with sync_session() as session:
try:
default_project = check_for_multiple_default_projects(session)[0]
for model in MODELS:
offset = 0
while True:
records = (
session.execute(
select(model)
.where(model.project_id == None)
.offset(offset)
.limit(BATCH_SIZE)
)
.scalars()
.all()
)
if not records:
break

# Update records with default project_id
for record in records:
record.project_id = default_project.id

session.commit()
offset += BATCH_SIZE

except Exception as e:
session.rollback()
click.echo(
click.style(
f"ERROR adding project_id to db entities: {traceback.format_exc()}",
fg="red",
)
)
raise e


def remove_project_id_from_db_entities():
engine = create_engine(os.getenv("POSTGRES_URI"))
sync_session = sessionmaker(engine, expire_on_commit=False)

with sync_session() as session:
try:
for model in MODELS:
offset = 0
while True:
records = (
session.execute(
select(model)
.where(model.project_id != None)
.offset(offset)
.limit(BATCH_SIZE)
)
.scalars()
.all()
)
if not records:
break

# Update records project_id column with None
for record in records:
record.project_id = None

session.commit()
offset += BATCH_SIZE

except Exception as e:
session.rollback()
click.echo(
click.style(
f"ERROR removing project_id to db entities: {traceback.format_exc()}",
fg="red",
)
)
raise e
33 changes: 31 additions & 2 deletions agenta-backend/agenta_backend/migrations/postgres/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

import click
import asyncpg

from sqlalchemy import inspect, text, Engine
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine

from alembic import command
from alembic.config import Config
from sqlalchemy import inspect, text
from alembic.script import ScriptDirectory
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine

from agenta_backend.utils.common import isCloudEE, isCloudDev

Expand Down Expand Up @@ -173,3 +174,31 @@ async def check_if_templates_table_exist():
await engine.dispose()

return True


def unique_constraint_exists(
engine: Engine, table_name: str, constraint_name: str
) -> bool:
"""
The function checks if a unique constraint with a specific name exists on a table in a PostgreSQL
database.
Args:
- engine (Engine): instance of a database engine that represents a connection to a database.
- table_name (str): name of the table to check the existence of the unique constraint.
- constraint_name (str): name of the unique constraint to check for existence.
Returns:
- returns a boolean value indicating whether a unique constraint with the specified `constraint_name` exists in the table.
"""

with engine.connect() as conn:
result = conn.execute(
text(
f"""
SELECT conname FROM pg_constraint
WHERE conname = '{constraint_name}' AND conrelid = '{table_name}'::regclass;
"""
)
)
return result.fetchone() is not None
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Update evaluators names with app name as prefix
Revision ID: 22d29365f5fc
Revises: 6cfe239894fb
Create Date: 2024-09-16 11:38:33.886908
"""

from typing import Sequence, Union

from agenta_backend.migrations.postgres.data_migrations.applications import (
update_evaluators_with_app_name,
)


# revision identifiers, used by Alembic.
revision: str = "22d29365f5fc"
down_revision: Union[str, None] = "6cfe239894fb"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### custom command ###
update_evaluators_with_app_name()
# ### end custom command ###


def downgrade() -> None:
# ### custom command ###
pass
# ### end custom command ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""add default project to scoped model entities
Revision ID: 55bdd2e9a465
Revises: c00a326c625a
Create Date: 2024-09-12 21:56:38.701088
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa

from agenta_backend.migrations.postgres.data_migrations.projects import (
add_project_id_to_db_entities,
remove_project_id_from_db_entities,
)


# revision identifiers, used by Alembic.
revision: str = "55bdd2e9a465"
down_revision: Union[str, None] = "c00a326c625a"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### custom command ###
add_project_id_to_db_entities()
# ### end custom command ###


def downgrade() -> None:
# ### custom command ###
remove_project_id_from_db_entities()
# ### end custom command ###
Loading

0 comments on commit 964f4cd

Please sign in to comment.