Skip to content

Commit

Permalink
add boolean support to jit sreg expression
Browse files Browse the repository at this point in the history
  • Loading branch information
usstq committed Dec 27, 2024
1 parent a941617 commit efc7c57
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ static std::shared_ptr<SIMDJit> jit_compile_gemmRegBlk(int rows, int cols, int p
B_stride = B_stride * 4;
dst_stride = dst_stride * 4;

jit->if_(accumulate == 0,
jit->if_(
accumulate == 0,
[&] {
// initilaize C to zero
for (int r = 0; r < rows; r++)
Expand Down Expand Up @@ -279,12 +280,12 @@ static std::shared_ptr<SIMDJit> jit_compile_accumulate_weight(WeightCompressionT
// load all arguments into register
auto dst = jit->get_sreg(0); // float*
auto OC = jit->get_sreg(1);
auto gate_ids = jit->get_sreg(2); // int32_t *
auto gate_cnt = jit->get_sreg(3); // int
auto pw0 = jit->get_sreg(4); // ov::float16* / uint8_t*
auto dense_x = jit->get_sreg(5); //
auto scales = jit->get_sreg(6); // float*
auto zero_points = jit->get_sreg(7); // float*
auto gate_ids = jit->get_sreg(2); // int32_t *
auto gate_cnt = jit->get_sreg(3); // int
auto pw0 = jit->get_sreg(4); // ov::float16* / uint8_t*
auto dense_x = jit->get_sreg(5); //
auto scales = jit->get_sreg(6); // float*
auto zero_points = jit->get_sreg(7); // float*

auto g = jit->get_sreg();
auto i = jit->get_sreg();
Expand Down Expand Up @@ -500,7 +501,7 @@ static std::shared_ptr<SIMDJit> jit_compile_repack_3xsimdw_1xsimdw(bool with_zp)
});

dst = repacked_B_nx1;
dst_stride = K *(simd_width * 1 * sizeof(float));
dst_stride = K * (simd_width * 1 * sizeof(float));

jit->for_loop(n0, n0, N, simd_width, [&]() {
jit->simd_loadu_ps(scale0, jit->ptr[scales + n0 * sizeof(float)]);
Expand Down Expand Up @@ -534,8 +535,8 @@ static std::shared_ptr<SIMDJit> jit_compile_repack_2xsimdw(WeightCompressionType
auto src_stride = jit->get_sreg(1); // in unit of f16 or bytes (int8/int4)
auto dst = jit->get_sreg(2); // float*
auto bK = jit->get_sreg(3);
auto scales = jit->get_sreg(4); // scales
auto zero_point = jit->get_sreg(5); // zero-point
auto scales = jit->get_sreg(4); // scales
auto zero_point = jit->get_sreg(5); // zero-point

auto k = jit->get_sreg();

Expand Down
Loading

0 comments on commit efc7c57

Please sign in to comment.