Skip to content

Commit

Permalink
re-introduce raw_app to avoid recomputing dependent type
Browse files Browse the repository at this point in the history
Also replaces some raw_app%s with app%s in autodiff
  • Loading branch information
leissa committed Nov 28, 2022
1 parent 11e373f commit ba3f6ef
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 82 deletions.
2 changes: 1 addition & 1 deletion dialects/autodiff/auxiliary/autodiff_aux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ const Def* zero_def(const Def* T) {
const Def* op_sum(const Def* T, DefArray defs) {
// TODO: assert all are of type T
auto& world = T->world();
return world.app<false>(world.app<false>(world.ax<sum>(), {world.lit_nat(defs.size()), T}), world.tuple(defs));
return world.app(world.app(world.ax<sum>(), {world.lit_nat(defs.size()), T}), world.tuple(defs));
}

} // namespace thorin::autodiff
Expand Down
6 changes: 3 additions & 3 deletions dialects/autodiff/auxiliary/autodiff_rewrite_inner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ const Def* AutoDiffEval::augment_pack(const Pack* pack, Lam* f, Lam* f_diff) {
// TODO: special case for const width (special tuple)

// <i:n, cps2ds body_pb (s#i)>
app_pb->set(world.app<false>(direct::op_cps2ds_dep(body_pb), world.extract(pb->var((nat_t)0), app_pb->var())));
app_pb->set(world.app(direct::op_cps2ds_dep(body_pb), world.extract(pb->var((nat_t)0), app_pb->var())));

world.DLOG("app pb of pack: {} : {}", app_pb, app_pb->type());

auto sumup = world.app<false>(world.ax<sum>(), {aug_shape, f_arg_ty_diff});
auto sumup = world.app(world.ax<sum>(), {aug_shape, f_arg_ty_diff});
world.DLOG("sumup: {} : {}", sumup, sumup->type());

pb->app(true, pb->var(1), world.app<false>(sumup, app_pb));
pb->app(true, pb->var(1), world.app(sumup, app_pb));

partial_pullback[aug_pack] = pb;

Expand Down
10 changes: 5 additions & 5 deletions dialects/autodiff/normalizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ namespace thorin::autodiff {
/// TODO: Maybe we want to handle trivial lookup replacements here.
const Def* normalize_ad(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
auto& world = type->world();
return world.app<false>(callee, arg, dbg);
return world.raw_app(type, callee, arg, dbg);
}

const Def* normalize_AD(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
auto& world = type->world();
auto ad_ty = autodiff_type_fun(arg);
if (ad_ty) return ad_ty;
return world.app<false>(callee, arg, dbg);
return world.raw_app(type, callee, arg, dbg);
}

const Def* normalize_Tangent(const Def*, const Def*, const Def* arg, const Def*) { return tangent_type_fun(arg); }
Expand All @@ -30,7 +30,7 @@ const Def* normalize_Tangent(const Def*, const Def*, const Def* arg, const Def*)
/// A high-level addition with zero can be shortened directly.
const Def* normalize_zero(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
auto& world = type->world();
return world.app<false>(callee, arg, dbg);
return world.raw_app(type, callee, arg, dbg);
}

/// Currently resolved the full addition.
Expand Down Expand Up @@ -85,7 +85,7 @@ const Def* normalize_add(const Def* type, const Def* callee, const Def* arg, con
}
// TODO: mem stays here (only resolved after direct simplification)

return world.app<false>(callee, arg, dbg);
return world.raw_app(type, callee, arg, dbg);
}

const Def* normalize_sum(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
Expand All @@ -105,7 +105,7 @@ const Def* normalize_sum(const Def* type, const Def* callee, const Def* arg, con
}
assert(0);

return world.app<false>(callee, arg, dbg);
return world.raw_app(type, callee, arg, dbg);
}

THORIN_autodiff_NORMALIZER_IMPL
Expand Down
2 changes: 1 addition & 1 deletion dialects/clos/normalizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace thorin::clos {
template<attr o>
const Def* normalize_clos(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
auto& w = type->world();
return o == attr::bot ? arg : w.app<false>(callee, arg, dbg);
return o == attr::bot ? arg : w.raw_app(type, callee, arg, dbg);
}

THORIN_clos_NORMALIZER_IMPL
Expand Down
50 changes: 25 additions & 25 deletions dialects/core/normalizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ const Def* normalize_nop(const Def* type, const Def* callee, const Def* arg, con
}
}

return world.app<false>(callee, arg, dbg);
return world.raw_app(type, callee, arg, dbg);
}

