diff --git a/common.hpp b/common.hpp index 4a423d5a..b79a3c92 100644 --- a/common.hpp +++ b/common.hpp @@ -465,7 +465,7 @@ struct SpatialTransformer { #if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] #else - struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position] + struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position] // kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); kq = ggml_soft_max_inplace(ctx, kq); diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 7acc4449..50a234b7 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include #include #include @@ -10,10 +12,12 @@ #include "stable-diffusion.h" #define STB_IMAGE_IMPLEMENTATION + #include "stb_image.h" #define STB_IMAGE_WRITE_IMPLEMENTATION #define STB_IMAGE_WRITE_STATIC + #include "stb_image_write.h" const char* rng_type_to_str[] = { @@ -50,6 +54,7 @@ enum SDMode { TXT2IMG, IMG2IMG, CONVERT, + STREAM, MODE_COUNT }; @@ -59,6 +64,8 @@ struct SDParams { std::string model_path; std::string vae_path; + std::string clip_path; + std::string unet_path; std::string taesd_path; std::string esrgan_path; std::string controlnet_path; @@ -86,6 +93,7 @@ struct SDParams { int64_t seed = 42; bool verbose = false; bool vae_tiling = false; + bool vae_decode_only = false; bool control_net_cpu = false; bool canny_preprocess = false; }; @@ -97,6 +105,8 @@ void print_params(SDParams params) { printf(" model_path: %s\n", params.model_path.c_str()); printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified"); printf(" vae_path: %s\n", params.vae_path.c_str()); + printf(" clip_path: %s\n", params.clip_path.c_str()); + printf(" unet_path: %s\n", params.unet_path.c_str()); printf(" taesd_path: %s\n", params.taesd_path.c_str()); printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); printf(" controlnet_path: %s\n", params.controlnet_path.c_str()); @@ -127,11 +137,14 @@ void print_usage(int argc, const char* argv[]) { printf("\n"); printf("arguments:\n"); printf(" -h, --help show this help message and exit\n"); - printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n"); + printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert or stream, default: txt2img)\n"); printf(" -t, --threads N number of threads to use during computation (default: -1).\n"); printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" -m, --model [MODEL] path to model\n"); + printf(" If the path is directory, support load model from \"unet/diffusion_pytorch_model.safetensors\", \"vae/diffusion_pytorch_model.safetensors\",\"text_encoder/model.safetensors\"\n"); printf(" --vae [VAE] path to vae\n"); + printf(" --clip [CLIP] path to clip\n"); + printf(" --unet [UNET] path to unet\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n"); @@ -150,7 +163,7 @@ void print_usage(int argc, const char* argv[]) { printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); printf(" -W, --width W image width, in pixel space (default: 512)\n"); - printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm}\n"); + printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm}\n"); printf(" sampling method (default: \"euler_a\")\n"); printf(" --steps STEPS number of sample steps (default: 20)\n"); printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); @@ -207,6 +220,18 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.vae_path = argv[i]; + } else if (arg == "--clip") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.vae_path = argv[i]; + } else if (arg == "--unet") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.vae_path = argv[i]; } else if (arg == "--taesd") { if (++i >= argc) { invalid_arg = true; @@ -416,52 +441,46 @@ void parse_args(int argc, const char** argv, SDParams& params) { print_usage(argc, argv); exit(1); } +} + +bool check_params(SDParams params) { + std::vector required_args; + std::vector invalid_args; + if (params.n_threads <= 0) { params.n_threads = get_num_physical_cores(); } if (params.mode != CONVERT && params.prompt.length() == 0) { - fprintf(stderr, "error: the following arguments are required: prompt\n"); - print_usage(argc, argv); - exit(1); + required_args.emplace_back("prompt"); } if (params.model_path.length() == 0) { - fprintf(stderr, "error: the following arguments are required: model_path\n"); - print_usage(argc, argv); - exit(1); + required_args.emplace_back("model_path"); } if (params.mode == IMG2IMG && params.input_path.length() == 0) { - fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n"); - print_usage(argc, argv); - exit(1); + required_args.emplace_back("init-img"); } if (params.output_path.length() == 0) { - fprintf(stderr, "error: the following arguments are required: output_path\n"); - print_usage(argc, argv); - exit(1); + required_args.emplace_back("output_path"); } if (params.width <= 0 || params.width % 64 != 0) { - fprintf(stderr, "error: the width must be a multiple of 64\n"); - exit(1); + invalid_args.emplace_back("the width must be a multiple of 64"); } if (params.height <= 0 || params.height % 64 != 0) { - fprintf(stderr, "error: the height must be a multiple of 64\n"); - exit(1); + invalid_args.emplace_back("the height must be a multiple of 64"); } if (params.sample_steps <= 0) { - fprintf(stderr, "error: the sample_steps must be greater than 0\n"); - exit(1); + invalid_args.emplace_back("the sample_steps must be greater than 0"); } if (params.strength < 0.f || params.strength > 1.f) { - fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n"); - exit(1); + invalid_args.emplace_back("can only work with strength in [0.0, 1.0]"); } if (params.seed < 0) { @@ -474,6 +493,36 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.output_path = "output.gguf"; } } + + if ((!invalid_args.empty()) || (!required_args.empty())) { + if (!invalid_args.empty()) { + std::ostringstream oss; + for (int i = 0; i < invalid_args.size(); i++) { + if (i > 0) { + oss << ",\n"; + } + oss << invalid_args[i]; + } + std::string invalid_args_str = oss.str(); + std::cout << "error: " << invalid_args_str << std::endl; + } + + if (!required_args.empty()) { + std::ostringstream oss; + for (int i = 0; i < required_args.size(); i++) { + if (i > 0) { + oss << ","; + } + oss << required_args[i]; + } + std::string required_args_str = oss.str(); + std::cout << "require: " << required_args_str << std::endl; + } + + return false; + } + + return true; } std::string get_image_params(SDParams params, int64_t seed) { @@ -510,181 +559,534 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { } } -int main(int argc, const char* argv[]) { - SDParams params; - parse_args(argc, argv, params); +std::vector parse_cin(std::string& input, std::set ignore_args) { + std::vector inputTokens; + std::string token; + std::istringstream iss(input); - sd_set_log_callback(sd_log_cb, (void*)¶ms); + std::string word; + bool in_stmt = false; + std::string stmt; + inputTokens.emplace_back("fake run path, no use!"); + while (iss >> word) { + if (word[0] == '"') { + in_stmt = true; + } - if (params.verbose) { - print_params(params); - printf("%s", sd_get_system_info()); + if (word[word.length() - 1] == '"') { + stmt += word; + word = stmt.substr(1, stmt.length() - 2); + stmt = ""; + in_stmt = false; + } + + if (in_stmt) { + stmt += word; + stmt += " "; + continue; + } + inputTokens.push_back(word); } - if (params.mode == CONVERT) { - bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype); - if (!success) { - fprintf(stderr, - "convert '%s'/'%s' to '%s' failed\n", - params.model_path.c_str(), - params.vae_path.c_str(), - params.output_path.c_str()); - return 1; - } else { - printf("convert '%s'/'%s' to '%s' success\n", - params.model_path.c_str(), - params.vae_path.c_str(), - params.output_path.c_str()); - return 0; + std::vector commands; + for (int i = 0; i < inputTokens.size(); i++) { + if (ignore_args.find(inputTokens[i]) != ignore_args.end()) { + i++; + continue; + } + commands.push_back(inputTokens[i]); + } + return commands; +} + +SDParams merge_params(SDParams dst, SDParams src) { + if (dst.n_threads != src.n_threads) { + if (src.n_threads > 0) { + dst.n_threads = src.n_threads; + } + } + + if (dst.mode != src.mode) { + if (src.mode == TXT2IMG || src.mode == IMG2IMG) { + dst.mode = src.mode; + if (dst.mode == IMG2IMG) { + dst.vae_decode_only = false; + } + } + } + + if (dst.model_path != src.model_path) { + if (!src.model_path.empty()) { + dst.model_path = src.model_path; + } + } + + if (dst.vae_path != src.vae_path) { + if (!src.vae_path.empty()) { + dst.vae_path = src.vae_path; + } + } + + if (dst.clip_path != src.clip_path) { + if (!src.clip_path.empty()) { + dst.clip_path = src.clip_path; + } + } + + if (dst.unet_path != src.unet_path) { + if (!src.unet_path.empty()) { + dst.unet_path = src.unet_path; + } + } + + if (dst.taesd_path != src.taesd_path) { + if (!src.taesd_path.empty()) { + dst.taesd_path = src.taesd_path; + } + } + + if (dst.esrgan_path != src.esrgan_path) { + if (!src.esrgan_path.empty()) { + dst.esrgan_path = src.esrgan_path; + } + } + + if (dst.controlnet_path != src.control_image_path) { + if (!src.controlnet_path.empty()) { + dst.controlnet_path = src.controlnet_path; + } + } + + if (dst.embeddings_path != src.embeddings_path) { + if (!src.embeddings_path.empty()) { + dst.embeddings_path = src.embeddings_path; + } + } + + if (dst.wtype != src.wtype) { + dst.wtype = src.wtype; + } + + if (dst.lora_model_dir != src.lora_model_dir) { + if (!src.lora_model_dir.empty()) { + dst.lora_model_dir = src.lora_model_dir; + } + } + + if (dst.output_path != src.output_path) { + if (!src.output_path.empty()) { + dst.output_path = src.output_path; + } + } + + if (dst.input_path != src.input_path) { + if (!src.input_path.empty()) { + dst.input_path = src.input_path; + } + } + + if (dst.control_image_path != src.control_image_path) { + if (!src.control_image_path.empty()) { + dst.control_image_path = src.control_image_path; + } + } + + if (dst.prompt != src.prompt) { + if (!src.prompt.empty()) { + dst.prompt = src.prompt; + } + } + + if (dst.negative_prompt != src.negative_prompt) { + if (!src.negative_prompt.empty()) { + dst.negative_prompt = src.negative_prompt; + } + } + + if (dst.cfg_scale != src.cfg_scale) { + if (src.cfg_scale >= 0) { + dst.cfg_scale = src.cfg_scale; + } + } + + if (dst.clip_skip != src.clip_skip) { + dst.clip_skip = src.clip_skip; + } + + if (dst.width != src.width) { + if (src.width > 0 || src.width % 64 == 0) { + dst.width = src.width; + } + } + + if (dst.height != src.height) { + if (src.height > 0 || src.height % 64 == 0) { + dst.height = src.height; + } + } + + if (dst.batch_count != src.batch_count) { + if (src.batch_count > 0) { + dst.batch_count = src.batch_count; + } + } + + if (dst.batch_count != src.batch_count) { + if (src.batch_count > 0) { + dst.batch_count = src.batch_count; + } + } + + if (dst.sample_method != src.sample_method) { + if (src.sample_method < N_SAMPLE_METHODS) { + dst.sample_method = src.sample_method; + } + } + + if (dst.schedule != src.schedule) { + if (src.schedule < N_SAMPLE_METHODS) { + dst.schedule = src.schedule; + } + } + + if (dst.sample_steps != src.sample_steps) { + if (src.sample_steps > 0) { + dst.sample_steps = src.sample_steps; + } + } + + if (dst.strength != src.strength) { + if (src.strength >= 0.f && src.strength <= 1.f) { + dst.strength = src.strength; + } + } + + if (dst.control_strength != src.control_strength) { + if (src.control_strength >= 0.f && src.control_strength <= 1.f) { + dst.control_strength = src.control_strength; + } + } + + if (dst.rng_type != src.rng_type) { + if (src.rng_type < CUDA_RNG) { + dst.rng_type = src.rng_type; } } - bool vae_decode_only = true; - uint8_t* input_image_buffer = NULL; - if (params.mode == IMG2IMG) { - vae_decode_only = false; + if (dst.seed != src.seed) { + if (src.seed > 0) { + dst.seed = src.seed; + } + } + + if (dst.verbose != src.verbose) { + dst.verbose = src.verbose; + } + + if (dst.vae_tiling != src.vae_tiling) { + dst.verbose = src.verbose; + } + + if (dst.vae_decode_only != src.vae_decode_only) { + dst.vae_decode_only = src.vae_decode_only; + } + + if (dst.control_net_cpu != src.control_net_cpu) { + dst.control_net_cpu = src.control_net_cpu; + } + + if (dst.canny_preprocess != src.canny_preprocess) { + dst.canny_preprocess = src.canny_preprocess; + } + return dst; +} + +class CliInstance { +public: + sd_ctx_t* sd_ctx; + + ~CliInstance() { + free_sd_ctx(sd_ctx); + } + + CliInstance(const SDParams& params) { + sd_ctx = new_sd_ctx( + params.n_threads, + params.vae_decode_only, + false, + params.lora_model_dir.c_str(), + params.rng_type, + params.vae_tiling, + params.wtype, + params.schedule, + params.control_net_cpu, + true); + } + + bool load_from_file(SDParams& params) { + // free api always check if the following methods can free, so we can always free the model before load it. + free_diffusions_params(sd_ctx); + auto load_status = load_diffusions_from_file(sd_ctx, params.model_path.c_str()); + + if (load_status && !params.clip_path.empty()) { + free_clip_params(sd_ctx); + load_status = load_clip_from_file(sd_ctx, params.clip_path.c_str()); + } + + if (load_status && !params.vae_path.empty()) { + free_vae_params(sd_ctx); + load_status = load_vae_from_file(sd_ctx, params.vae_path.c_str()); + } + + if (load_status && !params.unet_path.empty()) { + free_unet_params(sd_ctx); + load_status = load_unet_from_file(sd_ctx, params.unet_path.c_str()); + } + + return load_status; + } + + void txtimg(SDParams& params) { + set_options(sd_ctx, params.n_threads, + params.vae_decode_only, + true, + params.lora_model_dir.c_str(), + params.rng_type, + params.vae_tiling, + params.wtype, + params.schedule); + int c = 0; + uint8_t* input_image_buffer = stbi_load(params.control_image_path.c_str(), ¶ms.width, ¶ms.height, &c, 3); + if (input_image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str()); + return; + } + if (c != 3) { + fprintf(stderr, "input image must be a 3 channels RGB image, but got %d channels\n", c); + free(input_image_buffer); + return; + } + + sd_image_t input_image = {(uint32_t)params.width, + (uint32_t)params.height, + 3, + input_image_buffer}; + + sd_image_t* results = txt2img(sd_ctx, + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + params.cfg_scale, + params.width, + params.height, + params.sample_method, + params.sample_steps, + params.seed, + params.batch_count, + &input_image, + params.control_strength); + + results = upscaler(params, results); + save_image(params, results); + } + + void imgimg(SDParams& params) { + set_options(sd_ctx, params.n_threads, + params.vae_decode_only, + true, + params.lora_model_dir.c_str(), + params.rng_type, + params.vae_tiling, + params.wtype, + params.schedule); + uint8_t* input_image_buffer = NULL; int c = 0; input_image_buffer = stbi_load(params.input_path.c_str(), ¶ms.width, ¶ms.height, &c, 3); if (input_image_buffer == NULL) { fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str()); - return 1; + return; } if (c != 3) { fprintf(stderr, "input image must be a 3 channels RGB image, but got %d channels\n", c); free(input_image_buffer); - return 1; + return; } if (params.width <= 0 || params.width % 64 != 0) { fprintf(stderr, "error: the width of image must be a multiple of 64\n"); free(input_image_buffer); - return 1; + return; } + if (params.height <= 0 || params.height % 64 != 0) { fprintf(stderr, "error: the height of image must be a multiple of 64\n"); free(input_image_buffer); - return 1; + return; } - } - - sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), - params.vae_path.c_str(), - params.taesd_path.c_str(), - params.controlnet_path.c_str(), - params.lora_model_dir.c_str(), - params.embeddings_path.c_str(), - vae_decode_only, - params.vae_tiling, - true, - params.n_threads, - params.wtype, - params.rng_type, - params.schedule, - params.control_net_cpu); - - if (sd_ctx == NULL) { - printf("new_sd_ctx_t failed\n"); - return 1; - } - sd_image_t* results; - if (params.mode == TXT2IMG) { - sd_image_t* control_image = NULL; - if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) { - int c = 0; - input_image_buffer = stbi_load(params.control_image_path.c_str(), ¶ms.width, ¶ms.height, &c, 3); - if (input_image_buffer == NULL) { - fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str()); - return 1; - } - control_image = new sd_image_t{(uint32_t)params.width, - (uint32_t)params.height, - 3, - input_image_buffer}; - if (params.canny_preprocess) { // apply preprocessor - LOG_INFO("Applying canny preprocessor"); - control_image->data = preprocess_canny(control_image->data, control_image->width, control_image->height); - } - } - results = txt2img(sd_ctx, - params.prompt.c_str(), - params.negative_prompt.c_str(), - params.clip_skip, - params.cfg_scale, - params.width, - params.height, - params.sample_method, - params.sample_steps, - params.seed, - params.batch_count, - control_image, - params.control_strength); - } else { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, 3, input_image_buffer}; - results = img2img(sd_ctx, - input_image, - params.prompt.c_str(), - params.negative_prompt.c_str(), - params.clip_skip, - params.cfg_scale, - params.width, - params.height, - params.sample_method, - params.sample_steps, - params.strength, - params.seed, - params.batch_count); - } - - if (results == NULL) { - printf("generate failed\n"); - free_sd_ctx(sd_ctx); - return 1; + sd_image_t* results = img2img(sd_ctx, + input_image, + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + params.cfg_scale, + params.width, + params.height, + params.sample_method, + params.sample_steps, + params.strength, + params.seed, + params.batch_count); + results = upscaler(params, results); + save_image(params, results); } - int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth - if (params.esrgan_path.size() > 0) { - upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(), - params.n_threads, - params.wtype); +protected: + void save_image(const SDParams& params, sd_image_t* results) { + size_t last = params.output_path.find_last_of("."); + std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path; + for (int i = 0; i < params.batch_count; i++) { + if (results[i].data == NULL) { + continue; + } + std::string final_image_path = + i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png"; + stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, + results[i].data, 0, get_image_params(params, params.seed + i).c_str()); + printf("save result image to '%s'\n", final_image_path.c_str()); + free(results[i].data); + results[i].data = NULL; + } + free(results); + } - if (upscaler_ctx == NULL) { - printf("new_upscaler_ctx failed\n"); - } else { - for (int i = 0; i < params.batch_count; i++) { - if (results[i].data == NULL) { - continue; - } - sd_image_t upscaled_image = upscale(upscaler_ctx, results[i], upscale_factor); - if (upscaled_image.data == NULL) { - printf("upscale failed\n"); - continue; + sd_image_t* upscaler(const SDParams& params, sd_image_t* results) { + int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth + if (params.esrgan_path.size() > 0) { + upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(), + params.n_threads, + params.wtype); + if (upscaler_ctx == NULL) { + printf("new_upscaler_ctx failed\n"); + } else { + for (int i = 0; i < params.batch_count; i++) { + if (results[i].data == NULL) { + continue; + } + sd_image_t upscaled_image = upscale(upscaler_ctx, results[i], upscale_factor); + if (upscaled_image.data == NULL) { + printf("upscale failed\n"); + continue; + } + free(results[i].data); + results[i] = upscaled_image; } - free(results[i].data); - results[i] = upscaled_image; + free_upscaler_ctx(upscaler_ctx); } } + return results; } +}; - size_t last = params.output_path.find_last_of("."); - std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path; - for (int i = 0; i < params.batch_count; i++) { - if (results[i].data == NULL) { - continue; +int main(int argc, const char* argv[]) { + SDParams params; + + parse_args(argc, argv, params); + + if (params.mode != STREAM && !check_params(params)) { + return 1; + } + + sd_set_log_callback(sd_log_cb, (void*)¶ms); + + if (params.verbose) { + print_params(params); + printf("%s", sd_get_system_info()); + } + + if (params.mode == CONVERT) { + bool success = convert(params.model_path.c_str(), + params.vae_path.c_str(), + params.output_path.c_str(), + params.wtype); + if (!success) { + fprintf(stderr, + "convert '%s'/'%s' to '%s' failed\n", + params.model_path.c_str(), + params.vae_path.c_str(), + params.output_path.c_str()); + return 1; + } else { + printf("convert '%s'/'%s' to '%s' success\n", + params.model_path.c_str(), + params.vae_path.c_str(), + params.output_path.c_str()); + return 0; } - std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png"; - stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, - results[i].data, 0, get_image_params(params, params.seed + i).c_str()); - printf("save result image to '%s'\n", final_image_path.c_str()); - free(results[i].data); - results[i].data = NULL; } - free(results); - free_sd_ctx(sd_ctx); + auto instance = new CliInstance(params); + + if (params.mode == STREAM) { + std::cout << "you are in stream model, feel free to use txt2img or img2img" << std::endl; + while (true) { + std::string input; + std::cout << "please input args: " << std::endl; + std::getline(std::cin, input); + // hold an ignore cmd for feature to ignore the cmd not support + std::set ignore_cmd = {""}; + std::vector args = parse_cin(input, ignore_cmd); + SDParams stream_params; + const char** args_c_arr = new const char*[args.size()]; + for (int i = 0; i < args.size(); i++) { + std::string arg = args[i]; + char* c_str = new char[args[i].length() + 1]; + std::strcpy(c_str, arg.c_str()); + args_c_arr[i] = c_str; + } + parse_args(args.size(), args_c_arr, stream_params); + if (params.model_path != stream_params.model_path || + params.clip_path != stream_params.clip_path || + params.vae_path != stream_params.vae_path || + params.unet_path != stream_params.unet_path) { + instance->load_from_file(stream_params); + } + params = merge_params(params, stream_params); + if (!check_params(params)) { + continue; + } + if (params.mode == TXT2IMG) { + instance->txtimg(params); + } else if (params.mode == IMG2IMG) { + instance->imgimg(params); + } else { + return 1; + } + } + } else { + if (!params.model_path.empty()) { + if (!instance->load_from_file(params)) { + return 1; + } + } else { + if (!params.clip_path.empty() && !params.vae_path.empty() && !params.unet_path.empty()) { + if (!instance->load_from_file(params)) { + return 1; + } + } + } + if (params.mode == TXT2IMG) { + instance->txtimg(params); + } else if (params.mode == IMG2IMG) { + instance->imgimg(params); + } else { + return 0; + } + } return 0; } diff --git a/model.cpp b/model.cpp index b89edf27..59e788b1 100644 --- a/model.cpp +++ b/model.cpp @@ -1391,7 +1391,8 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend bool ModelLoader::load_tensors(std::map& tensors, ggml_backend_t backend, - std::set ignore_tensors) { + std::set ignore_tensors, + bool standalone) { std::set tensor_names_in_file; auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { const std::string& name = tensor_storage.name; @@ -1402,7 +1403,11 @@ bool ModelLoader::load_tensors(std::map& tenso real = tensors[name]; } else { if (ignore_tensors.find(name) == ignore_tensors.end()) { - LOG_WARN("unknown tensor '%s' in model file", name.c_str()); + if (standalone) { + LOG_WARN("unknown tensor '%s' in model file", name.c_str()); + } else { + LOG_DEBUG("unknown tensor '%s' in model file", name.c_str()); + } } return true; } diff --git a/model.h b/model.h index 13665a7e..7450f97a 100644 --- a/model.h +++ b/model.h @@ -119,7 +119,8 @@ class ModelLoader { bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend); bool load_tensors(std::map& tensors, ggml_backend_t backend, - std::set ignore_tensors = {}); + std::set ignore_tensors = {}, + bool standalone = true); bool save_to_gguf_file(const std::string& file_path, ggml_type type); int64_t cal_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); ~ModelLoader() = default; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 8dd5f16e..9f556dc4 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -75,25 +75,42 @@ class StableDiffusionGGML { std::map loras; std::shared_ptr denoiser = std::make_shared(); - ggml_backend_t backend = NULL; // general backend - ggml_type model_data_type = GGML_TYPE_COUNT; + schedule_t schedule = DEFAULT; + + ggml_backend_t backend = NULL; // general backend + ggml_type model_data_type = GGML_TYPE_COUNT; // runtime weight type + ggml_type wtype = GGML_TYPE_COUNT; // options weight type TinyAutoEncoder tae_first_stage; + + std::string model_path; + std::string clip_path; + std::string vae_path; + std::string unet_path; std::string taesd_path; ControlNet control_net; + ModelLoader model_loader; + StableDiffusionGGML() = default; StableDiffusionGGML(int n_threads, bool vae_decode_only, bool free_params_immediately, std::string lora_model_dir, - rng_type_t rng_type) + rng_type_t rng_type, + bool vae_tiling, + ggml_type wtype, + schedule_t schedule, + bool init_backend_immediately = true) : n_threads(n_threads), vae_decode_only(vae_decode_only), free_params_immediately(free_params_immediately), - lora_model_dir(lora_model_dir) { + lora_model_dir(lora_model_dir), + vae_tiling(vae_tiling), + wtype(wtype), + schedule(schedule) { first_stage_model.decode_only = vae_decode_only; tae_first_stage.decode_only = vae_decode_only; if (rng_type == STD_DEFAULT_RNG) { @@ -101,22 +118,473 @@ class StableDiffusionGGML { } else if (rng_type == CUDA_RNG) { rng = std::make_shared(); } + if (init_backend_immediately) { + init_backend(); + } } ~StableDiffusionGGML() { ggml_backend_free(backend); } + void init_backend() { +#ifdef SD_USE_CUBLAS + LOG_DEBUG("Using CUDA backend"); + backend = ggml_backend_cuda_init(0); +#endif +#ifdef SD_USE_METAL + LOG_DEBUG("Using Metal backend"); + ggml_metal_log_set_callback(ggml_log_callback_default, nullptr); + backend = ggml_backend_metal_init(); +#endif + + if (!backend) { + LOG_DEBUG("Using CPU backend"); + backend = ggml_backend_cpu_init(); + } +#ifdef SD_USE_FLASH_ATTENTION +#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) + LOG_WARN("Flash Attention not supported with GPU Backend"); +#else + LOG_INFO("Flash Attention enabled"); +#endif +#endif + } + + void set_options(int n_threads, + bool vae_decode_only, + bool free_params_immediately, + std::string lora_model_dir, + rng_type_t rng_type, + bool vae_tiling, + sd_type_t wtype, + schedule_t schedule) { + this->n_threads = n_threads; + bool standalone = clip_path != vae_path && vae_path != unet_path; + + std::string model_path; + if (!standalone && clip_path == vae_path) { + model_path = clip_path; + } + + if (!standalone && vae_path == unet_path) { + model_path = vae_path; + } + + if (this->vae_decode_only != vae_decode_only) { + this->vae_decode_only = vae_decode_only; + if (!vae_path.empty() && first_stage_model.params_buffer_size > 0) { + free_vae_params(); + std::string prefix; + if (standalone) { + prefix = ".vae"; + } + load_vae_from_file(vae_path, standalone, prefix); + } + } + + this->free_params_immediately = free_params_immediately; + this->lora_model_dir = std::move(lora_model_dir); + if (rng_type == STD_DEFAULT_RNG) { + rng = std::make_shared(); + } else if (rng_type == CUDA_RNG) { + rng = std::make_shared(); + } + this->vae_tiling = vae_tiling; + + if (this->wtype != (ggml_type)wtype) { + this->wtype = (ggml_type)wtype; + // TODO: can reload weight + // if (!standalone) { + // free_diffusions_params(); + // load_diffusions_from_file(model_path); + // } + } + + if (this->schedule != schedule) { + this->schedule = schedule; + apply_schedule(); + } + } + + bool load_clip_from_file(const std::string& model_path, bool standalone = true, const std::string& prefix = "te.") { + if (backend == NULL) { + LOG_ERROR("if you set init_backend_immediately false, please call init_backend first"); + return false; + } + + if (!model_path.empty()) { + LOG_INFO("loading clip from '%s'", model_path.c_str()); + if (!model_loader.init_from_file(model_path, prefix)) { + LOG_WARN("loading clip from '%s' failed", model_path.c_str()); + return false; + } + } + + version = model_loader.get_sd_version(); + if (version == VERSION_COUNT) { + LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); + return false; + } + + if (version == VERSION_XL) { + scale_factor = 0.13025f; + } + + cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version); + + LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); + + auto autodiscover_wtype = model_loader.get_sd_wtype(); + + if (wtype == GGML_TYPE_COUNT) { + model_data_type = autodiscover_wtype; + } else { + if (wtype > autodiscover_wtype) { + LOG_WARN("Stable Diffusion weight type can't set to %s, so set default: %s", + ggml_type_name(wtype), + ggml_type_name(model_data_type)); + model_data_type = autodiscover_wtype; + } else { + model_data_type = wtype; + } + } + + LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type)); + + LOG_DEBUG("loading vocab"); + std::string merges_utf8_str = model_loader.load_merges(); + if (merges_utf8_str.size() == 0) { + LOG_ERROR("get merges failed: '%s'", model_path.c_str()); + return false; + } + + cond_stage_model.tokenizer.load_from_merges(merges_utf8_str); + + if (!cond_stage_model.alloc_params_buffer(backend, model_data_type)) { + return false; + } + + LOG_DEBUG("preparing memory for clip weights"); + // prepare memory for the weights + { + cond_stage_model.init_params(); + cond_stage_model.map_by_name(tensors, "cond_stage_model."); + } + + struct ggml_init_params params; + params.mem_size = static_cast(3 * 1024) * 1024; // 3M + params.mem_buffer = NULL; + params.no_alloc = false; + // LOG_DEBUG("mem_size %u ", params.mem_size); + struct ggml_context* ctx = ggml_init(params); // for alphas_cumprod and is_using_v_parameterization check + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + + // load weights + LOG_DEBUG("loading clip weights"); + int64_t t0 = ggml_time_ms(); + + std::map tensors_need_to_load; + std::set ignore_tensors; + + for (auto& pair : tensors) { + tensors_need_to_load.insert(pair); + } + + bool success = model_loader.load_tensors(tensors_need_to_load, backend, ignore_tensors, standalone); + if (!success) { + LOG_ERROR("load tensors from clip model failed"); + ggml_free(ctx); + return false; + } + + LOG_INFO("clip memory buffer size = %.2fMB", cond_stage_model.params_buffer_size / 1024.0 / 1024.0); + int64_t t1 = ggml_time_ms(); + LOG_INFO("loading clip model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000); + ggml_free(ctx); + clip_path = model_path; + return true; + } + + void free_clip_params() { + if (cond_stage_model.params_buffer_size > 0) { + cond_stage_model.free_params_buffer(); + } + } + + bool load_unet_from_file(const std::string& model_path, + bool standalone = true, + const std::string& prefix = "unet.") { + if (backend == NULL) { + LOG_ERROR("if you set init_backend_immediately false, please call init_backend first"); + return false; + } + + if (version == VERSION_COUNT) { + LOG_ERROR("get sd version from file failed: '%s' ,make sure clip model has loaded", model_path.c_str()); + return false; + } + + if (!model_path.empty() && standalone) { + LOG_INFO("loading unet from '%s'", model_path.c_str()); + if (!model_loader.init_from_file(model_path, prefix)) { + LOG_WARN("loading unet from '%s' failed", model_path.c_str()); + return false; + } + } + + diffusion_model = UNetModel(version); + if (!diffusion_model.alloc_params_buffer(backend, model_data_type)) { + return false; + } + + LOG_DEBUG("preparing memory for unet weights"); + // prepare memory for the weights + { + // diffusion_model(UNetModel) + diffusion_model.init_params(); + diffusion_model.map_by_name(tensors, "model.diffusion_model."); + } + + struct ggml_init_params params; + params.mem_size = static_cast(3 * 1024) * 1024; // 3M + params.mem_buffer = NULL; + params.no_alloc = false; + + struct ggml_context* ctx = ggml_init(params); // for alphas_cumprod and is_using_v_parameterization check + + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + + // load weights + LOG_DEBUG("loading weights"); + int64_t t0 = ggml_time_ms(); + + std::map tensors_need_to_load; + std::set ignore_tensors; + ggml_tensor* alphas_cumprod_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, TIMESTEPS); + calculate_alphas_cumprod((float*)alphas_cumprod_tensor->data); + tensors_need_to_load["alphas_cumprod"] = alphas_cumprod_tensor; + for (auto& pair : tensors) { + const std::string& name = pair.first; + if (starts_with(name, "cond_stage_model.") || starts_with(name, "first_stage_model.")) { + ignore_tensors.insert(name); + continue; + } + tensors_need_to_load.insert(pair); + } + bool success = model_loader.load_tensors(tensors_need_to_load, backend, ignore_tensors, standalone); + if (!success) { + LOG_ERROR("load unet tensors from model loader failed"); + ggml_free(ctx); + return false; + } + LOG_INFO("unet memory buffer size = %.2fMB", diffusion_model.params_buffer_size / 1024.0 / 1024.0); + int64_t t1 = ggml_time_ms(); + LOG_INFO("loading unet model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000); + + bool is_using_v_parameterization = false; + if (version == VERSION_2_x) { + if (is_using_v_parameterization_for_sd2(ctx)) { + is_using_v_parameterization = true; + } + } + + if (is_using_v_parameterization) { + denoiser = std::make_shared(); + LOG_INFO("running in v-prediction mode"); + } else { + LOG_INFO("running in eps-prediction mode"); + } + + apply_schedule(); + ggml_free(ctx); + unet_path = model_path; + return true; + } + + void free_unet_params() { + if (diffusion_model.params_buffer_size > 0) { + diffusion_model.free_params_buffer(); + } + } + + bool load_vae_from_file(const std::string& model_path, + bool standalone = true, + const std::string& prefix = "vae.") { + if (backend == NULL) { + LOG_ERROR("if you set init_backend_immediately false, please call init_backend first"); + return false; + } + + if (version == VERSION_COUNT) { + LOG_ERROR("get sd version from file failed: '%s' ,please call load_clip_from_file first", + model_path.c_str()); + return false; + } + + if (!model_path.empty() && standalone) { + LOG_INFO("loading vae from '%s'", model_path.c_str()); + if (!model_loader.init_from_file(model_path, prefix)) { + LOG_WARN("loading vae from '%s' failed", model_path.c_str()); + return false; + } + } + + ggml_type vae_type = model_data_type; + if (version == VERSION_XL) { + vae_type = GGML_TYPE_F32; // avoid nan, not work... + } + + if (!first_stage_model.alloc_params_buffer(backend, vae_type)) { + return false; + } + + LOG_DEBUG("preparing memory for vae weights"); + // prepare memory for the weights + { + first_stage_model.init_params(); + first_stage_model.map_by_name(tensors, "first_stage_model."); + } + + struct ggml_init_params params; + params.mem_size = static_cast(10 * 1024) * 1024; // 3M + params.mem_buffer = NULL; + params.no_alloc = false; + // LOG_DEBUG("mem_size %u ", params.mem_size); + struct ggml_context* ctx = ggml_init(params); // for alphas_cumprod and is_using_v_parameterization check + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + + // load weights + LOG_DEBUG("loading weights"); + int64_t t0 = ggml_time_ms(); + + std::map tensors_need_to_load; + std::set ignore_tensors; + for (auto& pair : tensors) { + const std::string& name = pair.first; + if (vae_decode_only && + (starts_with(name, "first_stage_model.encoder") || starts_with(name, "first_stage_model.quant"))) { + ignore_tensors.insert(name); + continue; + } + tensors_need_to_load.insert(pair); + } + bool success = model_loader.load_tensors(tensors_need_to_load, backend, ignore_tensors, standalone); + if (!success) { + LOG_ERROR("load tensors from model loader failed"); + ggml_free(ctx); + return false; + } + LOG_INFO("vae memory buffer size = %.2fMB", first_stage_model.params_buffer_size / 1024.0 / 1024.0); + int64_t t1 = ggml_time_ms(); + LOG_INFO("loading vae model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000); + ggml_free(ctx); + vae_path = model_path; + return true; + } + + void free_vae_params() { + if (first_stage_model.params_buffer_size > 0) { + first_stage_model.free_params_buffer(); + } + } + + // load the all model from one file + bool load_diffusions_from_file(const std::string& model_path) { + LOG_INFO("loading model from '%s'", model_path.c_str()); + if (!load_clip_from_file(model_path, false, "")) { + free_clip_params(); + return false; + } + + if (!load_unet_from_file(model_path, false, "")) { + free_clip_params(); + free_unet_params(); + return false; + } + + if (!load_vae_from_file(model_path, false, "")) { + free_clip_params(); + free_unet_params(); + free_vae_params(); + return false; + } + + return true; + } + + void free_diffusions_params() { + free_clip_params(); + LOG_INFO("free clip params"); + + free_unet_params(); + LOG_INFO("free unet params"); + + free_vae_params(); + LOG_INFO("free vae params"); + } + + bool load_taesd_from_file(const std::string& taesd_path) { + if (first_stage_model.params_buffer_size > 0) { + free_vae_params(); + } + if (taesd_path.empty() || !tae_first_stage.load_from_file(taesd_path, backend)) { + return false; + } + + this->taesd_path = taesd_path; + use_tiny_autoencoder = true; + return true; + } + + void free_taesd_params() { + if (tae_first_stage.params_buffer_size > 0) { + tae_first_stage.free_params_buffer(); + } + } + + bool load_control_net_from_file(const std::string& control_net_path, const std::string& embeddings_path, bool control_net_cpu) { + if (!control_net_path.empty()) { + ggml_backend_t cn_backend = NULL; + if (control_net_cpu && !ggml_backend_is_cpu(backend)) { + LOG_DEBUG("ControlNet: Using CPU backend"); + cn_backend = ggml_backend_cpu_init(); + } else { + cn_backend = backend; + } + if (!control_net.load_from_file(control_net_path, cn_backend, GGML_TYPE_F16 /* just f16 controlnet models */)) { + return false; + } + } + } + + void free_control_net_params() { + if (control_net.params_buffer_size > 0) { + control_net.free_params_buffer(); + } + } + bool load_from_file(const std::string& model_path, const std::string& vae_path, - const std::string control_net_path, - const std::string embeddings_path, + const std::string& control_net_path, + const std::string& embeddings_path, const std::string& taesd_path, - bool vae_tiling_, + bool vae_tiling, ggml_type wtype, schedule_t schedule, bool control_net_cpu) { - use_tiny_autoencoder = taesd_path.size() > 0; + this->use_tiny_autoencoder = taesd_path.size() > 0; + this->taesd_path = taesd_path; + this->vae_tiling = vae_tiling; #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); @@ -141,8 +609,6 @@ class StableDiffusionGGML { LOG_INFO("loading model from '%s'", model_path.c_str()); ModelLoader model_loader; - vae_tiling = vae_tiling_; - if (!model_loader.init_from_file(model_path)) { LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); return false; @@ -164,7 +630,6 @@ class StableDiffusionGGML { scale_factor = 0.13025f; } cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version); - diffusion_model = UNetModel(version); LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); if (wtype == GGML_TYPE_COUNT) { @@ -374,6 +839,36 @@ class StableDiffusionGGML { return result < -1; } + void apply_schedule() const { + float alphas_cumprod_tensor[TIMESTEPS]; + calculate_alphas_cumprod(alphas_cumprod_tensor); + if (schedule != DEFAULT) { + switch (schedule) { + case DISCRETE: + LOG_INFO("running with discrete schedule"); + denoiser->schedule = std::make_shared(); + break; + case KARRAS: + LOG_INFO("running with Karras schedule"); + denoiser->schedule = std::make_shared(); + break; + case DEFAULT: + // Don't touch anything. + break; + default: + LOG_ERROR("Unknown schedule %i", schedule); + abort(); + } + } + + for (int i = 0; i < TIMESTEPS; i++) { + denoiser->schedule->alphas_cumprod[i] = alphas_cumprod_tensor[i]; + denoiser->schedule->sigmas[i] = std::sqrt( + (1 - denoiser->schedule->alphas_cumprod[i]) / denoiser->schedule->alphas_cumprod[i]); + denoiser->schedule->log_sigmas[i] = std::log(denoiser->schedule->sigmas[i]); + } + } + void apply_lora(const std::string& lora_name, float multiplier) { int64_t t0 = ggml_time_ms(); std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors"); @@ -1136,54 +1631,31 @@ struct sd_ctx_t { StableDiffusionGGML* sd = NULL; }; -sd_ctx_t* new_sd_ctx(const char* model_path_c_str, - const char* vae_path_c_str, - const char* taesd_path_c_str, - const char* control_net_path_c_str, - const char* lora_model_dir_c_str, - const char* embed_dir_c_str, +sd_ctx_t* new_sd_ctx(int n_threads, bool vae_decode_only, - bool vae_tiling, bool free_params_immediately, - int n_threads, - enum sd_type_t wtype, + const char* lora_model_dir_c_str, enum rng_type_t rng_type, + bool vae_tiling, + enum sd_type_t wtype, enum schedule_t s, - bool keep_control_net_cpu) { + bool keep_control_net_cpu, + bool init_backend_immediately) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == NULL) { return NULL; } - std::string model_path(model_path_c_str); - std::string vae_path(vae_path_c_str); - std::string taesd_path(taesd_path_c_str); - std::string control_net_path(control_net_path_c_str); - std::string embd_path(embed_dir_c_str); std::string lora_model_dir(lora_model_dir_c_str); sd_ctx->sd = new StableDiffusionGGML(n_threads, vae_decode_only, free_params_immediately, lora_model_dir, - rng_type); - if (sd_ctx->sd == NULL) { - return NULL; - } - - if (!sd_ctx->sd->load_from_file(model_path, - vae_path, - control_net_path, - embd_path, - taesd_path, - vae_tiling, - (ggml_type)wtype, - s, - keep_control_net_cpu)) { - delete sd_ctx->sd; - sd_ctx->sd = NULL; - free(sd_ctx); - return NULL; - } + rng_type, + vae_tiling, + static_cast(wtype), + s, + init_backend_immediately); return sd_ctx; } @@ -1195,6 +1667,120 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } +void init_backend(sd_ctx_t* sd_ctx) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return; + } + sd_ctx->sd->init_backend(); +} + +void set_options(sd_ctx_t* sd_ctx, + int n_threads, + bool vae_decode_only, + bool free_params_immediately, + const char* lora_model_dir, + rng_type_t rng_type, + bool vae_tiling, + sd_type_t wtype, + schedule_t schedule) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return; + } + sd_ctx->sd->set_options( + n_threads, + vae_decode_only, + free_params_immediately, + std::string(lora_model_dir), + rng_type, + vae_tiling, + wtype, + schedule); +} + +bool load_clip_from_file(sd_ctx_t* sd_ctx, const char* model_path, const char* prefix) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return false; + } + return sd_ctx->sd->load_clip_from_file(std::string(model_path), true, std::string(prefix)); +} + +void free_clip_params(sd_ctx_t* sd_ctx) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return; + } + sd_ctx->sd->free_clip_params(); +} + +bool load_unet_from_file(sd_ctx_t* sd_ctx, const char* model_path, const char* prefix) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return false; + } + return sd_ctx->sd->load_unet_from_file(std::string(model_path), true, std::string(prefix)); +} + +void free_unet_params(sd_ctx_t* sd_ctx) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return; + } + sd_ctx->sd->free_unet_params(); +} + +bool load_vae_from_file(sd_ctx_t* sd_ctx, const char* model_path, const char* prefix) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return false; + } + return sd_ctx->sd->load_vae_from_file(std::string(model_path), true, std::string(prefix)); +} + +void free_vae_params(sd_ctx_t* sd_ctx) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return; + } + sd_ctx->sd->free_vae_params(); +} + +bool load_taesd_from_file(sd_ctx_t* sd_ctx, const char* model_path) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return false; + } + return sd_ctx->sd->load_taesd_from_file(std::string(model_path)); +} + +void free_taesd_params(sd_ctx_t* sd_ctx) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return; + } + sd_ctx->sd->free_taesd_params(); +} + +// load all model from one file +bool load_diffusions_from_file(sd_ctx_t* sd_ctx, const char* model_path) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return false; + } + return sd_ctx->sd->load_diffusions_from_file(std::string(model_path)); +} + +// free all model from one file +void free_diffusions_params(sd_ctx_t* sd_ctx) { + if (sd_ctx == NULL || sd_ctx->sd == NULL) { + LOG_ERROR("must call new_sd_ctx first"); + return; + } + return sd_ctx->sd->free_diffusions_params(); +} + sd_image_t* txt2img(sd_ctx_t* sd_ctx, const char* prompt_c_str, const char* negative_prompt_c_str, diff --git a/stable-diffusion.h b/stable-diffusion.h index a8c9f532..bc3637da 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -102,20 +102,16 @@ typedef struct { typedef struct sd_ctx_t sd_ctx_t; -SD_API sd_ctx_t* new_sd_ctx(const char* model_path, - const char* vae_path, - const char* taesd_path, - const char* control_net_path_c_str, - const char* lora_model_dir, - const char* embed_dir_c_str, +SD_API sd_ctx_t* new_sd_ctx(int n_threads, bool vae_decode_only, - bool vae_tiling, bool free_params_immediately, - int n_threads, - enum sd_type_t wtype, + const char* lora_model_dir_c_str, enum rng_type_t rng_type, + bool vae_tiling, + enum sd_type_t wtype, enum schedule_t s, - bool keep_control_net_cpu); + bool keep_control_net_cpu, + bool init_backend_immediately); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); @@ -156,6 +152,38 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx); SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor); +SD_API void init_backend(sd_ctx_t* sd_ctx); + +SD_API void set_options(sd_ctx_t* sd_ctx, + int n_threads, + bool vae_decode_only, + bool free_params_immediately, + const char* lora_model_dir, + rng_type_t rng_type, + bool vae_tiling, + sd_type_t wtype, + schedule_t schedule); + +SD_API bool load_clip_from_file(sd_ctx_t* sd_ctx, const char* model_path, const char* prefix = "te."); + +SD_API void free_clip_params(sd_ctx_t* sd_ctx); + +SD_API bool load_unet_from_file(sd_ctx_t* sd_ctx, const char* model_path, const char* prefix = "unet."); + +SD_API void free_unet_params(sd_ctx_t* sd_ctx); + +SD_API bool load_vae_from_file(sd_ctx_t* sd_ctx, const char* model_path, const char* prefix = "vae."); + +SD_API void free_vae_params(sd_ctx_t* sd_ctx); + +SD_API bool load_taesd_from_file(sd_ctx_t* sd_ctx, const char* model_path); + +SD_API void free_taesd_params(sd_ctx_t* sd_ctx); + +SD_API bool load_diffusions_from_file(sd_ctx_t* sd_ctx, const char* model_path); + +SD_API void free_diffusions_params(sd_ctx_t* sd_ctx); + SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type); #ifdef __cplusplus