Skip to content

vae tiling improvements: encoding support and adaptative overlap #484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
157 changes: 131 additions & 26 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,10 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
struct ggml_tensor* output,
int x,
int y,
int overlap) {
int overlap_x,
int overlap_y,
int x_skip = 0,
int y_skip = 0) {
int64_t width = input->ne[0];
int64_t height = input->ne[1];
int64_t channels = input->ne[2];
Expand All @@ -472,17 +475,17 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
int64_t img_height = output->ne[1];

GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int iy = y_skip; iy < height; iy++) {
for (int ix = x_skip; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
float new_value = ggml_tensor_get_f32(input, ix, iy, k);
if (overlap > 0) { // blend colors in overlapped area
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k);

const float x_f_0 = (x > 0) ? ix / float(overlap) : 1;
const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1;
const float y_f_0 = (y > 0) ? iy / float(overlap) : 1;
const float y_f_1 = (y < (img_height - height)) ? (height - iy) / float(overlap) : 1;
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;

const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
Expand Down Expand Up @@ -595,20 +598,96 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {

typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;

__STATIC_INLINE__ void
sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {

int tile_overlap = (tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap;

num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;

if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
// if tiles don't fit perfectly using the desired overlap
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
num_tiles_dim++;
}

tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1));
if (num_tiles_dim <= 2) {
if (small_dim <= tile_size) {
num_tiles_dim = 1;
tile_overlap_factor_dim = 0;
} else {
num_tiles_dim = 2;
tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size;
}
}
}

// Tiling
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale,
const int p_tile_size_x, const int p_tile_size_y,
const float tile_overlap_factor, on_tile_process on_processing) {

output = ggml_set_f32(output, 0);

int input_width = (int)input->ne[0];
int input_height = (int)input->ne[1];
int output_width = (int)output->ne[0];
int output_height = (int)output->ne[1];

GGML_ASSERT(input_width / output_width == input_height / output_height && output_width / input_width == output_height / input_height);
GGML_ASSERT(input_width / output_width == scale || output_width / input_width == scale);

int small_width = output_width;
int small_height = output_height;

bool big_out = output_width > input_width;
if (big_out) {
// Ex: decode
small_width = input_width;
small_height = input_height;
}

int num_tiles_x;
float tile_overlap_factor_x;
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);

int num_tiles_y;
float tile_overlap_factor_y;
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);

LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);

GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2

int tile_overlap = (int32_t)(tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap;
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;

int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y);
int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;

int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;

int input_tile_size_x = tile_size_x;
int input_tile_size_y = tile_size_y;
int output_tile_size_x = tile_size_x;
int output_tile_size_y = tile_size_y;

if (big_out) {
output_tile_size_x *= scale;
output_tile_size_y *= scale;
} else {
input_tile_size_x *= scale;
input_tile_size_y *= scale;
}

struct ggml_init_params params = {};
params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk
params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk
params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * sizeof(float); // input chunk
params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * sizeof(float); // output chunk
params.mem_size += 3 * ggml_tensor_overhead();
params.mem_buffer = NULL;
params.no_alloc = false;
Expand All @@ -623,29 +702,50 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
}

