Skip to content

Commit ba3f6ef

Browse files
author
Roland Leißa
committed
re-introduce raw_app to avoid recomputing dependent type
Also replaces some raw_app%s with app%s in autodiff
1 parent 11e373f commit ba3f6ef

File tree

13 files changed

+87
-82
lines changed

13 files changed

+87
-82
lines changed

dialects/autodiff/auxiliary/autodiff_aux.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ const Def* zero_def(const Def* T) {
146146
const Def* op_sum(const Def* T, DefArray defs) {
147147
// TODO: assert all are of type T
148148
auto& world = T->world();
149-
return world.app<false>(world.app<false>(world.ax<sum>(), {world.lit_nat(defs.size()), T}), world.tuple(defs));
149+
return world.app(world.app(world.ax<sum>(), {world.lit_nat(defs.size()), T}), world.tuple(defs));
150150
}
151151

152152
} // namespace thorin::autodiff

dialects/autodiff/auxiliary/autodiff_rewrite_inner.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,14 @@ const Def* AutoDiffEval::augment_pack(const Pack* pack, Lam* f, Lam* f_diff) {
198198
// TODO: special case for const width (special tuple)
199199

200200
// <i:n, cps2ds body_pb (s#i)>
201-
app_pb->set(world.app<false>(direct::op_cps2ds_dep(body_pb), world.extract(pb->var((nat_t)0), app_pb->var())));
201+
app_pb->set(world.app(direct::op_cps2ds_dep(body_pb), world.extract(pb->var((nat_t)0), app_pb->var())));
202202

203203
world.DLOG("app pb of pack: {} : {}", app_pb, app_pb->type());
204204

205-
auto sumup = world.app<false>(world.ax<sum>(), {aug_shape, f_arg_ty_diff});
205+
auto sumup = world.app(world.ax<sum>(), {aug_shape, f_arg_ty_diff});
206206
world.DLOG("sumup: {} : {}", sumup, sumup->type());
207207

208-
pb->app(true, pb->var(1), world.app<false>(sumup, app_pb));
208+
pb->app(true, pb->var(1), world.app(sumup, app_pb));
209209

210210
partial_pullback[aug_pack] = pb;
211211

dialects/autodiff/normalizers.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ namespace thorin::autodiff {
1313
/// TODO: Maybe we want to handle trivial lookup replacements here.
1414
const Def* normalize_ad(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
1515
auto& world = type->world();
16-
return world.app<false>(callee, arg, dbg);
16+
return world.raw_app(type, callee, arg, dbg);
1717
}
1818

1919
const Def* normalize_AD(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
2020
auto& world = type->world();
2121
auto ad_ty = autodiff_type_fun(arg);
2222
if (ad_ty) return ad_ty;
23-
return world.app<false>(callee, arg, dbg);
23+
return world.raw_app(type, callee, arg, dbg);
2424
}
2525

2626
const Def* normalize_Tangent(const Def*, const Def*, const Def* arg, const Def*) { return tangent_type_fun(arg); }
@@ -30,7 +30,7 @@ const Def* normalize_Tangent(const Def*, const Def*, const Def* arg, const Def*)
3030
/// A high-level addition with zero can be shortened directly.
3131
const Def* normalize_zero(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
3232
auto& world = type->world();
33-
return world.app<false>(callee, arg, dbg);
33+
return world.raw_app(type, callee, arg, dbg);
3434
}
3535

3636
/// Currently resolved the full addition.
@@ -85,7 +85,7 @@ const Def* normalize_add(const Def* type, const Def* callee, const Def* arg, con
8585
}
8686
// TODO: mem stays here (only resolved after direct simplification)
8787

88-
return world.app<false>(callee, arg, dbg);
88+
return world.raw_app(type, callee, arg, dbg);
8989
}
9090

9191
const Def* normalize_sum(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
@@ -105,7 +105,7 @@ const Def* normalize_sum(const Def* type, const Def* callee, const Def* arg, con
105105
}
106106
assert(0);
107107

108-
return world.app<false>(callee, arg, dbg);
108+
return world.raw_app(type, callee, arg, dbg);
109109
}
110110

111111
THORIN_autodiff_NORMALIZER_IMPL

dialects/clos/normalizers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace thorin::clos {
55
template<attr o>
66
const Def* normalize_clos(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
77
auto& w = type->world();
8-
return o == attr::bot ? arg : w.app<false>(callee, arg, dbg);
8+
return o == attr::bot ? arg : w.raw_app(type, callee, arg, dbg);
99
}
1010

1111
THORIN_clos_NORMALIZER_IMPL

dialects/core/normalizers.cpp

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ const Def* normalize_nop(const Def* type, const Def* callee, const Def* arg, con
190190
}
191191
}
192192

193-
return world.app<false>(callee, arg, dbg);
193+
return world.raw_app(type, callee, arg, dbg);
194194
}
195195

