diff --git a/sample_assets/sample_images/main.png b/sample_assets/sample_images/main.png new file mode 100644 index 00000000..c99f6ad7 Binary files /dev/null and b/sample_assets/sample_images/main.png differ diff --git a/ui_components/components/adjust_shot_page.py b/ui_components/components/adjust_shot_page.py index bb326d07..fc198b89 100644 --- a/ui_components/components/adjust_shot_page.py +++ b/ui_components/components/adjust_shot_page.py @@ -2,6 +2,7 @@ from ui_components.widgets.shot_view import shot_keyframe_element from ui_components.components.explorer_page import gallery_image_view from ui_components.components.explorer_page import generate_images_element +from ui_components.components.frame_styling_page import frame_styling_page from ui_components.widgets.frame_selector import frame_selector_widget, frame_view from utils import st_memory from utils.data_repo.data_repo import DataRepo @@ -9,38 +10,48 @@ def adjust_shot_page(shot_uuid: str, h2): + + with st.sidebar: + + frame_selection = frame_selector_widget(show_frame_selector=True) + data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) - with h2: - frame_selector_widget(show=['shot_selector']) + if frame_selection == "": - st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]") + with st.sidebar: + frame_view(view='Video') + with st.expander("📋 Explorer Shortlist",expanded=True): - st.markdown("***") + if st_memory.toggle("Open", value=True, key="explorer_shortlist_toggle"): + + project_setting = data_repo.get_project_setting(shot.project.uuid) + number_of_pages = project_setting.total_shortlist_gallery_pages + # page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True,key=f"main_page_number") + # st.markdown("***") + page_number = 0 + gallery_image_view(shot.project.uuid, shortlist=True,view=['add_and_remove_from_shortlist','add_to_this_shot'], shot=shot,sidebar=True) - with st.sidebar: - frame_view(view='Video') - - shot_keyframe_element(st.session_state["shot_uuid"], 4, position="Individual") - # with st.expander("📋 Explorer Shortlist",expanded=True): - shot_explorer_view = st_memory.menu('',["Shortlist", "Explore"], - icons=['grid-3x3','airplane'], - menu_icon="cast", - default_index=st.session_state.get('shot_explorer_view', 0), - key="shot_explorer_view", orientation="horizontal", - styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "#868c91"}}) - - st.markdown("***") + + + + + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]") - if shot_explorer_view == "Shortlist": - project_setting = data_repo.get_project_setting(shot.project.uuid) - page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) st.markdown("***") - gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=8, open_detailed_view_for_all=False, shortlist=True, num_columns=4,view="individual_shot", shot=shot) - elif shot_explorer_view == "Explore": + + + shot_keyframe_element(st.session_state["shot_uuid"], 4, position="Individual") + project_setting = data_repo.get_project_setting(shot.project.uuid) - page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) - generate_images_element(position='explorer', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) + # st.markdown("***") + + with st.expander("✨ Generate Images", expanded=True): + generate_images_element(position='explorer', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) + st.markdown("***") + page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True,key=f"main_page_number_{shot.project.uuid}") st.markdown("***") - gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=8, open_detailed_view_for_all=False, shortlist=False, num_columns=4,view="individual_shot", shot=shot) \ No newline at end of file + gallery_image_view(shot.project.uuid, shortlist=False,view=['add_and_remove_from_shortlist','add_to_this_shot'], shot=shot) + else: + frame_styling_page(st.session_state["shot_uuid"], h2) \ No newline at end of file diff --git a/ui_components/components/animate_shot_page.py b/ui_components/components/animate_shot_page.py index e30e8282..ee319772 100644 --- a/ui_components/components/animate_shot_page.py +++ b/ui_components/components/animate_shot_page.py @@ -8,13 +8,14 @@ def animate_shot_page(shot_uuid: str, h2): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) - with h2: - frame_selector_widget(show=['shot_selector']) + + with st.sidebar: - frame_view() + frame_selector_widget(show_frame_selector=False) + frame_view(view='Video') st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]") st.markdown("***") variant_comparison_grid(st.session_state['shot_uuid'], stage="Shots") - with st.expander("🎬 Choose Animation Style & Create Variants", expanded=True): + with st.expander("🎥 Generate Animation", expanded=True): animation_style_element(st.session_state['shot_uuid']) \ No newline at end of file diff --git a/ui_components/components/app_settings_page.py b/ui_components/components/app_settings_page.py index be505adc..b1d4ad2d 100644 --- a/ui_components/components/app_settings_page.py +++ b/ui_components/components/app_settings_page.py @@ -3,6 +3,7 @@ import webbrowser from shared.constants import SERVER, ServerType from utils.common_utils import get_current_user +from ui_components.components.query_logger_page import query_logger_page from utils.data_repo.data_repo import DataRepo @@ -37,4 +38,6 @@ def app_settings_page(): else: payment_link = data_repo.generate_payment_link(credits) payment_link = f""" PAYMENT LINK """ - st.markdown(payment_link, unsafe_allow_html=True) \ No newline at end of file + st.markdown(payment_link, unsafe_allow_html=True) + st.markdown("***") + query_logger_page() \ No newline at end of file diff --git a/ui_components/components/explorer_page.py b/ui_components/components/explorer_page.py index 74d6b7c7..a18d0544 100644 --- a/ui_components/components/explorer_page.py +++ b/ui_components/components/explorer_page.py @@ -21,47 +21,30 @@ class InputImageStyling(ExtendedEnum): - EVOLVE_IMAGE = "Evolve Image" - MAINTAIN_STRUCTURE = "Maintain Structure" + TEXT2IMAGE = "Text to Image" + IMAGE2IMAGE = "Image to Image" + CONTROLNET_CANNY = "ControlNet Canny" + IPADAPTER_FACE = "IP-Adapter Face" + IPADAPTER_PLUS = "IP-Adapter Plus" + IPADPTER_FACE_AND_PLUS = "IP-Adapter Face & Plus" -def columnn_selecter(): - f1, f2 = st.columns([1, 1]) - with f1: - st_memory.number_input('Number of columns:', min_value=3, max_value=7, value=4,key="num_columns_explorer") - with f2: - st_memory.number_input('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer") - def explorer_page(project_uuid): data_repo = DataRepo() project_setting = data_repo.get_project_setting(project_uuid) st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}]") + st.markdown("***") - z1, z2, z3 = st.columns([0.25,2,0.25]) - with z2: - with st.expander("Prompt Settings", expanded=True): - generate_images_element(position='explorer', project_uuid=project_uuid, timing_uuid=None) - st.markdown("***") - columnn_selecter() - k1,k2 = st.columns([5,1]) - page_number = k1.radio("Select page:", options=range(1, project_setting.total_gallery_pages + 1), horizontal=True, key="main_gallery") - open_detailed_view_for_all = k2.toggle("Open detailed view for all:", key='main_gallery_toggle') + + with st.expander("✨ Generate Images", expanded=True): + generate_images_element(position='explorer', project_uuid=project_uuid, timing_uuid=None) st.markdown("***") - gallery_image_view(project_uuid, page_number, st.session_state['num_items_per_page_explorer'], open_detailed_view_for_all, False, st.session_state['num_columns_explorer'],view="explorer") + + + gallery_image_view(project_uuid,False,view=['add_and_remove_from_shortlist','view_inference_details']) -def shortlist_element(project_uuid): - data_repo = DataRepo() - project_setting = data_repo.get_project_setting(project_uuid) - columnn_selecter() - k1,k2 = st.columns([5,1]) - shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") - with k2: - open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle') - st.markdown("***") - gallery_image_view(project_uuid, shortlist_page_number, st.session_state['num_items_per_page_explorer'], open_detailed_view_for_all, True, st.session_state['num_columns_explorer'],view="shortlist") - def generate_images_element(position='explorer', project_uuid=None, timing_uuid=None): @@ -84,73 +67,136 @@ def generate_images_element(position='explorer', project_uuid=None, timing_uuid= st.write("") st.write("") st.write("") - if st.button("🔄", key="switch_prompt_position_button", use_container_width=True, help="This will switch the order the prompt and magic prompt are used - earlier items gets more attention."): + if st.button("🔄", key="switch_prompt_position_button:", use_container_width=True, help="This will switch the order the prompt and magic prompt are used - earlier items gets more attention."): st.session_state['switch_prompt_position'] = not st.session_state.get('switch_prompt_position', False) st.experimental_rerun() - neg1, _ = st.columns([1.5,1]) + neg1, _ = st.columns([1,1.3]) with neg1: - negative_prompt = st_memory.text_input("Negative prompt", value="bad image, worst image, bad anatomy, washed out colors",\ + negative_prompt = st_memory.text_input("Negative prompt:", value="bad image, worst image, bad anatomy, washed out colors",\ key="explorer_neg_prompt", \ help="These are the things you wish to be excluded from the image") - if position=='explorer': - _, b1, b2, b3, _ = st.columns([0.1,1.25,2,2,0.1]) - _, c1, c2, _ = st.columns([1,2,2,1]) - else: - b1, b2, b3 = st.columns([1,2,1]) - c1, c2, _ = st.columns([2,2,2]) - + + b1, b2, b3,b4 = st.columns([1.5,1.5,1.5,1]) + c1, c2, _ = st.columns([2,2,2]) + with b1: - use_input_image = st_memory.checkbox("Use input image", key="use_input_image", value=False) + type_of_generation = st_memory.radio("How would you like to generate the image?", options=InputImageStyling.value_list(), key="type_of_generation_key", help="Evolve Image will evolve the image based on the prompt, while Maintain Structure will keep the structure of the image and change the style.",horizontal=True) - - if use_input_image: + + input_image_key = "input_image_1" + if input_image_key not in st.session_state: + st.session_state[input_image_key] = None + if 'input_image_2' not in st.session_state: + st.session_state['input_image_2'] = None + if type_of_generation != InputImageStyling.TEXT2IMAGE.value: with b2: - type_of_transformation = st_memory.radio("What type of transformation would you like to do?", options=InputImageStyling.value_list(), key="type_of_transformation_key", help="Evolve Image will evolve the image based on the prompt, while Maintain Structure will keep the structure of the image and change the style.",horizontal=True) - - with c1: - input_image_key = f"input_image_{position}" - if input_image_key not in st.session_state: - st.session_state[input_image_key] = None - - input_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"], key="explorer_input_image", help="This will be the base image for the generation.") - if st.button("Upload", use_container_width=True): - st.session_state[input_image_key] = input_image + source_of_starting_image = st_memory.radio("Image source:", options=["Upload", "From Shot"], key="source_of_starting_image", help="This will be the base image for the generation.",horizontal=True) + if source_of_starting_image == "Upload": + input_image = st.file_uploader("Upload a starting image", type=["png", "jpg", "jpeg"], key="explorer_input_image", help="This will be the base image for the generation.") + else: + shot_list = data_repo.get_shot_list(project_uuid) + selection1, selection2 = st.columns([1,1]) + with selection1: + shot_name = st.selectbox("Shot:", options=[shot.name for shot in shot_list], key="explorer_shot_uuid", help="This will be the base image for the generation.") + shot_uuid = [shot.uuid for shot in shot_list if shot.name == shot_name][0] + frame_list = data_repo.get_timing_list_from_shot(shot_uuid) + with selection2: + list_of_timings = [i + 1 for i in range(len(frame_list))] + timing = st.selectbox("Frame #:", options=list_of_timings, key="explorer_frame_number", help="This will be the base image for the generation.") + #timing = st.number_input("Frame #:", min_value=1, max_value=len(frame_list), value=1, step=1, key="explorer_frame_number", help="This will be the base image for the generation.") + input_image = frame_list[timing - 1].primary_image.location + # make it a byte stream + st.image(frame_list[timing - 1].primary_image.location, use_column_width=True) + + if type_of_generation == InputImageStyling.IPADPTER_FACE_AND_PLUS.value: + source_of_starting_image_2 = st_memory.radio("How would you like to upload the second starting image?", options=["Upload", "From Shot"], key="source_of_starting_image_2", help="This will be the base image for the generation.",horizontal=True) + if source_of_starting_image_2 == "Upload": + input_image_2 = st.file_uploader("IP-Adapter Face image:", type=["png", "jpg", "jpeg"], key="explorer_input_image_2", help="This will be the base image for the generation.") + else: + selection1, selection2 = st.columns([1,1]) + with selection1: + shot_list = data_repo.get_shot_list(project_uuid) + shot_name = st.selectbox("Shot:", options=[shot.name for shot in shot_list], key="explorer_shot_uuid_2", help="This will be the base image for the generation.") + shot_uuid = [shot.uuid for shot in shot_list if shot.name == shot_name][0] + with selection2: + frame_list = data_repo.get_timing_list_from_shot(shot_uuid) + list_of_timings = [i + 1 for i in range(len(frame_list))] + timing = st.selectbox("Frame #:", options=list_of_timings, key="explorer_frame_number_2", help="This will be the base image for the generation.") + input_image_2 = frame_list[timing - 1].primary_image.location + st.image(frame_list[timing - 1].primary_image.location, use_column_width=True) + + # if type type is face and plus, then we need to make the text images + if type_of_generation == InputImageStyling.IPADPTER_FACE_AND_PLUS.value: + button_text = "Upload Images" + else: + button_text = "Upload Image" + if st.button(button_text, use_container_width=True): + st.session_state[input_image_key] = input_image + if type_of_generation == InputImageStyling.IPADPTER_FACE_AND_PLUS.value: + st.session_state['input_image_2'] = input_image_2 + st.rerun() + with b3: + # prompt_strength = round(1 - (strength_of_image / 100), 2) + if type_of_generation != InputImageStyling.IPADPTER_FACE_AND_PLUS.value: + if st.session_state[input_image_key] is not None: + st.info("Current image:") + st.image(st.session_state[input_image_key], use_column_width=True) + else: + st.info("Current image:") + st.error("Please upload an image") with b3: edge_pil_img = None - strength_of_current_image = st_memory.number_input("What % of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_current_image_key", help="This will determine how much of the current image will be kept in the final image.") - if type_of_transformation == InputImageStyling.EVOLVE_IMAGE.value: - prompt_strength = round(1 - (strength_of_current_image / 100), 2) - with c2: - if st.session_state[input_image_key] is not None: - input_image_bytes = st.session_state[input_image_key].getvalue() - pil_image = Image.open(io.BytesIO(input_image_bytes)) - blur_radius = (100 - strength_of_current_image) / 3 # Adjust this formula as needed - blurred_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius)) - st.image(blurred_image, use_column_width=True) - - elif type_of_transformation == InputImageStyling.MAINTAIN_STRUCTURE.value: - condition_scale = strength_of_current_image / 10 - with c2: - if st.session_state[input_image_key] is not None: - input_image_bytes = st.session_state[input_image_key] .getvalue() - pil_image = Image.open(io.BytesIO(input_image_bytes)) - cv_image = np.array(pil_image) - gray_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2GRAY) - lower_threshold = (100 - strength_of_current_image) * 3 - upper_threshold = lower_threshold * 3 - edges = cv2.Canny(gray_image, lower_threshold, upper_threshold) - edge_pil_img = Image.fromarray(edges) - st.image(edge_pil_img, use_column_width=True) - st.markdown("***") - + # strength_of_image = st_memory.slider("What % of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_image_key", help="This will determine how much of the current image will be kept in the final image.") + if type_of_generation == InputImageStyling.IMAGE2IMAGE.value: + with b3: + strength_of_image = st_memory.slider("How much blur would you like to add to the image?", min_value=0, max_value=100, value=50, step=1, key="strength_of_image2image", help="This will determine how much of the current image will be kept in the final image.") + + + elif type_of_generation == InputImageStyling.CONTROLNET_CANNY.value: + with b3: + strength_of_image = st_memory.slider("How much of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_controlnet_canny", help="This will determine how much of the current image will be kept in the final image.") + + elif type_of_generation == InputImageStyling.IPADAPTER_FACE.value: + with b3: + strength_of_image = st_memory.slider("How much of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_ipadapter_face", help="This will determine how much of the current image will be kept in the final image.") + + elif type_of_generation == InputImageStyling.IPADAPTER_PLUS.value: + with b3: + strength_of_plus = st_memory.slider("How much of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_ipadapter_plus", help="This will determine how much of the current image will be kept in the final image.") + + elif type_of_generation == InputImageStyling.IPADPTER_FACE_AND_PLUS.value: + with b3: + if st.session_state[input_image_key] is not None: + st.info("IP-Adapter Face image:") + st.image(st.session_state[input_image_key], use_column_width=True) + strength_of_face = st_memory.slider("How strong would would you like the Face model to influence?", min_value=0, max_value=100, value=50, step=1, key="strength_of_ipadapter_face", help="This will determine how much of the current image will be kept in the final image.") + else: + st.info("IP-Adapter Face image:") + st.error("Please upload an image") + if st.session_state['input_image_2'] is not None: + st.info("IP-Adapter Plus image:") + st.image(st.session_state['input_image_2'], use_column_width=True) + strength_of_plus = st_memory.slider("How strong would you like to influence the Plus model?", min_value=0, max_value=100, value=50, step=1, key="strength_of_ipadapter_plus", help="This will determine how much of the current image will be kept in the final image.") + else: + st.info("IP-Adapter Plus image:") + st.error("Please upload an second image") + + if type_of_generation != InputImageStyling.TEXT2IMAGE.value: + + if st.session_state[input_image_key] is not None: + with b3: + if st.button("Clear input image", key="clear_input_image", use_container_width=True): + st.session_state[input_image_key] = None + st.session_state['input_image_2'] = None + st.rerun() - else: + if not st.session_state[input_image_key]: input_image = None - type_of_transformation = None - strength_of_current_image = None + type_of_generation = None + strength_of_image = None # st.markdown("***") model_name = "stable_diffusion_xl" if position=='explorer': @@ -182,7 +228,9 @@ def generate_images_element(position='explorer', project_uuid=None, timing_uuid= counter += 1 log = None - if not input_image: + + + if InputImageStyling.value_list()[st.session_state['type_of_generation_key']] == InputImageStyling.TEXT2IMAGE.value: query_obj = MLQueryObject( timing_uuid=None, model_uuid=None, @@ -207,7 +255,7 @@ def generate_images_element(position='explorer', project_uuid=None, timing_uuid= output, log = ml_client.predict_model_output_standardized(replicate_model, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) else: - if type_of_transformation == InputImageStyling.EVOLVE_IMAGE.value: + if InputImageStyling.value_list()[st.session_state['type_of_generation_key']] == InputImageStyling.IMAGE2IMAGE.value: input_image_file = save_uploaded_image(input_image, project_uuid) query_obj = MLQueryObject( timing_uuid=None, @@ -227,7 +275,7 @@ def generate_images_element(position='explorer', project_uuid=None, timing_uuid= output, log = ml_client.predict_model_output_standardized(REPLICATE_MODEL.sdxl, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) - elif type_of_transformation == InputImageStyling.MAINTAIN_STRUCTURE.value: + elif InputImageStyling.value_list()[st.session_state['type_of_generation_key']] == InputImageStyling.CONTROLNET_CANNY.value: input_image_file = save_uploaded_image(edge_pil_img, project_uuid) query_obj = MLQueryObject( timing_uuid=None, @@ -247,6 +295,15 @@ def generate_images_element(position='explorer', project_uuid=None, timing_uuid= ) output, log = ml_client.predict_model_output_standardized(REPLICATE_MODEL.sdxl_controlnet, query_obj, queue_inference=QUEUE_INFERENCE_QUERIES) + + elif InputImageStyling.value_list()[st.session_state['type_of_generation_key']] == InputImageStyling.IPADAPTER_FACE.value: + st.write("Not implemented yet") + + elif InputImageStyling.value_list()[st.session_state['type_of_generation_key']] == InputImageStyling.IPADAPTER_PLUS.value: + st.write("Not implemented yet") + + elif InputImageStyling.value_list()[st.session_state['type_of_generation_key']] == InputImageStyling.IPADPTER_FACE_AND_PLUS.value: + st.write("Not implemented yet") if log: inference_data = { @@ -260,6 +317,7 @@ def generate_images_element(position='explorer', project_uuid=None, timing_uuid= process_inference_output(**inference_data) st.info("Check the Generation Log to the left for the status.") + time.sleep(0.5) toggle_generate_inference(position) st.rerun() @@ -273,11 +331,38 @@ def toggle_generate_inference(position): st.session_state[position + '_generate_inference'] = not st.session_state[position + '_generate_inference'] -def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_detailed_view_for_all=False, shortlist=False, num_columns=2, view="main", shot=None): +def gallery_image_view(project_uuid, shortlist=False, view=["main"], shot=None, sidebar=False): data_repo = DataRepo() project_settings = data_repo.get_project_setting(project_uuid) shot_list = data_repo.get_shot_list(project_uuid) + k1,k2 = st.columns([5,1]) + if sidebar != True: + f1, f2 = st.columns([1, 1]) + with f1: + num_columns = st_memory.slider('Number of columns:', min_value=3, max_value=7, value=4,key="num_columns_explorer") + with f2: + num_items_per_page = st_memory.slider('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer") + + if shortlist is False: + page_number = k1.radio("Select page:", options=range(1, project_settings.total_gallery_pages + 1), horizontal=True, key="main_gallery") + if 'view_inference_details' in view: + open_detailed_view_for_all = k2.toggle("Open detailed view for all:", key='main_gallery_toggle') + st.markdown("***") + else: + project_setting = data_repo.get_project_setting(project_uuid) + page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") + open_detailed_view_for_all = False + st.markdown("***") + + + + else: + project_setting = data_repo.get_project_setting(project_uuid) + page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") + open_detailed_view_for_all = False + num_items_per_page = 8 + num_columns = 2 gallery_image_list, res_payload = data_repo.get_all_file_list( file_type=InternalFileType.IMAGE.value, @@ -304,14 +389,23 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de # except (IOError, SyntaxError) as e: # return True # return False - + if shortlist is False: + fetch1, fetch2, fetch3, fetch4 = st.columns([0.25, 1, 1, 0.25]) + st.markdown("***") + with fetch2: + st.info("###### 25 images pending") + with fetch3: + image_pending = 8 + if image_pending: + if st.button("Check for new images", key=f"check_for_new_images_", use_container_width=True): + st.write("Fetching images...") + # st.markdown("***") total_image_count = res_payload['count'] if gallery_image_list and len(gallery_image_list): start_index = 0 end_index = min(start_index + num_items_per_page, total_image_count) shot_names = [s.name for s in shot_list] - shot_names.append('**Create New Shot**') - shot_names.insert(0, '') + shot_names.append('**Create New Shot**') for i in range(start_index, end_index, num_columns): cols = st.columns(num_columns) for j in range(num_columns): @@ -320,16 +414,14 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de st.image(gallery_image_list[i + j].location, use_column_width=True) # else: # st.error("The image is truncated and cannot be displayed.") - if view in ["explorer", "shortlist"]: + if 'add_and_remove_from_shortlist' in view: if shortlist: if st.button("Remove from shortlist ➖", key=f"shortlist_{gallery_image_list[i + j].uuid}",use_container_width=True, help="Remove from shortlist"): data_repo.update_file(gallery_image_list[i + j].uuid, tag=InternalFileTag.GALLERY_IMAGE.value) st.success("Removed From Shortlist") time.sleep(0.3) st.rerun() - else: - if st.button("Add to shortlist ➕", key=f"shortlist_{gallery_image_list[i + j].uuid}",use_container_width=True, help="Add to shortlist"): data_repo.update_file(gallery_image_list[i + j].uuid, tag=InternalFileTag.SHORTLISTED_GALLERY_IMAGE.value) st.success("Added To Shortlist") @@ -343,7 +435,7 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de input_params = json.loads(log.input_params) prompt = input_params.get('prompt', 'No prompt found') model = json.loads(log.output_details)['model_name'].split('/')[-1] - if view in ["explorer", "shortlist","individual_shot"]: + if 'view_inference_details' in view: with st.expander("Prompt Details", expanded=open_detailed_view_for_all): st.info(f"**Prompt:** {prompt}\n\n**Model:** {model}") @@ -355,8 +447,8 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de # ---------- add to shot btn --------------- if "last_shot_number" not in st.session_state: st.session_state["last_shot_number"] = 0 - if view not in ["explorer", "shortlist"]: - if view == "individual_shot": + if 'add_to_this_shot' in view or 'add_to_any_shot' in view: + if 'add_to_this_shot' in view: shot_name = shot.name else: shot_name = st.selectbox('Add to shot:', shot_names, key=f"current_shot_sidebar_selector_{gallery_image_list[i + j].uuid}",index=st.session_state["last_shot_number"]) @@ -373,9 +465,9 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de else: if st.button(f"Add to shot", key=f"add_{gallery_image_list[i + j].uuid}", help="Promote this variant to the primary image", use_container_width=True): - shot_number = shot_names.index(shot_name) + 1 - st.session_state["last_shot_number"] = shot_number - 1 - shot_uuid = shot_list[shot_number - 2].uuid + shot_number = shot_names.index(shot_name) + st.session_state["last_shot_number"] = shot_number + shot_uuid = shot_list[shot_number].uuid add_key_frame(gallery_image_list[i + j], False, shot_uuid, len(data_repo.get_timing_list_from_shot(shot_uuid)), refresh_state=False) # removing this from the gallery view diff --git a/ui_components/components/frame_styling_page.py b/ui_components/components/frame_styling_page.py index 4f252014..833ca693 100644 --- a/ui_components/components/frame_styling_page.py +++ b/ui_components/components/frame_styling_page.py @@ -21,24 +21,19 @@ def frame_styling_page(shot_uuid: str, h2): if len(timing_list) == 0: - with h2: - frame_selector_widget(show=['shot_selector','frame_selector']) - + st.markdown("#### There are no frames present in this shot yet.") else: with st.sidebar: - with h2: - - frame_selector_widget(show=['shot_selector','frame_selector']) - - st.session_state['styling_view'] = st_memory.menu('',\ - ["Generate", "Crop/Move", "Inpainting","Scribbling"], \ - icons=['magic', 'crop', "paint-bucket", 'pencil'], \ - menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \ - key="styling_view_selector", orientation="horizontal", \ - styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}) - + + st.session_state['styling_view'] = st_memory.menu('',\ + ["Generate", "Crop", "Inpaint","Scribble"], \ + icons=['magic', 'crop', "paint-bucket", 'pencil'], \ + menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \ + key="styling_view_selector", orientation="horizontal", \ + styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}) + frame_view(view="Key Frame") st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['frame_styling_view_type']}] > :orange[{st.session_state['styling_view']}] > :blue[{shot.name} - #{st.session_state['current_frame_index']}]") @@ -48,18 +43,18 @@ def frame_styling_page(shot_uuid: str, h2): st.markdown("***") if st.session_state['styling_view'] == "Generate": - with st.expander("🛠️ Generate Variants + Prompt Settings", expanded=True): + with st.expander("🛠️ Generate Variants", expanded=True): generate_images_element(position='individual', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) - elif st.session_state['styling_view'] == "Crop/Move": + elif st.session_state['styling_view'] == "Crop": with st.expander("🤏 Crop, Move & Rotate", expanded=True): cropping_selector_element(shot_uuid) - elif st.session_state['styling_view'] == "Inpainting": + elif st.session_state['styling_view'] == "Inpaint": with st.expander("🌌 Inpainting", expanded=True): inpainting_element(st.session_state['current_frame_uuid']) - elif st.session_state['styling_view'] == "Scribbling": + elif st.session_state['styling_view'] == "Scribble": with st.expander("📝 Draw On Image", expanded=True): drawing_element(shot_uuid) diff --git a/ui_components/components/project_settings_page.py b/ui_components/components/project_settings_page.py index 598327c1..c50d4f36 100644 --- a/ui_components/components/project_settings_page.py +++ b/ui_components/components/project_settings_page.py @@ -10,9 +10,10 @@ def project_settings_page(project_uuid): data_repo = DataRepo() - + st.subheader("Project Settings") project_settings = data_repo.get_project_setting(project_uuid) - attach_audio_element(project_uuid, True) + + frame_sizes = ["512x512", "768x512", "512x768"] current_size = f"{project_settings.width}x{project_settings.height}" @@ -33,4 +34,7 @@ def project_settings_page(project_uuid): if st.button("Save"): data_repo.update_project_setting(project_uuid, width=width) data_repo.update_project_setting(project_uuid, height=height) - st.experimental_rerun() \ No newline at end of file + st.experimental_rerun() + + st.write("") + attach_audio_element(project_uuid, True) \ No newline at end of file diff --git a/ui_components/components/query_logger_page.py b/ui_components/components/query_logger_page.py index 43a5556c..162a80eb 100644 --- a/ui_components/components/query_logger_page.py +++ b/ui_components/components/query_logger_page.py @@ -6,14 +6,16 @@ from utils.data_repo.data_repo import DataRepo def query_logger_page(): - st.header("Inference Log list") + st.subheader("Inference Log list") data_repo = DataRepo() current_user = get_current_user() b1, b2 = st.columns([1, 1]) total_log_table_pages = st.session_state['total_log_table_pages'] if 'total_log_table_pages' in st.session_state else DefaultTimingStyleParams.total_log_table_pages - page_number = b1.number_input('Page number', min_value=1, max_value=total_log_table_pages, value=1, step=1) + list_of_pages = [i for i in range(1, total_log_table_pages + 1)] + page_number = b1.radio('Select page:', options=list_of_pages, key='inference_log_page_number', index=0, horizontal=True) + # page_number = b1.number_input('Page number', min_value=1, max_value=total_log_table_pages, value=1, step=1) inference_log_list, total_page_count = data_repo.get_all_inference_log_list( page=page_number, data_per_page=100 diff --git a/ui_components/components/shortlist_page.py b/ui_components/components/shortlist_page.py index d7b35bf0..179aed5b 100644 --- a/ui_components/components/shortlist_page.py +++ b/ui_components/components/shortlist_page.py @@ -1,21 +1,20 @@ import streamlit as st -from ui_components.components.explorer_page import columnn_selecter,gallery_image_view +from ui_components.components.explorer_page import gallery_image_view from utils.data_repo.data_repo import DataRepo from utils import st_memory def shortlist_page(project_uuid): - st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}]") - st.markdown("***") + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}]") data_repo = DataRepo() project_setting = data_repo.get_project_setting(project_uuid) - columnn_selecter() - k1,k2 = st.columns([5,1]) - shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") - with k2: - open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle',value=False) + # columnn_selecter() + # k1,k2 = st.columns([5,1]) + # shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") + # with k2: + # open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle',value=False) st.markdown("***") - gallery_image_view(project_uuid, shortlist_page_number, st.session_state['num_items_per_page_explorer'], open_detailed_view_for_all, True, st.session_state['num_columns_explorer'],view="shortlist") + gallery_image_view(project_uuid, True,view=['view_inference_details','add_to_any_shot','add_and_remove_from_shortlist']) \ No newline at end of file diff --git a/ui_components/components/timeline_view_page.py b/ui_components/components/timeline_view_page.py index 8799d2de..76e39fb8 100644 --- a/ui_components/components/timeline_view_page.py +++ b/ui_components/components/timeline_view_page.py @@ -17,22 +17,29 @@ def timeline_view_page(shot_uuid: str, h2): if "view" not in st.session_state: st.session_state["view"] = views[0] st.session_state["manual_select"] = None - + st.write("") with st.expander("📋 Explorer Shortlist",expanded=True): if st_memory.toggle("Open", value=True, key="explorer_shortlist_toggle"): project_setting = data_repo.get_project_setting(shot.project.uuid) - page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) - gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=10, open_detailed_view_for_all=False, shortlist=True, num_columns=2,view="sidebar") - + # page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) + gallery_image_view(shot.project.uuid, shortlist=True,view=["add_and_remove_from_shortlist","add_to_any_shot"], shot=shot,sidebar=True) + ''' with h2: st.session_state['view'] = option_menu(None, views, icons=['palette', 'camera-reels', "hourglass", 'stopwatch'], menu_icon="cast", orientation="vertical", key="secti2on_selector", styles={ "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}, manual_select=st.session_state["manual_select"]) if st.session_state["manual_select"] != None: st.session_state["manual_select"] = None - + ''' st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{st.session_state['view']}]") st.markdown("***") + slider1, slider2 = st.columns([1,5]) + with slider1: + show_video = st.toggle("Show Video:", value=False, key="show_video") + if show_video: + st.session_state["view"] = "Shots" + else: + st.session_state["view"] = "Key Frames" timeline_view(st.session_state["shot_uuid"], st.session_state['view']) \ No newline at end of file diff --git a/ui_components/setup.py b/ui_components/setup.py index c58400c1..d9540e61 100644 --- a/ui_components/setup.py +++ b/ui_components/setup.py @@ -137,7 +137,7 @@ def setup_app_ui(): st.session_state["index_of_page"] = 0 with st.sidebar: - main_view_types = ["Creative Process", "Tools & Settings", "Video Rendering"] + main_view_types = ["Creative Process", "Project Settings", "Video Rendering"] st.session_state['main_view_type'] = st_memory.menu(None, main_view_types, icons=['search-heart', 'tools', "play-circle", 'stopwatch'], menu_icon="cast", default_index=0, key="main_view_type_name", orientation="horizontal", styles={ "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "red"}}) @@ -150,21 +150,23 @@ def setup_app_ui(): st.session_state['creative_process_manual_select'] = 0 st.session_state['page'] = creative_process_pages[0] - h1,h2 = st.columns([1.5,1]) - with h1: - # view_types = ["Explorer","Timeline","Individual"] - creative_process_pages = ["Explore", "Shortlist", "Timeline", "Adjust Shot", "Adjust Frame", "Animate Shot"] - st.session_state['page'] = option_menu( - None, - creative_process_pages, - icons=['compass', 'bookshelf','aspect-ratio', "hourglass", 'stopwatch'], - menu_icon="cast", - orientation="vertical", - key="section-selecto1r", - styles={"nav-link": {"font-size": "15px", "margin":"0px", "--hover-color": "#eee"}, - "nav-link-selected": {"background-color": "green"}}, - manual_select=st.session_state['creative_process_manual_select'] - ) + + # view_types = ["Explorer","Timeline","Individual"] + creative_process_pages = ["Explore", "Shortlist", "Timeline", "Adjust Shot", "Animate Shot"] + st.session_state['page'] = option_menu( + None, + creative_process_pages, + icons=['compass', 'bookshelf','aspect-ratio', "hourglass", 'stopwatch'], + menu_icon="cast", + orientation="vertical", + key="section-selecto1r", + styles={"nav-link": {"font-size": "15px", "margin":"0px", "--hover-color": "#eee"}, + "nav-link-selected": {"background-color": "green"}}, + manual_select=st.session_state['creative_process_manual_select'] + ) + + if st.session_state['page'] != "Adjust Shot": + st.session_state['current_frame_sidebar_selector'] = 0 if st.session_state['creative_process_manual_select'] != None: st.session_state['creative_process_manual_select'] = None @@ -179,9 +181,6 @@ def setup_app_ui(): elif st.session_state['page'] == "Timeline": timeline_view_page(st.session_state["shot_uuid"], h2) - elif st.session_state['page'] == "Adjust Frame": - frame_styling_page(st.session_state["shot_uuid"], h2) - elif st.session_state['page'] == "Adjust Shot": adjust_shot_page(st.session_state["shot_uuid"], h2) @@ -192,24 +191,11 @@ def setup_app_ui(): with st.expander("🔍 Generation Log", expanded=True): if st_memory.toggle("Open", value=True, key="generaton_log_toggle"): sidebar_logger(st.session_state["shot_uuid"]) - st.markdown("***") + # st.markdown("***") - elif st.session_state["main_view_type"] == "Tools & Settings": - with st.sidebar: - tool_pages = ["Query Logger", "Project Settings"] - - if st.session_state["page"] not in tool_pages: - st.session_state["page"] = tool_pages[0] - st.session_state["manual_select"] = None - - st.session_state['page'] = option_menu(None, tool_pages, icons=['pencil', 'palette', "hourglass", 'stopwatch'], menu_icon="cast", orientation="horizontal", key="secti2on_selector", styles={ - "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "green"}}, manual_select=st.session_state["manual_select"]) - if st.session_state["page"] == "Query Logger": - query_logger_page() - if st.session_state["page"] == "Custom Models": - custom_models_page(st.session_state["project_uuid"]) - elif st.session_state["page"] == "Project Settings": - project_settings_page(st.session_state["project_uuid"]) + elif st.session_state["main_view_type"] == "Project Settings": + + project_settings_page(st.session_state["project_uuid"]) elif st.session_state["main_view_type"] == "Video Rendering": video_rendering_page(st.session_state["project_uuid"]) diff --git a/ui_components/widgets/animation_style_element.py b/ui_components/widgets/animation_style_element.py index d6d7f20d..6da3d7b5 100644 --- a/ui_components/widgets/animation_style_element.py +++ b/ui_components/widgets/animation_style_element.py @@ -13,7 +13,8 @@ import matplotlib.pyplot as plt def animation_style_element(shot_uuid): - + disable_generate = False + help = "" motion_modules = AnimateDiffCheckpoint.get_name_list() variant_count = 1 current_animation_style = AnimationStyleType.CREATIVE_INTERPOLATION.value # setting a default value @@ -30,55 +31,153 @@ def animation_style_element(shot_uuid): 'animation_tool': AnimationToolType.ANIMATEDIFF.value, } - + interpolation_style = 'ease-in-out' st.markdown("#### Key Frame Settings") - d1, d2 = st.columns([1, 4]) - st.session_state['frame_position'] = 0 - with d1: - setting_a_1, setting_a_2, = st.columns([1, 1]) - with setting_a_1: - type_of_frame_distribution = st_memory.radio("Type of key frame distribution:", options=["Linear", "Dynamic"], key="type_of_frame_distribution").lower() - if type_of_frame_distribution == "linear": - with setting_a_2: - linear_frame_distribution_value = st_memory.number_input("Frames per key frame:", min_value=8, max_value=36, value=16, step=1, key="linear_frame_distribution_value") - dynamic_frame_distribution_values = [] - st.markdown("***") - setting_b_1, setting_b_2 = st.columns([1, 1]) - with setting_b_1: - type_of_key_frame_influence = st_memory.radio("Type of key frame length influence:", options=["Linear", "Dynamic"], key="type_of_key_frame_influence").lower() - if type_of_key_frame_influence == "linear": - with setting_b_2: - linear_key_frame_influence_value = st_memory.number_input("Length of key frame influence:", min_value=0.1, max_value=5.0, value=1.0, step=0.01, key="linear_key_frame_influence_value") - dynamic_key_frame_influence_values = [] - st.markdown("***") - - setting_d_1, setting_d_2 = st.columns([1, 1]) - - with setting_d_1: - type_of_cn_strength_distribution = st_memory.radio("Type of key frame strength control:", options=["Linear", "Dynamic"], key="type_of_cn_strength_distribution").lower() - if type_of_cn_strength_distribution == "linear": - with setting_d_2: - linear_cn_strength_value = st_memory.slider("Range of strength:", min_value=0.0, max_value=1.0, value=(0.0,0.7), step=0.01, key="linear_cn_strength_value") - dynamic_cn_strength_values = [] - - st.markdown("***") - footer1, _ = st.columns([2, 1]) - with footer1: - interpolation_style = 'ease-in-out' - motion_scale = st_memory.slider("Motion scale:", min_value=0.0, max_value=2.0, value=1.0, step=0.01, key="motion_scale") + type_of_setting = st_memory.radio("Type of key frame settings:", options=["Individual", "Bulk"], key="type_of_setting", horizontal=True) + if type_of_setting == "Individual": + items_per_row = 4 + strength_of_frames = [] + distances_to_next_frames = [] + speeds_of_transitions = [] + movements_between_frames = [] + for i in range(0, len(timing_list) , items_per_row): + with st.container(): + grid = st.columns([2 if j%2==0 else 1 for j in range(2*items_per_row)]) # Adjust the column widths + for j in range(items_per_row): + idx = i + j + if idx < len(timing_list): + + with grid[2*j]: # Adjust the index for image column + timing = timing_list[idx] + if timing.primary_image and timing.primary_image.location: + st.info(f"Frame {idx + 1}") + st.image(timing.primary_image.location, use_column_width=True) + # if not the last frame + + strength_of_frame = st.slider("Strength of current frame:", min_value=0.25, max_value=1.0, value=0.5, step=0.01, key=f"strength_of_frame_{idx}_{timing.uuid}") + strength_of_frames.append(strength_of_frame) + + else: + st.warning("No primary image present.") + with grid[2*j+1]: # Add the new column after the image column + if idx < len(timing_list) - 1: + st.write("") + distance_to_next_frame = st.slider("Distance to next frame:", min_value=4, max_value=32, value=16, step=1, key=f"distance_to_next_frame_{idx}_{timing.uuid}") + distances_to_next_frames.append(distance_to_next_frame) + + speed_of_transition = st.slider("Speed of transition:", min_value=0.45, max_value=0.7, value=0.6, step=0.01, key=f"speed_of_transition_{idx}_{timing.uuid}") + speeds_of_transitions.append(speed_of_transition) + + movement_between_frames = st.slider("Motion between frames:", min_value=0.2, max_value=0.95, value=0.5, step=0.01, key=f"movement_between_frames_{idx}_{timing.uuid}") + movements_between_frames.append(movement_between_frames) + + + if (i < len(timing_list) - 1) or (st.session_state["open_shot"] == shot.uuid) or (len(timing_list) % items_per_row != 0 and st.session_state["open_shot"] != shot.uuid): + st.markdown("***") + + + def transform_data(strength_of_frames, movements_between_frames, speeds_of_transitions, distances_to_next_frames): + + + def adjust_and_invert_relative_value(middle_value, relative_value): + if relative_value is not None: + adjusted_value = middle_value * relative_value + return round(middle_value - adjusted_value, 2) + return None + + def invert_value(value): + return round(1.0 - value, 2) if value is not None else None + + # Creating output_strength with relative and inverted start and end values + output_strength = [] + for i, strength in enumerate(strength_of_frames): + start_value = None if i == 0 else movements_between_frames[i - 1] + end_value = None if i == len(strength_of_frames) - 1 else movements_between_frames[i] + + # Adjusting and inverting start and end values relative to the middle value + adjusted_start = adjust_and_invert_relative_value(strength, start_value) + adjusted_end = adjust_and_invert_relative_value(strength, end_value) + + output_strength.append((adjusted_start, strength, adjusted_end)) + + # Creating output_speeds with inverted values + output_speeds = [(None, None) for _ in range(len(speeds_of_transitions) + 1)] + for i in range(len(speeds_of_transitions)): + current_tuple = list(output_speeds[i]) + next_tuple = list(output_speeds[i + 1]) + + inverted_speed = invert_value(speeds_of_transitions[i]) + current_tuple[1] = inverted_speed * 2 + next_tuple[0] = inverted_speed * 2 + + output_speeds[i] = tuple(current_tuple) + output_speeds[i + 1] = tuple(next_tuple) + + # Creating cumulative_distances + cumulative_distances = [0] + for distance in distances_to_next_frames: + cumulative_distances.append(cumulative_distances[-1] + distance) + + return output_strength, output_speeds, cumulative_distances - st.markdown("***") - if st.button("Reset to default settings", key="reset_animation_style"): - update_interpolation_settings(timing_list=timing_list) - st.rerun() + dynamic_strength_values, dynamic_key_frame_influence_values, dynamic_frame_distribution_values = transform_data(strength_of_frames, movements_between_frames, speeds_of_transitions, distances_to_next_frames) + type_of_frame_distribution = "dynamic" + type_of_key_frame_influence = "dynamic" + type_of_strength_distribution = "dynamic" + linear_frame_distribution_value = 16 + linear_key_frame_influence_value = 1.0 + linear_cn_strength_value = 1.0 + if st.toggle("Visualise motion data"): + columns = st.columns(max(7, len(timing_list))) + for idx, timing in enumerate(timing_list): + + markdown_text = f'##### **Frame {idx + 1}** ___' + + with columns[idx]: + st.markdown(markdown_text) + + keyframe_positions = get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, timing_list, linear_frame_distribution_value) + keyframe_positions = [position + 4 - 1 for position in keyframe_positions] + keyframe_positions.insert(0, 0) + + last_key_frame_position = (keyframe_positions[-1] + 1) + strength_values = extract_strength_values(type_of_strength_distribution, dynamic_strength_values, keyframe_positions, linear_cn_strength_value) + key_frame_influence_values = extract_influence_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value) + # calculate_weights(keyframe_positions, strength_values, buffer, key_frame_influence_values): + weights_list, frame_numbers_list = calculate_weights(keyframe_positions, strength_values, 4, key_frame_influence_values,last_key_frame_position) + plot_weights(weights_list, frame_numbers_list) + st.write(f"distribution: {str(dynamic_frame_distribution_values)[1:-1]}") + st.write(f"influence: {str(dynamic_key_frame_influence_values)[1:-1]}") + st.write(f"strength: {str(dynamic_strength_values)[1:-1]}") + + # drop all the first values in each list + # keyframe_positions = keyframe_positions[1:] + # strength_values = strength_values[1:] + #key_frame_influence_values = key_frame_influence_values[1:] + + # shirt all the keyframe values back by 4 + # keyframe_positions = [position - 3 for position in keyframe_positions] + + # make keyframe into a plain list + + # st.write(keyframe_positions) + # st.write(strength_values) + # st.write(key_frame_influence_values) + - with d2: + + elif type_of_setting == "Bulk": + + st.session_state['frame_position'] = 0 + type_of_frame_distribution = "linear" + type_of_key_frame_influence = "linear" + type_of_strength_distribution = "linear" columns = st.columns(max(7, len(timing_list))) disable_generate = False help = "" dynamic_frame_distribution_values = [] dynamic_key_frame_influence_values = [] - dynamic_cn_strength_values = [] + dynamic_strength_values = [] for idx, timing in enumerate(timing_list): @@ -92,188 +191,50 @@ def animation_style_element(shot_uuid): if timing.primary_image and timing.primary_image.location: columns[idx].image(timing.primary_image.location, use_column_width=True) - b = timing.primary_image.inference_params - if type_of_frame_distribution == "dynamic": - linear_frame_distribution_value = 16 - if f"frame_{idx+1}" not in st.session_state: - st.session_state[f"frame_{idx+1}"] = idx * 16 # Default values in increments of 16 - if idx == 0: # For the first frame, position is locked to 0 - with columns[idx]: - frame_position = st_memory.number_input(f"{idx+1} frame Position", min_value=0, max_value=0, value=0, step=1, key=f"dynamic_frame_distribution_values_{idx}", disabled=True) - else: - min_value = st.session_state[f"frame_{idx}"] + 1 - with columns[idx]: - frame_position = st_memory.number_input(f"#{idx+1} position:", min_value=min_value, value=st.session_state[f"frame_{idx+1}"], step=1, key=f"dynamic_frame_distribution_values_{idx}") - # st.session_state[f"frame_{idx+1}"] = frame_position - dynamic_frame_distribution_values.append(frame_position) - - if type_of_key_frame_influence == "dynamic": - linear_key_frame_influence_value = 1.1 - with columns[idx]: - dynamic_key_frame_influence_individual_value = st_memory.slider(f"#{idx+1} length of influence:", min_value=0.0, max_value=5.0, value=1.0, step=0.1, key=f"dynamic_key_frame_influence_values_{idx}") - dynamic_key_frame_influence_values.append(str(dynamic_key_frame_influence_individual_value)) - - if type_of_cn_strength_distribution == "dynamic": - linear_cn_strength_value = (0.0,1.0) - with columns[idx]: - help_texts = ["For the first frame, it'll start at the endpoint and decline to the starting point", - "For the final frame, it'll start at the starting point and end at the endpoint", - "For intermediate frames, it'll start at the starting point, peak in the middle at the endpoint, and decline to the starting point"] - label_texts = [f"#{idx+1} end -> start:", f"#{idx+1} start -> end:", f"#{idx+1} start -> peak:"] - help_text = help_texts[0] if idx == 0 else help_texts[1] if idx == len(timing_list) - 1 else help_texts[2] - label_text = label_texts[0] if idx == 0 else label_texts[1] if idx == len(timing_list) - 1 else label_texts[2] - dynamic_cn_strength_individual_value = st_memory.slider(label_text, min_value=0.0, max_value=1.0, value=(0.0,0.7), step=0.1, key=f"dynamic_cn_strength_values_{idx}",help=help_text) - dynamic_cn_strength_values.append(str(dynamic_cn_strength_individual_value)) - - # Convert lists to strings - dynamic_frame_distribution_values = ",".join(map(str, dynamic_frame_distribution_values)) # Convert integers to strings before joining - dynamic_key_frame_influence_values = ",".join(dynamic_key_frame_influence_values) - dynamic_cn_strength_values = ",".join(dynamic_cn_strength_values) - # dynamic_start_and_endpoint_values = ",".join(dynamic_start_and_endpoint_values) - # st.write(dynamic_start_and_endpoint_values) - - - def calculate_dynamic_influence_ranges(keyframe_positions, key_frame_influence_values, allow_extension=True): - if len(keyframe_positions) < 2 or len(keyframe_positions) != len(key_frame_influence_values): - return [] - - influence_ranges = [] - for i, position in enumerate(keyframe_positions): - influence_factor = key_frame_influence_values[i] - range_size = influence_factor * (keyframe_positions[-1] - keyframe_positions[0]) / (len(keyframe_positions) - 1) / 2 - - start_influence = position - range_size - end_influence = position + range_size - - # If extension beyond the adjacent keyframe is allowed, do not constrain the start and end influence. - if not allow_extension: - start_influence = max(start_influence, keyframe_positions[i - 1] if i > 0 else 0) - end_influence = min(end_influence, keyframe_positions[i + 1] if i < len(keyframe_positions) - 1 else keyframe_positions[-1]) - - influence_ranges.append((round(start_influence), round(end_influence))) - - return influence_ranges - - - def get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, images, linear_frame_distribution_value): - if type_of_frame_distribution == "dynamic": - return sorted([int(kf.strip()) for kf in dynamic_frame_distribution_values.split(',')]) - else: - return [i * linear_frame_distribution_value for i in range(len(images))] - - def extract_keyframe_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): - if type_of_key_frame_influence == "dynamic": - return [float(influence.strip()) for influence in dynamic_key_frame_influence_values.split(',')] - else: - return [linear_key_frame_influence_value for _ in keyframe_positions] - - def extract_start_and_endpoint_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): - if type_of_key_frame_influence == "dynamic": - # If dynamic_key_frame_influence_values is a list of characters representing tuples, process it - if isinstance(dynamic_key_frame_influence_values[0], str) and dynamic_key_frame_influence_values[0] == "(": - # Join the characters to form a single string and evaluate to convert into a list of tuples - string_representation = ''.join(dynamic_key_frame_influence_values) - dynamic_values = eval(f'[{string_representation}]') - else: - # If it's already a list of tuples or a single tuple, use it directly - dynamic_values = dynamic_key_frame_influence_values if isinstance(dynamic_key_frame_influence_values, list) else [dynamic_key_frame_influence_values] - return dynamic_values - else: - # Return a list of tuples with the linear_key_frame_influence_value as a tuple repeated for each position - return [linear_key_frame_influence_value for _ in keyframe_positions] - - def calculate_weights(influence_ranges, interpolation, start_and_endpoint_strength, last_key_frame_position): - weights_list = [] - frame_numbers_list = [] - - for i, (range_start, range_end) in enumerate(influence_ranges): - # Initialize variables - if i == 0: - strength_to, strength_from = start_and_endpoint_strength[i] if i < len(start_and_endpoint_strength) else (0.0, 1.0) - else: - strength_from, strength_to = start_and_endpoint_strength[i] if i < len(start_and_endpoint_strength) else (1.0, 0.0) - revert_direction_at_midpoint = (i != 0) and (i != len(influence_ranges) - 1) - - # if it's the first value, set influence range from 1.0 to 0.0 - if i == 0: - range_start = 0 - - # if it's the last value, set influence range to end at last_key_frame_position - if i == len(influence_ranges) - 1: - range_end = last_key_frame_position - - steps = range_end - range_start - diff = strength_to - strength_from - - # Calculate index for interpolation - index = np.linspace(0, 1, steps // 2 + 1) if revert_direction_at_midpoint else np.linspace(0, 1, steps) - - # Calculate weights based on interpolation type - if interpolation == "linear": - weights = np.linspace(strength_from, strength_to, len(index)) - elif interpolation == "ease-in": - weights = diff * np.power(index, 2) + strength_from - elif interpolation == "ease-out": - weights = diff * (1 - np.power(1 - index, 2)) + strength_from - elif interpolation == "ease-in-out": - weights = diff * ((1 - np.cos(index * np.pi)) / 2) + strength_from - - # If it's a middle keyframe, mirror the weights - if revert_direction_at_midpoint: - weights = np.concatenate([weights, weights[::-1]]) - - # Generate frame numbers - frame_numbers = np.arange(range_start, range_start + len(weights)) - - # "Dropper" component: For keyframes with negative start, drop the weights - if range_start < 0 and i > 0: - drop_count = abs(range_start) - weights = weights[drop_count:] - frame_numbers = frame_numbers[drop_count:] - - # Dropper component: for keyframes a range_End is greater than last_key_frame_position, drop the weights - if range_end > last_key_frame_position and i < len(influence_ranges) - 1: - drop_count = range_end - last_key_frame_position - weights = weights[:-drop_count] - frame_numbers = frame_numbers[:-drop_count] - - weights_list.append(weights) - frame_numbers_list.append(frame_numbers) - - return weights_list, frame_numbers_list - - - def plot_weights(weights_list, frame_numbers_list, frame_names): - plt.figure(figsize=(12, 6)) - - for i, weights in enumerate(weights_list): - frame_numbers = frame_numbers_list[i] - plt.plot(frame_numbers, weights, label=f'{frame_names[i]}') - - # Plot settings - plt.xlabel('Frame Number') - plt.ylabel('Weight') - plt.legend() - plt.ylim(0, 1.0) - plt.show() - st.set_option('deprecation.showPyplotGlobalUse', False) - st.pyplot() - - keyframe_positions = get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, timing_list, linear_frame_distribution_value) - last_key_frame_position = keyframe_positions[-1] - cn_strength_values = extract_start_and_endpoint_values(type_of_cn_strength_distribution, dynamic_cn_strength_values, keyframe_positions, linear_cn_strength_value) - key_frame_influence_values = extract_keyframe_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value) - # start_and_endpoint_values = extract_start_and_endpoint_values(type_of_start_and_endpoint, dynamic_start_and_endpoint_values, keyframe_positions, linear_start_and_endpoint_value) - influence_ranges = calculate_dynamic_influence_ranges(keyframe_positions, key_frame_influence_values) - weights_list, frame_numbers_list = calculate_weights(influence_ranges, interpolation_style, cn_strength_values, last_key_frame_position) - frame_names = [f'Frame {i+1}' for i in range(len(influence_ranges))] - plot_weights(weights_list, frame_numbers_list, frame_names) + b = timing.primary_image.inference_params + d1, d2 = st.columns([1, 5]) + with d1: + + linear_frame_distribution_value = st_memory.number_input("Frames per key frame:", min_value=8, max_value=36, value=16, step=1, key="linear_frame_distribution_value") + linear_key_frame_influence_value = st_memory.number_input("Length of key frame influence:", min_value=0.1, max_value=5.0, value=0.75, step=0.01, key="linear_key_frame_influence_value") + strength1, strength2 = st.columns([1, 1]) + with strength1: + bottom_of_strength_range = st_memory.number_input("Bottom of strength range:", min_value=0.0, max_value=1.0, value=0.35, step=0.01, key="bottom_of_strength_range") + with strength2: + top_of_strength_range = st_memory.number_input("Top of strength range:", min_value=0.0, max_value=1.0, value=0.5, step=0.01, key="top_of_strength_range") + + linear_cn_strength_value = (bottom_of_strength_range, top_of_strength_range) + + footer1, _ = st.columns([2, 1]) + with footer1: + interpolation_style = 'ease-in-out' + + if st.button("Reset to default settings", key="reset_animation_style"): + update_interpolation_settings(timing_list=timing_list) + st.rerun() + with d2: + columns = st.columns(max(7, len(timing_list))) + disable_generate = False + help = "" + + keyframe_positions = get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, timing_list, linear_frame_distribution_value) + keyframe_positions = [position + 4 - 1 for position in keyframe_positions] + keyframe_positions.insert(0, 0) + + last_key_frame_position = (keyframe_positions[-1] + 1) + strength_values = extract_strength_values(type_of_strength_distribution, dynamic_strength_values, keyframe_positions, linear_cn_strength_value) + key_frame_influence_values = extract_influence_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value) + # calculate_weights(keyframe_positions, strength_values, buffer, key_frame_influence_values): + weights_list, frame_numbers_list = calculate_weights(keyframe_positions, strength_values, 4, key_frame_influence_values,last_key_frame_position) + plot_weights(weights_list, frame_numbers_list) st.markdown("***") - e1, e2 = st.columns([1, 1]) + + st.markdown("#### Styling Settings") + e1, e2, e3 = st.columns([1, 1,1]) with e1: - st.markdown("#### Styling Settings") + strength_of_adherence = st_memory.slider("How much would you like to adhere to the input images?", min_value=0.0, max_value=1.0, value=0.5, step=0.01, key="stregnth_of_adherence") sd_model_list = [ "Realistic_Vision_V5.0.safetensors", "Counterfeit-V3.0_fp32.safetensors", @@ -283,14 +244,33 @@ def plot_weights(weights_list, frame_numbers_list, frame_names): ] # remove .safe tensors from the end of each model name + # motion_scale = st_memory.slider("Motion scale:", min_value=0.0, max_value=2.0, value=1.3, step=0.01, key="motion_scale") + motion_scale = 1.3 sd_model = st_memory.selectbox("Which model would you like to use?", options=sd_model_list, key="sd_model_video") - negative_prompt = st_memory.text_area("What would you like to avoid in the videos?", value="bad image, worst quality", key="negative_prompt_video") - relative_ipadapter_strength = st_memory.slider("How much would you like to influence the style?", min_value=0.0, max_value=5.0, value=1.1, step=0.1, key="ip_adapter_strength") - relative_ipadapter_influence = st_memory.slider("For how long would you like to influence the style?", min_value=0.0, max_value=5.0, value=1.1, step=0.1, key="ip_adapter_influence") - soft_scaled_cn_weights_multipler = st_memory.slider("How much would you like to scale the CN weights?", min_value=0.0, max_value=10.0, value=0.85, step=0.1, key="soft_scaled_cn_weights_multiple_video") - append_to_prompt = st_memory.text_input("What would you like to append to the prompts?", key="append_to_prompt") + f1, f2 = st.columns([1, 1]) + with f1: + prompt1, prompt2 = st.columns([1, 1]) + with prompt1: + positive_prompt = st_memory.text_area("What would you like to see in the videos?", value="", key="positive_prompt_video") + with prompt2: + negative_prompt = st_memory.text_area("What would you like to avoid in the videos?", value="bad image, worst quality", key="negative_prompt_video") + + soft_scaled_cn_weights_multiplier ="" + + with e2: + st.info("Higher values may cause flickering and sudden changes in the video. Lower values may cause the video to be less influenced by the input images.") + + + + # relative_ipadapter_strength = st_memory.slider("How much would you like to influence the style?", min_value=0.0, max_value=5.0, value=1.1, step=0.1, key="ip_adapter_strength") + # relative_ipadapter_influence = st_memory.slider("For how long would you like to influence the style?", min_value=0.0, max_value=5.0, value=1.1, step=0.1, key="ip_adapter_influence") + # soft_scaled_cn_weights_multipler = st_memory.slider("How much would you like to scale the CN weights?", min_value=0.0, max_value=10.0, value=0.85, step=0.1, key="soft_scaled_cn_weights_multiple_video") + # append_to_prompt = st_memory.text_input("What would you like to append to the prompts?", key="append_to_prompt") normalise_speed = True + + relative_ipadapter_strength = 1.0 + relative_ipadapter_influence = 0.0 project_settings = data_repo.get_project_setting(shot.project.uuid) width = project_settings.width @@ -307,11 +287,10 @@ def plot_weights(weights_list, frame_numbers_list, frame_names): interpolation_type=interpolation_style, stmfnet_multiplier=2, relative_ipadapter_strength=relative_ipadapter_strength, - relative_ipadapter_influence=relative_ipadapter_influence, - soft_scaled_cn_weights_multiplier=soft_scaled_cn_weights_multipler, - type_of_cn_strength_distribution=type_of_cn_strength_distribution, + relative_ipadapter_influence=relative_ipadapter_influence, + type_of_cn_strength_distribution=type_of_strength_distribution, linear_cn_strength_value=str(linear_cn_strength_value), - dynamic_cn_strength_values=str(dynamic_cn_strength_values), + dynamic_strength_values=str(dynamic_strength_values), type_of_frame_distribution=type_of_frame_distribution, linear_frame_distribution_value=linear_frame_distribution_value, dynamic_frame_distribution_values=dynamic_frame_distribution_values, @@ -328,7 +307,7 @@ def plot_weights(weights_list, frame_numbers_list, frame_names): if where_to_generate == "Cloud": animate_col_1, animate_col_2, _ = st.columns([1, 1, 2]) with animate_col_1: - variant_count = st.number_input("How many variants?", min_value=1, max_value=100, value=1, step=1, key="variant_count") + variant_count = st.number_input("How many variants?", min_value=1, max_value=5, value=1, step=1, key="variant_count") if st.button("Generate Animation Clip", key="generate_animation_clip", disabled=disable_generate, help=help): vid_quality = "full" if video_resolution == "Full Resolution" else "preview" @@ -433,7 +412,28 @@ def prepare_workflow_images(shot_uuid): buffer.seek(0) return buffer.getvalue() - +def extract_strength_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): + + if type_of_key_frame_influence == "dynamic": + # Process the dynamic_key_frame_influence_values depending on its format + if isinstance(dynamic_key_frame_influence_values, str): + dynamic_values = eval(dynamic_key_frame_influence_values) + else: + dynamic_values = dynamic_key_frame_influence_values + + # Iterate through the dynamic values and convert tuples with two values to three values + dynamic_values_corrected = [] + for value in dynamic_values: + if len(value) == 2: + value = (value[0], value[1], value[0]) + dynamic_values_corrected.append(value) + + return dynamic_values_corrected + else: + # Process for linear or other types + if len(linear_key_frame_influence_value) == 2: + linear_key_frame_influence_value = (linear_key_frame_influence_value[0], linear_key_frame_influence_value[1], linear_key_frame_influence_value[0]) + return [linear_key_frame_influence_value for _ in range(len(keyframe_positions) - 1)] def create_workflow_json(image_locations, settings): import os @@ -458,7 +458,7 @@ def create_workflow_json(image_locations, settings): type_of_cn_strength_distribution=settings['type_of_cn_strength_distribution'] linear_cn_strength_value=settings['linear_cn_strength_value'] buffer = settings['buffer'] - dynamic_cn_strength_values = settings['dynamic_cn_strength_values'] + dynamic_strength_values = settings['dynamic_strength_values'] interpolation_type = settings['interpolation_type'] ckpt = settings['ckpt'] motion_scale = settings['motion_scale'] @@ -466,7 +466,7 @@ def create_workflow_json(image_locations, settings): relative_ipadapter_influence = settings['relative_ipadapter_influence'] image_dimension = settings['image_dimension'] output_format = settings['output_format'] - soft_scaled_cn_weights_multiplier = settings['soft_scaled_cn_weights_multiplier'] + # soft_scaled_cn_weights_multiplier = settings['soft_scaled_cn_weights_multiplier'] stmfnet_multiplier = settings['stmfnet_multiplier'] if settings['type_of_frame_distribution'] == 'linear': @@ -503,7 +503,7 @@ def create_workflow_json(image_locations, settings): node['widgets_values'][6] = dynamic_key_frame_influence_values node['widgets_values'][7] = type_of_cn_strength_distribution node['widgets_values'][8] = linear_cn_strength_value - node['widgets_values'][9] = dynamic_cn_strength_values + node['widgets_values'][9] = dynamic_strength_values node['widgets_values'][-1] = buffer node['widgets_values'][-2] = interpolation_type node['widgets_values'][-3] = soft_scaled_cn_weights_multiplier @@ -544,8 +544,267 @@ def update_interpolation_settings(values=None, timing_list=None): for idx in range(0, len(timing_list)): default_values[f'dynamic_frame_distribution_values_{idx}'] = (idx) * 16 default_values[f'dynamic_key_frame_influence_values_{idx}'] = 1.0 - default_values[f'dynamic_cn_strength_values_{idx}'] = (0.0,0.7) + default_values[f'dynamic_strength_values_{idx}'] = (0.0,0.7) for key, default_value in default_values.items(): st.session_state[key] = values.get(key, default_value) if values and values.get(key) is not None else default_value - # print(f"{key}: {st.session_state[key]}") \ No newline at end of file + # print(f"{key}: {st.session_state[key]}") + + +def calculate_dynamic_influence_ranges(keyframe_positions, key_frame_influence_values, allow_extension=True): + if len(keyframe_positions) < 2 or len(keyframe_positions) != len(key_frame_influence_values): + return [] + + influence_ranges = [] + for i, position in enumerate(keyframe_positions): + influence_factor = key_frame_influence_values[i] + range_size = influence_factor * (keyframe_positions[-1] - keyframe_positions[0]) / (len(keyframe_positions) - 1) / 2 + + start_influence = position - range_size + end_influence = position + range_size + + # If extension beyond the adjacent keyframe is allowed, do not constrain the start and end influence. + if not allow_extension: + start_influence = max(start_influence, keyframe_positions[i - 1] if i > 0 else 0) + end_influence = min(end_influence, keyframe_positions[i + 1] if i < len(keyframe_positions) - 1 else keyframe_positions[-1]) + + influence_ranges.append((round(start_influence), round(end_influence))) + + return influence_ranges + +def extract_influence_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): + # Check and convert linear_key_frame_influence_value if it's a float or string float + # if it's a string that starts with a parenthesis, convert it to a tuple + if isinstance(linear_key_frame_influence_value, str) and linear_key_frame_influence_value[0] == "(": + linear_key_frame_influence_value = eval(linear_key_frame_influence_value) + + + if not isinstance(linear_key_frame_influence_value, tuple): + if isinstance(linear_key_frame_influence_value, (float, str)): + try: + value = float(linear_key_frame_influence_value) + linear_key_frame_influence_value = (value, value) + except ValueError: + raise ValueError("linear_key_frame_influence_value must be a float or a string representing a float") + + number_of_outputs = len(keyframe_positions) - 1 + + if type_of_key_frame_influence == "dynamic": + # Convert list of individual float values into tuples + if all(isinstance(x, float) for x in dynamic_key_frame_influence_values): + dynamic_values = [(value, value) for value in dynamic_key_frame_influence_values] + elif isinstance(dynamic_key_frame_influence_values[0], str) and dynamic_key_frame_influence_values[0] == "(": + string_representation = ''.join(dynamic_key_frame_influence_values) + dynamic_values = eval(f'[{string_representation}]') + else: + dynamic_values = dynamic_key_frame_influence_values if isinstance(dynamic_key_frame_influence_values, list) else [dynamic_key_frame_influence_values] + return dynamic_values[:number_of_outputs] + else: + return [linear_key_frame_influence_value for _ in range(number_of_outputs)] + + +def get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, images, linear_frame_distribution_value): + if type_of_frame_distribution == "dynamic": + # Check if the input is a string or a list + if isinstance(dynamic_frame_distribution_values, str): + # Sort the keyframe positions in numerical order + return sorted([int(kf.strip()) for kf in dynamic_frame_distribution_values.split(',')]) + elif isinstance(dynamic_frame_distribution_values, list): + return sorted(dynamic_frame_distribution_values) + else: + # Calculate the number of keyframes based on the total duration and linear_frames_per_keyframe + return [i * linear_frame_distribution_value for i in range(len(images))] + +def extract_keyframe_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): + if type_of_key_frame_influence == "dynamic": + return [float(influence.strip()) for influence in dynamic_key_frame_influence_values.split(',')] + else: + return [linear_key_frame_influence_value for _ in keyframe_positions] + +def extract_start_and_endpoint_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): + if type_of_key_frame_influence == "dynamic": + # If dynamic_key_frame_influence_values is a list of characters representing tuples, process it + if isinstance(dynamic_key_frame_influence_values[0], str) and dynamic_key_frame_influence_values[0] == "(": + # Join the characters to form a single string and evaluate to convert into a list of tuples + string_representation = ''.join(dynamic_key_frame_influence_values) + dynamic_values = eval(f'[{string_representation}]') + else: + # If it's already a list of tuples or a single tuple, use it directly + dynamic_values = dynamic_key_frame_influence_values if isinstance(dynamic_key_frame_influence_values, list) else [dynamic_key_frame_influence_values] + return dynamic_values + else: + # Return a list of tuples with the linear_key_frame_influence_value as a tuple repeated for each position + return [linear_key_frame_influence_value for _ in keyframe_positions] + +def calculate_weights(keyframe_positions, strength_values, buffer, key_frame_influence_values,last_key_frame_position): + + def calculate_influence_frame_number(key_frame_position, next_key_frame_position, distance): + # Calculate the absolute distance between key frames + key_frame_distance = abs(next_key_frame_position - key_frame_position) + + # Apply the distance multiplier + extended_distance = key_frame_distance * distance + + # Determine the direction of influence based on the positions of the key frames + if key_frame_position < next_key_frame_position: + # Normal case: influence extends forward + influence_frame_number = key_frame_position + extended_distance + else: + # Reverse case: influence extends backward + influence_frame_number = key_frame_position - extended_distance + + # Return the result rounded to the nearest integer + return round(influence_frame_number) + + def find_curve(batch_index_from, batch_index_to, strength_from, strength_to, interpolation,revert_direction_at_midpoint, last_key_frame_position,i, number_of_items,buffer): + + # Initialize variables based on the position of the keyframe + range_start = batch_index_from + range_end = batch_index_to + # if it's the first value, set influence range from 1.0 to 0.0 + if buffer > 0: + if i == 0: + range_start = 0 + elif i == 1: + range_start = buffer + else: + if i == 1: + range_start = 0 + + if i == number_of_items - 1: + range_end = last_key_frame_position + + steps = range_end - range_start + diff = strength_to - strength_from + + # Calculate index for interpolation + index = np.linspace(0, 1, steps // 2 + 1) if revert_direction_at_midpoint else np.linspace(0, 1, steps) + + # Calculate weights based on interpolation type + if interpolation == "linear": + weights = np.linspace(strength_from, strength_to, len(index)) + elif interpolation == "ease-in": + weights = diff * np.power(index, 2) + strength_from + elif interpolation == "ease-out": + weights = diff * (1 - np.power(1 - index, 2)) + strength_from + elif interpolation == "ease-in-out": + weights = diff * ((1 - np.cos(index * np.pi)) / 2) + strength_from + + if revert_direction_at_midpoint: + weights = np.concatenate([weights, weights[::-1]]) + + # Generate frame numbers + frame_numbers = np.arange(range_start, range_start + len(weights)) + + # "Dropper" component: For keyframes with negative start, drop the weights + if range_start < 0 and i > 0: + drop_count = abs(range_start) + weights = weights[drop_count:] + frame_numbers = frame_numbers[drop_count:] + + # Dropper component: for keyframes a range_End is greater than last_key_frame_position, drop the weights + if range_end > last_key_frame_position and i < number_of_items - 1: + drop_count = range_end - last_key_frame_position + weights = weights[:-drop_count] + frame_numbers = frame_numbers[:-drop_count] + + return weights, frame_numbers + weights_list = [] + frame_numbers_list = [] + + for i in range(len(keyframe_positions)): + + keyframe_position = keyframe_positions[i] + interpolation = "ease-in-out" + # strength_from = strength_to = 1.0 + + if i == 0: # buffer + + if buffer > 0: # First image with buffer + + strength_from = strength_to = strength_values[0][1] + else: + continue # Skip first image without buffer + batch_index_from = 0 + batch_index_to_excl = buffer + weights, frame_numbers = find_curve(batch_index_from, batch_index_to_excl, strength_from, strength_to, interpolation, False, last_key_frame_position, i, len(keyframe_positions), buffer) + + elif i == 1: # first image + + # GET IMAGE AND KEYFRAME INFLUENCE VALUES + key_frame_influence_from, key_frame_influence_to = key_frame_influence_values[0] + start_strength, mid_strength, end_strength = strength_values[0] + + keyframe_position = keyframe_positions[i] + next_key_frame_position = keyframe_positions[i+1] + + batch_index_from = keyframe_position + batch_index_to_excl = calculate_influence_frame_number(keyframe_position, next_key_frame_position, key_frame_influence_to) + weights, frame_numbers = find_curve(batch_index_from, batch_index_to_excl, mid_strength, end_strength, interpolation, False, last_key_frame_position, i, len(keyframe_positions), buffer) + # interpolation = "ease-in" + + elif i == len(keyframe_positions) - 1: # last image + + # GET IMAGE AND KEYFRAME INFLUENCE VALUES + key_frame_influence_from,key_frame_influence_to = key_frame_influence_values[i-1] + start_strength, mid_strength, end_strength = strength_values[i-1] + # strength_from, strength_to = cn_strength_values[i-1] + + keyframe_position = keyframe_positions[i] + previous_key_frame_position = keyframe_positions[i-1] + + batch_index_from = calculate_influence_frame_number(keyframe_position, previous_key_frame_position, key_frame_influence_from) + batch_index_to_excl = keyframe_position + weights, frame_numbers = find_curve(batch_index_from, batch_index_to_excl, start_strength, mid_strength, interpolation, False, last_key_frame_position, i, len(keyframe_positions), buffer) + # interpolation = "ease-out" + + else: # middle images + + # GET IMAGE AND KEYFRAME INFLUENCE VALUES + key_frame_influence_from,key_frame_influence_to = key_frame_influence_values[i-1] + start_strength, mid_strength, end_strength = strength_values[i-1] + keyframe_position = keyframe_positions[i] + + # CALCULATE WEIGHTS FOR FIRST HALF + previous_key_frame_position = keyframe_positions[i-1] + batch_index_from = calculate_influence_frame_number(keyframe_position, previous_key_frame_position, key_frame_influence_from) + batch_index_to_excl = keyframe_position + first_half_weights, first_half_frame_numbers = find_curve(batch_index_from, batch_index_to_excl, start_strength, mid_strength, interpolation, False, last_key_frame_position, i, len(keyframe_positions), buffer) + + # CALCULATE WEIGHTS FOR SECOND HALF + next_key_frame_position = keyframe_positions[i+1] + batch_index_from = keyframe_position + batch_index_to_excl = calculate_influence_frame_number(keyframe_position, next_key_frame_position, key_frame_influence_to) + second_half_weights, second_half_frame_numbers = find_curve(batch_index_from, batch_index_to_excl, mid_strength, end_strength, interpolation, False, last_key_frame_position, i, len(keyframe_positions), buffer) + + # COMBINE FIRST AND SECOND HALF + weights = np.concatenate([first_half_weights, second_half_weights]) + frame_numbers = np.concatenate([first_half_frame_numbers, second_half_frame_numbers]) + + weights_list.append(weights) + frame_numbers_list.append(frame_numbers) + + + return weights_list, frame_numbers_list + + + +def plot_weights(weights_list, frame_numbers_list): + plt.figure(figsize=(12, 6)) + + # Drop the first list of values from both lists + weights_list = weights_list[1:] + frame_numbers_list = frame_numbers_list[1:] + + for i, weights in enumerate(weights_list): + frame_numbers = frame_numbers_list[i] + plt.plot(frame_numbers, weights, label=f'Frame {i + 1}') + + # Plot settings + plt.xlabel('Frame Number') + plt.ylabel('Weight') + plt.legend() + plt.ylim(0, 1.0) + plt.show() + st.set_option('deprecation.showPyplotGlobalUse', False) + st.pyplot() \ No newline at end of file diff --git a/ui_components/widgets/frame_movement_widgets.py b/ui_components/widgets/frame_movement_widgets.py index 4a8351d8..21227f6f 100644 --- a/ui_components/widgets/frame_movement_widgets.py +++ b/ui_components/widgets/frame_movement_widgets.py @@ -118,56 +118,26 @@ def replace_image_widget(timing_uuid, stage, options=["Uploaded Frame", "Other F timing = data_repo.get_timing_from_uuid(timing_uuid) timing_list = data_repo.get_timing_list_from_shot(timing.shot.uuid) - replace_with = options[0] if len(options) == 1 else st.radio("Replace with:", options, horizontal=True, key=f"replacement_entity_{stage}_{timing_uuid}") - - if replace_with == "Other Frame": - image_replacement_stage = st.radio( - "Select stage to use:", - [ImageStage.MAIN_VARIANT.value, ImageStage.SOURCE_IMAGE.value], - key=f"image_replacement_stage_{stage}_{timing_uuid}", - horizontal=True - ) - replacement_img_number = st.number_input("Select image to use:", min_value=1, max_value=len( - timing_list), value=0, key=f"replacement_img_number_{stage}") - - if image_replacement_stage == ImageStage.SOURCE_IMAGE.value: - selected_image = timing_list[replacement_img_number - 1].source_image - elif image_replacement_stage == ImageStage.MAIN_VARIANT.value: - selected_image = timing_list[replacement_img_number - 1].primary_image - - st.image(selected_image.location, use_column_width=True) - - if st.button("Replace with selected frame", disabled=False,key=f"replace_with_selected_frame_{stage}_{timing_uuid}"): - if stage == WorkflowStageType.SOURCE.value: - data_repo.update_specific_timing(timing.uuid, source_image_id=selected_image.uuid) - st.success("Replaced") - time.sleep(1) - st.rerun() - else: - number_of_image_variants = add_image_variant( - selected_image.uuid, timing.uuid) - promote_image_variant( - timing.uuid, number_of_image_variants - 1) - st.success("Replaced") - time.sleep(1) - st.rerun() - elif replace_with == "Uploaded Frame": - btn_text = 'Upload source image' if stage == WorkflowStageType.SOURCE.value else 'Replace frame' - uploaded_file = st.file_uploader(btn_text, type=[ - "png", "jpeg"], accept_multiple_files=False,key=f"uploaded_file_{stage}_{timing_uuid}") - if uploaded_file != None: - if st.button(btn_text): - if uploaded_file: - timing = data_repo.get_timing_from_uuid(timing.uuid) - if save_and_promote_image(uploaded_file, timing.shot.uuid, timing.uuid, stage): - st.success("Replaced") - time.sleep(1.5) - st.rerun() + + btn_text = 'Upload source image' if stage == WorkflowStageType.SOURCE.value else 'Replace frame' + uploaded_file = st.file_uploader(btn_text, type=[ + "png", "jpeg"], accept_multiple_files=False,key=f"uploaded_file_{stage}_{timing_uuid}") + if uploaded_file != None: + if st.button(btn_text): + if uploaded_file: + timing = data_repo.get_timing_from_uuid(timing.uuid) + if save_and_promote_image(uploaded_file, timing.shot.uuid, timing.uuid, stage): + st.success("Replaced") + time.sleep(1.5) + st.rerun() def jump_to_single_frame_view_button(display_number, timing_list, src,uuid=None): if st.button(f"Jump to #{display_number}", key=f"{src}_{uuid}", use_container_width=True): + st.session_state['current_frame_sidebar_selector'] = display_number + st.session_state["creative_process_manual_select"] = 3 + ''' st.session_state['prev_frame_index'] = st.session_state['current_frame_index'] = display_number st.session_state['current_frame_uuid'] = timing_list[st.session_state['current_frame_index'] - 1].uuid st.session_state['frame_styling_view_type_manual_select'] = 2 @@ -176,4 +146,5 @@ def jump_to_single_frame_view_button(display_number, timing_list, src,uuid=None) st.session_state["creative_process_manual_select"] = 4 st.session_state["styling_view_selector_manual_select"] = 0 st.session_state['page'] = "Key Frames" + ''' st.rerun() diff --git a/ui_components/widgets/frame_selector.py b/ui_components/widgets/frame_selector.py index a81dd7be..d9d3b431 100644 --- a/ui_components/widgets/frame_selector.py +++ b/ui_components/widgets/frame_selector.py @@ -10,7 +10,7 @@ -def frame_selector_widget(show: List[str]): +def frame_selector_widget(show_frame_selector=True): data_repo = DataRepo() timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) @@ -21,49 +21,63 @@ def frame_selector_widget(show: List[str]): if 'prev_shot_index' not in st.session_state: st.session_state['prev_shot_index'] = shot.shot_idx - if 'shot_selector' in show: + shot1, shot2 = st.columns([1, 1]) + with shot1: shot_names = [s.name for s in shot_list] shot_name = st.selectbox('Shot name:', shot_names, key="current_shot_sidebar_selector",index=shot_names.index(shot.name)) - # find shot index based on shot name - st.session_state['current_shot_index'] = shot_names.index(shot_name) + 1 + # find shot index based on shot name + st.session_state['current_shot_index'] = shot_names.index(shot_name) + 1 - if shot_name != shot.name: - st.session_state["shot_uuid"] = shot_list[shot_names.index(shot_name)].uuid - st.rerun() + if shot_name != shot.name: + st.session_state["shot_uuid"] = shot_list[shot_names.index(shot_name)].uuid + st.rerun() - if not ('current_shot_index' in st.session_state and st.session_state['current_shot_index']): - st.session_state['current_shot_index'] = shot_names.index(shot_name) + 1 - update_current_shot_index(st.session_state['current_shot_index']) + if not ('current_shot_index' in st.session_state and st.session_state['current_shot_index']): + st.session_state['current_shot_index'] = shot_names.index(shot_name) + 1 + update_current_shot_index(st.session_state['current_shot_index']) # st.write if frame_selector is present + + + if st.session_state['page'] == "Key Frames": + if st.session_state['current_frame_index'] > len_timing_list: + update_current_frame_index(len_timing_list) - if 'frame_selector' in show: - - if st.session_state['page'] == "Key Frames": - if st.session_state['current_frame_index'] > len_timing_list: - update_current_frame_index(len_timing_list) - - elif st.session_state['page'] == "Shots": - if st.session_state['current_shot_index'] > len(shot_list): - update_current_shot_index(len(shot_list)) + elif st.session_state['page'] == "Shots": + if st.session_state['current_shot_index'] > len(shot_list): + update_current_shot_index(len(shot_list)) - + if show_frame_selector: if len(timing_list): if 'prev_frame_index' not in st.session_state or st.session_state['prev_frame_index'] > len(timing_list): - st.session_state['prev_frame_index'] = 1 + + # Create a list of frames with a blank value as the first item + frame_list = [''] + [f'{i+1}' for i in range(len(timing_list))] + + + with shot2: + frame_selection = st_memory.selectbox('Select a frame:', frame_list, key="current_frame_sidebar_selector") - st.session_state['current_frame_index'] = st.number_input(f"Key frame # (out of {len(timing_list)})", 1, - len(timing_list), value=st.session_state['prev_frame_index'], - step=1, key="current_frame_sidebar_selector") - - update_current_frame_index(st.session_state['current_frame_index']) + # Only trigger the frame number extraction and current frame index update if a non-empty value is selected + if frame_selection != '': + + if st.button("Jump to shot view",use_container_width=True): + st.session_state['current_frame_sidebar_selector'] = 0 + st.rerun() + # st.session_state['creative_process_manual_select'] = 4 + st.session_state['current_frame_index'] = int(frame_selection.split(' ')[-1]) + update_current_frame_index(st.session_state['current_frame_index']) else: + frame_selection = "" st.error("No frames present") + return frame_selection + def frame_view(view="Key Frame"): data_repo = DataRepo() # time1, time2 = st.columns([1,1]) - st.markdown("***") + # st.markdown("***") + st.write("") timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) shot = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"]) diff --git a/ui_components/widgets/image_zoom_widgets.py b/ui_components/widgets/image_zoom_widgets.py index a22efcc7..c731b1b6 100644 --- a/ui_components/widgets/image_zoom_widgets.py +++ b/ui_components/widgets/image_zoom_widgets.py @@ -78,7 +78,7 @@ def save_zoomed_image(image, timing_uuid, stage, promote=False): styled_image.uuid, timing_uuid) if promote: promote_image_variant(timing_uuid, number_of_image_variants - 1) - + ''' project_update_data = { "zoom_level": st.session_state['zoom_level_input'], "rotation_angle_value": st.session_state['rotation_angle_input'], @@ -93,7 +93,9 @@ def save_zoomed_image(image, timing_uuid, stage, promote=False): "zoom_details": f"{st.session_state['zoom_level_input']},{st.session_state['rotation_angle_input']},{st.session_state['x_shift']},{st.session_state['y_shift']}", } + data_repo.update_specific_timing(timing_uuid, **timing_update_data) + ''' def reset_zoom_element(): st.session_state['zoom_level_input_key'] = 100 diff --git a/ui_components/widgets/shot_view.py b/ui_components/widgets/shot_view.py index 68eae2f8..2b42b668 100644 --- a/ui_components/widgets/shot_view.py +++ b/ui_components/widgets/shot_view.py @@ -57,10 +57,11 @@ def shot_keyframe_element(shot_uuid, items_per_row, position="Timeline", **kwarg with col3: move_frames_toggle = st_memory.toggle("Move Frames", value=True, key="move_frames_toggle") with col4: - replace_image_widget_toggle = st_memory.toggle("Replace Image", value=False, key="replace_image_widget_toggle") - - with col5: change_shot_toggle = st_memory.toggle("Change Shot", value=False, key="change_shot_toggle") + # replace_image_widget_toggle = st_memory.toggle("Replace Image", value=False, key="replace_image_widget_toggle") + + + st.markdown("***") @@ -85,7 +86,7 @@ def shot_keyframe_element(shot_uuid, items_per_row, position="Timeline", **kwarg st.warning("No primary image present.") jump_to_single_frame_view_button(idx + 1, timing_list, f"jump_to_{idx + 1}",uuid=shot.uuid) if position != "Timeline": - timeline_view_buttons(idx, shot_uuid, replace_image_widget_toggle, copy_frame_toggle, move_frames_toggle,delete_frames_toggle, change_shot_toggle) + timeline_view_buttons(idx, shot_uuid, copy_frame_toggle, move_frames_toggle,delete_frames_toggle, change_shot_toggle) if (i < len(timing_list) - 1) or (st.session_state["open_shot"] == shot.uuid) or (len(timing_list) % items_per_row != 0 and st.session_state["open_shot"] != shot.uuid): st.markdown("***") # st.markdown("***") @@ -219,30 +220,8 @@ def update_shot_duration(shot_uuid): time.sleep(0.3) st.rerun() -def shot_video_element(shot_uuid): - data_repo = DataRepo() - - shot: InternalShotObject = data_repo.get_shot_from_uuid(shot_uuid) - - st.info(f"##### {shot.name}") - if shot.main_clip and shot.main_clip.location: - st.video(shot.main_clip.location) - else: - st.warning('''No video present''') - switch1,switch2 = st.columns([1,1]) - with switch1: - shot_adjustment_button(shot) - with switch2: - shot_animation_button(shot) - with st.expander("Details", expanded=False): - update_shot_name(shot.uuid) - update_shot_duration(shot.uuid) - move_shot_buttons(shot, "side") - delete_shot_button(shot.uuid) - if shot.main_clip: - create_video_download_button(shot.main_clip.location, tag="main_clip") @@ -274,6 +253,7 @@ def shot_adjustment_button(shot, show_label=False): button_label = "Shot Adjustment 🔧" if show_label else "🔧" if st.button(button_label, key=f"jump_to_shot_adjustment_{shot.uuid}", help=f"Shot adjustment view for '{shot.name}'", use_container_width=True): st.session_state["shot_uuid"] = shot.uuid + st.session_state['current_frame_sidebar_selector'] = 0 st.session_state['creative_process_manual_select'] = 3 st.session_state["manual_select"] = 1 st.session_state['shot_view_manual_select'] = 1 @@ -284,21 +264,19 @@ def shot_animation_button(shot, show_label=False): button_label = "Shot Animation 🎞️" if show_label else "🎞️" if st.button(button_label, key=f"jump_to_shot_animation_{shot.uuid}", help=f"Shot animation view for '{shot.name}'", use_container_width=True): st.session_state["shot_uuid"] = shot.uuid - st.session_state['creative_process_manual_select'] = 5 - st.session_state["manual_select"] = 1 + st.session_state['creative_process_manual_select'] = 4 + # st.session_state["manual_select"] = 1 st.session_state['shot_view_manual_select'] = 0 st.session_state['shot_view_index'] = 0 st.rerun() -def timeline_view_buttons(idx, shot_uuid, replace_image_widget_toggle, copy_frame_toggle, move_frames_toggle, delete_frames_toggle, change_shot_toggle): +def timeline_view_buttons(idx, shot_uuid, copy_frame_toggle, move_frames_toggle, delete_frames_toggle, change_shot_toggle): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) timing_list = shot.timing_list - if replace_image_widget_toggle: - replace_image_widget(timing_list[idx].uuid, stage=WorkflowStageType.STYLED.value, options=["Uploaded Frame"]) btn1, btn2, btn3, btn4 = st.columns([1, 1, 1, 1]) diff --git a/ui_components/widgets/timeline_view.py b/ui_components/widgets/timeline_view.py index 95fc2d1f..f1c8aef8 100644 --- a/ui_components/widgets/timeline_view.py +++ b/ui_components/widgets/timeline_view.py @@ -1,6 +1,6 @@ import streamlit as st from ui_components.methods.common_methods import add_new_shot -from ui_components.widgets.shot_view import shot_keyframe_element, shot_video_element +from ui_components.widgets.shot_view import shot_keyframe_element, shot_adjustment_button, shot_animation_button, update_shot_name, update_shot_duration, move_shot_buttons, delete_shot_button, create_video_download_button from utils.data_repo.data_repo import DataRepo from utils import st_memory @@ -9,13 +9,14 @@ def timeline_view(shot_uuid, stage): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) shot_list = data_repo.get_shot_list(shot.project.uuid) + timing_list: List[InternalFrameTimingObject] = shot.timing_list _, header_col_2 = st.columns([5.5,1.5]) - with header_col_2: - items_per_row = st_memory.slider("How many frames per row?", min_value=3, max_value=7, value=5, step=1, key="items_per_row_slider") - + #with header_col_2: + #items_per_row = st_memory.slider("How many frames per row?", min_value=3, max_value=7, value=5, step=1, key="items_per_row_slider") + ''' if stage == 'Key Frames': for shot in shot_list: with st.expander(f"_-_-_-_", expanded=True): @@ -27,18 +28,58 @@ def timeline_view(shot_uuid, stage): add_new_shot_element(shot, data_repo) else: - for idx, shot in enumerate(shot_list): - if idx % items_per_row == 0: - grid = st.columns(items_per_row) - with grid[idx % items_per_row]: - shot_video_element(shot.uuid) - if (idx + 1) % items_per_row == 0 or idx == len(shot_list) - 1: - st.markdown("***") - # if stage isn't - if idx == len(shot_list) - 1: - with grid[(idx + 1) % items_per_row]: - st.markdown("### Add new shot") - add_new_shot_element(shot, data_repo) + ''' + items_per_row = 4 + for idx, shot in enumerate(shot_list): + timing_list: List[InternalFrameTimingObject] = shot.timing_list + if idx % items_per_row == 0: + grid = st.columns(items_per_row) + + + with grid[idx % items_per_row]: + st.info(f"##### {shot.name}") + if stage == "Key Frames": + for i in range(0, len(timing_list), items_per_row): + if i % items_per_row == 0: + grid_timing = st.columns(items_per_row) + for j in range(items_per_row): + # idx = i + j + if i + j < len(timing_list): + with grid_timing[j]: + timing = timing_list[ i + j] + if timing.primary_image and timing.primary_image.location: + st.image(timing.primary_image.location, use_column_width=True) + else: + + if shot.main_clip and shot.main_clip.location: + st.video(shot.main_clip.location) + else: + st.warning('''No video present''') + + + switch1,switch2 = st.columns([1,1]) + with switch1: + shot_adjustment_button(shot) + with switch2: + shot_animation_button(shot) + + with st.expander("Details & settings", expanded=False): + update_shot_name(shot.uuid) + update_shot_duration(shot.uuid) + move_shot_buttons(shot, "side") + delete_shot_button(shot.uuid) + if shot.main_clip: + create_video_download_button(shot.main_clip.location, tag="main_clip") + + + if (idx + 1) % items_per_row == 0 or idx == len(shot_list) - 1: + st.markdown("***") + + # st.write(idx, len(shot_list) - 1, (idx + 1) % items_per_row, idx == len(shot_list) - 1) + if idx == len(shot_list) - 1: + with grid[(idx + 1) % items_per_row]: + st.markdown("### Add new shot") + add_new_shot_element(shot, data_repo) diff --git a/utils/common_utils.py b/utils/common_utils.py index aa70e8ef..65b13a07 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -246,5 +246,5 @@ def release_lock(key): def refresh_app(maintain_state=False): - st.session_state['maintain_state'] = maintain_state + # st.session_state['maintain_state'] = maintain_state st.rerun() \ No newline at end of file diff --git a/utils/ml_processor/replicate/utils.py b/utils/ml_processor/replicate/utils.py index bce7b5e5..3dd3d143 100644 --- a/utils/ml_processor/replicate/utils.py +++ b/utils/ml_processor/replicate/utils.py @@ -72,8 +72,8 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): data = { "prompt" : query_obj.prompt, "negative_prompt" : query_obj.negative_prompt, - "width" : 768 if query_obj.width == 512 else 1024, # 768 is the default for sdxl - "height" : 768 if query_obj.height == 512 else 1024, + "width" : 1024 if query_obj.width == 512 else 1024, # 768 is the default for sdxl + "height" : 1024 if query_obj.height == 512 else 1024, "prompt_strength": query_obj.strength, "mask": mask }