Skip to content

Commit

Permalink
More linting - 95% done
Browse files Browse the repository at this point in the history
  • Loading branch information
inFocus7 committed Jan 15, 2024
1 parent fd7b6c3 commit 1764076
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 23 deletions.
22 changes: 18 additions & 4 deletions ui/listicles/interface.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""
The interface for the Listicles section of the web app.
"""
import json
import gradio as gr
import processing.image as image_processing
import json
import ui.listicles.utils as listicle_utils
import ui.components.openai as openai_components
import utils.gradio as gru


def render_listicles_section():
def render_listicles_section() -> None:
"""
Renders the Listicles section of the web app.
"""
gru.render_tool_description("Create images in the style of those 'Your birth month is your ___' TikToks.")
with gr.Tab("Generate Artifacts"):
send_artifacts_to_batch_button, listicle_image_output, listicle_json_output = render_generate_section()
Expand All @@ -20,7 +26,11 @@ def render_listicles_section():
)


def render_batch_section():
def render_batch_section() -> (gr.File, gr.Code):
"""
Renders the Batch Image Processing section of the web app.
:return: The input images and input json components.
"""
with gr.Column():
gr.Markdown("# Input")
with gr.Row(equal_height=False):
Expand Down Expand Up @@ -106,7 +116,11 @@ def set_json(json_file):
return input_batch_images, input_batch_json


def render_generate_section():
def render_generate_section() -> (gr.Button, gr.Gallery, gr.Code):
"""
Renders the Generate Artifacts section of the web app.
:return: The send artifacts to batch button, the listicle image output gallery, and the listicle json output.
"""
api_key, api_text_model, api_image_model = openai_components.render_openai_setup()
with gr.Row(equal_height=False):
with gr.Group():
Expand Down
52 changes: 44 additions & 8 deletions ui/listicles/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import gradio as gr
"""
This file contains the functions that are used by the Gradio UI to generate listicles.
"""
import os
import json
from typing import Optional
import gradio as gr
import processing.image as image_processing
import os
from utils import font_manager, image as image_utils
import api.chatgpt as chatgpt_api

Expand Down Expand Up @@ -93,7 +97,12 @@ def process(image_files, json_data,
return images


def validate_json(json_file):
def validate_json(json_file: str) -> None:
"""
Validates the JSON file to make sure it has the required fields.
:param json_file: The JSON file to validate.
:return: None
"""
if not json_file or len(json_file) == 0:
gr.Warning("No JSON in the code block.")
return
Expand All @@ -117,7 +126,13 @@ def validate_json(json_file):
gr.Info("JSON is valid!")


def send_artifacts_to_batch(listicle_images, json_data):
def send_artifacts_to_batch(listicle_images: gr.data_classes.RootModel, json_data: str) -> (list, str):
"""
Sends the artifacts to the batch processing section.
:param listicle_images: The list of images to send. This is a Gradio Gallery.
:param json_data: The JSON data to send.
:return: The list of images and the JSON data sent.
"""
if not listicle_images or len(listicle_images.root) == 0:
gr.Warning("No images to send.")
return
Expand All @@ -130,10 +145,17 @@ def send_artifacts_to_batch(listicle_images, json_data):
return listicle_images, json_data


def save_artifacts(listicle_images, image_type, json_data):
def save_artifacts(listicle_images: gr.data_classes.RootModel, image_type: gr.Dropdown, json_data: str) -> None:
"""
Saves the artifacts to disk.
:param listicle_images: The list of images to save. This is a Gradio Gallery.
:param image_type: The type of image to save.
:param json_data: The JSON data to save.
:return: None
"""
if not json_data or len(json_data) == 0:
gr.Warning("No JSON data to save.")
return
return None

# Save the images
save_dir = image_processing.save_images_to_disk(listicle_images, image_type)
Expand All @@ -148,8 +170,22 @@ def save_artifacts(listicle_images, image_type, json_data):
gr.Info(f"Saved generated artifacts to {save_dir}.")


def generate_listicle(api_key, api_text_model, api_image_model, number_of_items, topic, association,
rating_type, details="", generate_artifacts=False):
def generate_listicle(api_key: str, api_text_model: str, api_image_model: str, number_of_items: int, topic: str,
association: str, rating_type: str, details: str = "", generate_artifacts: bool = False) \
-> (Optional[str], Optional[str], Optional[list[str]]):
"""
Generates a listicle using the OpenAI API.
:param api_key: The OpenAI API key to use.
:param api_text_model: The OpenAI API text model to use (e.g. 'gpt-4').
:param api_image_model: The OpenAI API image model to use (e.g. 'dall-e-3').
:param number_of_items: The number of items to generate.
:param topic: The topic of the listicle.
:param association: What each item is associated with.
:param rating_type: What the rating represents.
:param details: Additional details about the listicle you want to generate.
:param generate_artifacts: Whether to generate artifacts (images and JSON) for the listicle.
:return: The listicle content, the listicle JSON, and the listicle images.
"""
openai = chatgpt_api.get_openai_client(api_key)
if openai is None:
gr.Warning("No OpenAI client. Cannot generate listicle.")
Expand Down
9 changes: 4 additions & 5 deletions ui/music/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
import subprocess
import re
import time
import cv2
import tempfile
from typing import List, Dict, Optional
import cv2
from moviepy.editor import AudioFileClip
import numpy as np
import librosa
from utils import font_manager
import utils.image as image_utils
import numpy as np
import tempfile
from api import chatgpt as chatgpt_api
from processing import image as image_processing
import librosa
from utils import progress, visualizer
import cProfile


def analyze_audio(audio_path: str, target_fps: int) -> (List[Dict[float, float]], np.ndarray):
Expand Down
9 changes: 5 additions & 4 deletions utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ def print_progress_bar(current_iteration: int, total_iterations: int, bar_length
start_time: Optional[float] = None, end: str = ''):
progress_percentage = (current_iteration / total_iterations) * 100
completed_length = int(bar_length * current_iteration // total_iterations)
bar = '█' * completed_length + '░' * (bar_length - completed_length)
progress_bar = '█' * completed_length + '░' * (bar_length - completed_length)

elapsed_time = None
estimated_remaining_time = None
iterations_per_sec = None
if start_time is not None:
elapsed_time = time.time() - start_time
if current_iteration > 0:
Expand All @@ -19,7 +21,6 @@ def print_progress_bar(current_iteration: int, total_iterations: int, bar_length
estimated_remaining_time = None

time_string = ''
if estimated_remaining_time is not None:
if estimated_remaining_time is not None and iterations_per_sec is not None:
time_string = f'[{elapsed_time:.2f}s/{estimated_remaining_time:.2f}s, {iterations_per_sec:.2f}it/s]'
print(f'\r{progress_percentage:3.0f}%|{bar}| {current_iteration}/{total_iterations} {time_string}', end=end, flush=True)

print(f'\r{progress_percentage:3.0f}%|{progress_bar}| {current_iteration}/{total_iterations} {time_string}', end=end, flush=True)
24 changes: 22 additions & 2 deletions utils/visualizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
"""
This module defines the Visualizer class, which is used to draw the visualizer on the canvas.
"""
from typing import Dict, Optional
import numpy as np
import cv2


class Visualizer:
"""
This class is used to draw the visualizer on the canvas.
Will be replaced with a more general solution in the future to allow for more customization.
"""
def __init__(self, base_size, max_size, color, dot_count, width, height):
self.base_size = base_size
self.max_size = max_size
Expand All @@ -13,7 +21,11 @@ def __init__(self, base_size, max_size, color, dot_count, width, height):
self.cached_dot_positions = None
self.cached_resized_drawing = {}

def initialize_static_values(self):
def initialize_static_values(self: "Visualizer") -> None:
"""
Initializes static values for the visualizer.
:return: None.
"""
# Calculate and store dot positions
x_positions = (self.width / self.dot_count[0]) * np.arange(self.dot_count[0]) + (
self.width / self.dot_count[0] / 2)
Expand All @@ -23,7 +35,15 @@ def initialize_static_values(self):
self.cached_dot_positions = [(grid_x[y, x], grid_y[y, x]) for x in range(self.dot_count[0]) for y in
range(self.dot_count[1])]

def draw_visualizer(self, canvas, frequency_data, custom_drawing=None):
def draw_visualizer(self: "Visualizer", canvas: np.ndarray, frequency_data: Dict[float, float],
custom_drawing: Optional[np.ndarray] = None) -> None:
"""
Draws the visualizer on the canvas (a single frame).
:param canvas: The canvas to draw on.
:param frequency_data: The frequency data to use for drawing which correlates to the loudness + frequency.
:param custom_drawing: A custom drawing to use instead of the default circle.
:return: None.
"""
# Calculate and store dot positions
dot_count_x = self.dot_count[0]
dot_count_y = self.dot_count[1]
Expand Down

0 comments on commit 1764076

Please sign in to comment.