Skip to content
This repository has been archived by the owner on Dec 29, 2024. It is now read-only.

Commit

Permalink
Merge pull request #76 from banodoco/green-head
Browse files Browse the repository at this point in the history
Green head
  • Loading branch information
peteromallet authored Feb 19, 2024
2 parents e19fae5 + 596cfea commit 490afd8
Show file tree
Hide file tree
Showing 38 changed files with 953 additions and 564 deletions.
10 changes: 8 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import threading
import time
import streamlit as st
from moviepy.editor import *
Expand Down Expand Up @@ -93,7 +92,14 @@ def main():
project_init()

from ui_components.setup import setup_app_ui
setup_app_ui()
from ui_components.components.welcome_page import welcome_page

data_repo = DataRepo()
app_setting = data_repo.get_app_setting_from_uuid()
if app_setting.welcome_state != 0:
setup_app_ui()
else:
welcome_page()

st.session_state['maintain_state'] = False

Expand Down
14 changes: 13 additions & 1 deletion backend/db_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,17 @@ def get_all_file_list(self, **kwargs):
del kwargs['data_per_page']
sort_order = kwargs['sort_order'] if 'sort_order' in kwargs else None
del kwargs['sort_order']

shot_uuid_list = []
if 'shot_uuid_list' in kwargs:
shot_uuid_list = kwargs['shot_uuid_list']
del kwargs['shot_uuid_list']

file_list = InternalFileObject.objects.filter(**kwargs).all()

if shot_uuid_list and len(shot_uuid_list):
file_list = file_list.filter(shot_uuid__in=shot_uuid_list)

if sort_order:
if sort_order == SortOrder.DESCENDING.value:
file_list = file_list.order_by('-created_on')
Expand Down Expand Up @@ -600,7 +609,7 @@ def get_inference_log_from_uuid(self, uuid):

return InternalResponse(payload, 'inference log fetched', True)

def get_all_inference_log_list(self, project_id=None, page=1, data_per_page=5, status_list=None, exclude_model_list=None):
def get_all_inference_log_list(self, project_id=None, page=1, data_per_page=5, status_list=None, exclude_model_list=None, model_name_list=""):
if project_id:
project = Project.objects.filter(uuid=project_id, is_disabled=False).first()
log_list = InferenceLog.objects.filter(project_id=project.id, is_disabled=False).order_by('-created_on').all()
Expand All @@ -612,6 +621,9 @@ def get_all_inference_log_list(self, project_id=None, page=1, data_per_page=5, s
else:
log_list = log_list.exclude(status__in=["", None])

if model_name_list:
log_list = log_list.filter(model_name__in=model_name_list)

log_list = log_list.exclude(model_id=None) # hackish sol to exclude non-image/video logs

paginator = Paginator(log_list, data_per_page)
Expand Down
23 changes: 23 additions & 0 deletions backend/migrations/0013_filter_keys_added.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 4.2.1 on 2024-02-18 07:55

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('backend', '0012_shot_added_and_redundant_fields_removed'),
]

operations = [
migrations.AddField(
model_name='inferencelog',
name='model_name',
field=models.CharField(blank=True, default='', max_length=512),
),
migrations.AddField(
model_name='internalfileobject',
name='shot_uuid',
field=models.CharField(blank=True, default='', max_length=255),
),
]
2 changes: 2 additions & 0 deletions backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Meta:
class InferenceLog(BaseModel):
project = models.ForeignKey(Project, on_delete=models.CASCADE, null=True)
model = models.ForeignKey(AIModel, on_delete=models.DO_NOTHING, null=True)
model_name = models.CharField(max_length=512, default="", blank=True) # for filtering purposes
input_params = models.TextField(default="", blank=True)
output_details = models.TextField(default="", blank=True)
total_inference_time = models.FloatField(default=0)
Expand All @@ -97,6 +98,7 @@ class InternalFileObject(BaseModel):
tag = models.CharField(max_length=255,default="") # background_image, mask_image, canny_image etc..
project = models.ForeignKey(Project, on_delete=models.SET_NULL, default=None, null=True)
inference_log = models.ForeignKey(InferenceLog, on_delete=models.SET_NULL, default=None, null=True)
shot_uuid = models.CharField(max_length=255, default="", blank=True) # NOTE: this is not a foreignkey and purely for filtering purpose

