From 108536492e8ed76174640836a091e0f20e7bb5f4 Mon Sep 17 00:00:00 2001 From: ali <117142933+muhammad-ali-e@users.noreply.github.com> Date: Mon, 21 Oct 2024 17:21:04 +0530 Subject: [PATCH] V2/remove feature flag (#795) * removed multi_tenancy_v2 feature flag * minor additions for v2 * Commit pdm.lock changes * minor changes to remove variable that only necessory for migration * Update backend/pyproject.toml Co-authored-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Signed-off-by: ali <117142933+muhammad-ali-e@users.noreply.github.com> * Update backend/sample.env Co-authored-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Signed-off-by: ali <117142933+muhammad-ali-e@users.noreply.github.com> * minor changes in sample.env --------- Signed-off-by: ali <117142933+muhammad-ali-e@users.noreply.github.com> Co-authored-by: muhammad-ali-e Co-authored-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> --- backend/account_v2/templates/index.html | 2 +- backend/backend/constants.py | 1 - backend/backend/settings/base.py | 390 ++++++------------ .../connector_processor.py | 14 +- backend/connector_processor/views.py | 9 +- .../file_management/file_management_helper.py | 8 +- backend/file_management/views.py | 12 +- .../v2/management/commands/migrate_to_v2.py | 5 +- backend/permissions/permission.py | 9 +- backend/sample.env | 4 +- backend/scheduler/helper.py | 22 +- backend/scheduler/serializer.py | 8 +- backend/scheduler/tasks.py | 80 +--- backend/utils/constants.py | 2 +- backend/utils/models/organization_mixin.py | 9 +- backend/utils/serializer_utils.py | 9 +- backend/utils/user_context.py | 43 +- backend/utils/user_session.py | 16 +- backend/workflow_manager/urls.py | 12 +- platform-service/sample.env | 4 +- .../unstract/platform_service/constants.py | 2 +- .../platform_service/controller/platform.py | 139 +++---- .../helper/adapter_instance.py | 23 +- .../platform_service/helper/prompt_studio.py | 24 +- prompt-service/sample.env | 4 +- .../authentication_middleware.py | 17 +- .../src/unstract/prompt_service/constants.py | 2 +- .../src/unstract/prompt_service/helper.py | 45 +- 28 files changed, 245 insertions(+), 670 deletions(-) diff --git a/backend/account_v2/templates/index.html b/backend/account_v2/templates/index.html index ffa0b6085..7b2cbec84 100644 --- a/backend/account_v2/templates/index.html +++ b/backend/account_v2/templates/index.html @@ -6,6 +6,6 @@

Welcome Guest

-

Login

+

Login

diff --git a/backend/backend/constants.py b/backend/backend/constants.py index 365222e30..26d944d9a 100644 --- a/backend/backend/constants.py +++ b/backend/backend/constants.py @@ -33,5 +33,4 @@ class UrlPathConstants: class FeatureFlag: """Temporary feature flags.""" - MULTI_TENANCY_V2 = "multi_tenancy_v2" APP_DEPLOYMENT = "app_deployment" diff --git a/backend/backend/settings/base.py b/backend/backend/settings/base.py index 4f0253509..090aa7054 100644 --- a/backend/backend/settings/base.py +++ b/backend/backend/settings/base.py @@ -17,9 +17,6 @@ from dotenv import find_dotenv, load_dotenv from utils.common_utils import CommonUtils -from backend.constants import FeatureFlag -from unstract.flags.feature_flag import check_feature_flag_status - missing_settings = [] @@ -157,193 +154,96 @@ def get_required_setting( CSRF_TRUSTED_ORIGINS = [WEB_APP_ORIGIN_URL] CORS_ALLOW_ALL_ORIGINS = False -if not check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - LOGGING = { - "version": 1, - "disable_existing_loggers": False, - "filters": { - "request_id": {"()": "log_request_id.filters.RequestIDFilter"}, - "tenant_context": {"()": "django_tenants.log.TenantContextFilter"}, - }, - "formatters": { - "enriched": { - "format": ( - "%(levelname)s : [%(asctime)s] [%(schema_name)s:%(domain_url)s]" - "{module:%(module)s process:%(process)d " - "thread:%(thread)d request_id:%(request_id)s} :- %(message)s" - ), - }, - "verbose": { - "format": "[%(asctime)s] %(levelname)s %(name)s: %(message)s", - "datefmt": "%d/%b/%Y %H:%M:%S", - }, - "simple": { - "format": "{levelname} {message}", - "style": "{", - }, - }, - "handlers": { - "console": { - "level": DEFAULT_LOG_LEVEL, # Set the desired logging level here - "class": "logging.StreamHandler", - "filters": ["request_id", "tenant_context"], - "formatter": "enriched", - }, - }, - "root": { - "handlers": ["console"], - "level": DEFAULT_LOG_LEVEL, - # Set the desired logging level here as well - }, - } - SHARED_APPS = ( - # Multitenancy - "django_tenants", - "corsheaders", - # For the organization model - "account", - "account_usage", - # Django apps should go below this line - "django.contrib.admin", - "django.contrib.auth", - "django.contrib.contenttypes", - "django.contrib.sessions", - "django.contrib.messages", - "django.contrib.staticfiles", - "django.contrib.admindocs", - # Third party apps should go below this line, - "rest_framework", - # Connector OAuth - "connector_auth", - "social_django", - # Doc generator - "drf_yasg", - "docs", - # Plugins - "plugins", - "feature_flag", - "django_celery_beat", - ) - TENANT_APPS = ( - # your tenant-specific apps - "django.contrib.admin", - "django.contrib.auth", - "django.contrib.contenttypes", - "django.contrib.messages", - "django.contrib.staticfiles", - "tenant_account", - "project", - "prompt", - "connector", - "adapter_processor", - "file_management", - "workflow_manager.endpoint", - "workflow_manager.workflow", - "tool_instance", - "pipeline", - "platform_settings", - "api", - "prompt_studio.prompt_profile_manager", - "prompt_studio.prompt_studio", - "prompt_studio.prompt_studio_core", - "prompt_studio.prompt_studio_registry", - "prompt_studio.prompt_studio_output_manager", - "prompt_studio.prompt_studio_document_manager", - "prompt_studio.prompt_studio_index_manager", - "usage", - "notification", - ) -else: - LOGGING = { - "version": 1, - "disable_existing_loggers": False, - "filters": { - "request_id": {"()": "log_request_id.filters.RequestIDFilter"}, +LOGGING = { + "version": 1, + "disable_existing_loggers": False, + "filters": { + "request_id": {"()": "log_request_id.filters.RequestIDFilter"}, + }, + "formatters": { + "enriched": { + "format": ( + "%(levelname)s : [%(asctime)s]" + "{module:%(module)s process:%(process)d " + "thread:%(thread)d request_id:%(request_id)s} :- %(message)s" + ), }, - "formatters": { - "enriched": { - "format": ( - "%(levelname)s : [%(asctime)s]" - "{module:%(module)s process:%(process)d " - "thread:%(thread)d request_id:%(request_id)s} :- %(message)s" - ), - }, - "verbose": { - "format": "[%(asctime)s] %(levelname)s %(name)s: %(message)s", - "datefmt": "%d/%b/%Y %H:%M:%S", - }, - "simple": { - "format": "{levelname} {message}", - "style": "{", - }, + "verbose": { + "format": "[%(asctime)s] %(levelname)s %(name)s: %(message)s", + "datefmt": "%d/%b/%Y %H:%M:%S", }, - "handlers": { - "console": { - "level": DEFAULT_LOG_LEVEL, # Set the desired logging level here - "class": "logging.StreamHandler", - "filters": ["request_id"], - "formatter": "enriched", - }, + "simple": { + "format": "{levelname} {message}", + "style": "{", }, - "root": { - "handlers": ["console"], - "level": DEFAULT_LOG_LEVEL, - # Set the desired logging level here as well + }, + "handlers": { + "console": { + "level": DEFAULT_LOG_LEVEL, # Set the desired logging level here + "class": "logging.StreamHandler", + "filters": ["request_id"], + "formatter": "enriched", }, - } - SHARED_APPS = ( - # Multitenancy - # "django_tenants", - "corsheaders", - # For the organization model - "account_v2", - "account_usage", - # Django apps should go below this line - "django.contrib.admin", - "django.contrib.auth", - "django.contrib.contenttypes", - "django.contrib.sessions", - "django.contrib.messages", - "django.contrib.staticfiles", - "django.contrib.admindocs", - # Third party apps should go below this line, - "rest_framework", - # Connector OAuth - # "connector_auth", - "social_django", - # Doc generator - "drf_yasg", - "docs", - # Plugins - "plugins", - "feature_flag", - "django_celery_beat", - ) - v2_apps = ( - "migrating.v2", - "connector_auth_v2", - "tenant_account_v2", - "connector_v2", - "adapter_processor_v2", - "file_management", - "workflow_manager.endpoint_v2", - "workflow_manager.workflow_v2", - "tool_instance_v2", - "pipeline_v2", - "platform_settings_v2", - "api_v2", - "usage_v2", - "notification_v2", - "prompt_studio.prompt_profile_manager_v2", - "prompt_studio.prompt_studio_v2", - "prompt_studio.prompt_studio_core_v2", - "prompt_studio.prompt_studio_registry_v2", - "prompt_studio.prompt_studio_output_manager_v2", - "prompt_studio.prompt_studio_document_manager_v2", - "prompt_studio.prompt_studio_index_manager_v2", - ) - SHARED_APPS += v2_apps - TENANT_APPS = [] + }, + "root": { + "handlers": ["console"], + "level": DEFAULT_LOG_LEVEL, + # Set the desired logging level here as well + }, +} +SHARED_APPS = ( + # Multitenancy + # "django_tenants", + "corsheaders", + # For the organization model + "account_v2", + "account_usage", + # Django apps should go below this line + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "django.contrib.admindocs", + # Third party apps should go below this line, + "rest_framework", + # Connector OAuth + # "connector_auth", + "social_django", + # Doc generator + "drf_yasg", + "docs", + # Plugins + "plugins", + "feature_flag", + "django_celery_beat", +) +v2_apps = ( + "migrating.v2", + "connector_auth_v2", + "tenant_account_v2", + "connector_v2", + "adapter_processor_v2", + "file_management", + "workflow_manager.endpoint_v2", + "workflow_manager.workflow_v2", + "tool_instance_v2", + "pipeline_v2", + "platform_settings_v2", + "api_v2", + "usage_v2", + "notification_v2", + "prompt_studio.prompt_profile_manager_v2", + "prompt_studio.prompt_studio_v2", + "prompt_studio.prompt_studio_core_v2", + "prompt_studio.prompt_studio_registry_v2", + "prompt_studio.prompt_studio_output_manager_v2", + "prompt_studio.prompt_studio_document_manager_v2", + "prompt_studio.prompt_studio_index_manager_v2", +) +SHARED_APPS += v2_apps +TENANT_APPS = [] INSTALLED_APPS = list(SHARED_APPS) + [ app for app in TENANT_APPS if app not in SHARED_APPS @@ -358,92 +258,44 @@ def get_required_setting( PUBLIC_ORG_ID = "public" -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - # Middleware Configuration - TENANT_MIDDLEWARE = "middleware.organization_middleware.OrganizationMiddleware" - CUSTOM_AUTH_MIDDLEWARE = "account_v2.custom_auth_middleware.CustomAuthMiddleware" +# Middleware Configuration +TENANT_MIDDLEWARE = "middleware.organization_middleware.OrganizationMiddleware" +CUSTOM_AUTH_MIDDLEWARE = "account_v2.custom_auth_middleware.CustomAuthMiddleware" - # Pipeline Functions - SOCIAL_AUTH_PIPELINE_USER_AUTH = ( - "connector_auth_v2.pipeline.common.check_user_exists" - ) - SOCIAL_AUTH_PIPELINE_CACHE_CRED = ( - "connector_auth_v2.pipeline.common.cache_oauth_creds" - ) +# Pipeline Functions +SOCIAL_AUTH_PIPELINE_USER_AUTH = "connector_auth_v2.pipeline.common.check_user_exists" +SOCIAL_AUTH_PIPELINE_CACHE_CRED = "connector_auth_v2.pipeline.common.cache_oauth_creds" - # Routing Configuration - ROOT_URLCONF = "backend.base_urls" - - # DB Configuration - DB_ENGINE = "backend.custom_db" - - # Models - AUTH_USER_MODEL = "account_v2.User" - - # Social Authentication - SOCIAL_AUTH_USER_MODEL = "account_v2.User" - SOCIAL_AUTH_STORAGE = "connector_auth_v2.models.ConnectorDjangoStorage" - - # Namespaces - SOCIAL_AUTH_URL_NAMESPACE = "public:social" - LOGIN_CALLBACK_URL_NAMESPACE = "public:callback" - DATABASES = { - "default": { - "ENGINE": DB_ENGINE, - "NAME": f"{DB_NAME}", - "USER": f"{DB_USER}", - "HOST": f"{DB_HOST}", - "PASSWORD": f"{DB_PASSWORD}", - "PORT": f"{DB_PORT}", - "ATOMIC_REQUESTS": ATOMIC_REQUESTS, - "OPTIONS": { - "application_name": os.environ.get("APPLICATION_NAME", ""), - }, - } - } -else: - # Middleware Configuration - TENANT_MIDDLEWARE = "django_tenants.middleware.TenantSubfolderMiddleware" - CUSTOM_AUTH_MIDDLEWARE = "account.custom_auth_middleware.CustomAuthMiddleware" - - # Pipeline Functions - SOCIAL_AUTH_PIPELINE_USER_AUTH = "connector_auth.pipeline.common.check_user_exists" - SOCIAL_AUTH_PIPELINE_CACHE_CRED = "connector_auth.pipeline.common.cache_oauth_creds" - - # Routing Configuration - PUBLIC_SCHEMA_URLCONF = "backend.public_urls" - ROOT_URLCONF = "backend.urls" - - # DB Configuration - DB_ENGINE = "django_tenants.postgresql_backend" - DATABASE_ROUTERS = ("django_tenants.routers.TenantSyncRouter",) - - # Models - AUTH_USER_MODEL = "account.User" - TENANT_MODEL = "account.Organization" - TENANT_DOMAIN_MODEL = "account.Domain" - - # Social Authentication - SOCIAL_AUTH_USER_MODEL = "account.User" - SOCIAL_AUTH_STORAGE = "connector_auth.models.ConnectorDjangoStorage" - - # Namespaces - SOCIAL_AUTH_URL_NAMESPACE = "social" - LOGIN_CALLBACK_URL_NAMESPACE = "callback" - DATABASES = { - "default": { - "ENGINE": DB_ENGINE, - "NAME": f"{DB_NAME}", - "USER": f"{DB_USER}", - "HOST": f"{DB_HOST}", - "PASSWORD": f"{DB_PASSWORD}", - "PORT": f"{DB_PORT}", - "ATOMIC_REQUESTS": ATOMIC_REQUESTS, - "OPTIONS": { - "application_name": os.environ.get("APPLICATION_NAME", ""), - }, - } +# Routing Configuration +ROOT_URLCONF = "backend.base_urls" + +# DB Configuration +DB_ENGINE = "backend.custom_db" + +# Models +AUTH_USER_MODEL = "account_v2.User" + +# Social Authentication +SOCIAL_AUTH_USER_MODEL = "account_v2.User" +SOCIAL_AUTH_STORAGE = "connector_auth_v2.models.ConnectorDjangoStorage" + +# Namespaces +SOCIAL_AUTH_URL_NAMESPACE = "public:social" +LOGIN_CALLBACK_URL_NAMESPACE = "public:callback" +DATABASES = { + "default": { + "ENGINE": DB_ENGINE, + "NAME": f"{DB_NAME}", + "USER": f"{DB_USER}", + "HOST": f"{DB_HOST}", + "PASSWORD": f"{DB_PASSWORD}", + "PORT": f"{DB_PORT}", + "ATOMIC_REQUESTS": ATOMIC_REQUESTS, + "OPTIONS": { + "application_name": os.environ.get("APPLICATION_NAME", ""), + }, } +} MIDDLEWARE = [ "log_request_id.middleware.RequestIDMiddleware", diff --git a/backend/connector_processor/connector_processor.py b/backend/connector_processor/connector_processor.py index 0e554de3d..1dac1e326 100644 --- a/backend/connector_processor/connector_processor.py +++ b/backend/connector_processor/connector_processor.py @@ -3,6 +3,8 @@ import logging from typing import Any, Optional +from connector_auth_v2.constants import ConnectorAuthKey +from connector_auth_v2.pipeline.common import ConnectorAuthHelper from connector_processor.constants import ConnectorKeys from connector_processor.exceptions import ( InValidConnectorId, @@ -10,24 +12,14 @@ OAuthTimeOut, TestConnectorInputError, ) +from connector_v2.constants import ConnectorInstanceKey as CIKey -from backend.constants import FeatureFlag from backend.exceptions import UnstractFSException from unstract.connectors.base import UnstractConnector from unstract.connectors.connectorkit import Connectorkit from unstract.connectors.enums import ConnectorMode from unstract.connectors.exceptions import ConnectorError, FSAccessDeniedError from unstract.connectors.filesystems.ucs import UnstractCloudStorage -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from connector_auth_v2.constants import ConnectorAuthKey - from connector_auth_v2.pipeline.common import ConnectorAuthHelper - from connector_v2.constants import ConnectorInstanceKey as CIKey -else: - from connector.constants import ConnectorInstanceKey as CIKey - from connector_auth.constants import ConnectorAuthKey - from connector_auth.pipeline.common import ConnectorAuthHelper logger = logging.getLogger(__name__) diff --git a/backend/connector_processor/views.py b/backend/connector_processor/views.py index 20c9277a7..7afd70849 100644 --- a/backend/connector_processor/views.py +++ b/backend/connector_processor/views.py @@ -2,6 +2,7 @@ from connector_processor.constants import ConnectorKeys from connector_processor.exceptions import IdIsMandatory, InValidType from connector_processor.serializers import TestConnectorSerializer +from connector_v2.constants import ConnectorInstanceKey as CIKey from django.http.request import HttpRequest from django.http.response import HttpResponse from rest_framework import status @@ -12,14 +13,6 @@ from rest_framework.versioning import URLPathVersioning from rest_framework.viewsets import GenericViewSet -from backend.constants import FeatureFlag -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from connector_v2.constants import ConnectorInstanceKey as CIKey -else: - from connector.constants import ConnectorInstanceKey as CIKey - @api_view(("GET",)) def get_connector_schema(request: HttpRequest) -> HttpResponse: diff --git a/backend/file_management/file_management_helper.py b/backend/file_management/file_management_helper.py index 42118c039..0919cf1e5 100644 --- a/backend/file_management/file_management_helper.py +++ b/backend/file_management/file_management_helper.py @@ -6,6 +6,7 @@ from typing import Any import magic +from connector_v2.models import ConnectorInstance from django.conf import settings from django.http import StreamingHttpResponse from file_management.exceptions import ( @@ -23,15 +24,8 @@ from fsspec import AbstractFileSystem from pydrive2.files import ApiRequestError -from backend.constants import FeatureFlag from unstract.connectors.filesystems import connectors as fs_connectors from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from connector_v2.models import ConnectorInstance -else: - from connector.models import ConnectorInstance logger = logging.getLogger(__name__) diff --git a/backend/file_management/views.py b/backend/file_management/views.py index 576ec5e56..e608ad89b 100644 --- a/backend/file_management/views.py +++ b/backend/file_management/views.py @@ -1,6 +1,7 @@ import logging from typing import Any +from connector_v2.models import ConnectorInstance from django.http import HttpRequest from file_management.exceptions import ( ConnectorInstanceNotFound, @@ -15,24 +16,15 @@ FileUploadSerializer, ) from oauth2client.client import HttpAccessTokenRefreshError +from prompt_studio.prompt_studio_document_manager_v2.models import DocumentManager from rest_framework import serializers, status, viewsets from rest_framework.decorators import action from rest_framework.response import Response from rest_framework.versioning import URLPathVersioning from utils.user_session import UserSessionUtils -from backend.constants import FeatureFlag from unstract.connectors.exceptions import ConnectorError from unstract.connectors.filesystems.local_storage.local_storage import LocalStorageFS -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from connector_v2.models import ConnectorInstance - from prompt_studio.prompt_studio_document_manager_v2.models import DocumentManager - -else: - from connector.models import ConnectorInstance - from prompt_studio.prompt_studio_document_manager.models import DocumentManager logger = logging.getLogger(__name__) diff --git a/backend/migrating/v2/management/commands/migrate_to_v2.py b/backend/migrating/v2/management/commands/migrate_to_v2.py index 06521abfc..9eb59ce55 100644 --- a/backend/migrating/v2/management/commands/migrate_to_v2.py +++ b/backend/migrating/v2/management/commands/migrate_to_v2.py @@ -495,7 +495,10 @@ def handle(self, *args, **options): migrator.migrate(public_schema_migrations) if not schemas_to_migrate: - logger.info("Migration not run since SCHEMAS_TO_MIGRATE env seems empty.") + logger.info( + "Migration not run since SCHEMAS_TO_MIGRATE env seems empty." + "Set the value as `_ALL_` to migrate complete data" + ) return else: schemas_to_migrate = schemas_to_migrate.split(",") diff --git a/backend/permissions/permission.py b/backend/permissions/permission.py index 98bda6a17..02d62acb7 100644 --- a/backend/permissions/permission.py +++ b/backend/permissions/permission.py @@ -1,17 +1,10 @@ from typing import Any +from adapter_processor_v2.models import AdapterInstance from rest_framework import permissions from rest_framework.request import Request from rest_framework.views import APIView -from backend.constants import FeatureFlag -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from adapter_processor_v2.models import AdapterInstance -else: - from adapter_processor.models import AdapterInstance - class IsOwner(permissions.BasePermission): """Custom permission to only allow owners of an object.""" diff --git a/backend/sample.env b/backend/sample.env index 4ed0d8778..fcf2f3fc7 100644 --- a/backend/sample.env +++ b/backend/sample.env @@ -16,6 +16,7 @@ DB_USER='unstract_dev' DB_PASSWORD='unstract_pass' DB_NAME='unstract_db' DB_PORT=5432 +DB_SCHEMA="unstract" # Redis REDIS_HOST="unstract-redis" @@ -140,9 +141,6 @@ CELERY_BROKER_URL = "redis://unstract-redis:6379" # Indexing flag to prevent re-index INDEXING_FLAG_TTL=1800 -# V2 Configurations -DB_SCHEMA="unstract_v2" - # Notification Timeout in Seconds NOTIFICATION_TIMEOUT=5 diff --git a/backend/scheduler/helper.py b/backend/scheduler/helper.py index 630ffe6b4..2588d82b8 100644 --- a/backend/scheduler/helper.py +++ b/backend/scheduler/helper.py @@ -1,7 +1,7 @@ import logging from typing import Any -from django.db import connection +from pipeline_v2.models import Pipeline from rest_framework.serializers import ValidationError from scheduler.constants import SchedulerConstants as SC from scheduler.exceptions import JobDeletionError, JobSchedulingError @@ -12,19 +12,10 @@ disable_task, enable_task, ) +from utils.user_context import UserContext +from workflow_manager.workflow_v2.constants import WorkflowKey +from workflow_manager.workflow_v2.serializers import ExecuteWorkflowSerializer -from backend.constants import FeatureFlag -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from pipeline_v2.models import Pipeline - from utils.user_context import UserContext - from workflow_manager.workflow_v2.constants import WorkflowKey - from workflow_manager.workflow_v2.serializers import ExecuteWorkflowSerializer -else: - from pipeline.models import Pipeline - from workflow_manager.workflow.constants import WorkflowKey - from workflow_manager.workflow.serializers import ExecuteWorkflowSerializer logger = logging.getLogger(__name__) @@ -50,10 +41,7 @@ def _schedule_task_job(pipeline: Pipeline, job_data: Any) -> None: workflow_id = serializer.get_workflow_id(serializer.validated_data) # TODO: Remove unused argument in execute_pipeline_task execution_action = serializer.get_execution_action(serializer.validated_data) - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - organization_id = UserContext.get_organization_identifier() - else: - organization_id = connection.tenant.schema_name + organization_id = UserContext.get_organization_identifier() create_or_update_periodic_task( cron_string=cron_string, diff --git a/backend/scheduler/serializer.py b/backend/scheduler/serializer.py index fa4666214..4b39a6d44 100644 --- a/backend/scheduler/serializer.py +++ b/backend/scheduler/serializer.py @@ -1,17 +1,11 @@ import logging from typing import Any +from pipeline_v2.manager import PipelineManager from rest_framework import serializers from scheduler.constants import SchedulerConstants as SC -from backend.constants import FeatureFlag from backend.constants import FieldLengthConstants as FieldLength -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from pipeline_v2.manager import PipelineManager -else: - from pipeline.manager import PipelineManager logger = logging.getLogger(__name__) diff --git a/backend/scheduler/tasks.py b/backend/scheduler/tasks.py index 49e92b6ce..e0c89e04c 100644 --- a/backend/scheduler/tasks.py +++ b/backend/scheduler/tasks.py @@ -3,26 +3,13 @@ import traceback from typing import Any, Optional +from account_v2.subscription_loader import load_plugins, validate_etl_run from celery import shared_task from django_celery_beat.models import CrontabSchedule, PeriodicTask -from django_tenants.utils import get_tenant_model, tenant_context - -from backend.constants import FeatureFlag -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from account_v2.subscription_loader import load_plugins, validate_etl_run - from pipeline_v2.models import Pipeline - from pipeline_v2.pipeline_processor import PipelineProcessor - from utils.user_context import UserContext - from workflow_manager.workflow_v2.workflow_helper import WorkflowHelper -else: - from account.models import Organization - from account.subscription_loader import load_plugins, validate_etl_run - from pipeline.models import Pipeline - from pipeline.pipeline_processor import PipelineProcessor - from workflow_manager.workflow.workflow_helper import WorkflowHelper - +from pipeline_v2.models import Pipeline +from pipeline_v2.pipeline_processor import PipelineProcessor +from utils.user_context import UserContext +from workflow_manager.workflow_v2.workflow_helper import WorkflowHelper logger = logging.getLogger(__name__) subscription_loader = load_plugins() @@ -77,58 +64,11 @@ def execute_pipeline_task( with_logs: Any, name: Any, ) -> None: - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - execute_pipeline_task_v2( - organization_id=org_schema, - pipeline_id=pipepline_id, - pipeline_name=name, - ) - return - logger.info(f"Executing pipeline name: {name}") - try: - tenant: Organization = ( - get_tenant_model().objects.filter(schema_name=org_schema).first() - ) - with tenant_context(tenant): - pipeline = PipelineProcessor.fetch_pipeline( - pipeline_id=pipepline_id, check_active=True - ) - workflow = pipeline.workflow - logger.info( - f"Executing pipeline: {pipeline}, " - f"workflow: {workflow}, pipeline name: {name}" - ) - if ( - subscription_loader - and subscription_loader[0] - and not validate_etl_run(org_schema) - ): - try: - logger.info( - f"Subscription expired for '{org_schema}', " - f"disabling pipeline: {pipepline_id}" - ) - disable_task(pipepline_id) - except Exception as e: - logger.warning( - f"Failed to disable task: {pipepline_id}. Error: {e}" - ) - return - PipelineProcessor.initialize_pipeline_sync(pipeline_id=pipepline_id) - PipelineProcessor.update_pipeline( - pipepline_id, Pipeline.PipelineStatus.INPROGRESS - ) - execution_response = WorkflowHelper.complete_execution( - workflow=workflow, pipeline_id=pipepline_id - ) - execution_response.remove_result_metadata_keys() - logger.info(f"Execution response: {execution_response}") - logger.info(f"Execution completed for pipeline: {name}") - except Exception as e: - logger.error( - f"Failed to execute pipeline: {name}. Error: {e}" - f"\n\n'''{traceback.format_exc()}```" - ) + execute_pipeline_task_v2( + organization_id=org_schema, + pipeline_id=pipepline_id, + pipeline_name=name, + ) def execute_pipeline_task_v2( diff --git a/backend/utils/constants.py b/backend/utils/constants.py index d524bf430..a92f397a4 100644 --- a/backend/utils/constants.py +++ b/backend/utils/constants.py @@ -18,7 +18,7 @@ class Account: class FeatureFlag: """Temporary feature flags.""" - MULTI_TENANCY_V2 = "multi_tenancy_v2" + pass class Common: diff --git a/backend/utils/models/organization_mixin.py b/backend/utils/models/organization_mixin.py index 18d496c9d..1e178fe0c 100644 --- a/backend/utils/models/organization_mixin.py +++ b/backend/utils/models/organization_mixin.py @@ -1,11 +1,8 @@ # TODO:V2 class from account_v2.models import Organization from django.db import models -from utils.constants import FeatureFlag from utils.user_context import UserContext -from unstract.flags.feature_flag import check_feature_flag_status - class DefaultOrganizationMixin(models.Model): organization = models.ForeignKey( @@ -28,7 +25,5 @@ def save(self, *args, **kwargs): class DefaultOrganizationManagerMixin(models.Manager): def get_queryset(self): - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - organization = UserContext.get_organization() - return super().get_queryset().filter(organization=organization) - return super().get_queryset() + organization = UserContext.get_organization() + return super().get_queryset().filter(organization=organization) diff --git a/backend/utils/serializer_utils.py b/backend/utils/serializer_utils.py index fdecc7fae..e5782e0f5 100644 --- a/backend/utils/serializer_utils.py +++ b/backend/utils/serializer_utils.py @@ -1,16 +1,9 @@ from typing import Any +from account_v2.models import User from rest_framework.request import Request from utils.constants import Account -from backend.constants import FeatureFlag -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from account_v2.models import User -else: - from account.models import User - class SerializerUtils: @staticmethod diff --git a/backend/utils/user_context.py b/backend/utils/user_context.py index 1be52f02a..7f2b80dbd 100644 --- a/backend/utils/user_context.py +++ b/backend/utils/user_context.py @@ -1,25 +1,15 @@ from typing import Optional -from django.db import connection +from account_v2.models import Organization from django.db.utils import ProgrammingError -from utils.constants import Account, FeatureFlag +from utils.constants import Account from utils.local_context import StateStore -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from account_v2.models import Organization -else: - from account.models import Organization - class UserContext: @staticmethod def get_organization_identifier() -> str: - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - organization_id = StateStore.get(Account.ORGANIZATION_ID) - else: - organization_id = connection.tenant.schema_name + organization_id = StateStore.get(Account.ORGANIZATION_ID) return organization_id @staticmethod @@ -28,19 +18,16 @@ def set_organization_identifier(organization_identifier: str) -> None: @staticmethod def get_organization() -> Optional[Organization]: - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - organization_id = StateStore.get(Account.ORGANIZATION_ID) - try: - organization: Organization = Organization.objects.get( - organization_id=organization_id - ) - except Organization.DoesNotExist: - return None - except ProgrammingError: - # Handle cases where the database schema might not be fully set up, - # especially during the execution of management commands - # other than runserver - return None - else: - organization: Organization = connection.tenant + organization_id = StateStore.get(Account.ORGANIZATION_ID) + try: + organization: Organization = Organization.objects.get( + organization_id=organization_id + ) + except Organization.DoesNotExist: + return None + except ProgrammingError: + # Handle cases where the database schema might not be fully set up, + # especially during the execution of management commands + # other than runserver + return None return organization diff --git a/backend/utils/user_session.py b/backend/utils/user_session.py index 972f2a28f..c5c5a9452 100644 --- a/backend/utils/user_session.py +++ b/backend/utils/user_session.py @@ -3,24 +3,16 @@ from django.conf import settings from django.db import connection from django.http import HttpRequest -from utils.constants import FeatureFlag - -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from tenant_account_v2.models import OrganizationMember -else: - from tenant_account.models import OrganizationMember +from tenant_account_v2.models import OrganizationMember class UserSessionUtils: @staticmethod def get_organization_id(request: HttpRequest) -> Optional[str]: session_org_id = request.session.get("organization") - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - requested_org_id = request.organization_id - if requested_org_id and (session_org_id != requested_org_id): - return None + requested_org_id = request.organization_id + if requested_org_id and (session_org_id != requested_org_id): + return None return session_org_id @staticmethod diff --git a/backend/workflow_manager/urls.py b/backend/workflow_manager/urls.py index 7140cc174..42c972973 100644 --- a/backend/workflow_manager/urls.py +++ b/backend/workflow_manager/urls.py @@ -1,14 +1,6 @@ from django.urls import include, path - -from backend.constants import FeatureFlag -from unstract.flags.feature_flag import check_feature_flag_status - -if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - from workflow_manager.endpoint_v2 import urls as endpoint_urls - from workflow_manager.workflow_v2 import urls as workflow_urls -else: - from workflow_manager.endpoint import urls as endpoint_urls - from workflow_manager.workflow import urls as workflow_urls +from workflow_manager.endpoint_v2 import urls as endpoint_urls +from workflow_manager.workflow_v2 import urls as workflow_urls urlpatterns = [ path("endpoint/", include(endpoint_urls)), diff --git a/platform-service/sample.env b/platform-service/sample.env index c58e4421c..4723f83cc 100644 --- a/platform-service/sample.env +++ b/platform-service/sample.env @@ -14,6 +14,7 @@ PG_BE_PORT=5432 PG_BE_USERNAME=unstract_dev PG_BE_PASSWORD=unstract_pass PG_BE_DATABASE=unstract_db +DB_SCHEMA="unstract" # Encryption Key @@ -32,7 +33,4 @@ MODEL_PRICES_URL="https://raw.githubusercontent.com/BerriAI/litellm/main/model_p MODEL_PRICES_TTL_IN_DAYS=7 MODEL_PRICES_FILE_PATH="/tmp/model_prices.json" -# V2 Configurations -DB_SCHEMA="unstract_v2" - LOG_LEVEL=INFO diff --git a/platform-service/src/unstract/platform_service/constants.py b/platform-service/src/unstract/platform_service/constants.py index a9f87437f..4a443e800 100644 --- a/platform-service/src/unstract/platform_service/constants.py +++ b/platform-service/src/unstract/platform_service/constants.py @@ -1,7 +1,7 @@ class FeatureFlag: """Temporary feature flags.""" - MULTI_TENANCY_V2 = "multi_tenancy_v2" + pass class DBTable: diff --git a/platform-service/src/unstract/platform_service/controller/platform.py b/platform-service/src/unstract/platform_service/controller/platform.py index eb0ef3718..3381c6029 100644 --- a/platform-service/src/unstract/platform_service/controller/platform.py +++ b/platform-service/src/unstract/platform_service/controller/platform.py @@ -8,7 +8,7 @@ from flask import Blueprint, Request from flask import current_app as app from flask import jsonify, make_response, request -from unstract.platform_service.constants import DBTable, DBTableV2, FeatureFlag +from unstract.platform_service.constants import DBTable, DBTableV2 from unstract.platform_service.env import Env from unstract.platform_service.exceptions import APIError from unstract.platform_service.extensions import db @@ -18,8 +18,6 @@ from unstract.platform_service.helper.cost_calculation import CostCalculationHelper from unstract.platform_service.helper.prompt_studio import PromptStudioRequestHelper -from unstract.flags.feature_flag import check_feature_flag_status - platform_bp = Blueprint("platform", __name__) @@ -47,14 +45,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -def get_account_from_bearer_token(token: Optional[str]) -> str: - query = "SELECT organization_id FROM account_platformkey WHERE key=%s" - organization = execute_query(query, (token,)) - query_org = "SELECT schema_name FROM account_organization WHERE id=%s" - schema_name: str = execute_query(query_org, (organization,)) - return schema_name - - def get_organization_from_bearer_token(token: str) -> tuple[Optional[int], str]: """Fetch organization by platform key. @@ -64,21 +54,17 @@ def get_organization_from_bearer_token(token: str) -> tuple[Optional[int], str]: Returns: tuple[int, str]: organization uid and organization identifier """ - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - query = f""" - SELECT organization_id FROM "{Env.DB_SCHEMA}".{DBTableV2.PLATFORM_KEY} - WHERE key=%s - """ - organization_uid: int = execute_query(query, (token,)) - query_org = f""" - SELECT organization_id FROM "{Env.DB_SCHEMA}".{DBTableV2.ORGANIZATION} - WHERE id=%s - """ - organization_identifier: str = execute_query(query_org, (organization_uid,)) - return organization_uid, organization_identifier - else: - organization_identifier = get_account_from_bearer_token(token=token) - return None, organization_identifier + query = f""" + SELECT organization_id FROM "{Env.DB_SCHEMA}".{DBTableV2.PLATFORM_KEY} + WHERE key=%s + """ + organization_uid: int = execute_query(query, (token,)) + query_org = f""" + SELECT organization_id FROM "{Env.DB_SCHEMA}".{DBTableV2.ORGANIZATION} + WHERE id=%s + """ + organization_identifier: str = execute_query(query_org, (organization_uid,)) + return organization_uid, organization_identifier def execute_query(query: str, params: tuple = ()) -> Any: @@ -96,17 +82,11 @@ def validate_bearer_token(token: Optional[str]) -> bool: app.logger.error("Authentication failed. Empty bearer token") return False - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - platform_key_table = DBTableV2.PLATFORM_KEY - query = f""" - SELECT * FROM \"{Env.DB_SCHEMA}\".{platform_key_table} - WHERE key = '{token}' - """ - else: - platform_key_table = "account_platformkey" - query = f""" - SELECT * FROM {platform_key_table} WHERE key = '{token}' - """ + platform_key_table = DBTableV2.PLATFORM_KEY + query = f""" + SELECT * FROM \"{Env.DB_SCHEMA}\".{platform_key_table} + WHERE key = '{token}' + """ cursor = db.execute_sql(query) result_row = cursor.fetchone() @@ -241,62 +221,35 @@ def usage() -> Any: ) usage_id = uuid.uuid4() current_time = datetime.now() - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - query = f""" - INSERT INTO \"{Env.DB_SCHEMA}\".{DBTableV2.TOKEN_USAGE} ( - id, organization_id, workflow_id, - execution_id, adapter_instance_id, run_id, usage_type, - llm_usage_reason, model_name, embedding_tokens, prompt_tokens, - completion_tokens, total_tokens, cost_in_dollars, created_at, modified_at) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """ - usage_id = uuid.uuid4() - current_time = datetime.now() - params = ( - usage_id, - organization_uid, - workflow_id, - execution_id, - adapter_instance_id, - run_id, - usage_type, - llm_usage_reason, - model_name, - embedding_tokens, - prompt_tokens, - completion_tokens, - total_tokens, - cost_in_dollars, - current_time, - current_time, - ) - else: - query = f""" - INSERT INTO "{org_id}"."token_usage" (id, workflow_id, - execution_id, adapter_instance_id, run_id, usage_type, - llm_usage_reason, model_name, embedding_tokens, prompt_tokens, - completion_tokens, total_tokens, cost_in_dollars, created_at, modified_at) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """ - usage_id = uuid.uuid4() - current_time = datetime.now() - params = ( - usage_id, - workflow_id, - execution_id, - adapter_instance_id, - run_id, - usage_type, - llm_usage_reason, - model_name, - embedding_tokens, - prompt_tokens, - completion_tokens, - total_tokens, - cost_in_dollars, - current_time, - current_time, - ) + query = f""" + INSERT INTO \"{Env.DB_SCHEMA}\".{DBTableV2.TOKEN_USAGE} ( + id, organization_id, workflow_id, + execution_id, adapter_instance_id, run_id, usage_type, + llm_usage_reason, model_name, embedding_tokens, prompt_tokens, + completion_tokens, total_tokens, cost_in_dollars, created_at, modified_at) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """ + usage_id = uuid.uuid4() + current_time = datetime.now() + params = ( + usage_id, + organization_uid, + workflow_id, + execution_id, + adapter_instance_id, + run_id, + usage_type, + llm_usage_reason, + model_name, + embedding_tokens, + prompt_tokens, + completion_tokens, + total_tokens, + cost_in_dollars, + current_time, + current_time, + ) + try: with db.atomic() as transaction: db.execute_sql(query, params) diff --git a/platform-service/src/unstract/platform_service/helper/adapter_instance.py b/platform-service/src/unstract/platform_service/helper/adapter_instance.py index 813b48c3f..4b023596a 100644 --- a/platform-service/src/unstract/platform_service/helper/adapter_instance.py +++ b/platform-service/src/unstract/platform_service/helper/adapter_instance.py @@ -1,12 +1,10 @@ from typing import Any, Optional -from unstract.platform_service.constants import DBTableV2, FeatureFlag +from unstract.platform_service.constants import DBTableV2 from unstract.platform_service.exceptions import APIError from unstract.platform_service.extensions import db from unstract.platform_service.utils import EnvManager -from unstract.flags.feature_flag import check_feature_flag_status - DB_SCHEMA = EnvManager.get_required_setting("DB_SCHEMA", "unstract_v2") @@ -26,19 +24,12 @@ def get_adapter_instance_from_db( Returns: _type_: _description_ """ - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - query = ( - "SELECT id, adapter_id, adapter_name, adapter_type, adapter_metadata_b" - f' FROM "{DB_SCHEMA}".{DBTableV2.ADAPTER_INSTANCE} x ' - f"WHERE id='{adapter_instance_id}' and " - f"organization_id='{organization_uid}'" - ) - else: - query = ( - f"SELECT id, adapter_id, adapter_name, adapter_type, adapter_metadata_b" - f' FROM "{organization_id}".adapter_adapterinstance x ' - f"WHERE id='{adapter_instance_id}'" - ) + query = ( + "SELECT id, adapter_id, adapter_name, adapter_type, adapter_metadata_b" + f' FROM "{DB_SCHEMA}".{DBTableV2.ADAPTER_INSTANCE} x ' + f"WHERE id='{adapter_instance_id}' and " + f"organization_id='{organization_uid}'" + ) cursor = db.execute_sql(query) result_row = cursor.fetchone() if not result_row: diff --git a/platform-service/src/unstract/platform_service/helper/prompt_studio.py b/platform-service/src/unstract/platform_service/helper/prompt_studio.py index 565c3c2f9..35ef634db 100644 --- a/platform-service/src/unstract/platform_service/helper/prompt_studio.py +++ b/platform-service/src/unstract/platform_service/helper/prompt_studio.py @@ -1,12 +1,10 @@ from typing import Any -from unstract.platform_service.constants import DBTableV2, FeatureFlag +from unstract.platform_service.constants import DBTableV2 from unstract.platform_service.exceptions import APIError from unstract.platform_service.extensions import db from unstract.platform_service.utils import EnvManager -from unstract.flags.feature_flag import check_feature_flag_status - DB_SCHEMA = EnvManager.get_required_setting("DB_SCHEMA", "unstract_v2") @@ -25,20 +23,12 @@ def get_prompt_instance_from_db( Returns: _type_: _description_ """ - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - query = ( - "SELECT prompt_registry_id, tool_spec, " - "tool_metadata, tool_property FROM " - f'"{DB_SCHEMA}".{DBTableV2.PROMPT_STUDIO_REGISTRY} x ' - f"WHERE prompt_registry_id='{prompt_registry_id}'" - ) - else: - query = ( - f"SELECT prompt_registry_id, tool_spec, " - f"tool_metadata, tool_property FROM " - f'"{organization_id}".prompt_studio_registry_promptstudioregistry x' - f" WHERE prompt_registry_id='{prompt_registry_id}'" - ) + query = ( + "SELECT prompt_registry_id, tool_spec, " + "tool_metadata, tool_property FROM " + f'"{DB_SCHEMA}".{DBTableV2.PROMPT_STUDIO_REGISTRY} x ' + f"WHERE prompt_registry_id='{prompt_registry_id}'" + ) cursor = db.execute_sql(query) result_row = cursor.fetchone() if not result_row: diff --git a/prompt-service/sample.env b/prompt-service/sample.env index 697592d28..9289af409 100644 --- a/prompt-service/sample.env +++ b/prompt-service/sample.env @@ -4,6 +4,7 @@ PG_BE_PORT=5432 PG_BE_USERNAME=unstract_dev PG_BE_PASSWORD=unstract_pass PG_BE_DATABASE=unstract_db +DB_SCHEMA="unstract" # Redis REDIS_HOST="unstract-redis" @@ -26,8 +27,5 @@ EVALUATION_SERVER_IP=unstract-flipt EVALUATION_SERVER_PORT=9000 PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python -# V2 Configurations -DB_SCHEMA="unstract_v2" - # Flipt Service FLIPT_SERVICE_AVAILABLE=False diff --git a/prompt-service/src/unstract/prompt_service/authentication_middleware.py b/prompt-service/src/unstract/prompt_service/authentication_middleware.py index a5a2e135c..5192df71a 100644 --- a/prompt-service/src/unstract/prompt_service/authentication_middleware.py +++ b/prompt-service/src/unstract/prompt_service/authentication_middleware.py @@ -2,12 +2,10 @@ from flask import Request, current_app from unstract.prompt_service.config import db -from unstract.prompt_service.constants import DBTableV2, FeatureFlag +from unstract.prompt_service.constants import DBTableV2 from unstract.prompt_service.db_utils import DBUtils from unstract.prompt_service.env_manager import EnvLoader -from unstract.flags.feature_flag import check_feature_flag_status - DB_SCHEMA = EnvLoader.get_env_or_die("DB_SCHEMA", "unstract_v2") @@ -20,10 +18,7 @@ def validate_bearer_token(token: Optional[str]) -> bool: current_app.logger.error("Authentication failed. Empty bearer token") return False - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - platform_key_table = f'"{DB_SCHEMA}".{DBTableV2.PLATFORM_KEY}' - else: - platform_key_table = "account_platformkey" + platform_key_table = f'"{DB_SCHEMA}".{DBTableV2.PLATFORM_KEY}' query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'" cursor = db.execute_sql(query) @@ -71,12 +66,8 @@ def get_token_from_auth_header(request: Request) -> Optional[str]: @staticmethod def get_account_from_bearer_token(token: Optional[str]) -> str: - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - platform_key_table = DBTableV2.PLATFORM_KEY - organization_table = DBTableV2.ORGANIZATION - else: - platform_key_table = "account_platformkey" - organization_table = "account_organization" + platform_key_table = DBTableV2.PLATFORM_KEY + organization_table = DBTableV2.ORGANIZATION query = f"SELECT organization_id FROM {platform_key_table} WHERE key='{token}'" organization = DBUtils.execute_query(query) diff --git a/prompt-service/src/unstract/prompt_service/constants.py b/prompt-service/src/unstract/prompt_service/constants.py index 6f5a73878..eb3c4b804 100644 --- a/prompt-service/src/unstract/prompt_service/constants.py +++ b/prompt-service/src/unstract/prompt_service/constants.py @@ -85,7 +85,7 @@ class RunLevel(Enum): class FeatureFlag: """Temporary feature flags.""" - MULTI_TENANCY_V2 = "multi_tenancy_v2" + pass class DBTableV2: diff --git a/prompt-service/src/unstract/prompt_service/helper.py b/prompt-service/src/unstract/prompt_service/helper.py index 4a1d3e2dd..cb57cda4e 100644 --- a/prompt-service/src/unstract/prompt_service/helper.py +++ b/prompt-service/src/unstract/prompt_service/helper.py @@ -7,9 +7,8 @@ from dotenv import load_dotenv from flask import Flask, current_app, json -from unstract.prompt_service.authentication_middleware import AuthenticationMiddleware from unstract.prompt_service.config import db -from unstract.prompt_service.constants import DBTableV2, FeatureFlag +from unstract.prompt_service.constants import DBTableV2 from unstract.prompt_service.constants import PromptServiceContants as PSKeys from unstract.prompt_service.db_utils import DBUtils from unstract.prompt_service.env_manager import EnvLoader @@ -18,8 +17,6 @@ from unstract.sdk.exceptions import SdkError from unstract.sdk.llm import LLM -from unstract.flags.feature_flag import check_feature_flag_status - load_dotenv() # Global variable to store plugins @@ -117,46 +114,6 @@ def initialize_plugin_endpoints(app: Flask) -> None: def query_usage_metadata(token: str, metadata: dict[str, Any]) -> dict[str, Any]: - if check_feature_flag_status(FeatureFlag.MULTI_TENANCY_V2): - return query_usage_metadata_v2(token, metadata) - org_id: str = AuthenticationMiddleware.get_account_from_bearer_token(token) - run_id: str = metadata["run_id"] - query: str = f""" - SELECT - usage_type, - llm_usage_reason, - model_name, - SUM(prompt_tokens) AS input_tokens, - SUM(completion_tokens) AS output_tokens, - SUM(total_tokens) AS total_tokens, - SUM(embedding_tokens) AS embedding_tokens, - SUM(cost_in_dollars) AS cost_in_dollars - FROM "{org_id}"."token_usage" - WHERE run_id = %s - GROUP BY usage_type, llm_usage_reason, model_name; - """ - logger: Logger = current_app.logger - try: - with db.atomic(): - logger.info( - "Querying usage metadata for org_id: %s, run_id: %s", org_id, run_id - ) - cursor = db.execute_sql(query, (run_id,)) - results: list[tuple] = cursor.fetchall() - # Process results as needed - for row in results: - key, item = _get_key_and_item(row) - # Initialize the key as an empty list if it doesn't exist - if key not in metadata: - metadata[key] = [] - # Append the item to the list associated with the key - metadata[key].append(item) - except Exception as e: - logger.error(f"Error executing querying usage metadata: {e}") - return metadata - - -def query_usage_metadata_v2(token: str, metadata: dict[str, Any]) -> dict[str, Any]: DB_SCHEMA = EnvLoader.get_env_or_die("DB_SCHEMA", "unstract_v2") organization_uid, org_id = DBUtils.get_organization_from_bearer_token(token) run_id: str = metadata["run_id"]