Skip to content

Commit

Permalink
Implement inverseSqrt (#1029)
Browse files Browse the repository at this point in the history
* Implement `inverseSqrt`

* Force the rounding to match between JS and RS
  • Loading branch information
dfellis authored Dec 24, 2024
1 parent cc86ef3 commit 385824e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
8 changes: 8 additions & 0 deletions alan/src/compile/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,14 @@ test_gpgpu!(gpu_dot => r#"
stdout "25.0\n";
);

test_gpgpu!(gpu_inverse_sqrt => r#"
export fn main {
let b = GBuffer([4.0.f32, 25.0.f32]);
b.map(fn (val: gf32) = val.inverseSqrt).read{f32}.map(fn (v: f32) = v.string(1)).print;
}"#;
stdout "[0.5, 0.2]\n";
);

// TODO: Fix u64 numeric constants to get u64 bitwise tests in the new test suite
test!(u64_bitwise => r#"
prefix u64 as ~ precedence 10
Expand Down
12 changes: 8 additions & 4 deletions alan/test.ln
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ export fn{Test} main {
.assert(eq, (-0.5.f32).saturate, 0.0.f32)
.assert(eq, 0.5.f32.saturate, 0.5.f32)
.assert(eq, 1.5.f32.saturate, 1.0.f32))
.it('dot').assert(eq, {f32[2]}(3.f32, 4.f32) *. {f32[2]}(3.f32, 4.f32), 25.f32);
.it('dot').assert(eq, {f32[2]}(3.f32, 4.f32) *. {f32[2]}(3.f32, 4.f32), 25.f32)
.it('inverseSqrt').assert(eq, 4.f32.inverseSqrt, 0.5.f32);

test.describe("Basic math tests f64")
.it("add")
Expand Down Expand Up @@ -493,7 +494,8 @@ export fn{Test} main {
.assert(eq, (-0.5).saturate, 0.0)
.assert(eq, 0.5.saturate, 0.5)
.assert(eq, 1.5.saturate, 1.0))
.it('dot').assert(eq, {f64[2]}(3.0, 4.0) *. {f64[2]}(3.0, 4.0), 25.0);
.it('dot').assert(eq, {f64[2]}(3.0, 4.0) *. {f64[2]}(3.0, 4.0), 25.0)
.it('inverseSqrt').assert(eq, 25.0.inverseSqrt, 0.2);

test.describe("Basic math tests")
.it("grouping")
Expand Down Expand Up @@ -961,7 +963,8 @@ export fn{Test} main {
})
.it('normalize', fn (test: Mut{Testing}) = test
.assert(eq, [1.0, 0.0, 0.0].normalize.map(fn (v: f64) = string(v)).join(', '), '1, 0, 0')
.assert(eq, [3.0, 4.0].normalize.map(fn (v: f64) = string(v)).join(', '), '0.6, 0.8'));
.assert(eq, [3.0, 4.0].normalize.map(fn (v: f64) = string(v)).join(', '), '0.6, 0.8'))
.it('inverseSqrt').assert(eq, [4.0, 25.0].inverseSqrt.map(fn (v: f64) = string(v)).join(', '), '0.5, 0.2');

test.describe("Buffers")
.it("join", fn (test: Mut{Testing}) {
Expand Down Expand Up @@ -1038,7 +1041,8 @@ export fn{Test} main {
})
.it('normalize', fn (test: Mut{Testing}) = test
.assert(eq, {f64[3]}(1.0, 0.0, 0.0).normalize.map(fn (v: f64) = string(v)).join(', '), '1, 0, 0')
.assert(eq, {f64[2]}(3.0, 4.0).normalize.map(fn (v: f64) = string(v)).join(', '), '0.6, 0.8'));
.assert(eq, {f64[2]}(3.0, 4.0).normalize.map(fn (v: f64) = string(v)).join(', '), '0.6, 0.8'))
.it('inverseSqrt').assert(eq, {f64[2]}(4.0, 25.0).inverseSqrt.map(fn (v: f64) = string(v)).join(', '), '0.5, 0.2');

