Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add example for TRELLIS #1022

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 252 additions & 0 deletions misc/trellis3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
"This example originally contributed by @sandeeppatra96 and @patraxo on GitHub"
import logging
import tempfile
import traceback

import modal
import requests
from fastapi import HTTPException, Request, Response, status

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

REPO_URL = "https://github.com/microsoft/TRELLIS.git"
MODEL_NAME = "JeffreyXiang/TRELLIS-image-large"
TRELLIS_DIR = "/trellis"
MINUTES = 60
HOURS = 60 * MINUTES

cuda_version = "12.2.0"
flavor = "devel"
os_version = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{os_version}"


def clone_repository():
import subprocess

subprocess.run(
["git", "clone", "--recurse-submodules", REPO_URL, TRELLIS_DIR],
check=True,
)


# The specific version of torch==2.4.0 to circumvent the flash attention wheel build error

trellis_image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install(
"git",
"ffmpeg",
"cmake",
"clang",
"build-essential",
"libgl1-mesa-glx",
"libglib2.0-0",
"libgomp1",
"libxrender1",
"libxext6",
"ninja-build",
)
.pip_install("packaging", "ninja", "torch==2.4.0", "wheel", "setuptools")
.env(
{
# "MAX_JOBS": "16", # in case flash attention takes more time to build
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"CC": "clang",
"CXX": "clang++",
"CUDAHOSTCXX": "clang++",
"CUDA_HOME": "/usr/local/cuda-12.2",
"CPATH": "/usr/local/cuda-12.2/targets/x86_64-linux/include",
"LIBRARY_PATH": "/usr/local/cuda-12.2/targets/x86_64-linux/lib64",
"LD_LIBRARY_PATH": "/usr/local/cuda-12.2/targets/x86_64-linux/lib64",
"CFLAGS": "-Wno-narrowing",
"CXXFLAGS": "-Wno-narrowing",
"ATTN_BACKEND": "flash-attn", # or 'xformers'
"SPCONV_ALGO": "native", # or 'auto'
}
)
.pip_install("flash-attn==2.6.3", extra_options="--no-build-isolation")
.pip_install(
"git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8",
"numpy",
"pillow",
"imageio",
"onnxruntime",
"trimesh",
"safetensors",
"easydict",
"scipy",
"tqdm",
"einops",
"xformers",
"hf_transfer",
"opencv-python-headless",
"largesteps",
"spconv-cu118",
"rembg",
"torchvision",
"imageio-ffmpeg",
"xatlas",
"pyvista",
"pymeshfix",
"igraph",
"git+https://github.com/NVIDIAGameWorks/kaolin.git",
"https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl",
# "git+https://github.com/NVlabs/nvdiffrast.git", # build failed
"huggingface-hub",
"https://github.com/camenduru/wheels/releases/download/3090/diso-0.1.4-cp310-cp310-linux_x86_64.whl",
"https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl",
)
.pip_install("fastapi[standard]==0.115.6")
.entrypoint([])
.run_function(clone_repository)
)

app = modal.App(name="example-trellis-3d")


@app.cls(
image=trellis_image,
gpu=modal.gpu.L4(count=1),
timeout=1 * HOURS,
container_idle_timeout=1 * MINUTES,
)
class Model:
@modal.enter()
@modal.build()
def initialize(self):
import sys

sys.path.append(TRELLIS_DIR)

from trellis.pipelines import TrellisImageTo3DPipeline

try:
self.pipe = TrellisImageTo3DPipeline.from_pretrained(MODEL_NAME)
self.pipe.cuda()
logger.info("TRELLIS model initialized successfully")
except Exception as e:
error_msg = f"Error during model initialization: {str(e)}"
logger.error(error_msg)
logger.error(f"Traceback: {traceback.format_exc()}")
raise

def process_image(
self,
image_url: str,
simplify: float,
texture_size: int,
sparse_sampling_steps: int,
sparse_sampling_cfg: float,
slat_sampling_steps: int,
slat_sampling_cfg: int,
seed: int,
output_format: str,
):
import io
import os

from PIL import Image

try:
response = requests.get(image_url)
if response.status_code != 200:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to download image from provided URL",
)

image = Image.open(io.BytesIO(response.content))

logger.info("Starting model inference...")
outputs = self.pipe.run(
image,
seed=seed,
sparse_structure_sampler_params={
"steps": sparse_sampling_steps,
"cfg_strength": sparse_sampling_cfg,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_sampling_cfg,
},
)
logger.info("Model inference completed successfully")

if output_format == "glb":
from trellis.utils import postprocessing_utils

glb = postprocessing_utils.to_glb(
outputs["gaussian"][0],
outputs["mesh"][0],
simplify=simplify,
texture_size=texture_size,
)

temp_glb = tempfile.NamedTemporaryFile(
suffix=".glb", delete=False
)
temp_path = temp_glb.name
logger.info(f"Exporting mesh to: {temp_path}")
glb.export(temp_path)
temp_glb.close()

try:
with open(temp_path, "rb") as file:
content = file.read()
if os.path.exists(temp_path):
os.unlink(temp_path)
logger.info("Temp file cleaned up")
return Response(
content=content,
media_type="model/gltf-binary",
headers={
"Content-Disposition": "attachment; filename=output.glb",
},
)
except Exception as e:
if os.path.exists(temp_path):
os.unlink(temp_path)
raise e

else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported output format: {output_format}",
)

except Exception as e:
error_msg = f"Error during processing: {str(e)}"
logger.error(error_msg)
logger.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=error_msg,
)

@modal.web_endpoint(method="GET", docs=True)
async def generate(
self,
request: Request,
image_url: str,
simplify: float = 0.95,
texture_size: int = 1024,
sparse_sampling_steps: int = 12,
sparse_sampling_cfg: float = 7.5,
slat_sampling_steps: int = 12,
slat_sampling_cfg: int = 3,
seed: int = 42,
output_format: str = "glb",
):
return self.process_image(
image_url,
simplify,
texture_size,
sparse_sampling_steps,
sparse_sampling_cfg,
slat_sampling_steps,
slat_sampling_cfg,
seed,
output_format,
)