Skip to content

Commit

Permalink
[DOWNLOAD GIF]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Apr 16, 2024
1 parent 115c186 commit d3d8b56
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 36 deletions.
21 changes: 9 additions & 12 deletions servers/text_to_video/test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import torch
from diffusers import (
AnimateDiffPipeline,
Expand Down Expand Up @@ -37,7 +36,6 @@ def text_to_video(
str: The path to the exported GIF file.
"""
try:

device = "cuda"
dtype = torch.float16

Expand All @@ -46,14 +44,13 @@ def text_to_video(
base = "emilianJR/epiCRealism" # Choose to your favorite base model.
adapter = MotionAdapter().to(device, dtype)
adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))

pipe = AnimateDiffPipeline.from_pretrained(
base, motion_adapter=adapter, torch_dtype=dtype
).to(device)

logger.info(f"Initialized Model: {model_name}")



pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config,
timestep_spacing="trailing",
Expand All @@ -69,12 +66,12 @@ def text_to_video(
# )
# outputs.append(output)
# out = export_to_gif([output], f"{output_path}_{i}.gif")
# else:
# out = export_to_video([output], f"{output_path}_{i}.mp4")
# else:
# out = export_to_video([output], f"{output_path}_{i}.mp4")
output = pipe(
prompt = task,
guidance_scale = guidance_scale,
num_inference_steps = inference_steps
prompt=task,
guidance_scale=guidance_scale,
num_inference_steps=inference_steps,
)
out = export_to_gif(output.frames[0], output_path)
return out
Expand All @@ -84,4 +81,4 @@ def text_to_video(


out = text_to_video(task="A girl in hijab studying in a library")
print(out)
print(out)
47 changes: 23 additions & 24 deletions text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from huggingface_hub import hf_hub_download
from loguru import logger
from safetensors.torch import load_file
from fastapi.responses import FileResponse

from swarms_cloud.schema.text_to_video import TextToVideoRequest, TextToVideoResponse

Expand All @@ -33,6 +34,7 @@
allow_headers=["*"],
)


def text_to_video(
task: str,
model_name: str = "ByteDance/AnimateDiff-Lightning",
Expand All @@ -55,7 +57,6 @@ def text_to_video(
str: The path to the exported GIF file.
"""
try:

device = "cuda"
dtype = torch.float16

Expand All @@ -64,14 +65,13 @@ def text_to_video(
base = "emilianJR/epiCRealism" # Choose to your favorite base model.
adapter = MotionAdapter().to(device, dtype)
adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))

pipe = AnimateDiffPipeline.from_pretrained(
base, motion_adapter=adapter, torch_dtype=dtype
).to(device)

logger.info(f"Initialized Model: {model_name}")



pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config,
timestep_spacing="trailing",
Expand All @@ -87,16 +87,16 @@ def text_to_video(
# )
# outputs.append(output)
# out = export_to_gif([output], f"{output_path}_{i}.gif")
# else:
# out = export_to_video([output], f"{output_path}_{i}.mp4")
# else:
# out = export_to_video([output], f"{output_path}_{i}.mp4")
output = pipe(
prompt = task,
guidance_scale = guidance_scale,
num_inference_steps = inference_steps
prompt=task,
guidance_scale=guidance_scale,
num_inference_steps=inference_steps,
)

logger.info(f"Output ready: {output}")

out = export_to_gif(output.frames[0], output_path)
logger.info(f"Exported to GIF: {out}")
return out
Expand All @@ -105,7 +105,6 @@ def text_to_video(
return None



@app.post("/v1/chat/completions", response_model=TextToVideoResponse)
async def create_chat_completion(
request: TextToVideoRequest, # token: str = Depends(authenticate_user)
Expand Down Expand Up @@ -133,20 +132,20 @@ async def create_chat_completion(
# logger.error(f"Error: {e}")
# raise HTTPException(status_code=500, detail="Internal Server Error")

out = TextToVideoResponse(
status="success",
request_details=request,
video_url=response,
error=None,
)
# out = TextToVideoResponse(
# status="success",
# request_details=request,
# video_url=response,
# error=None,
# )

logger.info(f"Response: {out}")
logger.info(f"Downloading the file: {response}")
# out = FileResponse(
# path=response,
# filename=request.output_path,
# media_type="application/octet-stream",
# )
out = FileResponse(
path=response,
filename=request.output_path,
media_type="image/gif", # Use the correct media type for GIFs
)

return out
except Exception as e:
Expand Down

0 comments on commit d3d8b56

Please sign in to comment.