Skip to content

Commit

Permalink
Add division gate (#583)
Browse files Browse the repository at this point in the history
  • Loading branch information
jegork committed Jun 10, 2024
1 parent dcc9546 commit 28b73d7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/arraymancer/autograd/gates_basic.nim
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,45 @@ proc `-`*[TT](a, b: Variable[TT]): Variable[TT] =
# Caching for backprop
if a.is_grad_needed or b.is_grad_needed:
result.sub_cache(a, b)

type DivGate*[TT] {.final.} = ref object of Gate[TT]
a: Variable[TT]
b: Variable[TT]

proc div_backward_ag[TT](self: Gate[TT], payload: Payload[TT]): SmallDiffs[TT] =
let self = DivGate[TT](self)
let gradient = payload.variable.grad
result = newDiffs[TT](2)
result[0] = gradient /. self.b.value
result[1] = - gradient *. self.a.value /. self.b.value ^. 2

proc div_cache[TT](result: Variable[TT], a, b: Variable[TT]) =
# Gate
var gate: DivGate[TT]
new gate
gate.a = a
gate.b = b

# Result setup
result.grad = zeros_like result.value
result.requires_grad = true

# Add to graph
register_node(
"Div",
gate,
div_backward_ag[TT],
result,
a, b
)

proc `/.`*[TT](a, b: Variable[TT]): Variable[TT] =
when compileOption("boundChecks"):
check_ctx(a, b)

new result
result.context = a.context
result.value = a.value /. b.value

if a.is_grad_needed or b.is_grad_needed:
result.div_cache(a, b)
18 changes: 18 additions & 0 deletions tests/autograd/test_gate_basic.nim
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,23 @@ proc main() =
let constantTensor1 = ones[float32](2, 4) / 2.0
check: va.grad == constantTensor1

test "Gradient of division":
let a = toSeq(1..8).toTensor.reshape(2,4).asType(float32)
let b = toSeq(11..18).toTensor.reshape(2,4).asType(float32)

let ctx = newContext Tensor[float32]

let va = ctx.variable(a, requires_grad = true)
let vb = ctx.variable(b, requires_grad = true)

let vc = va /. vb

vc.backprop()

let gradC = ones[float32](2, 4)

check: va.grad == gradC /. b
check: vb.grad == - gradC *. a /. b ^. 2

main()
GC_fullCollect()

0 comments on commit 28b73d7

Please sign in to comment.