Skip to content
This repository has been archived by the owner on Dec 29, 2024. It is now read-only.

Commit

Permalink
Merge branch 'staging' of github.com:banodoco/banodoco
Browse files Browse the repository at this point in the history
  • Loading branch information
piyushK52 committed Feb 14, 2024
2 parents 4e3610f + c58b5e3 commit 05404cb
Show file tree
Hide file tree
Showing 68 changed files with 5,226 additions and 1,694 deletions.
6 changes: 6 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SERVER=development
SERVER_URL=http://127.0.0.1:8000
OFFLINE_MODE=True
GPU_INFERENCE_ENABLED=True
HOSTED_BACKGROUND_RUNNER_MODE=False
REPLICATE_KEY=xyz
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
comfyui*.log

venv
.vscode
Expand All @@ -22,6 +23,10 @@ test.py
doc.py
.env
data.json
comfy_runner/
ComfyUI/
output/
images.zip

# generated file TODO: move inside videos
depth.png
Expand Down
Empty file added __init__.py
Empty file.
Binary file removed arial.ttf
Binary file not shown.
36 changes: 33 additions & 3 deletions backend/db_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import subprocess
from typing import List
import uuid
from shared.constants import InternalFileType, SortOrder
from shared.constants import InferenceStatus, InternalFileTag, InternalFileType, SortOrder
from backend.serializers.dto import AIModelDto, AppSettingDto, BackupDto, BackupListDto, InferenceLogDto, InternalFileDto, ProjectDto, SettingDto, ShotDto, TimingDto, UserDto

from shared.constants import AUTOMATIC_FILE_HOSTING, LOCAL_DATABASE_NAME, SERVER, ServerType
Expand Down Expand Up @@ -371,6 +371,38 @@ def update_file(self, **kwargs):

return InternalResponse(payload, 'file updated successfully', True)

def get_file_count_from_type(self, file_tag, project_uuid):
project = Project.objects.filter(uuid=project_uuid, is_disabled=False).first()
file_count = InternalFileObject.objects.filter(tag=file_tag, project_id=project.id, is_disabled=False).count()
payload = {
'data': file_count
}

return InternalResponse(payload, 'file count fetched', True)

def get_explorer_pending_stats(self, project_uuid, log_status_list):
project = Project.objects.filter(uuid=project_uuid, is_disabled=False).first()
temp_image_count = InternalFileObject.objects.filter(tag=InternalFileTag.TEMP_GALLERY_IMAGE.value,\
project_id=project.id, is_disabled=False).count()
pending_image_count = InferenceLog.objects.filter(status__in=log_status_list, is_disabled=False).count()
payload = {
'data': {
'temp_image_count': temp_image_count,
'pending_image_count': pending_image_count
}
}

return InternalResponse(payload, 'file count fetched', True)

def update_temp_gallery_images(self, project_uuid):
project = Project.objects.filter(uuid=project_uuid, is_disabled=False).first()
InternalFileObject.objects.filter(
tag=InternalFileTag.TEMP_GALLERY_IMAGE.value,
project_id=project.id,
is_disabled=False).update(tag=InternalFileTag.GALLERY_IMAGE.value)

return True

# project
def get_project_from_uuid(self, uuid):
project = Project.objects.filter(uuid=uuid, is_disabled=False).first()
Expand Down Expand Up @@ -603,8 +635,6 @@ def create_inference_log(self, **kwargs):
if not attributes.is_valid():
return InternalResponse({}, attributes.errors, False)

print(attributes.data)

if 'project_id' in attributes.data and attributes.data['project_id']:
project = Project.objects.filter(uuid=attributes.data['project_id'], is_disabled=False).first()
if not project:
Expand Down
2 changes: 1 addition & 1 deletion backend/serializers/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class InternalFileDto(serializers.ModelSerializer):
inference_log = InferenceLogDto()
class Meta:
model = InternalFileObject
fields = ('uuid', 'name', 'local_path', 'type', 'hosted_url', 'created_on', 'inference_log', 'project')
fields = ('uuid', 'name', 'local_path', 'type', 'hosted_url', 'created_on', 'inference_log', 'project', 'tag')


