From 537a3901e50913065ccb104cc45accbe84e91f9f Mon Sep 17 00:00:00 2001 From: George Petterson Date: Mon, 20 May 2024 12:10:20 -0400 Subject: [PATCH] Fix some issues with the exe --- apps/shark_studio/api/llm.py | 21 ++++++------ apps/shark_studio/api/sd.py | 34 ++++++++++++++++---- apps/shark_studio/modules/ckpt_processing.py | 24 ++++++++++++++ apps/shark_studio/modules/pipeline.py | 2 +- apps/shark_studio/web/index.py | 7 ++++ apps/shark_studio/web/utils/file_utils.py | 14 +++++--- 6 files changed, 81 insertions(+), 21 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 6ee80ae49e..34e6230ba6 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -90,9 +90,10 @@ def __init__( self.file_spec += "_" + self.quantization if external_weights in ["safetensors", "gguf"]: - self.external_weight_file = get_resource_path( - os.path.join("..", self.file_spec + "." + external_weights) - ) + # self.external_weight_file = get_resource_path( + # os.path.join(cmd_opts.model_dir, self.file_spec + "." + external_weights) + # ) + self.external_weight_file = os.path.join(cmd_opts.model_dir, self.file_spec + "." + external_weights) else: self.external_weights = None self.external_weight_file = None @@ -102,14 +103,16 @@ def __init__( self.file_spec += "_streaming" self.streaming_llm = streaming_llm - self.tempfile_name = get_resource_path( - os.path.join("..", f"{self.file_spec}.tempfile") - ) + # self.tempfile_name = get_resource_path( + # os.path.join(cmd_opts.tmp_dir, f"{self.file_spec}.tempfile") + # ) + self.tempfile_name = os.path.join(cmd_opts.tmp_dir, f"{self.file_spec}.tempfile") # TODO: Tag vmfb with target triple of device instead of HAL backend self.vmfb_name = str( - get_resource_path( - os.path.join("..", f"{self.file_spec}_{self.backend}.vmfb.tempfile") - ) + # get_resource_path( + # os.path.join(cmd_opts.tmp_dir, f"{self.file_spec}_{self.backend}.vmfb.tempfile") + # ) + os.path.join(cmd_opts.tmp_dir, f"{self.file_spec}_{self.backend}.vmfb.tempfile") ) self.max_tokens = llm_model_map[model_name]["max_tokens"] diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index b4f0f0ddc0..01b3bfaca7 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -16,6 +16,7 @@ get_resource_path, get_checkpoints_path, ) +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts from apps.shark_studio.modules.pipeline import SharkPipelineBase from apps.shark_studio.modules.schedulers import get_schedulers from apps.shark_studio.modules.prompt_encoding import ( @@ -30,10 +31,12 @@ from apps.shark_studio.modules.ckpt_processing import ( preprocessCKPT, + load_model_from_ckpt, process_custom_pipe_weights, ) from transformers import CLIPTokenizer from diffusers.image_processor import VaeImageProcessor +from safetensors.torch import save_file sd_model_map = { "clip": { @@ -147,13 +150,31 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): for model in adapters: self.model_map[model] = adapters[model] + if custom_weights: + custom_weights_params, _ = process_custom_pipe_weights(custom_weights) + custom_weights_file = os.path.join(cmd_opts.model_dir, "checkpoints", os.path.basename(str(self.base_model_id)), custom_weights_params) + custom_weights_subdir = os.path.join(cmd_opts.model_dir, "checkpoints", os.path.basename(str(self.base_model_id)), custom_weights_params + ".d") + if not os.path.exists(custom_weights_subdir): + os.mkdir(custom_weights_subdir) + model = load_model_from_ckpt(custom_weights_file) + submodel_aliases = { + "unet": "unet", + "vae_decode": "vae", + "vae_encode": "text_encoder", + } + for submodel in self.static_kwargs: + if submodel not in submodel_aliases: + continue + if submodel_aliases[submodel] in model.__dict__: + save_file(model.__dict__[submodel_aliases[submodel]].state_dict(), os.path.join(custom_weights_subdir, submodel+".safetensors")) + for submodel in self.static_kwargs: if custom_weights: - custom_weights_params, _ = process_custom_pipe_weights(custom_weights) - if submodel not in ["clip", "clip2"]: - self.static_kwargs[submodel][ - "external_weights" - ] = custom_weights_params + if submodel in ["unet", "vae_decode"]: #submodel not in ["clip", "clip2"]: + # self.static_kwargs[submodel][ + # "external_weights" + # ] = custom_weights_params + self.static_kwargs[submodel]["external_weight_path"] = os.path.join(custom_weights_subdir, submodel+".safetensors") else: self.static_kwargs[submodel]["external_weight_path"] = os.path.join( self.weights_path, submodel + ".safetensors" @@ -603,7 +624,8 @@ def view_json_file(file_path): global_obj._init() sd_json = view_json_file( - get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json")) + # get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json")) + os.path.join(cmd_opts.config_dir, "default_sd_config.json") ) sd_kwargs = json.loads(sd_json) for arg in vars(cmd_opts): diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index fc0bd3b7b8..2d2e0609e8 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -54,6 +54,30 @@ def preprocessCKPT(custom_weights, is_inpaint=False): pipe.save_pretrained(path_to_diffusers) print("Loading complete") +def load_model_from_ckpt(custom_weights, is_inpaint=False): + path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) + # if next(Path(path_to_diffusers).iterdir(), None): + # print("Checkpoint already loaded at : ", path_to_diffusers) + # return + # else: + # print( + # "Diffusers' checkpoint will be identified here : ", + # path_to_diffusers, + # ) + from_safetensors = ( + True if custom_weights.lower().endswith(".safetensors") else False + ) + extract_ema = False + print("Loading diffusers' pipeline from original stable diffusion checkpoint") + num_in_channels = 9 if is_inpaint else 4 + pipe = download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict=custom_weights, + extract_ema=extract_ema, + from_safetensors=from_safetensors, + num_in_channels=num_in_channels, + ) + return pipe + def convert_original_vae(vae_checkpoint): vae_state_dict = {} diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py index 2daedc3352..efad5df0e5 100644 --- a/apps/shark_studio/modules/pipeline.py +++ b/apps/shark_studio/modules/pipeline.py @@ -41,7 +41,7 @@ def __init__( self.device, self.device_id = clean_device_info(device) self.import_mlir = import_mlir self.iree_module_dict = {} - self.tmp_dir = get_resource_path(cmd_opts.tmp_dir) + self.tmp_dir = cmd_opts.tmp_dir #get_resource_path(cmd_opts.tmp_dir) if not os.path.exists(self.tmp_dir): os.mkdir(self.tmp_dir) self.tempfiles = {} diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index d1b97c2f78..e2e8ad9079 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -215,6 +215,13 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): if __name__ == "__main__": from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + import shutil + + if cmd_opts.clear_all: + shutil.rmtree(cmd_opts.tmp_dir, ignore_errors=True) + for file in os.listdir(cmd_opts.model_dir): + if file not in ["checkpoints"]: + shutil.rmtree(file, ignore_errors=True) if cmd_opts.webui == False: api_only() diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py index 9617c16565..80813b0557 100644 --- a/apps/shark_studio/web/utils/file_utils.py +++ b/apps/shark_studio/web/utils/file_utils.py @@ -66,21 +66,24 @@ def get_resource_path(path): def get_configs_path() -> Path: - configs = get_resource_path(cmd_opts.config_dir) + # configs = get_resource_path(cmd_opts.config_dir) + configs = cmd_opts.config_dir if not os.path.exists(configs): os.mkdir(configs) return Path(configs) def get_generated_imgs_path() -> Path: - outputs = get_resource_path(cmd_opts.output_dir) + # outputs = get_resource_path(cmd_opts.output_dir) + outputs = cmd_opts.output_dir if not os.path.exists(outputs): os.mkdir(outputs) return Path(outputs) def get_tmp_path() -> Path: - tmpdir = get_resource_path(cmd_opts.model_dir) + # tmpdir = get_resource_path(cmd_opts.model_dir) + tmpdir = cmd_opts.model_dir if not os.path.exists(tmpdir): os.mkdir(tmpdir) return Path(tmpdir) @@ -106,7 +109,8 @@ def create_model_folders(): def get_checkpoints_path(model_type=""): - return get_resource_path(os.path.join(cmd_opts.model_dir, model_type)) + # return get_resource_path(os.path.join(cmd_opts.model_dir, model_type)) + return os.path.join(cmd_opts.model_dir, model_type) def get_checkpoints(model_type="checkpoints"): @@ -119,7 +123,7 @@ def get_checkpoints(model_type="checkpoints"): os.path.basename(x) for x in glob.glob(os.path.join(get_checkpoints_path(model_type), extn)) ] - ckpt_files.extend(files) + ckpt_files.extend(files) return sorted(ckpt_files, key=str.casefold)