196196
template<ncmp id>
@@ -219,7 +219,7 @@ const Def* normalize_ncmp(const Def* type, const Def* callee, const Def* arg, co
219219
}
220220
}
221221

222-
return world.app<false>(callee, arg, dbg);
222+
return world.raw_app(type, callee, arg, dbg);
223223
}
224224

225225
template<icmp id>
@@ -236,7 +236,7 @@ const Def* normalize_icmp(const Def* type, const Def* c, const Def* arg, const D
236236
if (id == icmp::ne) return world.lit_ff();
237237
}
238238

239-
return world.app<false>(callee, {a, b}, dbg);
239+
return world.raw_app(type, callee, {a, b}, dbg);
240240
}
241241

242242
template<bit1 id>
@@ -255,7 +255,7 @@ const Def* normalize_bit1(const Def* type, const Def* c, const Def* a, const Def
255255

256256
if (l) return world.lit_idx_mod(*s, ~*l);
257257

258-
return world.app<false>(callee, a, dbg);
258+
return world.raw_app(type, callee, a, dbg);
259259
}
260260

261261
template<class Id>
@@ -358,7 +358,7 @@ const Def* normalize_bit2(const Def* type, const Def* c, const Def* arg, const D
358358

359359
if (auto res = reassociate<bit2>(id, world, callee, a, b, dbg)) return res;
360360

361-
return world.app<false>(callee, {a, b}, dbg);
361+
return world.raw_app(type, callee, {a, b}, dbg);
362362
}
363363

364364
template<shr id>
@@ -391,7 +391,7 @@ const Def* normalize_shr(const Def* type, const Def* c, const Def* arg, const De
391391
if (ls && lb->get() > *ls) return world.bot(type, dbg);
392392
}
393393

394-
return world.app<false>(callee, {a, b}, dbg);
394+
return world.raw_app(type, callee, {a, b}, dbg);
395395
}
396396

397397
template<wrap id>
@@ -454,7 +454,7 @@ const Def* normalize_wrap(const Def* type, const Def* c, const Def* arg, const D
454454

455455
if (auto res = reassociate<wrap>(id, world, callee, a, b, dbg)) return res;
456456

457-
return world.app<false>(callee, {a, b}, dbg);
457+
return world.raw_app(type, callee, {a, b}, dbg);
458458
}
459459

460460
template<div id>
@@ -493,30 +493,30 @@ const Def* normalize_div(const Def* type, const Def* c, const Def* arg, const De
493493
}
494494
}
495495

496-
return world.app<false>(callee, {mem, a, b}, dbg);
496+
return world.raw_app(type, callee, {mem, a, b}, dbg);
497497
}
498498

