Skip to content

Commit bdbc485

Browse files
committed
compiler: handle eval branch quota in memoized calls
In a `memoized_call`, store how many backwards braches the call performs. Add this to `sema.branch_count` when using a memoized call. If this exceeds the quota, perform a non-memoized call to get a correct "exceeded X backwards branches" error. Also, do not memoize calls which do `@setEvalBranchQuota` or similar, as this affects global state which must apply to the caller. Change some eval branch quotas so that the compiler itself still builds correctly. This commit manually changes a file in Aro which is automatically generated. The sources which generate the file are not in this repo. Upstream Aro should make the suitable changes on their end before the next sync of Aro sources into the Zig repo.
1 parent 0a70455 commit bdbc485

File tree

6 files changed

+50
-11
lines changed

6 files changed

+50
-11
lines changed

lib/compiler/aro/aro/Builtins/Builtin.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5165,7 +5165,7 @@ const dafsa = [_]Node{
51655165
.{ .char = 'e', .end_of_word = false, .end_of_list = true, .number = 1, .child_index = 4913 },
51665166
};
51675167
pub const data = blk: {
5168-
@setEvalBranchQuota(3986);
5168+
@setEvalBranchQuota(30_000);
51695169
break :blk [_]@This(){
51705170
// _Block_object_assign
51715171
.{ .tag = @enumFromInt(0), .properties = .{ .param_str = "vv*vC*iC", .header = .blocks, .attributes = .{ .lib_function_without_prefix = true } } },

lib/compiler/aro/aro/Parser.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4802,6 +4802,7 @@ const CallExpr = union(enum) {
48024802
}
48034803

48044804
fn shouldPromoteVarArg(self: CallExpr, arg_idx: u32) bool {
4805+
@setEvalBranchQuota(2000);
48054806
return switch (self) {
48064807
.standard => true,
48074808
.builtin => |builtin| switch (builtin.tag) {
@@ -4902,6 +4903,7 @@ const CallExpr = union(enum) {
49024903
}
49034904

49044905
fn returnType(self: CallExpr, p: *Parser, callable_ty: Type) Type {
4906+
@setEvalBranchQuota(6000);
49054907
return switch (self) {
49064908
.standard => callable_ty.returnType(),
49074909
.builtin => |builtin| switch (builtin.tag) {

lib/std/crypto/pcurves/p384.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ pub const P384 = struct {
393393
}
394394

395395
const basePointPc = pc: {
396-
@setEvalBranchQuota(50000);
396+
@setEvalBranchQuota(70000);
397397
break :pc precompute(P384.basePoint, 15);
398398
};
399399

src/InternPool.zig

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2391,6 +2391,7 @@ pub const Key = union(enum) {
23912391
func: Index,
23922392
arg_values: []const Index,
23932393
result: Index,
2394+
branch_count: u32,
23942395
};
23952396

23962397
pub fn hash32(key: Key, ip: *const InternPool) u32 {
@@ -6157,6 +6158,7 @@ pub const MemoizedCall = struct {
61576158
func: Index,
61586159
args_len: u32,
61596160
result: Index,
6161+
branch_count: u32,
61606162
};
61616163

61626164
pub fn init(ip: *InternPool, gpa: Allocator, available_threads: usize) !void {
@@ -6785,6 +6787,7 @@ pub fn indexToKey(ip: *const InternPool, index: Index) Key {
67856787
.func = extra.data.func,
67866788
.arg_values = @ptrCast(extra_list.view().items(.@"0")[extra.end..][0..extra.data.args_len]),
67876789
.result = extra.data.result,
6790+
.branch_count = extra.data.branch_count,
67886791
} };
67896792
},
67906793
};
@@ -7955,6 +7958,7 @@ pub fn get(ip: *InternPool, gpa: Allocator, tid: Zcu.PerThread.Id, key: Key) All
79557958
.func = memoized_call.func,
79567959
.args_len = @intCast(memoized_call.arg_values.len),
79577960
.result = memoized_call.result,
7961+
.branch_count = memoized_call.branch_count,
79587962
}),
79597963
});
79607964
extra.appendSliceAssumeCapacity(.{@ptrCast(memoized_call.arg_values)});

src/Sema.zig

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ type_references: std.AutoArrayHashMapUnmanaged(InternPool.Index, void) = .{},
113113
/// `AnalUnit` multiple times.
114114
dependencies: std.AutoArrayHashMapUnmanaged(InternPool.Dependee, void) = .{},
115115

116+
/// Whether memoization of this call is permitted. Operations with side effects global
117+
/// to the `Sema`, such as `@setEvalBranchQuota`, set this to `false`. It is observed
118+
/// by `analyzeCall`.
119+
allow_memoize: bool = true,
120+
116121
const MaybeComptimeAlloc = struct {
117122
/// The runtime index of the `alloc` instruction.
118123
runtime_index: Value.RuntimeIndex,
@@ -5524,6 +5529,7 @@ fn zirSetEvalBranchQuota(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Compi
55245529
.needed_comptime_reason = "eval branch quota must be comptime-known",
55255530
}));
55265531
sema.branch_quota = @max(sema.branch_quota, quota);
5532+
sema.allow_memoize = false;
55275533
}
55285534

55295535
fn zirStoreNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!void {
@@ -6416,6 +6422,7 @@ fn zirSetAlignStack(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.Inst
64166422
}
64176423

64186424
zcu.intern_pool.funcMaxStackAlignment(sema.func_index, alignment);
6425+
sema.allow_memoize = false;
64196426
}
64206427

64216428
fn zirSetCold(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!void {
@@ -6434,6 +6441,7 @@ fn zirSetCold(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData)
64346441
.cau => return, // does nothing outside a function
64356442
};
64366443
ip.funcSetCold(func, is_cold);
6444+
sema.allow_memoize = false;
64376445
}
64386446

64396447
fn zirDisableInstrumentation(sema: *Sema) CompileError!void {
@@ -6445,6 +6453,7 @@ fn zirDisableInstrumentation(sema: *Sema) CompileError!void {
64456453
.cau => return, // does nothing outside a function
64466454
};
64476455
ip.funcSetDisableInstrumentation(func);
6456+
sema.allow_memoize = false;
64486457
}
64496458

64506459
fn zirSetFloatMode(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!void {
@@ -7727,15 +7736,25 @@ fn analyzeCall(
77277736
// This `res2` is here instead of directly breaking from `res` due to a stage1
77287737
// bug generating invalid LLVM IR.
77297738
const res2: Air.Inst.Ref = res2: {
7730-
if (should_memoize and is_comptime_call) {
7731-
if (zcu.intern_pool.getIfExists(.{ .memoized_call = .{
7732-
.func = module_fn_index,
7733-
.arg_values = memoized_arg_values,
7734-
.result = .none,
7735-
} })) |memoized_call_index| {
7736-
const memoized_call = zcu.intern_pool.indexToKey(memoized_call_index).memoized_call;
7737-
break :res2 Air.internedToRef(memoized_call.result);
7739+
memoize: {
7740+
if (!should_memoize) break :memoize;
7741+
if (!is_comptime_call) break :memoize;
7742+
const memoized_call_index = ip.getIfExists(.{
7743+
.memoized_call = .{
7744+
.func = module_fn_index,
7745+
.arg_values = memoized_arg_values,
7746+
.result = undefined, // ignored by hash+eql
7747+
.branch_count = undefined, // ignored by hash+eql
7748+
},
7749+
}) orelse break :memoize;
7750+
const memoized_call = ip.indexToKey(memoized_call_index).memoized_call;
7751+
if (sema.branch_count + memoized_call.branch_count > sema.branch_quota) {
7752+
// Let the call play out se we get the correct source location for the
7753+
// "evaluation exceeded X backwards branches" error.
7754+
break :memoize;
77387755
}
7756+
sema.branch_count += memoized_call.branch_count;
7757+
break :res2 Air.internedToRef(memoized_call.result);
77397758
}
77407759

77417760
new_fn_info.return_type = sema.fn_ret_ty.toIntern();
@@ -7773,6 +7792,17 @@ fn analyzeCall(
77737792
child_block.error_return_trace_index = error_return_trace_index;
77747793
}
77757794

7795+
// We temporarily set `allow_memoize` to `true` to track this comptime call.
7796+
// It is restored after this call finishes analysis, so that a caller may
7797+
// know whether an in-progress call (containing this call) may be memoized.
7798+
const old_allow_memoize = sema.allow_memoize;
7799+
defer sema.allow_memoize = old_allow_memoize and sema.allow_memoize;
7800+
sema.allow_memoize = true;
7801+
7802+
// Store the current eval branch count so we can find out how many eval branches
7803+
// the comptime call caused.
7804+
const old_branch_count = sema.branch_count;
7805+
77767806
const result = result: {
77777807
sema.analyzeFnBody(&child_block, fn_info.body) catch |err| switch (err) {
77787808
error.ComptimeReturn => break :result inlining.comptime_result,
@@ -7792,11 +7822,12 @@ fn analyzeCall(
77927822
// a reference to `comptime_allocs` so is not stable across instances of `Sema`.
77937823
// TODO: check whether any external comptime memory was mutated by the
77947824
// comptime function call. If so, then do not memoize the call here.
7795-
if (should_memoize and !Value.fromInterned(result_interned).canMutateComptimeVarState(zcu)) {
7825+
if (should_memoize and sema.allow_memoize and !Value.fromInterned(result_interned).canMutateComptimeVarState(zcu)) {
77967826
_ = try pt.intern(.{ .memoized_call = .{
77977827
.func = module_fn_index,
77987828
.arg_values = memoized_arg_values,
77997829
.result = result_transformed,
7830+
.branch_count = sema.branch_count - old_branch_count,
78007831
} });
78017832
}
78027833

src/register_manager.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ pub fn RegisterManager(
9393
comptime set: []const Register,
9494
reg: Register,
9595
) ?std.math.IntFittingRange(0, set.len - 1) {
96+
@setEvalBranchQuota(3000);
97+
9698
const Id = @TypeOf(reg.id());
9799
comptime var min_id: Id = std.math.maxInt(Id);
98100
comptime var max_id: Id = std.math.minInt(Id);

0 commit comments

Comments
 (0)