class BasicShotDto(serializers.ModelSerializer):
Expand Down
94 changes: 73 additions & 21 deletions banodoco_runner.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import json
import os
import shutil
import signal
import sys
import time
import uuid
import requests
import traceback
import sentry_sdk
import setproctitle
from dotenv import load_dotenv
import django
from shared.constants import OFFLINE_MODE, InferenceParamType, InferenceStatus, InferenceType, ProjectMetaData, HOSTED_BACKGROUND_RUNNER_MODE
from shared.logging.constants import LoggingType
from shared.logging.logging import AppLogger
from shared.logging.logging import app_logger
from ui_components.methods.file_methods import load_from_env, save_to_env
from utils.common_utils import acquire_lock, release_lock
from utils.data_repo.data_repo import DataRepo
from utils.ml_processor.replicate.constants import replicate_status_map
from utils.ml_processor.constants import replicate_status_map

from utils.constants import RUNNER_PROCESS_NAME, AUTH_TOKEN, REFRESH_AUTH_TOKEN
from utils.ml_processor.gpu.utils import is_comfy_runner_present, predict_gpu_output, setup_comfy_runner


load_dotenv()
Expand Down Expand Up @@ -50,6 +54,7 @@

def handle_termination(signal, frame):
print("Received termination signal. Cleaning up...")
global TERMINATE_SCRIPT
TERMINATE_SCRIPT = True
sys.exit(0)

Expand Down Expand Up @@ -127,18 +132,41 @@ def is_app_running():
except requests.exceptions.RequestException as e:
print("server not running")
return False

def update_cache_dict(inference_type, log, timing_uuid, shot_uuid, timing_update_list, shot_update_list, gallery_update_list):
if inference_type in [InferenceType.FRAME_TIMING_IMAGE_INFERENCE.value, \
InferenceType.FRAME_INPAINTING.value]:
if str(log.project.uuid) not in timing_update_list:
timing_update_list[str(log.project.uuid)] = []
timing_update_list[str(log.project.uuid)].append(timing_uuid)

elif inference_type == InferenceType.GALLERY_IMAGE_GENERATION.value:
gallery_update_list[str(log.project.uuid)] = True

elif inference_type == InferenceType.FRAME_INTERPOLATION.value:
if str(log.project.uuid) not in shot_update_list:
shot_update_list[str(log.project.uuid)] = []
shot_update_list[str(log.project.uuid)].append(shot_uuid)

def check_and_update_db():
# print("updating logs")
from backend.models import InferenceLog, AppSetting, User

app_logger = AppLogger()
# returning if db creation and migrations are pending
try:
user = User.objects.filter(is_disabled=False).first()
except Exception as e:
app_logger.log(LoggingType.DEBUG, "db creation pending..")
time.sleep(3)
return

if not user:
return

user = User.objects.filter(is_disabled=False).first()
app_setting = AppSetting.objects.filter(user_id=user.id, is_disabled=False).first()
replicate_key = app_setting.replicate_key_decrypted
if not replicate_key:
app_logger.log(LoggingType.ERROR, "Replicate key not found")
# app_logger.log(LoggingType.ERROR, "Replicate key not found")
return