template<ncmp id>
Expand Down Expand Up @@ -219,7 +219,7 @@ const Def* normalize_ncmp(const Def* type, const Def* callee, const Def* arg, co
}
}

return world.app<false>(callee, arg, dbg);
return world.raw_app(type, callee, arg, dbg);
}

template<icmp id>
Expand All @@ -236,7 +236,7 @@ const Def* normalize_icmp(const Def* type, const Def* c, const Def* arg, const D
if (id == icmp::ne) return world.lit_ff();
}

return world.app<false>(callee, {a, b}, dbg);
return world.raw_app(type, callee, {a, b}, dbg);
}

template<bit1 id>
Expand All @@ -255,7 +255,7 @@ const Def* normalize_bit1(const Def* type, const Def* c, const Def* a, const Def

if (l) return world.lit_idx_mod(*s, ~*l);

return world.app<false>(callee, a, dbg);
return world.raw_app(type, callee, a, dbg);
}

template<class Id>
Expand Down Expand Up @@ -358,7 +358,7 @@ const Def* normalize_bit2(const Def* type, const Def* c, const Def* arg, const D

if (auto res = reassociate<bit2>(id, world, callee, a, b, dbg)) return res;

return world.app<false>(callee, {a, b}, dbg);
return world.raw_app(type, callee, {a, b}, dbg);
}

template<shr id>
Expand Down Expand Up @@ -391,7 +391,7 @@ const Def* normalize_shr(const Def* type, const Def* c, const Def* arg, const De
if (ls && lb->get() > *ls) return world.bot(type, dbg);
}

return world.app<false>(callee, {a, b}, dbg);
return world.raw_app(type, callee, {a, b}, dbg);
}

template<wrap id>
Expand Down Expand Up @@ -454,7 +454,7 @@ const Def* normalize_wrap(const Def* type, const Def* c, const Def* arg, const D

if (auto res = reassociate<wrap>(id, world, callee, a, b, dbg)) return res;

return world.app<false>(callee, {a, b}, dbg);
return world.raw_app(type, callee, {a, b}, dbg);
}

template<div id>
Expand Down Expand Up @@ -493,30 +493,30 @@ const Def* normalize_div(const Def* type, const Def* c, const Def* arg, const De
}
}

return world.app<false>(callee, {mem, a, b}, dbg);
return world.raw_app(type, callee, {mem, a, b}, dbg);
}