class Meta:
app_label = 'backend'
Expand Down
2 changes: 2 additions & 0 deletions backend/serializers/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class CreateFileDao(serializers.Serializer):
hosted_url = serializers.CharField(max_length=512, required=False)
tag = serializers.CharField(max_length=100, allow_blank=True, required=False)
project_id = serializers.CharField(max_length=100, required=False)
shot_uuid = serializers.CharField(max_length=512, required=False, default="", allow_blank=True)
inference_log_id = serializers.CharField(max_length=100, allow_null=True, required=False)

def validate(self, data):
Expand Down Expand Up @@ -66,6 +67,7 @@ class CreateInferenceLogDao(serializers.Serializer):
output_details = serializers.CharField(required=False)
total_inference_time = serializers.CharField(required=False)
status = serializers.CharField(required=False, default="")
model_name = serializers.CharField(max_length=512, allow_blank=True, required=False, default="")


class CreateAIModelParamMapDao(serializers.Serializer):
Expand Down
16 changes: 14 additions & 2 deletions backend/serializers/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class Meta:
"total_inference_time",
"created_on",
"updated_on",
"status"
"status",
"model_name"
)


Expand All @@ -71,7 +72,18 @@ class InternalFileDto(serializers.ModelSerializer):
inference_log = InferenceLogDto()
class Meta:
model = InternalFileObject
fields = ('uuid', 'name', 'local_path', 'type', 'hosted_url', 'created_on', 'inference_log', 'project', 'tag')
fields = (
'uuid',
'name',
'local_path',
'type',
'hosted_url',
'created_on',
'inference_log',
'project',
'tag',
'shot_uuid'
)


class BasicShotDto(serializers.ModelSerializer):
Expand Down
33 changes: 30 additions & 3 deletions banodoco_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import time
import uuid
import psutil
import requests
import traceback
import sentry_sdk
Expand Down Expand Up @@ -69,11 +70,13 @@ def main():
print('runner running')
while True:
if TERMINATE_SCRIPT:
stop_server(8188)
return