test.describe("Conditionals")
.it("if function")
Expand Down
17 changes: 17 additions & 0 deletions alan_compiler/src/std/root.ln
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,8 @@ fn acoth(x: f32) = ln((x + 1.0.f32) / (x - 1.0.f32)) / 2.0.f32;
fn{Rs} round Method{"round_ties_even"} :: Deref{f32} -> f32;
fn{Js} round Method{"roundTiesEven"} :: f32 -> f32;
fn magnitude f32 = arg0.abs;
fn{Rs} inverseSqrt (v: f32) = {Method{"recip"} :: f32 -> f32}(sqrt(v));
fn{Js} inverseSqrt (v: f32) = 1.0.f32 / sqrt(v);

fn{Rs} add Infix{"+"} :: (f64, f64) -> f64;
fn{Js} add "((a, b) => new alan_std.F64(a.val + b.val))" <- RootBacking :: (f64, f64) -> f64;
Expand Down Expand Up @@ -699,6 +701,8 @@ fn acoth(x: f64) = ln((x + 1.0) / (x - 1.0)) / 2.0;
fn{Rs} round Method{"round_ties_even"} :: Deref{f64} -> f64;
fn{Js} round Method{"roundTiesEven"} :: f64 -> f64;
fn magnitude f64 = arg0.abs;
fn{Rs} inverseSqrt (v: f64) = {Method{"recip"} :: f64 -> f64}(sqrt(v));
fn{Js} inverseSqrt (v: f64) = 1.0 / sqrt(v);

/// Unsigned Integer-related functions and function bindings
fn{Rs} add Method{"wrapping_add"} :: (u8, Deref{u8}) -> u8;
Expand Down Expand Up @@ -1329,6 +1333,8 @@ fn normalize (arr: f64[]) {
let arr1 = arr.clone; // TODO: Needed for Rust codegen, but should not
return if(mag == 0.0, fn = arr1.clone, fn = arr.map(fn (v: f64) = v / mag));
}
fn inverseSqrt(arr: f32[]) = arr.map(inverseSqrt);
fn inverseSqrt(arr: f64[]) = arr.map(inverseSqrt);

/// Buffer related bindings
fn{Rs} get{T, S} "alan_std::getbuffer" <- RootBacking :: (T[S], i64) -> T?;
Expand Down Expand Up @@ -1390,6 +1396,8 @@ fn normalize{S} (arr: f64[S]) {
let arr1 = arr.clone; // TODO: Needed for Rust codegen, but should not
return if(mag == 0.0, fn = arr1.clone, fn = arr.map(fn (v: f64) = v / mag));
}
fn inverseSqrt{S}(buf: f32[S]) = buf.map(inverseSqrt);
fn inverseSqrt{S}(buf: f64[S]) = buf.map(inverseSqrt);

/// Dictionary-related bindings
fn{Rs} Dict{K, V} "alan_std::OrderedHashMap::new" <- RootBacking :: () -> Dict{K, V};
Expand Down Expand Up @@ -3971,6 +3979,15 @@ fn magnitude(v: gvec2f) = gMagnitude(v);
fn magnitude(v: gvec3f) = gMagnitude(v);
fn magnitude(v: gvec4f) = gMagnitude(v);

fn gInverseSqrt{I}(v: I) {
let varName = 'inverseSqrt('.concat(v.varName).concat(')');
return {I}(varName, v.statements, v.buffers);
}
fn inverseSqrt(v: gf32) = gInverseSqrt(v);
fn inverseSqrt(v: gvec2f) = gInverseSqrt(v);
fn inverseSqrt(v: gvec3f) = gInverseSqrt(v);
fn inverseSqrt(v: gvec4f) = gInverseSqrt(v);

fn gNormalize{I}(v: I) {
let varName = 'normalize('.concat(v.varName).concat(')');
return {I}(varName, v.statements, v.buffers);
Expand Down

0 comments on commit 385824e

Please sign in to comment.