From 28b73d7bd633b89e4f31a7bdca8757fe029c9b0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jegor=20Kit=C5=A1kerkin?= Date: Mon, 10 Jun 2024 19:04:39 +0300 Subject: [PATCH] Add division gate (#583) --- src/arraymancer/autograd/gates_basic.nim | 42 ++++++++++++++++++++++++ tests/autograd/test_gate_basic.nim | 18 ++++++++++ 2 files changed, 60 insertions(+) diff --git a/src/arraymancer/autograd/gates_basic.nim b/src/arraymancer/autograd/gates_basic.nim index e023b5ea8..ac383b357 100644 --- a/src/arraymancer/autograd/gates_basic.nim +++ b/src/arraymancer/autograd/gates_basic.nim @@ -96,3 +96,45 @@ proc `-`*[TT](a, b: Variable[TT]): Variable[TT] = # Caching for backprop if a.is_grad_needed or b.is_grad_needed: result.sub_cache(a, b) + +type DivGate*[TT] {.final.} = ref object of Gate[TT] + a: Variable[TT] + b: Variable[TT] + +proc div_backward_ag[TT](self: Gate[TT], payload: Payload[TT]): SmallDiffs[TT] = + let self = DivGate[TT](self) + let gradient = payload.variable.grad + result = newDiffs[TT](2) + result[0] = gradient /. self.b.value + result[1] = - gradient *. self.a.value /. self.b.value ^. 2 + +proc div_cache[TT](result: Variable[TT], a, b: Variable[TT]) = + # Gate + var gate: DivGate[TT] + new gate + gate.a = a + gate.b = b + + # Result setup + result.grad = zeros_like result.value + result.requires_grad = true + + # Add to graph + register_node( + "Div", + gate, + div_backward_ag[TT], + result, + a, b + ) + +proc `/.`*[TT](a, b: Variable[TT]): Variable[TT] = + when compileOption("boundChecks"): + check_ctx(a, b) + + new result + result.context = a.context + result.value = a.value /. b.value + + if a.is_grad_needed or b.is_grad_needed: + result.div_cache(a, b) diff --git a/tests/autograd/test_gate_basic.nim b/tests/autograd/test_gate_basic.nim index 4de4c2e2e..7e1498a5e 100644 --- a/tests/autograd/test_gate_basic.nim +++ b/tests/autograd/test_gate_basic.nim @@ -71,5 +71,23 @@ proc main() = let constantTensor1 = ones[float32](2, 4) / 2.0 check: va.grad == constantTensor1 + test "Gradient of division": + let a = toSeq(1..8).toTensor.reshape(2,4).asType(float32) + let b = toSeq(11..18).toTensor.reshape(2,4).asType(float32) + + let ctx = newContext Tensor[float32] + + let va = ctx.variable(a, requires_grad = true) + let vb = ctx.variable(b, requires_grad = true) + + let vc = va /. vb + + vc.backprop() + + let gradC = ones[float32](2, 4) + + check: va.grad == gradC /. b + check: vb.grad == - gradC *. a /. b ^. 2 + main() GC_fullCollect()