diff --git a/docs/demo/nnb_js.wasm b/docs/demo/nnb_js.wasm index a67e24e..3be0c55 100644 Binary files a/docs/demo/nnb_js.wasm and b/docs/demo/nnb_js.wasm differ diff --git a/src/nn-builder/src/snippet/matrix.cc b/src/nn-builder/src/snippet/matrix.cc index 31ff0b2..a4a7d9f 100644 --- a/src/nn-builder/src/snippet/matrix.cc +++ b/src/nn-builder/src/snippet/matrix.cc @@ -1217,15 +1217,15 @@ ExprList* MatrixSnippetSimd::MatrixAddRightSignScale(ds::NDArray *lhs, ds::NDArr auto lhs_addr = MakeBinary(Opcode::I32Add, MakeI32Const(lhs->Memory()->Begin()), MakeLocalGet(addr)); auto rhs_addr = MakeBinary(Opcode::I32Add, MakeI32Const(rhs->Memory()->Begin()), MakeLocalGet(addr)); // Compute right sign scale - // 1) [-1, 2, -3, 4] >= [0, 0, 0, 0] = [0, -1, 0, -1] - // 2) [0, -1, 0, -1] to-float = [0.0, -1.0. 0.0, -1.0] - // 3) [0.0, -1.0, 0.0, 1.0] * [2s, 2s, 2s, 2s] = [0, -2s, 0, -2s] - // 4) [0, -2s, 0, -2s] + [s, s, s, s] = [-s, s, -s, s] + // 1) [-1, 2, -3, 4] >= [0, 0, 0, 0] = [0, -1, 0, -1] + // 2) [0, -1, 0, -1] to-float = [0.0, -1.0. 0.0, -1.0] + // 3) [0.0, -1.0, 0.0, 1.0] * [-2s, -2s, -2s, -2s] = [0, 2s, 0, 2s] + // 4) [0, 2s, 0, 2s] - [s, s, s, s] = [-s, s, -s, s] auto rhs_ge = MakeBinary(Opcode::F32X4Ge, MakeV128Load(rhs_addr), MakeUnary(Opcode::F32X4Splat, MakeF32Const(0))); auto rhs_cnvt = MakeUnary(Opcode::F32X4ConvertI32X4S, rhs_ge); - auto rhs_mul = MakeBinary(Opcode::F32X4Mul, rhs_cnvt, MakeUnary(Opcode::F32X4Splat, MakeF32Const(2*scale))); - auto rhs_add = MakeBinary(Opcode::F32X4Add, rhs_mul, MakeUnary(Opcode::F32X4Splat, MakeF32Const(scale))); - b->Insert(MakeV128Store(MakeLocalGet(dst_addr), MakeBinary(Opcode::F32X4Add, MakeV128Load(lhs_addr), rhs_add))); + auto rhs_mul = MakeBinary(Opcode::F32X4Mul, rhs_cnvt, MakeUnary(Opcode::F32X4Splat, MakeF32Const(-2*scale))); + auto rhs_sub = MakeBinary(Opcode::F32X4Sub, rhs_mul, MakeUnary(Opcode::F32X4Splat, MakeF32Const(scale))); + b->Insert(MakeV128Store(MakeLocalGet(dst_addr), MakeBinary(Opcode::F32X4Add, MakeV128Load(lhs_addr), rhs_sub))); // Move to next elements b->Insert(GenerateCompoundAssignment(addr, Opcode::I32Add, MakeI32Const(simd_type_size))); })); @@ -1280,16 +1280,16 @@ ExprList* MatrixSnippetSimd::MatrixAddRightSignScaleAddRightScale(nn::ds::NDArra // Cache rhs val b->Insert(MakeLocalSet(rhs_v128_cache, MakeV128Load(rhs_addr))); // Compute right sign scale - // 1) [-1, 2, -3, 4] >= [0, 0, 0, 0] = [0, -1, 0, -1] - // 2) [0, -1, 0, -1] to-float = [0.0, -1.0. 0.0, -1.0] - // 3) [0.0, -1.0, 0.0, 1.0] * [2s, 2s, 2s, 2s] = [0, -2s, 0, -2s] - // 4) [0, -2s, 0, -2s] + [s, s, s, s] = [-s, s, -s, s] + // 1) [-1, 2, -3, 4] >= [0, 0, 0, 0] = [0, -1, 0, -1] + // 2) [0, -1, 0, -1] to-float = [0.0, -1.0. 0.0, -1.0] + // 3) [0.0, -1.0, 0.0, 1.0] * [-2s, -2s, -2s, -2s] = [0, 2s, 0, 2s] + // 4) [0, 2s, 0, 2s] - [s, s, s, s] = [-s, s, -s, s] auto rhs_ge = MakeBinary(Opcode::F32X4Ge, MakeLocalGet(rhs_v128_cache), MakeUnary(Opcode::F32X4Splat, MakeF32Const(0))); auto rhs_cnvt = MakeUnary(Opcode::F32X4ConvertI32X4S, rhs_ge); - auto rhs_mul = MakeBinary(Opcode::F32X4Mul, rhs_cnvt, MakeUnary(Opcode::F32X4Splat, MakeF32Const(2*scale1))); - auto rhs_add = MakeBinary(Opcode::F32X4Add, rhs_mul, MakeUnary(Opcode::F32X4Splat, MakeF32Const(scale1))); + auto rhs_mul = MakeBinary(Opcode::F32X4Mul, rhs_cnvt, MakeUnary(Opcode::F32X4Splat, MakeF32Const(-2*scale1))); + auto rhs_sub = MakeBinary(Opcode::F32X4Sub, rhs_mul, MakeUnary(Opcode::F32X4Splat, MakeF32Const(scale1))); auto rhs_scale2 = MakeBinary(Opcode::F32X4Mul, MakeLocalGet(rhs_v128_cache), MakeUnary(Opcode::F32X4Splat, MakeF32Const(scale2))); - auto rhs_val = MakeBinary(Opcode::F32X4Add, rhs_add, rhs_scale2); + auto rhs_val = MakeBinary(Opcode::F32X4Add, rhs_sub, rhs_scale2); b->Insert(MakeV128Store(MakeLocalGet(dst_addr), MakeBinary(Opcode::F32X4Add, MakeV128Load(lhs_addr), rhs_val))); // Move to next elements b->Insert(GenerateCompoundAssignment(addr, Opcode::I32Add, MakeI32Const(simd_type_size)));