Skip to content

Commit

Permalink
support hf diffusers
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarx authored Sep 30, 2022
1 parent b115952 commit 7e0ee92
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 36 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,13 @@ greater detail after the implementation stabilizes a bit more.

Most of the functionality in this notebook has been offloaded to library I published to pypi called `vktrs`. I strongly encourage you to import anything you need
from there rather than cutting and pasting function into a notebook. Similarly, if you have ideas for improvements, please don't hesitate to submit a PR!

## Dev notes

installing unreleased package in colab:

```
!pip install --upgrade setuptools build
!git clone --branch hf https://github.com/dmarx/video-killed-the-radio-star/
!cd video-killed-the-radio-star; python -m build; python -m pip install .[api,hf]
```
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.1
0.1.2
74 changes: 41 additions & 33 deletions Video_Killed_The_Radio_Star_Defusion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"source": [
"%%capture\n",
"# @title # 0. Setup\n",
"!pip install vktrs[api]"
"!pip install vktrs[api,hf]"
]
},
{
Expand Down Expand Up @@ -124,7 +124,21 @@
"\n",
"if use_stability_api:\n",
" import os, getpass\n",
" os.environ['STABILITY_KEY'] = getpass.getpass('Enter your API Key')\n"
" os.environ['STABILITY_KEY'] = getpass.getpass('Enter your API Key')\n",
"else:\n",
" # use diffusers\n",
" !pip install diffusers\n",
" !pip install \"ipywidgets>=7,<8\"\n",
" !pip install transformers\n",
"\n",
" !sudo apt -qq install git-lfs\n",
" !git config --global credential.helper store\n",
"\n",
" from google.colab import output\n",
" from huggingface_hub import notebook_login\n",
"\n",
" output.enable_custom_widget_manager()\n",
" notebook_login()"
]
},
{
Expand Down Expand Up @@ -176,7 +190,7 @@
"\n",
"storyboard.params = dict(\n",
"\n",
" video_url = 'https://www.youtube.com/watch?v=WJaxFbdjm8c' # @param {type:'string'}\n",
" video_url = 'https://www.youtube.com/watch?v=REojIUxX4rw' # @param {type:'string'}\n",
" #, audio_fpath = '' # @param {type:'string'} # TO DO: drop reliance on youtube for audio\n",
" , audio_fpath = None\n",
" , theme_prompt = \"extremely detailed, painted by ralph steadman and radiohead, beautiful, wow\" # @param {type:'string'}\n",
Expand Down Expand Up @@ -432,28 +446,16 @@
"source": [
"# @title # 6. 🙭 Generate init images\n",
"\n",
"from tqdm.autonotebook import tqdm\n",
"from omegaconf import OmegaConf\n",
"\n",
"storyboard_fname = 'storyboard.yaml'\n",
"storyboard = OmegaConf.load(storyboard_fname)\n",
"\n",
"prompt_starts = storyboard.prompt_starts\n",
"\n",
"# force use api for now\n",
"storyboard.params.use_stability_api = True\n",
"\n",
"import copy\n",
"import datetime as dt\n",
"import string\n",
"from omegaconf import OmegaConf\n",
"from pathlib import Path\n",
"import random\n",
"import string\n",
"from tqdm.autonotebook import tqdm\n",
"\n",
"import PIL\n",
"from pathlib import Path\n",
"\n",
"from vktrs.api import (\n",
" get_image_for_prompt\n",
")\n",
"from vktrs.tsp import (\n",
" tsp_permute_frames,\n",
" batched_tsp_permute_frames,\n",
Expand All @@ -462,18 +464,23 @@
"from vktrs.utils import (\n",
" add_caption2image,\n",
" save_frame,\n",
" remove_punctuation,\n",
")\n",
"\n",
"from vktrs.utils import remove_punctuation\n",
"\n",
"if storyboard.params.use_stability_api:\n",
"storyboard_fname = 'storyboard.yaml'\n",
"storyboard = OmegaConf.load(storyboard_fname)\n",
"\n",
"prompt_starts = storyboard.prompt_starts\n",
"use_stability_api = storyboard.params.use_stability_api\n",
"\n",
"\n",
"if use_stability_api:\n",
" from vktrs.api import get_image_for_prompt\n",
"else:\n",
" raise NotImplementedError(\n",
" 'Image generation with this notebook currently depends on the stability api. '\n",
" 'Support for inference using the huggingface diffusers library (i.e. no api required) '\n",
" 'will be added soon.'\n",
" )\n",
" from vktrs.hf import HfHelper\n",
" helper = HfHelper()\n",
" get_image_for_prompt = helper.get_image_for_prompt\n",
"\n",
"\n",
"def get_variations_w_init(prompt, init_image, **kargs):\n",
Expand All @@ -494,15 +501,10 @@
" return images\n",
"\n",
"\n",
"#frames = []\n",
"\n",
"theme_prompt = storyboard.params.theme_prompt\n",
"optimal_ordering = storyboard.params.optimal_ordering\n",
"add_caption = False # storyboard.params.add_caption\n",
"\n",
"display_frames_as_we_get_them = storyboard.params.display_frames_as_we_get_them\n",
"max_frames = storyboard.params.max_frames\n",
"image_consistency = storyboard.params.image_consistency\n",
"n_variations = storyboard.params.n_variations\n",
"max_variations_per_opt_pass = storyboard.params.max_variations_per_opt_pass\n",
"\n",
"height = storyboard.params.height\n",
"width = storyboard.params.width\n",
Expand Down Expand Up @@ -578,9 +580,15 @@
"storyboard_fname = 'storyboard.yaml'\n",
"storyboard = OmegaConf.load(storyboard_fname)\n",
"prompt_starts = OmegaConf.to_container(storyboard.prompt_starts, resolve=True)\n",
"\n",
"add_caption = storyboard.params.get('add_caption')\n",
"optimal_ordering = storyboard.params.optimal_ordering\n",
"display_frames_as_we_get_them = storyboard.params.display_frames_as_we_get_them\n",
"image_consistency = storyboard.params.image_consistency\n",
"max_frames = storyboard.params.max_frames\n",
"max_variations_per_opt_pass = storyboard.params.max_variations_per_opt_pass\n",
"n_variations = storyboard.params.n_variations\n",
"\n",
"\n",
"# load init_images and generate variations as needed\n",
"# to do: use SDK args to request multiple images in single request...\n",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dynamic = ["version"]

[project.optional-dependencies]
api = ["stability-sdk>=0.2.1"]
hf = ["diffusers","transformers","ftfy"]

[tool.setuptools.packages.find]
where =["."]
Expand Down
4 changes: 2 additions & 2 deletions vktrs/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def whisper_transmit_meta_across_alignment(
else:
rec_prev = token_large_index_segmentations[i-1]
rec_large['start'] = rec_prev['start']
rec_large['end'] = rec_prev['end']
rec_large['end'] = rec_prev.get('end')

token_large_index_segmentations[i] = rec_large

Expand All @@ -132,7 +132,7 @@ def whisper_segment_transcription(
print("still in phrase")
current_phrase.append(rec['token'])
start_prev = rec['start']
end_prev = rec['end']
end_prev = rec.get('end')
continue

# we're in the next phrase,
Expand Down
87 changes: 87 additions & 0 deletions vktrs/hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from pathlib import Path
import torch
from torch import autocast
from diffusers import (
StableDiffusionImg2ImgPipeline,
StableDiffusionPipeline,
)

# weird, why didn't this install with vktrs?
#!pip install pytokenizations yt-dlp python-tsp webvtt-py

#use_stability_api = False


# to do: rename "start_schedule" to "strength"
# start_schedule=(1-image_consistency))


class HfHelper:
def __init__(
self,
device = 'cuda',
device_img2img = None,
device_text2img = None,
model_path = '.',
model_id = "CompVis/stable-diffusion-v1-4",
download=True,
):
if not device_img2img:
device_img2img = device
if not device_text2img:
device_text2img = device
self.device = device
self.device_img2img = device_img2img
self.device_text2img = device_text2img
self.model_path = model_path
self.model_id = model_id
self.download = download
self.load_pipelines()

def load_pipelines(
self,
):
if self.download:
img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
self.model_id,
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True
)
img2img = img2img.to(self.device)
img2img.save_pretrained(self.model_path)
else:
img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
self.model_path,
local_files_only=True
).to(self.device)

text2img = StableDiffusionPipeline(
vae=img2img.vae,
text_encoder=img2img.text_encoder,
tokenizer=img2img.tokenizer,
unet=img2img.unet,
feature_extractor=img2img.feature_extractor,
scheduler=img2img.scheduler,
safety_checker=img2img.safety_checker,
)
#return text2img, img2img
text2img.enable_attention_slicing()
img2img.enable_attention_slicing()
self.text2img = text2img
self.img2img = img2img

def get_image_for_prompt(
self,
prompt,
**kwargs
):
f = self.text2img if kwargs.get('init_image') is None else self.img2img
#if kwargs.get('image_consistency') is not None:
#kwargs['strength'] = 1- kwargs['image_consistency']
if kwargs.get('start_schedule') is not None:
#kwargs['strength'] = kwargs['start_schedule']
kwargs['strength'] = kwargs.pop('start_schedule')
with autocast(self.device):
return f(prompt, **kwargs).images

0 comments on commit 7e0ee92

Please sign in to comment.