Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add prediction argument #334

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ arguments:
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
-b, --batch-count COUNT number of images to generate.
--schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
--prediction {eps, v, flow} Prediction mode (default: eps)
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
--vae-tiling process vae in tiles to reduce memory usage
Expand Down
29 changes: 29 additions & 0 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ const char* schedule_str[] = {
"gits",
};

const char* prediction_str[] = {
"default",
"eps",
"v",
"flow",
};

const char* modes_str[] = {
"txt2img",
"img2img",
Expand Down Expand Up @@ -105,6 +112,7 @@ struct SDParams {

sample_method_t sample_method = EULER_A;
schedule_t schedule = DEFAULT;
prediction_t prediction = DEFAULT_PRED;
int sample_steps = 20;
float strength = 0.75f;
float control_strength = 0.9f;
Expand Down Expand Up @@ -156,6 +164,7 @@ void print_params(SDParams params) {
printf(" height: %d\n", params.height);
printf(" sample_method: %s\n", sample_method_str[params.sample_method]);
printf(" schedule: %s\n", schedule_str[params.schedule]);
printf(" prediction: %s\n", prediction_str[params.prediction]);
printf(" sample_steps: %d\n", params.sample_steps);
printf(" strength(img2img): %.2f\n", params.strength);
printf(" rng: %s\n", rng_type_to_str[params.rng_type]);
Expand Down Expand Up @@ -208,6 +217,8 @@ void print_usage(int argc, const char* argv[]) {
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -b, --batch-count COUNT number of images to generate.\n");
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
printf(" --prediction {eps, v, flow}\n");
printf(" prediction mode (default: eps)\n");
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
Expand Down Expand Up @@ -496,6 +507,23 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.schedule = (schedule_t)schedule_found;
} else if (arg == "--prediction") {
if (++i >= argc) {
invalid_arg = true;
break;
}
const char* prediction_selected = argv[i];
int prediction_found = -1;
for (int n = 0; n < N_PREDICTIONS; n++) {
if (!strcmp(prediction_selected, prediction_str[n])) {
prediction_found = n;
}
}
if (prediction_found == -1) {
invalid_arg = true;
break;
}
params.prediction = (prediction_t)prediction_found;
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_arg = true;
Expand Down Expand Up @@ -780,6 +808,7 @@ int main(int argc, const char* argv[]) {
params.wtype,
params.rng_type,
params.schedule,
params.prediction,
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu);
Expand Down
66 changes: 44 additions & 22 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class StableDiffusionGGML {
bool vae_tiling_,
ggml_type wtype,
schedule_t schedule,
prediction_t prediction,
bool clip_on_cpu,
bool control_net_cpu,
bool vae_on_cpu) {
Expand Down Expand Up @@ -500,32 +501,51 @@ class StableDiffusionGGML {
int64_t t1 = ggml_time_ms();
LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000);

// check is_using_v_parameterization_for_sd2
bool is_using_v_parameterization = false;
if (version == VERSION_SD2) {
if (is_using_v_parameterization_for_sd2(ctx)) {
if (prediction != DEFAULT_PRED) {
switch (prediction) {
case EPS_PRED:
LOG_INFO("running in eps-prediction mode");
break;
case V_PRED:
LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>();
break;
case FLOW_PRED:
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
break;
default:
LOG_ERROR("Unknown parametrization %i", prediction);
abort();
}
} else {
// check is_using_v_parameterization_for_sd2
bool is_using_v_parameterization = false;
if (version == VERSION_SD2) {
if (is_using_v_parameterization_for_sd2(ctx)) {
is_using_v_parameterization = true;
}
} else if (version == VERSION_SVD) {
// TODO: V_PREDICTION_EDM
is_using_v_parameterization = true;
}
} else if (version == VERSION_SVD) {
// TODO: V_PREDICTION_EDM
is_using_v_parameterization = true;
}

if (version == VERSION_SD3_2B) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
LOG_INFO("running in Flux FLOW mode");
float shift = 1.15f;
if (version == VERSION_FLUX_SCHNELL) {
shift = 1.0f; // TODO: validate
if (version == VERSION_SD3_2B) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
LOG_INFO("running in Flux FLOW mode");
float shift = 1.15f;
if (version == VERSION_FLUX_SCHNELL) {
shift = 1.0f; // TODO: validate
}
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
} else if (is_using_v_parameterization) {
LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>();
} else {
LOG_INFO("running in eps-prediction mode");
}
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
} else if (is_using_v_parameterization) {
LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>();
} else {
LOG_INFO("running in eps-prediction mode");
}

if (schedule != DEFAULT) {
Expand Down Expand Up @@ -1023,6 +1043,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
enum sd_type_t wtype,
enum rng_type_t rng_type,
enum schedule_t s,
enum prediction_t p,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu) {
Expand Down Expand Up @@ -1062,6 +1083,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
vae_tiling,
(ggml_type)wtype,
s,
p,
keep_clip_on_cpu,
keep_control_net_cpu,
keep_vae_on_cpu)) {
Expand Down
9 changes: 9 additions & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ enum schedule_t {
N_SCHEDULES
};

enum prediction_t {
DEFAULT_PRED,
EPS_PRED,
V_PRED,
FLOW_PRED,
N_PREDICTIONS
};

// same as enum ggml_type
enum sd_type_t {
SD_TYPE_F32 = 0,
Expand Down Expand Up @@ -139,6 +147,7 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
enum sd_type_t wtype,
enum rng_type_t rng_type,
enum schedule_t s,
enum prediction_t p,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu);
Expand Down
Loading