-
Notifications
You must be signed in to change notification settings - Fork 976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable diffusion mlx #474
base: main
Are you sure you want to change the base?
Stable diffusion mlx #474
Conversation
Just so I understand what's going on, how is the model split up between devices? Lets say I have 3 devices with different capabilities, how does that work? |
There are no changes to that part. It's how the partition algorithm splits the shards across the devices. |
I see. The difference here is the layers are non-uniform. That means they won't necessarily get split proportional to the memory used right? |
Yeah, layers are non-uniform, so the split memory isn't exactly proportional to the number of layers. Can we split using the number of params? |
This is probably fine as long as the layers aren't wildly different in size. Do you know roughly how different in size they are? |
Unet does have couple larger layers because of upsampled dims and clip text encoder has comparatively smaller layers as it can be easily split similar to llms, made of transformer blocks. We can combine 2 clip layers and split UNET further to make it more uniform. |
I think at some point it would make sense to allow more granular sharding of models than just transformer blocks anyway, and this could involve updating to a memory-footprint heuristic based on dtypes and parameters rather than assuming uniform layer blocks |
Was running on this branch for a while doing image generation requests and got the following errors:
|
My bad - I accidentally had another Mac mini in the cluster on a different commit which was why we were getting these errors. No issues. |
Please fix conflicts @pranav4501 |
Resolved! |
Can we also support chat history? I want to be able to give follow-up instructions like: Me: Generate an image of a robot in a flower field |
Also, we need to persist the images in a system-wide location rather than in the exo repo itself. |
For chat history, we can implement it as follows:
|
Hi Alex,
|
I don't think it is ignoring the previous image here, but it creates a latent using the encoded image and encoded prompt to generate a new image. So, it is not editing the previous image. We can set the strength of the prompt in the newly generated image. For a smaller strength, many features of the older images are carried over but the features of the prompt are not fully generated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some architectural improvements I'd like to make here. Some high level notes:
inference_state
shouldn't be necessary- Special cases for stable diffusion
- Cache handling can be unified / made more flexible with e.g. prompt caching / preloading the cache with a serialized cache
- Tokenizer handling is hacky and should be unified - you had to write a new tokenizer for example
For now I think this is good to merge if you fix the small things I made comments on and we will work on these architectural changes later.
model = data.get("model", "") | ||
prompt = data.get("prompt", "") | ||
image_url = data.get("image_url", "") | ||
print(f"model: {model}, prompt: {prompt}, stream: {stream}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naked print
image_url = data.get("image_url", "") | ||
print(f"model: {model}, prompt: {prompt}, stream: {stream}") | ||
shard = build_base_shard(model, self.inference_engine_classname) | ||
print(f"shard: {shard}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naked print
img = Image.open(BytesIO(image_data)) | ||
W, H = (dim - dim % 64 for dim in (img.width, img.height)) | ||
if W != img.width or H != img.height: | ||
print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naked print
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is there a bulld dir?
models_config = json.dumps(models_config) | ||
models_config = json.loads(models_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this to sanitize the json?
Sharded stable diffusion inference for mlx
#159
Changes:
Sharding process: