diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py index e45fb773d..2e4f5b688 100644 --- a/shortfin/python/shortfin_apps/sd/components/builders.py +++ b/shortfin/python/shortfin_apps/sd/components/builders.py @@ -5,9 +5,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from iree.build import * -from iree.build.executor import FileNamespace +from iree.build.executor import FileNamespace, BuildAction, BuildContext, BuildFile import itertools import os +import urllib import shortfin.array as sfnp import copy @@ -162,25 +163,23 @@ def needs_update(ctx): return False -def needs_file(filename, ctx, namespace=FileNamespace.GEN): +def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN): out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() if os.path.exists(out_file): - needed = False - else: - # name_path = "bin" if namespace == FileNamespace.BIN else "" - # if name_path: - # filename = os.path.join(name_path, filename) - filekey = os.path.join(ctx.path, filename) - ctx.executor.all[filekey] = None - needed = True - return needed + if url: + needed = not is_valid_size(out_file, url) + if not needed: + return False + filekey = os.path.join(ctx.path, filename) + ctx.executor.all[filekey] = None + return True def needs_compile(filename, target, ctx): device = "amdgpu" if "gfx" in target else "llvmcpu" vmfb_name = f"{filename}_{device}-{target}.vmfb" namespace = FileNamespace.BIN - return needs_file(vmfb_name, ctx, namespace) + return needs_file(vmfb_name, ctx, namespace=namespace) def get_cached_vmfb(filename, target, ctx): @@ -190,6 +189,69 @@ def get_cached_vmfb(filename, target, ctx): return ctx.file(vmfb_name) +def is_valid_size(file_path, url): + if not url: + return True + with urllib.request.urlopen(url) as response: + content_length = response.getheader("Content-Length") + local_size = get_file_size(str(file_path)) + if content_length: + content_length = int(content_length) + if content_length != local_size: + return False + return True + + +def get_file_size(file_path): + """Gets the size of a local file in bytes as an integer.""" + + file_stats = os.stat(file_path) + return file_stats.st_size + + +def fetch_http_check_size(*, name: str, url: str) -> BuildFile: + context = BuildContext.current() + output_file = context.allocate_file(name) + action = FetchHttpWithCheckAction( + url=url, output_file=output_file, desc=f"Fetch {url}", executor=context.executor + ) + output_file.deps.add(action) + return output_file + + +class FetchHttpWithCheckAction(BuildAction): + def __init__(self, url: str, output_file: BuildFile, **kwargs): + super().__init__(**kwargs) + self.url = url + self.output_file = output_file + + def _invoke(self, retries=4): + path = self.output_file.get_fs_path() + self.executor.write_status(f"Fetching URL: {self.url} -> {path}") + try: + urllib.request.urlretrieve(self.url, str(path)) + except urllib.error.HTTPError as e: + if retries > 0: + retries -= 1 + self._invoke(retries=retries) + else: + raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None + local_size = get_file_size(str(path)) + try: + with urllib.request.urlopen(self.url) as response: + content_length = response.getheader("Content-Length") + if content_length: + content_length = int(content_length) + if content_length != local_size: + raise IOError( + f"Size of downloaded artifact does not match content-length header! {content_length} != {local_size}" + ) + except IOError: + if retries > 0: + retries -= 1 + self._invoke(retries=retries) + + @entrypoint(description="Retreives a set of SDXL submodels.") def sdxl( model_json=cl_arg( @@ -224,7 +286,7 @@ def sdxl( mlir_filenames = get_mlir_filenames(model_params, model) mlir_urls = get_url_map(mlir_filenames, mlir_bucket) for f, url in mlir_urls.items(): - if update or needs_file(f, ctx): + if update or needs_file(f, ctx, url): fetch_http(name=f, url=url) vmfb_filenames = get_vmfb_filenames(model_params, model=model, target=target) @@ -244,15 +306,14 @@ def sdxl( vmfb_filenames[idx] = get_cached_vmfb(file_stem, target, ctx) else: for f, url in vmfb_urls.items(): - if update or needs_file(f, ctx): + if update or needs_file(f, ctx, url): fetch_http(name=f, url=url) params_filenames = get_params_filenames(model_params, model=model, splat=splat) params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) for f, url in params_urls.items(): - out_file = os.path.join(ctx.executor.output_dir, f) - if needs_file(f, ctx): - fetch_http(name=f, url=url) + if needs_file(f, ctx, url): + fetch_http_check_size(name=f, url=url) filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames] return filenames