Skip to content

Commit

Permalink
refactor jpeg encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Langwen Huang committed Nov 14, 2024
1 parent 820910a commit 7d2ba6b
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 30 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "src/zstd"]
path = src/zstd
url = https://github.com/facebook/zstd.git
[submodule "src/imshrinker"]
path = src/imshrinker
url = https://github.com/gilson27/imshrinker.git
7 changes: 5 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@ if(NOT CMAKE_BUILD_TYPE)
endif()

set(CMAKE_C_FLAGS_RELEASE "-O3")
set(CMAKE_C_FLAGS_DEBUG "-g -Og -DDEBUG")
set(BUILD_THIRDPARTY True)

set(CMAKE_POSITION_INDEPENDENT_CODE True)

add_library(h5z_j2k SHARED h5z_j2k.c)
add_library(spiht_static STATIC spiht.cc)

add_subdirectory(zstd/build/cmake)
add_subdirectory(wavelib/src)
add_subdirectory(openjpeg)
add_subdirectory(imshrinker)

find_package(HDF5 REQUIRED)

include_directories(${HDF5_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/hdf5/src build/hdf5/src zstd/lib openjpeg/src/lib/openjp2 build/openjpeg/src/lib/openjp2 wavelib/header)
target_link_libraries(h5z_j2k PRIVATE libzstd_static openjp2_static wavelib ${HDF5_LIBRARIES})
include_directories(${HDF5_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/hdf5/src build/hdf5/src zstd/lib openjpeg/src/lib/openjp2 build/openjpeg/src/lib/openjp2 wavelib/header imshrinker/src)
target_link_libraries(h5z_j2k PRIVATE libzstd_static openjp2_static spiht_static wavelib ${HDF5_LIBRARIES})

install(
TARGETS h5z_j2k
Expand Down
1 change: 1 addition & 0 deletions src/imshrinker
Submodule imshrinker added at 0c795e
162 changes: 134 additions & 28 deletions src/j2k_codec.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,8 @@ void print_config(codec_config_t *config) {
void spiht_encode(float *buffer, size_t height, size_t width, float bit_rate, void **out_buffer, size_t *out_size);
void spiht_decode(void *buffer, size_t size, float *out_buffer);

/*Value-Range Relative error*/
float get_max_relative_error(float *data, float *decoded, float *residual, size_t tot_size) {
float cur_max_error = 0, max_data = data[0], min_data = data[0], data_range = 0;
assert(tot_size > 0);
float get_data_range(const float *data, const size_t tot_size) {
float max_data = data[0], min_data = data[0];
for (size_t i = 1; i < tot_size; ++i) {
if (data[i] > max_data) {
max_data = data[i];
Expand All @@ -226,23 +224,28 @@ float get_max_relative_error(float *data, float *decoded, float *residual, size_
min_data = data[i];
}
}
data_range = max_data - min_data;
return max_data - min_data;
}

/*Value-Range Relative error*/
float get_max_relative_error(const float *data, const float *decoded, const float *residual, const size_t tot_size, const float data_range) {
float cur_max_error = 0;
assert(tot_size > 0);
for (size_t i = 0; i < tot_size; ++i) {
float cur_error = fabsf(data[i] - decoded[i] - residual[i]) / data_range;
float residual_value = residual ? residual[i] : 0;
float cur_error = fabsf(data[i] - decoded[i] - residual_value) / data_range;
if (cur_error > cur_max_error) {
cur_max_error = cur_error;
}
}
return cur_max_error;
}

float get_max_error(residual_t error_type, float *data, float *decoded, float *residual, size_t tot_size) {
float get_max_error(const float *data, const float *decoded, const float *residual, const size_t tot_size) {
float cur_max_error = 0;
if (error_type == RELATIVE_ERROR) {
return get_max_relative_error(data, decoded, residual, tot_size);
}
for (size_t i = 0; i < tot_size; ++i) {
float cur_error = fabsf(data[i] - decoded[i] - residual[i]);
float residual_value = residual ? residual[i] : 0;
float cur_error = fabsf(data[i] - decoded[i] - residual_value);
/* this is pointwise relative error
if (error_type == RELATIVE_ERROR) {
cur_error /= fabsf(data[i]);
Expand All @@ -252,10 +255,77 @@ float get_max_error(residual_t error_type, float *data, float *decoded, float *r
cur_max_error = cur_error;
}
}

return cur_max_error;
}

double get_error_target_quantile(const float *data, const float *decoded, const float *residual, const size_t tot_size, const float error_target) {
size_t n = 0;
for (size_t i = 0; i < tot_size; ++i) {
float residual_value = residual ? residual[i] : 0;
float cur_error = fabsf(data[i] - decoded[i] - residual_value);
if (cur_error > error_target) {
n++;
}
}
return 1. - ((double) n / tot_size);
}

void sparsify_coefficients(const double *coeffs, double *coeffs_copy, const size_t coeffs_size, float *residual, const size_t image_dims[2],
const float *data, const float *decoded, const codec_config_t *config, const size_t tot_size) {
float residual_cr = 2000, stop_cr = 50; /*50*/
float cur_max_error = 0, best_max_error = 0;
double quantile, data_range = 1.0, *coeffs_best;

coeffs_best = (double *) malloc(coeffs_size * sizeof(double));

memset(coeffs_copy, 0, coeffs_size * sizeof(double));
memset(coeffs_best, 0, coeffs_size * sizeof(double));
if (config->residual_compression_type == RELATIVE_ERROR) {
data_range = get_data_range(data, tot_size);
cur_max_error = get_max_relative_error(data, decoded, NULL, tot_size, data_range);
} else {
cur_max_error = get_max_error(data, decoded, NULL, tot_size);
}
best_max_error = cur_max_error;

while (cur_max_error > config->error && residual_cr >= stop_cr) {
memcpy(coeffs_copy, coeffs, coeffs_size * sizeof(double));
double q_ratio = 1. - (1. / residual_cr);
quantile = zero_out_quantile(coeffs_copy, coeffs_size, q_ratio);
wavelib_backward(residual, image_dims[0], image_dims[1], WAVELET_LEVELS, coeffs_copy);

if (config->residual_compression_type == RELATIVE_ERROR) {
cur_max_error = get_max_relative_error(data, decoded, residual, tot_size, data_range);
} else {
cur_max_error = get_max_error(data, decoded, residual, tot_size);
}
if (cur_max_error < best_max_error) {
best_max_error = cur_max_error;
memcpy(coeffs_best, coeffs_copy, coeffs_size * sizeof(double));
}
#ifdef DEBUG
printf("Current max error: %f (ABS %f), residual_cr: %f\n", cur_max_error, cur_max_error*data_range, residual_cr);
#endif
residual_cr /= sqrtf(2.f);
}

if (cur_max_error > config->error) {
fprintf(stderr, "Could not reach error target of %f (%f instead).\n", config->error, best_max_error);
memcpy(coeffs_copy, coeffs_best, coeffs_size * sizeof(double));
}
free(coeffs_best);
}

double emulate_j2k_compression(uint16_t *scaled_data, size_t *image_dims, size_t *tile_dims, float current_cr,
codec_data_buffer_t *codec_data_buffer, float **decoded, float minval, float maxval,
float *data, size_t tot_size, float error_target) {
codec_data_buffer_init(codec_data_buffer);
j2k_encode_internal(scaled_data, image_dims, tile_dims, current_cr, codec_data_buffer);
codec_data_buffer_reset(codec_data_buffer);
j2k_decode_internal(decoded, NULL, NULL, minval, maxval, codec_data_buffer);
return get_error_target_quantile(data, *decoded, NULL, tot_size, error_target);
}

size_t encode_climate_variable(float *data, codec_config_t *config, uint8_t **out_buffer) {
#ifdef DEBUG
print_config(config);
Expand Down Expand Up @@ -293,7 +363,7 @@ size_t encode_climate_variable(float *data, codec_config_t *config, uint8_t **ou
// encode using jpeg2000
j2k_encode_internal(scaled_data, image_dims, tile_dims, config->base_cr, &codec_data_buffer);

free(scaled_data);

codec_data_buffer_reset(&codec_data_buffer);

size_t compressed_size = 0;
Expand All @@ -303,7 +373,7 @@ size_t encode_climate_variable(float *data, codec_config_t *config, uint8_t **ou
double *coeffs = NULL;
size_t coeffs_size = 0;
size_t coo_size = 0;
double quantile = config->quantile;
double quantile = config->quantile, base_quantile_target = 1., eps=1e-8;
if (config->residual_compression_type != NONE) {
// decode back the image
float *residual = (float *) malloc(tot_size * sizeof(float));
Expand All @@ -325,25 +395,56 @@ size_t encode_climate_variable(float *data, codec_config_t *config, uint8_t **ou
} else if (config->residual_compression_type == MAX_ERROR ||
config->residual_compression_type == RELATIVE_ERROR) {
double *coeffs_copy = (double *) malloc(coeffs_size * sizeof(double));
float residual_cr = 2000;
float cur_max_error = 0;
do {
memcpy(coeffs_copy, coeffs, coeffs_size * sizeof(double));
double q_ratio = 1. - (1. / residual_cr);
quantile = zero_out_quantile(coeffs_copy, coeffs_size, q_ratio);
wavelib_backward(residual, image_dims[0], image_dims[1], WAVELET_LEVELS, coeffs_copy);

cur_max_error = get_max_error(config->residual_compression_type, data, decoded, residual, tot_size);

residual_cr /= sqrtf(2.f);
} while (cur_max_error > config->error && residual_cr >= 50);
if (cur_max_error > config->error) {
fprintf(stderr, "Could not reach error target of %f (%f instead).", config->error, cur_max_error);
double error_target_quantile, error_target_quantile_prev = 0;
float error_target = config->error, current_cr = config->base_cr, cr_lo, cr_hi;
if (config->residual_compression_type == RELATIVE_ERROR) {
error_target *= get_data_range(data, tot_size);
}
error_target_quantile = get_error_target_quantile(data, decoded, NULL, tot_size, error_target);
error_target_quantile_prev = error_target_quantile;
cr_lo = current_cr;
cr_hi = current_cr;
while (error_target_quantile < base_quantile_target) {
cr_lo /= 2;
error_target_quantile = emulate_j2k_compression(scaled_data, image_dims, tile_dims, cr_lo, &codec_data_buffer, &decoded, minval, maxval, data, tot_size, error_target);
#ifdef DEBUG
printf("cr_lo: %f, error_target_quantile: %f, jp2_length: %lu\n", cr_lo, error_target_quantile, codec_data_buffer.length);
#endif
}
error_target_quantile = error_target_quantile_prev;
while (error_target_quantile >= base_quantile_target) {
cr_hi *= 2;
error_target_quantile = emulate_j2k_compression(scaled_data, image_dims, tile_dims, cr_hi, &codec_data_buffer, &decoded, minval, maxval, data, tot_size, error_target);
#ifdef DEBUG
printf("cr_hi: %f, error_target_quantile: %f, jp2_length: %lu\n", cr_hi, error_target_quantile, codec_data_buffer.length);
#endif
}
error_target_quantile = error_target_quantile_prev;

assert(cr_lo <= cr_hi);
while (fabs(error_target_quantile - base_quantile_target) > eps || cr_hi - cr_lo > 1.) {
current_cr = (cr_lo + cr_hi) / 2;
error_target_quantile = emulate_j2k_compression(scaled_data, image_dims, tile_dims, current_cr, &codec_data_buffer, &decoded, minval, maxval, data, tot_size, error_target);
#ifdef DEBUG
printf("current_cr: %f, error_target_quantile: %f, jp2_length: %lu\n", current_cr, error_target_quantile, codec_data_buffer.length);
#endif
if (error_target_quantile < base_quantile_target) {
cr_hi = current_cr;
} else {
cr_lo = current_cr;
}
}
for (size_t i = 0; i < tot_size; ++i) {
residual[i] = data[i] - decoded[i];
}
wavelib_forward(residual, image_dims[0], image_dims[1], WAVELET_LEVELS, &coeffs, &coeffs_size);
sparsify_coefficients(coeffs, coeffs_copy, coeffs_size, residual, image_dims, data, decoded, config, tot_size);
free(coeffs);
coeffs = coeffs_copy;
}

free(scaled_data);

size_t zero_count = 0;
for (size_t i = 0; i < coeffs_size; ++i) {
if (coeffs[i] == 0) {
Expand Down Expand Up @@ -380,9 +481,14 @@ size_t encode_climate_variable(float *data, codec_config_t *config, uint8_t **ou
assert(coo_iter <= coo_allocated_size);
coo_size = coo_iter;



compressed_size = ZSTD_compressBound(coo_size * sizeof(coo_t));
compressed_coefficients = (uint8_t *) malloc(compressed_size);
compressed_size = ZSTD_compress(compressed_coefficients, compressed_size, coo, coo_size * sizeof(coo_t), 22);
#ifdef DEBUG
printf("coeffs_size: %lu, coo_size: %lu, compressed_size: %lu, jp2_length: %lu\n", coeffs_size, coo_size, compressed_size, codec_data_buffer.length);
#endif

free(coo);
free(coeffs);
Expand Down
16 changes: 16 additions & 0 deletions src/spiht.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "Encoder.h"
#include "Decoder.h"
#include "spiht.h"


extern "C"
void spiht_encode(float *buffer, size_t height, size_t width, uint8_t **out_buffer, size_t *output_size, int num_stages, float* max_val) {
Encoder encoder;
encoder.encode_image(buffer, height, width, out_buffer, output_size, num_stages, max_val);
}

extern "C"
void spiht_decode(uint8_t *buffer, size_t size, float *out_buffer, size_t height, size_t width, int num_bits, float max_val) {
Decoder decoder;
decoder.decode_image(buffer, size, out_buffer, height, width, num_bits, max_val);
}
12 changes: 12 additions & 0 deletions src/spiht.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef SPIHT_H
#define SPIHT_H

#include <cstddef>
#include <cstdint>

extern "C" {
void spiht_encode(float *buffer, size_t height, size_t width, uint8_t **out_buffer, size_t *output_size, int num_stages, float* max_val);
void spiht_decode(uint8_t *buffer, size_t size, float *out_buffer, size_t height, size_t width, int num_bits, float max_val);
}

#endif // SPIHT_H

0 comments on commit 7d2ba6b

Please sign in to comment.