-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SDXL deployment example on
inf2
(#538)
## Summary - added `sdxl` deployment with tests - updated `neuron/device.py` import to be lazy so that dynamic env vars are used ## Related issues <!-- For example: "Closes #1234" --> ## Checks - [x] `make lint`: I've run `make lint` to lint the changes in this PR. - [x] `make test`: I've made sure the tests (`make test-cpu` or `make test`) are passing. - Additional tests: - [ ] Benchmark tests (when contributing new models) - [ ] GPU/HW tests
- Loading branch information
Showing
13 changed files
with
255 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
## Embeddings Service | ||
|
||
Start the server via: | ||
```bash | ||
nos serve up -c serve.yaml --http | ||
``` | ||
|
||
Optionally, you can provide the `inf2` runtime flag, but this is automatically inferred. | ||
|
||
```bash | ||
nos serve up -c serve.yaml --http --runtime inf2 | ||
``` | ||
|
||
### Run the tests | ||
|
||
```bash | ||
pytest -sv ./tests/test_embeddings_client.py | ||
``` | ||
|
||
### Call the service | ||
|
||
You can also call the service via the REST API directly: | ||
|
||
```bash | ||
curl \ | ||
-X POST http://<service-ip>:8000/v1/infer \ | ||
-H 'Content-Type: application/json' \ | ||
-d '{ | ||
"model_id": "BAAI/bge-small-en-v1.5", | ||
"inputs": { | ||
"texts": ["fox jumped over the moon"] | ||
} | ||
}' | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Usage: sky launch -c <cluster-name> job-inf2.yaml | ||
# image_id: ami-09c62125a680f0ead # us-east-2 | ||
# image_id: ami-0d4155c8606f16f5b # us-west-1 | ||
# image_id: ami-096319086cc3d5f23 # us-west-2 | ||
|
||
file_mounts: | ||
/app: . | ||
|
||
resources: | ||
cloud: aws | ||
region: us-west-2 | ||
instance_type: inf2.8xlarge | ||
image_id: ami-096319086cc3d5f23 # us-west-2 | ||
disk_size: 256 | ||
ports: | ||
- 8000 | ||
|
||
setup: | | ||
sudo apt-get install -y docker-compose-plugin | ||
cd /app && python3 -m venv .venv && source .venv/bin/activate | ||
pip install git+https://github.com/autonomi-ai/nos.git pytest | ||
run: | | ||
source /app/.venv/bin/activate | ||
cd /app && NOS_LOGGING_LEVEL=DEBUG nos serve up -c serve.yaml --http |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
"""SDXL model accelerated with AWS Neuron (using optimum-neuron).""" | ||
from dataclasses import dataclass, field, replace | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Union | ||
|
||
import torch | ||
from PIL import Image | ||
|
||
from nos.constants import NOS_CACHE_DIR | ||
from nos.hub import HuggingFaceHubConfig | ||
from nos.neuron.device import NeuronDevice | ||
|
||
|
||
@dataclass(frozen=True) | ||
class StableDiffusionInf2Config(HuggingFaceHubConfig): | ||
"""SDXL model configuration for Inf2.""" | ||
|
||
batch_size: int = 1 | ||
"""Batch size for the model.""" | ||
|
||
image_height: int = 1024 | ||
"""Height of the image.""" | ||
|
||
image_width: int = 1024 | ||
"""Width of the image.""" | ||
|
||
compiler_args: Dict[str, Any] = field( | ||
default_factory=lambda: {"auto_cast": "matmul", "auto_cast_type": "bf16"}, repr=False | ||
) | ||
"""Compiler arguments for the model.""" | ||
|
||
@property | ||
def id(self) -> str: | ||
"""Model ID.""" | ||
return f"{self.model_name}-bs-{self.batch_size}-{self.image_height}x{self.image_width}-{self.compiler_args.get('auto_cast_type', 'fp32')}" | ||
|
||
|
||
class StableDiffusionXLInf2: | ||
configs = { | ||
"stabilityai/stable-diffusion-xl-base-1.0-inf2": StableDiffusionInf2Config( | ||
model_name="stabilityai/stable-diffusion-xl-base-1.0", | ||
), | ||
} | ||
|
||
def __init__(self, model_name: str = "stabilityai/stable-diffusion-xl-base-1.0-inf2"): | ||
from nos.logging import logger | ||
|
||
NeuronDevice.setup_environment() | ||
try: | ||
cfg = StableDiffusionXLInf2.configs[model_name] | ||
except KeyError: | ||
raise ValueError(f"Invalid model_name: {model_name}, available models: {self.configs.keys()}") | ||
self.logger = logger | ||
self.model = None | ||
self.__load__(cfg) | ||
|
||
def __load__(self, cfg: StableDiffusionInf2Config): | ||
from optimum.neuron import NeuronStableDiffusionXLPipeline | ||
|
||
if self.model is not None: | ||
self.logger.debug(f"De-allocating existing model [cfg={self.cfg}, id={self.cfg.id}]") | ||
del self.model | ||
self.model = None | ||
self.cfg = cfg | ||
|
||
# Load model from cache if available, otherwise load from HF and compile | ||
# (cache is specific to model_name, batch_size and sequence_length) | ||
self.logger.debug(f"Loading model [cfg={self.cfg}, id={self.cfg.id}]") | ||
cache_dir = NOS_CACHE_DIR / "neuron" / self.cfg.id | ||
if Path(cache_dir).exists(): | ||
self.logger.debug(f"Loading model from {cache_dir}") | ||
self.model = NeuronStableDiffusionXLPipeline.from_pretrained(str(cache_dir)) | ||
self.logger.debug(f"Loaded model from {cache_dir}") | ||
else: | ||
input_shapes = { | ||
"batch_size": self.cfg.batch_size, | ||
"height": self.cfg.image_height, | ||
"width": self.cfg.image_width, | ||
} | ||
self.model = NeuronStableDiffusionXLPipeline.from_pretrained( | ||
self.cfg.model_name, export=True, **self.cfg.compiler_args, **input_shapes | ||
) | ||
self.model.save_pretrained(str(cache_dir)) | ||
self.logger.debug(f"Saved model to {cache_dir}") | ||
self.logger.debug(f"Loaded neuron model [id={self.cfg.id}]") | ||
|
||
@torch.inference_mode() | ||
def __call__( | ||
self, | ||
prompts: Union[str, List[str]], | ||
num_images: int = 1, | ||
num_inference_steps: int = 50, | ||
guidance_scale: float = 7.5, | ||
height: int = 512, | ||
width: int = 512, | ||
) -> List[Image.Image]: | ||
"""Generate images from text prompt.""" | ||
|
||
if isinstance(prompts, str): | ||
prompts = [prompts] | ||
if isinstance(prompts, list) and len(prompts) != 1: | ||
raise ValueError(f"Invalid number of prompts: {len(prompts)}, expected: 1") | ||
if height != self.cfg.image_height or width != self.cfg.image_width: | ||
cfg = replace(self.cfg, image_height=height, image_width=width) | ||
self.logger.debug(f"Re-loading model [cfg={cfg}, id={cfg.id}, prev_id={self.cfg.id}]") | ||
self.__load__(cfg) | ||
assert self.model is not None | ||
return self.model( | ||
prompts, | ||
num_images_per_prompt=num_images, | ||
num_inference_steps=num_inference_steps, | ||
guidance_scale=guidance_scale, | ||
).images |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
images: | ||
custom-inf2: | ||
base: autonomi/nos:latest-inf2 | ||
env: | ||
NOS_LOGGING_LEVEL: DEBUG | ||
NOS_NEURON_CORES: 2 | ||
NEURON_RT_VISIBLE_CORES: 2 | ||
|
||
models: | ||
stabilityai/stable-diffusion-xl-base-1.0-inf2: | ||
model_cls: StableDiffusionXLInf2 | ||
model_path: models/sdxl_inf2.py | ||
default_method: __call__ | ||
runtime_env: custom-inf2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
def test_sdxl_inf2(): | ||
from models.sdxl_inf2 import StableDiffusionXLInf2 | ||
from PIL import Image | ||
|
||
model = StableDiffusionXLInf2() | ||
prompts = "a photo of an astronaut riding a horse on mars" | ||
response = model(prompts=prompts, height=1024, width=1024, num_inference_steps=50) | ||
assert response is not None | ||
assert isinstance(response[0], Image.Image) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
|
||
|
||
@pytest.mark.parametrize("model_id", ["stabilityai/stable-diffusion-xl-base-1.0-inf2"]) | ||
def test_sdxl_inf2_client(model_id): | ||
from PIL import Image | ||
|
||
from nos.client import Client | ||
|
||
# Create a client | ||
client = Client("[::]:50051") | ||
assert client.WaitForServer() | ||
|
||
# Load the embeddings model | ||
model = client.Module(model_id) | ||
|
||
# Run inference | ||
prompts = "a photo of an astronaut riding a horse on mars" | ||
response = model(prompts=prompts, height=1024, width=1024, num_inference_steps=50) | ||
assert response is not None | ||
assert isinstance(response[0], Image.Image) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters