Skip to content

Commit

Permalink
redactor: replace custom request with third-party API package (webuia…
Browse files Browse the repository at this point in the history
…pi) for SD-model invocation

feat: add stablediffusion model  services
  • Loading branch information
cmgzn committed Sep 27, 2024
1 parent ca23ad8 commit e49e50e
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 136 deletions.
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
extra_litellm_requires = ["litellm"]
extra_zhipuai_requires = ["zhipuai"]
extra_ollama_requires = ["ollama>=0.1.7"]
extra_sd_webuiapi_requires = ["webuiapi"]

# Full requires
extra_full_requires = (
Expand All @@ -102,6 +103,7 @@
+ extra_litellm_requires
+ extra_zhipuai_requires
+ extra_ollama_requires
+ extra_sd_webuiapi_requires
)

# For online workstation
Expand Down Expand Up @@ -140,6 +142,7 @@
"litellm": extra_litellm_requires,
"zhipuai": extra_zhipuai_requires,
"gemini": extra_gemini_requires,
"stablediffusion": extra_sd_webuiapi_requires,
# For service functions
"service": extra_service_requires,
# For distribution mode
Expand Down
215 changes: 79 additions & 136 deletions src/agentscope/models/stablediffusion_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# -*- coding: utf-8 -*-
"""Model wrapper for stable diffusion models."""
from abc import ABC
import base64
import json
import time
from typing import Any, Optional, Union, Sequence
from typing import Any, Union, Sequence

import requests
from loguru import logger

try:
import webuiapi
except ImportError:
webuiapi = None

