diff --git a/src/aro/Parser.zig b/src/aro/Parser.zig index dbb9eeac..a639c981 100644 --- a/src/aro/Parser.zig +++ b/src/aro/Parser.zig @@ -6604,7 +6604,11 @@ fn eqExpr(p: *Parser) Error!Result { if (try lhs.adjustTypes(ne.?, &rhs, p, .equality)) { const op: std.math.CompareOperator = if (tag == .equal_expr) .eq else .neq; - const res = lhs.val.compareExtra(op, rhs.val, p.comp); + const res: ?bool = if (lhs.ty.isPtr() or rhs.ty.isPtr()) + lhs.val.comparePointers(op, rhs.val, p.comp) + else + lhs.val.compare(op, rhs.val, p.comp); + lhs.val = if (res) |val| Value.fromBool(val) else .{}; } else { lhs.val.boolCast(p.comp); @@ -6635,7 +6639,10 @@ fn compExpr(p: *Parser) Error!Result { .greater_than_equal_expr => .gte, else => unreachable, }; - const res = lhs.val.compareExtra(op, rhs.val, p.comp); + const res: ?bool = if (lhs.ty.isPtr() or rhs.ty.isPtr()) + lhs.val.comparePointers(op, rhs.val, p.comp) + else + lhs.val.compare(op, rhs.val, p.comp); lhs.val = if (res) |val| Value.fromBool(val) else .{}; } else { lhs.val.boolCast(p.comp); diff --git a/src/aro/Value.zig b/src/aro/Value.zig index 9f6091ee..21efd232 100644 --- a/src/aro/Value.zig +++ b/src/aro/Value.zig @@ -626,7 +626,7 @@ pub fn sub(res: *Value, lhs: Value, rhs: Value, ty: Type, rhs_ty: Type, comp: *C var total_offset: Value = undefined; const mul_overflow = try total_offset.mul(elem_size, rhs, comp.types.ptrdiff, comp); const old_offset = try int(rel.offset, comp); - const add_overflow = try total_offset.sub(total_offset, old_offset, comp.types.ptrdiff, undefined, comp); + const add_overflow = try total_offset.sub(old_offset, total_offset, comp.types.ptrdiff, undefined, comp); _ = try total_offset.intCast(comp.types.ptrdiff, comp); res.* = try reloc(.{ .name = rel.name, .offset = total_offset.toInt(i64, comp).? }, comp); return mul_overflow or add_overflow; @@ -961,13 +961,17 @@ pub fn complexConj(val: Value, ty: Type, comp: *Compilation) !Value { return intern(comp, .{ .complex = cf }); } -/// Returns null for values that cannot be compared at compile time (e.g. `&x < &y`) for globals `x` and `y`. -pub fn compareExtra(lhs: Value, op: std.math.CompareOperator, rhs: Value, comp: *const Compilation) ?bool { +fn shallowCompare(lhs: Value, op: std.math.CompareOperator, rhs: Value) ?bool { if (op == .eq) { return lhs.opt_ref == rhs.opt_ref; } else if (lhs.opt_ref == rhs.opt_ref) { return std.math.Order.eq.compare(op); } + return null; +} + +pub fn compare(lhs: Value, op: std.math.CompareOperator, rhs: Value, comp: *const Compilation) bool { + if (lhs.shallowCompare(op, rhs)) |val| return val; const lhs_key = comp.interner.get(lhs.ref()); const rhs_key = comp.interner.get(rhs.ref()); @@ -982,6 +986,21 @@ pub fn compareExtra(lhs: Value, op: std.math.CompareOperator, rhs: Value, comp: const imag_equal = std.math.compare(lhs.imag(f128, comp), .eq, rhs.imag(f128, comp)); return !real_equal or !imag_equal; } + + var lhs_bigint_space: BigIntSpace = undefined; + var rhs_bigint_space: BigIntSpace = undefined; + const lhs_bigint = lhs.toBigInt(&lhs_bigint_space, comp); + const rhs_bigint = rhs.toBigInt(&rhs_bigint_space, comp); + return lhs_bigint.order(rhs_bigint).compare(op); +} + +/// Returns null for values that cannot be compared at compile time (e.g. `&x < &y`) for globals `x` and `y`. +pub fn comparePointers(lhs: Value, op: std.math.CompareOperator, rhs: Value, comp: *const Compilation) ?bool { + if (lhs.shallowCompare(op, rhs)) |val| return val; + + const lhs_key = comp.interner.get(lhs.ref()); + const rhs_key = comp.interner.get(rhs.ref()); + if (lhs_key == .global_var_offset and rhs_key == .global_var_offset) { const lhs_reloc = lhs_key.global_var_offset; const rhs_reloc = rhs_key.global_var_offset; @@ -992,19 +1011,8 @@ pub fn compareExtra(lhs: Value, op: std.math.CompareOperator, rhs: Value, comp: } return std.math.compare(lhs_reloc.offset, op, rhs_reloc.offset); - } else if (lhs_key == .global_var_offset or rhs_key == .global_var_offset) { - return null; } - - var lhs_bigint_space: BigIntSpace = undefined; - var rhs_bigint_space: BigIntSpace = undefined; - const lhs_bigint = lhs.toBigInt(&lhs_bigint_space, comp); - const rhs_bigint = rhs.toBigInt(&rhs_bigint_space, comp); - return lhs_bigint.order(rhs_bigint).compare(op); -} - -pub fn compare(lhs: Value, op: std.math.CompareOperator, rhs: Value, comp: *const Compilation) bool { - return lhs.compareExtra(op, rhs, comp).?; + return null; } fn twosCompIntLimit(limit: std.math.big.int.TwosCompIntLimit, ty: Type, comp: *Compilation) !Value { diff --git a/test/cases/relocations.c b/test/cases/relocations.c index d275f9bb..0e0a25d6 100644 --- a/test/cases/relocations.c +++ b/test/cases/relocations.c @@ -41,6 +41,8 @@ _Static_assert(&packed.x - &packed.y == -1); char *p = (char*)(&x + 100); _Static_assert((char*)(&x+100) - (char*)&x == 400,""); +_Static_assert(&x - 2 != &x + 2, ""); +_Static_assert(&x - 2 == -2 + &x, ""); #define EXPECTED_ERRORS "relocations.c:24:1: error: static assertion failed" \ "relocations.c:29:16: error: static_assert expression is not an integral constant expression" \