// tiling
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1);
on_processing(input_tile, NULL, true);
int num_tiles = ceil((float)input_width / non_tile_overlap) * ceil((float)input_height / non_tile_overlap);
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], 1);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], 1);
int num_tiles = num_tiles_x * num_tiles_y;
LOG_INFO("processing %i tiles", num_tiles);
pretty_progress(1, num_tiles, 0.0f);
pretty_progress(0, num_tiles, 0.0f);
int tile_count = 1;
bool last_y = false, last_x = false;
float last_time = 0.0f;
for (int y = 0; y < input_height && !last_y; y += non_tile_overlap) {
if (y + tile_size >= input_height) {
y = input_height - tile_size;
for (int y = 0; y < small_height && !last_y; y += non_tile_overlap_y) {
int dy = 0;
if (y + tile_size_y >= small_height) {
int _y = y;
y = small_height - tile_size_y;
dy = _y - y;
if (big_out) {
dy *= scale;
}
last_y = true;
}
for (int x = 0; x < input_width && !last_x; x += non_tile_overlap) {
if (x + tile_size >= input_width) {
x = input_width - tile_size;
for (int x = 0; x < small_width && !last_x; x += non_tile_overlap_x) {
int dx = 0;
if (x + tile_size_x >= small_width) {
int _x = x;
x = small_width - tile_size_x;
dx = _x - x;
if (big_out) {
dx *= scale;
}
last_x = true;
}

int x_in = big_out ? x : scale * x;
int y_in = big_out ? y : scale * y;
int x_out = big_out ? x * scale : x;
int y_out = big_out ? y * scale : y;

int overlap_x_out = big_out ? tile_overlap_x * scale : tile_overlap_x;
int overlap_y_out = big_out ? tile_overlap_y * scale : tile_overlap_y;

int64_t t1 = ggml_time_ms();
ggml_split_tensor_2d(input, input_tile, x, y);
ggml_split_tensor_2d(input, input_tile, x_in, y_in);
on_processing(input_tile, output_tile, false);
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale);
ggml_merge_tensor_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);

int64_t t2 = ggml_time_ms();
last_time = (t2 - t1) / 1000.0f;
pretty_progress(tile_count, num_tiles, last_time);
Expand All @@ -659,6 +759,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
ggml_free(tiles_ctx);
}

__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale,
const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
}

__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
struct ggml_tensor* a) {
const float eps = 1e-6f; // default eps parameter
Expand Down
99 changes: 96 additions & 3 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,18 +1054,111 @@ class StableDiffusionGGML {
decode ? 3 : C,
x->ne[3]); // channels
int64_t t0 = ggml_time_ms();

// TODO: args instead of env for tile size / overlap?

float tile_overlap = 0.5f;
const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP");
if (SD_TILE_OVERLAP != nullptr) {
std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
try {
tile_overlap = std::stof(sd_tile_overlap_str);
if (tile_overlap < 0.0) {
LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0");
tile_overlap = 0.0;
}
else if (tile_overlap > 0.5) {
LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5");
tile_overlap = 0.5;
}
} catch (const std::invalid_argument&) {
LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default");
} catch (const std::out_of_range&) {
LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default");
}
}

int tile_size_x = 32;
int tile_size_y = 32;
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
if (SD_TILE_SIZE != nullptr) {
// format is AxB, or just A (equivalent to AxA)
// A and B can be integers (tile size) or floating point
// floating point <= 1 means simple fraction of the latent dimension
// floating point > 1 means number of tiles across that dimension
// a single number gets applied to both
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
float factor = std::stof(factor_str);
if (factor > 1.0)
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
return factor;
};
const int latent_x = W / (decode ? 1 : 8);
const int latent_y = H / (decode ? 1 : 8);
const int min_tile_dimension = 4;
std::string sd_tile_size_str = SD_TILE_SIZE;
size_t x_pos = sd_tile_size_str.find('x');
try {
int tmp_x = tile_size_x, tmp_y = tile_size_y;
if (x_pos != std::string::npos) {
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
if (tile_x_str.find('.') != std::string::npos) {
tmp_x = std::round(latent_x * get_tile_factor(tile_x_str));
}
else {
tmp_x = std::stoi(tile_x_str);
}
if (tile_y_str.find('.') != std::string::npos) {
tmp_y = std::round(latent_y * get_tile_factor(tile_y_str));
}
else {
tmp_y = std::stoi(tile_y_str);
}
}
else {
if (sd_tile_size_str.find('.') != std::string::npos) {
float tile_factor = get_tile_factor(sd_tile_size_str);
tmp_x = std::round(latent_x * tile_factor);
tmp_y = std::round(latent_y * tile_factor);
}
else {
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
}
}
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
} catch (const std::invalid_argument&) {
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
} catch (const std::out_of_range&) {
LOG_WARN("SD_TILE_SIZE is out of range, keeping the default");
}
}

if(!decode){
// TODO: also use and arg for this one?
// to keep the compute buffer size consistent
tile_size_x*=1.30539;
tile_size_y*=1.30539;
}
if (!use_tiny_autoencoder) {
if (decode) {
ggml_tensor_scale(x, 1.0f / scale_factor);
} else {
ggml_tensor_scale_input(x);
}
if (vae_tiling && decode) { // TODO: support tiling vae encode
if (vae_tiling) {
if (SD_TILE_SIZE != nullptr) {
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
}
if (SD_TILE_OVERLAP != nullptr) {
LOG_INFO("VAE Tile overlap: %.2f", tile_overlap);
}
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, decode, &out);
};
sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
} else {
first_stage_model->compute(n_threads, x, decode, &result);
}
Expand All @@ -1074,7 +1167,7 @@ class StableDiffusionGGML {
ggml_tensor_scale_output(result);
}
} else {
if (vae_tiling && decode) { // TODO: support tiling vae encode
if (vae_tiling) {
// split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, decode, &out);
Expand Down
Loading