log_list = InferenceLog.objects.filter(status__in=[InferenceStatus.QUEUED.value, InferenceStatus.IN_PROGRESS.value],
Expand All @@ -152,6 +180,7 @@ def check_and_update_db():
for log in log_list:
input_params = json.loads(log.input_params)
replicate_data = input_params.get(InferenceParamType.REPLICATE_INFERENCE.value, None)
local_gpu_data = input_params.get(InferenceParamType.GPU_INFERENCE.value, None)
if replicate_data:
prediction_id = replicate_data['prediction_id']

Expand Down Expand Up @@ -186,7 +215,7 @@ def check_and_update_db():
update_data['total_inference_time'] = float(result['metrics']['predict_time'])

InferenceLog.objects.filter(id=log.id).update(**update_data)
origin_data = json.loads(log.input_params).get(InferenceParamType.ORIGIN_DATA.value, None)
origin_data = json.loads(log.input_params).get(InferenceParamType.ORIGIN_DATA.value, {})
if origin_data and log_status == InferenceStatus.COMPLETED.value:
from ui_components.methods.common_methods import process_inference_output

Expand All @@ -195,20 +224,8 @@ def check_and_update_db():
origin_data['log_uuid'] = log.uuid
print("processing inference output")
process_inference_output(**origin_data)

if origin_data['inference_type'] in [InferenceType.FRAME_TIMING_IMAGE_INFERENCE.value, \
InferenceType.FRAME_INPAINTING.value]:
if str(log.project.uuid) not in timing_update_list:
timing_update_list[str(log.project.uuid)] = []
timing_update_list[str(log.project.uuid)].append(origin_data['timing_uuid'])

elif origin_data['inference_type'] == InferenceType.GALLERY_IMAGE_GENERATION.value:
gallery_update_list[str(log.project.uuid)] = True

elif origin_data['inference_type'] == InferenceType.FRAME_INTERPOLATION.value:
if str(log.project.uuid) not in shot_update_list:
shot_update_list[str(log.project.uuid)] = []
shot_update_list[str(log.project.uuid)].append(origin_data['shot_uuid'])
timing_uuid, shot_uuid = origin_data.get('timing_uuid', None), origin_data.get('shot_uuid', None)
update_cache_dict(origin_data['inference_type'], log, timing_uuid, shot_uuid, timing_update_list, shot_update_list, gallery_update_list)

except Exception as e:
app_logger.log(LoggingType.ERROR, f"Error: {e}")
Expand All @@ -226,8 +243,43 @@ def check_and_update_db():
if response:
app_logger.log(LoggingType.DEBUG, f"Error: {response.content}")
sentry_sdk.capture_exception(response.content)
elif local_gpu_data:
data = json.loads(local_gpu_data)
try:
setup_comfy_runner()
start_time = time.time()
output = predict_gpu_output(data['workflow_input'], data['file_path_list'], data['output_node_ids'])
end_time = time.time()

output = output[-1] # TODO: different models can have different logic
destination_path = "./videos/temp/" + str(uuid.uuid4()) + "." + output.split(".")[-1]
shutil.copy2("./output/" + output, destination_path)
output_details = json.loads(log.output_details)
output_details['output'] = destination_path
update_data = {
"status" : InferenceStatus.COMPLETED.value,
"output_details" : json.dumps(output_details),
"total_inference_time" : end_time - start_time,
}

InferenceLog.objects.filter(id=log.id).update(**update_data)
origin_data = json.loads(log.input_params).get(InferenceParamType.ORIGIN_DATA.value, {})
origin_data['output'] = destination_path
origin_data['log_uuid'] = log.uuid
print("processing inference output")

from ui_components.methods.common_methods import process_inference_output
process_inference_output(**origin_data)
timing_uuid, shot_uuid = origin_data.get('timing_uuid', None), origin_data.get('shot_uuid', None)
update_cache_dict(origin_data['inference_type'], log, timing_uuid, shot_uuid, timing_update_list, shot_update_list, gallery_update_list)

except Exception as e:
print("error occured: ", str(e))
# sentry_sdk.capture_exception(e)
traceback.print_exc()
InferenceLog.objects.filter(id=log.id).update(status=InferenceStatus.FAILED.value)
else:
# if not replicate data is present then removing the status
# if replicate/gpu data is not present then removing the status
InferenceLog.objects.filter(id=log.id).update(status="")

# adding update_data in the project
Expand Down
21 changes: 15 additions & 6 deletions banodoco_settings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import copy
import time
import uuid
import streamlit as st

from PIL import Image
from shared.constants import SERVER, AIModelCategory, GuidanceType, InternalFileType, ServerType
from shared.logging.constants import LoggingType
from shared.logging.logging import AppLogger
from shared.logging.logging import app_logger
from shared.constants import AnimationStyleType
from ui_components.methods.common_methods import add_image_variant
from ui_components.methods.file_methods import save_or_host_file
Expand All @@ -14,17 +15,25 @@
from utils.constants import ML_MODEL_LIST
from utils.data_repo.data_repo import DataRepo

logger = AppLogger()

def wait_for_db_ready():
data_repo = DataRepo()
while True:
try:
user_count = data_repo.get_total_user_count()
break
except Exception as e:
app_logger.log(LoggingType.DEBUG, 'waiting for db...')
time.sleep(3)

def project_init():
from utils.data_repo.data_repo import DataRepo
data_repo = DataRepo()

# db initialization takes some time
# time.sleep(2)
wait_for_db_ready()
user_count = data_repo.get_total_user_count()
# create a user if not already present (if dev mode)
# if this is the local server with no user than create one and related data
user_count = data_repo.get_total_user_count()
if SERVER == ServerType.DEVELOPMENT.value and not user_count:
user_data = {
"name" : "banodoco_user",
Expand All @@ -33,7 +42,7 @@ def project_init():
"type" : "user"
}
user: InternalUserObject = data_repo.create_user(**user_data)
logger.log(LoggingType.INFO, "new temp user created: " + user.name)
app_logger.log(LoggingType.INFO, "new temp user created: " + user.name)

create_new_user_data(user)
# creating data for online user
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ extra-streamlit-components==0.1.56
wrapt==1.15.0
pydantic==1.10.9
streamlit-server-state==0.17.1
setproctitle==1.3.3
setproctitle==1.3.3
gitdb==4.0.11
Binary file added sample_assets/sample_images/main.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 7 additions & 2 deletions shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class InternalFileTag(ExtendedEnum):
TEMP_IMAGE = 'temp'
GALLERY_IMAGE = 'gallery_image'
SHORTLISTED_GALLERY_IMAGE = 'shortlisted_gallery_image'
TEMP_GALLERY_IMAGE = 'temp_gallery_image' # these generations are complete but not yet being shown in the gallery

class AnimationStyleType(ExtendedEnum):
CREATIVE_INTERPOLATION = "Creative Interpolation"
Expand Down Expand Up @@ -93,6 +94,7 @@ class InferenceParamType(ExtendedEnum):
REPLICATE_INFERENCE = "replicate_inference" # replicate url for queue inference and other data
QUERY_DICT = "query_dict" # query dict of standardized inference params
ORIGIN_DATA = "origin_data" # origin data - used to store file once inference is completed
GPU_INFERENCE = "gpu_inference" # gpu inference data

class ProjectMetaData(ExtendedEnum):
DATA_UPDATE = "data_update" # info regarding cache/data update when runner updates the db
Expand All @@ -108,15 +110,18 @@ class SortOrder(ExtendedEnum):
SERVER = os.getenv('SERVER', ServerType.PRODUCTION.value)

AUTOMATIC_FILE_HOSTING = SERVER != ServerType.DEVELOPMENT.value # automatically upload project files to s3 (images, videos, gifs)
AWS_S3_BUCKET = 'banodoco'
AWS_S3_REGION = 'ap-south-1' # TODO: discuss this
AWS_S3_BUCKET = "banodoco-data-bucket-public"
AWS_S3_REGION = 'ap-south-1'
AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY", "")
AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY", "")
OFFLINE_MODE = os.getenv('OFFLINE_MODE', False) # for picking up secrets and file storage

LOCAL_DATABASE_NAME = 'banodoco_local.db'
ENCRYPTION_KEY = os.getenv('ENCRYPTION_KEY', 'J2684nBgNUYa_K0a6oBr5H8MpSRW0EJ52Qmq7jExE-w=')

QUEUE_INFERENCE_QUERIES = True
HOSTED_BACKGROUND_RUNNER_MODE = os.getenv('HOSTED_BACKGROUND_RUNNER_MODE', False)
GPU_INFERENCE_ENABLED = False if os.getenv('GPU_INFERENCE_ENABLED', False) in [False, 'False'] else True

if OFFLINE_MODE:
SECRET_ACCESS_TOKEN = os.getenv('SECRET_ACCESS_TOKEN', None)
Expand Down
11 changes: 7 additions & 4 deletions shared/file_upload/s3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import mimetypes
from urllib.parse import urlparse
import boto3
Expand All @@ -6,9 +7,10 @@
import shutil

import requests
from shared.constants import AWS_S3_BUCKET, AWS_S3_REGION
from shared.constants import AWS_ACCESS_KEY, AWS_S3_BUCKET, AWS_S3_REGION, AWS_SECRET_KEY
from shared.logging.logging import AppLogger
from shared.logging.constants import LoggingPayload, LoggingType
from ui_components.methods.file_methods import convert_file_to_base64
logger = AppLogger()

# TODO: fix proper paths for file uploads
Expand All @@ -29,14 +31,15 @@ def upload_file(file_location, aws_access_key, aws_secret_key, bucket=AWS_S3_BUC

return url

def upload_file_from_obj(file, aws_access_key, aws_secret_key, bucket=AWS_S3_BUCKET):
def upload_file_from_obj(file, file_extension, bucket=AWS_S3_BUCKET):
aws_access_key, aws_secret_key = AWS_ACCESS_KEY, AWS_SECRET_KEY
folder = 'test/'
unique_tag = str(uuid.uuid4())
file_extension = os.path.splitext(file.name)[1]
filename = unique_tag + file_extension
file.seek(0)

# Upload the file
content_type = mimetypes.guess_type(file.name)[0]
content_type = "application/octet-stream" if file_extension not in [".png", ".jpg"] else "image/png" # hackish sol, will fix later
data = {
"Body": file,
"Bucket": bucket,
Expand Down
Loading

0 comments on commit 05404cb

Please sign in to comment.