template<conv id>
const Def* normalize_conv(const Def* dst_ty, const Def* c, const Def* x, const Def* dbg) {
auto& world = dst_ty->world();
const Def* normalize_conv(const Def* dst_t, const Def* c, const Def* x, const Def* dbg) {
auto& world = dst_t->world();
auto callee = c->as<App>();
auto s_ty = x->type()->as<App>();
auto d_ty = dst_ty->as<App>();
auto s = s_ty->arg();
auto d = d_ty->arg();
auto s_t = x->type()->as<App>();
auto d_t = dst_t->as<App>();
auto s = s_t->arg();
auto d = d_t->arg();
auto ls = isa_lit(s);
auto ld = isa_lit(d);

if (s_ty == d_ty) return x;
if (x->isa<Bot>()) return world.bot(d_ty, dbg);
if (s_t == d_t) return x;
if (x->isa<Bot>()) return world.bot(d_t, dbg);
if constexpr (id == conv::s2s) {
if (ls && ld && *ld < *ls) return op(conv::u2u, d_ty, x, dbg); // just truncate - we don't care for signedness
if (ls && ld && *ld < *ls) return op(conv::u2u, d_t, x, dbg); // just truncate - we don't care for signedness
}

if (auto l = isa_lit(x); l && ls && ld) {
if constexpr (id == conv::u2u) {
if (*ld == 0) return world.lit(d_ty, *l); // I64
return world.lit(d_ty, *l % *ld);
if (*ld == 0) return world.lit(d_t, *l); // I64
return world.lit(d_t, *l % *ld);
}

auto sw = Idx::size2bitwidth(*ls);
Expand All @@ -525,7 +525,7 @@ const Def* normalize_conv(const Def* dst_ty, const Def* c, const Def* x, const D
// clang-format off
if (false) {}
#define M(S, D) \
else if (S == sw && D == dw) return world.lit(d_ty, w2s<D>(thorin::bitcast<w2s<S>>(*l)), dbg);
else if (S == sw && D == dw) return world.lit(d_t, w2s<D>(thorin::bitcast<w2s<S>>(*l)), dbg);
M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
M( 8, 16) M( 8, 32) M( 8, 64)
M(16, 32) M(16, 64)
Expand All @@ -534,7 +534,7 @@ const Def* normalize_conv(const Def* dst_ty, const Def* c, const Def* x, const D
// clang-format on
}

return world.app<false>(callee, x, dbg);
return world.raw_app(dst_t, callee, x, dbg);
}

const Def* normalize_bitcast(const Def* dst_t, const Def* callee, const Def* src, const Def* dbg) {
Expand All @@ -551,7 +551,7 @@ const Def* normalize_bitcast(const Def* dst_t, const Def* callee, const Def* src
if (Idx::size(dst_t)) return world.lit(dst_t, lit->get(), dbg);
}

return world.app<false>(callee, src, dbg);
return world.raw_app(dst_t, callee, src, dbg);
}

// TODO I guess we can do that with C++20 <bit>
Expand Down Expand Up @@ -609,7 +609,7 @@ const Def* normalize_trait(const Def*, const Def* callee, const Def* type, const
}

out:
return world.app<false>(callee, type, dbg);
return world.raw_app(type, callee, type, dbg);
}

const Def* normalize_zip(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
Expand Down Expand Up @@ -650,7 +650,7 @@ const Def* normalize_zip(const Def* type, const Def* c, const Def* arg, const De
}
}

return w.app<false>(callee, arg, dbg);
return w.raw_app(type, callee, arg, dbg);
}

template<pe id>
Expand All @@ -662,7 +662,7 @@ const Def* normalize_pe(const Def* type, const Def* callee, const Def* arg, cons
if (arg->dep_const()) return world.lit_tt();
}

return world.app<false>(callee, arg, dbg);
return world.raw_app(type, callee, arg, dbg);
}

THORIN_core_NORMALIZER_IMPL
Expand Down
4 changes: 2 additions & 2 deletions dialects/direct/direct.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ inline const Def* op_cps2ds_dep(const Def* f) {
Uf->set_filter(true);
Uf->set_body(rewritten_codom);

auto ax_app = world.app<false>(world.ax<direct::cps2ds_dep>(), {T, Uf});
auto ax_app = world.app(world.ax<direct::cps2ds_dep>(), {T, Uf});

world.DLOG("axiom app: {} : {}", ax_app, ax_app->type());

return world.app<false>(ax_app, f);
return world.app(ax_app, f);
}

} // namespace thorin::direct
2 changes: 1 addition & 1 deletion dialects/direct/normalizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const Def* normalize_cps2ds(const Def* type, const Def* callee, const Def* arg,
for (size_t i = 2; i < curry_args.size(); ++i) r = world.app(r, curry_args[i]);
return r;

return world.app<false>(callee, arg, dbg);
return world.raw_app(type, callee, arg, dbg);
}

THORIN_direct_NORMALIZER_IMPL
Expand Down
43 changes: 21 additions & 22 deletions dialects/math/normalizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ const Def* normalize_arith(const Def* type, const Def* c, const Def* arg, const

if (auto res = reassociate<arith>(id, world, callee, a, b, dbg)) return res;

return world.app<false>(callee, {a, b}, dbg);
return world.raw_app(type, callee, {a, b}, dbg);
}

template<extrema id>
Expand All @@ -297,49 +297,49 @@ const Def* normalize_extrema(const Def* type, const Def* c, const Def* arg, cons
}
}

return world.app<false>(c, arg, dbg);
return world.raw_app(type, c, arg, dbg);
}

template<tri id>
const Def* normalize_tri(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
auto& world = type->world();
if (auto lit = fold<tri, id>(world, type, arg, dbg)) return lit;
return world.app<false>(c, arg, dbg);
return world.raw_app(type, c, arg, dbg);
}

