Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(shortfin-sd) Validate size of downloaded artifacts. #560

Merged
merged 6 commits into from
Nov 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 78 additions & 17 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from iree.build import *
from iree.build.executor import FileNamespace
from iree.build.executor import FileNamespace, BuildAction, BuildContext, BuildFile
import itertools
import os
import urllib
import shortfin.array as sfnp
import copy

Expand Down Expand Up @@ -162,25 +163,23 @@ def needs_update(ctx):
return False


def needs_file(filename, ctx, namespace=FileNamespace.GEN):
def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN):
out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path()
if os.path.exists(out_file):
needed = False
else:
# name_path = "bin" if namespace == FileNamespace.BIN else ""
# if name_path:
# filename = os.path.join(name_path, filename)
filekey = os.path.join(ctx.path, filename)
ctx.executor.all[filekey] = None
needed = True
return needed
if url:
needed = not is_valid_size(out_file, url)
if not needed:
return False
filekey = os.path.join(ctx.path, filename)
ctx.executor.all[filekey] = None
return True


def needs_compile(filename, target, ctx):
device = "amdgpu" if "gfx" in target else "llvmcpu"
vmfb_name = f"{filename}_{device}-{target}.vmfb"
namespace = FileNamespace.BIN
return needs_file(vmfb_name, ctx, namespace)
return needs_file(vmfb_name, ctx, namespace=namespace)


def get_cached_vmfb(filename, target, ctx):
Expand All @@ -190,6 +189,69 @@ def get_cached_vmfb(filename, target, ctx):
return ctx.file(vmfb_name)


def is_valid_size(file_path, url):
if not url:
return True
with urllib.request.urlopen(url) as response:
content_length = response.getheader("Content-Length")
local_size = get_file_size(str(file_path))
if content_length:
content_length = int(content_length)
if content_length != local_size:
return False
return True


def get_file_size(file_path):
"""Gets the size of a local file in bytes as an integer."""

file_stats = os.stat(file_path)
return file_stats.st_size


def fetch_http_check_size(*, name: str, url: str) -> BuildFile:
context = BuildContext.current()
output_file = context.allocate_file(name)
action = FetchHttpWithCheckAction(
url=url, output_file=output_file, desc=f"Fetch {url}", executor=context.executor
)
output_file.deps.add(action)
return output_file


class FetchHttpWithCheckAction(BuildAction):
def __init__(self, url: str, output_file: BuildFile, **kwargs):
super().__init__(**kwargs)
self.url = url
self.output_file = output_file

def _invoke(self, retries=4):
path = self.output_file.get_fs_path()
self.executor.write_status(f"Fetching URL: {self.url} -> {path}")
try:
urllib.request.urlretrieve(self.url, str(path))
except urllib.error.HTTPError as e:
if retries > 0:
retries -= 1
self._invoke(retries=retries)
else:
raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None
local_size = get_file_size(str(path))
try:
with urllib.request.urlopen(self.url) as response:
content_length = response.getheader("Content-Length")
if content_length:
content_length = int(content_length)
if content_length != local_size:
raise IOError(
f"Size of downloaded artifact does not match content-length header! {content_length} != {local_size}"
)
except IOError:
if retries > 0:
retries -= 1
self._invoke(retries=retries)


@entrypoint(description="Retreives a set of SDXL submodels.")
def sdxl(
model_json=cl_arg(
Expand Down Expand Up @@ -224,7 +286,7 @@ def sdxl(
mlir_filenames = get_mlir_filenames(model_params, model)
mlir_urls = get_url_map(mlir_filenames, mlir_bucket)
for f, url in mlir_urls.items():
if update or needs_file(f, ctx):
if update or needs_file(f, ctx, url):
fetch_http(name=f, url=url)

vmfb_filenames = get_vmfb_filenames(model_params, model=model, target=target)
Expand All @@ -244,15 +306,14 @@ def sdxl(
vmfb_filenames[idx] = get_cached_vmfb(file_stem, target, ctx)
else:
for f, url in vmfb_urls.items():
if update or needs_file(f, ctx):
if update or needs_file(f, ctx, url):
fetch_http(name=f, url=url)

params_filenames = get_params_filenames(model_params, model=model, splat=splat)
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if needs_file(f, ctx):
fetch_http(name=f, url=url)
if needs_file(f, ctx, url):
fetch_http_check_size(name=f, url=url)
filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames]
return filenames

Expand Down
Loading