Skip to content

Commit

Permalink
Implement a few pylint suggestions.
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Nov 18, 2024
1 parent 370766a commit 917e1d4
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 56 deletions.
68 changes: 33 additions & 35 deletions shortfin/python/shortfin_apps/sd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any

import argparse
import logging
from pathlib import Path
import sys
import os
import copy
import subprocess
from contextlib import asynccontextmanager
import uvicorn

# Import first as it does dep checking and reporting.
from shortfin.interop.fastapi import FastAPIResponder

from contextlib import asynccontextmanager
from shortfin.support.logging_setup import native_handler

from fastapi import FastAPI, Request, Response
import uvicorn

from .components.generate import ClientGenerateBatchProcess
from .components.config_struct import ModelParams
Expand All @@ -29,14 +28,40 @@
from .components.service import GenerateService
from .components.tokenizer import Tokenizer

from shortfin.support.logging_setup import native_handler

logger = logging.getLogger("shortfin-sd")
logger.addHandler(native_handler)
logger.propagate = False

THIS_DIR = Path(__file__).resolve().parent

UVICORN_LOG_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"format": "[{asctime}] {message}",
"datefmt": "%Y-%m-%d %H:%M:%S",
"style": "{",
"use_colors": True,
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"loggers": {
"uvicorn": {
"handlers": ["console"],
"level": "INFO",
"propagate": False,
},
},
}


@asynccontextmanager
async def lifespan(app: FastAPI):
Expand Down Expand Up @@ -233,16 +258,14 @@ def get_modules(args, model_config, flagfile, td_spec):
for name in filenames:
for key in vmfbs.keys():
if key in name.lower():
if any([x in name for x in [".irpa", ".safetensors", ".gguf"]]):
if any(x in name for x in [".irpa", ".safetensors", ".gguf"]):
params[key].extend([name])
elif "vmfb" in name:
vmfbs[key].extend([name])
return vmfbs, params


def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
from pathlib import Path

def main(argv, log_config=UVICORN_LOG_CONFIG):
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
Expand Down Expand Up @@ -394,30 +417,5 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
main(
sys.argv[1:],
# Make logging defer to the default shortfin logging config.
log_config={
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"format": "[{asctime}] {message}",
"datefmt": "%Y-%m-%d %H:%M:%S",
"style": "{",
"use_colors": True,
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"loggers": {
"uvicorn": {
"handlers": ["console"],
"level": "INFO",
"propagate": False,
},
},
},
log_config=UVICORN_LOG_CONFIG,
)
39 changes: 18 additions & 21 deletions shortfin/python/shortfin_apps/sd/simple_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from datetime import datetime as dt
import os
import sys
import time
import json
import requests
import argparse
import base64
import time
import asyncio
import aiohttp
import sys
import os
import requests

from datetime import datetime as dt
from PIL import Image

sample_request = {
Expand All @@ -32,10 +32,10 @@
}


def bytes_to_img(bytes, outputdir, idx=0, width=1024, height=1024):
def bytes_to_img(in_bytes, outputdir, idx=0, width=1024, height=1024):
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
image = Image.frombytes(
mode="RGB", size=(width, height), data=base64.b64decode(bytes)
mode="RGB", size=(width, height), data=base64.b64decode(in_bytes)
)
if not os.path.isdir(outputdir):
os.mkdir(outputdir)
Expand Down Expand Up @@ -65,7 +65,6 @@ async def send_request(session, rep, args, data):
# Check if the response was successful
if response.status == 200:
response.raise_for_status() # Raise an error for bad responses
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
res_json = await response.json(content_type=None)
if args.save:
for idx, item in enumerate(res_json["images"]):
Expand All @@ -78,9 +77,8 @@ async def send_request(session, rep, args, data):
latency = end - start
print("Responses processed.")
return latency, len(data["prompt"])
else:
print(f"Error: Received {response.status} from server")
raise Exception
print(f"Error: Received {response.status} from server")
raise Exception


async def static(args):
Expand Down Expand Up @@ -116,7 +114,7 @@ async def static(args):
latencies.append(latency)
sample_counts.append(num_samples)
end = time.time()
if not any([i is None for i in [latencies, sample_counts]]):
if not any(i is None for i in [latencies, sample_counts]):
total_num_samples = sum(sample_counts)
sps = str(total_num_samples / (end - start))
# Until we have better measurements, don't report the throughput that includes saving images.
Expand Down Expand Up @@ -163,9 +161,9 @@ async def interactive(args):
pending, return_when=asyncio.ALL_COMPLETED
)
for task in done:
latency, num_samples = await task
_, _ = await task
pending = []
if any([i is None for i in [latencies, sample_counts]]):
if any(i is None for i in [latencies, sample_counts]):
raise ValueError("Received error response from server.")


Expand All @@ -175,28 +173,27 @@ async def ainput(prompt: str) -> str:

async def async_range(count):
for i in range(count):
yield (i)
yield i
await asyncio.sleep(0.0)


def check_health(url):
ready = False
print(f"Waiting for server.", end=None)
print("Waiting for server.", end=None)
while not ready:
try:
if requests.get(f"{url}/health").status_code == 200:
if requests.get(f"{url}/health", timeout=20).status_code == 200:
print("Successfully connected to server.")
ready = True
return
else:
time.sleep(2)
print(".", end=None)
time.sleep(2)
print(".", end=None)
except:
time.sleep(2)
print(".", end=None)


def main(argv):
def main():
p = argparse.ArgumentParser()
p.add_argument(
"--file",
Expand Down

0 comments on commit 917e1d4

Please sign in to comment.