499499
template<conv id>
500-
const Def* normalize_conv(const Def* dst_ty, const Def* c, const Def* x, const Def* dbg) {
501-
auto& world = dst_ty->world();
500+
const Def* normalize_conv(const Def* dst_t, const Def* c, const Def* x, const Def* dbg) {
501+
auto& world = dst_t->world();
502502
auto callee = c->as<App>();
503-
auto s_ty = x->type()->as<App>();
504-
auto d_ty = dst_ty->as<App>();
505-
auto s = s_ty->arg();
506-
auto d = d_ty->arg();
503+
auto s_t = x->type()->as<App>();
504+
auto d_t = dst_t->as<App>();
505+
auto s = s_t->arg();
506+
auto d = d_t->arg();
507507
auto ls = isa_lit(s);
508508
auto ld = isa_lit(d);
509509

510-
if (s_ty == d_ty) return x;
511-
if (x->isa<Bot>()) return world.bot(d_ty, dbg);
510+
if (s_t == d_t) return x;
511+
if (x->isa<Bot>()) return world.bot(d_t, dbg);
512512
if constexpr (id == conv::s2s) {
513-
if (ls && ld && *ld < *ls) return op(conv::u2u, d_ty, x, dbg); // just truncate - we don't care for signedness
513+
if (ls && ld && *ld < *ls) return op(conv::u2u, d_t, x, dbg); // just truncate - we don't care for signedness
514514
}
515515

516516
if (auto l = isa_lit(x); l && ls && ld) {
517517
if constexpr (id == conv::u2u) {
518-
if (*ld == 0) return world.lit(d_ty, *l); // I64
519-
return world.lit(d_ty, *l % *ld);
518+
if (*ld == 0) return world.lit(d_t, *l); // I64
519+
return world.lit(d_t, *l % *ld);
520520
}
521521

522522
auto sw = Idx::size2bitwidth(*ls);
@@ -525,7 +525,7 @@ const Def* normalize_conv(const Def* dst_ty, const Def* c, const Def* x, const D
525525
// clang-format off
526526
if (false) {}
527527
#define M(S, D) \
528-
else if (S == sw && D == dw) return world.lit(d_ty, w2s<D>(thorin::bitcast<w2s<S>>(*l)), dbg);
528+
else if (S == sw && D == dw) return world.lit(d_t, w2s<D>(thorin::bitcast<w2s<S>>(*l)), dbg);
529529
M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
530530
M( 8, 16) M( 8, 32) M( 8, 64)
531531
M(16, 32) M(16, 64)
@@ -534,7 +534,7 @@ const Def* normalize_conv(const Def* dst_ty, const Def* c, const Def* x, const D
534534
// clang-format on
535535
}
536536

537-
return world.app<false>(callee, x, dbg);
537+
return world.raw_app(dst_t, callee, x, dbg);
538538
}
539539

540540
const Def* normalize_bitcast(const Def* dst_t, const Def* callee, const Def* src, const Def* dbg) {
@@ -551,7 +551,7 @@ const Def* normalize_bitcast(const Def* dst_t, const Def* callee, const Def* src
551551
if (Idx::size(dst_t)) return world.lit(dst_t, lit->get(), dbg);
552552
}
553553

554-
return world.app<false>(callee, src, dbg);
554+
return world.raw_app(dst_t, callee, src, dbg);
555555
}
556556

557557
// TODO I guess we can do that with C++20 <bit>
@@ -609,7 +609,7 @@ const Def* normalize_trait(const Def*, const Def* callee, const Def* type, const
609609
}
610610

611611
out:
612-
return world.app<false>(callee, type, dbg);
612+
return world.raw_app(type, callee, type, dbg);
613613
}
614614

615615
const Def* normalize_zip(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
@@ -650,7 +650,7 @@ const Def* normalize_zip(const Def* type, const Def* c, const Def* arg, const De
650650
}
651651
}
652652

653-
return w.app<false>(callee, arg, dbg);
653+
return w.raw_app(type, callee, arg, dbg);
654654
}
655655

656656
template<pe id>
@@ -662,7 +662,7 @@ const Def* normalize_pe(const Def* type, const Def* callee, const Def* arg, cons
662662
if (arg->dep_const()) return world.lit_tt();
663663
}
664664

665-
return world.app<false>(callee, arg, dbg);
665+
return world.raw_app(type, callee, arg, dbg);
666666
}
667667

668668
THORIN_core_NORMALIZER_IMPL

dialects/direct/direct.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ inline const Def* op_cps2ds_dep(const Def* f) {
3333
Uf->set_filter(true);
3434
Uf->set_body(rewritten_codom);
3535

36-
auto ax_app = world.app<false>(world.ax<direct::cps2ds_dep>(), {T, Uf});
36+
auto ax_app = world.app(world.ax<direct::cps2ds_dep>(), {T, Uf});
3737

3838
world.DLOG("axiom app: {} : {}", ax_app, ax_app->type());
3939

40-
return world.app<false>(ax_app, f);
40+
return world.app(ax_app, f);
4141
}
4242

4343
} // namespace thorin::direct

dialects/direct/normalizers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ const Def* normalize_cps2ds(const Def* type, const Def* callee, const Def* arg,
3939
for (size_t i = 2; i < curry_args.size(); ++i) r = world.app(r, curry_args[i]);
4040
return r;
4141

42-
return world.app<false>(callee, arg, dbg);
42+
return world.raw_app(type, callee, arg, dbg);
4343
}
4444

4545
THORIN_direct_NORMALIZER_IMPL

dialects/math/normalizers.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ const Def* normalize_arith(const Def* type, const Def* c, const Def* arg, const
276276

277277
if (auto res = reassociate<arith>(id, world, callee, a, b, dbg)) return res;
278278

279-
return world.app<false>(callee, {a, b}, dbg);
279+
return world.raw_app(type, callee, {a, b}, dbg);
280280
}
281281

282282
template<extrema id>
@@ -297,49 +297,49 @@ const Def* normalize_extrema(const Def* type, const Def* c, const Def* arg, cons
297297
}
298298
}
299299

300-
return world.app<false>(c, arg, dbg);
300+
return world.raw_app(type, c, arg, dbg);
301301
}
302302

303303
template<tri id>
304304
const Def* normalize_tri(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
305305
auto& world = type->world();
306306
if (auto lit = fold<tri, id>(world, type, arg, dbg)) return lit;
307-
return world.app<false>(c, arg, dbg);
307+
return world.raw_app(type, c, arg, dbg);
308308
}
309309