const Def* normalize_pow(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
auto& world = type->world();
auto [a, b] = arg->projs<2>();
if (auto lit = fold<pow, /*dummy*/ pow(0)>(world, type, a, b, dbg)) return lit;
return world.app<false>(c, arg, dbg);
return world.raw_app(type, c, arg, dbg);
}

template<rt id>
const Def* normalize_rt(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
auto& world = type->world();
if (auto lit = fold<rt, id>(world, type, arg, dbg)) return lit;
return world.app<false>(c, arg, dbg);
return world.raw_app(type, c, arg, dbg);
}

template<exp id>
const Def* normalize_exp(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
auto& world = type->world();
if (auto lit = fold<exp, id>(world, type, arg, dbg)) return lit;
return world.app<false>(c, arg, dbg);
return world.raw_app(type, c, arg, dbg);
}

template<er id>
const Def* normalize_er(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
auto& world = type->world();
if (auto lit = fold<er, id>(world, type, arg, dbg)) return lit;
return world.app<false>(c, arg, dbg);
return world.raw_app(type, c, arg, dbg);
}

template<gamma id>
const Def* normalize_gamma(const Def* type, const Def* c, const Def* arg, const Def* dbg) {
auto& world = type->world();
if (auto lit = fold<gamma, id>(world, type, arg, dbg)) return lit;
return world.app<false>(c, arg, dbg);
return world.raw_app(type, c, arg, dbg);
}

template<cmp id>
Expand All @@ -352,7 +352,7 @@ const Def* normalize_cmp(const Def* type, const Def* c, const Def* arg, const De
if (id == cmp::f) return world.lit_ff();
if (id == cmp::t) return world.lit_tt();

return world.app<false>(callee, {a, b}, dbg);
return world.raw_app(type, callee, {a, b}, dbg);
}

template<class Id, Id id, nat_t sw, nat_t dw>
Expand All @@ -363,26 +363,26 @@ Res fold(u64 a) {
}

template<conv id>
const Def* normalize_conv(const Def* dst_ty, const Def* c, const Def* x, const Def* dbg) {
auto& world = dst_ty->world();
const Def* normalize_conv(const Def* dst_t, const Def* c, const Def* x, const Def* dbg) {
auto& world = dst_t->world();
auto callee = c->as<App>();
auto s_ty = x->type()->as<App>();
auto d_ty = dst_ty->as<App>();
auto s = s_ty->arg();
auto d = d_ty->arg();
auto s_t = x->type()->as<App>();
auto d_t = dst_t->as<App>();
auto s = s_t->arg();
auto d = d_t->arg();
auto ls = isa_lit(s);
auto ld = isa_lit(d);

if (s_ty == d_ty) return x;
if (x->isa<Bot>()) return world.bot(d_ty, dbg);
if (s_t == d_t) return x;
if (x->isa<Bot>()) return world.bot(d_t, dbg);

if (auto l = isa_lit(x); l && ls && ld) {
constexpr bool sf = id == conv::f2f || id == conv::f2s || id == conv::f2u;
constexpr bool df = id == conv::f2f || id == conv::s2f || id == conv::u2f;
constexpr nat_t min_s = sf ? 16 : 1;
constexpr nat_t min_d = df ? 16 : 1;
auto sw = sf ? isa_f(s_ty) : Idx::size2bitwidth(*ls);
auto dw = df ? isa_f(d_ty) : Idx::size2bitwidth(*ld);
auto sw = sf ? isa_f(s_t) : Idx::size2bitwidth(*ls);
auto dw = df ? isa_f(d_t) : Idx::size2bitwidth(*ld);

if (sw && dw) {
Res res;
Expand All @@ -403,12 +403,11 @@ const Def* normalize_conv(const Def* dst_ty, const Def* c, const Def* x, const D

else unreachable();
// clang-format on

return world.lit(d_ty, *res, dbg);
return world.lit(d_t, *res, dbg);
}
}
out:
return world.app<false>(callee, x, dbg);
return world.raw_app(dst_t, callee, x, dbg);
}

// TODO I guess we can do that with C++20 <bit>
Expand Down
Loading

0 comments on commit ba3f6ef

Please sign in to comment.