Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image Preview #416

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class StableDiffusionGGML {
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
backend = ggml_backend_vk_init(device);
}
if(!backend) {
if (!backend) {
LOG_WARN("Failed to initialize Vulkan backend");
}
#endif
Expand All @@ -181,7 +181,7 @@ class StableDiffusionGGML {
backend = ggml_backend_cpu_init();
}
#ifdef SD_USE_FLASH_ATTENTION
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL) || defined(SD_USE_VULKAN)
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN)
LOG_WARN("Flash Attention not supported with GPU Backend");
#else
LOG_INFO("Flash Attention enabled");
Expand Down Expand Up @@ -762,7 +762,8 @@ class StableDiffusionGGML {
sample_method_t method,
const std::vector<float>& sigmas,
int start_merge_step,
SDCondition id_cond) {
SDCondition id_cond,
size_t batch_num = 0) {
size_t steps = sigmas.size() - 1;
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
// print_ggml_tensor(noise);
Expand Down Expand Up @@ -885,6 +886,9 @@ class StableDiffusionGGML {
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
}

send_result_step_callback(denoised, batch_num, step);

return denoised;
};

Expand Down Expand Up @@ -998,6 +1002,47 @@ class StableDiffusionGGML {
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
return compute_first_stage(work_ctx, x, true);
}

sd_result_cb_t result_cb = nullptr;
void* result_cb_data = nullptr;

void send_result_callback(ggml_context* work_ctx, ggml_tensor* x, size_t number) {
if (result_cb == nullptr) {
return;
}

struct ggml_tensor* img = decode_first_stage(work_ctx, x);
auto image_data = sd_tensor_to_image(img);

result_cb(number, image_data, result_cb_data);
}

sd_result_step_cb_t result_step_cb = nullptr;
void* result_step_cb_data = nullptr;

void send_result_step_callback(ggml_tensor* x, size_t number, size_t step) {
if (result_step_cb == nullptr) {
return;
}

struct ggml_init_params params {};
params.mem_size = static_cast<size_t>(10 * 1024) * 1024;
params.mem_buffer = nullptr;
params.no_alloc = false;

struct ggml_context* work_ctx = ggml_init(params);
if (!work_ctx) {
return;
}

struct ggml_tensor* result = ggml_dup_tensor(work_ctx, x);
copy_ggml_tensor(result, x);

struct ggml_tensor* img = decode_first_stage(work_ctx, result);
result_step_cb(number, step, sd_tensor_to_image(img), result_step_cb_data);

ggml_free(work_ctx);
}
};

/*================================================= SD API ==================================================*/
Expand Down Expand Up @@ -1081,6 +1126,16 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx);
}

void sd_ctx_set_result_callback(sd_ctx_t* sd_ctx, sd_result_cb_t cb, void* data) {
sd_ctx->sd->result_cb = cb;
sd_ctx->sd->result_cb_data = data;
}

void sd_ctx_set_result_step_callback(sd_ctx_t* sd_ctx, sd_result_step_cb_t cb, void* data) {
sd_ctx->sd->result_step_cb = cb;
sd_ctx->sd->result_step_cb_data = data;
}

sd_image_t* generate_image(sd_ctx_t* sd_ctx,
struct ggml_context* work_ctx,
ggml_tensor* init_latent,
Expand Down Expand Up @@ -1308,11 +1363,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
sample_method,
sigmas,
start_merge_step,
id_cond);
id_cond,
b);
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
// print_ggml_tensor(x_0);
int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);

if (sd_ctx->sd->result_cb != nullptr) {
sd_ctx->sd->send_result_callback(work_ctx, x_0, b);
continue;
}

final_latents.push_back(x_0);
}

Expand All @@ -1322,6 +1384,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
int64_t t3 = ggml_time_ms();
LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000);

if (sd_ctx->sd->result_cb != nullptr) {
return nullptr;
}

// Decode to image
LOG_INFO("decoding %zu latents", final_latents.size());
std::vector<struct ggml_tensor*> decoded_images; // collect decoded images
Expand Down
4 changes: 4 additions & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ enum sd_log_level_t {

typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
typedef void (*sd_result_cb_t)(size_t number, uint8_t* image_data, void* data);
typedef void (*sd_result_step_cb_t)(size_t number, size_t step, uint8_t* image_data, void* data);

SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
Expand Down Expand Up @@ -144,6 +146,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
bool keep_vae_on_cpu);

SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
SD_API void sd_ctx_set_result_callback(sd_ctx_t* sd_ctx, sd_result_cb_t cb, void* data);
SD_API void sd_ctx_set_result_step_callback(sd_ctx_t* sd_ctx, sd_result_step_cb_t cb, void* data);

SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
const char* prompt,
Expand Down
Loading