Skip to content

Commit

Permalink
add example for TRELLIS (#1022)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesfrye authored Dec 18, 2024
1 parent 24d7b4c commit cc64ef6
Showing 1 changed file with 252 additions and 0 deletions.
252 changes: 252 additions & 0 deletions misc/trellis3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
"This example originally contributed by @sandeeppatra96 and @patraxo on GitHub"
import logging
import tempfile
import traceback

import modal
import requests
from fastapi import HTTPException, Request, Response, status

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

REPO_URL = "https://github.com/microsoft/TRELLIS.git"
MODEL_NAME = "JeffreyXiang/TRELLIS-image-large"
TRELLIS_DIR = "/trellis"
MINUTES = 60
HOURS = 60 * MINUTES

cuda_version = "12.2.0"
flavor = "devel"
os_version = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{os_version}"


def clone_repository():
import subprocess

subprocess.run(
["git", "clone", "--recurse-submodules", REPO_URL, TRELLIS_DIR],
check=True,
)


# The specific version of torch==2.4.0 to circumvent the flash attention wheel build error

trellis_image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install(
"git",
"ffmpeg",
"cmake",
"clang",
"build-essential",
"libgl1-mesa-glx",
"libglib2.0-0",
"libgomp1",
"libxrender1",
"libxext6",
"ninja-build",
)
.pip_install("packaging", "ninja", "torch==2.4.0", "wheel", "setuptools")
.env(
{
# "MAX_JOBS": "16", # in case flash attention takes more time to build
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"CC": "clang",
"CXX": "clang++",
"CUDAHOSTCXX": "clang++",
"CUDA_HOME": "/usr/local/cuda-12.2",
"CPATH": "/usr/local/cuda-12.2/targets/x86_64-linux/include",
"LIBRARY_PATH": "/usr/local/cuda-12.2/targets/x86_64-linux/lib64",
"LD_LIBRARY_PATH": "/usr/local/cuda-12.2/targets/x86_64-linux/lib64",
"CFLAGS": "-Wno-narrowing",
"CXXFLAGS": "-Wno-narrowing",
"ATTN_BACKEND": "flash-attn", # or 'xformers'
"SPCONV_ALGO": "native", # or 'auto'
}
)
.pip_install("flash-attn==2.6.3", extra_options="--no-build-isolation")
.pip_install(
"git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8",
"numpy",
"pillow",
"imageio",
"onnxruntime",
"trimesh",
"safetensors",
"easydict",
"scipy",
"tqdm",
"einops",
"xformers",
"hf_transfer",
"opencv-python-headless",
"largesteps",
"spconv-cu118",
"rembg",
"torchvision",
"imageio-ffmpeg",
"xatlas",
"pyvista",
"pymeshfix",
"igraph",
"git+https://github.com/NVIDIAGameWorks/kaolin.git",
"https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl",
# "git+https://github.com/NVlabs/nvdiffrast.git", # build failed
"huggingface-hub",
"https://github.com/camenduru/wheels/releases/download/3090/diso-0.1.4-cp310-cp310-linux_x86_64.whl",
"https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl",
)
.pip_install("fastapi[standard]==0.115.6")
.entrypoint([])
.run_function(clone_repository)
)

app = modal.App(name="example-trellis-3d")


@app.cls(
image=trellis_image,
gpu=modal.gpu.L4(count=1),
timeout=1 * HOURS,
container_idle_timeout=1 * MINUTES,
)
class Model:
@modal.enter()
@modal.build()
def initialize(self):
import sys

sys.path.append(TRELLIS_DIR)

from trellis.pipelines import TrellisImageTo3DPipeline

try:
self.pipe = TrellisImageTo3DPipeline.from_pretrained(MODEL_NAME)
self.pipe.cuda()
logger.info("TRELLIS model initialized successfully")
except Exception as e:
error_msg = f"Error during model initialization: {str(e)}"
logger.error(error_msg)
logger.error(f"Traceback: {traceback.format_exc()}")
raise

def process_image(
self,
image_url: str,
simplify: float,
texture_size: int,
sparse_sampling_steps: int,
sparse_sampling_cfg: float,
slat_sampling_steps: int,
slat_sampling_cfg: int,
seed: int,
output_format: str,
):
import io
import os

from PIL import Image

try:
response = requests.get(image_url)
if response.status_code != 200:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to download image from provided URL",
)

image = Image.open(io.BytesIO(response.content))

logger.info("Starting model inference...")
outputs = self.pipe.run(
image,
seed=seed,
sparse_structure_sampler_params={
"steps": sparse_sampling_steps,
"cfg_strength": sparse_sampling_cfg,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_sampling_cfg,
},
)
logger.info("Model inference completed successfully")

if output_format == "glb":
from trellis.utils import postprocessing_utils

glb = postprocessing_utils.to_glb(
outputs["gaussian"][0],
outputs["mesh"][0],
simplify=simplify,
texture_size=texture_size,
)

temp_glb = tempfile.NamedTemporaryFile(
suffix=".glb", delete=False
)
temp_path = temp_glb.name
logger.info(f"Exporting mesh to: {temp_path}")
glb.export(temp_path)
temp_glb.close()

try:
with open(temp_path, "rb") as file:
content = file.read()
if os.path.exists(temp_path):
os.unlink(temp_path)
logger.info("Temp file cleaned up")
return Response(
content=content,
media_type="model/gltf-binary",
headers={
"Content-Disposition": "attachment; filename=output.glb",
},
)
except Exception as e:
if os.path.exists(temp_path):
os.unlink(temp_path)
raise e

else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported output format: {output_format}",
)

except Exception as e:
error_msg = f"Error during processing: {str(e)}"
logger.error(error_msg)
logger.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=error_msg,
)

@modal.web_endpoint(method="GET", docs=True)
async def generate(
self,
request: Request,
image_url: str,
simplify: float = 0.95,
texture_size: int = 1024,
sparse_sampling_steps: int = 12,
sparse_sampling_cfg: float = 7.5,
slat_sampling_steps: int = 12,
slat_sampling_cfg: int = 3,
seed: int = 42,
output_format: str = "glb",
):
return self.process_image(
image_url,
simplify,
texture_size,
sparse_sampling_steps,
sparse_sampling_cfg,
slat_sampling_steps,
slat_sampling_cfg,
seed,
output_format,
)

0 comments on commit cc64ef6

Please sign in to comment.