diff --git a/.env.sample b/.env.sample new file mode 100644 index 00000000..bc7acee1 --- /dev/null +++ b/.env.sample @@ -0,0 +1,6 @@ +SERVER=development +SERVER_URL=http://127.0.0.1:8000 +OFFLINE_MODE=True +GPU_INFERENCE_ENABLED=True +HOSTED_BACKGROUND_RUNNER_MODE=False +REPLICATE_KEY=xyz \ No newline at end of file diff --git a/.gitignore b/.gitignore index 1a128308..5f552954 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +comfyui*.log venv .vscode @@ -22,6 +23,10 @@ test.py doc.py .env data.json +comfy_runner/ +ComfyUI/ +output/ +images.zip # generated file TODO: move inside videos depth.png diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/arial.ttf b/arial.ttf deleted file mode 100644 index ff0815cd..00000000 Binary files a/arial.ttf and /dev/null differ diff --git a/backend/db_repo.py b/backend/db_repo.py index 7fa9490c..b3aedbee 100644 --- a/backend/db_repo.py +++ b/backend/db_repo.py @@ -15,7 +15,7 @@ import subprocess from typing import List import uuid -from shared.constants import InternalFileType, SortOrder +from shared.constants import InferenceStatus, InternalFileTag, InternalFileType, SortOrder from backend.serializers.dto import AIModelDto, AppSettingDto, BackupDto, BackupListDto, InferenceLogDto, InternalFileDto, ProjectDto, SettingDto, ShotDto, TimingDto, UserDto from shared.constants import AUTOMATIC_FILE_HOSTING, LOCAL_DATABASE_NAME, SERVER, ServerType @@ -371,6 +371,38 @@ def update_file(self, **kwargs): return InternalResponse(payload, 'file updated successfully', True) + def get_file_count_from_type(self, file_tag, project_uuid): + project = Project.objects.filter(uuid=project_uuid, is_disabled=False).first() + file_count = InternalFileObject.objects.filter(tag=file_tag, project_id=project.id, is_disabled=False).count() + payload = { + 'data': file_count + } + + return InternalResponse(payload, 'file count fetched', True) + + def get_explorer_pending_stats(self, project_uuid, log_status_list): + project = Project.objects.filter(uuid=project_uuid, is_disabled=False).first() + temp_image_count = InternalFileObject.objects.filter(tag=InternalFileTag.TEMP_GALLERY_IMAGE.value,\ + project_id=project.id, is_disabled=False).count() + pending_image_count = InferenceLog.objects.filter(status__in=log_status_list, is_disabled=False).count() + payload = { + 'data': { + 'temp_image_count': temp_image_count, + 'pending_image_count': pending_image_count + } + } + + return InternalResponse(payload, 'file count fetched', True) + + def update_temp_gallery_images(self, project_uuid): + project = Project.objects.filter(uuid=project_uuid, is_disabled=False).first() + InternalFileObject.objects.filter( + tag=InternalFileTag.TEMP_GALLERY_IMAGE.value, + project_id=project.id, + is_disabled=False).update(tag=InternalFileTag.GALLERY_IMAGE.value) + + return True + # project def get_project_from_uuid(self, uuid): project = Project.objects.filter(uuid=uuid, is_disabled=False).first() @@ -603,8 +635,6 @@ def create_inference_log(self, **kwargs): if not attributes.is_valid(): return InternalResponse({}, attributes.errors, False) - print(attributes.data) - if 'project_id' in attributes.data and attributes.data['project_id']: project = Project.objects.filter(uuid=attributes.data['project_id'], is_disabled=False).first() if not project: diff --git a/backend/serializers/dto.py b/backend/serializers/dto.py index 4fac24df..57a11b54 100644 --- a/backend/serializers/dto.py +++ b/backend/serializers/dto.py @@ -71,7 +71,7 @@ class InternalFileDto(serializers.ModelSerializer): inference_log = InferenceLogDto() class Meta: model = InternalFileObject - fields = ('uuid', 'name', 'local_path', 'type', 'hosted_url', 'created_on', 'inference_log', 'project') + fields = ('uuid', 'name', 'local_path', 'type', 'hosted_url', 'created_on', 'inference_log', 'project', 'tag') class BasicShotDto(serializers.ModelSerializer): diff --git a/banodoco_runner.py b/banodoco_runner.py index 769bd46d..27ade190 100644 --- a/banodoco_runner.py +++ b/banodoco_runner.py @@ -1,22 +1,26 @@ import json import os +import shutil import signal import sys import time +import uuid import requests +import traceback import sentry_sdk import setproctitle from dotenv import load_dotenv import django from shared.constants import OFFLINE_MODE, InferenceParamType, InferenceStatus, InferenceType, ProjectMetaData, HOSTED_BACKGROUND_RUNNER_MODE from shared.logging.constants import LoggingType -from shared.logging.logging import AppLogger +from shared.logging.logging import app_logger from ui_components.methods.file_methods import load_from_env, save_to_env from utils.common_utils import acquire_lock, release_lock from utils.data_repo.data_repo import DataRepo -from utils.ml_processor.replicate.constants import replicate_status_map +from utils.ml_processor.constants import replicate_status_map from utils.constants import RUNNER_PROCESS_NAME, AUTH_TOKEN, REFRESH_AUTH_TOKEN +from utils.ml_processor.gpu.utils import is_comfy_runner_present, predict_gpu_output, setup_comfy_runner load_dotenv() @@ -50,6 +54,7 @@ def handle_termination(signal, frame): print("Received termination signal. Cleaning up...") + global TERMINATE_SCRIPT TERMINATE_SCRIPT = True sys.exit(0) @@ -127,18 +132,41 @@ def is_app_running(): except requests.exceptions.RequestException as e: print("server not running") return False + +def update_cache_dict(inference_type, log, timing_uuid, shot_uuid, timing_update_list, shot_update_list, gallery_update_list): + if inference_type in [InferenceType.FRAME_TIMING_IMAGE_INFERENCE.value, \ + InferenceType.FRAME_INPAINTING.value]: + if str(log.project.uuid) not in timing_update_list: + timing_update_list[str(log.project.uuid)] = [] + timing_update_list[str(log.project.uuid)].append(timing_uuid) + + elif inference_type == InferenceType.GALLERY_IMAGE_GENERATION.value: + gallery_update_list[str(log.project.uuid)] = True + + elif inference_type == InferenceType.FRAME_INTERPOLATION.value: + if str(log.project.uuid) not in shot_update_list: + shot_update_list[str(log.project.uuid)] = [] + shot_update_list[str(log.project.uuid)].append(shot_uuid) def check_and_update_db(): # print("updating logs") from backend.models import InferenceLog, AppSetting, User - app_logger = AppLogger() + # returning if db creation and migrations are pending + try: + user = User.objects.filter(is_disabled=False).first() + except Exception as e: + app_logger.log(LoggingType.DEBUG, "db creation pending..") + time.sleep(3) + return + + if not user: + return - user = User.objects.filter(is_disabled=False).first() app_setting = AppSetting.objects.filter(user_id=user.id, is_disabled=False).first() replicate_key = app_setting.replicate_key_decrypted if not replicate_key: - app_logger.log(LoggingType.ERROR, "Replicate key not found") + # app_logger.log(LoggingType.ERROR, "Replicate key not found") return log_list = InferenceLog.objects.filter(status__in=[InferenceStatus.QUEUED.value, InferenceStatus.IN_PROGRESS.value], @@ -152,6 +180,7 @@ def check_and_update_db(): for log in log_list: input_params = json.loads(log.input_params) replicate_data = input_params.get(InferenceParamType.REPLICATE_INFERENCE.value, None) + local_gpu_data = input_params.get(InferenceParamType.GPU_INFERENCE.value, None) if replicate_data: prediction_id = replicate_data['prediction_id'] @@ -186,7 +215,7 @@ def check_and_update_db(): update_data['total_inference_time'] = float(result['metrics']['predict_time']) InferenceLog.objects.filter(id=log.id).update(**update_data) - origin_data = json.loads(log.input_params).get(InferenceParamType.ORIGIN_DATA.value, None) + origin_data = json.loads(log.input_params).get(InferenceParamType.ORIGIN_DATA.value, {}) if origin_data and log_status == InferenceStatus.COMPLETED.value: from ui_components.methods.common_methods import process_inference_output @@ -195,20 +224,8 @@ def check_and_update_db(): origin_data['log_uuid'] = log.uuid print("processing inference output") process_inference_output(**origin_data) - - if origin_data['inference_type'] in [InferenceType.FRAME_TIMING_IMAGE_INFERENCE.value, \ - InferenceType.FRAME_INPAINTING.value]: - if str(log.project.uuid) not in timing_update_list: - timing_update_list[str(log.project.uuid)] = [] - timing_update_list[str(log.project.uuid)].append(origin_data['timing_uuid']) - - elif origin_data['inference_type'] == InferenceType.GALLERY_IMAGE_GENERATION.value: - gallery_update_list[str(log.project.uuid)] = True - - elif origin_data['inference_type'] == InferenceType.FRAME_INTERPOLATION.value: - if str(log.project.uuid) not in shot_update_list: - shot_update_list[str(log.project.uuid)] = [] - shot_update_list[str(log.project.uuid)].append(origin_data['shot_uuid']) + timing_uuid, shot_uuid = origin_data.get('timing_uuid', None), origin_data.get('shot_uuid', None) + update_cache_dict(origin_data['inference_type'], log, timing_uuid, shot_uuid, timing_update_list, shot_update_list, gallery_update_list) except Exception as e: app_logger.log(LoggingType.ERROR, f"Error: {e}") @@ -226,8 +243,43 @@ def check_and_update_db(): if response: app_logger.log(LoggingType.DEBUG, f"Error: {response.content}") sentry_sdk.capture_exception(response.content) + elif local_gpu_data: + data = json.loads(local_gpu_data) + try: + setup_comfy_runner() + start_time = time.time() + output = predict_gpu_output(data['workflow_input'], data['file_path_list'], data['output_node_ids']) + end_time = time.time() + + output = output[-1] # TODO: different models can have different logic + destination_path = "./videos/temp/" + str(uuid.uuid4()) + "." + output.split(".")[-1] + shutil.copy2("./output/" + output, destination_path) + output_details = json.loads(log.output_details) + output_details['output'] = destination_path + update_data = { + "status" : InferenceStatus.COMPLETED.value, + "output_details" : json.dumps(output_details), + "total_inference_time" : end_time - start_time, + } + + InferenceLog.objects.filter(id=log.id).update(**update_data) + origin_data = json.loads(log.input_params).get(InferenceParamType.ORIGIN_DATA.value, {}) + origin_data['output'] = destination_path + origin_data['log_uuid'] = log.uuid + print("processing inference output") + + from ui_components.methods.common_methods import process_inference_output + process_inference_output(**origin_data) + timing_uuid, shot_uuid = origin_data.get('timing_uuid', None), origin_data.get('shot_uuid', None) + update_cache_dict(origin_data['inference_type'], log, timing_uuid, shot_uuid, timing_update_list, shot_update_list, gallery_update_list) + + except Exception as e: + print("error occured: ", str(e)) + # sentry_sdk.capture_exception(e) + traceback.print_exc() + InferenceLog.objects.filter(id=log.id).update(status=InferenceStatus.FAILED.value) else: - # if not replicate data is present then removing the status + # if replicate/gpu data is not present then removing the status InferenceLog.objects.filter(id=log.id).update(status="") # adding update_data in the project diff --git a/banodoco_settings.py b/banodoco_settings.py index 12e6811e..4656e935 100644 --- a/banodoco_settings.py +++ b/banodoco_settings.py @@ -1,11 +1,12 @@ import copy +import time import uuid import streamlit as st from PIL import Image from shared.constants import SERVER, AIModelCategory, GuidanceType, InternalFileType, ServerType from shared.logging.constants import LoggingType -from shared.logging.logging import AppLogger +from shared.logging.logging import app_logger from shared.constants import AnimationStyleType from ui_components.methods.common_methods import add_image_variant from ui_components.methods.file_methods import save_or_host_file @@ -14,17 +15,25 @@ from utils.constants import ML_MODEL_LIST from utils.data_repo.data_repo import DataRepo -logger = AppLogger() + +def wait_for_db_ready(): + data_repo = DataRepo() + while True: + try: + user_count = data_repo.get_total_user_count() + break + except Exception as e: + app_logger.log(LoggingType.DEBUG, 'waiting for db...') + time.sleep(3) def project_init(): from utils.data_repo.data_repo import DataRepo data_repo = DataRepo() - # db initialization takes some time - # time.sleep(2) + wait_for_db_ready() + user_count = data_repo.get_total_user_count() # create a user if not already present (if dev mode) # if this is the local server with no user than create one and related data - user_count = data_repo.get_total_user_count() if SERVER == ServerType.DEVELOPMENT.value and not user_count: user_data = { "name" : "banodoco_user", @@ -33,7 +42,7 @@ def project_init(): "type" : "user" } user: InternalUserObject = data_repo.create_user(**user_data) - logger.log(LoggingType.INFO, "new temp user created: " + user.name) + app_logger.log(LoggingType.INFO, "new temp user created: " + user.name) create_new_user_data(user) # creating data for online user diff --git a/requirements.txt b/requirements.txt index 69e43127..8bc8db43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,4 +36,5 @@ extra-streamlit-components==0.1.56 wrapt==1.15.0 pydantic==1.10.9 streamlit-server-state==0.17.1 -setproctitle==1.3.3 \ No newline at end of file +setproctitle==1.3.3 +gitdb==4.0.11 \ No newline at end of file diff --git a/sample_assets/sample_images/main.png b/sample_assets/sample_images/main.png new file mode 100644 index 00000000..c99f6ad7 Binary files /dev/null and b/sample_assets/sample_images/main.png differ diff --git a/shared/constants.py b/shared/constants.py index 3672cd81..3ad951b1 100644 --- a/shared/constants.py +++ b/shared/constants.py @@ -61,6 +61,7 @@ class InternalFileTag(ExtendedEnum): TEMP_IMAGE = 'temp' GALLERY_IMAGE = 'gallery_image' SHORTLISTED_GALLERY_IMAGE = 'shortlisted_gallery_image' + TEMP_GALLERY_IMAGE = 'temp_gallery_image' # these generations are complete but not yet being shown in the gallery class AnimationStyleType(ExtendedEnum): CREATIVE_INTERPOLATION = "Creative Interpolation" @@ -93,6 +94,7 @@ class InferenceParamType(ExtendedEnum): REPLICATE_INFERENCE = "replicate_inference" # replicate url for queue inference and other data QUERY_DICT = "query_dict" # query dict of standardized inference params ORIGIN_DATA = "origin_data" # origin data - used to store file once inference is completed + GPU_INFERENCE = "gpu_inference" # gpu inference data class ProjectMetaData(ExtendedEnum): DATA_UPDATE = "data_update" # info regarding cache/data update when runner updates the db @@ -108,8 +110,10 @@ class SortOrder(ExtendedEnum): SERVER = os.getenv('SERVER', ServerType.PRODUCTION.value) AUTOMATIC_FILE_HOSTING = SERVER != ServerType.DEVELOPMENT.value # automatically upload project files to s3 (images, videos, gifs) -AWS_S3_BUCKET = 'banodoco' -AWS_S3_REGION = 'ap-south-1' # TODO: discuss this +AWS_S3_BUCKET = "banodoco-data-bucket-public" +AWS_S3_REGION = 'ap-south-1' +AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY", "") +AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY", "") OFFLINE_MODE = os.getenv('OFFLINE_MODE', False) # for picking up secrets and file storage LOCAL_DATABASE_NAME = 'banodoco_local.db' @@ -117,6 +121,7 @@ class SortOrder(ExtendedEnum): QUEUE_INFERENCE_QUERIES = True HOSTED_BACKGROUND_RUNNER_MODE = os.getenv('HOSTED_BACKGROUND_RUNNER_MODE', False) +GPU_INFERENCE_ENABLED = False if os.getenv('GPU_INFERENCE_ENABLED', False) in [False, 'False'] else True if OFFLINE_MODE: SECRET_ACCESS_TOKEN = os.getenv('SECRET_ACCESS_TOKEN', None) diff --git a/shared/file_upload/s3.py b/shared/file_upload/s3.py index 1978087e..35d17515 100644 --- a/shared/file_upload/s3.py +++ b/shared/file_upload/s3.py @@ -1,3 +1,4 @@ +import hashlib import mimetypes from urllib.parse import urlparse import boto3 @@ -6,9 +7,10 @@ import shutil import requests -from shared.constants import AWS_S3_BUCKET, AWS_S3_REGION +from shared.constants import AWS_ACCESS_KEY, AWS_S3_BUCKET, AWS_S3_REGION, AWS_SECRET_KEY from shared.logging.logging import AppLogger from shared.logging.constants import LoggingPayload, LoggingType +from ui_components.methods.file_methods import convert_file_to_base64 logger = AppLogger() # TODO: fix proper paths for file uploads @@ -29,14 +31,15 @@ def upload_file(file_location, aws_access_key, aws_secret_key, bucket=AWS_S3_BUC return url -def upload_file_from_obj(file, aws_access_key, aws_secret_key, bucket=AWS_S3_BUCKET): +def upload_file_from_obj(file, file_extension, bucket=AWS_S3_BUCKET): + aws_access_key, aws_secret_key = AWS_ACCESS_KEY, AWS_SECRET_KEY folder = 'test/' unique_tag = str(uuid.uuid4()) - file_extension = os.path.splitext(file.name)[1] filename = unique_tag + file_extension + file.seek(0) # Upload the file - content_type = mimetypes.guess_type(file.name)[0] + content_type = "application/octet-stream" if file_extension not in [".png", ".jpg"] else "image/png" # hackish sol, will fix later data = { "Body": file, "Bucket": bucket, diff --git a/shared/logging/logging.py b/shared/logging/logging.py index d71628f8..b2865431 100644 --- a/shared/logging/logging.py +++ b/shared/logging/logging.py @@ -59,4 +59,5 @@ def log(self, log_type: LoggingType, log_message, log_data = None): self.error(log_message) elif log_type in [LoggingType.INFERENCE_CALL, LoggingType.INFERENCE_RESULT]: self.info(log_message) - \ No newline at end of file + +app_logger = AppLogger() \ No newline at end of file diff --git a/ui_components/components/adjust_shot_page.py b/ui_components/components/adjust_shot_page.py index bb326d07..591a81db 100644 --- a/ui_components/components/adjust_shot_page.py +++ b/ui_components/components/adjust_shot_page.py @@ -2,45 +2,38 @@ from ui_components.widgets.shot_view import shot_keyframe_element from ui_components.components.explorer_page import gallery_image_view from ui_components.components.explorer_page import generate_images_element +from ui_components.components.frame_styling_page import frame_styling_page from ui_components.widgets.frame_selector import frame_selector_widget, frame_view from utils import st_memory from utils.data_repo.data_repo import DataRepo - def adjust_shot_page(shot_uuid: str, h2): - data_repo = DataRepo() - shot = data_repo.get_shot_from_uuid(shot_uuid) - - with h2: - frame_selector_widget(show=['shot_selector']) - - st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]") - - st.markdown("***") - with st.sidebar: - frame_view(view='Video') - - shot_keyframe_element(st.session_state["shot_uuid"], 4, position="Individual") - # with st.expander("πŸ“‹ Explorer Shortlist",expanded=True): - shot_explorer_view = st_memory.menu('',["Shortlist", "Explore"], - icons=['grid-3x3','airplane'], - menu_icon="cast", - default_index=st.session_state.get('shot_explorer_view', 0), - key="shot_explorer_view", orientation="horizontal", - styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "#868c91"}}) + frame_selection = frame_selector_widget(show_frame_selector=True) - st.markdown("***") + data_repo = DataRepo() + shot = data_repo.get_shot_from_uuid(shot_uuid) - if shot_explorer_view == "Shortlist": - project_setting = data_repo.get_project_setting(shot.project.uuid) - page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) + if frame_selection == "": + with st.sidebar: + frame_view(view='Video',show_current_frames=False) + with st.expander("πŸ“‹ Explorer Shortlist",expanded=True): + if st_memory.toggle("Open", value=True, key="explorer_shortlist_toggle"): + project_setting = data_repo.get_project_setting(shot.project.uuid) + number_of_pages = project_setting.total_shortlist_gallery_pages + page_number = 0 + gallery_image_view(shot.project.uuid, shortlist=True,view=['add_and_remove_from_shortlist','add_to_this_shot'], shot=shot,sidebar=True) + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]") st.markdown("***") - gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=8, open_detailed_view_for_all=False, shortlist=True, num_columns=4,view="individual_shot", shot=shot) - elif shot_explorer_view == "Explore": + shot_keyframe_element(st.session_state["shot_uuid"], 4, position="Individual") project_setting = data_repo.get_project_setting(shot.project.uuid) - page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) - generate_images_element(position='explorer', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) + + with st.expander("✨ Generate Images", expanded=True): + generate_images_element(position='explorer', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) + st.markdown("***") + st.markdown("***") - gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=8, open_detailed_view_for_all=False, shortlist=False, num_columns=4,view="individual_shot", shot=shot) \ No newline at end of file + gallery_image_view(shot.project.uuid, shortlist=False,view=['add_and_remove_from_shortlist','add_to_this_shot','view_inference_details'], shot=shot,sidebar=False) + else: + frame_styling_page(st.session_state["shot_uuid"], h2) \ No newline at end of file diff --git a/ui_components/components/animate_shot_page.py b/ui_components/components/animate_shot_page.py index e30e8282..98a651d0 100644 --- a/ui_components/components/animate_shot_page.py +++ b/ui_components/components/animate_shot_page.py @@ -8,13 +8,16 @@ def animate_shot_page(shot_uuid: str, h2): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) - with h2: - frame_selector_widget(show=['shot_selector']) + + with st.sidebar: - frame_view() + frame_selector_widget(show_frame_selector=False) + frame_view(view='Video',show_current_frames=False) st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]") st.markdown("***") variant_comparison_grid(st.session_state['shot_uuid'], stage="Shots") - with st.expander("🎬 Choose Animation Style & Create Variants", expanded=True): - animation_style_element(st.session_state['shot_uuid']) \ No newline at end of file + + animation_style_element(st.session_state['shot_uuid']) + + st.markdown("***") \ No newline at end of file diff --git a/ui_components/components/app_settings_page.py b/ui_components/components/app_settings_page.py index be505adc..22d69841 100644 --- a/ui_components/components/app_settings_page.py +++ b/ui_components/components/app_settings_page.py @@ -3,17 +3,17 @@ import webbrowser from shared.constants import SERVER, ServerType from utils.common_utils import get_current_user +from ui_components.components.query_logger_page import query_logger_page from utils.data_repo.data_repo import DataRepo def app_settings_page(): data_repo = DataRepo() - - if SERVER == ServerType.DEVELOPMENT.value: - st.subheader("Purchase Credits") - st.write("This feature is only available in production") + st.markdown("#### App Settings") + st.markdown("***") + if SERVER != ServerType.DEVELOPMENT.value: with st.expander("Purchase Credits", expanded=True): user_credits = get_current_user(invalidate_cache=True).total_credits @@ -37,4 +37,6 @@ def app_settings_page(): else: payment_link = data_repo.generate_payment_link(credits) payment_link = f""" PAYMENT LINK """ - st.markdown(payment_link, unsafe_allow_html=True) \ No newline at end of file + st.markdown(payment_link, unsafe_allow_html=True) + + query_logger_page() \ No newline at end of file diff --git a/ui_components/components/custom_models_page.py b/ui_components/components/custom_models_page.py index 0874cb16..305ca2a7 100644 --- a/ui_components/components/custom_models_page.py +++ b/ui_components/components/custom_models_page.py @@ -9,167 +9,168 @@ from utils.data_repo.data_repo import DataRepo -def custom_models_page(project_uuid): - data_repo = DataRepo() - - with st.expander("Existing models", expanded=True): - - st.subheader("Existing Models:") - - # TODO: the list should only show models trained by the user (not the default ones) - current_user_uuid = get_current_user_uuid() - model_list: List[InternalAIModelObject] = data_repo.get_all_ai_model_list(\ - model_category_list=[AIModelCategory.DREAMBOOTH.value, AIModelCategory.LORA.value], \ - user_id=current_user_uuid, custom_trained=True) - if model_list == []: - st.info("You don't have any models yet. Train a new model below.") - else: - header1, header2, header3, header4, header5, header6 = st.columns( - 6) - with header1: - st.markdown("###### Model Name") - with header2: - st.markdown("###### Trigger Word") - with header3: - st.markdown("###### Model Type") - with header4: - st.markdown("###### Example Image #1") - with header5: - st.markdown("###### Example Image #2") - with header6: - st.markdown("###### Example Image #3") - - for model in model_list: - col1, col2, col3, col4, col5, col6 = st.columns(6) - with col1: - st.text(model.name) - with col2: - if model.keyword != "": - st.text(model.keyword) - with col3: - if model.category != "": - st.text(model.category) - with col4: - if len(model.training_image_list) > 0: - st.image(model.training_image_list[0].location) - with col5: - if len(model.training_image_list) > 1: - st.image(model.training_image_list[1].location) - with col6: - if len(model.training_image_list) > 2: - st.image(model.training_image_list[2].location) - st.markdown("***") - - with st.expander("Train a new model", expanded=True): - st.subheader("Train a new model:") - - type_of_model = st.selectbox("Type of model:", [AIModelCategory.DREAMBOOTH.value, AIModelCategory.LORA.value], help="If you'd like to use other methods for model training, let us know - or implement it yourself :)") - model_name = st.text_input( - "Model name:", value="", help="No spaces or special characters please") - - if type_of_model == AIModelCategory.DREAMBOOTH.value: - instance_prompt = st.text_input( - "Trigger word:", value="", help="This is the word that will trigger the model") - class_prompt = st.text_input("Describe what your prompts depict generally:", - value="", help="This will help guide the model to learn what you want it to do") - max_train_steps = st.number_input( - "Max training steps:", value=2000, help=" The number of training steps to run. Fewer steps make it run faster but typically make it worse quality, and vice versa.") - type_of_task = "" - resolution = "" - controller_type = st.selectbox("What ControlNet controller would you like to use?", [ - "normal", "canny", "hed", "scribble", "seg", "openpose", "depth", "mlsd"]) - model_type_list = json.dumps([AIModelType.TXT2IMG.value]) - - elif type_of_model == AIModelCategory.LORA.value: - type_of_task = st.selectbox( - "Type of task:", ["Face", "Object", "Style"]).lower() - resolution = st.selectbox("Resolution:", [ - "512", "768", "1024"], help="The resolution for input images. All the images in the train/validation dataset will be resized to this resolution.") - instance_prompt = "" - class_prompt = "" - max_train_steps = "" - controller_type = "" - model_type_list = json.dumps([AIModelType.TXT2IMG.value]) - - uploaded_files = st.file_uploader("Images you'd like to train the model based on:", type=[ - 'png', 'jpg', 'jpeg'], key="prompt_file", accept_multiple_files=True) - if uploaded_files is not None: - column = 0 - for image in uploaded_files: - # if it's an even number - if uploaded_files.index(image) % 2 == 0: - column = column + 1 - row_1_key = str(column) + 'a' - row_2_key = str(column) + 'b' - row_1_key, row_2_key = st.columns([1, 1]) - with row_1_key: - st.image( - uploaded_files[uploaded_files.index(image)], width=300) - else: - with row_2_key: - st.image( - uploaded_files[uploaded_files.index(image)], width=300) - - st.write(f"You've selected {len(uploaded_files)} images.") - - if len(uploaded_files) <= 5 and model_name == "": - st.write( - "Select at least 5 images and fill in all the fields to train a new model.") - st.button("Train Model", disabled=True) - else: - if st.button("Train Model", disabled=False): - st.info("Loading...") - - # TODO: check the local storage - # directory = "videos/training_data" - # if not os.path.exists(directory): - # os.makedirs(directory) - - # for image in uploaded_files: - # with open(os.path.join(f"videos/training_data", image.name), "wb") as f: - # f.write(image.getbuffer()) - # images_for_model.append(image.name) - model_status = train_model(uploaded_files, instance_prompt, class_prompt, max_train_steps, - model_name, type_of_model, type_of_task, resolution, controller_type, model_type_list) - st.success(model_status) - - # with st.expander("Add model from internet"): - # st.subheader("Add a model the internet:") - # uploaded_type_of_model = st.selectbox("Type of model:", [ - # "LoRA", "Dreambooth"], key="uploaded_type_of_model", disabled=True, help="You can currently only upload LoRA models - this will change soon.") - # uploaded_model_name = st.text_input( - # "Model name:", value="", help="No spaces or special characters please", key="uploaded_model_name") - # uploaded_model_images = st.file_uploader("Please add at least 2 sample images from this model:", type=[ - # 'png', 'jpg', 'jpeg'], key="uploaded_prompt_file", accept_multiple_files=True) - # uploaded_link_to_model = st.text_input( - # "Link to model:", value="", key="uploaded_link_to_model") - # st.info("The model should be a direct link to a .safetensors files. You can find models on websites like: https://civitai.com/") - # if uploaded_model_name == "" or uploaded_link_to_model == "" or uploaded_model_images is None: - # st.write("Fill in all the fields to add a model from the internet.") - # st.button("Upload Model", disabled=True) - # else: - # if st.button("Upload Model", disabled=False): - # images_for_model = [] - # directory = "videos/training_data" - # if not os.path.exists(directory): - # os.makedirs(directory) - - # for image in uploaded_model_images: - # with open(os.path.join(f"videos/training_data", image.name), "wb") as f: - # f.write(image.getbuffer()) - # images_for_model.append(image.name) - # for i in range(len(images_for_model)): - # images_for_model[i] = 'videos/training_data/' + \ - # images_for_model[i] - # df = pd.read_csv("models.csv") - # df = df.append({}, ignore_index=True) - # new_row_index = df.index[-1] - # df.iloc[new_row_index, 0] = uploaded_model_name - # df.iloc[new_row_index, 4] = str(images_for_model) - # df.iloc[new_row_index, 5] = uploaded_type_of_model - # df.iloc[new_row_index, 6] = uploaded_link_to_model - # df.to_csv("models.csv", index=False) - # st.success( - # f"Successfully uploaded - the model '{model_name}' is now available for use!") - # time.sleep(1.5) - # st.rerun() +# NOTE: code not in use +# def custom_models_page(project_uuid): +# data_repo = DataRepo() + +# with st.expander("Existing models", expanded=True): + +# st.subheader("Existing Models:") + +# # TODO: the list should only show models trained by the user (not the default ones) +# current_user_uuid = get_current_user_uuid() +# model_list: List[InternalAIModelObject] = data_repo.get_all_ai_model_list(\ +# model_category_list=[AIModelCategory.DREAMBOOTH.value, AIModelCategory.LORA.value], \ +# user_id=current_user_uuid, custom_trained=True) +# if model_list == []: +# st.info("You don't have any models yet. Train a new model below.") +# else: +# header1, header2, header3, header4, header5, header6 = st.columns( +# 6) +# with header1: +# st.markdown("###### Model Name") +# with header2: +# st.markdown("###### Trigger Word") +# with header3: +# st.markdown("###### Model Type") +# with header4: +# st.markdown("###### Example Image #1") +# with header5: +# st.markdown("###### Example Image #2") +# with header6: +# st.markdown("###### Example Image #3") + +# for model in model_list: +# col1, col2, col3, col4, col5, col6 = st.columns(6) +# with col1: +# st.text(model.name) +# with col2: +# if model.keyword != "": +# st.text(model.keyword) +# with col3: +# if model.category != "": +# st.text(model.category) +# with col4: +# if len(model.training_image_list) > 0: +# st.image(model.training_image_list[0].location) +# with col5: +# if len(model.training_image_list) > 1: +# st.image(model.training_image_list[1].location) +# with col6: +# if len(model.training_image_list) > 2: +# st.image(model.training_image_list[2].location) +# st.markdown("***") + +# with st.expander("Train a new model", expanded=True): +# st.subheader("Train a new model:") + +# type_of_model = st.selectbox("Type of model:", [AIModelCategory.DREAMBOOTH.value, AIModelCategory.LORA.value], help="If you'd like to use other methods for model training, let us know - or implement it yourself :)") +# model_name = st.text_input( +# "Model name:", value="", help="No spaces or special characters please") + +# if type_of_model == AIModelCategory.DREAMBOOTH.value: +# instance_prompt = st.text_input( +# "Trigger word:", value="", help="This is the word that will trigger the model") +# class_prompt = st.text_input("Describe what your prompts depict generally:", +# value="", help="This will help guide the model to learn what you want it to do") +# max_train_steps = st.number_input( +# "Max training steps:", value=2000, help=" The number of training steps to run. Fewer steps make it run faster but typically make it worse quality, and vice versa.") +# type_of_task = "" +# resolution = "" +# controller_type = st.selectbox("What ControlNet controller would you like to use?", [ +# "normal", "canny", "hed", "scribble", "seg", "openpose", "depth", "mlsd"]) +# model_type_list = json.dumps([AIModelType.TXT2IMG.value]) + +# elif type_of_model == AIModelCategory.LORA.value: +# type_of_task = st.selectbox( +# "Type of task:", ["Face", "Object", "Style"]).lower() +# resolution = st.selectbox("Resolution:", [ +# "512", "768", "1024"], help="The resolution for input images. All the images in the train/validation dataset will be resized to this resolution.") +# instance_prompt = "" +# class_prompt = "" +# max_train_steps = "" +# controller_type = "" +# model_type_list = json.dumps([AIModelType.TXT2IMG.value]) + +# uploaded_files = st.file_uploader("Images you'd like to train the model based on:", type=[ +# 'png', 'jpg', 'jpeg'], key="prompt_file", accept_multiple_files=True) +# if uploaded_files is not None: +# column = 0 +# for image in uploaded_files: +# # if it's an even number +# if uploaded_files.index(image) % 2 == 0: +# column = column + 1 +# row_1_key = str(column) + 'a' +# row_2_key = str(column) + 'b' +# row_1_key, row_2_key = st.columns([1, 1]) +# with row_1_key: +# st.image( +# uploaded_files[uploaded_files.index(image)], width=300) +# else: +# with row_2_key: +# st.image( +# uploaded_files[uploaded_files.index(image)], width=300) + +# st.write(f"You've selected {len(uploaded_files)} images.") + +# if len(uploaded_files) <= 5 and model_name == "": +# st.write( +# "Select at least 5 images and fill in all the fields to train a new model.") +# st.button("Train Model", disabled=True) +# else: +# if st.button("Train Model", disabled=False): +# st.info("Loading...") + +# # TODO: check the local storage +# # directory = "videos/training_data" +# # if not os.path.exists(directory): +# # os.makedirs(directory) + +# # for image in uploaded_files: +# # with open(os.path.join(f"videos/training_data", image.name), "wb") as f: +# # f.write(image.getbuffer()) +# # images_for_model.append(image.name) +# model_status = train_model(uploaded_files, instance_prompt, class_prompt, max_train_steps, +# model_name, type_of_model, type_of_task, resolution, controller_type, model_type_list) +# st.success(model_status) + +# # with st.expander("Add model from internet"): +# # st.subheader("Add a model the internet:") +# # uploaded_type_of_model = st.selectbox("Type of model:", [ +# # "LoRA", "Dreambooth"], key="uploaded_type_of_model", disabled=True, help="You can currently only upload LoRA models - this will change soon.") +# # uploaded_model_name = st.text_input( +# # "Model name:", value="", help="No spaces or special characters please", key="uploaded_model_name") +# # uploaded_model_images = st.file_uploader("Please add at least 2 sample images from this model:", type=[ +# # 'png', 'jpg', 'jpeg'], key="uploaded_prompt_file", accept_multiple_files=True) +# # uploaded_link_to_model = st.text_input( +# # "Link to model:", value="", key="uploaded_link_to_model") +# # st.info("The model should be a direct link to a .safetensors files. You can find models on websites like: https://civitai.com/") +# # if uploaded_model_name == "" or uploaded_link_to_model == "" or uploaded_model_images is None: +# # st.write("Fill in all the fields to add a model from the internet.") +# # st.button("Upload Model", disabled=True) +# # else: +# # if st.button("Upload Model", disabled=False): +# # images_for_model = [] +# # directory = "videos/training_data" +# # if not os.path.exists(directory): +# # os.makedirs(directory) + +# # for image in uploaded_model_images: +# # with open(os.path.join(f"videos/training_data", image.name), "wb") as f: +# # f.write(image.getbuffer()) +# # images_for_model.append(image.name) +# # for i in range(len(images_for_model)): +# # images_for_model[i] = 'videos/training_data/' + \ +# # images_for_model[i] +# # df = pd.read_csv("models.csv") +# # df = df.append({}, ignore_index=True) +# # new_row_index = df.index[-1] +# # df.iloc[new_row_index, 0] = uploaded_model_name +# # df.iloc[new_row_index, 4] = str(images_for_model) +# # df.iloc[new_row_index, 5] = uploaded_type_of_model +# # df.iloc[new_row_index, 6] = uploaded_link_to_model +# # df.to_csv("models.csv", index=False) +# # st.success( +# # f"Successfully uploaded - the model '{model_name}' is now available for use!") +# # time.sleep(1.5) +# # st.rerun() diff --git a/ui_components/components/explorer_page.py b/ui_components/components/explorer_page.py index 74d6b7c7..02cfb8a5 100644 --- a/ui_components/components/explorer_page.py +++ b/ui_components/components/explorer_page.py @@ -1,172 +1,190 @@ import json import streamlit as st -from ui_components.methods.common_methods import process_inference_output,add_new_shot, save_uploaded_image +from ui_components.methods.common_methods import get_canny_img, process_inference_output,add_new_shot, save_new_image, save_uploaded_image from ui_components.methods.file_methods import generate_pil_image from ui_components.methods.ml_methods import query_llama2 from ui_components.widgets.add_key_frame_element import add_key_frame from utils.common_utils import refresh_app from utils.constants import MLQueryObject from utils.data_repo.data_repo import DataRepo -from shared.constants import QUEUE_INFERENCE_QUERIES, AIModelType, InferenceType, InternalFileTag, InternalFileType, SortOrder +from shared.constants import GPU_INFERENCE_ENABLED, QUEUE_INFERENCE_QUERIES, AIModelType, InferenceType, InternalFileTag, InternalFileType, SortOrder from utils import st_memory import time from utils.enum import ExtendedEnum from utils.ml_processor.ml_interface import get_ml_client -from utils.ml_processor.replicate.constants import REPLICATE_MODEL -from PIL import Image, ImageFilter -import io -import cv2 +from utils.ml_processor.constants import ML_MODEL import numpy as np from utils import st_memory class InputImageStyling(ExtendedEnum): - EVOLVE_IMAGE = "Evolve Image" - MAINTAIN_STRUCTURE = "Maintain Structure" + TEXT2IMAGE = "Text to Image" + IMAGE2IMAGE = "Image to Image" + CONTROLNET_CANNY = "ControlNet Canny" + IPADAPTER_FACE = "IP-Adapter Face" + IPADAPTER_PLUS = "IP-Adapter Plus" + IPADPTER_FACE_AND_PLUS = "IP-Adapter Face & Plus" -def columnn_selecter(): - f1, f2 = st.columns([1, 1]) - with f1: - st_memory.number_input('Number of columns:', min_value=3, max_value=7, value=4,key="num_columns_explorer") - with f2: - st_memory.number_input('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer") - def explorer_page(project_uuid): - data_repo = DataRepo() - project_setting = data_repo.get_project_setting(project_uuid) - st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}]") st.markdown("***") - z1, z2, z3 = st.columns([0.25,2,0.25]) - with z2: - with st.expander("Prompt Settings", expanded=True): - generate_images_element(position='explorer', project_uuid=project_uuid, timing_uuid=None) - st.markdown("***") - columnn_selecter() - k1,k2 = st.columns([5,1]) - page_number = k1.radio("Select page:", options=range(1, project_setting.total_gallery_pages + 1), horizontal=True, key="main_gallery") - open_detailed_view_for_all = k2.toggle("Open detailed view for all:", key='main_gallery_toggle') - st.markdown("***") - gallery_image_view(project_uuid, page_number, st.session_state['num_items_per_page_explorer'], open_detailed_view_for_all, False, st.session_state['num_columns_explorer'],view="explorer") - -def shortlist_element(project_uuid): - data_repo = DataRepo() - project_setting = data_repo.get_project_setting(project_uuid) - columnn_selecter() - k1,k2 = st.columns([5,1]) - shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") - with k2: - open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle') + with st.expander("✨ Generate Images", expanded=True): + generate_images_element(position='explorer', project_uuid=project_uuid, timing_uuid=None) st.markdown("***") - gallery_image_view(project_uuid, shortlist_page_number, st.session_state['num_items_per_page_explorer'], open_detailed_view_for_all, True, st.session_state['num_columns_explorer'],view="shortlist") - + gallery_image_view(project_uuid,False,view=['add_and_remove_from_shortlist','view_inference_details']) def generate_images_element(position='explorer', project_uuid=None, timing_uuid=None): data_repo = DataRepo() project_settings = data_repo.get_project_setting(project_uuid) help_input='''This will generate a specific prompt based on your input.\n\n For example, "Sad scene of old Russian man, dreary style" might result in "Boris Karloff, 80 year old man wearing a suit, standing at funeral, dark blue watercolour."''' a1, a2, a3 = st.columns([1,1,0.3]) - with a1 if 'switch_prompt_position' not in st.session_state or st.session_state['switch_prompt_position'] == False else a2: prompt = st_memory.text_area("What's your base prompt?", key="explorer_base_prompt", help="This exact text will be included for each generation.") with a2 if 'switch_prompt_position' not in st.session_state or st.session_state['switch_prompt_position'] == False else a1: - magic_prompt = st_memory.text_area("What's your magic prompt?", key="explorer_magic_prompt", help=help_input) - #if magic_prompt != "": - # chaos_level = st_memory.slider("How much chaos would you like to add to the magic prompt?", min_value=0, max_value=100, value=20, step=1, key="chaos_level", help="This will determine how random the generated prompt will be.") - # temperature = chaos_level / 20 - temperature = 1.0 - with a3: - st.write("") - st.write("") - st.write("") - if st.button("πŸ”„", key="switch_prompt_position_button", use_container_width=True, help="This will switch the order the prompt and magic prompt are used - earlier items gets more attention."): - st.session_state['switch_prompt_position'] = not st.session_state.get('switch_prompt_position', False) - st.experimental_rerun() - - neg1, _ = st.columns([1.5,1]) - with neg1: - negative_prompt = st_memory.text_input("Negative prompt", value="bad image, worst image, bad anatomy, washed out colors",\ + negative_prompt = st_memory.text_area("Negative prompt:", value="bad image, worst image, bad anatomy, washed out colors",\ key="explorer_neg_prompt", \ help="These are the things you wish to be excluded from the image") - if position=='explorer': - _, b1, b2, b3, _ = st.columns([0.1,1.25,2,2,0.1]) - _, c1, c2, _ = st.columns([1,2,2,1]) - else: - b1, b2, b3 = st.columns([1,2,1]) - c1, c2, _ = st.columns([2,2,2]) - + + b1, b2, b3, _ = st.columns([1.5,1,1.5,1]) with b1: - use_input_image = st_memory.checkbox("Use input image", key="use_input_image", value=False) - - - if use_input_image: - with b2: - type_of_transformation = st_memory.radio("What type of transformation would you like to do?", options=InputImageStyling.value_list(), key="type_of_transformation_key", help="Evolve Image will evolve the image based on the prompt, while Maintain Structure will keep the structure of the image and change the style.",horizontal=True) + type_of_generation = st_memory.radio("How would you like to generate the image?", options=InputImageStyling.value_list(), key="type_of_generation_key", help="Evolve Image will evolve the image based on the prompt, while Maintain Structure will keep the structure of the image and change the style.",horizontal=True) - with c1: - input_image_key = f"input_image_{position}" - if input_image_key not in st.session_state: - st.session_state[input_image_key] = None + input_image_1_key = "input_image_1" + input_image_2_key = "input_image_2" + if input_image_1_key not in st.session_state: + st.session_state[input_image_1_key] = None + st.session_state[input_image_2_key] = None - input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"], key="explorer_input_image", help="This will be the base image for the generation.") - if st.button("Upload", use_container_width=True): - st.session_state[input_image_key] = input_image + uploaded_image_1 = None + uploaded_image_2 = None - with b3: - edge_pil_img = None - strength_of_current_image = st_memory.number_input("What % of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_current_image_key", help="This will determine how much of the current image will be kept in the final image.") - if type_of_transformation == InputImageStyling.EVOLVE_IMAGE.value: - prompt_strength = round(1 - (strength_of_current_image / 100), 2) - with c2: - if st.session_state[input_image_key] is not None: - input_image_bytes = st.session_state[input_image_key].getvalue() - pil_image = Image.open(io.BytesIO(input_image_bytes)) - blur_radius = (100 - strength_of_current_image) / 3 # Adjust this formula as needed - blurred_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius)) - st.image(blurred_image, use_column_width=True) - - elif type_of_transformation == InputImageStyling.MAINTAIN_STRUCTURE.value: - condition_scale = strength_of_current_image / 10 - with c2: - if st.session_state[input_image_key] is not None: - input_image_bytes = st.session_state[input_image_key] .getvalue() - pil_image = Image.open(io.BytesIO(input_image_bytes)) - cv_image = np.array(pil_image) - gray_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2GRAY) - lower_threshold = (100 - strength_of_current_image) * 3 - upper_threshold = lower_threshold * 3 - edges = cv2.Canny(gray_image, lower_threshold, upper_threshold) - edge_pil_img = Image.fromarray(edges) - st.image(edge_pil_img, use_column_width=True) - st.markdown("***") + # these require two images + ipadapter_types = [InputImageStyling.IPADPTER_FACE_AND_PLUS.value] + + # UI for image input if type_of_generation is not txt2img + if type_of_generation != InputImageStyling.TEXT2IMAGE.value: + # UI - Base Input + with b2: + source_of_starting_image = st_memory.radio("Image source:", options=["Upload", "From Shot"], key="source_of_starting_image", help="This will be the base image for the generation.",horizontal=True) + # image upload + if source_of_starting_image == "Upload": + uploaded_image_1 = st.file_uploader("Upload a starting image", type=["png", "jpg", "jpeg"], key="explorer_input_image", help="This will be the base image for the generation.") + # taking image from shots + else: + shot_list = data_repo.get_shot_list(project_uuid) + selection1, selection2 = st.columns([1,1]) + with selection1: + shot_name = st.selectbox("Shot:", options=[shot.name for shot in shot_list], key="explorer_shot_uuid", help="This will be the base image for the generation.") + + shot_uuid = [shot.uuid for shot in shot_list if shot.name == shot_name][0] + frame_list = data_repo.get_timing_list_from_shot(shot_uuid) + + with selection2: + list_of_timings = [i + 1 for i in range(len(frame_list))] + timing = st.selectbox("Frame #:", options=list_of_timings, key="explorer_frame_number", help="This will be the base image for the generation.") + uploaded_image_1 = frame_list[timing - 1].primary_image.location + # make it a byte stream + st.image(frame_list[timing - 1].primary_image.location, use_column_width=True) - else: - input_image = None - type_of_transformation = None - strength_of_current_image = None - # st.markdown("***") - model_name = "stable_diffusion_xl" - if position=='explorer': + # taking a second image in the case of ip_adapter_face_plus + if type_of_generation in ipadapter_types: + source_of_starting_image_2 = st_memory.radio("How would you like to upload the second starting image?", options=["Upload", "From Shot"], key="source_of_starting_image_2", help="This will be the base image for the generation.",horizontal=True) + if source_of_starting_image_2 == "Upload": + uploaded_image_2 = st.file_uploader("IP-Adapter Face image:", type=["png", "jpg", "jpeg"], key="explorer_input_image_2", help="This will be the base image for the generation.") + else: + selection1, selection2 = st.columns([1,1]) + with selection1: + shot_list = data_repo.get_shot_list(project_uuid) + shot_name = st.selectbox("Shot:", options=[shot.name for shot in shot_list], key="explorer_shot_uuid_2", help="This will be the base image for the generation.") + shot_uuid = [shot.uuid for shot in shot_list if shot.name == shot_name][0] + with selection2: + frame_list = data_repo.get_timing_list_from_shot(shot_uuid) + list_of_timings = [i + 1 for i in range(len(frame_list))] + timing = st.selectbox("Frame #:", options=list_of_timings, key="explorer_frame_number_2", help="This will be the base image for the generation.") + uploaded_image_2 = frame_list[timing - 1].primary_image.location + st.image(frame_list[timing - 1].primary_image.location, use_column_width=True) + + # if type type is face and plus, then we need to make the text images + button_text = "Upload Images" if type_of_generation in ipadapter_types else "Upload Image" + + if st.button(button_text, use_container_width=True): + st.session_state[input_image_1_key] = uploaded_image_1 + st.session_state[input_image_2_key] = uploaded_image_2 + st.rerun() + + # UI - Preview + with b3: + # prompt_strength = round(1 - (strength_of_image / 100), 2) + if type_of_generation not in ipadapter_types: + st.info("Current image:") + if st.session_state[input_image_1_key] is not None: + st.image(st.session_state[input_image_1_key], use_column_width=True) + else: + st.error("Please upload an image") + + if type_of_generation == InputImageStyling.IMAGE2IMAGE.value: + strength_of_image = st_memory.slider("How much blur would you like to add to the image?", min_value=0, max_value=100, value=50, step=1, key="strength_of_image2image", help="This will determine how much of the current image will be kept in the final image.") + + elif type_of_generation == InputImageStyling.CONTROLNET_CANNY.value: + strength_of_image = st_memory.slider("How much of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_controlnet_canny", help="This will determine how much of the current image will be kept in the final image.") + + elif type_of_generation == InputImageStyling.IPADAPTER_FACE.value: + strength_of_image = st_memory.slider("How much of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_ipadapter_face", help="This will determine how much of the current image will be kept in the final image.") + + elif type_of_generation == InputImageStyling.IPADAPTER_PLUS.value: + strength_of_image = st_memory.slider("How much of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_ipadapter_plus", help="This will determine how much of the current image will be kept in the final image.") + + elif type_of_generation == InputImageStyling.IPADPTER_FACE_AND_PLUS.value: + # UI - displaying uploaded images + + # NOTE: hackish sol, text of plus and face are interchanged, will fix later + st.info("IP-Adapter Plus image:") + if st.session_state[input_image_1_key] is not None: + st.image(st.session_state[input_image_1_key], use_column_width=True) + strength_of_face = st_memory.slider("How strong would would you like the Face model to influence?", min_value=0, max_value=100, value=50, step=1, key="strength_of_ipadapter_face", help="This will determine how much of the current image will be kept in the final image.") + else: + st.error("Please upload an image") + + st.info("IP-Adapter Face image:") + if st.session_state[input_image_2_key] is not None: + st.image(st.session_state[input_image_2_key], use_column_width=True) + strength_of_plus = st_memory.slider("How strong would you like to influence the Plus model?", min_value=0, max_value=100, value=50, step=1, key="strength_of_ipadapter_plus", help="This will determine how much of the current image will be kept in the final image.") + else: + st.error("Please upload an second image") + + # UI - clear btn + if st.session_state[input_image_1_key] is not None: + with b3: + if st.button("Clear input image(s)", key="clear_input_image", use_container_width=True): + st.session_state[input_image_1_key] = None + st.session_state[input_image_2_key] = None + st.rerun() + + if position == 'explorer': _, d2,d3, _ = st.columns([0.25, 1,1, 0.25]) else: - d2,d3 = st.columns([1,1]) + d2, d3 = st.columns([1,1]) with d2: - number_to_generate = st.slider("How many images would you like to generate?", min_value=0, max_value=100, value=4, step=4, key="number_to_generate", help="It'll generate 4 from each variation.") + number_to_generate = st.slider("How many images would you like to generate?", min_value=0, max_value=100, value=4, step=2, key="number_to_generate", help="It'll generate 4 from each variation.") with d3: - st.write(" ") + st.write(" ") + # ------------------- Generating output ------------------------------------- if st.session_state.get(position + '_generate_inference'): ml_client = get_ml_client() counter = 0 + + magic_prompt, temperature = "", 0 for _ in range(number_to_generate): - + + ''' if counter % 4 == 0: if magic_prompt != "": input_text = "I want to flesh the following user input out - could you make it such that it retains the original meaning but is more specific and descriptive:\n\nfloral background|array of colorful wildflowers and green foliage forms a vibrant, natural backdrop.\nfancy old man|Barnaby Jasper Hawthorne, a dignified gentleman in his late seventies\ncomic book style|illustration style of a 1960s superhero comic book\nsky with diamonds|night sky filled with twinkling stars like diamonds on velvet\n20 y/o indian guy|Piyush Ahuja, a twenty-year-old Indian software engineer\ndark fantasy|a dark, gothic style similar to an Edgar Allen Poe novel\nfuturistic world|set in a 22nd century off-world colony called Ajita Iyera\nbeautiful lake|the crystal clear waters of a luminous blue alpine mountain lake\nminimalistic illustration|simple illustration with solid colors and basic geometrical shapes and figures\nmale blacksmith|Arun Thakkar, a Black country village blacksmith\ndesert sunrise|reddish orange sky at sunrise somewhere out in the Arabia desert\nforest|dense forest of Swedish pine trees\ngreece landscape|bright cyan sky meets turquoise on Santorini\nspace|shifting nebula clouds across the endless expanse of deep space\nwizard orcs|Poljak Ardell, a half-orc warlock\ntropical island|Palm tree-lined tropical paradise beach near Corfu\ncyberpunk cityscape |Neon holo displays reflect from steel surfaces of buildings in Cairo Cyberspace\njapanese garden & pond|peaceful asian zen koi fishpond surrounded by bonsai trees\nattractive young african woman|Chimene Nkasa, young Congolese social media star\ninsane style|wild and unpredictable artwork like Salvador Dali’s Persistence Of Memory painting\n30s european women|Francisca Sampere, 31 year old Spanish woman\nlighthouse|iconic green New England coastal lighthouse against grey sky\ngirl in hat|Dora Alamanni dressed up with straw boater hat\nretro poster design|stunning vintage 80s movie poster reminiscent of Blade Runner\nabstract color combinations|a modernist splatter painting with overlapping colors\nnordic style |simple line drawing of white on dark blue with clean geometrical figures and shapes\nyoung asian woman, abstract style|Kaya Suzuki's face rendered in bright, expressive brush strokes\nblue monster|large cobalt blue cartoonish creature similar to a yeti\nman at work|portrait sketch of business man working late night in the office\nunderwater sunbeams|aquatic creatures swimming through waves of refracting ocean sunlight\nhappy cat on table|tabby kitten sitting alert in anticipation on kitchen counter\ntop​\nold timey train robber|Wiley Hollister, mid-thirties outlaw\nchinese landscape|Mt. Taihang surrounded by clouds\nancient ruins, sci fi style|deserted ancient civilization under stormy ominous sky full of mysterious UFOs\nanime art|classic anime, in the style of Akira Toriyama\nold man, sad scene|Seneca Hawkins, older gentleman slumped forlorn on street bench in early autumn evening\ncathedral|interior view of Gothic church in Vienna\ndreamlike|spellbinding dreamlike atmosphere, work called Pookanaut\nbird on lake, evening time|grizzled kingfisher sitting regally facing towards beautiful ripple-reflected setting orange pink sum\nyoung female character, cutsey style|Aoife Delaney dressed up as Candyflud, cheerful child adventurer\ninteresting style|stunning cubist abstract geometrical block\nevil woman|Luisa Schultze, frightening murderess\nfashion model|Ishita Chaudry, an Indian fashionista with unique dress sense\ncastle, moody scene|grand Renaissance Palace in Prague against twilight mist filled with crows\ntropical paradise island|Pristine white sand beach with palm trees at Ile du Mariasi, Reunion\npoverty stricken village|simple shack-based settlement in rural Niger\ngothic horror creature|wretchedly deformed and hideous tatter-clad creature like Caliban from Shakespeare ’s Tempes\nlots of color|rainbow colored Dutch flower field\nattractive woman on holidays|Siena Chen in her best little black dress, walking down a glamorous Las Vegas Boulevard\nItalian city scene|Duomo di Milano on dark rainy night sky behind it\nhappy dog outdoor|bouncy Irish Setter frolickling around green grass in summer sun\nmedieval fantasy world|illustration work for Eye Of The Titan - novel by Rania D’Allara\nperson relaxing|Alejandro Gonzalez sitting crosslegged in elegant peacock blue kurta while reading book\nretro sci fi robot|Vintage, cartoonish android reminiscent of the Bender Futurama character. Named Clyde Frost.\ngeometric style|geometric abstract style based on 1960 Russian poster design by Alexander Rodchenk \nbeautiful girl face, vaporwave style|Rayna Vratasky, looking all pink and purple retro\nspooking |horrifying Chupacabra-like being staring intensely to camera\nbrazilian woman having fun|Analia Santos, playing puzzle game with friends\nfemale elf warrior|Finnula Thalas, an Eladrin paladin wielding two great warblades\nlsd trip scene|kaleidoscopic colorscape, filled with ephemerally shifting forms\nyoung african man headshot|Roger Mwafulo looking sharp with big lush smile\nsad or dying person|elderly beggar Jeon Hagopian slumped against trash can bin corner\nart |neurologically inspired psychedelian artwork like David Normal's β€œSentient Energy ” series\nattractive german woman|Johanna Hecker, blonde beauty with long hair wrapped in braid ties\nladybug|Cute ladybug perched on red sunset flower petals on summery meadow backdrop\nbeautiful asian women |Chiraya Phetlue, Thai-French model standing front view wearing white dress\nmindblowing style|trippy space illustration that could be cover for a book by Koyu Azumi\nmoody|forest full of thorn trees stretching into the horizon at dusk\nhappy family, abstract style|illustration work of mother, father and child from 2017 children’s picture book The Gifts Of Motherhood By Michelle Sparks\n" @@ -178,75 +196,153 @@ def generate_images_element(position='explorer', project_uuid=None, timing_uuid= else: # switch_prompt_position is True prompt_with_variations = f"{output_magic_prompt}, {prompt}" if prompt else output_magic_prompt - - + ''' counter += 1 log = None - if not input_image: + generation_method = InputImageStyling.value_list()[st.session_state['type_of_generation_key']] + if generation_method == InputImageStyling.TEXT2IMAGE.value: query_obj = MLQueryObject( timing_uuid=None, model_uuid=None, - guidance_scale=5, + guidance_scale=8, seed=-1, - num_inference_steps=30, - strength=1, + num_inference_steps=25, + strength=0.5, adapter_type=None, - prompt=prompt_with_variations, + prompt=prompt, negative_prompt=negative_prompt, height=project_settings.height, width=project_settings.width, project_uuid=project_uuid ) - model_list = data_repo.get_all_ai_model_list(model_type_list=[AIModelType.TXT2IMG.value], custom_trained=False) - model_dict = {} - for m in model_list: - model_dict[m.name] = m + # NOTE: code not is use + # model_list = data_repo.get_all_ai_model_list(model_type_list=[AIModelType.TXT2IMG.value], custom_trained=False) + # model_dict = {} + # for m in model_list: + # model_dict[m.name] = m + # replicate_model = ML_MODEL.get_model_by_db_obj(model_dict[model_name]) - replicate_model = REPLICATE_MODEL.get_model_by_db_obj(model_dict[model_name]) - output, log = ml_client.predict_model_output_standardized(replicate_model, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) + output, log = ml_client.predict_model_output_standardized(ML_MODEL.sdxl, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) + + elif generation_method == InputImageStyling.IMAGE2IMAGE.value: + input_image_file = save_new_image(st.session_state[input_image_1_key], project_uuid) + query_obj = MLQueryObject( + timing_uuid=None, + model_uuid=None, + image_uuid=input_image_file.uuid, + guidance_scale=5, + seed=-1, + num_inference_steps=30, + strength=0.8, + adapter_type=None, + prompt=prompt, + negative_prompt=negative_prompt, + height=project_settings.height, + width=project_settings.width, + project_uuid=project_uuid + ) - else: - if type_of_transformation == InputImageStyling.EVOLVE_IMAGE.value: - input_image_file = save_uploaded_image(input_image, project_uuid) - query_obj = MLQueryObject( - timing_uuid=None, - model_uuid=None, - image_uuid=input_image_file.uuid, - guidance_scale=5, - seed=-1, - num_inference_steps=30, - strength=prompt_strength, - adapter_type=None, - prompt=prompt, - negative_prompt=negative_prompt, - height=project_settings.height, - width=project_settings.width, - project_uuid=project_uuid - ) - - output, log = ml_client.predict_model_output_standardized(REPLICATE_MODEL.sdxl, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) - - elif type_of_transformation == InputImageStyling.MAINTAIN_STRUCTURE.value: - input_image_file = save_uploaded_image(edge_pil_img, project_uuid) - query_obj = MLQueryObject( - timing_uuid=None, - model_uuid=None, - image_uuid=input_image_file.uuid, - guidance_scale=5, - seed=-1, - num_inference_steps=30, - strength=0.5, - adapter_type=None, - prompt=prompt, - negative_prompt=negative_prompt, - height=project_settings.height, - width=project_settings.width, - project_uuid=project_uuid, - data={'condition_scale': condition_scale} - ) - - output, log = ml_client.predict_model_output_standardized(REPLICATE_MODEL.sdxl_controlnet, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) + output, log = ml_client.predict_model_output_standardized(ML_MODEL.sdxl_img2img, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) + + elif generation_method == InputImageStyling.CONTROLNET_CANNY.value: + edge_pil_img = get_canny_img(st.session_state[input_image_1_key], low_threshold=50, high_threshold=150) # redundant incase of local inference + input_img = edge_pil_img if not GPU_INFERENCE_ENABLED else st.session_state[input_image_1_key] + input_image_file = save_new_image(input_img, project_uuid) + query_obj = MLQueryObject( + timing_uuid=None, + model_uuid=None, + image_uuid=input_image_file.uuid, + guidance_scale=8, + seed=-1, + num_inference_steps=30, + strength=strength_of_image/100, + adapter_type=None, + prompt=prompt, + low_threshold=0.2, + high_threshold=0.7, + negative_prompt=negative_prompt, + height=project_settings.height, + width=project_settings.width, + project_uuid=project_uuid, + data={'condition_scale': 1} + ) + + output, log = ml_client.predict_model_output_standardized(ML_MODEL.sdxl_controlnet, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) + + elif generation_method == InputImageStyling.IPADAPTER_FACE.value: + # validation + if not (st.session_state[input_image_1_key]): + st.error('Please upload an image') + return + + input_image_file = save_new_image(st.session_state[input_image_1_key], project_uuid) + query_obj = MLQueryObject( + timing_uuid=None, + model_uuid=None, + image_uuid=input_image_file.uuid, + guidance_scale=5, + seed=-1, + num_inference_steps=30, + strength=strength_of_image/100, + adapter_type=None, + prompt=prompt, + negative_prompt=negative_prompt, + height=project_settings.height, + width=project_settings.width, + project_uuid=project_uuid, + data={} + ) + + output, log = ml_client.predict_model_output_standardized(ML_MODEL.ipadapter_face, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) + + elif generation_method == InputImageStyling.IPADAPTER_PLUS.value: + input_image_file = save_new_image(st.session_state[input_image_1_key], project_uuid) + query_obj = MLQueryObject( + timing_uuid=None, + model_uuid=None, + image_uuid=input_image_file.uuid, + guidance_scale=5, + seed=-1, + num_inference_steps=30, + strength=strength_of_image/100, + adapter_type=None, + prompt=prompt, + negative_prompt=negative_prompt, + height=project_settings.height, + width=project_settings.width, + project_uuid=project_uuid, + data={'condition_scale': 1} + ) + + output, log = ml_client.predict_model_output_standardized(ML_MODEL.ipadapter_plus, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) + + elif generation_method == InputImageStyling.IPADPTER_FACE_AND_PLUS.value: + # validation + if not (st.session_state[input_image_2_key] and st.session_state[input_image_1_key]): + st.error('Please upload both images') + return + + plus_image_file = save_new_image(st.session_state[input_image_1_key], project_uuid) + face_image_file = save_new_image(st.session_state[input_image_2_key], project_uuid) + query_obj = MLQueryObject( + timing_uuid=None, + model_uuid=None, + image_uuid=plus_image_file.uuid, + guidance_scale=5, + seed=-1, + num_inference_steps=30, + strength=(strength_of_face/100, strength_of_plus/100), # (face, plus) + adapter_type=None, + prompt=prompt, + negative_prompt=negative_prompt, + height=project_settings.height, + width=project_settings.width, + project_uuid=project_uuid, + data={'file_image_2_uuid': face_image_file.uuid} + ) + + output, log = ml_client.predict_model_output_standardized(ML_MODEL.ipadapter_face_plus, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) if log: inference_data = { @@ -260,11 +356,25 @@ def generate_images_element(position='explorer', project_uuid=None, timing_uuid= process_inference_output(**inference_data) st.info("Check the Generation Log to the left for the status.") + time.sleep(0.5) toggle_generate_inference(position) st.rerun() # ----------- generate btn -------------- - st.button("Generate images", key="generate_images", use_container_width=True, type="primary", on_click=lambda: toggle_generate_inference(position)) + if prompt == "": + st.button("Generate images", key="generate_images", use_container_width=True, type="primary", disabled=True, help="Please enter a prompt to generate images") + elif type_of_generation == InputImageStyling.IMAGE2IMAGE.value and st.session_state[input_image_1_key] is None: + st.button("Generate images", key="generate_images", use_container_width=True, type="primary", disabled=True, help="Please upload an image") + elif type_of_generation == InputImageStyling.CONTROLNET_CANNY.value and st.session_state[input_image_1_key] is None: + st.button("Generate images", key="generate_images", use_container_width=True, type="primary", disabled=True, help="Please upload an image") + elif type_of_generation == InputImageStyling.IPADAPTER_FACE.value and st.session_state[input_image_1_key] is None: + st.button("Generate images", key="generate_images", use_container_width=True, type="primary", disabled=True, help="Please upload an image") + elif type_of_generation == InputImageStyling.IPADAPTER_PLUS.value and st.session_state[input_image_1_key] is None: + st.button("Generate images", key="generate_images", use_container_width=True, type="primary", disabled=True, help="Please upload an image") + elif type_of_generation == InputImageStyling.IPADPTER_FACE_AND_PLUS.value and (st.session_state[input_image_1_key] is None or st.session_state[input_image_2_key] is None): + st.button("Generate images", key="generate_images", use_container_width=True, type="primary", disabled=True, help="Please upload both images") + else: + st.button("Generate images", key="generate_images", use_container_width=True, type="primary", on_click=lambda: toggle_generate_inference(position)) def toggle_generate_inference(position): if position + '_generate_inference' not in st.session_state: @@ -272,12 +382,35 @@ def toggle_generate_inference(position): else: st.session_state[position + '_generate_inference'] = not st.session_state[position + '_generate_inference'] - -def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_detailed_view_for_all=False, shortlist=False, num_columns=2, view="main", shot=None): +def gallery_image_view(project_uuid, shortlist=False, view=["main"], shot=None, sidebar=False): data_repo = DataRepo() - project_settings = data_repo.get_project_setting(project_uuid) shot_list = data_repo.get_shot_list(project_uuid) + k1,k2 = st.columns([5,1]) + + if sidebar != True: + f1, f2 = st.columns([1, 1]) + with f1: + num_columns = st_memory.slider('Number of columns:', min_value=3, max_value=7, value=4,key="num_columns_explorer") + with f2: + num_items_per_page = st_memory.slider('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer") + + if shortlist is False: + page_number = k1.radio("Select page:", options=range(1, project_settings.total_gallery_pages + 1), horizontal=True, key="main_gallery") + if 'view_inference_details' in view: + open_detailed_view_for_all = k2.toggle("Open detailed view for all:", key='main_gallery_toggle') + st.markdown("***") + else: + project_setting = data_repo.get_project_setting(project_uuid) + page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") + open_detailed_view_for_all = False + st.markdown("***") + else: + project_setting = data_repo.get_project_setting(project_uuid) + page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") + open_detailed_view_for_all = False + num_items_per_page = 8 + num_columns = 2 gallery_image_list, res_payload = data_repo.get_all_file_list( file_type=InternalFileType.IMAGE.value, @@ -297,21 +430,46 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de project_settings.total_shortlist_gallery_pages = res_payload['total_pages'] st.rerun() - # def is_image_truncated(image_path): - # try: - # img = Image.open(image_path) - # img.verify() # verify that it is, in fact an image - # except (IOError, SyntaxError) as e: - # return True - # return False - + if shortlist is False: + _, fetch2, fetch3, _ = st.columns([0.25, 1, 1, 0.25]) + # st.markdown("***") + explorer_stats = data_repo.get_explorer_pending_stats(project_uuid=project_uuid) + + if explorer_stats['temp_image_count'] + explorer_stats['pending_image_count']: + st.markdown("***") + + with fetch2: + total_number_pending = explorer_stats['temp_image_count'] + explorer_stats['pending_image_count'] + if total_number_pending: + + if explorer_stats['temp_image_count'] == 0 and explorer_stats['pending_image_count'] > 0: + st.info(f"###### {explorer_stats['pending_image_count']} images pending generation") + button_text = "Check for new images" + elif explorer_stats['temp_image_count'] > 0 and explorer_stats['pending_image_count'] == 0: + st.info(f"###### {explorer_stats['temp_image_count']} new images generated") + button_text = "Pull new images" + else: + st.info(f"###### {explorer_stats['pending_image_count']} images pending generation and {explorer_stats['temp_image_count']} ready to be fetched") + button_text = "Check for/pull new images" + + # st.info(f"###### {total_number_pending} images pending generation") + # st.info(f"###### {explorer_stats['temp_image_count']} new images generated") + # st.info(f"###### {explorer_stats['pending_image_count']} images pending generation") + + with fetch3: + if st.button(f"{button_text}", key=f"check_for_new_images_", use_container_width=True): + if explorer_stats['temp_image_count']: + data_repo.update_temp_gallery_images(project_uuid) + st.success("New images fetched") + time.sleep(0.3) + st.rerun() + total_image_count = res_payload['count'] if gallery_image_list and len(gallery_image_list): start_index = 0 end_index = min(start_index + num_items_per_page, total_image_count) shot_names = [s.name for s in shot_list] - shot_names.append('**Create New Shot**') - shot_names.insert(0, '') + shot_names.append('**Create New Shot**') for i in range(start_index, end_index, num_columns): cols = st.columns(num_columns) for j in range(num_columns): @@ -320,16 +478,14 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de st.image(gallery_image_list[i + j].location, use_column_width=True) # else: # st.error("The image is truncated and cannot be displayed.") - if view in ["explorer", "shortlist"]: + if 'add_and_remove_from_shortlist' in view: if shortlist: if st.button("Remove from shortlist βž–", key=f"shortlist_{gallery_image_list[i + j].uuid}",use_container_width=True, help="Remove from shortlist"): data_repo.update_file(gallery_image_list[i + j].uuid, tag=InternalFileTag.GALLERY_IMAGE.value) st.success("Removed From Shortlist") time.sleep(0.3) st.rerun() - else: - if st.button("Add to shortlist βž•", key=f"shortlist_{gallery_image_list[i + j].uuid}",use_container_width=True, help="Add to shortlist"): data_repo.update_file(gallery_image_list[i + j].uuid, tag=InternalFileTag.SHORTLISTED_GALLERY_IMAGE.value) st.success("Added To Shortlist") @@ -343,7 +499,7 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de input_params = json.loads(log.input_params) prompt = input_params.get('prompt', 'No prompt found') model = json.loads(log.output_details)['model_name'].split('/')[-1] - if view in ["explorer", "shortlist","individual_shot"]: + if 'view_inference_details' in view: with st.expander("Prompt Details", expanded=open_detailed_view_for_all): st.info(f"**Prompt:** {prompt}\n\n**Model:** {model}") @@ -355,8 +511,8 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de # ---------- add to shot btn --------------- if "last_shot_number" not in st.session_state: st.session_state["last_shot_number"] = 0 - if view not in ["explorer", "shortlist"]: - if view == "individual_shot": + if 'add_to_this_shot' in view or 'add_to_any_shot' in view: + if 'add_to_this_shot' in view: shot_name = shot.name else: shot_name = st.selectbox('Add to shot:', shot_names, key=f"current_shot_sidebar_selector_{gallery_image_list[i + j].uuid}",index=st.session_state["last_shot_number"]) @@ -373,33 +529,14 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de else: if st.button(f"Add to shot", key=f"add_{gallery_image_list[i + j].uuid}", help="Promote this variant to the primary image", use_container_width=True): - shot_number = shot_names.index(shot_name) + 1 - st.session_state["last_shot_number"] = shot_number - 1 - shot_uuid = shot_list[shot_number - 2].uuid + shot_number = shot_names.index(shot_name) + st.session_state["last_shot_number"] = shot_number + shot_uuid = shot_list[shot_number].uuid add_key_frame(gallery_image_list[i + j], False, shot_uuid, len(data_repo.get_timing_list_from_shot(shot_uuid)), refresh_state=False) # removing this from the gallery view data_repo.update_file(gallery_image_list[i + j].uuid, tag="") - refresh_app(maintain_state=True) - + refresh_app(maintain_state=True) st.markdown("***") else: - st.warning("No images present") - - - -''' - -def update_max_frame_per_shot_element(project_uuid): - data_repo = DataRepo() - project_settings = data_repo.get_project_setting(project_uuid) - - - max_frames = st.number_input(label='Max frames per shot', min_value=1, value=project_settings.max_frames_per_shot) - - if max_frames != project_settings.max_frames_per_shot: - project_settings.max_frames_per_shot = max_frames - st.success("Updated") - time.sleep(0.3) - st.rerun() -''' \ No newline at end of file + st.warning("No images present") \ No newline at end of file diff --git a/ui_components/components/frame_styling_page.py b/ui_components/components/frame_styling_page.py index 4f252014..69fae677 100644 --- a/ui_components/components/frame_styling_page.py +++ b/ui_components/components/frame_styling_page.py @@ -21,46 +21,43 @@ def frame_styling_page(shot_uuid: str, h2): if len(timing_list) == 0: - with h2: - frame_selector_widget(show=['shot_selector','frame_selector']) - + st.markdown("#### There are no frames present in this shot yet.") else: with st.sidebar: - with h2: - - frame_selector_widget(show=['shot_selector','frame_selector']) - - st.session_state['styling_view'] = st_memory.menu('',\ - ["Generate", "Crop/Move", "Inpainting","Scribbling"], \ - icons=['magic', 'crop', "paint-bucket", 'pencil'], \ - menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \ - key="styling_view_selector", orientation="horizontal", \ - styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}) - + + st.session_state['styling_view'] = st_memory.menu('',\ + ["Generate", "Crop", "Inpaint"], \ + icons=['magic', 'crop', "paint-bucket", 'pencil'], \ + menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \ + key="styling_view_selector", orientation="horizontal", \ + styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}) + frame_view(view="Key Frame") - st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['frame_styling_view_type']}] > :orange[{st.session_state['styling_view']}] > :blue[{shot.name} - #{st.session_state['current_frame_index']}]") + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['frame_styling_view_type']}] > :orange[{shot.name} - #{st.session_state['current_frame_index']}] > :blue[{st.session_state['styling_view']}]") variant_comparison_grid(st.session_state['current_frame_uuid'], stage=CreativeProcessType.STYLING.value) - st.markdown("***") + if st.session_state['styling_view'] == "Generate": - with st.expander("πŸ› οΈ Generate Variants + Prompt Settings", expanded=True): + with st.expander("πŸ› οΈ Generate Variants", expanded=True): generate_images_element(position='individual', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) - elif st.session_state['styling_view'] == "Crop/Move": + elif st.session_state['styling_view'] == "Crop": with st.expander("🀏 Crop, Move & Rotate", expanded=True): cropping_selector_element(shot_uuid) - elif st.session_state['styling_view'] == "Inpainting": + elif st.session_state['styling_view'] == "Inpaint": with st.expander("🌌 Inpainting", expanded=True): inpainting_element(st.session_state['current_frame_uuid']) - elif st.session_state['styling_view'] == "Scribbling": + elif st.session_state['styling_view'] == "Scribble": with st.expander("πŸ“ Draw On Image", expanded=True): drawing_element(shot_uuid) + st.markdown("***") + \ No newline at end of file diff --git a/ui_components/components/new_project_page.py b/ui_components/components/new_project_page.py index b079d38a..68e32b5a 100644 --- a/ui_components/components/new_project_page.py +++ b/ui_components/components/new_project_page.py @@ -14,6 +14,9 @@ def new_project_page(): # Initialize data repository data_repo = DataRepo() + # title + st.markdown("#### New Project") + st.markdown("***") # Define multicolumn layout project_column, filler_column = st.columns(2) @@ -122,4 +125,5 @@ def new_project_page(): st.session_state['app_settings'] = 0 st.success("Project created successfully!") time.sleep(1) - st.rerun() \ No newline at end of file + st.rerun() + st.markdown("***") \ No newline at end of file diff --git a/ui_components/components/project_settings_page.py b/ui_components/components/project_settings_page.py index 598327c1..1bc61972 100644 --- a/ui_components/components/project_settings_page.py +++ b/ui_components/components/project_settings_page.py @@ -10,15 +10,18 @@ def project_settings_page(project_uuid): data_repo = DataRepo() - + st.markdown("#### Project Settings") + st.markdown("***") project_settings = data_repo.get_project_setting(project_uuid) - attach_audio_element(project_uuid, True) + + frame_sizes = ["512x512", "768x512", "512x768"] current_size = f"{project_settings.width}x{project_settings.height}" current_index = frame_sizes.index(current_size) if current_size in frame_sizes else 0 - with st.expander("Frame Size", expanded=True): + with st.expander("πŸ–ΌοΈ Frame Size", expanded=True): + v1, v2, v3 = st.columns([4, 4, 2]) with v1: st.write("Current Size = ", project_settings.width, "x", project_settings.height) @@ -33,4 +36,7 @@ def project_settings_page(project_uuid): if st.button("Save"): data_repo.update_project_setting(project_uuid, width=width) data_repo.update_project_setting(project_uuid, height=height) - st.experimental_rerun() \ No newline at end of file + st.experimental_rerun() + + st.write("") + attach_audio_element(project_uuid, True) \ No newline at end of file diff --git a/ui_components/components/query_logger_page.py b/ui_components/components/query_logger_page.py index 43a5556c..03a69a1d 100644 --- a/ui_components/components/query_logger_page.py +++ b/ui_components/components/query_logger_page.py @@ -6,14 +6,16 @@ from utils.data_repo.data_repo import DataRepo def query_logger_page(): - st.header("Inference Log list") + st.markdown("##### Inference log") data_repo = DataRepo() current_user = get_current_user() b1, b2 = st.columns([1, 1]) total_log_table_pages = st.session_state['total_log_table_pages'] if 'total_log_table_pages' in st.session_state else DefaultTimingStyleParams.total_log_table_pages - page_number = b1.number_input('Page number', min_value=1, max_value=total_log_table_pages, value=1, step=1) + list_of_pages = [i for i in range(1, total_log_table_pages + 1)] + page_number = b1.radio('Select page:', options=list_of_pages, key='inference_log_page_number', index=0, horizontal=True) + # page_number = b1.number_input('Page number', min_value=1, max_value=total_log_table_pages, value=1, step=1) inference_log_list, total_page_count = data_repo.get_all_inference_log_list( page=page_number, data_per_page=100 @@ -43,4 +45,6 @@ def query_logger_page(): data['Status'].append(log.status) - st.table(data=data) \ No newline at end of file + st.table(data=data) + + st.markdown("***") \ No newline at end of file diff --git a/ui_components/components/shortlist_page.py b/ui_components/components/shortlist_page.py index d7b35bf0..179aed5b 100644 --- a/ui_components/components/shortlist_page.py +++ b/ui_components/components/shortlist_page.py @@ -1,21 +1,20 @@ import streamlit as st -from ui_components.components.explorer_page import columnn_selecter,gallery_image_view +from ui_components.components.explorer_page import gallery_image_view from utils.data_repo.data_repo import DataRepo from utils import st_memory def shortlist_page(project_uuid): - st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}]") - st.markdown("***") + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}]") data_repo = DataRepo() project_setting = data_repo.get_project_setting(project_uuid) - columnn_selecter() - k1,k2 = st.columns([5,1]) - shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") - with k2: - open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle',value=False) + # columnn_selecter() + # k1,k2 = st.columns([5,1]) + # shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") + # with k2: + # open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle',value=False) st.markdown("***") - gallery_image_view(project_uuid, shortlist_page_number, st.session_state['num_items_per_page_explorer'], open_detailed_view_for_all, True, st.session_state['num_columns_explorer'],view="shortlist") + gallery_image_view(project_uuid, True,view=['view_inference_details','add_to_any_shot','add_and_remove_from_shortlist']) \ No newline at end of file diff --git a/ui_components/components/timeline_view_page.py b/ui_components/components/timeline_view_page.py index 8799d2de..76e39fb8 100644 --- a/ui_components/components/timeline_view_page.py +++ b/ui_components/components/timeline_view_page.py @@ -17,22 +17,29 @@ def timeline_view_page(shot_uuid: str, h2): if "view" not in st.session_state: st.session_state["view"] = views[0] st.session_state["manual_select"] = None - + st.write("") with st.expander("πŸ“‹ Explorer Shortlist",expanded=True): if st_memory.toggle("Open", value=True, key="explorer_shortlist_toggle"): project_setting = data_repo.get_project_setting(shot.project.uuid) - page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) - gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=10, open_detailed_view_for_all=False, shortlist=True, num_columns=2,view="sidebar") - + # page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) + gallery_image_view(shot.project.uuid, shortlist=True,view=["add_and_remove_from_shortlist","add_to_any_shot"], shot=shot,sidebar=True) + ''' with h2: st.session_state['view'] = option_menu(None, views, icons=['palette', 'camera-reels', "hourglass", 'stopwatch'], menu_icon="cast", orientation="vertical", key="secti2on_selector", styles={ "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}, manual_select=st.session_state["manual_select"]) if st.session_state["manual_select"] != None: st.session_state["manual_select"] = None - + ''' st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{st.session_state['view']}]") st.markdown("***") + slider1, slider2 = st.columns([1,5]) + with slider1: + show_video = st.toggle("Show Video:", value=False, key="show_video") + if show_video: + st.session_state["view"] = "Shots" + else: + st.session_state["view"] = "Key Frames" timeline_view(st.session_state["shot_uuid"], st.session_state['view']) \ No newline at end of file diff --git a/ui_components/components/video_rendering_page.py b/ui_components/components/video_rendering_page.py index 8722ded8..b6875bbe 100644 --- a/ui_components/components/video_rendering_page.py +++ b/ui_components/components/video_rendering_page.py @@ -19,6 +19,8 @@ def video_rendering_page(project_uuid): "Planet_of_the_Snapes", "No_Country_for_Old_Yentas", "The_Expendable_Accountant", "The_Terminal_Illness", "A_Streetcar_Named_Retire", "The_Secret_Life_of_Walter_s_Mitty", "The_Hunger_Games_Catching_Foam", "The_Godfather_Part_Time_Job", "How_To_Kill_a_Mockingbird", "Star_Trek_III_The_Search_for_Spock_s_Missing_Sock", "Gone_with_the_Wind_Chimes", "Dr_No_Clue", "Ferris_Bueller_s_Day_Off_Sick", "Monty_Python_and_the_Holy_Fail", "A_Fistful_of_Quarters", "Willy_Wonka_and_the_Chocolate_Heartburn", "The_Good_the_Bad_and_the_Dandruff", "The_Princess_Bride_of_Frankenstein", "The_Wizard_of_Bras", "Pulp_Friction", "Die_Hard_with_a_Clipboard", "Indiana_Jones_and_the_Last_Audit", "Finding_Nemoy", "The_Silence_of_the_Lambs_The_Musical", "Titanic_2_The_Iceberg_Strikes_Back", "Fast_Times_at_Ridgemont_Mortuary", "The_Graduate_But_Only_Because_He_Has_an_Advanced_Degree", "Beauty_and_the_Yeast", "The_Blair_Witch_Takes_Manhattan", "Reservoir_Bitches", "Die_Hard_with_a_Pension"] random_name = random.choice(parody_movie_names) + st.markdown("#### Video Rendering") + st.markdown("***") final_video_name = st.text_input( "What would you like to name this video?", value=random_name) diff --git a/ui_components/constants.py b/ui_components/constants.py index 867b0ce5..ae7d5eae 100644 --- a/ui_components/constants.py +++ b/ui_components/constants.py @@ -16,6 +16,9 @@ class CreativeProcessType(ExtendedEnum): STYLING = "Key Frames" MOTION = "Shots" +class ShotMetaData(ExtendedEnum): + MOTION_DATA = "motion_data" # {"timing_data": [...]} + class DefaultTimingStyleParams: prompt = "" negative_prompt = "bad image, worst quality" @@ -55,6 +58,16 @@ class DefaultProjectSettingParams: total_shortlist_gallery_pages = 1 max_frames_per_shot = 30 +DEFAULT_SHOT_MOTION_VALUES = { + "strength_of_frame" : 0.5, + "distance_to_next_frame" : 1.0, + "speed_of_transition" : 0.6, + "freedom_between_frames" : 0.5, + "individual_prompt" : "", + "individual_negative_prompt" : "", + "motion_during_frame" : 1.3, +} + # TODO: make proper paths for every file CROPPED_IMG_LOCAL_PATH = "videos/temp/cropped.png" diff --git a/ui_components/methods/common_methods.py b/ui_components/methods/common_methods.py index 28b69546..b15baadf 100644 --- a/ui_components/methods/common_methods.py +++ b/ui_components/methods/common_methods.py @@ -1,4 +1,5 @@ import io +import random from typing import List import os from PIL import Image, ImageDraw, ImageOps, ImageFilter @@ -59,6 +60,9 @@ def clone_styling_settings(source_frame_number, target_frame_uuid): # TODO: image format is assumed to be PNG, change this later def save_new_image(img: Union[Image.Image, str, np.ndarray, io.BytesIO], project_uuid) -> InternalFileObject: + ''' + Saves an image into the project. The image is not added into any shot and is without tags. + ''' data_repo = DataRepo() img = generate_pil_image(img) @@ -245,7 +249,6 @@ def fetch_image_by_stage(shot_uuid, stage, frame_idx): else: return None - # returns a PIL image object def rotate_image(location, degree): if location.startswith('http') or location.startswith('https'): @@ -338,30 +341,27 @@ def promote_video_variant(shot_uuid, variant_uuid): data_repo.update_shot(uuid=shot.uuid, main_clip_id=variant_to_promote.uuid) - -def extract_canny_lines(image_path_or_url, project_uuid, low_threshold=50, high_threshold=150) -> InternalFileObject: - data_repo = DataRepo() - - # Check if the input is a URL - if image_path_or_url.startswith("http"): - response = r.get(image_path_or_url) - image_data = np.frombuffer(response.content, dtype=np.uint8) - image = cv2.imdecode(image_data, cv2.IMREAD_GRAYSCALE) +def get_canny_img(img_obj, low_threshold, high_threshold, invert_img=False): + if isinstance(img_obj, str): + if img_obj.startswith("http"): + response = r.get(img_obj) + image_data = np.frombuffer(response.content, dtype=np.uint8) + image = cv2.imdecode(image_data, cv2.IMREAD_GRAYSCALE) + else: + image = cv2.imread(img_obj, cv2.IMREAD_GRAYSCALE) else: - # Read the image from a local file - image = cv2.imread(image_path_or_url, cv2.IMREAD_GRAYSCALE) + image_data = generate_pil_image(img_obj) + image = np.array(image_data) - # Apply Gaussian blur to the image blurred_image = cv2.GaussianBlur(image, (5, 5), 0) - - # Apply the Canny edge detection canny_edges = cv2.Canny(blurred_image, low_threshold, high_threshold) - - # Reverse the colors (invert the image) - inverted_canny_edges = 255 - canny_edges - - # Convert the inverted Canny edge result to a PIL Image + inverted_canny_edges = 255 - canny_edges if invert_img else canny_edges new_canny_image = Image.fromarray(inverted_canny_edges) + return new_canny_image + +def extract_canny_lines(image_path_or_url, project_uuid, low_threshold=50, high_threshold=150) -> InternalFileObject: + data_repo = DataRepo() + new_canny_image = get_canny_img(image_path_or_url, low_threshold, high_threshold) # Save the new image unique_file_name = str(uuid.uuid4()) + ".png" @@ -382,6 +382,27 @@ def extract_canny_lines(image_path_or_url, project_uuid, low_threshold=50, high_ canny_image_file = data_repo.create_file(**file_data) return canny_image_file +def combine_mask_and_input_image(mask_path, input_image_path, overlap_color="transparent"): + # Open the input image and the mask + input_image = Image.open(input_image_path) + mask_image = Image.open(mask_path) + input_image = input_image.convert("RGBA") + + is_white = lambda pixel, threshold=245: all(value > threshold for value in pixel[:3]) + + fill_color = (0.5,0.5,0.5,1) # default grey + if overlap_color == "transparent": + fill_color = (0,0,0,0) + elif overlap_color == "grey": + fill_color = (0.5, 0.5, 0.5, 1) + + for x in range(mask_image.width): + for y in range(mask_image.height): + if is_white(mask_image.getpixel((x, y))): + input_image.putpixel((x, y), fill_color) + + return input_image + # the input image is an image created by the PIL library def create_or_update_mask(timing_uuid, image) -> InternalFileObject: data_repo = DataRepo() @@ -595,134 +616,137 @@ def save_audio_file(uploaded_file, project_uuid): def execute_image_edit(type_of_mask_selection, type_of_mask_replacement, background_image, editing_image, prompt, negative_prompt, width, height, layer, timing_uuid): - from ui_components.methods.ml_methods import inpainting, remove_background, create_depth_mask_image + from ui_components.methods.ml_methods import inpainting data_repo = DataRepo() timing: InternalFrameTimingObject = data_repo.get_timing_from_uuid( timing_uuid) project = timing.shot.project inference_log = None - if type_of_mask_selection == "Automated Background Selection": - removed_background = remove_background(editing_image) - response = r.get(removed_background) - img = Image.open(BytesIO(response.content)) - hosted_url = save_or_host_file(img, SECOND_MASK_FILE_PATH) - add_temp_file_to_project(project.uuid, SECOND_MASK_FILE, hosted_url or SECOND_MASK_FILE_PATH) - - if type_of_mask_replacement == "Replace With Image": - edited_image = replace_background(project.uuid, background_image) - - elif type_of_mask_replacement == "Inpainting": - path = project.get_temp_mask_file(SECOND_MASK_FILE).location - if path.startswith("http"): - response = r.get(path) - image = Image.open(BytesIO(response.content)) - else: - image = Image.open(path) - - converted_image = Image.new("RGB", image.size, (255, 255, 255)) - for x in range(image.width): - for y in range(image.height): - pixel = image.getpixel((x, y)) - if pixel[3] == 0: - converted_image.putpixel((x, y), (0, 0, 0)) - else: - converted_image.putpixel((x, y), (255, 255, 255)) - create_or_update_mask(timing_uuid, converted_image) - edited_image = inpainting( - editing_image, prompt, negative_prompt, timing.uuid, True) - - elif type_of_mask_selection == "Manual Background Selection": - if type_of_mask_replacement == "Replace With Image": - bg_img = generate_pil_image(editing_image) - mask_img = generate_pil_image(timing.mask.location) - - result_img = Image.new("RGBA", bg_img.size, (255, 255, 255, 0)) - for x in range(bg_img.size[0]): - for y in range(bg_img.size[1]): - if x < mask_img.size[0] and y < mask_img.size[1]: - if mask_img.getpixel((x, y)) == (255, 255, 255): - result_img.putpixel((x, y), (255, 255, 255, 0)) - else: - result_img.putpixel((x, y), bg_img.getpixel((x, y))) + if type_of_mask_selection == "Manual Background Selection": + # NOTE: code not is use + # if type_of_mask_replacement == "Replace With Image": + # bg_img = generate_pil_image(editing_image) + # mask_img = generate_pil_image(timing.mask.location) + + # result_img = Image.new("RGBA", bg_img.size, (255, 255, 255, 0)) + # for x in range(bg_img.size[0]): + # for y in range(bg_img.size[1]): + # if x < mask_img.size[0] and y < mask_img.size[1]: + # if mask_img.getpixel((x, y)) == (255, 255, 255): + # result_img.putpixel((x, y), (255, 255, 255, 0)) + # else: + # result_img.putpixel((x, y), bg_img.getpixel((x, y))) - hosted_manual_bg_url = save_or_host_file(result_img, SECOND_MASK_FILE_PATH) - add_temp_file_to_project(project.uuid, SECOND_MASK_FILE, hosted_manual_bg_url or SECOND_MASK_FILE_PATH) - edited_image = replace_background(project.uuid, background_image) + # hosted_manual_bg_url = save_or_host_file(result_img, SECOND_MASK_FILE_PATH) + # add_temp_file_to_project(project.uuid, SECOND_MASK_FILE, hosted_manual_bg_url or SECOND_MASK_FILE_PATH) + # edited_image = replace_background(project.uuid, background_image) - elif type_of_mask_replacement == "Inpainting": + if type_of_mask_replacement == "Inpainting": edited_image, log = inpainting(editing_image, prompt, negative_prompt, timing_uuid, False) inference_log = log - elif type_of_mask_selection == "Automated Layer Selection": - mask_location = create_depth_mask_image( - editing_image, layer, timing.uuid) - if type_of_mask_replacement == "Replace With Image": - if mask_location.startswith("http"): - mask = Image.open( - BytesIO(r.get(mask_location).content)).convert('1') - else: - mask = Image.open(mask_location).convert('1') - if editing_image.startswith("http"): - response = r.get(editing_image) - bg_img = Image.open(BytesIO(response.content)).convert('RGBA') - else: - bg_img = Image.open(editing_image).convert('RGBA') - - hosted_automated_bg_url = save_or_host_file(result_img, SECOND_MASK_FILE_PATH) - add_temp_file_to_project(project.uuid, SECOND_MASK_FILE, hosted_automated_bg_url or SECOND_MASK_FILE_PATH) - edited_image = replace_background(project.uuid, SECOND_MASK_FILE_PATH, background_image) - - elif type_of_mask_replacement == "Inpainting": - edited_image = inpainting( - editing_image, prompt, negative_prompt, timing_uuid, True) - - elif type_of_mask_selection == "Re-Use Previous Mask": - mask_location = timing.mask.location - if type_of_mask_replacement == "Replace With Image": - if mask_location.startswith("http"): - response = r.get(mask_location) - mask = Image.open(BytesIO(response.content)).convert('1') - else: - mask = Image.open(mask_location).convert('1') - if editing_image.startswith("http"): - response = r.get(editing_image) - bg_img = Image.open(BytesIO(response.content)).convert('RGBA') - else: - bg_img = Image.open(editing_image).convert('RGBA') + # NOTE: code not is use ------------------------------------- + # elif type_of_mask_selection == "Automated Background Selection": + # removed_background = remove_background(editing_image) + # response = r.get(removed_background) + # img = Image.open(BytesIO(response.content)) + # hosted_url = save_or_host_file(img, SECOND_MASK_FILE_PATH) + # add_temp_file_to_project(project.uuid, SECOND_MASK_FILE, hosted_url or SECOND_MASK_FILE_PATH) + + # if type_of_mask_replacement == "Replace With Image": + # edited_image = replace_background(project.uuid, background_image) + + # elif type_of_mask_replacement == "Inpainting": + # path = project.get_temp_mask_file(SECOND_MASK_FILE).location + # if path.startswith("http"): + # response = r.get(path) + # image = Image.open(BytesIO(response.content)) + # else: + # image = Image.open(path) + + # converted_image = Image.new("RGB", image.size, (255, 255, 255)) + # for x in range(image.width): + # for y in range(image.height): + # pixel = image.getpixel((x, y)) + # if pixel[3] == 0: + # converted_image.putpixel((x, y), (0, 0, 0)) + # else: + # converted_image.putpixel((x, y), (255, 255, 255)) + # create_or_update_mask(timing_uuid, converted_image) + # edited_image = inpainting( + # editing_image, prompt, negative_prompt, timing.uuid, True) - hosted_image_replace_url = save_or_host_file(result_img, SECOND_MASK_FILE_PATH) - add_temp_file_to_project(project.uuid, SECOND_MASK_FILE, hosted_image_replace_url or SECOND_MASK_FILE_PATH) - edited_image = replace_background(project.uuid, SECOND_MASK_FILE_PATH, background_image) - - elif type_of_mask_replacement == "Inpainting": - edited_image = inpainting( - editing_image, prompt, negative_prompt, timing_uuid, True) - - elif type_of_mask_selection == "Invert Previous Mask": - if type_of_mask_replacement == "Replace With Image": - mask_location = timing.mask.location - if mask_location.startswith("http"): - response = r.get(mask_location) - mask = Image.open(BytesIO(response.content)).convert('1') - else: - mask = Image.open(mask_location).convert('1') - inverted_mask = ImageOps.invert(mask) - if editing_image.startswith("http"): - response = r.get(editing_image) - bg_img = Image.open(BytesIO(response.content)).convert('RGBA') - else: - bg_img = Image.open(editing_image).convert('RGBA') - masked_img = Image.composite(bg_img, Image.new( - 'RGBA', bg_img.size, (0, 0, 0, 0)), inverted_mask) - # TODO: standardise temproray fixes - hosted_prvious_invert_url = save_or_host_file(result_img, SECOND_MASK_FILE_PATH) - add_temp_file_to_project(project.uuid, SECOND_MASK_FILE, hosted_prvious_invert_url or SECOND_MASK_FILE_PATH) - edited_image = replace_background(project.uuid, SECOND_MASK_FILE_PATH, background_image) - - elif type_of_mask_replacement == "Inpainting": - edited_image = inpainting( - editing_image, prompt, negative_prompt, timing_uuid, False) + # elif type_of_mask_selection == "Automated Layer Selection": + # mask_location = create_depth_mask_image( + # editing_image, layer, timing.uuid) + # if type_of_mask_replacement == "Replace With Image": + # if mask_location.startswith("http"): + # mask = Image.open( + # BytesIO(r.get(mask_location).content)).convert('1') + # else: + # mask = Image.open(mask_location).convert('1') + # if editing_image.startswith("http"): + # response = r.get(editing_image) + # bg_img = Image.open(BytesIO(response.content)).convert('RGBA') + # else: + # bg_img = Image.open(editing_image).convert('RGBA') + + # hosted_automated_bg_url = save_or_host_file(result_img, SECOND_MASK_FILE_PATH) + # add_temp_file_to_project(project.uuid, SECOND_MASK_FILE, hosted_automated_bg_url or SECOND_MASK_FILE_PATH) + # edited_image = replace_background(project.uuid, SECOND_MASK_FILE_PATH, background_image) + + # elif type_of_mask_replacement == "Inpainting": + # edited_image = inpainting( + # editing_image, prompt, negative_prompt, timing_uuid, True) + + # elif type_of_mask_selection == "Re-Use Previous Mask": + # mask_location = timing.mask.location + # if type_of_mask_replacement == "Replace With Image": + # if mask_location.startswith("http"): + # response = r.get(mask_location) + # mask = Image.open(BytesIO(response.content)).convert('1') + # else: + # mask = Image.open(mask_location).convert('1') + # if editing_image.startswith("http"): + # response = r.get(editing_image) + # bg_img = Image.open(BytesIO(response.content)).convert('RGBA') + # else: + # bg_img = Image.open(editing_image).convert('RGBA') + + # hosted_image_replace_url = save_or_host_file(result_img, SECOND_MASK_FILE_PATH) + # add_temp_file_to_project(project.uuid, SECOND_MASK_FILE, hosted_image_replace_url or SECOND_MASK_FILE_PATH) + # edited_image = replace_background(project.uuid, SECOND_MASK_FILE_PATH, background_image) + + # elif type_of_mask_replacement == "Inpainting": + # edited_image = inpainting( + # editing_image, prompt, negative_prompt, timing_uuid, True) + + # elif type_of_mask_selection == "Invert Previous Mask": + # if type_of_mask_replacement == "Replace With Image": + # mask_location = timing.mask.location + # if mask_location.startswith("http"): + # response = r.get(mask_location) + # mask = Image.open(BytesIO(response.content)).convert('1') + # else: + # mask = Image.open(mask_location).convert('1') + # inverted_mask = ImageOps.invert(mask) + # if editing_image.startswith("http"): + # response = r.get(editing_image) + # bg_img = Image.open(BytesIO(response.content)).convert('RGBA') + # else: + # bg_img = Image.open(editing_image).convert('RGBA') + # masked_img = Image.composite(bg_img, Image.new( + # 'RGBA', bg_img.size, (0, 0, 0, 0)), inverted_mask) + # # TODO: standardise temproray fixes + # hosted_prvious_invert_url = save_or_host_file(result_img, SECOND_MASK_FILE_PATH) + # add_temp_file_to_project(project.uuid, SECOND_MASK_FILE, hosted_prvious_invert_url or SECOND_MASK_FILE_PATH) + # edited_image = replace_background(project.uuid, SECOND_MASK_FILE_PATH, background_image) + + # elif type_of_mask_replacement == "Inpainting": + # edited_image = inpainting( + # editing_image, prompt, negative_prompt, timing_uuid, False) + # --------------------------------------------------------------------- return edited_image, inference_log @@ -788,14 +812,19 @@ def process_inference_output(**kwargs): if not shot: return False + output = output[-1] if isinstance(output, list) else output # output can also be an url - if isinstance(output, str) and output.startswith("http"): - temp_output_file = generate_temp_file(output, '.mp4') - output = None - with open(temp_output_file.name, 'rb') as f: - output = f.read() - - os.remove(temp_output_file.name) + if isinstance(output, str): + if output.startswith("http"): + temp_output_file = generate_temp_file(output, '.mp4') + output = None + with open(temp_output_file.name, 'rb') as f: + output = f.read() + + os.remove(temp_output_file.name) + else: + with open(output, 'rb') as f: + output = f.read() if 'normalise_speed' in settings and settings['normalise_speed']: output = VideoProcessor.update_video_bytes_speed(output, shot.duration) @@ -841,7 +870,7 @@ def process_inference_output(**kwargs): hosted_url=output[0] if isinstance(output, list) else output, inference_log_id=log.uuid, project_id=project_uuid, - tag=InternalFileTag.GALLERY_IMAGE.value + tag=InternalFileTag.TEMP_GALLERY_IMAGE.value # will be updated to GALLERY_IMAGE once the user clicks 'check for new images' ) else: log_uuid = kwargs.get('log_uuid') @@ -944,4 +973,8 @@ def update_app_setting_keys(): app_logger.log(LoggingType.DEBUG, 'setting keys', None) data_repo.update_app_setting(replicate_username='bn') - data_repo.update_app_setting(replicate_key=key) \ No newline at end of file + data_repo.update_app_setting(replicate_key=key) + + +def random_seed(): + return random.randint(10**14, 10**15 - 1) \ No newline at end of file diff --git a/ui_components/methods/data_logger.py b/ui_components/methods/data_logger.py index 31aa60e1..7e0320ed 100644 --- a/ui_components/methods/data_logger.py +++ b/ui_components/methods/data_logger.py @@ -3,14 +3,13 @@ import time from shared.constants import InferenceStatus from shared.logging.constants import LoggingPayload, LoggingType -from shared.logging.logging import AppLogger from utils.common_utils import get_current_user_uuid from utils.data_repo.data_repo import DataRepo -from utils.ml_processor.replicate.constants import REPLICATE_MODEL, ReplicateModel +from utils.ml_processor.constants import ML_MODEL, MLModel -def log_model_inference(model: ReplicateModel, time_taken, **kwargs): +def log_model_inference(model: MLModel, time_taken, **kwargs): kwargs_dict = dict(kwargs) # removing object like bufferedreader, image_obj .. @@ -21,19 +20,11 @@ def log_model_inference(model: ReplicateModel, time_taken, **kwargs): data_str = json.dumps(kwargs_dict) time_taken = round(time_taken, 2) if time_taken else 0 - data = { - 'model_name': model.name, - 'model_version': model.version, - 'total_inference_time': time_taken, - 'input_params': data_str, - 'created_on': int(time.time()) - } - - system_logger = AppLogger() - logging_payload = LoggingPayload(message="logging inference data", data=data) + # system_logger = AppLogger() + # logging_payload = LoggingPayload(message="logging inference data", data=data) - # logging in console - system_logger.log(LoggingType.INFERENCE_CALL, logging_payload) + # # logging in console + # system_logger.log(LoggingType.INFERENCE_CALL, logging_payload) # storing the log in db data_repo = DataRepo() @@ -41,14 +32,14 @@ def log_model_inference(model: ReplicateModel, time_taken, **kwargs): ai_model = data_repo.get_ai_model_from_name(model.name, user_id) # hackish sol for insuring that inpainting logs don't have an empty model field - if ai_model is None and model.name in [REPLICATE_MODEL.sdxl_inpainting.name, REPLICATE_MODEL.ad_interpolation.name]: - ai_model = data_repo.get_ai_model_from_name(REPLICATE_MODEL.sdxl.name, user_id) + if ai_model is None and model.name in [ML_MODEL.sdxl_inpainting.name, ML_MODEL.ad_interpolation.name]: + ai_model = data_repo.get_ai_model_from_name(ML_MODEL.sdxl.name, user_id) log_data = { "project_id" : st.session_state["project_uuid"], "model_id" : ai_model.uuid if ai_model else None, "input_params" : data_str, - "output_details" : json.dumps({"model_name": model.name, "version": model.version}), + "output_details" : json.dumps({"model_name": model.display_name(), "version": model.version}), "total_inference_time" : time_taken, "status" : InferenceStatus.COMPLETED.value if time_taken else InferenceStatus.QUEUED.value, } diff --git a/ui_components/methods/file_methods.py b/ui_components/methods/file_methods.py index c8547b1e..ca824bd9 100644 --- a/ui_components/methods/file_methods.py +++ b/ui_components/methods/file_methods.py @@ -84,6 +84,12 @@ def normalize_size_internal_file_obj(file_obj: InternalFileObject, **kwargs): project_setting = data_repo.get_project_setting(file_obj.project.uuid) dim = (project_setting.width, project_setting.height) + create_new_file = True if 'create_new_file' in kwargs \ + and kwargs['create_new_file'] else False + + if create_new_file: + file_obj = create_duplicate_file(file_obj) + pil_file = generate_pil_image(file_obj.location) uploaded_url = save_or_host_file(pil_file, file_obj.location, mime_type='image/png', dim=dim) if uploaded_url: @@ -92,7 +98,6 @@ def normalize_size_internal_file_obj(file_obj: InternalFileObject, **kwargs): return file_obj - def save_or_host_file_bytes(video_bytes, path, ext=".mp4"): uploaded_url = None if SERVER != ServerType.DEVELOPMENT.value: @@ -130,7 +135,6 @@ def add_temp_file_to_project(project_uuid, key, file_path): } data_repo.update_project(**project_data) - def generate_temp_file(url, ext=".mp4"): response = requests.get(url) if not response.ok: @@ -142,7 +146,6 @@ def generate_temp_file(url, ext=".mp4"): return temp_file - def generate_pil_image(img: Union[Image.Image, str, np.ndarray, io.BytesIO]): # Check if img is a PIL image if isinstance(img, Image.Image): @@ -177,7 +180,6 @@ def generate_temp_file_from_uploaded_file(uploaded_file): temp_file.write(uploaded_file.read()) return temp_file - def convert_bytes_to_file(file_location_to_save, mime_type, file_bytes, project_uuid, inference_log_id=None, filename=None, tag="") -> InternalFileObject: data_repo = DataRepo() @@ -201,7 +203,6 @@ def convert_bytes_to_file(file_location_to_save, mime_type, file_bytes, project_ return file - def convert_file_to_base64(fh: io.IOBase) -> str: fh.seek(0) @@ -216,6 +217,13 @@ def convert_file_to_base64(fh: io.IOBase) -> str: s = encoded_body.decode("utf-8") return f"data:{mime_type};base64,{s}" +def resize_io_buffers(io_buffer, target_width, target_height, format="PNG"): + input_image = Image.open(io_buffer) + input_image = input_image.resize((target_width, target_height), Image.ANTIALIAS) + output_image_buffer = io.BytesIO() + input_image.save(output_image_buffer, format='PNG') + return output_image_buffer + ENV_FILE_PATH = '.env' def save_to_env(key, value): set_key(dotenv_path=ENV_FILE_PATH, key_to_set=key, value_to_set=value) @@ -230,7 +238,7 @@ def load_from_env(key): from PIL import Image from io import BytesIO -def zip_images(image_locations, zip_filename='images.zip'): +def zip_images(image_locations, zip_filename='images.zip', filename_list=[]): # Calculate the number of digits needed for padding num_digits = len(str(len(image_locations) - 1)) @@ -238,7 +246,10 @@ def zip_images(image_locations, zip_filename='images.zip'): for idx, image_location in enumerate(image_locations): # Pad the index with zeros padded_idx = str(idx).zfill(num_digits) - image_name = f"{padded_idx}.png" + if filename_list and len(filename_list) > idx: + image_name = filename_list[idx] + else: + image_name = f"{padded_idx}.png" if image_location.startswith('http'): response = requests.get(image_location) @@ -265,12 +276,10 @@ def zip_images(image_locations, zip_filename='images.zip'): return zip_filename - - def create_duplicate_file(file: InternalFileObject, project_uuid=None) -> InternalFileObject: data_repo = DataRepo() - unique_id = ''.join(random.choices(string.ascii_lowercase + string.digits, k=5)) + ".mp4" + unique_id = ''.join(random.choices(string.ascii_lowercase + string.digits, k=5)) file_data = { "name": unique_id + '_' + file.name, "type": file.type, diff --git a/ui_components/methods/ml_methods.py b/ui_components/methods/ml_methods.py index b704796d..d2ffe5e4 100644 --- a/ui_components/methods/ml_methods.py +++ b/ui_components/methods/ml_methods.py @@ -8,182 +8,185 @@ import uuid import urllib from backend.models import InternalFileObject -from shared.constants import QUEUE_INFERENCE_QUERIES, SERVER, AIModelCategory, InferenceType, InternalFileType, ServerType +from shared.constants import GPU_INFERENCE_ENABLED, QUEUE_INFERENCE_QUERIES, SERVER, AIModelCategory, InferenceType, InternalFileType, ServerType from ui_components.constants import MASK_IMG_LOCAL_PATH, TEMP_MASK_FILE -from ui_components.methods.common_methods import process_inference_output +from ui_components.methods.common_methods import combine_mask_and_input_image, process_inference_output +from ui_components.methods.file_methods import save_or_host_file from ui_components.models import InternalAIModelObject, InternalFrameTimingObject, InternalSettingObject from utils.constants import ImageStage, MLQueryObject from utils.data_repo.data_repo import DataRepo from utils.ml_processor.ml_interface import get_ml_client -from utils.ml_processor.replicate.constants import REPLICATE_MODEL, ReplicateModel +from utils.ml_processor.constants import ML_MODEL, MLModel -def trigger_restyling_process(timing_uuid, update_inference_settings, \ - transformation_stage, promote_new_generation, **kwargs): - data_repo = DataRepo() +# NOTE: code not is use +# def trigger_restyling_process(timing_uuid, update_inference_settings, \ +# transformation_stage, promote_new_generation, **kwargs): +# data_repo = DataRepo() - timing: InternalFrameTimingObject = data_repo.get_timing_from_uuid(timing_uuid) - project_settings: InternalSettingObject = data_repo.get_project_setting(timing.shot.project.uuid) +# timing: InternalFrameTimingObject = data_repo.get_timing_from_uuid(timing_uuid) +# project_settings: InternalSettingObject = data_repo.get_project_setting(timing.shot.project.uuid) - source_image = timing.source_image if transformation_stage == ImageStage.SOURCE_IMAGE.value else \ - timing.primary_image - - query_obj = MLQueryObject( - timing_uuid, - image_uuid=source_image.uuid if 'add_image_in_params' in kwargs and kwargs['add_image_in_params'] else None, - width=project_settings.width, - height=project_settings.height, - **kwargs - ) - - prompt = query_obj.prompt - if update_inference_settings is True: - prompt = prompt.replace(",", ".") - prompt = prompt.replace("\n", "") - - project_settings.batch_prompt = prompt - project_settings.batch_strength = query_obj.strength - project_settings.batch_negative_prompt = query_obj.negative_prompt - project_settings.batch_guidance_scale = query_obj.guidance_scale - project_settings.batch_seed = query_obj.seed - project_settings.batch_num_inference_steps = query_obj.num_inference_steps - # project_settings.batch_custom_models = query_obj.data.get('custom_models', []), - project_settings.batch_adapter_type = query_obj.adapter_type - # project_settings.batch_add_image_in_params = st.session_state['add_image_in_params'], - - query_obj.prompt = dynamic_prompting(prompt, source_image) - output, log = restyle_images(query_obj, QUEUE_INFERENCE_QUERIES) - - inference_data = { - "inference_type": InferenceType.FRAME_TIMING_IMAGE_INFERENCE.value, - "output": output, - "log_uuid": log.uuid, - "timing_uuid": timing_uuid, - "promote_new_generation": promote_new_generation, - } - process_inference_output(**inference_data) +# source_image = timing.source_image if transformation_stage == ImageStage.SOURCE_IMAGE.value else \ +# timing.primary_image + +# query_obj = MLQueryObject( +# timing_uuid, +# image_uuid=source_image.uuid if 'add_image_in_params' in kwargs and kwargs['add_image_in_params'] else None, +# width=project_settings.width, +# height=project_settings.height, +# **kwargs +# ) + +# prompt = query_obj.prompt +# if update_inference_settings is True: +# prompt = prompt.replace(",", ".") +# prompt = prompt.replace("\n", "") + +# project_settings.batch_prompt = prompt +# project_settings.batch_strength = query_obj.strength +# project_settings.batch_negative_prompt = query_obj.negative_prompt +# project_settings.batch_guidance_scale = query_obj.guidance_scale +# project_settings.batch_seed = query_obj.seed +# project_settings.batch_num_inference_steps = query_obj.num_inference_steps +# # project_settings.batch_custom_models = query_obj.data.get('custom_models', []), +# project_settings.batch_adapter_type = query_obj.adapter_type +# # project_settings.batch_add_image_in_params = st.session_state['add_image_in_params'], + +# # query_obj.prompt = dynamic_prompting(prompt, source_image) +# output, log = restyle_images(query_obj, QUEUE_INFERENCE_QUERIES) + +# inference_data = { +# "inference_type": InferenceType.FRAME_TIMING_IMAGE_INFERENCE.value, +# "output": output, +# "log_uuid": log.uuid, +# "timing_uuid": timing_uuid, +# "promote_new_generation": promote_new_generation, +# } +# process_inference_output(**inference_data) + +# def restyle_images(query_obj: MLQueryObject, queue_inference=False) -> InternalFileObject: +# data_repo = DataRepo() +# ml_client = get_ml_client() +# db_model = data_repo.get_ai_model_from_uuid(query_obj.model_uuid) + +# if db_model.category == AIModelCategory.LORA.value: +# model = ML_MODEL.clones_lora_training_2 +# output, log = ml_client.predict_model_output_standardized(model, query_obj, queue_inference=queue_inference) + +# elif db_model.category == AIModelCategory.CONTROLNET.value: +# adapter_type = query_obj.adapter_type +# if adapter_type == "normal": +# model = ML_MODEL.jagilley_controlnet_normal +# elif adapter_type == "canny": +# model = ML_MODEL.jagilley_controlnet_canny +# elif adapter_type == "hed": +# model = ML_MODEL.jagilley_controlnet_hed +# elif adapter_type == "scribble": +# model = ML_MODEL.jagilley_controlnet_scribble +# elif adapter_type == "seg": +# model = ML_MODEL.jagilley_controlnet_seg +# elif adapter_type == "hough": +# model = ML_MODEL.jagilley_controlnet_hough +# elif adapter_type == "depth2img": +# model = ML_MODEL.jagilley_controlnet_depth2img +# elif adapter_type == "pose": +# model = ML_MODEL.jagilley_controlnet_pose +# output, log = ml_client.predict_model_output_standardized(model, query_obj, queue_inference=queue_inference) + +# elif db_model.category == AIModelCategory.DREAMBOOTH.value: +# output, log = prompt_model_dreambooth(query_obj, queue_inference=queue_inference) + +# else: +# model = ML_MODEL.get_model_by_db_obj(db_model) # TODO: remove this dependency +# output, log = ml_client.predict_model_output_standardized(model, query_obj, queue_inference=queue_inference) + +# return output, log + +# def prompt_model_dreambooth(query_obj: MLQueryObject, queue_inference=False): +# data_repo = DataRepo() +# ml_client = get_ml_client() + +# model_uuid = query_obj.data.get('dreambooth_model_uuid', None) +# if not model_uuid: +# st.error('No dreambooth model selected') +# return - -def restyle_images(query_obj: MLQueryObject, queue_inference=False) -> InternalFileObject: - data_repo = DataRepo() - ml_client = get_ml_client() - db_model = data_repo.get_ai_model_from_uuid(query_obj.model_uuid) - - if db_model.category == AIModelCategory.LORA.value: - model = REPLICATE_MODEL.clones_lora_training_2 - output, log = ml_client.predict_model_output_standardized(model, query_obj, queue_inference=queue_inference) - - elif db_model.category == AIModelCategory.CONTROLNET.value: - adapter_type = query_obj.adapter_type - if adapter_type == "normal": - model = REPLICATE_MODEL.jagilley_controlnet_normal - elif adapter_type == "canny": - model = REPLICATE_MODEL.jagilley_controlnet_canny - elif adapter_type == "hed": - model = REPLICATE_MODEL.jagilley_controlnet_hed - elif adapter_type == "scribble": - model = REPLICATE_MODEL.jagilley_controlnet_scribble - elif adapter_type == "seg": - model = REPLICATE_MODEL.jagilley_controlnet_seg - elif adapter_type == "hough": - model = REPLICATE_MODEL.jagilley_controlnet_hough - elif adapter_type == "depth2img": - model = REPLICATE_MODEL.jagilley_controlnet_depth2img - elif adapter_type == "pose": - model = REPLICATE_MODEL.jagilley_controlnet_pose - output, log = ml_client.predict_model_output_standardized(model, query_obj, queue_inference=queue_inference) - - elif db_model.category == AIModelCategory.DREAMBOOTH.value: - output, log = prompt_model_dreambooth(query_obj, queue_inference=queue_inference) - - else: - model = REPLICATE_MODEL.get_model_by_db_obj(db_model) # TODO: remove this dependency - output, log = ml_client.predict_model_output_standardized(model, query_obj, queue_inference=queue_inference) - - return output, log - -def prompt_model_dreambooth(query_obj: MLQueryObject, queue_inference=False): - data_repo = DataRepo() - ml_client = get_ml_client() - - model_uuid = query_obj.data.get('dreambooth_model_uuid', None) - if not model_uuid: - st.error('No dreambooth model selected') - return - - dreambooth_model: InternalAIModelObject = data_repo.get_ai_model_from_uuid(model_uuid) +# dreambooth_model: InternalAIModelObject = data_repo.get_ai_model_from_uuid(model_uuid) - model_name = dreambooth_model.name - model_id = dreambooth_model.replicate_url - - if not dreambooth_model.version: - version = ml_client.get_model_version_from_id(model_id) - data_repo.update_ai_model(uuid=dreambooth_model.uuid, version=version) - else: - version = dreambooth_model.version - - app_setting = data_repo.get_app_setting_from_uuid() - model = ReplicateModel(f"{app_setting.replicate_username}/{model_name}", version) - output, log = ml_client.predict_model_output_standardized(model, query_obj, queue_inference=queue_inference) - - return output, log - - -def prompt_clip_interrogator(input_image, which_model, best_or_fast): - if which_model == "Stable Diffusion 1.5": - which_model = "ViT-L-14/openai" - elif which_model == "Stable Diffusion 2": - which_model = "ViT-H-14/laion2b_s32b_b79k" - - if not input_image.startswith("http"): - input_image = open(input_image, "rb") - - ml_client = get_ml_client() - output, _ = ml_client.predict_model_output( - REPLICATE_MODEL.clip_interrogator, image=input_image, clip_model_name=which_model, mode=best_or_fast) - - return output - -def prompt_model_blip2(input_image, query): - if not input_image.startswith("http"): - input_image = open(input_image, "rb") - - ml_client = get_ml_client() - output, _ = ml_client.predict_model_output( - REPLICATE_MODEL.salesforce_blip_2, image=input_image, question=query) - - return output - -def facial_expression_recognition(input_image): - input_image = input_image.location - if not input_image.startswith("http"): - input_image = open(input_image, "rb") - - ml_client = get_ml_client() - output, _ = ml_client.predict_model_output( - REPLICATE_MODEL.phamquiluan_face_recognition, input_path=input_image) - - emo_label = output[0]["emo_label"] - if emo_label == "disgust": - emo_label = "disgusted" - elif emo_label == "fear": - emo_label = "fearful" - elif emo_label == "surprised": - emo_label = "surprised" - emo_proba = output[0]["emo_proba"] - if emo_proba > 0.95: - emotion = (f"very {emo_label} expression") - elif emo_proba > 0.85: - emotion = (f"{emo_label} expression") - elif emo_proba > 0.75: - emotion = (f"somewhat {emo_label} expression") - elif emo_proba > 0.65: - emotion = (f"slightly {emo_label} expression") - elif emo_proba > 0.55: - emotion = (f"{emo_label} expression") - else: - emotion = (f"neutral expression") - return emotion +# model_name = dreambooth_model.name +# model_id = dreambooth_model.replicate_url + +# if not dreambooth_model.version: +# version = ml_client.get_model_version_from_id(model_id) +# data_repo.update_ai_model(uuid=dreambooth_model.uuid, version=version) +# else: +# version = dreambooth_model.version + +# app_setting = data_repo.get_app_setting_from_uuid() +# model = MLModel(f"{app_setting.replicate_username}/{model_name}", version) +# output, log = ml_client.predict_model_output_standardized(model, query_obj, queue_inference=queue_inference) + +# return output, log + +# NOTE: code not is use +# def prompt_clip_interrogator(input_image, which_model, best_or_fast): +# if which_model == "Stable Diffusion 1.5": +# which_model = "ViT-L-14/openai" +# elif which_model == "Stable Diffusion 2": +# which_model = "ViT-H-14/laion2b_s32b_b79k" + +# if not input_image.startswith("http"): +# input_image = open(input_image, "rb") + +# ml_client = get_ml_client() +# output, _ = ml_client.predict_model_output( +# ML_MODEL.clip_interrogator, image=input_image, clip_model_name=which_model, mode=best_or_fast) + +# return output + +# NOTE: code not is use +# def prompt_model_blip2(input_image, query): +# if not input_image.startswith("http"): +# input_image = open(input_image, "rb") + +# ml_client = get_ml_client() +# output, _ = ml_client.predict_model_output( +# ML_MODEL.salesforce_blip_2, image=input_image, question=query) + +# return output + +# NOTE: code not is use +# def facial_expression_recognition(input_image): +# input_image = input_image.location +# if not input_image.startswith("http"): +# input_image = open(input_image, "rb") + +# ml_client = get_ml_client() +# output, _ = ml_client.predict_model_output( +# ML_MODEL.phamquiluan_face_recognition, input_path=input_image) + +# emo_label = output[0]["emo_label"] +# if emo_label == "disgust": +# emo_label = "disgusted" +# elif emo_label == "fear": +# emo_label = "fearful" +# elif emo_label == "surprised": +# emo_label = "surprised" +# emo_proba = output[0]["emo_proba"] +# if emo_proba > 0.95: +# emotion = (f"very {emo_label} expression") +# elif emo_proba > 0.85: +# emotion = (f"{emo_label} expression") +# elif emo_proba > 0.75: +# emotion = (f"somewhat {emo_label} expression") +# elif emo_proba > 0.65: +# emotion = (f"slightly {emo_label} expression") +# elif emo_proba > 0.55: +# emotion = (f"{emo_label} expression") +# else: +# emotion = (f"neutral expression") +# return emotion def inpainting(input_image: str, prompt, negative_prompt, timing_uuid, mask_in_project=False) -> InternalFileObject: data_repo = DataRepo() @@ -201,103 +204,123 @@ def inpainting(input_image: str, prompt, negative_prompt, timing_uuid, mask_in_p if not input_image.startswith("http"): input_image = open(input_image, "rb") + query_obj = MLQueryObject( + timing_uuid=timing_uuid, + model_uuid=None, + guidance_scale=7.5, + seed=-1, + num_inference_steps=25, + strength=0.7, + adapter_type=None, + prompt=prompt, + negative_prompt=negative_prompt, + height=512, + width=512, + low_threshold=100, # update these default values + high_threshold=200, + image_uuid=None, + mask_uuid=None, + data={ + "input_image": input_image, + "mask": mask, + } + ) + ml_client = get_ml_client() - output, log = ml_client.predict_model_output( - REPLICATE_MODEL.sdxl_inpainting, - mask=mask, - image=input_image, - prompt=prompt, - negative_prompt=negative_prompt, - num_inference_steps=25, - strength=0.99, - queue_inference=QUEUE_INFERENCE_QUERIES + output, log = ml_client.predict_model_output_standardized( + ML_MODEL.sdxl_inpainting, + query_obj, + QUEUE_INFERENCE_QUERIES ) return output, log -def remove_background(input_image): - if not input_image.startswith("http"): - input_image = open(input_image, "rb") +# NOTE: code not is use +# def remove_background(input_image): +# if not input_image.startswith("http"): +# input_image = open(input_image, "rb") - ml_client = get_ml_client() - output, _ = ml_client.predict_model_output( - REPLICATE_MODEL.pollination_modnet, image=input_image) - return output +# ml_client = get_ml_client() +# output, _ = ml_client.predict_model_output( +# ML_MODEL.pollination_modnet, image=input_image) +# return output -def create_depth_mask_image(input_image, layer, timing_uuid): - from ui_components.methods.common_methods import create_or_update_mask +# NOTE: code not is use +# def create_depth_mask_image(input_image, layer, timing_uuid): +# from ui_components.methods.common_methods import create_or_update_mask - if not input_image.startswith("http"): - input_image = open(input_image, "rb") - - ml_client = get_ml_client() - output, log = ml_client.predict_model_output( - REPLICATE_MODEL.cjwbw_midas, image=input_image, model_type="dpt_beit_large_512") - try: - temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='wb') - with urllib.request.urlopen(output) as response, open(temp_file.name, 'wb') as out_file: - out_file.write(response.read()) - except Exception as e: - print(e) - - depth_map = Image.open(temp_file.name) - os.remove(temp_file.name) - depth_map = depth_map.convert("L") # Convert to grayscale image - pixels = depth_map.load() - mask = Image.new("L", depth_map.size) - mask_pixels = mask.load() - - fg_mask = Image.new("L", depth_map.size) if "Foreground" in layer else None - mg_mask = Image.new( - "L", depth_map.size) if "Middleground" in layer else None - bg_mask = Image.new("L", depth_map.size) if "Background" in layer else None - - fg_pixels = fg_mask.load() if fg_mask else None - mg_pixels = mg_mask.load() if mg_mask else None - bg_pixels = bg_mask.load() if bg_mask else None - - for i in range(depth_map.size[0]): - for j in range(depth_map.size[1]): - depth_value = pixels[i, j] - - if fg_pixels: - fg_pixels[i, j] = 0 if depth_value > 200 else 255 - if mg_pixels: - mg_pixels[i, j] = 0 if depth_value <= 200 and depth_value > 50 else 255 - if bg_pixels: - bg_pixels[i, j] = 0 if depth_value <= 50 else 255 - - mask_pixels[i, j] = 255 - if fg_pixels: - mask_pixels[i, j] &= fg_pixels[i, j] - if mg_pixels: - mask_pixels[i, j] &= mg_pixels[i, j] - if bg_pixels: - mask_pixels[i, j] &= bg_pixels[i, j] - - return create_or_update_mask(timing_uuid, mask) - -def dynamic_prompting(prompt, source_image): - if "[expression]" in prompt: - prompt_expression = facial_expression_recognition(source_image) - prompt = prompt.replace("[expression]", prompt_expression) - - if "[location]" in prompt: - prompt_location = prompt_model_blip2( - source_image, "What's surrounding the character?") - prompt = prompt.replace("[location]", prompt_location) - - if "[mouth]" in prompt: - prompt_mouth = prompt_model_blip2( - source_image, "is their mouth open or closed?") - prompt = prompt.replace("[mouth]", "mouth is " + str(prompt_mouth)) - - if "[looking]" in prompt: - prompt_looking = prompt_model_blip2( - source_image, "the person is looking") - prompt = prompt.replace("[looking]", "looking " + str(prompt_looking)) - - return prompt +# if not input_image.startswith("http"): +# input_image = open(input_image, "rb") + +# ml_client = get_ml_client() +# output, log = ml_client.predict_model_output( +# ML_MODEL.cjwbw_midas, image=input_image, model_type="dpt_beit_large_512") +# try: +# temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='wb') +# with urllib.request.urlopen(output) as response, open(temp_file.name, 'wb') as out_file: +# out_file.write(response.read()) +# except Exception as e: +# print(e) + +# depth_map = Image.open(temp_file.name) +# os.remove(temp_file.name) +# depth_map = depth_map.convert("L") # Convert to grayscale image +# pixels = depth_map.load() +# mask = Image.new("L", depth_map.size) +# mask_pixels = mask.load() + +# fg_mask = Image.new("L", depth_map.size) if "Foreground" in layer else None +# mg_mask = Image.new( +# "L", depth_map.size) if "Middleground" in layer else None +# bg_mask = Image.new("L", depth_map.size) if "Background" in layer else None + +# fg_pixels = fg_mask.load() if fg_mask else None +# mg_pixels = mg_mask.load() if mg_mask else None +# bg_pixels = bg_mask.load() if bg_mask else None + +# for i in range(depth_map.size[0]): +# for j in range(depth_map.size[1]): +# depth_value = pixels[i, j] + +# if fg_pixels: +# fg_pixels[i, j] = 0 if depth_value > 200 else 255 +# if mg_pixels: +# mg_pixels[i, j] = 0 if depth_value <= 200 and depth_value > 50 else 255 +# if bg_pixels: +# bg_pixels[i, j] = 0 if depth_value <= 50 else 255 + +# mask_pixels[i, j] = 255 +# if fg_pixels: +# mask_pixels[i, j] &= fg_pixels[i, j] +# if mg_pixels: +# mask_pixels[i, j] &= mg_pixels[i, j] +# if bg_pixels: +# mask_pixels[i, j] &= bg_pixels[i, j] + +# return create_or_update_mask(timing_uuid, mask) + +# NOTE: code not is use +# def dynamic_prompting(prompt, source_image): +# if "[expression]" in prompt: +# prompt_expression = facial_expression_recognition(source_image) +# prompt = prompt.replace("[expression]", prompt_expression) + +# if "[location]" in prompt: +# prompt_location = prompt_model_blip2( +# source_image, "What's surrounding the character?") +# prompt = prompt.replace("[location]", prompt_location) + +# if "[mouth]" in prompt: +# prompt_mouth = prompt_model_blip2( +# source_image, "is their mouth open or closed?") +# prompt = prompt.replace("[mouth]", "mouth is " + str(prompt_mouth)) + +# if "[looking]" in prompt: +# prompt_looking = prompt_model_blip2( +# source_image, "the person is looking") +# prompt = prompt.replace("[looking]", "looking " + str(prompt_looking)) + +# return prompt def query_llama2(prompt, temperature): ml_client = get_ml_client() @@ -312,7 +335,7 @@ def query_llama2(prompt, temperature): "stop_sequences": "\n" } - output, log = ml_client.predict_model_output(REPLICATE_MODEL.llama_2_7b, **input) + output, log = ml_client.predict_model_output(ML_MODEL.llama_2_7b, **input) result = "" for item in output: result += item diff --git a/ui_components/methods/training_methods.py b/ui_components/methods/training_methods.py index f1f308cf..9e77b8cb 100644 --- a/ui_components/methods/training_methods.py +++ b/ui_components/methods/training_methods.py @@ -4,87 +4,89 @@ from utils.common_utils import get_current_user_uuid from utils.data_repo.data_repo import DataRepo from utils.ml_processor.ml_interface import get_ml_client -from utils.ml_processor.replicate.constants import REPLICATE_MODEL +from utils.ml_processor.constants import ML_MODEL -# NOTE: making an exception for this function, passing just the image urls instead of -# image files -def train_model(images_list, instance_prompt, class_prompt, max_train_steps, - model_name, type_of_model, type_of_task, resolution, controller_type, model_type_list): - # prepare and upload the training data (images.zip) - ml_client = get_ml_client() - try: - training_file_url = ml_client.upload_training_data(images_list) - except Exception as e: - raise e +# NOTE: code not in use +# NOTE: making an exception for this function, passing just the image urls instead of image files +# def train_model(images_list, instance_prompt, class_prompt, max_train_steps, +# model_name, type_of_model, type_of_task, resolution, controller_type, model_type_list): +# # prepare and upload the training data (images.zip) +# ml_client = get_ml_client() +# try: +# training_file_url = ml_client.upload_training_data(images_list) +# except Exception as e: +# raise e - # training the model - model_name = model_name.replace(" ", "-").lower() - if type_of_model == "Dreambooth": - return train_dreambooth_model(instance_prompt, class_prompt, training_file_url, - max_train_steps, model_name, images_list, controller_type, model_type_list) - elif type_of_model == "LoRA": - return train_lora_model(training_file_url, type_of_task, resolution, model_name, images_list, model_type_list) +# # training the model +# model_name = model_name.replace(" ", "-").lower() +# if type_of_model == "Dreambooth": +# return train_dreambooth_model(instance_prompt, class_prompt, training_file_url, +# max_train_steps, model_name, images_list, controller_type, model_type_list) +# elif type_of_model == "LoRA": +# return train_lora_model(training_file_url, type_of_task, resolution, model_name, images_list, model_type_list) +# NOTE: code not in use # INFO: images_list passed here are converted to internal files after they are used for training -def train_dreambooth_model(instance_prompt, class_prompt, training_file_url, max_train_steps, model_name, images_list: List[str], controller_type, model_type_list): - from ui_components.methods.common_methods import convert_image_list_to_file_list +# def train_dreambooth_model(instance_prompt, class_prompt, training_file_url, max_train_steps, model_name, images_list: List[str], controller_type, model_type_list): +# from ui_components.methods.common_methods import convert_image_list_to_file_list - ml_client = get_ml_client() - app_setting = DataRepo().get_app_setting_from_uuid() +# ml_client = get_ml_client() +# app_setting = DataRepo().get_app_setting_from_uuid() - response = ml_client.dreambooth_training( - training_file_url, instance_prompt, class_prompt, max_train_steps, model_name, controller_type, len(images_list), app_setting.replicate_username) - training_status = response["status"] +# response = ml_client.dreambooth_training( +# training_file_url, instance_prompt, class_prompt, max_train_steps, model_name, controller_type, len(images_list), app_setting.replicate_username) +# training_status = response["status"] - model_id = response["id"] - if training_status == "queued": - file_list = convert_image_list_to_file_list(images_list) - file_uuid_list = [file.uuid for file in file_list] - file_uuid_list = json.dumps(file_uuid_list) +# model_id = response["id"] +# if training_status == "queued": +# file_list = convert_image_list_to_file_list(images_list) +# file_uuid_list = [file.uuid for file in file_list] +# file_uuid_list = json.dumps(file_uuid_list) - model_data = { - "name": model_name, - "user_id": get_current_user_uuid(), - "replicate_model_id": model_id, - "replicate_url": response["model"], - "diffusers_url": "", - "category": AIModelCategory.DREAMBOOTH.value, - "training_image_list": file_uuid_list, - "keyword": instance_prompt, - "custom_trained": True, - "model_type": model_type_list - } +# model_data = { +# "name": model_name, +# "user_id": get_current_user_uuid(), +# "replicate_model_id": model_id, +# "replicate_url": response["model"], +# "diffusers_url": "", +# "category": AIModelCategory.DREAMBOOTH.value, +# "training_image_list": file_uuid_list, +# "keyword": instance_prompt, +# "custom_trained": True, +# "model_type": model_type_list +# } - data_repo = DataRepo() - data_repo.create_ai_model(**model_data) +# data_repo = DataRepo() +# data_repo.create_ai_model(**model_data) - return "Success - Training Started. Please wait 10-15 minutes for the model to be trained." - else: - return "Failed" +# return "Success - Training Started. Please wait 10-15 minutes for the model to be trained." +# else: +# return "Failed" +# NOTE: code not in use # INFO: images_list passed here are converted to internal files after they are used for training -def train_lora_model(training_file_url, type_of_task, resolution, model_name, images_list, model_type_list): - from ui_components.methods.common_methods import convert_image_list_to_file_list +# def train_lora_model(training_file_url, type_of_task, resolution, model_name, images_list, model_type_list): +# from ui_components.methods.common_methods import convert_image_list_to_file_list - data_repo = DataRepo() - ml_client = get_ml_client() - output = ml_client.predict_model_output(REPLICATE_MODEL.clones_lora_training, instance_data=training_file_url, - task=type_of_task, resolution=int(resolution)) +# data_repo = DataRepo() +# ml_client = get_ml_client() +# output = ml_client.predict_model_output(ML_MODEL.clones_lora_training, instance_data=training_file_url, +# task=type_of_task, resolution=int(resolution)) - file_list = convert_image_list_to_file_list(images_list) - file_uuid_list = [file.uuid for file in file_list] - file_uuid_list = json.dumps(file_uuid_list) - model_data = { - "name": model_name, - "user_id": get_current_user_uuid(), - "replicate_url": output, - "diffusers_url": "", - "category": AIModelCategory.LORA.value, - "training_image_list": file_uuid_list, - "custom_trained": True, - "model_type": model_type_list - } +# file_list = convert_image_list_to_file_list(images_list) +# file_uuid_list = [file.uuid for file in file_list] +# file_uuid_list = json.dumps(file_uuid_list) +# model_data = { +# "name": model_name, +# "user_id": get_current_user_uuid(), +# "replicate_url": output, +# "diffusers_url": "", +# "category": AIModelCategory.LORA.value, +# "training_image_list": file_uuid_list, +# "custom_trained": True, +# "model_type": model_type_list +# } - data_repo.create_ai_model(**model_data) - return f"Successfully trained - the model '{model_name}' is now available for use!" +# data_repo.create_ai_model(**model_data) +# return f"Successfully trained - the model '{model_name}' is now available for use!" diff --git a/ui_components/methods/video_methods.py b/ui_components/methods/video_methods.py index 506871b4..0bf6f00b 100644 --- a/ui_components/methods/video_methods.py +++ b/ui_components/methods/video_methods.py @@ -40,6 +40,7 @@ def create_single_interpolated_clip(shot_uuid, quality, settings={}, variant_cou img_list = [t.primary_image.location for t in timing_list] settings.update(interpolation_steps=interpolation_steps) + settings.update(file_uuid_list=[t.primary_image.uuid for t in timing_list]) # res is an array of tuples (video_bytes, log) res = VideoInterpolator.create_interpolated_clip( @@ -250,7 +251,7 @@ def render_video(final_video_name, project_uuid, file_tag=InternalFileTag.GENERA for shot in shot_list: if not shot.main_clip: st.error("Please generate all videos") - time.sleep(0.3) + time.sleep(0.7) return False shot_video = sync_audio_and_duration(shot.main_clip, shot.uuid, audio_sync_required=False) diff --git a/ui_components/models.py b/ui_components/models.py index 252945a4..5f913245 100644 --- a/ui_components/models.py +++ b/ui_components/models.py @@ -1,5 +1,7 @@ import datetime import json +import os +from urllib.parse import urlparse from shared.constants import InferenceParamType, ProjectMetaData from ui_components.constants import DefaultProjectSettingParams, DefaultTimingStyleParams @@ -39,6 +41,17 @@ def inference_params(self): return json.loads(log.input_params) return None + + @property + def filename(self): + input_path = self.location + if urlparse(input_path).scheme: + parts = urlparse(input_path) + filename = os.path.basename(parts.path) + else: + filename = os.path.basename(input_path) + + return filename class InternalProjectObject: diff --git a/ui_components/setup.py b/ui_components/setup.py index c58400c1..0f65b7dd 100644 --- a/ui_components/setup.py +++ b/ui_components/setup.py @@ -7,7 +7,6 @@ from ui_components.widgets.timeline_view import timeline_view from ui_components.widgets.sidebar_logger import sidebar_logger from ui_components.components.app_settings_page import app_settings_page -from ui_components.components.custom_models_page import custom_models_page from ui_components.components.frame_styling_page import frame_styling_page from ui_components.components.shortlist_page import shortlist_page from ui_components.components.timeline_view_page import timeline_view_page @@ -137,7 +136,7 @@ def setup_app_ui(): st.session_state["index_of_page"] = 0 with st.sidebar: - main_view_types = ["Creative Process", "Tools & Settings", "Video Rendering"] + main_view_types = ["Creative Process", "Project Settings", "Video Rendering"] st.session_state['main_view_type'] = st_memory.menu(None, main_view_types, icons=['search-heart', 'tools', "play-circle", 'stopwatch'], menu_icon="cast", default_index=0, key="main_view_type_name", orientation="horizontal", styles={ "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "red"}}) @@ -150,21 +149,23 @@ def setup_app_ui(): st.session_state['creative_process_manual_select'] = 0 st.session_state['page'] = creative_process_pages[0] - h1,h2 = st.columns([1.5,1]) - with h1: - # view_types = ["Explorer","Timeline","Individual"] - creative_process_pages = ["Explore", "Shortlist", "Timeline", "Adjust Shot", "Adjust Frame", "Animate Shot"] - st.session_state['page'] = option_menu( - None, - creative_process_pages, - icons=['compass', 'bookshelf','aspect-ratio', "hourglass", 'stopwatch'], - menu_icon="cast", - orientation="vertical", - key="section-selecto1r", - styles={"nav-link": {"font-size": "15px", "margin":"0px", "--hover-color": "#eee"}, - "nav-link-selected": {"background-color": "green"}}, - manual_select=st.session_state['creative_process_manual_select'] - ) + + # view_types = ["Explorer","Timeline","Individual"] + creative_process_pages = ["Explore", "Shortlist", "Timeline", "Adjust Shot", "Animate Shot"] + st.session_state['page'] = option_menu( + None, + creative_process_pages, + icons=['compass', 'bookshelf','aspect-ratio', "hourglass", 'stopwatch'], + menu_icon="cast", + orientation="vertical", + key="section-selecto1r", + styles={"nav-link": {"font-size": "15px", "margin":"0px", "--hover-color": "#eee"}, + "nav-link-selected": {"background-color": "green"}}, + manual_select=st.session_state['creative_process_manual_select'] + ) + + if st.session_state['page'] != "Adjust Shot": + st.session_state['current_frame_sidebar_selector'] = 0 if st.session_state['creative_process_manual_select'] != None: st.session_state['creative_process_manual_select'] = None @@ -179,9 +180,6 @@ def setup_app_ui(): elif st.session_state['page'] == "Timeline": timeline_view_page(st.session_state["shot_uuid"], h2) - elif st.session_state['page'] == "Adjust Frame": - frame_styling_page(st.session_state["shot_uuid"], h2) - elif st.session_state['page'] == "Adjust Shot": adjust_shot_page(st.session_state["shot_uuid"], h2) @@ -192,24 +190,11 @@ def setup_app_ui(): with st.expander("πŸ” Generation Log", expanded=True): if st_memory.toggle("Open", value=True, key="generaton_log_toggle"): sidebar_logger(st.session_state["shot_uuid"]) - st.markdown("***") + # st.markdown("***") - elif st.session_state["main_view_type"] == "Tools & Settings": - with st.sidebar: - tool_pages = ["Query Logger", "Project Settings"] - - if st.session_state["page"] not in tool_pages: - st.session_state["page"] = tool_pages[0] - st.session_state["manual_select"] = None - - st.session_state['page'] = option_menu(None, tool_pages, icons=['pencil', 'palette', "hourglass", 'stopwatch'], menu_icon="cast", orientation="horizontal", key="secti2on_selector", styles={ - "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "green"}}, manual_select=st.session_state["manual_select"]) - if st.session_state["page"] == "Query Logger": - query_logger_page() - if st.session_state["page"] == "Custom Models": - custom_models_page(st.session_state["project_uuid"]) - elif st.session_state["page"] == "Project Settings": - project_settings_page(st.session_state["project_uuid"]) + elif st.session_state["main_view_type"] == "Project Settings": + + project_settings_page(st.session_state["project_uuid"]) elif st.session_state["main_view_type"] == "Video Rendering": video_rendering_page(st.session_state["project_uuid"]) diff --git a/ui_components/widgets/animation_style_element.py b/ui_components/widgets/animation_style_element.py index d6d7f20d..b1bbe819 100644 --- a/ui_components/widgets/animation_style_element.py +++ b/ui_components/widgets/animation_style_element.py @@ -3,7 +3,7 @@ import streamlit as st from typing import List from shared.constants import AnimationStyleType, AnimationToolType -from ui_components.constants import DefaultProjectSettingParams +from ui_components.constants import DEFAULT_SHOT_MOTION_VALUES, DefaultProjectSettingParams, ShotMetaData from ui_components.methods.video_methods import create_single_interpolated_clip from utils.data_repo.data_repo import DataRepo from utils.ml_processor.motion_module import AnimateDiffCheckpoint @@ -13,7 +13,8 @@ import matplotlib.pyplot as plt def animation_style_element(shot_uuid): - + disable_generate = False + help = "" motion_modules = AnimateDiffCheckpoint.get_name_list() variant_count = 1 current_animation_style = AnimationStyleType.CREATIVE_INTERPOLATION.value # setting a default value @@ -22,347 +23,328 @@ def animation_style_element(shot_uuid): shot: InternalShotObject = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"]) st.session_state['project_uuid'] = str(shot.project.uuid) timing_list: List[InternalFrameTimingObject] = shot.timing_list - - - video_resolution = None + buffer = 4 + settings = { 'animation_tool': AnimationToolType.ANIMATEDIFF.value, } - - st.markdown("#### Key Frame Settings") - d1, d2 = st.columns([1, 4]) - st.session_state['frame_position'] = 0 - with d1: - setting_a_1, setting_a_2, = st.columns([1, 1]) - with setting_a_1: - type_of_frame_distribution = st_memory.radio("Type of key frame distribution:", options=["Linear", "Dynamic"], key="type_of_frame_distribution").lower() - if type_of_frame_distribution == "linear": - with setting_a_2: - linear_frame_distribution_value = st_memory.number_input("Frames per key frame:", min_value=8, max_value=36, value=16, step=1, key="linear_frame_distribution_value") - dynamic_frame_distribution_values = [] - st.markdown("***") - setting_b_1, setting_b_2 = st.columns([1, 1]) - with setting_b_1: - type_of_key_frame_influence = st_memory.radio("Type of key frame length influence:", options=["Linear", "Dynamic"], key="type_of_key_frame_influence").lower() - if type_of_key_frame_influence == "linear": - with setting_b_2: - linear_key_frame_influence_value = st_memory.number_input("Length of key frame influence:", min_value=0.1, max_value=5.0, value=1.0, step=0.01, key="linear_key_frame_influence_value") - dynamic_key_frame_influence_values = [] - st.markdown("***") - - setting_d_1, setting_d_2 = st.columns([1, 1]) - - with setting_d_1: - type_of_cn_strength_distribution = st_memory.radio("Type of key frame strength control:", options=["Linear", "Dynamic"], key="type_of_cn_strength_distribution").lower() - if type_of_cn_strength_distribution == "linear": - with setting_d_2: - linear_cn_strength_value = st_memory.slider("Range of strength:", min_value=0.0, max_value=1.0, value=(0.0,0.7), step=0.01, key="linear_cn_strength_value") - dynamic_cn_strength_values = [] + interpolation_style = 'ease-in-out' + + advanced1, advanced2, advanced3 = st.columns([1.0,1.5, 1.0]) + with advanced1: + st.markdown("#### Animation Settings") + + with advanced3: + with st.expander("Bulk edit"): + what_would_you_like_to_edit = st.selectbox("What would you like to edit?", options=["Seconds to next frames", "Speed of transitions", "Freedom between frames","Strength of frames", "Motion during frames"], key="what_would_you_like_to_edit") + if what_would_you_like_to_edit == "Seconds to next frames": + what_to_change_it_to = st.slider("What would you like to change it to?", min_value=0.25, max_value=6.00, step=0.25, value=1.0, key="what_to_change_it_to") + if what_would_you_like_to_edit == "Strength of frames": + what_to_change_it_to = st.slider("What would you like to change it to?", min_value=0.25, max_value=1.0, step=0.01, value=0.5, key="what_to_change_it_to") + elif what_would_you_like_to_edit == "Speed of transitions": + what_to_change_it_to = st.slider("What would you like to change it to?", min_value=0.45, max_value=0.7, step=0.01, value=0.6, key="what_to_change_it_to") + elif what_would_you_like_to_edit == "Freedom between frames": + what_to_change_it_to = st.slider("What would you like to change it to?", min_value=0.2, max_value=0.95, step=0.01, value=0.5, key="what_to_change_it_to") + elif what_would_you_like_to_edit == "Motion during frames": + what_to_change_it_to = st.slider("What would you like to change it to?", min_value=0.5, max_value=1.5, step=0.01, value=1.3, key="what_to_change_it_to") + + if st.button("Bulk edit", key="bulk_edit"): + if what_would_you_like_to_edit == "Strength of frames": + for idx, timing in enumerate(timing_list): + st.session_state[f'strength_of_frame_{shot.uuid}_{idx}'] = what_to_change_it_to + elif what_would_you_like_to_edit == "Seconds to next frames": + for idx, timing in enumerate(timing_list): + st.session_state[f'distance_to_next_frame_{shot.uuid}_{idx}'] = what_to_change_it_to + elif what_would_you_like_to_edit == "Speed of transitions": + for idx, timing in enumerate(timing_list): + st.session_state[f'speed_of_transition_{shot.uuid}_{idx}'] = what_to_change_it_to + elif what_would_you_like_to_edit == "Freedom between frames": + for idx, timing in enumerate(timing_list): + st.session_state[f'freedom_between_frames_{shot.uuid}_{idx}'] = what_to_change_it_to + elif what_would_you_like_to_edit == "Motion during frames": + for idx, timing in enumerate(timing_list): + st.session_state[f'motion_during_frame_{shot.uuid}_{idx}'] = what_to_change_it_to + + st.markdown("***") + type_of_setting = "Individual" + if type_of_setting == "Individual": + items_per_row = 3 + strength_of_frames = [] + distances_to_next_frames = [] + speeds_of_transitions = [] + freedoms_between_frames = [] + individual_prompts = [] + individual_negative_prompts = [] + motions_during_frames = [] + + for i in range(0, len(timing_list) , items_per_row): + with st.container(): + grid = st.columns([2 if j%2==0 else 1 for j in range(2*items_per_row)]) # Adjust the column widths + for j in range(items_per_row): + idx = i + j + if idx < len(timing_list): + with grid[2*j]: # Adjust the index for image column + timing = timing_list[idx] + if timing.primary_image and timing.primary_image.location: + st.info(f"**Frame {idx + 1}**") + st.image(timing.primary_image.location, use_column_width=True) + + motion_data = DEFAULT_SHOT_MOTION_VALUES + # setting default parameters (fetching data from the shot if it's present) + if f'strength_of_frame_{shot.uuid}_{idx}' not in st.session_state: + shot_meta_data = shot.meta_data_dict.get(ShotMetaData.MOTION_DATA.value, json.dumps({})) + timing_data = json.loads(shot_meta_data).get("timing_data", []) + if timing_data and len(timing_data) >= idx + 1: + motion_data = timing_data[idx] + + for k, v in motion_data.items(): + st.session_state[f'{k}_{shot.uuid}_{idx}'] = v + + # settings control + with st.expander("Advanced settings:"): + strength_of_frame = st.slider("Strength of current frame:", min_value=0.25, max_value=1.0, step=0.01, key=f"strength_of_frame_widget_{shot.uuid}_{idx}", value=st.session_state[f'strength_of_frame_{shot.uuid}_{idx}']) + strength_of_frames.append(strength_of_frame) + individual_prompt = st.text_input("What to include:", key=f"individual_prompt_widget_{idx}_{timing.uuid}", value=st.session_state[f'individual_prompt_{shot.uuid}_{idx}'], help="Use this sparingly, as it can have a large impact on the video and cause weird distortions.") + individual_prompts.append(individual_prompt) + individual_negative_prompt = st.text_input("What to avoid:", key=f"negative_prompt_widget_{idx}_{timing.uuid}", value=st.session_state[f'individual_negative_prompt_{shot.uuid}_{idx}'],help="Use this sparingly, as it can have a large impact on the video and cause weird distortions.") + individual_negative_prompts.append(individual_negative_prompt) + # motion_during_frame = st.slider("Motion during frame:", min_value=0.5, max_value=1.5, step=0.01, key=f"motion_during_frame_widget_{idx}_{timing.uuid}", value=st.session_state[f'motion_during_frame_{shot.uuid}_{idx}']) + motion_during_frame = 1.3 + motions_during_frames.append(motion_during_frame) + else: + st.warning("No primary image present.") + + # distance, speed and freedom settings (also aggregates them into arrays) + with grid[2*j+1]: # Add the new column after the image column + if idx < len(timing_list) - 1: + st.write("") + st.write("") + st.write("") + st.write("") + # if st.session_state[f'distance_to_next_frame_{shot.uuid}_{idx}'] is a int, make it a float + if isinstance(st.session_state[f'distance_to_next_frame_{shot.uuid}_{idx}'], int): + st.session_state[f'distance_to_next_frame_{shot.uuid}_{idx}'] = float(st.session_state[f'distance_to_next_frame_{shot.uuid}_{idx}']) + distance_to_next_frame = st.slider("Seconds to next frame:", min_value=0.25, max_value=6.00, step=0.25, key=f"distance_to_next_frame_widget_{idx}_{timing.uuid}", value=st.session_state[f'distance_to_next_frame_{shot.uuid}_{idx}']) + distances_to_next_frames.append(distance_to_next_frame) + speed_of_transition = st.slider("Speed of transition:", min_value=0.45, max_value=0.7, step=0.01, key=f"speed_of_transition_widget_{idx}_{timing.uuid}", value=st.session_state[f'speed_of_transition_{shot.uuid}_{idx}']) + speeds_of_transitions.append(speed_of_transition) + freedom_between_frames = st.slider("Freedom between frames:", min_value=0.2, max_value=0.95, step=0.01, key=f"freedom_between_frames_widget_{idx}_{timing.uuid}", value=st.session_state[f'freedom_between_frames_{shot.uuid}_{idx}']) + freedoms_between_frames.append(freedom_between_frames) + + if (i < len(timing_list) - 1) or (len(timing_list) % items_per_row != 0): + st.markdown("***") - st.markdown("***") - footer1, _ = st.columns([2, 1]) - with footer1: - interpolation_style = 'ease-in-out' - motion_scale = st_memory.slider("Motion scale:", min_value=0.0, max_value=2.0, value=1.0, step=0.01, key="motion_scale") + with advanced1: + if st.button("Save current settings", key="save_current_settings"): + update_session_state_with_animation_details(shot.uuid, timing_list, strength_of_frames, distances_to_next_frames, speeds_of_transitions, freedoms_between_frames, motions_during_frames, individual_prompts, individual_negative_prompts) + st.success("Settings saved successfully.") + time.sleep(0.7) + st.rerun() - st.markdown("***") - if st.button("Reset to default settings", key="reset_animation_style"): - update_interpolation_settings(timing_list=timing_list) - st.rerun() - - with d2: - columns = st.columns(max(7, len(timing_list))) - disable_generate = False - help = "" - dynamic_frame_distribution_values = [] - dynamic_key_frame_influence_values = [] - dynamic_cn_strength_values = [] - - - for idx, timing in enumerate(timing_list): - # Use modulus to cycle through colors - # color = color_names[idx % len(color_names)] - # Only create markdown text for the current index - markdown_text = f'##### **Frame {idx + 1}** ___' - - with columns[idx]: - st.markdown(markdown_text) - - if timing.primary_image and timing.primary_image.location: - columns[idx].image(timing.primary_image.location, use_column_width=True) - b = timing.primary_image.inference_params - if type_of_frame_distribution == "dynamic": - linear_frame_distribution_value = 16 - if f"frame_{idx+1}" not in st.session_state: - st.session_state[f"frame_{idx+1}"] = idx * 16 # Default values in increments of 16 - if idx == 0: # For the first frame, position is locked to 0 - with columns[idx]: - frame_position = st_memory.number_input(f"{idx+1} frame Position", min_value=0, max_value=0, value=0, step=1, key=f"dynamic_frame_distribution_values_{idx}", disabled=True) - else: - min_value = st.session_state[f"frame_{idx}"] + 1 - with columns[idx]: - frame_position = st_memory.number_input(f"#{idx+1} position:", min_value=min_value, value=st.session_state[f"frame_{idx+1}"], step=1, key=f"dynamic_frame_distribution_values_{idx}") - # st.session_state[f"frame_{idx+1}"] = frame_position - dynamic_frame_distribution_values.append(frame_position) - - if type_of_key_frame_influence == "dynamic": - linear_key_frame_influence_value = 1.1 - with columns[idx]: - dynamic_key_frame_influence_individual_value = st_memory.slider(f"#{idx+1} length of influence:", min_value=0.0, max_value=5.0, value=1.0, step=0.1, key=f"dynamic_key_frame_influence_values_{idx}") - dynamic_key_frame_influence_values.append(str(dynamic_key_frame_influence_individual_value)) - - if type_of_cn_strength_distribution == "dynamic": - linear_cn_strength_value = (0.0,1.0) - with columns[idx]: - help_texts = ["For the first frame, it'll start at the endpoint and decline to the starting point", - "For the final frame, it'll start at the starting point and end at the endpoint", - "For intermediate frames, it'll start at the starting point, peak in the middle at the endpoint, and decline to the starting point"] - label_texts = [f"#{idx+1} end -> start:", f"#{idx+1} start -> end:", f"#{idx+1} start -> peak:"] - help_text = help_texts[0] if idx == 0 else help_texts[1] if idx == len(timing_list) - 1 else help_texts[2] - label_text = label_texts[0] if idx == 0 else label_texts[1] if idx == len(timing_list) - 1 else label_texts[2] - dynamic_cn_strength_individual_value = st_memory.slider(label_text, min_value=0.0, max_value=1.0, value=(0.0,0.7), step=0.1, key=f"dynamic_cn_strength_values_{idx}",help=help_text) - dynamic_cn_strength_values.append(str(dynamic_cn_strength_individual_value)) - - # Convert lists to strings - dynamic_frame_distribution_values = ",".join(map(str, dynamic_frame_distribution_values)) # Convert integers to strings before joining - dynamic_key_frame_influence_values = ",".join(dynamic_key_frame_influence_values) - dynamic_cn_strength_values = ",".join(dynamic_cn_strength_values) - # dynamic_start_and_endpoint_values = ",".join(dynamic_start_and_endpoint_values) - # st.write(dynamic_start_and_endpoint_values) + dynamic_strength_values, dynamic_key_frame_influence_values, dynamic_frame_distribution_values = transform_data(strength_of_frames, freedoms_between_frames, speeds_of_transitions, distances_to_next_frames) + + type_of_frame_distribution = "dynamic" + type_of_key_frame_influence = "dynamic" + type_of_strength_distribution = "dynamic" + linear_frame_distribution_value = 16 + linear_key_frame_influence_value = 1.0 + linear_cn_strength_value = 1.0 + + with st.sidebar: + with st.expander("πŸ“ˆ Visualise motion data", expanded=True): + if st_memory.toggle("Visualise motion data"): - - def calculate_dynamic_influence_ranges(keyframe_positions, key_frame_influence_values, allow_extension=True): - if len(keyframe_positions) < 2 or len(keyframe_positions) != len(key_frame_influence_values): - return [] + keyframe_positions = get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, timing_list, linear_frame_distribution_value) + keyframe_positions = [int(kf * 16) for kf in keyframe_positions] + last_key_frame_position = (keyframe_positions[-1]) + strength_values = extract_strength_values(type_of_strength_distribution, dynamic_strength_values, keyframe_positions, linear_cn_strength_value) + key_frame_influence_values = extract_influence_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value) + weights_list, frame_numbers_list = calculate_weights(keyframe_positions, strength_values, 4, key_frame_influence_values,last_key_frame_position) + plot_weights(weights_list, frame_numbers_list) - influence_ranges = [] - for i, position in enumerate(keyframe_positions): - influence_factor = key_frame_influence_values[i] - range_size = influence_factor * (keyframe_positions[-1] - keyframe_positions[0]) / (len(keyframe_positions) - 1) / 2 - start_influence = position - range_size - end_influence = position + range_size - - # If extension beyond the adjacent keyframe is allowed, do not constrain the start and end influence. - if not allow_extension: - start_influence = max(start_influence, keyframe_positions[i - 1] if i > 0 else 0) - end_influence = min(end_influence, keyframe_positions[i + 1] if i < len(keyframe_positions) - 1 else keyframe_positions[-1]) + st.markdown("***") + st.markdown("#### Overall style settings") + + sd_model_list = [ + "Realistic_Vision_V5.1.safetensors", + "anything-v3-fp16-pruned.safetensors", + "counterfeitV30_25.safetensors", + "Deliberate_v2.safetensors", + "dreamshaper_8.safetensors", + "epicrealism_pureEvolutionV5.safetensors", + "majicmixRealistic_v6.safetensors", + "perfectWorld_v6Baked.safetensors", + "wd-illusion-fp16.safetensors", + "aniverse_v13.safetensors", + "juggernaut_v21.safetensor" + ] + # remove .safe tensors from the end of each model name + # motion_scale = st_memory.slider("Motion scale:", min_value=0.0, max_value=2.0, value=1.3, step=0.01, key="motion_scale") + z1,z2 = st.columns([1, 1]) + with z1: + sd_model = st_memory.selectbox("Which model would you like to use?", options=sd_model_list, key="sd_model_video") - influence_ranges.append((round(start_influence), round(end_influence))) + e1, e2, e3 = st.columns([1, 1,1]) - return influence_ranges - - - def get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, images, linear_frame_distribution_value): - if type_of_frame_distribution == "dynamic": - return sorted([int(kf.strip()) for kf in dynamic_frame_distribution_values.split(',')]) - else: - return [i * linear_frame_distribution_value for i in range(len(images))] - - def extract_keyframe_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): - if type_of_key_frame_influence == "dynamic": - return [float(influence.strip()) for influence in dynamic_key_frame_influence_values.split(',')] - else: - return [linear_key_frame_influence_value for _ in keyframe_positions] + with e1: - def extract_start_and_endpoint_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): - if type_of_key_frame_influence == "dynamic": - # If dynamic_key_frame_influence_values is a list of characters representing tuples, process it - if isinstance(dynamic_key_frame_influence_values[0], str) and dynamic_key_frame_influence_values[0] == "(": - # Join the characters to form a single string and evaluate to convert into a list of tuples - string_representation = ''.join(dynamic_key_frame_influence_values) - dynamic_values = eval(f'[{string_representation}]') - else: - # If it's already a list of tuples or a single tuple, use it directly - dynamic_values = dynamic_key_frame_influence_values if isinstance(dynamic_key_frame_influence_values, list) else [dynamic_key_frame_influence_values] - return dynamic_values - else: - # Return a list of tuples with the linear_key_frame_influence_value as a tuple repeated for each position - return [linear_key_frame_influence_value for _ in keyframe_positions] - def calculate_weights(influence_ranges, interpolation, start_and_endpoint_strength, last_key_frame_position): - weights_list = [] - frame_numbers_list = [] - for i, (range_start, range_end) in enumerate(influence_ranges): - # Initialize variables - if i == 0: - strength_to, strength_from = start_and_endpoint_strength[i] if i < len(start_and_endpoint_strength) else (0.0, 1.0) - else: - strength_from, strength_to = start_and_endpoint_strength[i] if i < len(start_and_endpoint_strength) else (1.0, 0.0) - revert_direction_at_midpoint = (i != 0) and (i != len(influence_ranges) - 1) - - # if it's the first value, set influence range from 1.0 to 0.0 - if i == 0: - range_start = 0 - - # if it's the last value, set influence range to end at last_key_frame_position - if i == len(influence_ranges) - 1: - range_end = last_key_frame_position - - steps = range_end - range_start - diff = strength_to - strength_from - - # Calculate index for interpolation - index = np.linspace(0, 1, steps // 2 + 1) if revert_direction_at_midpoint else np.linspace(0, 1, steps) - - # Calculate weights based on interpolation type - if interpolation == "linear": - weights = np.linspace(strength_from, strength_to, len(index)) - elif interpolation == "ease-in": - weights = diff * np.power(index, 2) + strength_from - elif interpolation == "ease-out": - weights = diff * (1 - np.power(1 - index, 2)) + strength_from - elif interpolation == "ease-in-out": - weights = diff * ((1 - np.cos(index * np.pi)) / 2) + strength_from - - # If it's a middle keyframe, mirror the weights - if revert_direction_at_midpoint: - weights = np.concatenate([weights, weights[::-1]]) - - # Generate frame numbers - frame_numbers = np.arange(range_start, range_start + len(weights)) - - # "Dropper" component: For keyframes with negative start, drop the weights - if range_start < 0 and i > 0: - drop_count = abs(range_start) - weights = weights[drop_count:] - frame_numbers = frame_numbers[drop_count:] - - # Dropper component: for keyframes a range_End is greater than last_key_frame_position, drop the weights - if range_end > last_key_frame_position and i < len(influence_ranges) - 1: - drop_count = range_end - last_key_frame_position - weights = weights[:-drop_count] - frame_numbers = frame_numbers[:-drop_count] - - weights_list.append(weights) - frame_numbers_list.append(frame_numbers) - - return weights_list, frame_numbers_list - - - def plot_weights(weights_list, frame_numbers_list, frame_names): - plt.figure(figsize=(12, 6)) - - for i, weights in enumerate(weights_list): - frame_numbers = frame_numbers_list[i] - plt.plot(frame_numbers, weights, label=f'{frame_names[i]}') - - # Plot settings - plt.xlabel('Frame Number') - plt.ylabel('Weight') - plt.legend() - plt.ylim(0, 1.0) - plt.show() - st.set_option('deprecation.showPyplotGlobalUse', False) - st.pyplot() - - keyframe_positions = get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, timing_list, linear_frame_distribution_value) - last_key_frame_position = keyframe_positions[-1] - cn_strength_values = extract_start_and_endpoint_values(type_of_cn_strength_distribution, dynamic_cn_strength_values, keyframe_positions, linear_cn_strength_value) - key_frame_influence_values = extract_keyframe_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value) - # start_and_endpoint_values = extract_start_and_endpoint_values(type_of_start_and_endpoint, dynamic_start_and_endpoint_values, keyframe_positions, linear_start_and_endpoint_value) - influence_ranges = calculate_dynamic_influence_ranges(keyframe_positions, key_frame_influence_values) - weights_list, frame_numbers_list = calculate_weights(influence_ranges, interpolation_style, cn_strength_values, last_key_frame_position) - frame_names = [f'Frame {i+1}' for i in range(len(influence_ranges))] - plot_weights(weights_list, frame_numbers_list, frame_names) + strength_of_adherence = st_memory.slider("How much would you like to force adherence to the input images?", min_value=0.0, max_value=1.0, value=0.3, step=0.01, key="stregnth_of_adherence") + with e2: + st.info("Higher values may cause flickering and sudden changes in the video. Lower values may cause the video to be less influenced by the input images but can also fix colouring issues.") + + f1, f2, f3 = st.columns([1, 1, 1]) + + with f1: + overall_positive_prompt = st_memory.text_area("What would you like to see in the videos?", value="", key="positive_prompt_video") + with f2: + overall_negative_prompt = st_memory.text_area("What would you like to avoid in the videos?", value="", key="negative_prompt_video") + + with f3: + st.write("") + st.write("") + st.info("Use these sparingly, as they can have a large impact on the video. You can also edit them for individual frames in the advanced settings above.") + soft_scaled_cn_weights_multiplier = "" st.markdown("***") - e1, e2 = st.columns([1, 1]) + st.markdown("#### Overall motion settings") + h1, h2, h3 = st.columns([0.5, 1.5, 1]) + with h1: + + type_of_motion_context = st.radio("Type of motion context:", options=["Low", "Standard", "High"], key="type_of_motion_context", horizontal=False, index=1) + + with h2: + st.info("This is how much the motion will be informed by the previous and next frames. 'High' can make it smoother but increase artifacts - while 'Low' make the motion less smooth but removes artifacts. Naturally, we recommend Standard.") - with e1: - st.markdown("#### Styling Settings") - sd_model_list = [ - "Realistic_Vision_V5.0.safetensors", - "Counterfeit-V3.0_fp32.safetensors", - "epic_realism.safetensors", - "dreamshaper_v8.safetensors", - "deliberate_v3.safetensors" - ] - - # remove .safe tensors from the end of each model name - sd_model = st_memory.selectbox("Which model would you like to use?", options=sd_model_list, key="sd_model_video") - negative_prompt = st_memory.text_area("What would you like to avoid in the videos?", value="bad image, worst quality", key="negative_prompt_video") - relative_ipadapter_strength = st_memory.slider("How much would you like to influence the style?", min_value=0.0, max_value=5.0, value=1.1, step=0.1, key="ip_adapter_strength") - relative_ipadapter_influence = st_memory.slider("For how long would you like to influence the style?", min_value=0.0, max_value=5.0, value=1.1, step=0.1, key="ip_adapter_influence") - soft_scaled_cn_weights_multipler = st_memory.slider("How much would you like to scale the CN weights?", min_value=0.0, max_value=10.0, value=0.85, step=0.1, key="soft_scaled_cn_weights_multiple_video") - append_to_prompt = st_memory.text_input("What would you like to append to the prompts?", key="append_to_prompt") - - normalise_speed = True + i1, i2, i3 = st.columns([1, 1, 1]) + with i1: + motion_scale = st.slider("Motion scale:", min_value=0.0, max_value=2.0, value=1.3, step=0.01, key="motion_scale") + + with i2: + st.info("This is how much the video moves. Above 1.4 gets jittery, below 0.8 makes it too fluid.") + context_length = 16 + context_stride = 2 + context_overlap = 4 + + if type_of_motion_context == "Low": + context_length = 16 + context_stride = 1 + context_overlap = 2 + + elif type_of_motion_context == "Standard": + context_length = 16 + context_stride = 2 + context_overlap = 4 + elif type_of_motion_context == "High": + context_length = 16 + context_stride = 4 + context_overlap = 4 + + + relative_ipadapter_strength = 1.0 + relative_cn_strength = 0.0 project_settings = data_repo.get_project_setting(shot.project.uuid) width = project_settings.width height = project_settings.height img_dimension = f"{width}x{height}" + # st.write(dynamic_frame_distribution_values) + dynamic_frame_distribution_values = [float(value) * 16 for value in dynamic_frame_distribution_values] + + # st.write(dynamic_frame_distribution_values) + + individual_prompts = format_frame_prompts_with_buffer(dynamic_frame_distribution_values, individual_prompts, buffer) + individual_negative_prompts = format_frame_prompts_with_buffer(dynamic_frame_distribution_values, individual_negative_prompts, buffer) + + multipled_base_end_percent = 0.05 * (strength_of_adherence * 10) + multipled_base_adapter_strength = 0.05 * (strength_of_adherence * 20) + + motion_scales = format_motion_strengths_with_buffer(dynamic_frame_distribution_values, motions_during_frames, buffer) + settings.update( ckpt=sd_model, + width=width, + height=height, buffer=4, motion_scale=motion_scale, + motion_scales=motion_scales, image_dimension=img_dimension, output_format="video/h264-mp4", - negative_prompt=negative_prompt, + prompt=overall_positive_prompt, + negative_prompt=overall_negative_prompt, interpolation_type=interpolation_style, stmfnet_multiplier=2, relative_ipadapter_strength=relative_ipadapter_strength, - relative_ipadapter_influence=relative_ipadapter_influence, - soft_scaled_cn_weights_multiplier=soft_scaled_cn_weights_multipler, - type_of_cn_strength_distribution=type_of_cn_strength_distribution, - linear_cn_strength_value=str(linear_cn_strength_value), - dynamic_cn_strength_values=str(dynamic_cn_strength_values), - type_of_frame_distribution=type_of_frame_distribution, + relative_cn_strength=relative_cn_strength, + type_of_strength_distribution=type_of_strength_distribution, + linear_strength_value=str(linear_cn_strength_value), + dynamic_strength_values=str(dynamic_strength_values), linear_frame_distribution_value=linear_frame_distribution_value, dynamic_frame_distribution_values=dynamic_frame_distribution_values, + type_of_frame_distribution=type_of_frame_distribution, type_of_key_frame_influence=type_of_key_frame_influence, linear_key_frame_influence_value=float(linear_key_frame_influence_value), dynamic_key_frame_influence_values=dynamic_key_frame_influence_values, - normalise_speed=normalise_speed, - animation_style=AnimationStyleType.CREATIVE_INTERPOLATION.value + normalise_speed=False, + ipadapter_noise=0.3, + animation_style=AnimationStyleType.CREATIVE_INTERPOLATION.value, + context_length=context_length, + context_stride=context_stride, + context_overlap=context_overlap, + multipled_base_end_percent=multipled_base_end_percent, + multipled_base_adapter_strength=multipled_base_adapter_strength, + individual_prompts=individual_prompts, + individual_negative_prompts=individual_negative_prompts, + animation_stype=AnimationStyleType.CREATIVE_INTERPOLATION.value, + # make max_frame the final value in the dynamic_frame_distribution_values + max_frames=str(dynamic_frame_distribution_values[-1]) + + ) st.markdown("***") st.markdown("#### Generation Settings") - where_to_generate = st_memory.radio("Where would you like to generate the video?", options=["Cloud", "Local"], key="where_to_generate", horizontal=True) - if where_to_generate == "Cloud": - animate_col_1, animate_col_2, _ = st.columns([1, 1, 2]) - with animate_col_1: - variant_count = st.number_input("How many variants?", min_value=1, max_value=100, value=1, step=1, key="variant_count") - - if st.button("Generate Animation Clip", key="generate_animation_clip", disabled=disable_generate, help=help): - vid_quality = "full" if video_resolution == "Full Resolution" else "preview" - st.success("Generating clip - see status in the Generation Log in the sidebar. Press 'Refresh log' to update.") - - positive_prompt = "" - for idx, timing in enumerate(timing_list): - if timing.primary_image and timing.primary_image.location: - b = timing.primary_image.inference_params - prompt = b['prompt'] if b else "" - prompt += append_to_prompt # Appending the text to each prompt - frame_prompt = f"{idx * linear_frame_distribution_value}_" + prompt - positive_prompt += ":" + frame_prompt if positive_prompt else frame_prompt - else: - st.error("Please generate primary images") - time.sleep(0.7) - st.rerun() - - settings.update( - image_prompt_list=positive_prompt, - animation_stype=current_animation_style, - ) - - create_single_interpolated_clip( - shot_uuid, - vid_quality, - settings, - variant_count - ) - st.rerun() + animate_col_1, animate_col_2, _ = st.columns([1, 1, 2]) + with animate_col_1: + variant_count = st.number_input("How many variants?", min_value=1, max_value=5, value=1, step=1, key="variant_count") + + if st.button("Generate Animation Clip", key="generate_animation_clip", disabled=disable_generate, help=help): + # last keyframe position * 16 + duration = float(dynamic_frame_distribution_values[-1] / 16) + data_repo.update_shot(uuid=shot.uuid, duration=duration) + update_session_state_with_animation_details(shot.uuid, timing_list, strength_of_frames, distances_to_next_frames, speeds_of_transitions, freedoms_between_frames, motions_during_frames, individual_prompts, individual_negative_prompts) + vid_quality = "full" # TODO: add this if video_resolution == "Full Resolution" else "preview" + st.success("Generating clip - see status in the Generation Log in the sidebar. Press 'Refresh log' to update.") + + positive_prompt = "" + append_to_prompt = "" # TODO: add this + for idx, timing in enumerate(timing_list): + if timing.primary_image and timing.primary_image.location: + b = timing.primary_image.inference_params + prompt = b.get("prompt", "") if b else "" + prompt += append_to_prompt + frame_prompt = f"{idx * linear_frame_distribution_value}_" + prompt + positive_prompt += ":" + frame_prompt if positive_prompt else frame_prompt + else: + st.error("Please generate primary images") + time.sleep(0.7) + st.rerun() + + create_single_interpolated_clip( + shot_uuid, + vid_quality, + settings, + variant_count + ) + st.rerun() - with animate_col_2: + with animate_col_2: number_of_frames = len(timing_list) - if height==width: cost_per_key_frame = 0.035 else: @@ -371,43 +353,50 @@ def plot_weights(weights_list, frame_numbers_list, frame_names): cost_per_generation = cost_per_key_frame * number_of_frames * variant_count st.info(f"Generating a video with {number_of_frames} frames in the cloud will cost c. ${cost_per_generation:.2f} USD.") - elif where_to_generate == "Local": - h1,h2 = st.columns([1,1]) - with h1: - st.info("You can run this locally in ComfyUI but you'll need at least 16GB VRAM. To get started, you can follow the instructions [here](https://github.com/peteromallet/steerable-motion) and download the workflow and images below.") - - btn1, btn2 = st.columns([1,1]) - with btn1: - st.download_button( - label="Download workflow JSON", - data=json.dumps(prepare_workflow_json(shot_uuid, settings)), - file_name='workflow.json' - ) - with btn2: - st.download_button( - label="Download images", - data=prepare_workflow_images(shot_uuid), - file_name='data.zip' - ) - -def prepare_workflow_json(shot_uuid, settings): +def update_session_state_with_animation_details(shot_uuid, timing_list, strength_of_frames, distances_to_next_frames, speeds_of_transitions, freedoms_between_frames, motions_during_frames, individual_prompts, individual_negative_prompts): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) + meta_data = shot.meta_data_dict + timing_data = [] + for idx, timing in enumerate(timing_list): + if idx < len(timing_list): + st.session_state[f'strength_of_frame_{shot_uuid}_{idx}'] = strength_of_frames[idx] + st.session_state[f'individual_prompt_{shot_uuid}_{idx}'] = individual_prompts[idx] + st.session_state[f'individual_negative_prompt_{shot_uuid}_{idx}'] = individual_negative_prompts[idx] + st.session_state[f'motion_during_frame_{shot_uuid}_{idx}'] = motions_during_frames[idx] + if idx < len(timing_list) - 1: + st.session_state[f'distance_to_next_frame_{shot_uuid}_{idx}'] = distances_to_next_frames[idx] + st.session_state[f'speed_of_transition_{shot_uuid}_{idx}'] = speeds_of_transitions[idx] + st.session_state[f'freedom_between_frames_{shot_uuid}_{idx}'] = freedoms_between_frames[idx] + + # adding into the meta-data + state_data = { + "strength_of_frame" : strength_of_frames[idx], + "individual_prompt" : individual_prompts[idx], + "individual_negative_prompt" : individual_negative_prompts[idx], + "motion_during_frame" : motions_during_frames[idx], + "distance_to_next_frame" : distances_to_next_frames[idx] if idx < len(timing_list) - 1 else DEFAULT_SHOT_MOTION_VALUES["distance_to_next_frame"], + "speed_of_transition" : speeds_of_transitions[idx] if idx < len(timing_list) - 1 else DEFAULT_SHOT_MOTION_VALUES["speed_of_transition"], + "freedom_between_frames" : freedoms_between_frames[idx] if idx < len(timing_list) - 1 else DEFAULT_SHOT_MOTION_VALUES["freedom_between_frames"], + } + + timing_data.append(state_data) + + meta_data.update({ShotMetaData.MOTION_DATA.value : json.dumps({"timing_data": timing_data})}) + data_repo.update_shot(**{"uuid": shot_uuid, "meta_data": json.dumps(meta_data)}) + + +def format_frame_prompts_with_buffer(frame_numbers, individual_prompts, buffer): + adjusted_frame_numbers = [frame + buffer for frame in frame_numbers] + + # Preprocess prompts to remove any '/' or '"' from the values + processed_prompts = [prompt.replace("/", "").replace('"', '') for prompt in individual_prompts] + + # Format the adjusted frame numbers and processed prompts + formatted = ', '.join(f'"{int(frame)}": "{prompt}"' for frame, prompt in zip(adjusted_frame_numbers, processed_prompts)) + return formatted - positive_prompt = "" - for idx, timing in enumerate(shot.timing_list): - b = None - if timing.primary_image and timing.primary_image.location: - b = timing.primary_image.inference_params - prompt = b['prompt'] if b else "" - frame_prompt = f'"{idx * settings["linear_frame_distribution_value"]}":"{prompt}"' + ("," if idx != len(shot.timing_list) - 1 else "") - positive_prompt += frame_prompt - - settings['image_prompt_list'] = positive_prompt - workflow_data = create_workflow_json(shot.timing_list, settings) - - return workflow_data - +''' def prepare_workflow_images(shot_uuid): import requests import io @@ -434,7 +423,31 @@ def prepare_workflow_images(shot_uuid): buffer.seek(0) return buffer.getvalue() +''' +def extract_strength_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): + + if type_of_key_frame_influence == "dynamic": + # Process the dynamic_key_frame_influence_values depending on its format + if isinstance(dynamic_key_frame_influence_values, str): + dynamic_values = eval(dynamic_key_frame_influence_values) + else: + dynamic_values = dynamic_key_frame_influence_values + + # Iterate through the dynamic values and convert tuples with two values to three values + dynamic_values_corrected = [] + for value in dynamic_values: + if len(value) == 2: + value = (value[0], value[1], value[0]) + dynamic_values_corrected.append(value) + return dynamic_values_corrected + else: + # Process for linear or other types + if len(linear_key_frame_influence_value) == 2: + linear_key_frame_influence_value = (linear_key_frame_influence_value[0], linear_key_frame_influence_value[1], linear_key_frame_influence_value[0]) + return [linear_key_frame_influence_value for _ in range(len(keyframe_positions) - 1)] + +''' def create_workflow_json(image_locations, settings): import os @@ -458,7 +471,7 @@ def create_workflow_json(image_locations, settings): type_of_cn_strength_distribution=settings['type_of_cn_strength_distribution'] linear_cn_strength_value=settings['linear_cn_strength_value'] buffer = settings['buffer'] - dynamic_cn_strength_values = settings['dynamic_cn_strength_values'] + dynamic_strength_values = settings['dynamic_strength_values'] interpolation_type = settings['interpolation_type'] ckpt = settings['ckpt'] motion_scale = settings['motion_scale'] @@ -466,7 +479,7 @@ def create_workflow_json(image_locations, settings): relative_ipadapter_influence = settings['relative_ipadapter_influence'] image_dimension = settings['image_dimension'] output_format = settings['output_format'] - soft_scaled_cn_weights_multiplier = settings['soft_scaled_cn_weights_multiplier'] + # soft_scaled_cn_weights_multiplier = settings['soft_scaled_cn_weights_multiplier'] stmfnet_multiplier = settings['stmfnet_multiplier'] if settings['type_of_frame_distribution'] == 'linear': @@ -475,9 +488,7 @@ def create_workflow_json(image_locations, settings): batch_size = int(settings['dynamic_frame_distribution_values'].split(',')[-1]) + int(buffer) img_width, img_height = image_dimension.split("x") - - - + for node in json_data['nodes']: if node['id'] == 189: node['widgets_values'][-3] = int(img_width) @@ -503,7 +514,7 @@ def create_workflow_json(image_locations, settings): node['widgets_values'][6] = dynamic_key_frame_influence_values node['widgets_values'][7] = type_of_cn_strength_distribution node['widgets_values'][8] = linear_cn_strength_value - node['widgets_values'][9] = dynamic_cn_strength_values + node['widgets_values'][9] = dynamic_strength_values node['widgets_values'][-1] = buffer node['widgets_values'][-2] = interpolation_type node['widgets_values'][-3] = soft_scaled_cn_weights_multiplier @@ -522,7 +533,7 @@ def create_workflow_json(image_locations, settings): return json_data - +''' def update_interpolation_settings(values=None, timing_list=None): default_values = { 'type_of_frame_distribution': 0, @@ -544,8 +555,310 @@ def update_interpolation_settings(values=None, timing_list=None): for idx in range(0, len(timing_list)): default_values[f'dynamic_frame_distribution_values_{idx}'] = (idx) * 16 default_values[f'dynamic_key_frame_influence_values_{idx}'] = 1.0 - default_values[f'dynamic_cn_strength_values_{idx}'] = (0.0,0.7) + default_values[f'dynamic_strength_values_{idx}'] = (0.0,0.7) for key, default_value in default_values.items(): st.session_state[key] = values.get(key, default_value) if values and values.get(key) is not None else default_value - # print(f"{key}: {st.session_state[key]}") \ No newline at end of file + # print(f"{key}: {st.session_state[key]}") + + +def calculate_dynamic_influence_ranges(keyframe_positions, key_frame_influence_values, allow_extension=True): + if len(keyframe_positions) < 2 or len(keyframe_positions) != len(key_frame_influence_values): + return [] + + influence_ranges = [] + for i, position in enumerate(keyframe_positions): + influence_factor = key_frame_influence_values[i] + range_size = influence_factor * (keyframe_positions[-1] - keyframe_positions[0]) / (len(keyframe_positions) - 1) / 2 + + start_influence = position - range_size + end_influence = position + range_size + + # If extension beyond the adjacent keyframe is allowed, do not constrain the start and end influence. + if not allow_extension: + start_influence = max(start_influence, keyframe_positions[i - 1] if i > 0 else 0) + end_influence = min(end_influence, keyframe_positions[i + 1] if i < len(keyframe_positions) - 1 else keyframe_positions[-1]) + + influence_ranges.append((round(start_influence), round(end_influence))) + + return influence_ranges + +def extract_influence_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): + # Check and convert linear_key_frame_influence_value if it's a float or string float + # if it's a string that starts with a parenthesis, convert it to a tuple + if isinstance(linear_key_frame_influence_value, str) and linear_key_frame_influence_value[0] == "(": + linear_key_frame_influence_value = eval(linear_key_frame_influence_value) + + + if not isinstance(linear_key_frame_influence_value, tuple): + if isinstance(linear_key_frame_influence_value, (float, str)): + try: + value = float(linear_key_frame_influence_value) + linear_key_frame_influence_value = (value, value) + except ValueError: + raise ValueError("linear_key_frame_influence_value must be a float or a string representing a float") + + number_of_outputs = len(keyframe_positions) + + if type_of_key_frame_influence == "dynamic": + # Convert list of individual float values into tuples + if all(isinstance(x, float) for x in dynamic_key_frame_influence_values): + dynamic_values = [(value, value) for value in dynamic_key_frame_influence_values] + elif isinstance(dynamic_key_frame_influence_values[0], str) and dynamic_key_frame_influence_values[0] == "(": + string_representation = ''.join(dynamic_key_frame_influence_values) + dynamic_values = eval(f'[{string_representation}]') + else: + dynamic_values = dynamic_key_frame_influence_values if isinstance(dynamic_key_frame_influence_values, list) else [dynamic_key_frame_influence_values] + return dynamic_values[:number_of_outputs] + else: + return [linear_key_frame_influence_value for _ in range(number_of_outputs)] + + +def get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, images, linear_frame_distribution_value): + if type_of_frame_distribution == "dynamic": + # Check if the input is a string or a list + if isinstance(dynamic_frame_distribution_values, str): + # Sort the keyframe positions in numerical order + return sorted([int(kf.strip()) for kf in dynamic_frame_distribution_values.split(',')]) + elif isinstance(dynamic_frame_distribution_values, list): + return sorted(dynamic_frame_distribution_values) + else: + # Calculate the number of keyframes based on the total duration and linear_frames_per_keyframe + return [i * linear_frame_distribution_value for i in range(len(images))] + +def extract_keyframe_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): + if type_of_key_frame_influence == "dynamic": + return [float(influence.strip()) for influence in dynamic_key_frame_influence_values.split(',')] + else: + return [linear_key_frame_influence_value for _ in keyframe_positions] + +def extract_start_and_endpoint_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): + if type_of_key_frame_influence == "dynamic": + # If dynamic_key_frame_influence_values is a list of characters representing tuples, process it + if isinstance(dynamic_key_frame_influence_values[0], str) and dynamic_key_frame_influence_values[0] == "(": + # Join the characters to form a single string and evaluate to convert into a list of tuples + string_representation = ''.join(dynamic_key_frame_influence_values) + dynamic_values = eval(f'[{string_representation}]') + else: + # If it's already a list of tuples or a single tuple, use it directly + dynamic_values = dynamic_key_frame_influence_values if isinstance(dynamic_key_frame_influence_values, list) else [dynamic_key_frame_influence_values] + return dynamic_values + else: + # Return a list of tuples with the linear_key_frame_influence_value as a tuple repeated for each position + return [linear_key_frame_influence_value for _ in keyframe_positions] + +def calculate_weights(keyframe_positions, strength_values, buffer, key_frame_influence_values,last_key_frame_position): + + def calculate_influence_frame_number(key_frame_position, next_key_frame_position, distance): + # Calculate the absolute distance between key frames + key_frame_distance = abs(next_key_frame_position - key_frame_position) + + # Apply the distance multiplier + extended_distance = key_frame_distance * distance + + # Determine the direction of influence based on the positions of the key frames + if key_frame_position < next_key_frame_position: + # Normal case: influence extends forward + influence_frame_number = key_frame_position + extended_distance + else: + # Reverse case: influence extends backward + influence_frame_number = key_frame_position - extended_distance + + # Return the result rounded to the nearest integer + return round(influence_frame_number) + + def find_curve(batch_index_from, batch_index_to, strength_from, strength_to, interpolation,revert_direction_at_midpoint, last_key_frame_position,i, number_of_items,buffer): + + # Initialize variables based on the position of the keyframe + range_start = batch_index_from + range_end = batch_index_to + # if it's the first value, set influence range from 1.0 to 0.0 + + + if i == number_of_items - 1: + range_end = last_key_frame_position + + steps = range_end - range_start + diff = strength_to - strength_from + + # Calculate index for interpolation + index = np.linspace(0, 1, steps // 2 + 1) if revert_direction_at_midpoint else np.linspace(0, 1, steps) + + # Calculate weights based on interpolation type + if interpolation == "linear": + weights = np.linspace(strength_from, strength_to, len(index)) + elif interpolation == "ease-in": + weights = diff * np.power(index, 2) + strength_from + elif interpolation == "ease-out": + weights = diff * (1 - np.power(1 - index, 2)) + strength_from + elif interpolation == "ease-in-out": + weights = diff * ((1 - np.cos(index * np.pi)) / 2) + strength_from + + if revert_direction_at_midpoint: + weights = np.concatenate([weights, weights[::-1]]) + + # Generate frame numbers + frame_numbers = np.arange(range_start, range_start + len(weights)) + + # "Dropper" component: For keyframes with negative start, drop the weights + if range_start < 0 and i > 0: + drop_count = abs(range_start) + weights = weights[drop_count:] + frame_numbers = frame_numbers[drop_count:] + + # Dropper component: for keyframes a range_End is greater than last_key_frame_position, drop the weights + if range_end > last_key_frame_position and i < number_of_items - 1: + drop_count = range_end - last_key_frame_position + weights = weights[:-drop_count] + frame_numbers = frame_numbers[:-drop_count] + + return weights, frame_numbers + + weights_list = [] + frame_numbers_list = [] + + + for i in range(len(keyframe_positions)): + + keyframe_position = keyframe_positions[i] + interpolation = "ease-in-out" + # strength_from = strength_to = 1.0 + + if i == 0: # first image + + # GET IMAGE AND KEYFRAME INFLUENCE VALUES + + key_frame_influence_from, key_frame_influence_to = key_frame_influence_values[i] + + start_strength, mid_strength, end_strength = strength_values[i] + + keyframe_position = keyframe_positions[i] + next_key_frame_position = keyframe_positions[i+1] + + batch_index_from = keyframe_position + + batch_index_to_excl = calculate_influence_frame_number(keyframe_position, next_key_frame_position, key_frame_influence_to) + + + weights, frame_numbers = find_curve(batch_index_from, batch_index_to_excl, mid_strength, end_strength, interpolation, False, last_key_frame_position, i, len(keyframe_positions), buffer) + # interpolation = "ease-in" + + elif i == len(keyframe_positions) - 1: # last image + + + # GET IMAGE AND KEYFRAME INFLUENCE VALUES + + + key_frame_influence_from,key_frame_influence_to = key_frame_influence_values[i] + start_strength, mid_strength, end_strength = strength_values[i] + # strength_from, strength_to = cn_strength_values[i-1] + + keyframe_position = keyframe_positions[i] + previous_key_frame_position = keyframe_positions[i-1] + + + batch_index_from = calculate_influence_frame_number(keyframe_position, previous_key_frame_position, key_frame_influence_from) + + batch_index_to_excl = keyframe_position + weights, frame_numbers = find_curve(batch_index_from, batch_index_to_excl, start_strength, mid_strength, interpolation, False, last_key_frame_position, i, len(keyframe_positions), buffer) + # interpolation = "ease-out" + + else: # middle images + + + # GET IMAGE AND KEYFRAME INFLUENCE VALUES + key_frame_influence_from,key_frame_influence_to = key_frame_influence_values[i] + start_strength, mid_strength, end_strength = strength_values[i] + keyframe_position = keyframe_positions[i] + + # CALCULATE WEIGHTS FOR FIRST HALF + previous_key_frame_position = keyframe_positions[i-1] + batch_index_from = calculate_influence_frame_number(keyframe_position, previous_key_frame_position, key_frame_influence_from) + batch_index_to_excl = keyframe_position + first_half_weights, first_half_frame_numbers = find_curve(batch_index_from, batch_index_to_excl, start_strength, mid_strength, interpolation, False, last_key_frame_position, i, len(keyframe_positions), buffer) + + # CALCULATE WEIGHTS FOR SECOND HALF + next_key_frame_position = keyframe_positions[i+1] + batch_index_from = keyframe_position + batch_index_to_excl = calculate_influence_frame_number(keyframe_position, next_key_frame_position, key_frame_influence_to) + second_half_weights, second_half_frame_numbers = find_curve(batch_index_from, batch_index_to_excl, mid_strength, end_strength, interpolation, False, last_key_frame_position, i, len(keyframe_positions), buffer) + + # COMBINE FIRST AND SECOND HALF + weights = np.concatenate([first_half_weights, second_half_weights]) + frame_numbers = np.concatenate([first_half_frame_numbers, second_half_frame_numbers]) + + weights_list.append(weights) + frame_numbers_list.append(frame_numbers) + + return weights_list, frame_numbers_list + +def plot_weights(weights_list, frame_numbers_list): + plt.figure(figsize=(12, 6)) + + + for i, weights in enumerate(weights_list): + frame_numbers = frame_numbers_list[i] + plt.plot(frame_numbers, weights, label=f'Frame {i + 1}') + + # Plot settings + plt.xlabel('Frame Number') + plt.ylabel('Weight') + plt.legend() + plt.ylim(0, 1.0) + plt.show() + st.set_option('deprecation.showPyplotGlobalUse', False) + st.pyplot() + + + +def transform_data(strength_of_frames, movements_between_frames, speeds_of_transitions, distances_to_next_frames): + def adjust_and_invert_relative_value(middle_value, relative_value): + if relative_value is not None: + adjusted_value = middle_value * relative_value + return round(middle_value - adjusted_value, 2) + return None + + def invert_value(value): + return round(1.0 - value, 2) if value is not None else None + + # Creating output_strength with relative and inverted start and end values + output_strength = [] + for i, strength in enumerate(strength_of_frames): + start_value = None if i == 0 else movements_between_frames[i - 1] + end_value = None if i == len(strength_of_frames) - 1 else movements_between_frames[i] + + # Adjusting and inverting start and end values relative to the middle value + adjusted_start = adjust_and_invert_relative_value(strength, start_value) + adjusted_end = adjust_and_invert_relative_value(strength, end_value) + + output_strength.append((adjusted_start, strength, adjusted_end)) + + # Creating output_speeds with inverted values + output_speeds = [(None, None) for _ in range(len(speeds_of_transitions) + 1)] + for i in range(len(speeds_of_transitions)): + current_tuple = list(output_speeds[i]) + next_tuple = list(output_speeds[i + 1]) + + inverted_speed = invert_value(speeds_of_transitions[i]) + current_tuple[1] = inverted_speed * 2 + next_tuple[0] = inverted_speed * 2 + + output_speeds[i] = tuple(current_tuple) + output_speeds[i + 1] = tuple(next_tuple) + + # Creating cumulative_distances + cumulative_distances = [0] + for distance in distances_to_next_frames: + cumulative_distances.append(cumulative_distances[-1] + distance) + + return output_strength, output_speeds, cumulative_distances + + + +def format_motion_strengths_with_buffer(frame_numbers, motion_strengths, buffer): + # Adjust the first frame number to 0 and shift the others by the buffer + adjusted_frame_numbers = [0] + [frame + buffer for frame in frame_numbers[1:]] + + # Format the adjusted frame numbers and strengths + formatted = ', '.join(f'{frame}:({strength})' for frame, strength in zip(adjusted_frame_numbers, motion_strengths)) + return formatted \ No newline at end of file diff --git a/ui_components/widgets/attach_audio_element.py b/ui_components/widgets/attach_audio_element.py index 56b186ac..5d44e569 100644 --- a/ui_components/widgets/attach_audio_element.py +++ b/ui_components/widgets/attach_audio_element.py @@ -8,7 +8,8 @@ def attach_audio_element(project_uuid, expanded): data_repo = DataRepo() project_setting: InternalSettingObject = data_repo.get_project_setting(project_uuid) - with st.expander("Audio", expanded=expanded): + with st.expander("πŸ”Š Audio", expanded=expanded): + uploaded_file = st.file_uploader("Attach audio", type=[ "mp3"], help="This will attach this audio when you render a video") if st.button("Upload and attach new audio"): diff --git a/ui_components/widgets/frame_movement_widgets.py b/ui_components/widgets/frame_movement_widgets.py index 4a8351d8..21227f6f 100644 --- a/ui_components/widgets/frame_movement_widgets.py +++ b/ui_components/widgets/frame_movement_widgets.py @@ -118,56 +118,26 @@ def replace_image_widget(timing_uuid, stage, options=["Uploaded Frame", "Other F timing = data_repo.get_timing_from_uuid(timing_uuid) timing_list = data_repo.get_timing_list_from_shot(timing.shot.uuid) - replace_with = options[0] if len(options) == 1 else st.radio("Replace with:", options, horizontal=True, key=f"replacement_entity_{stage}_{timing_uuid}") - - if replace_with == "Other Frame": - image_replacement_stage = st.radio( - "Select stage to use:", - [ImageStage.MAIN_VARIANT.value, ImageStage.SOURCE_IMAGE.value], - key=f"image_replacement_stage_{stage}_{timing_uuid}", - horizontal=True - ) - replacement_img_number = st.number_input("Select image to use:", min_value=1, max_value=len( - timing_list), value=0, key=f"replacement_img_number_{stage}") - - if image_replacement_stage == ImageStage.SOURCE_IMAGE.value: - selected_image = timing_list[replacement_img_number - 1].source_image - elif image_replacement_stage == ImageStage.MAIN_VARIANT.value: - selected_image = timing_list[replacement_img_number - 1].primary_image - - st.image(selected_image.location, use_column_width=True) - - if st.button("Replace with selected frame", disabled=False,key=f"replace_with_selected_frame_{stage}_{timing_uuid}"): - if stage == WorkflowStageType.SOURCE.value: - data_repo.update_specific_timing(timing.uuid, source_image_id=selected_image.uuid) - st.success("Replaced") - time.sleep(1) - st.rerun() - else: - number_of_image_variants = add_image_variant( - selected_image.uuid, timing.uuid) - promote_image_variant( - timing.uuid, number_of_image_variants - 1) - st.success("Replaced") - time.sleep(1) - st.rerun() - elif replace_with == "Uploaded Frame": - btn_text = 'Upload source image' if stage == WorkflowStageType.SOURCE.value else 'Replace frame' - uploaded_file = st.file_uploader(btn_text, type=[ - "png", "jpeg"], accept_multiple_files=False,key=f"uploaded_file_{stage}_{timing_uuid}") - if uploaded_file != None: - if st.button(btn_text): - if uploaded_file: - timing = data_repo.get_timing_from_uuid(timing.uuid) - if save_and_promote_image(uploaded_file, timing.shot.uuid, timing.uuid, stage): - st.success("Replaced") - time.sleep(1.5) - st.rerun() + + btn_text = 'Upload source image' if stage == WorkflowStageType.SOURCE.value else 'Replace frame' + uploaded_file = st.file_uploader(btn_text, type=[ + "png", "jpeg"], accept_multiple_files=False,key=f"uploaded_file_{stage}_{timing_uuid}") + if uploaded_file != None: + if st.button(btn_text): + if uploaded_file: + timing = data_repo.get_timing_from_uuid(timing.uuid) + if save_and_promote_image(uploaded_file, timing.shot.uuid, timing.uuid, stage): + st.success("Replaced") + time.sleep(1.5) + st.rerun() def jump_to_single_frame_view_button(display_number, timing_list, src,uuid=None): if st.button(f"Jump to #{display_number}", key=f"{src}_{uuid}", use_container_width=True): + st.session_state['current_frame_sidebar_selector'] = display_number + st.session_state["creative_process_manual_select"] = 3 + ''' st.session_state['prev_frame_index'] = st.session_state['current_frame_index'] = display_number st.session_state['current_frame_uuid'] = timing_list[st.session_state['current_frame_index'] - 1].uuid st.session_state['frame_styling_view_type_manual_select'] = 2 @@ -176,4 +146,5 @@ def jump_to_single_frame_view_button(display_number, timing_list, src,uuid=None) st.session_state["creative_process_manual_select"] = 4 st.session_state["styling_view_selector_manual_select"] = 0 st.session_state['page'] = "Key Frames" + ''' st.rerun() diff --git a/ui_components/widgets/frame_selector.py b/ui_components/widgets/frame_selector.py index a81dd7be..c6c4f0d3 100644 --- a/ui_components/widgets/frame_selector.py +++ b/ui_components/widgets/frame_selector.py @@ -10,7 +10,7 @@ -def frame_selector_widget(show: List[str]): +def frame_selector_widget(show_frame_selector=True): data_repo = DataRepo() timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) @@ -21,49 +21,65 @@ def frame_selector_widget(show: List[str]): if 'prev_shot_index' not in st.session_state: st.session_state['prev_shot_index'] = shot.shot_idx - if 'shot_selector' in show: + shot1, shot2 = st.columns([1, 1]) + with shot1: shot_names = [s.name for s in shot_list] shot_name = st.selectbox('Shot name:', shot_names, key="current_shot_sidebar_selector",index=shot_names.index(shot.name)) - # find shot index based on shot name - st.session_state['current_shot_index'] = shot_names.index(shot_name) + 1 + # find shot index based on shot name + st.session_state['current_shot_index'] = shot_names.index(shot_name) + 1 - if shot_name != shot.name: - st.session_state["shot_uuid"] = shot_list[shot_names.index(shot_name)].uuid - st.rerun() + if shot_name != shot.name: + st.session_state["shot_uuid"] = shot_list[shot_names.index(shot_name)].uuid + st.rerun() - if not ('current_shot_index' in st.session_state and st.session_state['current_shot_index']): - st.session_state['current_shot_index'] = shot_names.index(shot_name) + 1 - update_current_shot_index(st.session_state['current_shot_index']) + if not ('current_shot_index' in st.session_state and st.session_state['current_shot_index']): + st.session_state['current_shot_index'] = shot_names.index(shot_name) + 1 + update_current_shot_index(st.session_state['current_shot_index']) # st.write if frame_selector is present + + + if st.session_state['page'] == "Key Frames": + if st.session_state['current_frame_index'] > len_timing_list: + update_current_frame_index(len_timing_list) - if 'frame_selector' in show: - - if st.session_state['page'] == "Key Frames": - if st.session_state['current_frame_index'] > len_timing_list: - update_current_frame_index(len_timing_list) - - elif st.session_state['page'] == "Shots": - if st.session_state['current_shot_index'] > len(shot_list): - update_current_shot_index(len(shot_list)) + elif st.session_state['page'] == "Shots": + if st.session_state['current_shot_index'] > len(shot_list): + update_current_shot_index(len(shot_list)) - + if show_frame_selector: if len(timing_list): if 'prev_frame_index' not in st.session_state or st.session_state['prev_frame_index'] > len(timing_list): - st.session_state['prev_frame_index'] = 1 + + # Create a list of frames with a blank value as the first item + frame_list = [''] + [f'{i+1}' for i in range(len(timing_list))] + + + with shot2: + frame_selection = st_memory.selectbox('Select a frame:', frame_list, key="current_frame_sidebar_selector") - st.session_state['current_frame_index'] = st.number_input(f"Key frame # (out of {len(timing_list)})", 1, - len(timing_list), value=st.session_state['prev_frame_index'], - step=1, key="current_frame_sidebar_selector") - - update_current_frame_index(st.session_state['current_frame_index']) + # Only trigger the frame number extraction and current frame index update if a non-empty value is selected + if frame_selection != '': + + if st.button("Jump to shot view",use_container_width=True): + st.session_state['current_frame_sidebar_selector'] = 0 + st.rerun() + # st.session_state['creative_process_manual_select'] = 4 + st.session_state['current_frame_index'] = int(frame_selection.split(' ')[-1]) + update_current_frame_index(st.session_state['current_frame_index']) else: - st.error("No frames present") + frame_selection = "" + with shot2: + st.write("") + st.error("No frames present") + + return frame_selection -def frame_view(view="Key Frame"): +def frame_view(view="Key Frame",show_current_frames=True): data_repo = DataRepo() # time1, time2 = st.columns([1,1]) - st.markdown("***") + # st.markdown("***") + st.write("") timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) shot = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"]) @@ -107,11 +123,12 @@ def frame_view(view="Key Frame"): with a2: update_shot_duration(shot.uuid) - st.markdown("---") + if show_current_frames: + st.markdown("---") - timing_list: List[InternalFrameTimingObject] = shot.timing_list + timing_list: List[InternalFrameTimingObject] = shot.timing_list - display_shot_frames(timing_list, False) + display_shot_frames(timing_list, False) st.markdown("---") diff --git a/ui_components/widgets/image_zoom_widgets.py b/ui_components/widgets/image_zoom_widgets.py index a22efcc7..c731b1b6 100644 --- a/ui_components/widgets/image_zoom_widgets.py +++ b/ui_components/widgets/image_zoom_widgets.py @@ -78,7 +78,7 @@ def save_zoomed_image(image, timing_uuid, stage, promote=False): styled_image.uuid, timing_uuid) if promote: promote_image_variant(timing_uuid, number_of_image_variants - 1) - + ''' project_update_data = { "zoom_level": st.session_state['zoom_level_input'], "rotation_angle_value": st.session_state['rotation_angle_input'], @@ -93,7 +93,9 @@ def save_zoomed_image(image, timing_uuid, stage, promote=False): "zoom_details": f"{st.session_state['zoom_level_input']},{st.session_state['rotation_angle_input']},{st.session_state['x_shift']},{st.session_state['y_shift']}", } + data_repo.update_specific_timing(timing_uuid, **timing_update_data) + ''' def reset_zoom_element(): st.session_state['zoom_level_input_key'] = 100 diff --git a/ui_components/widgets/inpainting_element.py b/ui_components/widgets/inpainting_element.py index 26f338b4..a219d869 100644 --- a/ui_components/widgets/inpainting_element.py +++ b/ui_components/widgets/inpainting_element.py @@ -73,6 +73,7 @@ def inpainting_element(timing_uuid): type_of_mask_selection = "Manual Background Selection" + # NOTE: removed other mask selection methods, will update the code later if type_of_mask_selection == "Manual Background Selection": if st.session_state['edited_image'] == "": with main_col_1: @@ -150,11 +151,8 @@ def inpainting_element(timing_uuid): st.session_state['edited_image'] = "" st.rerun() - with main_col_1: - st.session_state["type_of_mask_replacement"] = "Inpainting" - btn1, btn2 = st.columns([1, 1]) with btn1: prompt = st.text_area("Prompt:", help="Describe the whole image, but focus on the details you want changed!", @@ -163,9 +161,8 @@ def inpainting_element(timing_uuid): negative_prompt = st.text_area( "Negative Prompt:", help="Enter any things you want to make the model avoid!", value=DefaultProjectSettingParams.batch_negative_prompt, height=150) - edit1, edit2 = st.columns(2) - - with edit1: + col1, _ = st.columns(2) + with col1: if st.button(f'Run Edit'): if st.session_state["type_of_mask_replacement"] == "Inpainting": edited_image, log = execute_image_edit( diff --git a/ui_components/widgets/shot_view.py b/ui_components/widgets/shot_view.py index 68eae2f8..f2ab3057 100644 --- a/ui_components/widgets/shot_view.py +++ b/ui_components/widgets/shot_view.py @@ -57,10 +57,11 @@ def shot_keyframe_element(shot_uuid, items_per_row, position="Timeline", **kwarg with col3: move_frames_toggle = st_memory.toggle("Move Frames", value=True, key="move_frames_toggle") with col4: - replace_image_widget_toggle = st_memory.toggle("Replace Image", value=False, key="replace_image_widget_toggle") - - with col5: change_shot_toggle = st_memory.toggle("Change Shot", value=False, key="change_shot_toggle") + # replace_image_widget_toggle = st_memory.toggle("Replace Image", value=False, key="replace_image_widget_toggle") + + + st.markdown("***") @@ -85,8 +86,8 @@ def shot_keyframe_element(shot_uuid, items_per_row, position="Timeline", **kwarg st.warning("No primary image present.") jump_to_single_frame_view_button(idx + 1, timing_list, f"jump_to_{idx + 1}",uuid=shot.uuid) if position != "Timeline": - timeline_view_buttons(idx, shot_uuid, replace_image_widget_toggle, copy_frame_toggle, move_frames_toggle,delete_frames_toggle, change_shot_toggle) - if (i < len(timing_list) - 1) or (st.session_state["open_shot"] == shot.uuid) or (len(timing_list) % items_per_row != 0 and st.session_state["open_shot"] != shot.uuid): + timeline_view_buttons(idx, shot_uuid, copy_frame_toggle, move_frames_toggle,delete_frames_toggle, change_shot_toggle) + if (i < len(timing_list) - 1) or (st.session_state["open_shot"] == shot.uuid) or (len(timing_list) % items_per_row != 0 and st.session_state["open_shot"] != shot.uuid) or len(timing_list) % items_per_row == 0: st.markdown("***") # st.markdown("***") @@ -219,30 +220,8 @@ def update_shot_duration(shot_uuid): time.sleep(0.3) st.rerun() -def shot_video_element(shot_uuid): - data_repo = DataRepo() - - shot: InternalShotObject = data_repo.get_shot_from_uuid(shot_uuid) - - st.info(f"##### {shot.name}") - if shot.main_clip and shot.main_clip.location: - st.video(shot.main_clip.location) - else: - st.warning('''No video present''') - switch1,switch2 = st.columns([1,1]) - with switch1: - shot_adjustment_button(shot) - with switch2: - shot_animation_button(shot) - with st.expander("Details", expanded=False): - update_shot_name(shot.uuid) - update_shot_duration(shot.uuid) - move_shot_buttons(shot, "side") - delete_shot_button(shot.uuid) - if shot.main_clip: - create_video_download_button(shot.main_clip.location, tag="main_clip") @@ -274,6 +253,7 @@ def shot_adjustment_button(shot, show_label=False): button_label = "Shot Adjustment πŸ”§" if show_label else "πŸ”§" if st.button(button_label, key=f"jump_to_shot_adjustment_{shot.uuid}", help=f"Shot adjustment view for '{shot.name}'", use_container_width=True): st.session_state["shot_uuid"] = shot.uuid + st.session_state['current_frame_sidebar_selector'] = 0 st.session_state['creative_process_manual_select'] = 3 st.session_state["manual_select"] = 1 st.session_state['shot_view_manual_select'] = 1 @@ -284,21 +264,19 @@ def shot_animation_button(shot, show_label=False): button_label = "Shot Animation 🎞️" if show_label else "🎞️" if st.button(button_label, key=f"jump_to_shot_animation_{shot.uuid}", help=f"Shot animation view for '{shot.name}'", use_container_width=True): st.session_state["shot_uuid"] = shot.uuid - st.session_state['creative_process_manual_select'] = 5 - st.session_state["manual_select"] = 1 + st.session_state['creative_process_manual_select'] = 4 + # st.session_state["manual_select"] = 1 st.session_state['shot_view_manual_select'] = 0 st.session_state['shot_view_index'] = 0 st.rerun() -def timeline_view_buttons(idx, shot_uuid, replace_image_widget_toggle, copy_frame_toggle, move_frames_toggle, delete_frames_toggle, change_shot_toggle): +def timeline_view_buttons(idx, shot_uuid, copy_frame_toggle, move_frames_toggle, delete_frames_toggle, change_shot_toggle): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) timing_list = shot.timing_list - if replace_image_widget_toggle: - replace_image_widget(timing_list[idx].uuid, stage=WorkflowStageType.STYLED.value, options=["Uploaded Frame"]) btn1, btn2, btn3, btn4 = st.columns([1, 1, 1, 1]) diff --git a/ui_components/widgets/sidebar_logger.py b/ui_components/widgets/sidebar_logger.py index 8f719908..cc550e9e 100644 --- a/ui_components/widgets/sidebar_logger.py +++ b/ui_components/widgets/sidebar_logger.py @@ -1,13 +1,14 @@ +import time import streamlit as st -from shared.constants import InferenceParamType, InferenceStatus +from shared.constants import InferenceParamType, InferenceStatus, InternalFileTag, InternalFileType from ui_components.widgets.frame_movement_widgets import jump_to_single_frame_view_button import json import math from ui_components.widgets.frame_selector import update_current_frame_index from utils.data_repo.data_repo import DataRepo -from utils.ml_processor.replicate.constants import REPLICATE_MODEL +from utils.ml_processor.constants import ML_MODEL def sidebar_logger(shot_uuid): data_repo = DataRepo() @@ -19,7 +20,7 @@ def sidebar_logger(shot_uuid): if a1.button("Refresh log", disabled=refresh_disabled, help="You can also press 'r' on your keyboard to refresh."): st.rerun() status_option = st.radio("Statuses to display:", options=["All", "In Progress", "Succeeded", "Failed"], key="status_option", index=0, horizontal=True) - + status_list = None if status_option == "In Progress": status_list = [InferenceStatus.QUEUED.value, InferenceStatus.IN_PROGRESS.value] @@ -73,7 +74,10 @@ def sidebar_logger(shot_uuid): prompt = input_params.get('prompt', 'No prompt found') st.write(f'"{prompt[:30]}..."' if len(prompt) > 30 else f'"{prompt}"') st.caption(f"Model:") - st.write(json.loads(log.output_details)['model_name'].split('/')[-1]) + try: + st.write(json.loads(log.output_details)['model_name'].split('/')[-1]) + except Exception as e: + st.write('') with c2: if output_url: @@ -96,13 +100,23 @@ def sidebar_logger(shot_uuid): elif log.status == InferenceStatus.CANCELED.value: st.warning("Canceled") + log_file = log_file_dict[log.uuid] if log.uuid in log_file_dict else None + if log_file: + if log_file.type == InternalFileType.IMAGE.value and log_file.tag != InternalFileTag.SHORTLISTED_GALLERY_IMAGE.value: + if st.button("Add to shortlist βž•", key=f"sidebar_shortlist_{log_file.uuid}",use_container_width=True, help="Add to shortlist"): + data_repo.update_file(log_file.uuid, tag=InternalFileTag.SHORTLISTED_GALLERY_IMAGE.value) + st.success("Added To Shortlist") + time.sleep(0.3) + st.rerun() + + if output_url and origin_data and 'timing_uuid' in origin_data and origin_data['timing_uuid']: timing = data_repo.get_timing_from_uuid(origin_data['timing_uuid']) if timing and st.session_state['frame_styling_view_type'] != "Timeline": jump_to_single_frame_view_button(timing.aux_frame_index + 1, timing_list, 'sidebar_'+str(log.uuid)) else: - if st.session_state['frame_styling_view_type'] != "Explorer": + if st.session_state['page'] != "Explore": if st.button(f"Jump to explorer", key=str(log.uuid)): # TODO: fix this st.session_state['main_view_type'] = "Creative Process" @@ -111,5 +125,5 @@ def sidebar_logger(shot_uuid): st.rerun() - + st.markdown("---") \ No newline at end of file diff --git a/ui_components/widgets/timeline_view.py b/ui_components/widgets/timeline_view.py index 95fc2d1f..444c31cb 100644 --- a/ui_components/widgets/timeline_view.py +++ b/ui_components/widgets/timeline_view.py @@ -1,6 +1,6 @@ import streamlit as st from ui_components.methods.common_methods import add_new_shot -from ui_components.widgets.shot_view import shot_keyframe_element, shot_video_element +from ui_components.widgets.shot_view import shot_keyframe_element, shot_adjustment_button, shot_animation_button, update_shot_name, update_shot_duration, move_shot_buttons, delete_shot_button, create_video_download_button from utils.data_repo.data_repo import DataRepo from utils import st_memory @@ -9,36 +9,63 @@ def timeline_view(shot_uuid, stage): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) shot_list = data_repo.get_shot_list(shot.project.uuid) + timing_list: List[InternalFrameTimingObject] = shot.timing_list _, header_col_2 = st.columns([5.5,1.5]) - with header_col_2: - items_per_row = st_memory.slider("How many frames per row?", min_value=3, max_value=7, value=5, step=1, key="items_per_row_slider") + #with header_col_2: - if stage == 'Key Frames': - for shot in shot_list: - with st.expander(f"_-_-_-_", expanded=True): - shot_keyframe_element(shot.uuid, items_per_row) + items_per_row = 4 + for idx, shot in enumerate(shot_list): + timing_list: List[InternalFrameTimingObject] = shot.timing_list + if idx % items_per_row == 0: + grid = st.columns(items_per_row) + + + with grid[idx % items_per_row]: + st.info(f"##### {shot.name}") + if stage == "Key Frames": + for i in range(0, len(timing_list), items_per_row): + if i % items_per_row == 0: + grid_timing = st.columns(items_per_row) + for j in range(items_per_row): + # idx = i + j + if i + j < len(timing_list): + with grid_timing[j]: + timing = timing_list[ i + j] + if timing.primary_image and timing.primary_image.location: + st.image(timing.primary_image.location, use_column_width=True) + else: + + if shot.main_clip and shot.main_clip.location: + st.video(shot.main_clip.location) + else: + st.warning('''No video present''') + + + switch1,switch2 = st.columns([1,1]) + with switch1: + shot_adjustment_button(shot) + with switch2: + shot_animation_button(shot) + + with st.expander("Details & settings", expanded=False): + update_shot_name(shot.uuid) + update_shot_duration(shot.uuid) + move_shot_buttons(shot, "side") + delete_shot_button(shot.uuid) + if shot.main_clip: + create_video_download_button(shot.main_clip.location, tag="main_clip") + + + if (idx + 1) % items_per_row == 0 or idx == len(shot_list) - 1: st.markdown("***") - st.markdown("### Add new shot") - shot1,shot2 = st.columns([0.75,3]) - with shot1: - add_new_shot_element(shot, data_repo) - else: - for idx, shot in enumerate(shot_list): - if idx % items_per_row == 0: - grid = st.columns(items_per_row) - with grid[idx % items_per_row]: - shot_video_element(shot.uuid) - if (idx + 1) % items_per_row == 0 or idx == len(shot_list) - 1: - st.markdown("***") - # if stage isn't - if idx == len(shot_list) - 1: - with grid[(idx + 1) % items_per_row]: - st.markdown("### Add new shot") - add_new_shot_element(shot, data_repo) + if idx == len(shot_list) - 1: + with grid[(idx + 1) % items_per_row]: + st.markdown("### Add new shot") + add_new_shot_element(shot, data_repo) diff --git a/ui_components/widgets/variant_comparison_grid.py b/ui_components/widgets/variant_comparison_grid.py index 9a4e0197..72d5100e 100644 --- a/ui_components/widgets/variant_comparison_grid.py +++ b/ui_components/widgets/variant_comparison_grid.py @@ -51,6 +51,7 @@ def variant_comparison_grid(ele_uuid, stage=CreativeProcessType.MOTION.value): if not len(variants): st.info("No variants present") + st.markdown("***") else: current_variant = shot.primary_interpolated_video_index if stage == CreativeProcessType.MOTION.value else int(timing.primary_variant_index) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/utils/common_utils.py b/utils/common_utils.py index aa70e8ef..65b13a07 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -246,5 +246,5 @@ def release_lock(key): def refresh_app(maintain_state=False): - st.session_state['maintain_state'] = maintain_state + # st.session_state['maintain_state'] = maintain_state st.rerun() \ No newline at end of file diff --git a/utils/constants.py b/utils/constants.py index cfe7a9a7..5f80f90d 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -2,7 +2,7 @@ import json from shared.constants import AIModelCategory, AIModelType from utils.enum import ExtendedEnum -from utils.ml_processor.replicate.constants import REPLICATE_MODEL +from utils.ml_processor.constants import ML_MODEL import streamlit as st @@ -67,8 +67,8 @@ def to_json(self): ML_MODEL_LIST = [ { "name" : 'stable-diffusion-img2img-v2.1', - "version": REPLICATE_MODEL.img2img_sd_2_1.version, - "replicate_url" : REPLICATE_MODEL.img2img_sd_2_1.name, + "version": ML_MODEL.img2img_sd_2_1.version, + "replicate_url" : ML_MODEL.img2img_sd_2_1.name, "category" : AIModelCategory.BASE_SD.value, "keyword" : "", "model_type": json.dumps([AIModelType.IMG2IMG.value]), @@ -76,8 +76,8 @@ def to_json(self): }, { "name" : 'depth2img', - "version": REPLICATE_MODEL.jagilley_controlnet_depth2img.version, - "replicate_url" : REPLICATE_MODEL.jagilley_controlnet_depth2img.name, + "version": ML_MODEL.jagilley_controlnet_depth2img.version, + "replicate_url" : ML_MODEL.jagilley_controlnet_depth2img.name, "category" : AIModelCategory.BASE_SD.value, "keyword" : "", "model_type": json.dumps([AIModelType.IMG2IMG.value]), @@ -85,8 +85,8 @@ def to_json(self): }, { "name" : 'pix2pix', - "version": REPLICATE_MODEL.arielreplicate.version, - "replicate_url" : REPLICATE_MODEL.arielreplicate.name, + "version": ML_MODEL.arielreplicate.version, + "replicate_url" : ML_MODEL.arielreplicate.name, "category" : AIModelCategory.BASE_SD.value, "keyword" : "", "model_type": json.dumps([AIModelType.IMG2IMG.value]), @@ -115,8 +115,8 @@ def to_json(self): }, { "name" : 'StyleGAN-NADA', - "version": REPLICATE_MODEL.stylegan_nada.version, - "replicate_url" : REPLICATE_MODEL.stylegan_nada.name, + "version": ML_MODEL.stylegan_nada.version, + "replicate_url" : ML_MODEL.stylegan_nada.name, "category" : AIModelCategory.BASE_SD.value, "keyword" : "", "model_type": json.dumps([AIModelType.IMG2IMG.value]), @@ -124,8 +124,8 @@ def to_json(self): }, { "name" : 'real-esrgan-upscaling', - "version": REPLICATE_MODEL.real_esrgan_upscale.version, - "replicate_url" : REPLICATE_MODEL.real_esrgan_upscale.name, + "version": ML_MODEL.real_esrgan_upscale.version, + "replicate_url" : ML_MODEL.real_esrgan_upscale.name, "category" : AIModelCategory.BASE_SD.value, "keyword" : "", "model_type": json.dumps([AIModelType.IMG2IMG.value]), @@ -133,8 +133,8 @@ def to_json(self): }, { "name" : 'controlnet_1_1_x_realistic_vision_v2_0', - "version": REPLICATE_MODEL.controlnet_1_1_x_realistic_vision_v2_0.version, - "replicate_url" : REPLICATE_MODEL.controlnet_1_1_x_realistic_vision_v2_0.name, + "version": ML_MODEL.controlnet_1_1_x_realistic_vision_v2_0.version, + "replicate_url" : ML_MODEL.controlnet_1_1_x_realistic_vision_v2_0.name, "category" : AIModelCategory.BASE_SD.value, "keyword" : "", "model_type": json.dumps([AIModelType.IMG2IMG.value]), @@ -142,8 +142,8 @@ def to_json(self): }, { "name" : 'urpm-v1.3', - "version": REPLICATE_MODEL.urpm.version, - "replicate_url" : REPLICATE_MODEL.urpm.name, + "version": ML_MODEL.urpm.version, + "replicate_url" : ML_MODEL.urpm.name, "category" : AIModelCategory.BASE_SD.value, "keyword" : "", "model_type": json.dumps([AIModelType.IMG2IMG.value]), @@ -151,8 +151,8 @@ def to_json(self): }, { "name": "stable_diffusion_xl", - "version": REPLICATE_MODEL.sdxl.version, - "replicate_url": REPLICATE_MODEL.sdxl.name, + "version": ML_MODEL.sdxl.version, + "replicate_url": ML_MODEL.sdxl.name, "category": AIModelCategory.BASE_SD.value, "keyword": "", "model_type": json.dumps([AIModelType.TXT2IMG.value, AIModelType.IMG2IMG.value]), @@ -160,8 +160,8 @@ def to_json(self): }, { "name": "realistic_vision_5", - "version": REPLICATE_MODEL.realistic_vision_v5.version, - "replicate_url": REPLICATE_MODEL.realistic_vision_v5.name, + "version": ML_MODEL.realistic_vision_v5.version, + "replicate_url": ML_MODEL.realistic_vision_v5.name, "category": AIModelCategory.BASE_SD.value, "keyword": "", "model_type": json.dumps([AIModelType.TXT2IMG.value]), @@ -169,8 +169,8 @@ def to_json(self): }, { "name": "deliberate_v3", - "version": REPLICATE_MODEL.deliberate_v3.version, - "replicate_url": REPLICATE_MODEL.deliberate_v3.name, + "version": ML_MODEL.deliberate_v3.version, + "replicate_url": ML_MODEL.deliberate_v3.name, "category": AIModelCategory.BASE_SD.value, "keyword": "", "model_type": json.dumps([AIModelType.TXT2IMG.value, AIModelType.IMG2IMG.value]), @@ -178,8 +178,8 @@ def to_json(self): }, { "name": "dreamshaper_v7", - "version": REPLICATE_MODEL.dreamshaper_v7.version, - "replicate_url": REPLICATE_MODEL.dreamshaper_v7.name, + "version": ML_MODEL.dreamshaper_v7.version, + "replicate_url": ML_MODEL.dreamshaper_v7.name, "category": AIModelCategory.BASE_SD.value, "keyword": "", "model_type": json.dumps([AIModelType.TXT2IMG.value, AIModelType.IMG2IMG.value]), @@ -187,8 +187,8 @@ def to_json(self): }, { "name": "epic_realism_v5", - "version": REPLICATE_MODEL.epicrealism_v5.version, - "replicate_url": REPLICATE_MODEL.epicrealism_v5.name, + "version": ML_MODEL.epicrealism_v5.version, + "replicate_url": ML_MODEL.epicrealism_v5.name, "category": AIModelCategory.BASE_SD.value, "keyword": "", "model_type": json.dumps([AIModelType.TXT2IMG.value, AIModelType.IMG2IMG.value]), @@ -196,8 +196,8 @@ def to_json(self): }, { "name": "sdxl_controlnet", - "version": REPLICATE_MODEL.sdxl_controlnet.version, - "replicate_url": REPLICATE_MODEL.sdxl_controlnet.name, + "version": ML_MODEL.sdxl_controlnet.version, + "replicate_url": ML_MODEL.sdxl_controlnet.name, "category": AIModelCategory.BASE_SD.value, "keyword": "", "model_type": json.dumps([AIModelType.IMG2IMG.value]), @@ -205,8 +205,8 @@ def to_json(self): }, { "name": "sdxl_controlnet_openpose", - "version": REPLICATE_MODEL.sdxl_controlnet_openpose.version, - "replicate_url": REPLICATE_MODEL.sdxl_controlnet_openpose.name, + "version": ML_MODEL.sdxl_controlnet_openpose.version, + "replicate_url": ML_MODEL.sdxl_controlnet_openpose.name, "category": AIModelCategory.BASE_SD.value, "keyword": "", "model_type": json.dumps([AIModelType.IMG2IMG.value]), @@ -214,8 +214,8 @@ def to_json(self): }, { "name": "realistic_vision_img2img", - "version": REPLICATE_MODEL.realistic_vision_v5_img2img.version, - "replicate_url": REPLICATE_MODEL.realistic_vision_v5_img2img.name, + "version": ML_MODEL.realistic_vision_v5_img2img.version, + "replicate_url": ML_MODEL.realistic_vision_v5_img2img.name, "category": AIModelCategory.BASE_SD.value, "keyword": "", "model_type": json.dumps([AIModelType.IMG2IMG.value]), diff --git a/utils/data_repo/api_repo.py b/utils/data_repo/api_repo.py index dcd11b8e..2b413b0f 100644 --- a/utils/data_repo/api_repo.py +++ b/utils/data_repo/api_repo.py @@ -52,6 +52,7 @@ def _setup_urls(self): # project self.PROJECT_URL = '/v1/data/project' self.PROJECT_LIST_URL = '/v1/data/project/list' + self.EXPLORER_STATS_URL = '/v1/data/project/stats' # project setting self.PROJECT_SETTING_URL = '/v1/data/project-setting' @@ -65,6 +66,7 @@ def _setup_urls(self): self.FILE_LIST_URL = '/v1/data/file/list' self.FILE_UUID_LIST_URL = '/v1/data/file/uuid-list' self.FILE_UPLOAD_URL = '/v1/data/file/upload' + self.FILE_EXTRA_URL = '/v1/data/file/extra' # TODO: fix url patterns # app setting self.APP_SETTING_URL = '/v1/data/app-setting' @@ -86,6 +88,7 @@ def _setup_urls(self): self.SHOT_INTERPOLATED_CLIP = '/v1/data/shot/interpolated-clip' self.SHOT_DUPLICATE_URL = '/v1/data/shot/duplicate' + def logout(self): delete_url_param(AUTH_TOKEN) st.rerun() @@ -246,6 +249,14 @@ def update_file(self, **kwargs): res = self.http_put(url=self.FILE_URL, data=kwargs) return InternalResponse(res['payload'], 'success', res['status']) + def get_file_count_from_type(self, file_tag, project_uuid): + res = self.http_get(self.FILE_EXTRA_URL, params={'file_tag': file_tag, 'project_uuid': project_uuid}) + return InternalResponse(res['payload'], 'success', res['status']) + + def update_temp_gallery_images(self, project_uuid): + res = self.http_put(self.FILE_EXTRA_URL, data={'uuid': project_uuid}) + return InternalResponse(res['payload'], 'success', res['status']) + # project def get_project_from_uuid(self, uuid): res = self.http_get(self.PROJECT_URL, params={'uuid': uuid}) @@ -520,4 +531,9 @@ def delete_shot(self, shot_uuid): def add_interpolated_clip(self, shot_uuid, **kwargs): kwargs['uuid'] = shot_uuid res = self.http_post(self.SHOT_INTERPOLATED_CLIP, data=kwargs) + return InternalResponse(res['payload'], 'success', res['status']) + + # combined + def get_explorer_pending_stats(self, project_uuid, log_status_list): + res = self.http_get(self.EXPLORER_STATS_URL, params={'project_uuid': project_uuid, 'log_status_list': log_status_list}) return InternalResponse(res['payload'], 'success', res['status']) \ No newline at end of file diff --git a/utils/data_repo/data_repo.py b/utils/data_repo/data_repo.py index ec4c234a..795e77d9 100644 --- a/utils/data_repo/data_repo.py +++ b/utils/data_repo/data_repo.py @@ -1,7 +1,7 @@ # this repo serves as a middlerware between API backend and the frontend import json import time -from shared.constants import SECRET_ACCESS_TOKEN, InferenceParamType, InternalFileType, InternalResponse +from shared.constants import SECRET_ACCESS_TOKEN, InferenceParamType, InferenceStatus, InternalFileType, InternalResponse from shared.constants import SERVER, ServerType from shared.logging.constants import LoggingType from shared.logging.logging import AppLogger @@ -121,6 +121,11 @@ def create_file(self, **kwargs) -> InternalFileObject: uploaded_file_url = self.upload_file(file_content) kwargs.update({'hosted_url':uploaded_file_url}) + # handling the case of local inference.. will fix later + if 'hosted_url' in kwargs and not kwargs['hosted_url'].startswith('http'): + kwargs['local_path'] = kwargs['hosted_url'] + del kwargs['hosted_url'] + res = self.db_repo.create_file(**kwargs) file = res.data['data'] if res.status else None file = InternalFileObject(**file) if file else None @@ -139,6 +144,8 @@ def get_image_list_from_uuid_list(self, image_uuid_list, file_type=InternalFileT if not (image_uuid_list and len(image_uuid_list)): return [] image_list = self.db_repo.get_image_list_from_uuid_list(image_uuid_list, file_type=file_type).data['data'] + + print("--------------- ", image_list[0]['project']['uuid']) return [InternalFileObject(**image) for image in image_list] if image_list else [] def update_file(self, file_uuid, **kwargs): @@ -153,6 +160,13 @@ def update_file(self, file_uuid, **kwargs): file = res.data['data'] if res.status else None return InternalFileObject(**file) if file else None + def get_file_count_from_type(self, file_tag=None, project_uuid=None): + return self.db_repo.get_file_count_from_type(file_tag, project_uuid).data['data'] + + def update_temp_gallery_images(self, project_uuid): + self.db_repo.update_temp_gallery_images(project_uuid) + return True + # project def get_project_from_uuid(self, uuid): project = self.db_repo.get_project_from_uuid(uuid).data['data'] @@ -464,4 +478,12 @@ def duplicate_shot(self, shot_uuid): def add_interpolated_clip(self, shot_uuid, **kwargs): res = self.db_repo.add_interpolated_clip(shot_uuid, **kwargs) - return res.status \ No newline at end of file + return res.status + + # combined + # gives the count of 1. temp generated images 2. inference logs with in-progress/pending status + def get_explorer_pending_stats(self, project_uuid): + log_status_list = [InferenceStatus.IN_PROGRESS.value, InferenceStatus.QUEUED.value] + res = self.db_repo.get_explorer_pending_stats(project_uuid, log_status_list) + count_data = res.data['data'] if res.status else {"temp_image_count": 0, "pending_image_count": 0} + return count_data \ No newline at end of file diff --git a/utils/media_processor/interpolator.py b/utils/media_processor/interpolator.py index 251742cc..ddaed642 100644 --- a/utils/media_processor/interpolator.py +++ b/utils/media_processor/interpolator.py @@ -4,14 +4,15 @@ import streamlit as st import requests as r import numpy as np -from shared.constants import AnimationStyleType, AnimationToolType +from shared.constants import QUEUE_INFERENCE_QUERIES, AnimationStyleType, AnimationToolType from ui_components.constants import DefaultTimingStyleParams from ui_components.methods.file_methods import generate_temp_file, zip_images from ui_components.models import InferenceLogObject +from utils.constants import MLQueryObject from utils.data_repo.data_repo import DataRepo from utils.ml_processor.ml_interface import get_ml_client -from utils.ml_processor.replicate.constants import REPLICATE_MODEL +from utils.ml_processor.constants import ML_MODEL class VideoInterpolator: @@ -58,16 +59,16 @@ def create_interpolated_clip(img_location_list, animation_style, settings, varia @staticmethod def video_through_frame_interpolation(img_location_list, settings, variant_count, queue_inference=False): ml_client = get_ml_client() - zip_filename = zip_images(img_location_list) - zip_url = ml_client.upload_training_data(zip_filename, delete_after_upload=True) - print("zipped file url: ", zip_url) - animation_tool = settings['animation_tool'] if 'animation_tool' in settings else AnimationToolType.G_FILM.value + # zip_filename = zip_images(img_location_list) + # zip_url = ml_client.upload_training_data(zip_filename, delete_after_upload=True) + # print("zipped file url: ", zip_url) + # animation_tool = settings['animation_tool'] if 'animation_tool' in settings else AnimationToolType.G_FILM.value final_res = [] for _ in range(variant_count): # if animation_tool == AnimationToolType.G_FILM.value: # res = ml_client.predict_model_output( - # REPLICATE_MODEL.google_frame_interpolation, + # ML_MODEL.google_frame_interpolation, # frame1=img1, # frame2=img2, # times_to_interpolate=settings['interpolation_steps'], @@ -79,37 +80,70 @@ def video_through_frame_interpolation(img_location_list, settings, variant_count # defaulting to animatediff interpolation if True: - - data = { + # NOTE: @Peter these are all the settings you passed in from the UI + sm_data = { "ckpt": settings['ckpt'], + "width": settings['width'], # "width": "512", + "height": settings['height'], # "height": "512", "buffer": settings['buffer'], - "image_list": zip_url, - "motion_scale": settings['motion_scale'], - "output_format": settings['output_format'], + "motion_scale": settings['motion_scale'], # "motion_scale": "1.0", + "motion_scales": settings['motion_scales'], "image_dimension": settings["image_dimension"], + "output_format": settings['output_format'], + "prompt": settings["prompt"], "negative_prompt": settings["negative_prompt"], - "image_prompt_list": settings["image_prompt_list"], + # "image_prompt_list": settings["image_prompt_list"], "interpolation_type": settings["interpolation_type"], "stmfnet_multiplier": settings["stmfnet_multiplier"], "relative_ipadapter_strength": settings["relative_ipadapter_strength"], - "relative_ipadapter_influence": settings["relative_ipadapter_influence"], - "soft_scaled_cn_weights_multiplier": settings["soft_scaled_cn_weights_multiplier"], - "type_of_cn_strength_distribution": settings["type_of_cn_strength_distribution"], - "linear_cn_strength_value": settings["linear_cn_strength_value"], - "dynamic_cn_strength_values": settings["dynamic_cn_strength_values"], + "relative_cn_strength": settings["relative_cn_strength"], + "type_of_strength_distribution": settings["type_of_strength_distribution"], + "linear_strength_value": settings["linear_strength_value"], + "dynamic_strength_values": settings["dynamic_strength_values"], "linear_frame_distribution_value": settings["linear_frame_distribution_value"], - "type_of_frame_distribution": settings["type_of_frame_distribution"], "dynamic_frame_distribution_values": settings["dynamic_frame_distribution_values"], + "type_of_frame_distribution": settings["type_of_frame_distribution"], "type_of_key_frame_influence": settings["type_of_key_frame_influence"], "linear_key_frame_influence_value": settings["linear_key_frame_influence_value"], - "dynamic_key_frame_influence_values": settings["dynamic_key_frame_influence_values"], - "relative_ipadapter_strength": settings["relative_ipadapter_strength"], - "relative_ipadapter_influence": settings["relative_ipadapter_influence"], - "ipadapter_noise": 0.25, - "queue_inference": True + "dynamic_key_frame_influence_values": settings["dynamic_key_frame_influence_values"], + "normalise_speed": settings["normalise_speed"], + "ipadapter_noise": settings["ipadapter_noise"], + "queue_inference": True, + "context_length": settings["context_length"], + "context_stride": settings["context_stride"], + "context_overlap": settings["context_overlap"], + "multipled_base_end_percent": settings["multipled_base_end_percent"], + "multipled_base_adapter_strength": settings["multipled_base_adapter_strength"], + "individual_prompts": settings["individual_prompts"], + "individual_negative_prompts": settings["individual_negative_prompts"], + "max_frames": settings["max_frames"], + } - res = ml_client.predict_model_output(REPLICATE_MODEL.ad_interpolation, **data) + # adding the input images + for idx, img_uuid in enumerate(settings['file_uuid_list']): + sm_data["file_image_" + str(idx) + "_uuid"] = img_uuid + + # NOTE: @Peter all the above settings are put in the 'data' parameter below + ml_query_object = MLQueryObject( + prompt="SM", # hackish fix + timing_uuid=None, + model_uuid=None, + guidance_scale=None, + seed=None, + num_inference_steps=None, + strength=None, + adapter_type=None, + negative_prompt="", + height=512, + width=512, + low_threshold=100, + high_threshold=200, + image_uuid=None, + mask_uuid=None, + data=sm_data + ) + res = ml_client.predict_model_output_standardized(ML_MODEL.ad_interpolation, ml_query_object, QUEUE_INFERENCE_QUERIES) final_res.append(res) diff --git a/utils/ml_processor/comfy_data_transform.py b/utils/ml_processor/comfy_data_transform.py new file mode 100644 index 00000000..55270f1d --- /dev/null +++ b/utils/ml_processor/comfy_data_transform.py @@ -0,0 +1,378 @@ +import os +import random +import tempfile +import uuid +from shared.constants import InternalFileType +from shared.logging.constants import LoggingType +from shared.logging.logging import app_logger +from ui_components.methods.common_methods import combine_mask_and_input_image, random_seed +from ui_components.methods.file_methods import save_or_host_file, zip_images +from utils.constants import MLQueryObject +from utils.data_repo.data_repo import DataRepo +from utils.ml_processor.constants import ML_MODEL, ComfyWorkflow, MLModel +import json + + +MODEL_PATH_DICT = { + ComfyWorkflow.SDXL: {"workflow_path": 'comfy_workflows/sdxl_workflow_api.json', "output_node_id": 19}, + ComfyWorkflow.SDXL_IMG2IMG: {"workflow_path": 'comfy_workflows/sdxl_img2img_workflow_api.json', "output_node_id": 31}, + ComfyWorkflow.SDXL_CONTROLNET: {"workflow_path": 'comfy_workflows/sdxl_controlnet_workflow_api.json', "output_node_id": 9}, + ComfyWorkflow.SDXL_CONTROLNET_OPENPOSE: {"workflow_path": 'comfy_workflows/sdxl_openpose_workflow_api.json', "output_node_id": 9}, + ComfyWorkflow.LLAMA_2_7B: {"workflow_path": 'comfy_workflows/llama_workflow_api.json', "output_node_id": 14}, + ComfyWorkflow.SDXL_INPAINTING: {"workflow_path": 'comfy_workflows/sdxl_inpainting_workflow_api.json', "output_node_id": 56}, + ComfyWorkflow.IP_ADAPTER_PLUS: {"workflow_path": 'comfy_workflows/ipadapter_plus_api.json', "output_node_id": 29}, + ComfyWorkflow.IP_ADAPTER_FACE: {"workflow_path": 'comfy_workflows/ipadapter_face_api.json', "output_node_id": 29}, + ComfyWorkflow.IP_ADAPTER_FACE_PLUS: {"workflow_path": 'comfy_workflows/ipadapter_face_plus_api.json', "output_node_id": 29}, + ComfyWorkflow.STEERABLE_MOTION: {"workflow_path": 'comfy_workflows/steerable_motion_api.json', "output_node_id": 281} +} + +# these methods return the workflow along with the output node class name +class ComfyDataTransform: + @staticmethod + def get_workflow_json(model: ComfyWorkflow): + json_file_path = "./utils/ml_processor/" + MODEL_PATH_DICT[model]["workflow_path"] + with open(json_file_path) as f: + json_data = json.load(f) + return json_data, [MODEL_PATH_DICT[model]['output_node_id']] + + @staticmethod + def transform_sdxl_workflow(query: MLQueryObject): + workflow, output_node_ids = ComfyDataTransform.get_workflow_json(ComfyWorkflow.SDXL) + + # workflow params + height, width = query.height, query.width + positive_prompt, negative_prompt = query.prompt, query.negative_prompt + steps, cfg = query.num_inference_steps, query.guidance_scale + + # updating params + seed = random_seed() + workflow["10"]["inputs"]["noise_seed"] = seed + workflow["10"]["inputs"]["noise_seed"] = seed + workflow["5"]["width"], workflow["5"]["height"] = max(width, 1024), max(height, 1024) + workflow["6"]["inputs"]["text"] = workflow["15"]["inputs"]["text"] = positive_prompt + workflow["7"]["inputs"]["text"] = workflow["16"]["inputs"]["text"] = negative_prompt + workflow["10"]["inputs"]["steps"], workflow["10"]["inputs"]["cfg"] = steps, cfg + workflow["11"]["inputs"]["steps"], workflow["11"]["inputs"]["cfg"] = steps, cfg + + return json.dumps(workflow), output_node_ids + + @staticmethod + def transform_sdxl_img2img_workflow(query: MLQueryObject): + data_repo = DataRepo() + workflow, output_node_ids = ComfyDataTransform.get_workflow_json(ComfyWorkflow.SDXL_IMG2IMG) + + # workflow params + height, width = 1024, 1024 + positive_prompt, negative_prompt = query.prompt, query.negative_prompt + steps, cfg = 20, 7 # hardcoding values + strength = round(query.strength / 100, 1) + image = data_repo.get_file_from_uuid(query.image_uuid) + image_name = image.filename + + # updating params + workflow["37:0"]["inputs"]["image"] = image_name + workflow["42:0"]["inputs"]["text"] = positive_prompt + workflow["42:1"]["inputs"]["text"] = negative_prompt + workflow["42:2"]["inputs"]["steps"] = steps + workflow["42:2"]["inputs"]["cfg"] = cfg + workflow["42:2"]["inputs"]["denoise"] = 1 - strength + workflow["42:2"]["inputs"]["seed"] = random_seed() + + return json.dumps(workflow), output_node_ids + + + @staticmethod + def transform_sdxl_controlnet_workflow(query: MLQueryObject): + data_repo = DataRepo() + workflow, output_node_ids = ComfyDataTransform.get_workflow_json(ComfyWorkflow.SDXL_CONTROLNET) + + # workflow params + height, width = query.height, query.width + positive_prompt, negative_prompt = query.prompt, query.negative_prompt + steps, cfg = query.num_inference_steps, query.guidance_scale + low_threshold, high_threshold = query.low_threshold, query.high_threshold + image = data_repo.get_file_from_uuid(query.image_uuid) + image_name = image.filename + + # updating params + workflow["3"]["inputs"]["seed"] = random_seed() + workflow["5"]["width"], workflow["5"]["height"] = width, height + workflow["17"]["width"], workflow["17"]["height"] = width, height + workflow["6"]["inputs"]["text"], workflow["7"]["inputs"]["text"] = positive_prompt, negative_prompt + workflow["12"]["inputs"]["low_threshold"], workflow["12"]["inputs"]["high_threshold"] = low_threshold, high_threshold + workflow["3"]["inputs"]["steps"], workflow["3"]["inputs"]["cfg"] = steps, cfg + workflow["13"]["inputs"]["image"] = image_name + + return json.dumps(workflow), output_node_ids + + @staticmethod + def transform_sdxl_controlnet_openpose_workflow(query: MLQueryObject): + data_repo = DataRepo() + workflow, output_node_ids = ComfyDataTransform.get_workflow_json(ComfyWorkflow.SDXL_CONTROLNET_OPENPOSE) + + # workflow params + height, width = query.height, query.width + positive_prompt, negative_prompt = query.prompt, query.negative_prompt + steps, cfg = query.num_inference_steps, query.guidance_scale + image = data_repo.get_file_from_uuid(query.image_uuid) + image_name = image.filename + + # updating params + workflow["3"]["inputs"]["seed"] = random_seed() + workflow["5"]["width"], workflow["5"]["height"] = width, height + workflow["11"]["width"], workflow["11"]["height"] = width, height + workflow["6"]["inputs"]["text"], workflow["7"]["inputs"]["text"] = positive_prompt, negative_prompt + workflow["3"]["inputs"]["steps"], workflow["3"]["inputs"]["cfg"] = steps, cfg + workflow["12"]["inputs"]["image"] = image_name + + return json.dumps(workflow), output_node_ids + + @staticmethod + def transform_llama_2_7b_workflow(query: MLQueryObject): + workflow, output_node_ids = ComfyDataTransform.get_workflow_json(ComfyWorkflow.LLAMA_2_7B) + + # workflow params + input_text = query.prompt + temperature = query.data.get("temperature", 0.8) + + # updating params + workflow["15"]["inputs"]["prompt"] = input_text + workflow["15"]["inputs"]["temperature"] = temperature + + return json.dumps(workflow), output_node_ids + + @staticmethod + def transform_sdxl_inpainting_workflow(query: MLQueryObject): + data_repo = DataRepo() + workflow, output_node_ids = ComfyDataTransform.get_workflow_json(ComfyWorkflow.SDXL_INPAINTING) + + # workflow params + # node 'get_img_size' automatically fetches the size + positive_prompt, negative_prompt = query.prompt, query.negative_prompt + steps, cfg = query.num_inference_steps, query.guidance_scale + input_image = query.data.get('data', {}).get('input_image', None) + mask = query.data.get('data', {}).get('mask', None) + timing = data_repo.get_timing_from_uuid(query.timing_uuid) + + # inpainting workflows takes in an image and inpaints the transparent area + combined_img = combine_mask_and_input_image(mask, input_image) + filename = str(uuid.uuid4()) + ".png" + hosted_url = save_or_host_file(combined_img, "videos/temp/" + filename) + + file_data = { + "name": filename, + "type": InternalFileType.IMAGE.value, + "project_id": timing.shot.project.uuid + } + + if hosted_url: + file_data.update({'hosted_url': hosted_url}) + else: + file_data.update({'local_path': "videos/temp/" + filename}) + file = data_repo.create_file(**file_data) + + # adding the combined image in query (and removing io buffers) + query.data = { + "data": { + "file_combined_img": file.uuid + } + } + + # updating params + workflow["3"]["inputs"]["seed"] = random_seed() + workflow["20"]["inputs"]["image"] = filename + workflow["3"]["inputs"]["steps"], workflow["3"]["inputs"]["cfg"] = steps, cfg + workflow["34"]["inputs"]["text_g"] = workflow["34"]["inputs"]["text_l"] = positive_prompt + workflow["37"]["inputs"]["text_g"] = workflow["37"]["inputs"]["text_l"] = negative_prompt + + return json.dumps(workflow), output_node_ids + + @staticmethod + def transform_ipadaptor_plus_workflow(query: MLQueryObject): + data_repo = DataRepo() + workflow, output_node_ids = ComfyDataTransform.get_workflow_json(ComfyWorkflow.IP_ADAPTER_PLUS) + + # workflow params + height, width = query.height, query.width + steps, cfg = query.num_inference_steps, query.guidance_scale + image = data_repo.get_file_from_uuid(query.image_uuid) + image_name = image.filename + + # updating params + workflow["3"]["inputs"]["seed"] = random_seed() + workflow["5"]["width"], workflow["5"]["height"] = width, height + workflow["3"]["inputs"]["steps"], workflow["3"]["inputs"]["cfg"] = steps, cfg + # workflow["24"]["inputs"]["image"] = image_name # ipadapter image + workflow["28"]["inputs"]["image"] = image_name # dummy image + workflow["6"]["inputs"]["text"] = query.prompt + workflow["7"]["inputs"]["text"] = query.negative_prompt + workflow["27"]["inputs"]["weight"] = query.strength + + return json.dumps(workflow), output_node_ids + + @staticmethod + def transform_ipadaptor_face_workflow(query: MLQueryObject): + data_repo = DataRepo() + workflow, output_node_ids = ComfyDataTransform.get_workflow_json(ComfyWorkflow.IP_ADAPTER_FACE) + + # workflow params + height, width = query.height, query.width + steps, cfg = query.num_inference_steps, query.guidance_scale + image = data_repo.get_file_from_uuid(query.image_uuid) + image_name = image.filename + strength = query.strength + + # updating params + workflow["3"]["inputs"]["seed"] = random_seed() + workflow["5"]["width"], workflow["5"]["height"] = width, height + workflow["3"]["inputs"]["steps"], workflow["3"]["inputs"]["cfg"] = steps, cfg + workflow["24"]["inputs"]["image"] = image_name # ipadapter image + workflow["6"]["inputs"]["text"] = query.prompt + workflow["7"]["inputs"]["text"] = query.negative_prompt + workflow["36"]["inputs"]["weight"] = query.strength + workflow["36"]["inputs"]["weight_v2"] = query.strength + + return json.dumps(workflow), output_node_ids + + @staticmethod + def transform_ipadaptor_face_plus_workflow(query: MLQueryObject): + data_repo = DataRepo() + workflow, output_node_ids = ComfyDataTransform.get_workflow_json(ComfyWorkflow.IP_ADAPTER_FACE_PLUS) + + # workflow params + height, width = query.height, query.width + steps, cfg = query.num_inference_steps, query.guidance_scale + image = data_repo.get_file_from_uuid(query.image_uuid) + image_name = image.filename + image_2 = data_repo.get_file_from_uuid(query.data.get('data', {}).get("file_image_2_uuid", None)) + image_name_2 = image_2.filename if image_2 else None + + # updating params + workflow["3"]["inputs"]["seed"] = random_seed() + workflow["5"]["width"], workflow["5"]["height"] = width, height + workflow["3"]["inputs"]["steps"], workflow["3"]["inputs"]["cfg"] = steps, cfg + workflow["24"]["inputs"]["image"] = image_name # ipadapter image + workflow["28"]["inputs"]["image"] = image_name_2 # insight face image + workflow["6"]["inputs"]["text"] = query.prompt + workflow["7"]["inputs"]["text"] = query.negative_prompt + workflow["29"]["inputs"]["weight"] = query.strength[0] + workflow["27"]["inputs"]["weight"] = query.strength[1] + + return json.dumps(workflow), output_node_ids + + @staticmethod + def transform_steerable_motion_workflow(query: MLQueryObject): + + + sm_data = query.data.get('data', {}) + workflow, output_node_ids = ComfyDataTransform.get_workflow_json(ComfyWorkflow.STEERABLE_MOTION) + #project_settings = data_repo.get_project_setting(shot.project.uuid) + # width = project_settings.width + # height = project_settings.height + + + print(sm_data) + workflow['464']['inputs']['height'] = sm_data.get('height') + workflow['464']['inputs']['width'] = sm_data.get('width') + + workflow['461']['inputs']['ckpt_name'] = sm_data.get('ckpt') + + workflow['473']['inputs']['buffer'] = sm_data.get('buffer') + workflow['187']['inputs']['motion_scale'] = sm_data.get('motion_scale') + # workflow['548']['inputs']['text'] = sm_data.get('motion_scales') + workflow['281']['inputs']['format'] = sm_data.get('output_format') + workflow['536']['inputs']['pre_text'] = sm_data.get('prompt') + workflow['537']['inputs']['pre_text'] = sm_data.get('negative_prompt') + workflow['292']['inputs']['multiplier'] = sm_data.get('stmfnet_multiplier') + workflow['473']['inputs']['relative_ipadapter_strength'] = sm_data.get('relative_ipadapter_strength') + workflow['473']['inputs']['relative_cn_strength'] = sm_data.get('relative_cn_strength') + workflow['473']['inputs']['type_of_strength_distribution'] = sm_data.get('type_of_strength_distribution') + workflow['473']['inputs']['linear_strength_value'] = sm_data.get('linear_strength_value') + + workflow['473']['inputs']['dynamic_strength_values'] = str(sm_data.get('dynamic_strength_values'))[1:-1] + workflow['473']['inputs']['linear_frame_distribution_value'] = sm_data.get('linear_frame_distribution_value') + workflow['473']['inputs']['dynamic_frame_distribution_values'] = ', '.join(str(int(value)) for value in sm_data.get('dynamic_frame_distribution_values')) + workflow['473']['inputs']['type_of_frame_distribution'] = sm_data.get('type_of_frame_distribution') + workflow['473']['inputs']['type_of_key_frame_influence'] = sm_data.get('type_of_key_frame_influence') + workflow['473']['inputs']['linear_key_frame_influence_value'] = sm_data.get('linear_key_frame_influence_value') + + # print(dynamic_key_frame_influence_values) + workflow['473']['inputs']['dynamic_key_frame_influence_values'] = str(sm_data.get('dynamic_key_frame_influence_values'))[1:-1] + workflow['473']['inputs']['ipadapter_noise'] = sm_data.get('ipadapter_noise') + workflow['342']['inputs']['context_length'] = sm_data.get('context_length') + workflow['342']['inputs']['context_stride'] = sm_data.get('context_stride') + workflow['342']['inputs']['context_overlap'] = sm_data.get('context_overlap') + workflow['468']['inputs']['end_percent'] = sm_data.get('multipled_base_end_percent') + workflow['470']['inputs']['strength_model'] = sm_data.get('multipled_base_adapter_strength') + workflow["482"]["inputs"]["seed"] = random_seed() + workflow["536"]["inputs"]["text"] = sm_data.get('individual_prompts') + # make max_frames an int + + + workflow["536"]["inputs"]["max_frames"] = int(float(sm_data.get('max_frames'))) + workflow["537"]["inputs"]["max_frames"] = int(float(sm_data.get('max_frames'))) + workflow["537"]["inputs"]["text"] = sm_data.get('individual_negative_prompts') + + + return json.dumps(workflow), output_node_ids + + +# NOTE: only populating with models currently in use +MODEL_WORKFLOW_MAP = { + ML_MODEL.sdxl.workflow_name: ComfyDataTransform.transform_sdxl_workflow, + ML_MODEL.sdxl_controlnet.workflow_name: ComfyDataTransform.transform_sdxl_controlnet_workflow, + ML_MODEL.sdxl_controlnet_openpose.workflow_name: ComfyDataTransform.transform_sdxl_controlnet_openpose_workflow, + ML_MODEL.llama_2_7b.workflow_name: ComfyDataTransform.transform_llama_2_7b_workflow, + ML_MODEL.sdxl_inpainting.workflow_name: ComfyDataTransform.transform_sdxl_inpainting_workflow, + ML_MODEL.ipadapter_plus.workflow_name: ComfyDataTransform.transform_ipadaptor_plus_workflow, + ML_MODEL.ipadapter_face.workflow_name: ComfyDataTransform.transform_ipadaptor_face_workflow, + ML_MODEL.ipadapter_face_plus.workflow_name: ComfyDataTransform.transform_ipadaptor_face_plus_workflow, + ML_MODEL.ad_interpolation.workflow_name: ComfyDataTransform.transform_steerable_motion_workflow, + ML_MODEL.sdxl_img2img.workflow_name: ComfyDataTransform.transform_sdxl_img2img_workflow +} + +# returns stringified json of the workflow +def get_model_workflow_from_query(model: MLModel, query_obj: MLQueryObject) -> str: + if model.workflow_name not in MODEL_WORKFLOW_MAP: + app_logger.log(LoggingType.ERROR, f"model {model.workflow_name} not supported for local inference") + raise ValueError(f'Model {model.workflow_name} not supported for local inference') + + return MODEL_WORKFLOW_MAP[model.workflow_name](query_obj) + +def get_workflow_json_url(workflow_json): + from utils.ml_processor.ml_interface import get_ml_client + ml_client = get_ml_client() + temp_fd, temp_json_path = tempfile.mkstemp(suffix='.json') + + with open(temp_json_path, 'w') as temp_json_file: + temp_json_file.write(workflow_json) + + return ml_client.upload_training_data(temp_json_path, delete_after_upload=True) + +def get_file_list_from_query_obj(query_obj: MLQueryObject): + file_uuid_list = [] + + if query_obj.image_uuid: + file_uuid_list.append(query_obj.image_uuid) + + if query_obj.mask_uuid: + file_uuid_list.append(query_obj.mask_uuid) + + for k, v in query_obj.data.get('data', {}).items(): + if k.startswith("file_"): + file_uuid_list.append(v) + + return file_uuid_list + +# returns the zip file which can be passed to the comfy_runner replicate endpoint +def get_file_zip_url(file_uuid_list, index_files=False) -> str: + from utils.ml_processor.ml_interface import get_ml_client + + data_repo = DataRepo() + ml_client = get_ml_client() + + file_list = data_repo.get_image_list_from_uuid_list(file_uuid_list) + filename_list = [f.filename for f in file_list] if not index_files else [] # file names would be indexed like 1.png, 2.png ... + zip_path = zip_images([f.location for f in file_list], 'videos/temp/input_images.zip', filename_list) + + return ml_client.upload_training_data(zip_path, delete_after_upload=True) \ No newline at end of file diff --git a/utils/ml_processor/comfy_workflows/ipadapter_face_api.json b/utils/ml_processor/comfy_workflows/ipadapter_face_api.json new file mode 100644 index 00000000..d53a3ec8 --- /dev/null +++ b/utils/ml_processor/comfy_workflows/ipadapter_face_api.json @@ -0,0 +1,195 @@ +{ + "3": { + "inputs": { + "seed": 862782529735965, + "steps": 24, + "cfg": 9.25, + "sampler_name": "ddim", + "scheduler": "normal", + "denoise": 1, + "model": [ + "36", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "21": { + "inputs": { + "ipadapter_file": "ip-adapter_sdxl.safetensors" + }, + "class_type": "IPAdapterModelLoader", + "_meta": { + "title": "Load IPAdapter Model" + } + }, + "24": { + "inputs": { + "image": "rWA-3_T7_400x400.jpg", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "29": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "36": { + "inputs": { + "weight": 0.75, + "noise": 0.3, + "weight_type": "linear", + "start_at": 0, + "end_at": 1, + "faceid_v2": true, + "weight_v2": 0.75, + "unfold_batch": false, + "ipadapter": [ + "21", + 0 + ], + "clip_vision": [ + "41", + 0 + ], + "insightface": [ + "37", + 0 + ], + "image": [ + "40", + 0 + ], + "model": [ + "4", + 0 + ] + }, + "class_type": "IPAdapterApplyFaceID", + "_meta": { + "title": "Apply IPAdapter FaceID" + } + }, + "37": { + "inputs": { + "provider": "CUDA" + }, + "class_type": "InsightFaceLoader", + "_meta": { + "title": "Load InsightFace" + } + }, + "40": { + "inputs": { + "crop_position": "center", + "sharpening": 0, + "pad_around": true, + "image": [ + "24", + 0 + ] + }, + "class_type": "PrepImageForInsightFace", + "_meta": { + "title": "Prepare Image For InsightFace" + } + }, + "41": { + "inputs": { + "clip_name": "SDXL/pytorch_model.bin" + }, + "class_type": "CLIPVisionLoader", + "_meta": { + "title": "Load CLIP Vision" + } + } +} \ No newline at end of file diff --git a/utils/ml_processor/comfy_workflows/ipadapter_face_plus_api.json b/utils/ml_processor/comfy_workflows/ipadapter_face_plus_api.json new file mode 100644 index 00000000..906e0e9c --- /dev/null +++ b/utils/ml_processor/comfy_workflows/ipadapter_face_plus_api.json @@ -0,0 +1,269 @@ +{ + "3": { + "inputs": { + "seed": 244730832305022, + "steps": 24, + "cfg": 9.25, + "sampler_name": "ddim", + "scheduler": "normal", + "denoise": 1, + "model": [ + "27", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "man sitting on a bus, futuristic style", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "photography, text, watermark, blurry, haze, low contrast, low quality, underexposed, ugly, deformed, boring, bad quality, cartoon, ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, low detail, low quality, double face, 2 faces, cropped, ugly, low-res, tiling, grainy, cropped, ostentatious, ugly, oversaturated, grain, low resolution, disfigured, blurry, bad anatomy, disfigured, poorly drawn face, mutant, mutated, extra limb, ugly, poorly drawn hands, missing limbs, blurred, floating limbs, disjointed limbs, deformed hands, blurred, out of focus, long neck, long body, ugly, disgusting, childish, cut off cropped, distorted, imperfect, surreal, bad hands, text, error, extra digit, fewer digits, cropped , worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, Lots of hands, extra limbs, extra fingers, conjoined fingers, deformed fingers, old, ugly eyes, imperfect eyes, skewed eyes , unnatural face, stiff face, stiff body, unbalanced body, unnatural body, lacking body, details are not clear, cluttered, details are sticky, details are low, distorted details, ugly hands, imperfect hands, (mutated hands and fingers:1.5), (long body :1.3), (mutation, poorly drawn :1.2) bad hands, fused ha nd, missing hand, disappearing arms, hands, disappearing thigh, disappearing calf, disappearing legs, ui, missing fingers", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "21": { + "inputs": { + "ipadapter_file": "ip-adapter_sdxl.safetensors" + }, + "class_type": "IPAdapterModelLoader", + "_meta": { + "title": "Load IPAdapter Model" + } + }, + "23": { + "inputs": { + "clip_name": "SDXL/pytorch_model.bin" + }, + "class_type": "CLIPVisionLoader", + "_meta": { + "title": "Load CLIP Vision" + } + }, + "24": { + "inputs": { + "image": "boy_sunshine.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "26": { + "inputs": { + "ipadapter_file": "ip-adapter_sdxl.safetensors" + }, + "class_type": "IPAdapterModelLoader", + "_meta": { + "title": "Load IPAdapter Model" + } + }, + "27": { + "inputs": { + "weight": 0.65, + "noise": 0.3, + "weight_type": "original", + "start_at": 0, + "end_at": 0.396, + "unfold_batch": false, + "ipadapter": [ + "26", + 0 + ], + "clip_vision": [ + "23", + 0 + ], + "image": [ + "39", + 0 + ], + "model": [ + "36", + 0 + ] + }, + "class_type": "IPAdapterApply", + "_meta": { + "title": "Apply IPAdapter" + } + }, + "28": { + "inputs": { + "image": "king_dark.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "29": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "36": { + "inputs": { + "weight": 0.75, + "noise": 0.3, + "weight_type": "linear", + "start_at": 0, + "end_at": 1, + "faceid_v2": true, + "weight_v2": 0.75, + "unfold_batch": false, + "ipadapter": [ + "21", + 0 + ], + "clip_vision": [ + "41", + 0 + ], + "insightface": [ + "37", + 0 + ], + "image": [ + "40", + 0 + ], + "model": [ + "4", + 0 + ] + }, + "class_type": "IPAdapterApplyFaceID", + "_meta": { + "title": "Apply IPAdapter FaceID" + } + }, + "37": { + "inputs": { + "provider": "CUDA" + }, + "class_type": "InsightFaceLoader", + "_meta": { + "title": "Load InsightFace" + } + }, + "39": { + "inputs": { + "interpolation": "LANCZOS", + "crop_position": "pad", + "sharpening": 0, + "image": [ + "28", + 0 + ] + }, + "class_type": "PrepImageForClipVision", + "_meta": { + "title": "Prepare Image For Clip Vision" + } + }, + "40": { + "inputs": { + "crop_position": "center", + "sharpening": 0, + "pad_around": true, + "image": [ + "24", + 0 + ] + }, + "class_type": "PrepImageForInsightFace", + "_meta": { + "title": "Prepare Image For InsightFace" + } + }, + "41": { + "inputs": { + "clip_name": "SDXL/pytorch_model.bin" + }, + "class_type": "CLIPVisionLoader", + "_meta": { + "title": "Load CLIP Vision" + } + } + } + \ No newline at end of file diff --git a/utils/ml_processor/comfy_workflows/ipadapter_plus_api.json b/utils/ml_processor/comfy_workflows/ipadapter_plus_api.json new file mode 100644 index 00000000..28f35179 --- /dev/null +++ b/utils/ml_processor/comfy_workflows/ipadapter_plus_api.json @@ -0,0 +1,180 @@ +{ + "3": { + "inputs": { + "seed": 641608455784125, + "steps": 24, + "cfg": 9.25, + "sampler_name": "ddim", + "scheduler": "normal", + "denoise": 1, + "model": [ + "27", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "23": { + "inputs": { + "clip_name": "SDXL/pytorch_model.bin" + }, + "class_type": "CLIPVisionLoader", + "_meta": { + "title": "Load CLIP Vision" + } + }, + "26": { + "inputs": { + "ipadapter_file": "ip-adapter_sdxl.safetensors" + }, + "class_type": "IPAdapterModelLoader", + "_meta": { + "title": "Load IPAdapter Model" + } + }, + "27": { + "inputs": { + "weight": 0.65, + "noise": 0.3, + "weight_type": "original", + "start_at": 0, + "end_at": 0.396, + "unfold_batch": false, + "ipadapter": [ + "26", + 0 + ], + "clip_vision": [ + "23", + 0 + ], + "image": [ + "39", + 0 + ], + "model": [ + "4", + 0 + ] + }, + "class_type": "IPAdapterApply", + "_meta": { + "title": "Apply IPAdapter" + } + }, + "28": { + "inputs": { + "image": "714d97a3fe2dcf645f1b500d523d4d3c848acc65bfda3602c56305fc.jpg", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "29": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "39": { + "inputs": { + "interpolation": "LANCZOS", + "crop_position": "pad", + "sharpening": 0, + "image": [ + "28", + 0 + ] + }, + "class_type": "PrepImageForClipVision", + "_meta": { + "title": "Prepare Image For Clip Vision" + } + } +} \ No newline at end of file diff --git a/utils/ml_processor/comfy_workflows/llama_workflow_api.json b/utils/ml_processor/comfy_workflows/llama_workflow_api.json new file mode 100644 index 00000000..da5186b5 --- /dev/null +++ b/utils/ml_processor/comfy_workflows/llama_workflow_api.json @@ -0,0 +1,53 @@ +{ + "1": { + "inputs": { + "Model": "llama-2-7b.Q5_0.gguf", + "n_ctx": 0 + }, + "class_type": "Load LLM Model Basic", + "_meta": { + "title": "Load LLM Model Basic" + } + }, + "14": { + "inputs": { + "text": [ + "15", + 0 + ] + }, + "class_type": "ShowText|pysssss", + "_meta": { + "title": "Show Text 🐍" + } + }, + "15": { + "inputs": { + "prompt": "write a poem on finding your way in about 100 words", + "suffix": "", + "max_response_tokens": 500, + "temperature": 0.8, + "top_p": 0.95, + "min_p": 0.05, + "typical_p": 1, + "echo": false, + "frequency_penalty": 0, + "presence_penalty": 0, + "repeat_penalty": 1.1, + "top_k": 40, + "seed": 273, + "tfs_z": 1, + "mirostat_mode": 0, + "mirostat_tau": 5, + "mirostat_eta": 0.1, + "LLM": [ + "1", + 0 + ] + }, + "class_type": "Call LLM Advanced", + "_meta": { + "title": "Call LLM Advanced" + } + } +} \ No newline at end of file diff --git a/utils/ml_processor/comfy_workflows/sdxl_controlnet_workflow_api.json b/utils/ml_processor/comfy_workflows/sdxl_controlnet_workflow_api.json new file mode 100644 index 00000000..6e7630ed --- /dev/null +++ b/utils/ml_processor/comfy_workflows/sdxl_controlnet_workflow_api.json @@ -0,0 +1,189 @@ +{ + "3": { + "inputs": { + "seed": 741148140596738, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1, + "model": [ + "4", + 0 + ], + "positive": [ + "14", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "a person standing in an open field", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "text, watermark", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "12": { + "inputs": { + "low_threshold": 0.19999999999999984, + "high_threshold": 0.7, + "image": [ + "17", + 0 + ] + }, + "class_type": "Canny", + "_meta": { + "title": "Canny" + } + }, + "13": { + "inputs": { + "image": "boy_sunshine.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "14": { + "inputs": { + "strength": 1, + "conditioning": [ + "6", + 0 + ], + "control_net": [ + "22", + 0 + ], + "image": [ + "12", + 0 + ] + }, + "class_type": "ControlNetApply", + "_meta": { + "title": "Apply ControlNet" + } + }, + "17": { + "inputs": { + "upscale_method": "nearest-exact", + "width": 1024, + "height": 1024, + "crop": "center", + "image": [ + "13", + 0 + ] + }, + "class_type": "ImageScale", + "_meta": { + "title": "Upscale Image" + } + }, + "22": { + "inputs": { + "control_net_name": "canny_diffusion_pytorch_model.safetensors" + }, + "class_type": "ControlNetLoader", + "_meta": { + "title": "Load ControlNet Model" + } + }, + "25": { + "inputs": { + "images": [ + "12", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + } +} \ No newline at end of file diff --git a/utils/ml_processor/comfy_workflows/sdxl_img2img_workflow_api.json b/utils/ml_processor/comfy_workflows/sdxl_img2img_workflow_api.json new file mode 100644 index 00000000..d1e90a01 --- /dev/null +++ b/utils/ml_processor/comfy_workflows/sdxl_img2img_workflow_api.json @@ -0,0 +1,134 @@ +{ + "1": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "31": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "44:0", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "37:0": { + "inputs": { + "image": "cca52d68-91aa-4db8-b724-2370d03ff987.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "37:1": { + "inputs": { + "pixels": [ + "37:0", + 0 + ], + "vae": [ + "1", + 2 + ] + }, + "class_type": "VAEEncode", + "_meta": { + "title": "VAE Encode" + } + }, + "42:0": { + "inputs": { + "text": "pic of a king", + "clip": [ + "1", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "Positive prompt" + } + }, + "42:1": { + "inputs": { + "text": "", + "clip": [ + "1", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "Negative prompt (not used)" + } + }, + "42:2": { + "inputs": { + "seed": 89273174590337, + "steps": 20, + "cfg": 7, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 0.6, + "model": [ + "1", + 0 + ], + "positive": [ + "42:0", + 0 + ], + "negative": [ + "42:1", + 0 + ], + "latent_image": [ + "37:1", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "44:0": { + "inputs": { + "samples": [ + "42:2", + 0 + ], + "vae": [ + "1", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "44:1": { + "inputs": { + "images": [ + "44:0", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + } +} \ No newline at end of file diff --git a/utils/ml_processor/comfy_workflows/sdxl_inpainting_workflow_api.json b/utils/ml_processor/comfy_workflows/sdxl_inpainting_workflow_api.json new file mode 100644 index 00000000..e45d1d6b --- /dev/null +++ b/utils/ml_processor/comfy_workflows/sdxl_inpainting_workflow_api.json @@ -0,0 +1,304 @@ +{ + "3": { + "inputs": { + "seed": 996231241255407, + "steps": 16, + "cfg": 6, + "sampler_name": "dpmpp_sde", + "scheduler": "karras", + "denoise": 1, + "model": [ + "49", + 0 + ], + "positive": [ + "34", + 0 + ], + "negative": [ + "37", + 0 + ], + "latent_image": [ + "26", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "29", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "20": { + "inputs": { + "image": "boy_sunshine.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "26": { + "inputs": { + "grow_mask_by": 1, + "pixels": [ + "20", + 0 + ], + "vae": [ + "29", + 2 + ], + "mask": [ + "20", + 1 + ] + }, + "class_type": "VAEEncodeForInpaint", + "_meta": { + "title": "VAE Encode (for Inpainting)" + } + }, + "29": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "30": { + "inputs": { + "unet_name": "inpainting_diffusion_pytorch_model.fp16.safetensors" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "UNETLoader" + } + }, + "33": { + "inputs": { + "image": [ + "20", + 0 + ] + }, + "class_type": "Get image size", + "_meta": { + "title": "Get image size" + } + }, + "34": { + "inputs": { + "width": [ + "33", + 0 + ], + "height": [ + "33", + 1 + ], + "crop_w": 0, + "crop_h": 0, + "target_width": [ + "33", + 0 + ], + "target_height": [ + "33", + 1 + ], + "text_g": "man fishing, ZipRealism, Zip2D", + "text_l": "man fishing, ZipRealism, Zip2D", + "clip": [ + "49", + 1 + ] + }, + "class_type": "CLIPTextEncodeSDXL", + "_meta": { + "title": "CLIPTextEncodeSDXL" + } + }, + "37": { + "inputs": { + "width": 1024, + "height": 1024, + "crop_w": 0, + "crop_h": 0, + "target_width": 1024, + "target_height": 1024, + "text_g": "ZipRealism_Neg, AC_Neg1, AC_Neg2,", + "text_l": "ZipRealism_Neg, AC_Neg1, AC_Neg2,", + "clip": [ + "29", + 1 + ] + }, + "class_type": "CLIPTextEncodeSDXL", + "_meta": { + "title": "CLIPTextEncodeSDXL" + } + }, + "49": { + "inputs": { + "switch_1": "Off", + "lora_name_1": "None", + "strength_model_1": 1.3, + "strength_clip_1": 1, + "switch_2": "Off", + "lora_name_2": "None", + "strength_model_2": 1, + "strength_clip_2": 1, + "switch_3": "Off", + "lora_name_3": "None", + "strength_model_3": 1, + "strength_clip_3": 1, + "model": [ + "30", + 0 + ], + "clip": [ + "29", + 1 + ] + }, + "class_type": "LoraStackLoader_PoP", + "_meta": { + "title": "LoraStackLoader_PoP" + } + }, + "50": { + "inputs": { + "ascore": 6, + "width": 1024, + "height": 1024, + "text": "man fishing, ZipRealism, Zip2D", + "clip": [ + "51", + 1 + ] + }, + "class_type": "CLIPTextEncodeSDXLRefiner", + "_meta": { + "title": "CLIPTextEncodeSDXLRefiner" + } + }, + "51": { + "inputs": { + "ckpt_name": "sd_xl_refiner_1.0_0.9vae.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "52": { + "inputs": { + "ascore": 6, + "width": 1024, + "height": 1024, + "text": "ZipRealism_Neg, AC_Neg1, AC_Neg2,", + "clip": [ + "51", + 1 + ] + }, + "class_type": "CLIPTextEncodeSDXLRefiner", + "_meta": { + "title": "CLIPTextEncodeSDXLRefiner" + } + }, + "54": { + "inputs": { + "add_noise": "enable", + "noise_seed": 0, + "steps": 20, + "cfg": 6, + "sampler_name": "dpmpp_sde", + "scheduler": "karras", + "start_at_step": 16, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": [ + "51", + 0 + ], + "positive": [ + "50", + 0 + ], + "negative": [ + "52", + 0 + ], + "latent_image": [ + "3", + 0 + ] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced)" + } + }, + "55": { + "inputs": { + "samples": [ + "54", + 0 + ], + "vae": [ + "51", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "56": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "55", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "57": { + "inputs": { + "images": [ + "8", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + } +} \ No newline at end of file diff --git a/utils/ml_processor/comfy_workflows/sdxl_openpose_workflow_api.json b/utils/ml_processor/comfy_workflows/sdxl_openpose_workflow_api.json new file mode 100644 index 00000000..50dff7a5 --- /dev/null +++ b/utils/ml_processor/comfy_workflows/sdxl_openpose_workflow_api.json @@ -0,0 +1,191 @@ +{ + "3": { + "inputs": { + "seed": 253663277835217, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1, + "model": [ + "4", + 0 + ], + "positive": [ + "16", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0_0.9vae.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 624, + "height": 624, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "a ballerina, romantic sunset, 4k photo", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "text, watermark", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "10": { + "inputs": { + "detect_hand": "enable", + "detect_body": "enable", + "detect_face": "enable", + "resolution": "v1.1", + "image": [ + "11", + 0 + ] + }, + "class_type": "OpenposePreprocessor", + "_meta": { + "title": "OpenPose Pose" + } + }, + "11": { + "inputs": { + "upscale_method": "nearest-exact", + "width": 623, + "height": 623, + "crop": "disabled", + "image": [ + "12", + 0 + ] + }, + "class_type": "ImageScale", + "_meta": { + "title": "Upscale Image" + } + }, + "12": { + "inputs": { + "image": "boy_sunshine.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "13": { + "inputs": { + "images": [ + "10", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + }, + "14": { + "inputs": { + "control_net_name": "OpenPoseXL2.safetensors" + }, + "class_type": "ControlNetLoader", + "_meta": { + "title": "Load ControlNet Model" + } + }, + "16": { + "inputs": { + "strength": 1, + "conditioning": [ + "6", + 0 + ], + "control_net": [ + "14", + 0 + ], + "image": [ + "10", + 0 + ] + }, + "class_type": "ControlNetApply", + "_meta": { + "title": "Apply ControlNet" + } + } +} \ No newline at end of file diff --git a/utils/ml_processor/comfy_workflows/sdxl_workflow_api.json b/utils/ml_processor/comfy_workflows/sdxl_workflow_api.json new file mode 100644 index 00000000..c47e562e --- /dev/null +++ b/utils/ml_processor/comfy_workflows/sdxl_workflow_api.json @@ -0,0 +1,178 @@ +{ + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint - BASE" + } + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "evening sunset scenery blue sky nature, glass bottle with a galaxy in it", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "text, watermark", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "10": { + "inputs": { + "add_noise": "enable", + "noise_seed": 721897303308196, + "steps": 25, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 0, + "end_at_step": 20, + "return_with_leftover_noise": "enable", + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced) - BASE" + } + }, + "11": { + "inputs": { + "add_noise": "disable", + "noise_seed": 0, + "steps": 25, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 20, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": [ + "12", + 0 + ], + "positive": [ + "15", + 0 + ], + "negative": [ + "16", + 0 + ], + "latent_image": [ + "10", + 0 + ] + }, + "class_type": "KSamplerAdvanced", + "_meta": { + "title": "KSampler (Advanced) - REFINER" + } + }, + "12": { + "inputs": { + "ckpt_name": "sd_xl_refiner_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint - REFINER" + } + }, + "15": { + "inputs": { + "text": "evening sunset scenery blue sky nature, glass bottle with a galaxy in it", + "clip": [ + "12", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "16": { + "inputs": { + "text": "text, watermark", + "clip": [ + "12", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "17": { + "inputs": { + "samples": [ + "11", + 0 + ], + "vae": [ + "12", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "19": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "17", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + } +} \ No newline at end of file diff --git a/utils/ml_processor/comfy_workflows/steerable_motion_api.json b/utils/ml_processor/comfy_workflows/steerable_motion_api.json new file mode 100644 index 00000000..f9884c9d --- /dev/null +++ b/utils/ml_processor/comfy_workflows/steerable_motion_api.json @@ -0,0 +1,451 @@ +{ + "187": { + "inputs": { + "model_name": "v3_sd15_mm.ckpt", + "beta_schedule": "sqrt_linear (AnimateDiff)", + "motion_scale": 1.3, + "apply_v2_models_properly": false, + "model": [ + "473", + 3 + ], + "context_options": [ + "342", + 0 + ] + }, + "class_type": "ADE_AnimateDiffLoaderWithContext", + "_meta": { + "title": "AnimateDiff Loader [Legacy] πŸŽ­πŸ…πŸ…“β‘ " + } + }, + "207": { + "inputs": { + "add_noise": "enable", + "noise_seed": 5, + "steps": 20, + "cfg": 8, + "sampler_name": "euler_ancestral", + "scheduler": "normal", + "start_at_step": 0, + "end_at_step": 20, + "return_with_leftover_noise": "disable", + "preview_method": "auto", + "vae_decode": "true", + "model": [ + "187", + 0 + ], + "positive": [ + "505", + 0 + ], + "negative": [ + "505", + 1 + ], + "latent_image": [ + "464", + 0 + ], + "optional_vae": [ + "458", + 0 + ] + }, + "class_type": "KSampler Adv. (Efficient)", + "_meta": { + "title": "KSampler Adv. (Efficient), CN sampler" + } + }, + "281": { + "inputs": { + "frame_rate": 16, + "loop_count": 0, + "filename_prefix": "steerable-motion/AD_", + "format": "video/h264-mp4", + "pix_fmt": "yuv420p", + "crf": 20, + "save_metadata": true, + "pingpong": false, + "save_output": true, + "images": [ + "292", + 0 + ] + }, + "class_type": "VHS_VideoCombine", + "_meta": { + "title": "Video Combine πŸŽ₯πŸ…₯πŸ…—πŸ…’" + } + }, + "292": { + "inputs": { + "ckpt_name": "stmfnet.pth", + "clear_cache_after_n_frames": 15, + "multiplier": 2, + "duplicate_first_last_frames": true, + "cache_in_fp16": false, + "frames": [ + "354", + 2 + ] + }, + "class_type": "STMFNet VFI", + "_meta": { + "title": "STMFNet VFI" + } + }, + "342": { + "inputs": { + "context_length": 16, + "context_stride": 2, + "context_overlap": 4, + "context_schedule": "uniform", + "closed_loop": false, + "fuse_method": "flat", + "use_on_equal_length": false, + "start_percent": 0, + "guarantee_steps": 1 + }, + "class_type": "ADE_AnimateDiffUniformContextOptions", + "_meta": { + "title": "Context Optionsβ—†Looped Uniform πŸŽ­πŸ…πŸ…“" + } + }, + "354": { + "inputs": { + "split_index": 4, + "images": [ + "207", + 5 + ] + }, + "class_type": "VHS_SplitImages", + "_meta": { + "title": "Split Image Batch πŸŽ₯πŸ…₯πŸ…—πŸ…’" + } + }, + "369": { + "inputs": { + "ipadapter_file": "ip-adapter-plus_sd15.bin" + }, + "class_type": "IPAdapterModelLoader", + "_meta": { + "title": "Load IPAdapter Model" + } + }, + "370": { + "inputs": { + "clip_name": "SD1.5/pytorch_model.bin" + }, + "class_type": "CLIPVisionLoader", + "_meta": { + "title": "Load CLIP Vision" + } + }, + "389": { + "inputs": { + "images": [ + "401", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + }, + "401": { + "inputs": { + "directory": "./ComfyUI/input/", + "image_load_cap": 0, + "skip_first_images": 0, + "select_every_nth": 1 + }, + "class_type": "VHS_LoadImagesPath", + "_meta": { + "title": "Load Images (Path) πŸŽ₯πŸ…₯πŸ…—πŸ…’" + } + }, + "436": { + "inputs": { + "images": [ + "473", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + }, + "458": { + "inputs": { + "vae_name": "vae-ft-mse-840000-ema-pruned.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "461": { + "inputs": { + "ckpt_name": "Realistic_Vision_V5.1.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "464": { + "inputs": { + "width": [ + "508", + 0 + ], + "height": [ + "508", + 1 + ], + "batch_size": [ + "473", + 5 + ] + }, + "class_type": "ADE_EmptyLatentImageLarge", + "_meta": { + "title": "Empty Latent Image (Big Batch) πŸŽ­πŸ…πŸ…“" + } + }, + "467": { + "inputs": { + "sparsectrl_name": "v3_sd15_sparsectrl_rgb.ckpt", + "use_motion": true, + "motion_strength": 1, + "motion_scale": 1, + "sparse_method": [ + "473", + 4 + ] + }, + "class_type": "ACN_SparseCtrlLoaderAdvanced", + "_meta": { + "title": "Load SparseCtrl Model πŸ›‚πŸ…πŸ…’πŸ…" + } + }, + "468": { + "inputs": { + "strength": 0.6, + "start_percent": 0, + "end_percent": 0.05, + "positive": [ + "473", + 1 + ], + "negative": [ + "473", + 2 + ], + "control_net": [ + "467", + 0 + ], + "image": [ + "469", + 0 + ] + }, + "class_type": "ACN_AdvancedControlNetApply", + "_meta": { + "title": "Apply Advanced ControlNet πŸ›‚πŸ…πŸ…’πŸ…" + } + }, + "469": { + "inputs": { + "image": [ + "401", + 0 + ], + "vae": [ + "458", + 0 + ], + "latent_size": [ + "464", + 0 + ] + }, + "class_type": "ACN_SparseCtrlRGBPreprocessor", + "_meta": { + "title": "RGB SparseCtrl πŸ›‚πŸ…πŸ…’πŸ…" + } + }, + "470": { + "inputs": { + "lora_name": "v3_sd15_adapter.ckpt", + "strength_model": 0.01, + "strength_clip": 0.25, + "model": [ + "461", + 0 + ], + "clip": [ + "461", + 1 + ] + }, + "class_type": "LoraLoader", + "_meta": { + "title": "Load LoRA" + } + }, + "473": { + "inputs": { + "control_net_name": "control_v11f1e_sd15_tile_fp16.safetensors", + "type_of_frame_distribution": "linear", + "linear_frame_distribution_value": 16, + "dynamic_frame_distribution_values": "0,16,32,48", + "type_of_key_frame_influence": "linear", + "linear_key_frame_influence_value": "0.75", + "dynamic_key_frame_influence_values": "", + "type_of_strength_distribution": "linear", + "linear_strength_value": "(0.4,0.5,0.4)", + "dynamic_strength_values": "(0.0,1.0),(0.0,1.0),(0.0,1.0),(0.0,1.0)", + "soft_scaled_cn_weights_multiplier": 0.85, + "buffer": 4, + "relative_cn_strength": 0, + "relative_ipadapter_strength": 1, + "ipadapter_noise": 0.2, + "ipadapter_start_at": 0, + "ipadapter_end_at": 0.6, + "cn_start_at": 0, + "cn_end_at": 0.65, + "positive": [ + "536", + 0 + ], + "negative": [ + "537", + 1 + ], + "images": [ + "401", + 0 + ], + "model": [ + "470", + 0 + ], + "ipadapter": [ + "369", + 0 + ], + "clip_vision": [ + "370", + 0 + ] + }, + "class_type": "BatchCreativeInterpolation", + "_meta": { + "title": "Batch Creative Interpolation πŸŽžοΈπŸ…’πŸ…œ" + } + }, + "482": { + "inputs": { + "seed": 32, + "steps": 25, + "cfg": 5, + "sampler_name": "dpmpp_2m_sde_gpu", + "scheduler": "exponential", + "denoise": 1 + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "505": { + "inputs": { + "strength": 0.5, + "start_percent": 0.6, + "end_percent": 0.675, + "positive": [ + "468", + 0 + ], + "negative": [ + "468", + 1 + ], + "control_net": [ + "467", + 0 + ], + "image": [ + "469", + 0 + ] + }, + "class_type": "ACN_AdvancedControlNetApply", + "_meta": { + "title": "Apply Advanced ControlNet πŸ›‚πŸ…πŸ…’πŸ…" + } + }, + "508": { + "inputs": { + "image": [ + "401", + 0 + ] + }, + "class_type": "GetImageSize+", + "_meta": { + "title": "πŸ”§ Get Image Size" + } + }, + "536": { + "inputs": { + "text": "\"4\": \"\", \"36\": \"\", \"68\": \"\"", + "max_frames": 120, + "current_frame": 0, + "print_output": false, + "pre_text": "", + "app_text": "", + "pw_a": 0, + "pw_b": 0, + "pw_c": 0, + "pw_d": 0, + "clip": [ + "470", + 1 + ] + }, + "class_type": "PromptSchedule", + "_meta": { + "title": "Positive Prompt" + } + }, + "537": { + "inputs": { + "text": "\"4\": \"\", \"36\": \"\", \"68\": \"\"", + "max_frames": 120, + "current_frame": 0, + "print_output": false, + "pre_text": "", + "app_text": "", + "pw_a": 0, + "pw_b": 0, + "pw_c": 0, + "pw_d": 0, + "clip": [ + "470", + 1 + ] + }, + "class_type": "PromptSchedule", + "_meta": { + "title": "Negative Prompt" + } + } +} \ No newline at end of file diff --git a/utils/ml_processor/constants.py b/utils/ml_processor/constants.py new file mode 100644 index 00000000..4260b95b --- /dev/null +++ b/utils/ml_processor/constants.py @@ -0,0 +1,120 @@ +from dataclasses import dataclass +from shared.constants import InferenceStatus +from utils.enum import ExtendedEnum + + +class ComfyWorkflow(ExtendedEnum): + IP_ADAPTER_PLUS = "ip_adapter_plus" + IP_ADAPTER_FACE = "ip_adapter_face" + IP_ADAPTER_FACE_PLUS = "ip_adapter_face_plus" + SDXL = "sdxl" + SDXL_CONTROLNET = "sdxl_controlnet" + SDXL_CONTROLNET_OPENPOSE = "sdxl_controlnet_openpose" + LLAMA_2_7B = "llama_2_7b" + SDXL_INPAINTING = "sdxl-inpainting" + STEERABLE_MOTION = "steerable_motion" + SDXL_IMG2IMG = "sdxl_img2img" + +@dataclass +class MLModel: + # properties for replicate (result of ad-hoc coding new features :<) + name: str + version: str + + # workflow name (multiple workflows can be run through a common replicate endpoint) + workflow_name: str = None + + def display_name(self): + for model in ML_MODEL.__dict__.values(): + if isinstance(model, MLModel): + if (self.workflow_name and model.workflow_name != self.workflow_name): + continue + + if self.name == model.name: + return model.workflow_name.value if model.workflow_name else model.name.split("/")[-1] + return None + + +# comfy runner replicate endpoint +class ComfyRunnerModel: + name = "voku682/comfy_runner" + version = "36d691e7ae92a8f29194bb6ee5aa61a6ab23c77ad7fb5b2cb6f31641512ca21c" + +class ML_MODEL: + sdxl_inpainting = MLModel("lucataco/sdxl-inpainting", "f03c01943bacdee38d6a5d216586bf9bfbfd799350aed263aa32980efc173f0b") + clones_lora_training = MLModel("cloneofsimo/lora-training", "b2a308762e36ac48d16bfadc03a65493fe6e799f429f7941639a6acec5b276cc") + clones_lora_training_2 = MLModel("cloneofsimo/lora", "fce477182f407ffd66b94b08e761424cabd13b82b518754b83080bc75ad32466") + google_frame_interpolation = MLModel("google-research/frame-interpolation", "4f88a16a13673a8b589c18866e540556170a5bcb2ccdc12de556e800e9456d3d") + pollination_modnet = MLModel("pollinations/modnet", "da7d45f3b836795f945f221fc0b01a6d3ab7f5e163f13208948ad436001e2255") + clip_interrogator = MLModel("pharmapsychotic/clip-interrogator", "a4a8bafd6089e1716b06057c42b19378250d008b80fe87caa5cd36d40c1eda90") + gfp_gan = MLModel("xinntao/gfpgan", "6129309904ce4debfde78de5c209bce0022af40e197e132f08be8ccce3050393") + ghost_face_swap = MLModel("arielreplicate/ghost_face_swap", "106df0aaf9690354379d8cd291ad337f6b3ea02fe07d90feb1dafd64820066fa") + stylegan_nada = MLModel("rinongal/stylegan-nada", "6b2af4ac56fa2384f8f86fc7620943d5fc7689dcbb6183733743a215296d0e30") + img2img_sd_2_1 = MLModel("cjwbw/stable-diffusion-img2img-v2.1", "650c347f19a96c8a0379db998c4cd092e0734534591b16a60df9942d11dec15b") + cjwbw_style_hair = MLModel("cjwbw/style-your-hair", "c4c7e5a657e2e1abccd57625093522a9928edeccee77e3f55d57c664bcd96fa2") + depth2img_sd = MLModel("jagilley/stable-diffusion-depth2img", "68f699d395bc7c17008283a7cef6d92edc832d8dc59eb41a6cafec7fc70b85bc") + salesforce_blip_2 = MLModel("salesforce/blip-2", "4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608") + phamquiluan_face_recognition = MLModel("phamquiluan/facial-expression-recognition", "b16694d5bfed43612f1bfad7015cf2b7883b732651c383fe174d4b7783775ff5") + arielreplicate = MLModel("arielreplicate/instruct-pix2pix", "10e63b0e6361eb23a0374f4d9ee145824d9d09f7a31dcd70803193ebc7121430") + cjwbw_midas = MLModel("cjwbw/midas", "a6ba5798f04f80d3b314de0f0a62277f21ab3503c60c84d4817de83c5edfdae0") + jagilley_controlnet_normal = MLModel("jagilley/controlnet-normal", "cc8066f617b6c99fdb134bc1195c5291cf2610875da4985a39de50ee1f46d81c") + jagilley_controlnet_canny = MLModel("jagilley/controlnet-canny", "aff48af9c68d162388d230a2ab003f68d2638d88307bdaf1c2f1ac95079c9613") + jagilley_controlnet_hed = MLModel("jagilley/controlnet-hed", "cde353130c86f37d0af4060cd757ab3009cac68eb58df216768f907f0d0a0653") + jagilley_controlnet_scribble = MLModel("jagilley/controlnet-scribble", "435061a1b5a4c1e26740464bf786efdfa9cb3a3ac488595a2de23e143fdb0117") + jagilley_controlnet_seg = MLModel("jagilley/controlnet-seg", "f967b165f4cd2e151d11e7450a8214e5d22ad2007f042f2f891ca3981dbfba0d") + jagilley_controlnet_hough = MLModel("jagilley/controlnet-hough", "854e8727697a057c525cdb45ab037f64ecca770a1769cc52287c2e56472a247b") + jagilley_controlnet_depth2img = MLModel("jagilley/controlnet-depth2img", "922c7bb67b87ec32cbc2fd11b1d5f94f0ba4f5519c4dbd02856376444127cc60") + jagilley_controlnet_pose = MLModel("jagilley/controlnet-pose", "0304f7f774ba7341ef754231f794b1ba3d129e3c46af3022241325ae0c50fb99") + real_esrgan_upscale = MLModel("cjwbw/real-esrgan", "d0ee3d708c9b911f122a4ad90046c5d26a0293b99476d697f6bb7f2e251ce2d4") + controlnet_1_1_x_realistic_vision_v2_0 = MLModel("usamaehsan/controlnet-1.1-x-realistic-vision-v2.0", "7fbf4c86671738f97896c9cb4922705adfcdcf54a6edab193bb8c176c6b34a69") + urpm = MLModel("mcai/urpm-v1.3-img2img", "4df956e8dbfebf1afaf0c3ee98ad426ec58c4262d24360d054582e5eab2cb5f6") + sdxl = MLModel("stability-ai/sdxl", "af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33", ComfyWorkflow.SDXL) + sdxl_img2img = MLModel("stability-ai/sdxl", "af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33", ComfyWorkflow.SDXL_IMG2IMG) + + # addition 30/9/2023 + realistic_vision_v5 = MLModel("heedster/realistic-vision-v5", "c0259010b93e7a4102a4ba946d70e06d7d0c7dc007201af443cfc8f943ab1d3c") + deliberate_v3 = MLModel("pagebrain/deliberate-v3", "1851b62340ae657f05f8b8c8a020e3f9a46efde9fe80f273eef026c0003252ac") + dreamshaper_v7 = MLModel("pagebrain/dreamshaper-v7", "0deba88df4e49b302585e1a7b6bd155e18962c1048966a40fe60ba05805743ff") + epicrealism_v5 = MLModel("pagebrain/epicrealism-v5", "222465e57e4d9812207f14133c9499d47d706ecc41a8bf400120285b2f030b42") + sdxl_controlnet = MLModel("lucataco/sdxl-controlnet", "db2ffdbdc7f6cb4d6dab512434679ee3366ae7ab84f89750f8947d5594b79a47", ComfyWorkflow.SDXL_CONTROLNET) + realistic_vision_v5_img2img = MLModel("lucataco/realistic-vision-v5-img2img", "82bbb4595458d6be142450fc6d8c4d79c936b92bd184dd2d6dd71d0796159819") + ad_interpolation = MLModel(ComfyRunnerModel.name, ComfyRunnerModel.version, ComfyWorkflow.STEERABLE_MOTION) + + # addition 17/10/2023 + llama_2_7b = MLModel("meta/llama-2-7b", "527827021d8756c7ab79fde0abbfaac885c37a3ed5fe23c7465093f0878d55ef", ComfyWorkflow.LLAMA_2_7B) + + # addition 11/11/2023 + sdxl_controlnet_openpose = MLModel("lucataco/sdxl-controlnet-openpose", "d63e0b238b2d963d90348e2dad19830fbe372a7a43d90d234b2b63cae76d4397", ComfyWorkflow.SDXL_CONTROLNET_OPENPOSE) + + # addition 05/02/2024 (workflows) + ipadapter_plus = MLModel(ComfyRunnerModel.name, ComfyRunnerModel.version, ComfyWorkflow.IP_ADAPTER_PLUS) + ipadapter_face = MLModel(ComfyRunnerModel.name, ComfyRunnerModel.version, ComfyWorkflow.IP_ADAPTER_FACE) + ipadapter_face_plus = MLModel(ComfyRunnerModel.name, ComfyRunnerModel.version, ComfyWorkflow.IP_ADAPTER_FACE_PLUS) + + @staticmethod + def get_model_by_db_obj(model_db_obj): + for model in ML_MODEL.__dict__.values(): + if isinstance(model, MLModel) and model.name == model_db_obj.replicate_url and model.version == model_db_obj.version: + return model + return None + +DEFAULT_LORA_MODEL_URL = "https://replicate.delivery/pbxt/nWm6eP9ojwVvBCaWoWZVawOKRfgxPJmkVk13ES7PX36Y66kQA/tmpxuz6k_k2datazip.safetensors" + +CONTROLNET_MODELS = [ + ML_MODEL.jagilley_controlnet_normal, + ML_MODEL.jagilley_controlnet_canny, + ML_MODEL.jagilley_controlnet_hed, + ML_MODEL.jagilley_controlnet_scribble, + ML_MODEL.jagilley_controlnet_seg, + ML_MODEL.jagilley_controlnet_hough, + ML_MODEL.jagilley_controlnet_depth2img, + ML_MODEL.jagilley_controlnet_pose, +] + +replicate_status_map = { + "starting": InferenceStatus.QUEUED.value, + "processing": InferenceStatus.IN_PROGRESS.value, + "succeeded": InferenceStatus.COMPLETED.value, + "failed": InferenceStatus.FAILED.value, + "canceled": InferenceStatus.CANCELED.value +} \ No newline at end of file diff --git a/utils/ml_processor/gpu/gpu.py b/utils/ml_processor/gpu/gpu.py new file mode 100644 index 00000000..1f6f0c0d --- /dev/null +++ b/utils/ml_processor/gpu/gpu.py @@ -0,0 +1,108 @@ +import json +from shared.constants import InferenceParamType +from shared.logging.logging import AppLogger +from ui_components.methods.data_logger import log_model_inference +from ui_components.methods.file_methods import normalize_size_internal_file_obj +from utils.constants import MLQueryObject +from utils.data_repo.data_repo import DataRepo +from utils.ml_processor.comfy_data_transform import get_model_workflow_from_query +from utils.ml_processor.constants import ML_MODEL, ComfyWorkflow, MLModel +from utils.ml_processor.gpu.utils import predict_gpu_output, setup_comfy_runner +from utils.ml_processor.ml_interface import MachineLearningProcessor +import time + + +# NOTE: add credit management methods such update_usage_credits, check_usage_credits etc.. for hosting +class GPUProcessor(MachineLearningProcessor): + def __init__(self): + setup_comfy_runner() + data_repo = DataRepo() + self.app_settings = data_repo.get_app_secrets_from_user_uuid() + super().__init__() + + def predict_model_output_standardized(self, model: MLModel, query_obj: MLQueryObject, queue_inference=False): + data_repo = DataRepo() + workflow_json, output_node_ids = get_model_workflow_from_query(model, query_obj) + file_uuid_list = [] + + if query_obj.image_uuid: + file_uuid_list.append(query_obj.image_uuid) + + for k, v in query_obj.data.get('data', {}).items(): + if k.startswith("file_"): + file_uuid_list.append(v) + + file_list = data_repo.get_image_list_from_uuid_list(file_uuid_list) + + models_using_sdxl = [ + ComfyWorkflow.SDXL.value, + ComfyWorkflow.SDXL_IMG2IMG.value, + ComfyWorkflow.SDXL_CONTROLNET.value, + ComfyWorkflow.SDXL_INPAINTING.value, + ComfyWorkflow.IP_ADAPTER_FACE.value, + ComfyWorkflow.IP_ADAPTER_FACE_PLUS.value, + ComfyWorkflow.IP_ADAPTER_PLUS.value + ] + + # maps old_file_name : new_resized_file_name + new_file_map = {} + if model.display_name() in models_using_sdxl: + res = [] + for file in file_list: + new_width, new_height = 1024 if query_obj.width == 512 else 768, 1024 if query_obj.height == 512 else 768 + # although the new_file created using create_new_file has the same location as the original file, it is + # scaled to the original resolution after inference save (so resize has no effect) + new_file = normalize_size_internal_file_obj(file, dim=[new_width, new_height], create_new_file=True) + res.append(new_file) + new_file_map[file.filename] = new_file.filename + + file_list = res + + file_path_list = [f.location for f in file_list] + + # replacing old files with resized files + # if len(new_file_map.keys()): + # workflow_json = json.loads(workflow_json) + # for node in workflow_json: + # if "inputs" in workflow_json[node]: + # for k, v in workflow_json[node]["inputs"].items(): + # if isinstance(v, str) and v in new_file_map: + # workflow_json[node]["inputs"][k] = new_file_map[v] + + # workflow_json = json.dumps(workflow_json) + + data = { + "workflow_input": workflow_json, + "file_path_list": file_path_list, + "output_node_ids": output_node_ids + } + + params = { + "prompt": query_obj.prompt, # hackish sol + InferenceParamType.QUERY_DICT.value: query_obj.to_json(), + InferenceParamType.GPU_INFERENCE.value: json.dumps(data) + } + return self.predict_model_output(model, **params) if not queue_inference else self.queue_prediction(model, **params) + + def predict_model_output(self, replicate_model: MLModel, **kwargs): + queue_inference = kwargs.get('queue_inference', False) + if queue_inference: + return self.queue_prediction(replicate_model, **kwargs) + + data = kwargs.get(InferenceParamType.GPU_INFERENCE.value, None) + data = json.loads(data) + start_time = time.time() + output = predict_gpu_output(data['workflow_input'], data['file_path_list'], data['output_node_ids']) + end_time = time.time() + + log = log_model_inference(replicate_model, end_time - start_time, **kwargs) + return output, log + + def queue_prediction(self, replicate_model, **kwargs): + log = log_model_inference(replicate_model, None, **kwargs) + return None, log + + def upload_training_data(self, zip_file_name, delete_after_upload=False): + # TODO: fix for online hosting + # return the local file path as it is + return zip_file_name \ No newline at end of file diff --git a/utils/ml_processor/gpu/utils.py b/utils/ml_processor/gpu/utils.py new file mode 100644 index 00000000..b6fa69a4 --- /dev/null +++ b/utils/ml_processor/gpu/utils.py @@ -0,0 +1,48 @@ +import importlib +import os +import sys +import subprocess +import time +from git import Repo +from shared.logging.constants import LoggingType +from shared.logging.logging import app_logger + + +COMFY_RUNNER_PATH = "./comfy_runner" + +def predict_gpu_output(workflow: str, file_path_list=[], output_node=None) -> str: + # spec = importlib.util.spec_from_file_location('my_module', f'{COMFY_RUNNER_PATH}/inf.py') + # comfy_runner = importlib.util.module_from_spec(spec) + # spec.loader.exec_module(comfy_runner) + + # hackish sol.. waiting for comfy repo to be cloned + while not is_comfy_runner_present(): + time.sleep(2) + + sys.path.append(str(os.getcwd()) + COMFY_RUNNER_PATH[1:]) + from comfy_runner.inf import ComfyRunner + + comfy_runner = ComfyRunner() + output = comfy_runner.predict( + workflow_input=workflow, + file_path_list=file_path_list, + stop_server_after_completion=True, + output_node_ids=output_node + ) + + return output['file_paths'] # ignoring text output for now {"file_paths": [], "text_content": []} + +def is_comfy_runner_present(): + return os.path.exists(COMFY_RUNNER_PATH) # hackish sol, will fix later + +# TODO: convert comfy_runner into a package for easy import +def setup_comfy_runner(): + if is_comfy_runner_present(): + return + + app_logger.log(LoggingType.INFO, 'cloning comfy runner') + comfy_repo_url = "https://github.com/piyushK52/comfy-runner" + Repo.clone_from(comfy_repo_url, COMFY_RUNNER_PATH[2:], single_branch=True, branch='feature/package') + + # installing dependencies + subprocess.run(['pip', 'install', '-r', COMFY_RUNNER_PATH + '/requirements.txt'], check=True) \ No newline at end of file diff --git a/utils/ml_processor/ml_interface.py b/utils/ml_processor/ml_interface.py index d6a8ec1a..1e0692a6 100644 --- a/utils/ml_processor/ml_interface.py +++ b/utils/ml_processor/ml_interface.py @@ -1,11 +1,27 @@ from abc import ABC +from shared.constants import GPU_INFERENCE_ENABLED + def get_ml_client(): from utils.ml_processor.replicate.replicate import ReplicateProcessor + from utils.ml_processor.gpu.gpu import GPUProcessor - return ReplicateProcessor() + return ReplicateProcessor() if not GPU_INFERENCE_ENABLED else GPUProcessor() class MachineLearningProcessor(ABC): def __init__(self): pass + + def predict_model_output_standardized(self, *args, **kwargs): + pass + + def predict_model_output(self, *args, **kwargs): + pass + + def upload_training_data(self, *args, **kwargs): + pass + + # NOTE: implementation not neccessary as this functionality is removed from the app + def dreambooth_training(self, *args, **kwargs): + pass \ No newline at end of file diff --git a/utils/ml_processor/replicate/constants.py b/utils/ml_processor/replicate/constants.py deleted file mode 100644 index 9356a5ae..00000000 --- a/utils/ml_processor/replicate/constants.py +++ /dev/null @@ -1,82 +0,0 @@ -from dataclasses import dataclass - -from shared.constants import InferenceStatus - - -@dataclass -class ReplicateModel: - name: str - version: str - -class REPLICATE_MODEL: - sdxl_inpainting = ReplicateModel("lucataco/sdxl-inpainting", "f03c01943bacdee38d6a5d216586bf9bfbfd799350aed263aa32980efc173f0b") - clones_lora_training = ReplicateModel("cloneofsimo/lora-training", "b2a308762e36ac48d16bfadc03a65493fe6e799f429f7941639a6acec5b276cc") - clones_lora_training_2 = ReplicateModel("cloneofsimo/lora", "fce477182f407ffd66b94b08e761424cabd13b82b518754b83080bc75ad32466") - google_frame_interpolation = ReplicateModel("google-research/frame-interpolation", "4f88a16a13673a8b589c18866e540556170a5bcb2ccdc12de556e800e9456d3d") - pollination_modnet = ReplicateModel("pollinations/modnet", "da7d45f3b836795f945f221fc0b01a6d3ab7f5e163f13208948ad436001e2255") - clip_interrogator = ReplicateModel("pharmapsychotic/clip-interrogator", "a4a8bafd6089e1716b06057c42b19378250d008b80fe87caa5cd36d40c1eda90") - gfp_gan = ReplicateModel("xinntao/gfpgan", "6129309904ce4debfde78de5c209bce0022af40e197e132f08be8ccce3050393") - ghost_face_swap = ReplicateModel("arielreplicate/ghost_face_swap", "106df0aaf9690354379d8cd291ad337f6b3ea02fe07d90feb1dafd64820066fa") - stylegan_nada = ReplicateModel("rinongal/stylegan-nada", "6b2af4ac56fa2384f8f86fc7620943d5fc7689dcbb6183733743a215296d0e30") - img2img_sd_2_1 = ReplicateModel("cjwbw/stable-diffusion-img2img-v2.1", "650c347f19a96c8a0379db998c4cd092e0734534591b16a60df9942d11dec15b") - cjwbw_style_hair = ReplicateModel("cjwbw/style-your-hair", "c4c7e5a657e2e1abccd57625093522a9928edeccee77e3f55d57c664bcd96fa2") - depth2img_sd = ReplicateModel("jagilley/stable-diffusion-depth2img", "68f699d395bc7c17008283a7cef6d92edc832d8dc59eb41a6cafec7fc70b85bc") - salesforce_blip_2 = ReplicateModel("salesforce/blip-2", "4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608") - phamquiluan_face_recognition = ReplicateModel("phamquiluan/facial-expression-recognition", "b16694d5bfed43612f1bfad7015cf2b7883b732651c383fe174d4b7783775ff5") - arielreplicate = ReplicateModel("arielreplicate/instruct-pix2pix", "10e63b0e6361eb23a0374f4d9ee145824d9d09f7a31dcd70803193ebc7121430") - cjwbw_midas = ReplicateModel("cjwbw/midas", "a6ba5798f04f80d3b314de0f0a62277f21ab3503c60c84d4817de83c5edfdae0") - jagilley_controlnet_normal = ReplicateModel("jagilley/controlnet-normal", "cc8066f617b6c99fdb134bc1195c5291cf2610875da4985a39de50ee1f46d81c") - jagilley_controlnet_canny = ReplicateModel("jagilley/controlnet-canny", "aff48af9c68d162388d230a2ab003f68d2638d88307bdaf1c2f1ac95079c9613") - jagilley_controlnet_hed = ReplicateModel("jagilley/controlnet-hed", "cde353130c86f37d0af4060cd757ab3009cac68eb58df216768f907f0d0a0653") - jagilley_controlnet_scribble = ReplicateModel("jagilley/controlnet-scribble", "435061a1b5a4c1e26740464bf786efdfa9cb3a3ac488595a2de23e143fdb0117") - jagilley_controlnet_seg = ReplicateModel("jagilley/controlnet-seg", "f967b165f4cd2e151d11e7450a8214e5d22ad2007f042f2f891ca3981dbfba0d") - jagilley_controlnet_hough = ReplicateModel("jagilley/controlnet-hough", "854e8727697a057c525cdb45ab037f64ecca770a1769cc52287c2e56472a247b") - jagilley_controlnet_depth2img = ReplicateModel("jagilley/controlnet-depth2img", "922c7bb67b87ec32cbc2fd11b1d5f94f0ba4f5519c4dbd02856376444127cc60") - jagilley_controlnet_pose = ReplicateModel("jagilley/controlnet-pose", "0304f7f774ba7341ef754231f794b1ba3d129e3c46af3022241325ae0c50fb99") - real_esrgan_upscale = ReplicateModel("cjwbw/real-esrgan", "d0ee3d708c9b911f122a4ad90046c5d26a0293b99476d697f6bb7f2e251ce2d4") - controlnet_1_1_x_realistic_vision_v2_0 = ReplicateModel("usamaehsan/controlnet-1.1-x-realistic-vision-v2.0", "7fbf4c86671738f97896c9cb4922705adfcdcf54a6edab193bb8c176c6b34a69") - urpm = ReplicateModel("mcai/urpm-v1.3-img2img", "4df956e8dbfebf1afaf0c3ee98ad426ec58c4262d24360d054582e5eab2cb5f6") - sdxl = ReplicateModel("stability-ai/sdxl", "af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33") - - # addition 30/9/2023 - realistic_vision_v5 = ReplicateModel("heedster/realistic-vision-v5", "c0259010b93e7a4102a4ba946d70e06d7d0c7dc007201af443cfc8f943ab1d3c") - deliberate_v3 = ReplicateModel("pagebrain/deliberate-v3", "1851b62340ae657f05f8b8c8a020e3f9a46efde9fe80f273eef026c0003252ac") - dreamshaper_v7 = ReplicateModel("pagebrain/dreamshaper-v7", "0deba88df4e49b302585e1a7b6bd155e18962c1048966a40fe60ba05805743ff") - epicrealism_v5 = ReplicateModel("pagebrain/epicrealism-v5", "222465e57e4d9812207f14133c9499d47d706ecc41a8bf400120285b2f030b42") - sdxl_controlnet = ReplicateModel("lucataco/sdxl-controlnet", "db2ffdbdc7f6cb4d6dab512434679ee3366ae7ab84f89750f8947d5594b79a47") - realistic_vision_v5_img2img = ReplicateModel("lucataco/realistic-vision-v5-img2img", "82bbb4595458d6be142450fc6d8c4d79c936b92bd184dd2d6dd71d0796159819") - ad_interpolation = ReplicateModel("peter942/steerable-motion", "aa308181b5df669f20e56411b74ebafd7c01f82f7fe2a34a4b9382d4bd8155ba") - - # addition 17/10/2023 - llama_2_7b = ReplicateModel("meta/llama-2-7b", "527827021d8756c7ab79fde0abbfaac885c37a3ed5fe23c7465093f0878d55ef") - - # addition 11/11/2023 - sdxl_controlnet_openpose = ReplicateModel("lucataco/sdxl-controlnet-openpose", "d63e0b238b2d963d90348e2dad19830fbe372a7a43d90d234b2b63cae76d4397") - - @staticmethod - def get_model_by_db_obj(model_db_obj): - for model in REPLICATE_MODEL.__dict__.values(): - if isinstance(model, ReplicateModel) and model.name == model_db_obj.replicate_url and model.version == model_db_obj.version: - return model - return None - -DEFAULT_LORA_MODEL_URL = "https://replicate.delivery/pbxt/nWm6eP9ojwVvBCaWoWZVawOKRfgxPJmkVk13ES7PX36Y66kQA/tmpxuz6k_k2datazip.safetensors" - -CONTROLNET_MODELS = [ - REPLICATE_MODEL.jagilley_controlnet_normal, - REPLICATE_MODEL.jagilley_controlnet_canny, - REPLICATE_MODEL.jagilley_controlnet_hed, - REPLICATE_MODEL.jagilley_controlnet_scribble, - REPLICATE_MODEL.jagilley_controlnet_seg, - REPLICATE_MODEL.jagilley_controlnet_hough, - REPLICATE_MODEL.jagilley_controlnet_depth2img, - REPLICATE_MODEL.jagilley_controlnet_pose, -] - -replicate_status_map = { - "starting": InferenceStatus.QUEUED.value, - "processing": InferenceStatus.IN_PROGRESS.value, - "succeeded": InferenceStatus.COMPLETED.value, - "failed": InferenceStatus.FAILED.value, - "canceled": InferenceStatus.CANCELED.value -} \ No newline at end of file diff --git a/utils/ml_processor/replicate/replicate.py b/utils/ml_processor/replicate/replicate.py index d1b94c64..f0038451 100644 --- a/utils/ml_processor/replicate/replicate.py +++ b/utils/ml_processor/replicate/replicate.py @@ -17,7 +17,7 @@ import zipfile from PIL import Image -from utils.ml_processor.replicate.constants import REPLICATE_MODEL, ReplicateModel +from utils.ml_processor.constants import ML_MODEL, MLModel from ui_components.methods.data_logger import log_model_inference from utils.ml_processor.replicate.utils import check_user_credits, get_model_params_from_query_obj @@ -45,7 +45,7 @@ def update_usage_credits(self, time_taken): cost = round(time_taken * 0.004, 3) data_repo.update_usage_credits(-cost) - def get_model(self, input_model: ReplicateModel): + def get_model(self, input_model: MLModel): model = replicate.models.get(input_model.name) model_version = model.versions.get(input_model.version) if input_model.version else model return model_version @@ -56,13 +56,18 @@ def get_model_by_name(self, model_name, model_version=None): return model_version # it converts the standardized query_obj into params required by replicate - def predict_model_output_standardized(self, model: ReplicateModel, query_obj: MLQueryObject, queue_inference=False): + def predict_model_output_standardized(self, model: MLModel, query_obj: MLQueryObject, queue_inference=False): params = get_model_params_from_query_obj(model, query_obj) + + # remoing buffers + query_obj.data = {} + params[InferenceParamType.QUERY_DICT.value] = query_obj.to_json() + params["prompt"] = query_obj.prompt return self.predict_model_output(model, **params) if not queue_inference else self.queue_prediction(model, **params) @check_user_credits - def predict_model_output(self, replicate_model: ReplicateModel, **kwargs): + def predict_model_output(self, replicate_model: MLModel, **kwargs): # TODO: make unified interface for directing to queue_prediction queue_inference = kwargs.get('queue_inference', False) if queue_inference: @@ -92,7 +97,7 @@ def predict_model_output(self, replicate_model: ReplicateModel, **kwargs): log = log_model_inference(replicate_model, end_time - start_time, **kwargs) self.update_usage_credits(end_time - start_time) - if replicate_model == REPLICATE_MODEL.clip_interrogator: + if replicate_model == ML_MODEL.clip_interrogator: output = output # adding this for organisation purpose else: output = [output[-1]] if isinstance(output, list) else output @@ -100,7 +105,7 @@ def predict_model_output(self, replicate_model: ReplicateModel, **kwargs): return output, log @check_user_credits - def queue_prediction(self, replicate_model: ReplicateModel, **kwargs): + def queue_prediction(self, replicate_model: MLModel, **kwargs): url = "https://api.replicate.com/v1/predictions" headers = { "Authorization": "Token " + os.environ.get("REPLICATE_API_TOKEN"), @@ -153,7 +158,7 @@ def queue_prediction(self, replicate_model: ReplicateModel, **kwargs): self.logger.log(LoggingType.ERROR, f"Error in creating prediction: {response.content}") @check_user_credits - def predict_model_output_async(self, replicate_model: ReplicateModel, **kwargs): + def predict_model_output_async(self, replicate_model: MLModel, **kwargs): res = asyncio.run(self._multi_async_prediction(replicate_model, **kwargs)) output_list = [] @@ -169,12 +174,12 @@ def predict_model_output_async(self, replicate_model: ReplicateModel, **kwargs): return output_list - async def _multi_async_prediction(self, replicate_model: ReplicateModel, **kwargs): + async def _multi_async_prediction(self, replicate_model: MLModel, **kwargs): variant_count = kwargs['variant_count'] if ('variant_count' in kwargs and kwargs['variant_count']) else 1 res = await asyncio.gather(*[self._async_model_prediction(replicate_model, **kwargs) for _ in range(variant_count)]) return res - async def _async_model_prediction(self, replicate_model: ReplicateModel, **kwargs): + async def _async_model_prediction(self, replicate_model: MLModel, **kwargs): model_version = self.get_model(replicate_model) start_time = time.time() output = await asyncio.to_thread(model_version.predict, **kwargs) @@ -184,7 +189,7 @@ async def _async_model_prediction(self, replicate_model: ReplicateModel, **kwarg @check_user_credits def inpainting(self, video_name, input_image, prompt, negative_prompt): - model = self.get_model(REPLICATE_MODEL.sdxl_inpainting) + model = self.get_model(ML_MODEL.sdxl_inpainting) mask = "mask.png" mask = upload_file("mask.png", self.app_settings['aws_access_key'], self.app_settings['aws_secret_key']) @@ -219,7 +224,6 @@ def upload_training_data(self, zip_file_name, delete_after_upload=False): return serving_url - # TODO: figure how to resolve model location setting, right now it's hardcoded to peter942/modnet @check_user_credits def dreambooth_training(self, training_file_url, instance_prompt, \ diff --git a/utils/ml_processor/replicate/utils.py b/utils/ml_processor/replicate/utils.py index bce7b5e5..785dec05 100644 --- a/utils/ml_processor/replicate/utils.py +++ b/utils/ml_processor/replicate/utils.py @@ -1,7 +1,11 @@ +import io +from PIL import Image +from ui_components.methods.file_methods import normalize_size_internal_file_obj, resize_io_buffers from utils.common_utils import user_credits_available from utils.constants import MLQueryObject from utils.data_repo.data_repo import DataRepo -from utils.ml_processor.replicate.constants import CONTROLNET_MODELS, REPLICATE_MODEL +from utils.ml_processor.comfy_data_transform import get_file_list_from_query_obj, get_file_zip_url, get_model_workflow_from_query, get_workflow_json_url +from utils.ml_processor.constants import CONTROLNET_MODELS, ML_MODEL, ComfyRunnerModel, ComfyWorkflow def check_user_credits(method): @@ -28,6 +32,43 @@ async def wrapper(self, *args, **kwargs): def get_model_params_from_query_obj(model, query_obj: MLQueryObject): data_repo = DataRepo() + # handling comfy_runner workflows + if model.name == ComfyRunnerModel.name: + workflow_json, output_node_ids = get_model_workflow_from_query(model, query_obj) + workflow_file = get_workflow_json_url(workflow_json) + + models_using_sdxl = [ + ComfyWorkflow.SDXL.value, + ComfyWorkflow.SDXL_IMG2IMG.value, + ComfyWorkflow.SDXL_CONTROLNET.value, + ComfyWorkflow.SDXL_INPAINTING.value, + ComfyWorkflow.IP_ADAPTER_FACE.value, + ComfyWorkflow.IP_ADAPTER_FACE_PLUS.value, + ComfyWorkflow.IP_ADAPTER_PLUS.value + ] + + # resizing image for sdxl + file_uuid_list = get_file_list_from_query_obj(query_obj) + if model.display_name() in models_using_sdxl and len(file_uuid_list): + new_uuid_list = [] + for file_uuid in file_uuid_list: + new_width, new_height = 1024 if query_obj.width == 512 else 768, 1024 if query_obj.height == 512 else 768 + file = data_repo.get_file_from_uuid(file_uuid) + new_file = normalize_size_internal_file_obj(file, dim=[new_width, new_height], create_new_file=True) + new_uuid_list.append(new_file.uuid) + + file_uuid_list = new_uuid_list + + index_files = True if model.display_name() in ['steerable_motion'] else False + file_zip = get_file_zip_url(file_uuid_list, index_files=index_files) + + data = { + "workflow_json": workflow_file, + "file_list": file_zip + } + + return data + input_image, mask = None, None if query_obj.image_uuid: image = data_repo.get_file_from_uuid(query_obj.image_uuid) @@ -43,7 +84,7 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): if not mask.startswith('http'): mask = open(mask, 'rb') - if model == REPLICATE_MODEL.img2img_sd_2_1: + if model == ML_MODEL.img2img_sd_2_1: data = { "prompt_strength" : query_obj.strength, "prompt" : query_obj.prompt, @@ -58,30 +99,53 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): if input_image: data['image'] = input_image - elif model == REPLICATE_MODEL.real_esrgan_upscale: + elif model == ML_MODEL.real_esrgan_upscale: data = { "image": input_image, "upscale": query_obj.data.get('upscale', 2), } - elif model == REPLICATE_MODEL.stylegan_nada: + elif model == ML_MODEL.stylegan_nada: data = { "input": input_image, "output_style": query_obj.prompt } - elif model == REPLICATE_MODEL.sdxl: + elif model in [ML_MODEL.sdxl, ML_MODEL.sdxl_img2img]: + new_width, new_height = 1024 if query_obj.width == 512 else 768, 1024 if query_obj.height == 512 else 768 data = { "prompt" : query_obj.prompt, "negative_prompt" : query_obj.negative_prompt, - "width" : 768 if query_obj.width == 512 else 1024, # 768 is the default for sdxl - "height" : 768 if query_obj.height == 512 else 1024, + "width" : new_width, # 768 is the default for sdxl + "height" : new_height, "prompt_strength": query_obj.strength, - "mask": mask + "mask": mask, + "disable_safety_checker": True, } if input_image: - data['image'] = input_image + output_image_buffer = resize_io_buffers(input_image, new_width, new_height) + data['image'] = output_image_buffer + + elif model == ML_MODEL.sdxl_inpainting: + new_width, new_height = 1024 if query_obj.width == 512 else 768, 1024 if query_obj.height == 512 else 768 + data = { + "prompt" : query_obj.prompt, + "negative_prompt" : query_obj.negative_prompt, + "width" : new_width, # 768 is the default for sdxl + "height" : new_height, + "strength": query_obj.strength, + "scheduler": "K_EULER", + "guidance_scale": 8, + "steps": 20, + "mask": query_obj.data.get("data", {}).get("mask", None), + "image": query_obj.data.get("data", {}).get("input_image", None), + "disable_safety_checker": True, + } - elif model == REPLICATE_MODEL.arielreplicate: + if input_image: + output_image_buffer = resize_io_buffers(input_image, new_width, new_height) + data['image'] = output_image_buffer + + elif model == ML_MODEL.arielreplicate: data = { "instruction_text" : query_obj.prompt, "seed" : query_obj.seed, @@ -93,7 +157,7 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): if input_image: data['input_image'] = input_image - elif model == REPLICATE_MODEL.urpm: + elif model == ML_MODEL.urpm: data = { 'prompt': query_obj.prompt, 'negative_prompt': query_obj.negative_prompt, @@ -107,7 +171,7 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): if input_image: data['image'] = input_image - elif model == REPLICATE_MODEL.controlnet_1_1_x_realistic_vision_v2_0: + elif model == ML_MODEL.controlnet_1_1_x_realistic_vision_v2_0: data = { 'prompt': query_obj.prompt, 'ddim_steps': query_obj.num_inference_steps, @@ -119,7 +183,7 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): if input_image: data['image'] = input_image - elif model == REPLICATE_MODEL.realistic_vision_v5: + elif model == ML_MODEL.realistic_vision_v5: if not (query_obj.guidance_scale >= 3.5 and query_obj.guidance_scale <= 7.0): raise ValueError("Guidance scale must be between 3.5 and 7.0") @@ -132,7 +196,7 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): 'steps': query_obj.num_inference_steps, 'seed': query_obj.seed if query_obj.seed not in [-1, 0] else 0 } - elif model == REPLICATE_MODEL.deliberate_v3 or model == REPLICATE_MODEL.dreamshaper_v7 or model == REPLICATE_MODEL.epicrealism_v5: + elif model == ML_MODEL.deliberate_v3 or model == ML_MODEL.dreamshaper_v7 or model == ML_MODEL.epicrealism_v5: data = { 'prompt': query_obj.prompt, 'negative_prompt': query_obj.negative_prompt, @@ -153,7 +217,8 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): if mask: data['mask'] = mask - elif model == REPLICATE_MODEL.sdxl_controlnet: + elif model == ML_MODEL.sdxl_controlnet: + new_width, new_height = 1024 if query_obj.width == 512 else 768, 1024 if query_obj.height == 512 else 768 data = { 'prompt': query_obj.prompt, 'negative_prompt': query_obj.negative_prompt, @@ -162,9 +227,10 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): } if input_image: - data['image'] = input_image + output_image_buffer = resize_io_buffers(input_image, new_width, new_height) + data['image'] = output_image_buffer - elif model == REPLICATE_MODEL.sdxl_controlnet_openpose: + elif model == ML_MODEL.sdxl_controlnet_openpose: data = { 'prompt': query_obj.prompt, 'negative_prompt': query_obj.negative_prompt, @@ -175,7 +241,7 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): if input_image: data['image'] = input_image - elif model == REPLICATE_MODEL.realistic_vision_v5_img2img: + elif model == ML_MODEL.realistic_vision_v5_img2img: data = { 'prompt': query_obj.prompt, 'negative_prompt': query_obj.negative_prompt, @@ -188,7 +254,7 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): data['image'] = input_image elif model in CONTROLNET_MODELS: - if model == REPLICATE_MODEL.jagilley_controlnet_scribble and query_obj.data.get('canny_image', None): + if model == ML_MODEL.jagilley_controlnet_scribble and query_obj.data.get('canny_image', None): input_image = data_repo.get_file_from_uuid(query_obj.data['canny_image']).location if not input_image.startswith('http'): input_image = open(input_image, 'rb') @@ -211,7 +277,7 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): 'high_threshold': query_obj.high_threshold, } - elif model in [REPLICATE_MODEL.clones_lora_training_2]: + elif model in [ML_MODEL.clones_lora_training_2]: if query_obj.adapter_type: adapter_condition_image = input_image