Skip to content

Commit

Permalink
Add size option to choose size with pixels
Browse files Browse the repository at this point in the history
  • Loading branch information
bfirsh committed Oct 9, 2024
1 parent 9f36e63 commit 0e0d866
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,28 @@
"9:21": (640, 1536),
}

# 1 megapixel sizes
SIZES = {
f"{x}x{y}": (x, y) for x, y in ASPECT_RATIOS.values()
}
# 0.25 megapixel sizes
SIZES.update({
f"{x / 2}x{y / 2}": (x / 2, y / 2) for x, y in ASPECT_RATIOS.values()
})


@dataclass
class SharedInputs:
prompt: Input = Input(description="Prompt for generated image")
size: Input = Input(
description="Size of the generated image",
choices=list(SIZES.keys()),
default="1024x1024",
)
aspect_ratio: Input = Input(
description="Aspect ratio for the generated image",
choices=list(ASPECT_RATIOS.keys()),
default="1:1",
default=None,
)
num_outputs: Input = Input(
description="Number of outputs to generate", default=1, le=4, ge=1
Expand Down Expand Up @@ -99,7 +113,7 @@ class SharedInputs:
megapixels: Input = Input(
description="Approximate number of megapixels for generated image",
choices=["1", "0.25"],
default="1",
default=None,
)


Expand Down Expand Up @@ -246,11 +260,21 @@ def predict():
raise Exception("You need to instantiate a predictor for a specific flux model")

def preprocess(
self, aspect_ratio: str, seed: Optional[int], megapixels: str
self, size: str, aspect_ratio: str | None, seed: Optional[int], megapixels: str | None
) -> Dict:
width, height = ASPECT_RATIOS.get(aspect_ratio)
if megapixels == "0.25":
width, height = width // 2, height // 2
width, height = SIZES.get(size)

# Backwards compatibility for deprecated aspect_ratio and megapixels inputs
if aspect_ratio is not None or megapixels is not None:
# set defaults
if aspect_ratio is None:
aspect_ratio = "1024x1024"
if megapixels is None:
megapixels = "1"

width, height = ASPECT_RATIOS.get(aspect_ratio)
if megapixels == "0.25":
width, height = width // 2, height // 2

if not seed:
seed = int.from_bytes(os.urandom(2), "big")
Expand Down Expand Up @@ -468,6 +492,7 @@ def setup(self) -> None:
def predict(
self,
prompt: str = SHARED_INPUTS.prompt,
size: str = SHARED_INPUTS.size,
aspect_ratio: str = SHARED_INPUTS.aspect_ratio,
num_outputs: int = SHARED_INPUTS.num_outputs,
num_inference_steps: int = Input(
Expand All @@ -483,7 +508,7 @@ def predict(
go_fast: bool = SHARED_INPUTS.go_fast,
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)
hws_kwargs = self.preprocess(size, aspect_ratio, seed, megapixels)

if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
Expand Down Expand Up @@ -518,6 +543,7 @@ def setup(self) -> None:
def predict(
self,
prompt: str = SHARED_INPUTS.prompt,
size: str = SHARED_INPUTS.size,
aspect_ratio: str = SHARED_INPUTS.aspect_ratio,
image: Path = Input(
description="Input image for image to image mode. The aspect ratio of your output will match this image",
Expand Down Expand Up @@ -549,7 +575,7 @@ def predict(
if image and go_fast:
print("img2img not supported with fp8 quantization; running with bf16")
go_fast = False
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)
hws_kwargs = self.preprocess(size, aspect_ratio, seed, megapixels)

if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
Expand Down

0 comments on commit 0e0d866

Please sign in to comment.