Skip to content

Commit

Permalink
(shortfin-sd) Validate size of downloaded artifacts. (#560)
Browse files Browse the repository at this point in the history
This commit introduces simple filesize/content-length comparison for a
simple validation of downloaded artifacts.
While I'm not sold on this as a long-term validation solution, it is
certainly better than nothing.
  • Loading branch information
monorimet authored Nov 18, 2024
1 parent 5d453c3 commit 99c5279
Showing 1 changed file with 78 additions and 17 deletions.
95 changes: 78 additions & 17 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from iree.build import *
from iree.build.executor import FileNamespace
from iree.build.executor import FileNamespace, BuildAction, BuildContext, BuildFile
import itertools
import os
import urllib
import shortfin.array as sfnp
import copy

Expand Down Expand Up @@ -162,25 +163,23 @@ def needs_update(ctx):
return False


def needs_file(filename, ctx, namespace=FileNamespace.GEN):
def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN):
out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path()
if os.path.exists(out_file):
needed = False
else:
# name_path = "bin" if namespace == FileNamespace.BIN else ""
# if name_path:
# filename = os.path.join(name_path, filename)
filekey = os.path.join(ctx.path, filename)
ctx.executor.all[filekey] = None
needed = True
return needed
if url:
needed = not is_valid_size(out_file, url)
if not needed:
return False
filekey = os.path.join(ctx.path, filename)
ctx.executor.all[filekey] = None
return True


def needs_compile(filename, target, ctx):
device = "amdgpu" if "gfx" in target else "llvmcpu"
vmfb_name = f"{filename}_{device}-{target}.vmfb"
namespace = FileNamespace.BIN
return needs_file(vmfb_name, ctx, namespace)
return needs_file(vmfb_name, ctx, namespace=namespace)


def get_cached_vmfb(filename, target, ctx):
Expand All @@ -190,6 +189,69 @@ def get_cached_vmfb(filename, target, ctx):
return ctx.file(vmfb_name)


def is_valid_size(file_path, url):
if not url:
return True
with urllib.request.urlopen(url) as response:
content_length = response.getheader("Content-Length")
local_size = get_file_size(str(file_path))
if content_length:
content_length = int(content_length)
if content_length != local_size:
return False
return True


def get_file_size(file_path):
"""Gets the size of a local file in bytes as an integer."""

file_stats = os.stat(file_path)
return file_stats.st_size


def fetch_http_check_size(*, name: str, url: str) -> BuildFile:
context = BuildContext.current()
output_file = context.allocate_file(name)
action = FetchHttpWithCheckAction(
url=url, output_file=output_file, desc=f"Fetch {url}", executor=context.executor
)
output_file.deps.add(action)
return output_file


class FetchHttpWithCheckAction(BuildAction):
def __init__(self, url: str, output_file: BuildFile, **kwargs):
super().__init__(**kwargs)
self.url = url
self.output_file = output_file

def _invoke(self, retries=4):
path = self.output_file.get_fs_path()
self.executor.write_status(f"Fetching URL: {self.url} -> {path}")
try:
urllib.request.urlretrieve(self.url, str(path))
except urllib.error.HTTPError as e:
if retries > 0:
retries -= 1
self._invoke(retries=retries)
else:
raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None
local_size = get_file_size(str(path))
try:
with urllib.request.urlopen(self.url) as response:
content_length = response.getheader("Content-Length")
if content_length:
content_length = int(content_length)
if content_length != local_size:
raise IOError(
f"Size of downloaded artifact does not match content-length header! {content_length} != {local_size}"
)
except IOError:
if retries > 0:
retries -= 1
self._invoke(retries=retries)


@entrypoint(description="Retreives a set of SDXL submodels.")
def sdxl(
model_json=cl_arg(
Expand Down Expand Up @@ -224,7 +286,7 @@ def sdxl(
mlir_filenames = get_mlir_filenames(model_params, model)
mlir_urls = get_url_map(mlir_filenames, mlir_bucket)
for f, url in mlir_urls.items():
if update or needs_file(f, ctx):
if update or needs_file(f, ctx, url):
fetch_http(name=f, url=url)

vmfb_filenames = get_vmfb_filenames(model_params, model=model, target=target)
Expand All @@ -244,15 +306,14 @@ def sdxl(
vmfb_filenames[idx] = get_cached_vmfb(file_stem, target, ctx)
else:
for f, url in vmfb_urls.items():
if update or needs_file(f, ctx):
if update or needs_file(f, ctx, url):
fetch_http(name=f, url=url)

params_filenames = get_params_filenames(model_params, model=model, splat=splat)
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if needs_file(f, ctx):
fetch_http(name=f, url=url)
if needs_file(f, ctx, url):
fetch_http_check_size(name=f, url=url)
filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames]
return filenames

Expand Down

0 comments on commit 99c5279

Please sign in to comment.