Skip to content

Commit

Permalink
feat: Update stable-diffusion-neuron docker image (#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
ratnopamc authored Apr 8, 2024
1 parent 6f26638 commit c50c181
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#----------------------------------------------------------------------
# NOTE: For deployment instructions, refer to the DoEKS website.
#----------------------------------------------------------------------
---
apiVersion: v1
kind: Namespace
metadata:
Expand Down Expand Up @@ -29,7 +28,7 @@ spec:
- name: stable-diffusion-v2
autoscaling_config:
metrics_interval_s: 0.2
min_replicas: 8
min_replicas: 2
max_replicas: 12
look_back_period_s: 2
downscale_delay_s: 30
Expand All @@ -55,7 +54,7 @@ spec:
spec:
containers:
- name: ray-head
image: public.ecr.aws/data-on-eks/ray2.9.0-py310-stablediffusion-neuron:v1.0
image: public.ecr.aws/data-on-eks/ray2.9.0-py310-stablediffusion-neuron:latest
imagePullPolicy: Always # Ensure the image is always pulled when updated
lifecycle:
preStop:
Expand Down Expand Up @@ -98,7 +97,7 @@ spec:
spec:
containers:
- name: ray-worker
image: public.ecr.aws/data-on-eks/ray2.9.0-py310-stablediffusion-neuron:v1.0
image: public.ecr.aws/data-on-eks/ray2.9.0-py310-stablediffusion-neuron:latest
imagePullPolicy: Always # Ensure the image is always pulled when updated
lifecycle:
preStop:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ class StableDiffusionV2:
def __init__(self):
from optimum.neuron import NeuronStableDiffusionXLPipeline

compiled_model_id = "aws-neuron/stable-diffusion-xl-base-1-0-1024x1024"
model_id = os.getenv('MODEL_ID')

# To avoid saving the model locally, we can use the pre-compiled model directly from HF
self.pipe = NeuronStableDiffusionXLPipeline.from_pretrained(compiled_model_id, device_ids=[0, 1])
self.pipe = NeuronStableDiffusionXLPipeline.from_pretrained(model_id, device_ids=[0, 1])

async def generate(self, prompt: str):
assert len(prompt), "prompt parameter cannot be empty"
Expand Down

0 comments on commit c50c181

Please sign in to comment.