310310
const Def* normalize_pow(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
311311
auto& world = type->world();
312312
auto [a, b] = arg->projs<2>();
313313
if (auto lit = fold<pow, /*dummy*/ pow(0)>(world, type, a, b, dbg)) return lit;
314-
return world.app<false>(c, arg, dbg);
314+
return world.raw_app(type, c, arg, dbg);
315315
}
316316

317317
template<rt id>
318318
const Def* normalize_rt(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
319319
auto& world = type->world();
320320
if (auto lit = fold<rt, id>(world, type, arg, dbg)) return lit;
321-
return world.app<false>(c, arg, dbg);
321+
return world.raw_app(type, c, arg, dbg);
322322
}
323323

324324
template<exp id>
325325
const Def* normalize_exp(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
326326
auto& world = type->world();
327327
if (auto lit = fold<exp, id>(world, type, arg, dbg)) return lit;
328-
return world.app<false>(c, arg, dbg);
328+
return world.raw_app(type, c, arg, dbg);
329329
}
330330

331331
template<er id>
332332
const Def* normalize_er(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
333333
auto& world = type->world();
334334
if (auto lit = fold<er, id>(world, type, arg, dbg)) return lit;
335-
return world.app<false>(c, arg, dbg);
335+
return world.raw_app(type, c, arg, dbg);
336336
}
337337

338338
template<gamma id>
339339
const Def* normalize_gamma(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
340340
auto& world = type->world();
341341
if (auto lit = fold<gamma, id>(world, type, arg, dbg)) return lit;
342-
return world.app<false>(c, arg, dbg);
342+
return world.raw_app(type, c, arg, dbg);
343343
}
344344

345345
template<cmp id>
@@ -352,7 +352,7 @@ const Def* normalize_cmp(const Def* type, const Def* c, const Def* arg, const De
352352
if (id == cmp::f) return world.lit_ff();
353353
if (id == cmp::t) return world.lit_tt();
354354

355-
return world.app<false>(callee, {a, b}, dbg);
355+
return world.raw_app(type, callee, {a, b}, dbg);
356356
}
357357

358358
template<class Id, Id id, nat_t sw, nat_t dw>
@@ -363,26 +363,26 @@ Res fold(u64 a) {
363363
}
364364

365365
template<conv id>
366-
const Def* normalize_conv(const Def* dst_ty, const Def* c, const Def* x, const Def* dbg) {
367-
auto& world = dst_ty->world();
366+
const Def* normalize_conv(const Def* dst_t, const Def* c, const Def* x, const Def* dbg) {
367+
auto& world = dst_t->world();
368368
auto callee = c->as<App>();
369-
auto s_ty = x->type()->as<App>();
370-
auto d_ty = dst_ty->as<App>();
371-
auto s = s_ty->arg();
372-
auto d = d_ty->arg();
369+
auto s_t = x->type()->as<App>();
370+
auto d_t = dst_t->as<App>();
371+
auto s = s_t->arg();
372+
auto d = d_t->arg();
373373
auto ls = isa_lit(s);
374374
auto ld = isa_lit(d);
375375

376-
if (s_ty == d_ty) return x;
377-
if (x->isa<Bot>()) return world.bot(d_ty, dbg);
376+
if (s_t == d_t) return x;
377+
if (x->isa<Bot>()) return world.bot(d_t, dbg);
378378

379379
if (auto l = isa_lit(x); l && ls && ld) {
380380
constexpr bool sf = id == conv::f2f || id == conv::f2s || id == conv::f2u;
381381
constexpr bool df = id == conv::f2f || id == conv::s2f || id == conv::u2f;
382382
constexpr nat_t min_s = sf ? 16 : 1;
383383
constexpr nat_t min_d = df ? 16 : 1;
384-
auto sw = sf ? isa_f(s_ty) : Idx::size2bitwidth(*ls);
385-
auto dw = df ? isa_f(d_ty) : Idx::size2bitwidth(*ld);
384+
auto sw = sf ? isa_f(s_t) : Idx::size2bitwidth(*ls);
385+
auto dw = df ? isa_f(d_t) : Idx::size2bitwidth(*ld);
386386

387387
if (sw && dw) {
388388
Res res;
@@ -403,12 +403,11 @@ const Def* normalize_conv(const Def* dst_ty, const Def* c, const Def* x, const D
403403

404404
else unreachable();
405405
// clang-format on
406-
407-
return world.lit(d_ty, *res, dbg);
406+
return world.lit(d_t, *res, dbg);
408407
}
409408
}
410409
out:
411-
return world.app<false>(callee, x, dbg);
410+
return world.raw_app(dst_t, callee, x, dbg);
412411
}
413412

414413
// TODO I guess we can do that with C++20 <bit>

0 commit comments

Comments
 (0)