from . import ModelWrapperBase, ModelResponse
from ..constants import _DEFAULT_MAX_RETRIES
from ..constants import _DEFAULT_RETRY_INTERVAL
from ..message import Msg
from ..manager import FileManager
from ..utils.common import _convert_to_str
Expand All @@ -23,9 +22,10 @@ class StableDiffusionWrapperBase(ModelWrapperBase, ABC):
To use SD-webui API, please
1. First download stable-diffusion-webui from
https://github.com/AUTOMATIC1111/stable-diffusion-webui and
install it with 'webui-user.bat'
install it
2. Move your checkpoint to 'models/Stable-diffusion' folder
3. Start launch.py with the '--api' parameter to start the server
3. Start launch.py with the '--api --port=7862' parameter
4. Install the 'webuiapi' package by 'pip install webuiapi'
After that, you can use the SD-webui API and
query the available parameters on the http://localhost:7862/docs page
"""
Expand All @@ -35,15 +35,10 @@ class StableDiffusionWrapperBase(ModelWrapperBase, ABC):
def __init__(
self,
config_name: str,
host: str = "127.0.0.1:7862",
base_url: Optional[Union[str, None]] = None,
use_https: bool = False,
generate_args: dict = None,
headers: dict = None,
options: dict = None,
timeout: int = 30,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_interval: int = _DEFAULT_RETRY_INTERVAL,
host: str = "127.0.0.1",
port: int = 7862,
**kwargs: Any,
) -> None:
"""
Expand All @@ -52,46 +47,30 @@ def __init__(
Args:
config_name (`str`):
The name of the model config.
host (`str`, default `"127.0.0.1:7862"`):
The host port of the stable-diffusion webui server.
base_url (`str`, default `None`):
Base URL for the stable-diffusion webui services.
Generated from host and use_https if not provided.
use_https (`bool`, default `False`):
Whether to generate the base URL with HTTPS protocol or HTTP.
generate_args (`dict`, default `None`):
The extra keyword arguments used in SD api generation,
e.g. `{"steps": 50}`.
headers (`dict`, default `None`):
HTTP request headers.
options (`dict`, default `None`):
The keyword arguments to change the webui settings
The keyword arguments to change the sd-webui settings
such as model or CLIP skip, this changes will persist.
e.g. `{"sd_model_checkpoint": "Anything-V3.0-pruned"}`.
host (`str`, default `"127.0.0.1"`):
The host of the stable-diffusion webui server.
port (`int`, default `7862`):
The port of the stable-diffusion webui server.
"""
# Construct base_url based on HTTPS usage if not provided
if base_url is None:
if use_https:
base_url = f"https://{host}"
else:
base_url = f"http://{host}"

self.base_url = base_url
self.options_url = f"{base_url}/sdapi/v1/options"
# Initialize the SD-webui API
self.api = webuiapi.WebUIApi(host=host, port=port, **kwargs)
self.generate_args = generate_args or {}

# Initialize the HTTP session and update the request headers
self.session = requests.Session()
if headers:
self.session.headers.update(headers)

# Set options if provided
if options:
self._set_options(options)
self.api.set_options(options)
logger.info(f"Set webui options: {options}")

# Get the default model name from the web-options
model_name = (
self._get_options()["sd_model_checkpoint"].split("[")[0].strip()
self.api.get_options()["sd_model_checkpoint"].split("[")[0].strip()
)
# Update the model name
if self.generate_args.get("override_settings"):
Expand All @@ -102,116 +81,29 @@ def __init__(

super().__init__(config_name=config_name, model_name=model_name)

self.timeout = timeout
self.max_retries = max_retries
self.retry_interval = retry_interval

@property
def url(self) -> str:
"""SD-webui API endpoint URL"""
raise NotImplementedError()

def _get_options(self) -> dict:
response = self.session.get(url=self.options_url)
if response.status_code != 200:
logger.error(f"Failed to get options with {response.json()}")
raise RuntimeError(f"Failed to get options with {response.json()}")
return response.json()

def _set_options(self, options: dict) -> None:
response = self.session.post(url=self.options_url, json=options)
if response.status_code != 200:
logger.error(json.dumps(options, indent=4))
raise RuntimeError(f"Failed to set options with {response.json()}")
logger.info("Optionsset successfully")

def _invoke_model(self, payload: dict) -> dict:
"""Invoke SD webui API and record the invocation if needed"""
# step1: prepare post requests
for i in range(1, self.max_retries + 1):
response = self.session.post(url=self.url, json=payload)

if response.status_code == requests.codes.ok:
break

if i < self.max_retries:
logger.warning(
f"Failed to call the model with "
f"requests.codes == {response.status_code}, retry "
f"{i + 1}/{self.max_retries} times",
)
time.sleep(i * self.retry_interval)

# step2: record model invocation
# record the model api invocation, which will be skipped if
# `FileManager.save_api_invocation` is `False`
self._save_model_invocation(
arguments=payload,
response=response.json(),
)

# step3: return the response json
if response.status_code == requests.codes.ok:
return response.json()
else:
logger.error(
json.dumps({"url": self.url, "json": payload}, indent=4),
)
raise RuntimeError(
f"Failed to call the model with {response.json()}",
)

def _parse_response(self, response: dict) -> ModelResponse:
"""Parse the response json data into ModelResponse"""
return ModelResponse(raw=response)


class StableDiffusionImageSynthesisWrapper(StableDiffusionWrapperBase):
"""Stable Diffusion Text-to-Image (txt2img) API Wrapper"""

model_type: str = "sd_txt2img"

@property
def url(self) -> str:
return f"{self.base_url}/sdapi/v1/txt2img"

def _parse_response(self, response: dict) -> ModelResponse:
session_parameters = response["parameters"]
size = f"{session_parameters['width']}*{session_parameters['height']}"
image_count = (
session_parameters["batch_size"] * session_parameters["n_iter"]
)

self.monitor.update_image_tokens(
model_name=self.model_name,
image_count=image_count,
resolution=size,
)

# Get image base64code as a list
images = response["images"]
b64_images = [base64.b64decode(image) for image in images]

file_manager = FileManager.get_instance()
# Return local url
image_urls = [file_manager.save_image(_) for _ in b64_images]
text = "Image saved to " + "\n".join(image_urls)
return ModelResponse(text=text, image_urls=image_urls, raw=response)

def __call__(
self,
prompt: str,
save_local: bool = True,
**kwargs: Any,
) -> ModelResponse:
"""
Args:
prompt (`str`):
The prompt string to generate images from.
save_local (`bool`, default `False`):
Whether to save the generated images locally.
**kwargs (`Any`):
The keyword arguments to SD-webui txt2img API, e.g.
`n_iter`, `steps`, `seed`, `width`, etc. Please refer to
https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API
or http://localhost:7860/docs
or http://localhost:7862/docs
for more detailed arguments.
Returns:
`ModelResponse`:
Expand All @@ -226,10 +118,61 @@ def __call__(
}

# step2: forward to generate response
response = self._invoke_model(payload)
response = self.api.txt2img(**payload)

# step3: save model invocation and update monitor
self._save_model_invocation_and_update_monitor(
payload=payload,
response=response.json,
)

# step4: parse the response
PIL_images = response.images

file_manager = FileManager.get_instance()
if save_local:
# Save images
image_urls = [file_manager.save_image(_) for _ in PIL_images]
text = "Image saved to " + "\n".join(image_urls)
else:
image_urls = PIL_images
text = None

return ModelResponse(
text=text,
image_urls=image_urls,
raw=response.json,
)

def _save_model_invocation_and_update_monitor(
self,
payload: dict,
response: dict,
) -> None:
"""Save the model invocation and update the monitor accordingly.
Args:
kwargs (`dict`):
The keyword arguments to the DashScope chat API.
response (`dict`):
The response object returned by the DashScope chat API.
"""
self._save_model_invocation(
arguments=payload,
response=response,
)

session_parameters = response["parameters"]
size = f"{session_parameters['width']}*{session_parameters['height']}"
image_count = (
session_parameters["batch_size"] * session_parameters["n_iter"]
)

# step3: parse the response
return self._parse_response(response)
self.monitor.update_image_tokens(
model_name=self.model_name,
image_count=image_count,
resolution=size,
)

def format(self, *args: Union[Msg, Sequence[Msg]]) -> str:
# This is a temporary implementation to focus on the prompt
Expand Down
2 changes: 2 additions & 0 deletions src/agentscope/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
openai_edit_image,
openai_create_image_variation,
)
from .multi_modality.stablediffusion_services import sd_text_to_image

from .service_response import ServiceResponse
from .service_toolkit import ServiceToolkit
Expand Down Expand Up @@ -117,6 +118,7 @@ def get_help() -> None:
"openai_image_to_text",
"openai_edit_image",
"openai_create_image_variation",
"sd_text_to_image",
"tripadvisor_search",
"tripadvisor_search_location_photos",
"tripadvisor_search_location_details",
Expand Down
Loading

0 comments on commit e49e50e

Please sign in to comment.