Skip to content

Commit a330129

Browse files
committed
start of real tests
1 parent d1568c4 commit a330129

File tree

3 files changed

+83
-2
lines changed

3 files changed

+83
-2
lines changed

dialects/autodiff/auxiliary/autodiff_aux.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,13 @@ const Def* autodiff_type_fun(const Def* ty) {
140140
auto body_ad = autodiff_type_fun(body);
141141
if (!body_ad) return nullptr;
142142
return world.arr(shape, body_ad);
143-
}
144-
if (auto sig = ty->isa<Sigma>()) {
143+
} else if (auto sig = ty->isa<Sigma>()) {
145144
// TODO: nom sigma
146145
DefArray ops(sig->ops(), [&](const Def* op) { return autodiff_type_fun(op); });
147146
world.DLOG("ops: {,}", ops);
148147
return world.sigma(ops);
148+
} else if (auto real = match<core::Real>(ty)) {
149+
return ty;
149150
}
150151
// mem
151152
if (match<mem::M>(ty)) return ty;
@@ -171,6 +172,11 @@ const Def* zero_def(const Def* T) {
171172
auto zero = world.lit_idx(T, 0, world.dbg("zero"));
172173
world.DLOG("zero_def for int is {}", zero);
173174
return zero;
175+
} else if (auto real = match<core::Real>(T)) {
176+
auto width = as_lit<nat_t>(real->arg());
177+
auto zero = core::lit_real(T->world(), width, 0.0);
178+
world.DLOG("zero_def for real is {}", zero);
179+
return zero;
174180
} else if (auto sig = T->isa<Sigma>()) {
175181
DefArray ops(sig->ops(), [&](const Def* op) { return world.app(world.ax<zero>(), op); });
176182
return world.tuple(ops);

dialects/autodiff/normalizers.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ const Def* normalize_add(const Def* type, const Def* callee, const Def* arg, con
8484
world.app(world.app(world.ax(core::wrap::add), {world.lit_nat_0(), world.lit_nat(width)}), {a, b});
8585
world.DLOG("int add {} : {}", int_add, world.iinfer(int_add));
8686
return int_add;
87+
} else if (auto real = match<core::Real>(T)) {
88+
auto width = as_lit<nat_t>(real->arg());
89+
world.DLOG("width {}", width);
90+
auto real_add =
91+
world.app(world.app(world.ax(core::rop::add), {world.lit_nat_0(), world.lit_nat(width)}), {a, b});
92+
world.DLOG("real add {} : {}", real_add, real_add->type());
93+
return real_add;
94+
8795
} else if (auto app = T->isa<App>()) {
8896
auto callee = app->callee();
8997
assert(0 && "not handled");

lit/autodiff/square_real.thorin

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: rm -f %t.ll ; \
2+
// RUN: %thorin -d tool -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
3+
4+
.import core;
5+
.import autodiff;
6+
.import mem;
7+
8+
.let _32 = 32;
9+
.let R32 = %core.Real _32;
10+
11+
12+
.lam .extern internal_diff_core_rop_mul
13+
![m:.Nat, w:.Nat] ->
14+
(.Cn[[%core.Real w, %core.Real w], .Cn[%core.Real w, .Cn[%core.Real w, .Cn[%core.Real w, %core.Real w]]]])
15+
= {
16+
.cn inner_mul_deriv_cps
17+
![[a:%core.Real w, b:%core.Real w], ret:.Cn[%core.Real w, .Cn[%core.Real w, .Cn[%core.Real w, %core.Real w]]]]@(.tt)
18+
= {
19+
.let result = %core.rop.mul (m,w) (a,b);
20+
.cn mul_pb ![s:(%core.Real w), pb_ret:(.Cn [%core.Real w, %core.Real w])]@(.tt) = {
21+
.let lhs = %core.rop.mul (m,w) (s,b);
22+
.let rhs = %core.rop.mul (m,w) (s,a);
23+
pb_ret (lhs, rhs)
24+
};
25+
ret (result,mul_pb)
26+
};
27+
inner_mul_deriv_cps
28+
};
29+
30+
31+
32+
33+
.cn f [a:R32, ret: .Cn [R32]] = {
34+
.let b = %core.rop.mul (0, _32) (a, a);
35+
ret b
36+
};
37+
38+
// .cn g [a:R32, ret: .Cn [R32, .Cn [R32, .Cn [R32]]]] = {
39+
// .let f_diff = %autodiff.autodiff (.Cn [R32,.Cn[R32]]) f;
40+
// f_diff (a,ret)
41+
// };
42+
43+
.cn .extern test [[], ret: .Cn [R32]] = {
44+
// .cn inner [r:R32, pb:.Cn [R32, .Cn [R32]]] = {
45+
// .cn pb_rec p::[R32] = {
46+
// // ret (5.0:R32)
47+
// ret r
48+
// };
49+
// pb (1.0:R32, pb_rec)
50+
// };
51+
// g (5.0:R32, inner)
52+
53+
54+
// .cn pb_rec p::[R32] = {
55+
// ret p
56+
// };
57+
58+
// f (5.0:R32, pb_rec)
59+
60+
61+
// ret (6.0:R32)
62+
63+
.let b = %core.rop.mul (0, _32) (2.0:R32, 3.0:R32);
64+
ret b
65+
};
66+
67+
// CHECK-DAG: return{{.*}}176484

0 commit comments

Comments
 (0)