diff --git a/README.md b/README.md index 95fd5e66..e5282536 100644 --- a/README.md +++ b/README.md @@ -227,6 +227,7 @@ arguments: -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) -b, --batch-count COUNT number of images to generate. --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete) + --prediction {eps, v, flow} Prediction mode (default: eps) --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x --vae-tiling process vae in tiles to reduce memory usage diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index ceae27b8..820a8a6c 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -51,6 +51,13 @@ const char* schedule_str[] = { "gits", }; +const char* prediction_str[] = { + "default", + "eps", + "v", + "flow", +}; + const char* modes_str[] = { "txt2img", "img2img", @@ -105,6 +112,7 @@ struct SDParams { sample_method_t sample_method = EULER_A; schedule_t schedule = DEFAULT; + prediction_t prediction = DEFAULT_PRED; int sample_steps = 20; float strength = 0.75f; float control_strength = 0.9f; @@ -156,6 +164,7 @@ void print_params(SDParams params) { printf(" height: %d\n", params.height); printf(" sample_method: %s\n", sample_method_str[params.sample_method]); printf(" schedule: %s\n", schedule_str[params.schedule]); + printf(" prediction: %s\n", prediction_str[params.prediction]); printf(" sample_steps: %d\n", params.sample_steps); printf(" strength(img2img): %.2f\n", params.strength); printf(" rng: %s\n", rng_type_to_str[params.rng_type]); @@ -208,6 +217,8 @@ void print_usage(int argc, const char* argv[]) { printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -b, --batch-count COUNT number of images to generate.\n"); printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n"); + printf(" --prediction {eps, v, flow}\n"); + printf(" prediction mode (default: eps)\n"); printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); @@ -496,6 +507,23 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.schedule = (schedule_t)schedule_found; + } else if (arg == "--prediction") { + if (++i >= argc) { + invalid_arg = true; + break; + } + const char* prediction_selected = argv[i]; + int prediction_found = -1; + for (int n = 0; n < N_PREDICTIONS; n++) { + if (!strcmp(prediction_selected, prediction_str[n])) { + prediction_found = n; + } + } + if (prediction_found == -1) { + invalid_arg = true; + break; + } + params.prediction = (prediction_t)prediction_found; } else if (arg == "-s" || arg == "--seed") { if (++i >= argc) { invalid_arg = true; @@ -780,6 +808,7 @@ int main(int argc, const char* argv[]) { params.wtype, params.rng_type, params.schedule, + params.prediction, params.clip_on_cpu, params.control_net_cpu, params.vae_on_cpu); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 07b59bb8..586c6d76 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -149,6 +149,7 @@ class StableDiffusionGGML { bool vae_tiling_, ggml_type wtype, schedule_t schedule, + prediction_t prediction, bool clip_on_cpu, bool control_net_cpu, bool vae_on_cpu) { @@ -500,32 +501,51 @@ class StableDiffusionGGML { int64_t t1 = ggml_time_ms(); LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000); - // check is_using_v_parameterization_for_sd2 - bool is_using_v_parameterization = false; - if (version == VERSION_SD2) { - if (is_using_v_parameterization_for_sd2(ctx)) { + if (prediction != DEFAULT_PRED) { + switch (prediction) { + case EPS_PRED: + LOG_INFO("running in eps-prediction mode"); + break; + case V_PRED: + LOG_INFO("running in v-prediction mode"); + denoiser = std::make_shared(); + break; + case FLOW_PRED: + LOG_INFO("running in FLOW mode"); + denoiser = std::make_shared(); + break; + default: + LOG_ERROR("Unknown parametrization %i", prediction); + abort(); + } + } else { + // check is_using_v_parameterization_for_sd2 + bool is_using_v_parameterization = false; + if (version == VERSION_SD2) { + if (is_using_v_parameterization_for_sd2(ctx)) { + is_using_v_parameterization = true; + } + } else if (version == VERSION_SVD) { + // TODO: V_PREDICTION_EDM is_using_v_parameterization = true; } - } else if (version == VERSION_SVD) { - // TODO: V_PREDICTION_EDM - is_using_v_parameterization = true; - } - if (version == VERSION_SD3_2B) { - LOG_INFO("running in FLOW mode"); - denoiser = std::make_shared(); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { - LOG_INFO("running in Flux FLOW mode"); - float shift = 1.15f; - if (version == VERSION_FLUX_SCHNELL) { - shift = 1.0f; // TODO: validate + if (version == VERSION_SD3_2B) { + LOG_INFO("running in FLOW mode"); + denoiser = std::make_shared(); + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + LOG_INFO("running in Flux FLOW mode"); + float shift = 1.15f; + if (version == VERSION_FLUX_SCHNELL) { + shift = 1.0f; // TODO: validate + } + denoiser = std::make_shared(shift); + } else if (is_using_v_parameterization) { + LOG_INFO("running in v-prediction mode"); + denoiser = std::make_shared(); + } else { + LOG_INFO("running in eps-prediction mode"); } - denoiser = std::make_shared(shift); - } else if (is_using_v_parameterization) { - LOG_INFO("running in v-prediction mode"); - denoiser = std::make_shared(); - } else { - LOG_INFO("running in eps-prediction mode"); } if (schedule != DEFAULT) { @@ -1023,6 +1043,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, enum sd_type_t wtype, enum rng_type_t rng_type, enum schedule_t s, + enum prediction_t p, bool keep_clip_on_cpu, bool keep_control_net_cpu, bool keep_vae_on_cpu) { @@ -1062,6 +1083,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, vae_tiling, (ggml_type)wtype, s, + p, keep_clip_on_cpu, keep_control_net_cpu, keep_vae_on_cpu)) { diff --git a/stable-diffusion.h b/stable-diffusion.h index 0d4cc1fd..74fd7fc0 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -57,6 +57,14 @@ enum schedule_t { N_SCHEDULES }; +enum prediction_t { + DEFAULT_PRED, + EPS_PRED, + V_PRED, + FLOW_PRED, + N_PREDICTIONS +}; + // same as enum ggml_type enum sd_type_t { SD_TYPE_F32 = 0, @@ -139,6 +147,7 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, enum sd_type_t wtype, enum rng_type_t rng_type, enum schedule_t s, + enum prediction_t p, bool keep_clip_on_cpu, bool keep_control_net_cpu, bool keep_vae_on_cpu);