if SERVER == 'development':
if not is_app_running():
if retries <= 0:
stop_server(8188)
print('runner stopped')
return
retries -= 1
Expand Down Expand Up @@ -147,6 +150,29 @@ def update_cache_dict(inference_type, log, timing_uuid, shot_uuid, timing_update
if str(log.project.uuid) not in shot_update_list:
shot_update_list[str(log.project.uuid)] = []
shot_update_list[str(log.project.uuid)].append(shot_uuid)

def find_process_by_port(port):
pid = None
for proc in psutil.process_iter(attrs=['pid', 'name', 'connections']):
try:
if proc and 'connections' in proc.info and proc.info['connections']:
for conn in proc.info['connections']:
if conn.status == psutil.CONN_LISTEN and conn.laddr.port == port:
app_logger.log(LoggingType.DEBUG, f"Process {proc.info['pid']} (Port {port})")
pid = proc.info['pid']
break
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
pass

return pid

def stop_server(self, port):
pid = find_process_by_port(port)
if pid:
app_logger.log(LoggingType.DEBUG, "comfy server stopped")
process = psutil.Process(pid)
process.terminate()
process.wait()

def check_and_update_db():
# print("updating logs")
Expand Down Expand Up @@ -225,7 +251,7 @@ def check_and_update_db():
print("processing inference output")
process_inference_output(**origin_data)
timing_uuid, shot_uuid = origin_data.get('timing_uuid', None), origin_data.get('shot_uuid', None)
update_cache_dict(origin_data['inference_type'], log, timing_uuid, shot_uuid, timing_update_list, shot_update_list, gallery_update_list)
update_cache_dict(origin_data.get('inference_type', ""), log, timing_uuid, shot_uuid, timing_update_list, shot_update_list, gallery_update_list)

except Exception as e:
app_logger.log(LoggingType.ERROR, f"Error: {e}")
Expand All @@ -248,7 +274,8 @@ def check_and_update_db():
try:
setup_comfy_runner()
start_time = time.time()
output = predict_gpu_output(data['workflow_input'], data['file_path_list'], data['output_node_ids'])
output = predict_gpu_output(data['workflow_input'], data['file_path_list'], \
data['output_node_ids'], data.get("extra_model_list", []), data.get("ignore_model_list", []))
end_time = time.time()

output = output[-1] # TODO: different models can have different logic
Expand All @@ -271,7 +298,7 @@ def check_and_update_db():
from ui_components.methods.common_methods import process_inference_output
process_inference_output(**origin_data)
timing_uuid, shot_uuid = origin_data.get('timing_uuid', None), origin_data.get('shot_uuid', None)
update_cache_dict(origin_data['inference_type'], log, timing_uuid, shot_uuid, timing_update_list, shot_update_list, gallery_update_list)
update_cache_dict(origin_data.get('inference_type', ''), log, timing_uuid, shot_uuid, timing_update_list, shot_update_list, gallery_update_list)

except Exception as e:
print("error occured: ", str(e))
Expand Down
57 changes: 29 additions & 28 deletions banodoco_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,38 +103,39 @@ def create_new_project(user: InternalUserObject, project_name: str, width=512, h

shot = data_repo.create_shot(**shot_data)

# NOTE: removing sample timing frame
# create a sample timing frame
st.session_state["project_uuid"] = project.uuid
sample_file_location = "sample_assets/sample_images/v.jpeg"
img = Image.open(sample_file_location)
img = img.resize((width, height))

unique_file_name = f"{str(uuid.uuid4())}.png"
file_location = f"videos/{project.uuid}/resources/prompt_images/{unique_file_name}"
hosted_url = save_or_host_file(img, file_location, mime_type='image/png', dim=(width, height))
file_data = {
"name": str(uuid.uuid4()),
"type": InternalFileType.IMAGE.value,
"project_id": project.uuid,
"dim": (width, height),
}

if hosted_url:
file_data.update({'hosted_url': hosted_url})
else:
file_data.update({'local_path': file_location})
# st.session_state["project_uuid"] = project.uuid
# sample_file_location = "sample_assets/sample_images/v.jpeg"
# img = Image.open(sample_file_location)
# img = img.resize((width, height))

# unique_file_name = f"{str(uuid.uuid4())}.png"
# file_location = f"videos/{project.uuid}/resources/prompt_images/{unique_file_name}"
# hosted_url = save_or_host_file(img, file_location, mime_type='image/png', dim=(width, height))
# file_data = {
# "name": str(uuid.uuid4()),
# "type": InternalFileType.IMAGE.value,
# "project_id": project.uuid,
# "dim": (width, height),
# }

# if hosted_url:
# file_data.update({'hosted_url': hosted_url})
# else:
# file_data.update({'local_path': file_location})

source_image = data_repo.create_file(**file_data)
# source_image = data_repo.create_file(**file_data)

timing_data = {
"frame_time": 0.0,
"aux_frame_index": 0,
"source_image_id": source_image.uuid,
"shot_id": shot.uuid,
}
timing: InternalFrameTimingObject = data_repo.create_timing(**timing_data)
# timing_data = {
# "frame_time": 0.0,
# "aux_frame_index": 0,
# "source_image_id": source_image.uuid,
# "shot_id": shot.uuid,
# }
# timing: InternalFrameTimingObject = data_repo.create_timing(**timing_data)

add_image_variant(source_image.uuid, timing.uuid)
# add_image_variant(source_image.uuid, timing.uuid)

# create default ai models
model_list = create_predefined_models(user)
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ wrapt==1.15.0
pydantic==1.10.9
streamlit-server-state==0.17.1
setproctitle==1.3.3
gitdb==4.0.11
gitdb==4.0.11
psutil==5.9.8
12 changes: 6 additions & 6 deletions ui_components/components/adjust_shot_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ def adjust_shot_page(shot_uuid: str, h2):
project_setting = data_repo.get_project_setting(shot.project.uuid)
number_of_pages = project_setting.total_shortlist_gallery_pages
page_number = 0
gallery_image_view(shot.project.uuid, shortlist=True,view=['add_and_remove_from_shortlist','add_to_this_shot'], shot=shot,sidebar=True)
gallery_image_view(shot.project.uuid, shortlist=True,view=['add_and_remove_from_shortlist','add_to_this_shot'], shot=shot, sidebar=True)

st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]")
st.markdown("***")
shot_keyframe_element(st.session_state["shot_uuid"], 4, position="Individual")
project_setting = data_repo.get_project_setting(shot.project.uuid)


with st.expander("✨ Generate Images", expanded=True):
generate_images_element(position='explorer', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid'])
generate_images_element(position='explorer', project_uuid=shot.project.uuid, timing_uuid=None, shot_uuid=shot.uuid)
st.markdown("***")

st.markdown("***")
gallery_image_view(shot.project.uuid, shortlist=False,view=['add_and_remove_from_shortlist','add_to_this_shot','view_inference_details'], shot=shot,sidebar=False)
gallery_image_view(shot.project.uuid, shortlist=False,view=['add_and_remove_from_shortlist','add_to_this_shot','view_inference_details','shot_chooser'], shot=shot,sidebar=False)
else:
frame_styling_page(st.session_state["shot_uuid"], h2)
Loading

0 comments on commit 490afd8

Please sign in to comment.