From 9aaff357a8021661278dc7ebea9d0db280f5ff9f Mon Sep 17 00:00:00 2001 From: Malcolm Still Date: Sun, 31 Mar 2024 04:39:24 +0100 Subject: [PATCH] Describe memory management (arena) + more complete parser + std.log (#8) --- DESIGN.md | 33 + biscuit-builder/src/check.zig | 28 +- biscuit-builder/src/expression.zig | 38 +- biscuit-builder/src/fact.zig | 8 +- biscuit-builder/src/policy.zig | 14 +- biscuit-builder/src/predicate.zig | 14 +- biscuit-builder/src/root.zig | 1 + biscuit-builder/src/rule.zig | 32 +- biscuit-builder/src/scope.zig | 2 +- biscuit-builder/src/term.zig | 79 +- biscuit-datalog/src/check.zig | 8 +- biscuit-datalog/src/combinator.zig | 4 +- biscuit-datalog/src/expression.zig | 42 +- biscuit-datalog/src/fact.zig | 8 +- biscuit-datalog/src/fact_set.zig | 8 +- biscuit-datalog/src/main.zig | 6 + biscuit-datalog/src/matched_variables.zig | 4 +- biscuit-datalog/src/origin.zig | 4 +- biscuit-datalog/src/predicate.zig | 18 +- biscuit-datalog/src/rule.zig | 103 +- biscuit-datalog/src/rule_set.zig | 10 +- biscuit-datalog/src/scope.zig | 2 +- biscuit-datalog/src/set.zig | 4 +- biscuit-datalog/src/symbol_table.zig | 8 +- biscuit-datalog/src/term.zig | 4 +- biscuit-datalog/src/trusted_origins.zig | 6 +- biscuit-datalog/src/world.zig | 71 +- biscuit-format/src/serialized_biscuit.zig | 20 +- biscuit-format/src/signed_block.zig | 2 +- biscuit-parser/build.zig | 9 - biscuit-parser/build.zig.zon | 1 - biscuit-parser/src/parser.zig | 1474 +++++++++++++-------- biscuit-samples/src/main.zig | 6 +- biscuit/src/authorizer.zig | 527 ++++---- biscuit/src/biscuit.zig | 146 +- biscuit/src/block.zig | 10 +- 36 files changed, 1623 insertions(+), 1131 deletions(-) create mode 100644 DESIGN.md diff --git a/DESIGN.md b/DESIGN.md new file mode 100644 index 0000000..bcfa3a2 --- /dev/null +++ b/DESIGN.md @@ -0,0 +1,33 @@ +# Design + +## Memory Management + +Authorizing a biscuit is a short-lived single pass operation. This makes arena allocation +a great candidate for large parts of biscuit memory management: + +- It can greatly reduce the complexity of code that would otherwise have to carefully clone + / move memory to avoid leaks and double frees. +- Potentially code can be faster because we don't do granular deallocation and we can avoid + some copying. +- Again, because we're not necessarily copying, there is the potential for reduced memory usage + in places. + +The disadvantage would be that: + +- We potentially over allocate, i.e. we are storing more in memory than we technically need to. + +We create a toplevel arena and into it allocate all: + +- facts +- predicates +- rules +- queries +- policies +- expressions +- ops +- terms + +i.e. all the domain level objects are arena allocated. This means we don't have to do +complex reasoning about scope / lifetimes of these objects, they are all valid until +the toplevel arena is deallocated. If we are careful to always copy when modifying +one of these resources we can also share resources. diff --git a/biscuit-builder/src/check.zig b/biscuit-builder/src/check.zig index 6e3f2e9..963fe1d 100644 --- a/biscuit-builder/src/check.zig +++ b/biscuit-builder/src/check.zig @@ -5,25 +5,35 @@ const Term = @import("term.zig").Term; const Rule = @import("rule.zig").Rule; pub const Check = struct { - kind: datalog.Check.Kind, + kind: Kind, queries: std.ArrayList(Rule), - pub fn deinit(check: Check) void { - for (check.queries.items) |query| { - query.deinit(); - } + pub const Kind = enum { + one, + all, + }; - check.queries.deinit(); + pub fn deinit(_: Check) void { + // for (check.queries.items) |query| { + // query.deinit(); + // } + + // check.queries.deinit(); } - pub fn convert(check: Check, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Check { + pub fn toDatalog(check: Check, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Check { var queries = std.ArrayList(datalog.Rule).init(allocator); for (check.queries.items) |query| { - try queries.append(try query.convert(allocator, symbols)); + try queries.append(try query.toDatalog(allocator, symbols)); } - return .{ .kind = check.kind, .queries = queries }; + const kind: datalog.Check.Kind = switch (check.kind) { + .one => .one, + .all => .all, + }; + + return .{ .kind = kind, .queries = queries }; } pub fn format(check: Check, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { diff --git a/biscuit-builder/src/expression.zig b/biscuit-builder/src/expression.zig index e324e26..cc475e5 100644 --- a/biscuit-builder/src/expression.zig +++ b/biscuit-builder/src/expression.zig @@ -58,7 +58,7 @@ pub const Expression = union(ExpressionType) { }; /// convert to datalog fact - pub fn convert(expression: Expression, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Expression { + pub fn toDatalog(expression: Expression, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Expression { var ops = std.ArrayList(datalog.Op).init(allocator); try expression.toOpcodes(allocator, &ops, symbols); @@ -68,7 +68,7 @@ pub const Expression = union(ExpressionType) { pub fn toOpcodes(expression: Expression, allocator: std.mem.Allocator, ops: *std.ArrayList(datalog.Op), symbols: *datalog.SymbolTable) !void { switch (expression) { - .value => |v| try ops.append(.{ .value = try v.convert(allocator, symbols) }), + .value => |v| try ops.append(.{ .value = try v.toDatalog(allocator, symbols) }), .unary => |u| { try u.expression.toOpcodes(allocator, ops, symbols); @@ -113,22 +113,22 @@ pub const Expression = union(ExpressionType) { } } - pub fn deinit(expression: *Expression) void { - switch (expression.*) { - .value => |v| v.deinit(), - .unary => |*u| { - u.expression.deinit(); - - u.allocator.destroy(u.expression); - }, - .binary => |*b| { - b.left.deinit(); - b.right.deinit(); - - b.allocator.destroy(b.left); - b.allocator.destroy(b.right); - }, - } + pub fn deinit(_: *Expression) void { + // switch (expression.*) { + // .value => |v| v.deinit(), + // .unary => |*u| { + // u.expression.deinit(); + + // u.allocator.destroy(u.expression); + // }, + // .binary => |*b| { + // b.left.deinit(); + // b.right.deinit(); + + // b.allocator.destroy(b.left); + // b.allocator.destroy(b.right); + // }, + // } } pub fn value(term: Term) !Expression { @@ -159,7 +159,7 @@ pub const Expression = union(ExpressionType) { .value => |v| try writer.print("{any}", .{v}), .unary => |u| { switch (u.op) { - .negate => try writer.print("-{any}", .{u.expression}), + .negate => try writer.print("!{any}", .{u.expression}), .parens => try writer.print("({any})", .{u.expression}), .length => try writer.print("{any}.length()", .{u.expression}), } diff --git a/biscuit-builder/src/fact.zig b/biscuit-builder/src/fact.zig index 0b1cd19..0fcc54c 100644 --- a/biscuit-builder/src/fact.zig +++ b/biscuit-builder/src/fact.zig @@ -7,13 +7,13 @@ pub const Fact = struct { predicate: Predicate, variables: ?std.StringHashMap(?Term), - pub fn deinit(fact: Fact) void { - fact.predicate.deinit(); + pub fn deinit(_: Fact) void { + // fact.predicate.deinit(); } /// convert to datalog fact - pub fn convert(fact: Fact, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Fact { - return .{ .predicate = try fact.predicate.convert(allocator, symbols) }; + pub fn toDatalog(fact: Fact, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Fact { + return .{ .predicate = try fact.predicate.toDatalog(allocator, symbols) }; } pub fn format(fact: Fact, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { diff --git a/biscuit-builder/src/policy.zig b/biscuit-builder/src/policy.zig index 637ac63..badcd45 100644 --- a/biscuit-builder/src/policy.zig +++ b/biscuit-builder/src/policy.zig @@ -11,19 +11,19 @@ pub const Policy = struct { deny, }; - pub fn deinit(policy: Policy) void { - for (policy.queries.items) |query| { - query.deinit(); - } + pub fn deinit(_: Policy) void { + // for (policy.queries.items) |query| { + // query.deinit(); + // } - policy.queries.deinit(); + // policy.queries.deinit(); } - // pub fn convert(policy: Policy, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !Policy { + // pub fn toDatalog(policy: Policy, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !Policy { // var queries = std.ArrayList(Rule).init(allocator); // for (policy.queries.items) |query| { - // try queries.append(try query.convert(allocator, symbols)); + // try queries.append(try query.toDatalog(allocator, symbols)); // } // return .{ .kind = policy.kind, .queries = queries }; diff --git a/biscuit-builder/src/predicate.zig b/biscuit-builder/src/predicate.zig index 8eec66f..739c658 100644 --- a/biscuit-builder/src/predicate.zig +++ b/biscuit-builder/src/predicate.zig @@ -6,22 +6,22 @@ pub const Predicate = struct { name: []const u8, terms: std.ArrayList(Term), - pub fn deinit(predicate: Predicate) void { - for (predicate.terms.items) |term| { - term.deinit(); - } + pub fn deinit(_: Predicate) void { + // for (predicate.terms.items) |term| { + // term.deinit(); + // } - predicate.terms.deinit(); + // predicate.terms.deinit(); } /// convert to datalog predicate - pub fn convert(predicate: Predicate, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Predicate { + pub fn toDatalog(predicate: Predicate, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Predicate { const name = try symbols.insert(predicate.name); var terms = std.ArrayList(datalog.Term).init(allocator); for (predicate.terms.items) |term| { - try terms.append(try term.convert(allocator, symbols)); + try terms.append(try term.toDatalog(allocator, symbols)); } return .{ .name = name, .terms = terms }; diff --git a/biscuit-builder/src/root.zig b/biscuit-builder/src/root.zig index 89c79ff..977df93 100644 --- a/biscuit-builder/src/root.zig +++ b/biscuit-builder/src/root.zig @@ -7,3 +7,4 @@ pub const Expression = @import("expression.zig").Expression; pub const Scope = @import("scope.zig").Scope; pub const Date = @import("date.zig").Date; pub const Policy = @import("policy.zig").Policy; +pub const Set = @import("biscuit-datalog").Set; diff --git a/biscuit-builder/src/rule.zig b/biscuit-builder/src/rule.zig index ab4540a..b9e6238 100644 --- a/biscuit-builder/src/rule.zig +++ b/biscuit-builder/src/rule.zig @@ -12,40 +12,40 @@ pub const Rule = struct { variables: ?std.StringHashMap(?Term), scopes: std.ArrayList(Scope), - pub fn deinit(rule: Rule) void { - rule.head.deinit(); + pub fn deinit(_: Rule) void { + // rule.head.deinit(); - for (rule.body.items) |predicate| { - predicate.deinit(); - } + // for (rule.body.items) |predicate| { + // predicate.deinit(); + // } - for (rule.expressions.items) |*expression| { - expression.deinit(); - } + // for (rule.expressions.items) |*expression| { + // expression.deinit(); + // } - rule.body.deinit(); - rule.expressions.deinit(); - rule.scopes.deinit(); + // rule.body.deinit(); + // rule.expressions.deinit(); + // rule.scopes.deinit(); } /// convert to datalog predicate - pub fn convert(rule: Rule, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Rule { - const head = try rule.head.convert(allocator, symbols); + pub fn toDatalog(rule: Rule, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Rule { + const head = try rule.head.toDatalog(allocator, symbols); var body = std.ArrayList(datalog.Predicate).init(allocator); var expressions = std.ArrayList(datalog.Expression).init(allocator); var scopes = std.ArrayList(datalog.Scope).init(allocator); for (rule.body.items) |predicate| { - try body.append(try predicate.convert(allocator, symbols)); + try body.append(try predicate.toDatalog(allocator, symbols)); } for (rule.expressions.items) |expression| { - try expressions.append(try expression.convert(allocator, symbols)); + try expressions.append(try expression.toDatalog(allocator, symbols)); } for (rule.scopes.items) |scope| { - try scopes.append(try scope.convert(allocator, symbols)); + try scopes.append(try scope.toDatalog(allocator, symbols)); } return .{ diff --git a/biscuit-builder/src/scope.zig b/biscuit-builder/src/scope.zig index eae862c..da52d6c 100644 --- a/biscuit-builder/src/scope.zig +++ b/biscuit-builder/src/scope.zig @@ -19,7 +19,7 @@ pub const Scope = union(ScopeKind) { parameter: []const u8, /// convert to datalog fact - pub fn convert(scope: Scope, _: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Scope { + pub fn toDatalog(scope: Scope, _: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Scope { return switch (scope) { .authority => .{ .authority = {} }, .previous => .{ .previous = {} }, diff --git a/biscuit-builder/src/term.zig b/biscuit-builder/src/term.zig index f3ffa45..fb8e10e 100644 --- a/biscuit-builder/src/term.zig +++ b/biscuit-builder/src/term.zig @@ -8,6 +8,8 @@ const TermTag = enum(u8) { integer, bool, date, + bytes, + set, }; pub const Term = union(TermTag) { @@ -16,17 +18,30 @@ pub const Term = union(TermTag) { integer: i64, bool: bool, date: u64, + bytes: []const u8, + set: datalog.Set(Term), pub fn deinit(_: Term) void {} - pub fn convert(term: Term, _: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Term { - return switch (term) { - .variable => |s| .{ .variable = @truncate(try symbols.insert(s)) }, // FIXME: assert symbol fits in u32 - .string => |s| .{ .string = try symbols.insert(s) }, - .integer => |n| .{ .integer = n }, - .bool => |b| .{ .bool = b }, - .date => |d| .{ .date = d }, - }; + pub fn toDatalog(term: Term, arena: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Term { + switch (term) { + .variable => |s| return .{ .variable = std.math.cast(u32, try symbols.insert(s)) orelse return error.FailedToCastInt }, + .string => |s| return .{ .string = try symbols.insert(s) }, + .integer => |n| return .{ .integer = n }, + .bool => |b| return .{ .bool = b }, + .date => |d| return .{ .date = d }, + .bytes => |b| return .{ .bytes = b }, + .set => |s| { + var datalog_set = datalog.Set(datalog.Term).init(arena); + + var it = s.iterator(); + while (it.next()) |t| { + try datalog_set.add(try t.toDatalog(arena, symbols)); + } + + return .{ .set = datalog_set }; + }, + } } pub fn format(term: Term, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { @@ -36,6 +51,54 @@ pub const Term = union(TermTag) { .integer => |n| try writer.print("{}", .{n}), .bool => |b| if (b) try writer.print("true", .{}) else try writer.print("false", .{}), .date => |n| try writer.print("{}", .{n}), + .bytes => |b| try writer.print("{x}", .{b}), + .set => |s| { + try writer.print("[", .{}); + + const count = s.count(); + + var it = s.iterator(); + + var i: usize = 0; + while (it.next()) |t| { + defer i += 1; + try writer.print("{any}", .{t}); + + if (i < count - 1) try writer.print(", ", .{}); + } + + try writer.print("]", .{}); + }, + } + } + + pub fn eql(term: Term, other_term: Term) bool { + if (std.meta.activeTag(term) != std.meta.activeTag(other_term)) return false; + + return switch (term) { + .variable => |v| std.mem.eql(u8, v, other_term.variable), + .integer => |v| v == other_term.integer, + .string => |v| std.mem.eql(u8, v, other_term.string), + .bool => |v| v == other_term.bool, + .date => |v| v == other_term.date, + .bytes => |v| std.mem.eql(u8, v, other_term.bytes), + .set => |v| v.eql(other_term.set), + }; + } + + pub fn hash(term: Term, hasher: anytype) void { + // Hash the tag type + std.hash.autoHash(hasher, std.meta.activeTag(term)); + + // Hash the value + switch (term) { + .variable => |v| for (v) |b| std.hash.autoHash(hasher, b), + .integer => |v| std.hash.autoHash(hasher, v), + .string => |v| for (v) |b| std.hash.autoHash(hasher, b), + .bool => |v| std.hash.autoHash(hasher, v), + .date => |v| std.hash.autoHash(hasher, v), + .bytes => |v| for (v) |b| std.hash.autoHash(hasher, b), + .set => |v| v.hash(hasher), } } }; diff --git a/biscuit-datalog/src/check.zig b/biscuit-datalog/src/check.zig index 5a4564e..1722732 100644 --- a/biscuit-datalog/src/check.zig +++ b/biscuit-datalog/src/check.zig @@ -24,9 +24,9 @@ pub const Check = struct { return .{ .queries = rules, .kind = kind }; } - pub fn deinit(check: *Check) void { + pub fn testDeinit(check: *Check) void { for (check.queries.items) |*query| { - query.deinit(); + query.testDeinit(); } check.queries.deinit(); @@ -41,11 +41,11 @@ pub const Check = struct { return writer.print("", .{}); } - pub fn convert(check: Check, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Check { + pub fn remap(check: Check, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Check { var queries = try check.queries.clone(); for (queries.items, 0..) |query, i| { - queries.items[i] = try query.convert(old_symbols, new_symbols); + queries.items[i] = try query.remap(old_symbols, new_symbols); } return .{ diff --git a/biscuit-datalog/src/combinator.zig b/biscuit-datalog/src/combinator.zig index baf4ad5..22cb184 100644 --- a/biscuit-datalog/src/combinator.zig +++ b/biscuit-datalog/src/combinator.zig @@ -11,6 +11,8 @@ const Expression = @import("expression.zig").Expression; const MatchedVariables = @import("matched_variables.zig").MatchedVariables; const SymbolTable = @import("symbol_table.zig").SymbolTable; +const log = std.log.scoped(.combinator); + /// Combinator is an iterator that will generate MatchedVariables from /// the body of a rule. /// @@ -103,7 +105,7 @@ pub const Combinator = struct { // Lookup the next (trusted) fact const origin_fact = combinator.trusted_fact_iterator.next() orelse return null; - std.debug.print("combinator next trusted fact: {any}\n", .{origin_fact.fact}); + log.debug("[{}] next trusted fact: {any}", .{ combinator.id, origin_fact.fact }); const origin = origin_fact.origin.*; const fact = origin_fact.fact.*; diff --git a/biscuit-datalog/src/expression.zig b/biscuit-datalog/src/expression.zig index 74ee6ea..47b604c 100644 --- a/biscuit-datalog/src/expression.zig +++ b/biscuit-datalog/src/expression.zig @@ -5,16 +5,18 @@ const Regex = @import("regex").Regex; const Term = @import("term.zig").Term; const SymbolTable = @import("symbol_table.zig").SymbolTable; +const log = std.log.scoped(.expression); + pub const Expression = struct { ops: std.ArrayList(Op), - pub fn fromSchema(allocator: std.mem.Allocator, schema_expression: schema.ExpressionV2) !Expression { - var ops = std.ArrayList(Op).init(allocator); + pub fn fromSchema(arena: std.mem.Allocator, schema_expression: schema.ExpressionV2) !Expression { + var ops = try std.ArrayList(Op).initCapacity(arena, schema_expression.ops.items.len); for (schema_expression.ops.items) |schema_op| { const schema_op_content = schema_op.Content orelse return error.ExpectedOp; const op: Op = switch (schema_op_content) { - .value => |term| .{ .value = try Term.fromSchema(allocator, term) }, + .value => |term| .{ .value = try Term.fromSchema(arena, term) }, .unary => |unary_op| .{ .unary = switch (unary_op.kind) { .Negate => .negate, @@ -57,8 +59,8 @@ pub const Expression = struct { return .{ .ops = ops }; } - pub fn deinit(expression: *Expression) void { - expression.ops.deinit(); + pub fn deinit(_: *Expression) void { + // expression.ops.deinit(); } pub fn evaluate(expr: Expression, allocator: mem.Allocator, values: std.AutoHashMap(u32, Term), symbols: *SymbolTable) !Term { @@ -99,13 +101,12 @@ pub const Expression = struct { return stack.items[0]; } - pub fn convert(expression: Expression, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Expression { - // std.debug.print("Converting expression\n", .{}); + pub fn remap(expression: Expression, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Expression { const ops = try expression.ops.clone(); for (ops.items, 0..) |op, i| { ops.items[i] = switch (op) { - .value => |trm| .{ .value = try trm.convert(old_symbols, new_symbols) }, + .value => |trm| .{ .value = try trm.remap(old_symbols, new_symbols) }, else => op, }; } @@ -311,29 +312,30 @@ fn concat(allocator: std.mem.Allocator, left: []const u8, right: []const u8) ![] test { const testing = std.testing; + const allocator = testing.allocator; + const t1: Term = .{ .integer = 10 }; const t2: Term = .{ .integer = 22 }; - try testing.expectEqual(@as(Term, .{ .bool = false }), try Binary.equal.evaluate(t1, t2, SymbolTable.init("test", testing.allocator))); - try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.equal.evaluate(t1, t1, SymbolTable.init("test", testing.allocator))); - - try testing.expectEqual(@as(Term, .{ .integer = 32 }), try Binary.add.evaluate(t1, t2, SymbolTable.init("test", testing.allocator))); - try testing.expectEqual(@as(Term, .{ .integer = 220 }), try Binary.mul.evaluate(t1, t2, SymbolTable.init("test", testing.allocator))); - var symbols = SymbolTable.init("test", testing.allocator); defer symbols.deinit(); + try testing.expectEqual(@as(Term, .{ .bool = false }), try Binary.equal.evaluate(allocator, t1, t2, &symbols)); + try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.equal.evaluate(allocator, t1, t1, &symbols)); + try testing.expectEqual(@as(Term, .{ .integer = 32 }), try Binary.add.evaluate(allocator, t1, t2, &symbols)); + try testing.expectEqual(@as(Term, .{ .integer = 220 }), try Binary.mul.evaluate(allocator, t1, t2, &symbols)); + const s = .{ .string = try symbols.insert("prefix_middle_suffix") }; const prefix = .{ .string = try symbols.insert("prefix") }; const suffix = .{ .string = try symbols.insert("suffix") }; const middle = .{ .string = try symbols.insert("middle") }; - try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.equal.evaluate(s, s, symbols)); - try testing.expectEqual(@as(Term, .{ .bool = false }), try Binary.equal.evaluate(s, prefix, symbols)); - try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.not_equal.evaluate(s, prefix, symbols)); - try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.prefix.evaluate(s, prefix, symbols)); - try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.suffix.evaluate(s, suffix, symbols)); - try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.contains.evaluate(s, middle, symbols)); + try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.equal.evaluate(allocator, s, s, &symbols)); + try testing.expectEqual(@as(Term, .{ .bool = false }), try Binary.equal.evaluate(allocator, s, prefix, &symbols)); + try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.not_equal.evaluate(allocator, s, prefix, &symbols)); + try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.prefix.evaluate(allocator, s, prefix, &symbols)); + try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.suffix.evaluate(allocator, s, suffix, &symbols)); + try testing.expectEqual(@as(Term, .{ .bool = true }), try Binary.contains.evaluate(allocator, s, middle, &symbols)); } // test "negate" { diff --git a/biscuit-datalog/src/fact.zig b/biscuit-datalog/src/fact.zig index 27aa944..cbeaf75 100644 --- a/biscuit-datalog/src/fact.zig +++ b/biscuit-datalog/src/fact.zig @@ -17,13 +17,13 @@ pub const Fact = struct { return .{ .predicate = predicate }; } - pub fn deinit(fact: *Fact) void { - fact.predicate.deinit(); + pub fn deinit(_: *Fact) void { + // fact.predicate.deinit(); } /// Convert fact to new symbol space - pub fn convert(fact: Fact, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Fact { - return .{ .predicate = try fact.predicate.convert(old_symbols, new_symbols) }; + pub fn remap(fact: Fact, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Fact { + return .{ .predicate = try fact.predicate.remap(old_symbols, new_symbols) }; } pub fn clone(fact: Fact) !Fact { diff --git a/biscuit-datalog/src/fact_set.zig b/biscuit-datalog/src/fact_set.zig index a0fd9f4..d478a47 100644 --- a/biscuit-datalog/src/fact_set.zig +++ b/biscuit-datalog/src/fact_set.zig @@ -38,11 +38,11 @@ pub const FactSet = struct { // value as a key, and we try to insert into hashmap that already contains that value, // we will leak the key if we don't detect the existing version and deallocate one of the // keys. - pub fn deinit(fact_set: *FactSet) void { + pub fn testDeinit(fact_set: *FactSet) void { var it = fact_set.sets.iterator(); while (it.next()) |origin_facts| { - origin_facts.key_ptr.deinit(); // Okay, in practice this is also giving us incorrect alignment issues + origin_facts.key_ptr.testDeinit(); // Okay, in practice this is also giving us incorrect alignment issues origin_facts.value_ptr.deinit(); } @@ -150,7 +150,7 @@ test "FactSet trustedIterator" { const Term = @import("term.zig").Term; var fs = FactSet.init(testing.allocator); - defer fs.deinit(); + defer fs.testDeinit(); var origin = Origin.init(testing.allocator); try origin.insert(0); @@ -184,7 +184,7 @@ test "FactSet trustedIterator" { // With a trusted iterator only trusting [0] we only expect a single fact { var trusted_origins = try TrustedOrigins.defaultOrigins(testing.allocator); - defer trusted_origins.deinit(); + defer trusted_origins.testDeinit(); var count: usize = 0; diff --git a/biscuit-datalog/src/main.zig b/biscuit-datalog/src/main.zig index bed063a..7fff113 100644 --- a/biscuit-datalog/src/main.zig +++ b/biscuit-datalog/src/main.zig @@ -15,18 +15,24 @@ pub const Check = @import("check.zig").Check; pub const Origin = @import("origin.zig").Origin; pub const TrustedOrigins = @import("trusted_origins.zig").TrustedOrigins; pub const world = @import("world.zig"); +pub const Set = @import("set.zig").Set; test { _ = @import("check.zig"); _ = @import("combinator.zig"); _ = @import("expression.zig"); + _ = @import("fact_set.zig"); _ = @import("fact.zig"); _ = @import("matched_variables.zig"); + _ = @import("origin.zig"); _ = @import("predicate.zig"); + _ = @import("rule_set.zig"); _ = @import("rule.zig"); _ = @import("run_limits.zig"); + _ = @import("scope.zig"); _ = @import("set.zig"); _ = @import("symbol_table.zig"); _ = @import("term.zig"); + _ = @import("trusted_origins.zig"); _ = @import("world.zig"); } diff --git a/biscuit-datalog/src/matched_variables.zig b/biscuit-datalog/src/matched_variables.zig index b22dee6..ef3ab56 100644 --- a/biscuit-datalog/src/matched_variables.zig +++ b/biscuit-datalog/src/matched_variables.zig @@ -49,8 +49,8 @@ pub const MatchedVariables = struct { return .{ .variables = variables }; } - pub fn deinit(matched_variables: *MatchedVariables) void { - matched_variables.variables.deinit(); + pub fn deinit(_: *MatchedVariables) void { + // matched_variables.variables.deinit(); } pub fn clone(matched_variables: *const MatchedVariables) !MatchedVariables { diff --git a/biscuit-datalog/src/origin.zig b/biscuit-datalog/src/origin.zig index 1ed103d..65130c4 100644 --- a/biscuit-datalog/src/origin.zig +++ b/biscuit-datalog/src/origin.zig @@ -23,7 +23,7 @@ pub const Origin = struct { return .{ .block_ids = block_ids }; } - pub fn deinit(origin: *Origin) void { + pub fn testDeinit(origin: *Origin) void { origin.block_ids.deinit(); } @@ -94,7 +94,7 @@ test "Origins" { const testing = std.testing; var origins = Origin.init(testing.allocator); - defer origins.deinit(); + defer origins.testDeinit(); try origins.insert(12); try origins.insert(13); diff --git a/biscuit-datalog/src/predicate.zig b/biscuit-datalog/src/predicate.zig index 612260f..539f6ae 100644 --- a/biscuit-datalog/src/predicate.zig +++ b/biscuit-datalog/src/predicate.zig @@ -26,11 +26,11 @@ pub const Predicate = struct { return writer.print(")", .{}); } - pub fn deinit(predicate: *Predicate) void { - for (predicate.terms.items) |*term| { - term.deinit(); - } - predicate.terms.deinit(); + pub fn deinit(_: *Predicate) void { + // for (predicate.terms.items) |*term| { + // term.deinit(); + // } + // predicate.terms.deinit(); } pub fn eql(predicate: Predicate, other_predicate: Predicate) bool { @@ -67,12 +67,12 @@ pub const Predicate = struct { /// Convert predicate to new symbol space /// /// Equivalent to clone but with the symbol rewriting - pub fn convert(predicate: Predicate, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Predicate { + pub fn remap(predicate: Predicate, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Predicate { const name = try old_symbols.getString(predicate.name); var terms = try predicate.terms.clone(); for (terms.items, 0..) |term, i| { - terms.items[i] = try term.convert(old_symbols, new_symbols); + terms.items[i] = try term.remap(old_symbols, new_symbols); } return .{ @@ -120,6 +120,8 @@ test { const testing = std.testing; const allocator = testing.allocator; + const test_log = std.log.scoped(.test_predicate); + var terms_1 = std.ArrayList(Term).init(allocator); defer terms_1.deinit(); try terms_1.insertSlice(0, &.{ .{ .string = 10 }, .{ .integer = 20 } }); @@ -151,5 +153,5 @@ test { try testing.expect(!p1.match(p4)); try testing.expect(p1.match(p5)); - std.debug.print("predicate = {any}\n", .{p1}); + test_log.debug("predicate = {any}\n", .{p1}); } diff --git a/biscuit-datalog/src/rule.zig b/biscuit-datalog/src/rule.zig index b1fb127..e97bca8 100644 --- a/biscuit-datalog/src/rule.zig +++ b/biscuit-datalog/src/rule.zig @@ -15,6 +15,8 @@ const Scope = @import("scope.zig").Scope; const Expression = @import("expression.zig").Expression; const TrustedOrigins = @import("trusted_origins.zig").TrustedOrigins; +const log = std.log.scoped(.rule); + pub const Rule = struct { head: Predicate, body: std.ArrayList(Predicate), @@ -44,20 +46,20 @@ pub const Rule = struct { return .{ .head = head, .body = body, .expressions = expressions, .scopes = scopes }; } - pub fn deinit(rule: *Rule) void { - rule.head.deinit(); + pub fn deinit(_: *Rule) void { + // rule.head.deinit(); - for (rule.body.items) |*predicate| { - predicate.deinit(); - } + // for (rule.body.items) |*predicate| { + // predicate.deinit(); + // } - for (rule.expressions.items) |*expression| { - expression.deinit(); - } + // for (rule.expressions.items) |*expression| { + // expression.deinit(); + // } - rule.body.deinit(); - rule.expressions.deinit(); - rule.scopes.deinit(); + // rule.body.deinit(); + // rule.expressions.deinit(); + // rule.scopes.deinit(); } /// ### Generate new facts from this rule and the existing facts @@ -115,28 +117,25 @@ pub const Rule = struct { /// ``` /// /// ...and we add it to the set of facts (the set will take care of deduplication) - pub fn apply(rule: *const Rule, allocator: mem.Allocator, origin_id: u64, facts: *const FactSet, new_facts: *FactSet, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !void { - var arena = std.heap.ArenaAllocator.init(allocator); - defer arena.deinit(); - - std.debug.print("\napplying rule (from origin {}):\n {any}\n", .{ origin_id, rule }); - const matched_variables = try MatchedVariables.init(arena.allocator(), rule); + pub fn apply(rule: *const Rule, arena: mem.Allocator, origin_id: u64, facts: *const FactSet, new_facts: *FactSet, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !void { + log.debug("\napplying rule {any} (from block {})", .{ rule, origin_id }); + const matched_variables = try MatchedVariables.init(arena, rule); // TODO: if body is empty stuff - var it = Combinator.init(0, allocator, matched_variables, rule.body.items, rule.expressions.items, facts, symbols, trusted_origins); + var it = Combinator.init(0, arena, matched_variables, rule.body.items, rule.expressions.items, facts, symbols, trusted_origins); defer it.deinit(); blk: while (try it.next()) |*origin_bindings| { const origin: Origin = origin_bindings[0]; const bindings: MatchedVariables = origin_bindings[1]; - if (!try bindings.evaluateExpressions(allocator, rule.expressions.items, symbols)) continue; + if (!try bindings.evaluateExpressions(arena, rule.expressions.items, symbols)) continue; // TODO: Describe why clonedWithAllocator? More generally, describe in comment the overall // lifetimes / memory allocation approach during evaluation. - var predicate = try rule.head.cloneWithAllocator(allocator); - defer predicate.deinit(); + var predicate = try rule.head.clone(); + // defer predicate.deinit(); // Loop over terms in head predicate. Update all _variable_ terms with their value // from the binding. @@ -153,18 +152,18 @@ pub const Rule = struct { var new_origin = try origin.clone(); try new_origin.insert(origin_id); - std.debug.print("\nadding new fact:\n {any} with origin {any}\n", .{ fact, new_origin }); + log.debug("apply: adding new fact {any} with origin {any}", .{ fact, new_origin }); // Skip adding fact if we already have generated it. Because the // Set will clobber duplicate facts we'll lose a reference when // inserting a duplicate and then when we loop over the set to // deinit the facts we'll miss some. This ensures that the facts // can be freed purely from the Set. - if (new_facts.contains(new_origin, fact)) { - new_origin.deinit(); - continue; - } + // if (new_facts.contains(new_origin, fact)) { + // // new_origin.deinit(); + // continue; + // } - try new_facts.add(new_origin, try fact.clone()); + try new_facts.add(new_origin, fact); } } @@ -173,17 +172,17 @@ pub const Rule = struct { /// /// Note: whilst the combinator may return multiple valid matches, `findMatch` only requires a single match /// so stopping on the first `it.next()` that returns not-null is enough. - pub fn findMatch(rule: *Rule, allocator: mem.Allocator, facts: *const FactSet, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !bool { - std.debug.print("\nrule.findMatch on {any} ({any})\n", .{ rule, trusted_origins }); - var arena = std.heap.ArenaAllocator.init(allocator); - defer arena.deinit(); + pub fn findMatch(rule: *Rule, arena: mem.Allocator, facts: *const FactSet, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !bool { + log.debug("findMatch({any}, {any})", .{ rule, trusted_origins }); + // var arena = std.heap.ArenaAllocator.init(allocator); + // defer arena.deinit(); - const arena_allocator = arena.allocator(); + // const arena_allocator = arena.allocator(); if (rule.body.items.len == 0) { - const variables = std.AutoHashMap(u32, Term).init(allocator); + const variables = std.AutoHashMap(u32, Term).init(arena); for (rule.expressions.items) |expr| { - const result = try expr.evaluate(arena_allocator, variables, symbols); + const result = try expr.evaluate(arena, variables, symbols); switch (result) { .bool => |b| if (b) continue else return false, @@ -193,32 +192,32 @@ pub const Rule = struct { return true; } else { - const matched_variables = try MatchedVariables.init(arena_allocator, rule); + const matched_variables = try MatchedVariables.init(arena, rule); - var it = Combinator.init(0, allocator, matched_variables, rule.body.items, rule.expressions.items, facts, symbols, trusted_origins); + var it = Combinator.init(0, arena, matched_variables, rule.body.items, rule.expressions.items, facts, symbols, trusted_origins); defer it.deinit(); while (try it.next()) |*origin_bindings| { const bindings: MatchedVariables = origin_bindings[1]; - if (try bindings.evaluateExpressions(arena_allocator, rule.expressions.items, symbols)) return true; + if (try bindings.evaluateExpressions(arena, rule.expressions.items, symbols)) return true; } return false; } } - pub fn checkMatchAll(rule: *Rule, allocator: mem.Allocator, facts: *const FactSet, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !bool { - std.debug.print("\nrule.checkMatchAll on {any} ({any})\n", .{ rule, trusted_origins }); - var arena = std.heap.ArenaAllocator.init(allocator); - defer arena.deinit(); + pub fn checkMatchAll(rule: *Rule, arena: mem.Allocator, facts: *const FactSet, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !bool { + log.debug("checkMatchAll({any}, {any})", .{ rule, trusted_origins }); + // var arena = std.heap.ArenaAllocator.init(allocator); + // defer arena.deinit(); - const arena_allocator = arena.allocator(); + // const arena_allocator = arena.allocator(); if (rule.body.items.len == 0) { - const variables = std.AutoHashMap(u32, Term).init(allocator); + const variables = std.AutoHashMap(u32, Term).init(arena); for (rule.expressions.items) |expr| { - const result = try expr.evaluate(arena_allocator, variables, symbols); + const result = try expr.evaluate(arena, variables, symbols); switch (result) { .bool => |b| if (b) continue else return false, @@ -228,15 +227,15 @@ pub const Rule = struct { return true; } else { - const matched_variables = try MatchedVariables.init(arena_allocator, rule); + const matched_variables = try MatchedVariables.init(arena, rule); - var it = Combinator.init(0, allocator, matched_variables, rule.body.items, rule.expressions.items, facts, symbols, trusted_origins); + var it = Combinator.init(0, arena, matched_variables, rule.body.items, rule.expressions.items, facts, symbols, trusted_origins); defer it.deinit(); while (try it.next()) |*origin_bindings| { const bindings: MatchedVariables = origin_bindings[1]; - if (try bindings.evaluateExpressions(arena_allocator, rule.expressions.items, symbols)) continue; + if (try bindings.evaluateExpressions(arena, rule.expressions.items, symbols)) continue; return false; } @@ -292,25 +291,25 @@ pub const Rule = struct { } // Convert datalog fact from old symbol space to new symbol space - pub fn convert(rule: Rule, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Rule { + pub fn remap(rule: Rule, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Rule { var body = try rule.body.clone(); var expressions = try rule.expressions.clone(); var scopes = try rule.scopes.clone(); for (body.items, 0..) |predicate, i| { - body.items[i] = try predicate.convert(old_symbols, new_symbols); + body.items[i] = try predicate.remap(old_symbols, new_symbols); } for (expressions.items, 0..) |expression, i| { - expressions.items[i] = try expression.convert(old_symbols, new_symbols); + expressions.items[i] = try expression.remap(old_symbols, new_symbols); } for (scopes.items, 0..) |scope, i| { - scopes.items[i] = try scope.convert(old_symbols, new_symbols); + scopes.items[i] = try scope.remap(old_symbols, new_symbols); } return .{ - .head = try rule.head.convert(old_symbols, new_symbols), + .head = try rule.head.remap(old_symbols, new_symbols), .body = body, .expressions = expressions, .scopes = scopes, diff --git a/biscuit-datalog/src/rule_set.zig b/biscuit-datalog/src/rule_set.zig index 5db8adc..491ac2f 100644 --- a/biscuit-datalog/src/rule_set.zig +++ b/biscuit-datalog/src/rule_set.zig @@ -15,11 +15,11 @@ pub const RuleSet = struct { }; } - pub fn deinit(rule_set: *RuleSet) void { + pub fn testDeinit(rule_set: *RuleSet) void { var it = rule_set.rules.iterator(); while (it.next()) |entry| { - entry.key_ptr.deinit(); + entry.key_ptr.testDeinit(); entry.value_ptr.deinit(); } @@ -41,12 +41,14 @@ pub const RuleSet = struct { test "RuleSet" { const testing = std.testing; + const test_log = std.log.scoped(.test_rule_set); + var rs = RuleSet.init(testing.allocator); - defer rs.deinit(); + defer rs.testDeinit(); const default_origins = try TrustedOrigins.defaultOrigins(testing.allocator); const rule: Rule = undefined; try rs.add(0, default_origins, rule); - std.debug.print("rs = {any}\n", .{rs}); + test_log.debug("rs = {any}", .{rs}); } diff --git a/biscuit-datalog/src/scope.zig b/biscuit-datalog/src/scope.zig index 5eb6546..60efa7a 100644 --- a/biscuit-datalog/src/scope.zig +++ b/biscuit-datalog/src/scope.zig @@ -20,7 +20,7 @@ pub const Scope = union(ScopeTag) { }; } - pub fn convert(scope: Scope, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Scope { + pub fn remap(scope: Scope, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Scope { return switch (scope) { .authority => .authority, .previous => .previous, diff --git a/biscuit-datalog/src/set.zig b/biscuit-datalog/src/set.zig index 6d13db5..b8f10d3 100644 --- a/biscuit-datalog/src/set.zig +++ b/biscuit-datalog/src/set.zig @@ -171,6 +171,8 @@ test { const testing = std.testing; const allocator = testing.allocator; + const test_log = std.log.scoped(.test_set); + var s = Set(Fact).init(allocator); defer s.deinit(); @@ -187,7 +189,7 @@ test { try s.add(Fact{ .predicate = Predicate{ .name = 10, .terms = undefined } }); try testing.expectEqual(@as(u32, 2), s.count()); - std.debug.print("set = {any}\n", .{s}); + test_log.debug("set = {any}\n", .{s}); } test "hashing" { diff --git a/biscuit-datalog/src/symbol_table.zig b/biscuit-datalog/src/symbol_table.zig index d5243ab..c1693c2 100644 --- a/biscuit-datalog/src/symbol_table.zig +++ b/biscuit-datalog/src/symbol_table.zig @@ -3,6 +3,8 @@ const mem = std.mem; const Ed25519 = std.crypto.sign.Ed25519; +const log = std.log.scoped(.symbol_table); + pub const SymbolTable = struct { name: []const u8, allocator: mem.Allocator, @@ -41,7 +43,7 @@ pub const SymbolTable = struct { const index = symbol_table.symbols.items.len - 1 + NON_DEFAULT_SYMBOLS_OFFSET; - // std.debug.print("{s}: Inserting \"{s}\" at {}\n", .{ symbol_table.name, symbol, index }); + log.debug("inserting \"{s}\" at {} [{s}]", .{ symbol, index, symbol_table.name }); return index; } @@ -77,18 +79,14 @@ pub const SymbolTable = struct { pub fn getString(symbol_table: *const SymbolTable, sym_index: u64) ![]const u8 { if (indexToDefault(sym_index)) |sym| { - // std.debug.print("Found \"{s}\" at {} (default)\n", .{ sym, sym_index }); return sym; } if (sym_index >= NON_DEFAULT_SYMBOLS_OFFSET and sym_index < NON_DEFAULT_SYMBOLS_OFFSET + symbol_table.symbols.items.len) { const sym = symbol_table.symbols.items[sym_index - NON_DEFAULT_SYMBOLS_OFFSET]; - // std.debug.print("Found \"{s}\" at {}\n", .{ sym, sym_index }); return sym; } - // std.debug.print("Existing sym index {} not found\n", .{sym_index}); - return error.SymbolNotFound; } diff --git a/biscuit-datalog/src/term.zig b/biscuit-datalog/src/term.zig index 01e711a..9cc8f48 100644 --- a/biscuit-datalog/src/term.zig +++ b/biscuit-datalog/src/term.zig @@ -44,7 +44,7 @@ pub const Term = union(TermKind) { }; } - pub fn convert(term: Term, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Term { + pub fn remap(term: Term, old_symbols: *const SymbolTable, new_symbols: *SymbolTable) !Term { return switch (term) { .variable => |id| .{ .variable = std.math.cast(u32, try new_symbols.insert(try old_symbols.getString(id))) orelse return error.VariableIdTooLarge }, .string => |id| .{ .string = try new_symbols.insert(try old_symbols.getString(id)) }, @@ -54,7 +54,7 @@ pub const Term = union(TermKind) { var it = s.iterator(); while (it.next()) |term_ptr| { - try set.add(try term_ptr.convert(old_symbols, new_symbols)); + try set.add(try term_ptr.remap(old_symbols, new_symbols)); } break :blk .{ .set = set }; diff --git a/biscuit-datalog/src/trusted_origins.zig b/biscuit-datalog/src/trusted_origins.zig index fea1f85..eff0067 100644 --- a/biscuit-datalog/src/trusted_origins.zig +++ b/biscuit-datalog/src/trusted_origins.zig @@ -14,7 +14,7 @@ pub const TrustedOrigins = struct { return .{ .ids = InnerSet.init(allocator) }; } - pub fn deinit(trusted_origins: *TrustedOrigins) void { + pub fn testDeinit(trusted_origins: *TrustedOrigins) void { trusted_origins.ids.deinit(); } @@ -124,10 +124,10 @@ test "Trusted origin" { const testing = std.testing; var to = try TrustedOrigins.defaultOrigins(testing.allocator); - defer to.deinit(); + defer to.testDeinit(); var o = Origin.init(testing.allocator); - defer o.deinit(); + defer o.testDeinit(); try o.insert(22); diff --git a/biscuit-datalog/src/world.zig b/biscuit-datalog/src/world.zig index 1e70b39..57cc67d 100644 --- a/biscuit-datalog/src/world.zig +++ b/biscuit-datalog/src/world.zig @@ -10,48 +10,44 @@ const TrustedOrigins = @import("trusted_origins.zig").TrustedOrigins; const RunLimits = @import("run_limits.zig").RunLimits; const SymbolTable = @import("symbol_table.zig").SymbolTable; +const log = std.log.scoped(.world); + pub const World = struct { - allocator: mem.Allocator, + arena: mem.Allocator, fact_set: FactSet, rule_set: RuleSet, - symbols: std.ArrayList([]const u8), - /// init world - /// - /// Note: the allocator we pass in can be any allocator. This allocator - /// is used purely for the toplevel Set and ArrayLists. Any facts allocated - /// during world run will be allocated with a provided allocator that is - /// specifically an arena. The world and rule code will reflect that by - /// not doing explicit deallocation on the fact / predicate / term level. - /// - /// If we ever want to change away from that arena model, we'll have to - /// fix up some code internally to allow that. - pub fn init(allocator: mem.Allocator) World { + pub fn init(arena: mem.Allocator) World { return .{ - .allocator = allocator, - .fact_set = FactSet.init(allocator), - .rule_set = RuleSet.init(allocator), - .symbols = std.ArrayList([]const u8).init(allocator), + .arena = arena, + .fact_set = FactSet.init(arena), + .rule_set = RuleSet.init(arena), }; } - pub fn deinit(world: *World) void { - world.symbols.deinit(); - world.rule_set.deinit(); - world.fact_set.deinit(); + pub fn deinit(_: *World) void { + // world.symbols.deinit(); + // world.rule_set.deinit(); + // world.fact_set.deinit(); } + /// Generate all facts from rules and existing facts + /// + /// Uses default run limits. pub fn run(world: *World, symbols: *SymbolTable) !void { try world.runWithLimits(symbols, .{}); } + /// Generate all facts from rules and existing facts + /// + /// User specifies run limits. pub fn runWithLimits(world: *World, symbols: *SymbolTable, limits: RunLimits) !void { for (0..limits.max_iterations) |iteration| { - std.debug.print("\nrunWithLimits[{}]\n", .{iteration}); + log.debug("runWithLimits[{}]", .{iteration}); const starting_fact_count = world.fact_set.count(); - var new_fact_sets = FactSet.init(world.allocator); - defer new_fact_sets.deinit(); + var new_fact_sets = FactSet.init(world.arena); + // defer new_fact_sets.deinit(); // Iterate over rules to generate new facts { @@ -65,30 +61,23 @@ pub const World = struct { const origin_id: u64 = origin_rule[0]; const rule: Rule = origin_rule[1]; - try rule.apply(world.allocator, origin_id, &world.fact_set, &new_fact_sets, symbols, trusted_origins); + try rule.apply(world.arena, origin_id, &world.fact_set, &new_fact_sets, symbols, trusted_origins); } } } var it = new_fact_sets.iterator(); while (it.next()) |origin_fact| { - const existing_origin = origin_fact.origin.*; + const origin = origin_fact.origin.*; const fact = origin_fact.fact.*; - var origin = try existing_origin.clone(); - - if (world.fact_set.contains(origin, fact)) { - origin.deinit(); - continue; - } - - try world.fact_set.add(origin, try fact.cloneWithAllocator(world.allocator)); + try world.fact_set.add(origin, fact); } - std.debug.print("starting_fact_count = {}, world.facts.count() = {}\n", .{ starting_fact_count, world.fact_set.count() }); + log.debug("starting_fact_count = {}, world.facts.count() = {}", .{ starting_fact_count, world.fact_set.count() }); // If we haven't generated any new facts, we're done. if (starting_fact_count == world.fact_set.count()) { - std.debug.print("No new facts!\n", .{}); + log.debug("No new facts!", .{}); return; } @@ -100,20 +89,22 @@ pub const World = struct { /// Add fact with origin to world pub fn addFact(world: *World, origin: Origin, fact: Fact) !void { - std.debug.print("\nworld: adding fact = {any} ({any}) \n", .{ fact, origin }); + log.debug("adding fact = {any}, origin = ({any})", .{ fact, origin }); + try world.fact_set.add(origin, fact); } + // Add rule trusting scope from origin pub fn addRule(world: *World, origin_id: usize, scope: TrustedOrigins, rule: Rule) !void { - std.debug.print("\nworld: adding rule (origin {}) = {any} (trusts {any})\n", .{ origin_id, rule, scope }); + log.debug("adding rule {any}, origin = {}, trusts {any}", .{ rule, origin_id, scope }); try world.rule_set.add(origin_id, scope, rule); } pub fn queryMatch(world: *World, rule: *Rule, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !bool { - return rule.findMatch(world.allocator, &world.fact_set, symbols, trusted_origins); + return rule.findMatch(world.arena, &world.fact_set, symbols, trusted_origins); } pub fn queryMatchAll(world: *World, rule: *Rule, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !bool { - return rule.checkMatchAll(world.allocator, &world.fact_set, symbols, trusted_origins); + return rule.checkMatchAll(world.arena, &world.fact_set, symbols, trusted_origins); } }; diff --git a/biscuit-format/src/serialized_biscuit.zig b/biscuit-format/src/serialized_biscuit.zig index c6f2a33..1c9e641 100644 --- a/biscuit-format/src/serialized_biscuit.zig +++ b/biscuit-format/src/serialized_biscuit.zig @@ -5,6 +5,8 @@ const schema = @import("biscuit-schema"); const SignedBlock = @import("signed_block.zig").SignedBlock; const Proof = @import("proof.zig").Proof; +const log = std.log.scoped(.serialized_biscuit); + pub const MIN_SCHEMA_VERSION = 3; pub const MAX_SCHEMA_VERSION = 4; @@ -70,10 +72,15 @@ pub const SerializedBiscuit = struct { /// public key is the public key of the private key in the /// the proof. fn verify(serialized_biscuit: *SerializedBiscuit, root_public_key: Ed25519.PublicKey) !void { + log.debug("verify()", .{}); + defer log.debug("end verify()", .{}); + var pk = root_public_key; // Verify the authority block's signature { + log.debug("verifying authority block", .{}); + errdefer log.debug("failed to verify authority block", .{}); if (serialized_biscuit.authority.external_signature != null) return error.AuthorityBlockMustNotHaveExternalSignature; var verifier = try serialized_biscuit.authority.signature.verifier(pk); @@ -88,9 +95,12 @@ pub const SerializedBiscuit = struct { } // Verify the other blocks' signatures - for (serialized_biscuit.blocks.items) |*block| { + for (serialized_biscuit.blocks.items, 1..) |*block, block_id| { // Verify the block signature { + log.debug("verifying block {}", .{block_id}); + errdefer log.debug("failed to verify block {}", .{block_id}); + var verifier = try block.signature.verifier(pk); verifier.update(block.block); @@ -105,6 +115,9 @@ pub const SerializedBiscuit = struct { // Verify the external signature (where one exists) if (block.external_signature) |external_signature| { + log.debug("verifying external signature on block {}", .{block_id}); + errdefer log.debug("failed to verify external signature on block {}", .{block_id}); + var external_verifier = try external_signature.signature.verifier(external_signature.public_key); external_verifier.update(block.block); external_verifier.update(&block.algorithm2Buf()); @@ -116,13 +129,18 @@ pub const SerializedBiscuit = struct { } // Check the proof + + log.debug("verifying proof", .{}); switch (serialized_biscuit.proof) { .next_secret => |next_secret| { if (!std.mem.eql(u8, &pk.bytes, &next_secret.publicKeyBytes())) { + log.debug("failed to verify proof (sealed)", .{}); return error.SecretKeyProofFailedMismatchedPublicKeys; } }, .final_signature => |final_signature| { + errdefer log.debug("failed to verify proof (attenuated)", .{}); + var last_block = if (serialized_biscuit.blocks.items.len == 0) serialized_biscuit.authority else serialized_biscuit.blocks.items[serialized_biscuit.blocks.items.len - 1]; var verifier = try final_signature.verifier(pk); diff --git a/biscuit-format/src/signed_block.zig b/biscuit-format/src/signed_block.zig index 7d6344e..1c7662c 100644 --- a/biscuit-format/src/signed_block.zig +++ b/biscuit-format/src/signed_block.zig @@ -34,7 +34,7 @@ pub const SignedBlock = struct { const algo = required_block_external_key.algorithm; - std.debug.print("ALGORITHM = {}\n", .{algo}); + _ = algo; // FIXME: we need to use algorithm (at least at the point that support for things other than Ed25519) if (block_external_signature.len != Ed25519.Signature.encoded_length) return error.IncorrectBlockExternalSignatureLength; if (block_external_public_key.len != Ed25519.PublicKey.encoded_length) return error.IncorrectBlockExternalPublicKeyLength; diff --git a/biscuit-parser/build.zig b/biscuit-parser/build.zig index e5e89f2..686f2cb 100644 --- a/biscuit-parser/build.zig +++ b/biscuit-parser/build.zig @@ -16,18 +16,12 @@ pub fn build(b: *std.Build) void { const optimize = b.standardOptimizeOption(.{}); const ziglyph = b.dependency("ziglyph", .{ .optimize = optimize, .target = target }); - const schema = b.dependency("biscuit-schema", .{ .target = target, .optimize = optimize }); - const format = b.dependency("biscuit-format", .{ .target = target, .optimize = optimize }); const builder = b.dependency("biscuit-builder", .{ .target = target, .optimize = optimize }); - const datalog = b.dependency("biscuit-datalog", .{ .target = target, .optimize = optimize }); _ = b.addModule("biscuit-parser", .{ .root_source_file = .{ .path = "src/parser.zig" }, .imports = &.{ - .{ .name = "biscuit-schema", .module = schema.module("biscuit-schema") }, - .{ .name = "biscuit-format", .module = format.module("biscuit-format") }, .{ .name = "biscuit-builder", .module = builder.module("biscuit-builder") }, - .{ .name = "biscuit-datalog", .module = datalog.module("biscuit-datalog") }, .{ .name = "ziglyph", .module = ziglyph.module("ziglyph") }, }, }); @@ -39,10 +33,7 @@ pub fn build(b: *std.Build) void { .target = target, .optimize = optimize, }); - lib_unit_tests.root_module.addImport("biscuit-schema", schema.module("biscuit-schema")); - lib_unit_tests.root_module.addImport("biscuit-format", format.module("biscuit-format")); lib_unit_tests.root_module.addImport("biscuit-builder", builder.module("biscuit-builder")); - lib_unit_tests.root_module.addImport("biscuit-datalog", datalog.module("biscuit-datalog")); lib_unit_tests.root_module.addImport("ziglyph", ziglyph.module("ziglyph")); const run_lib_unit_tests = b.addRunArtifact(lib_unit_tests); diff --git a/biscuit-parser/build.zig.zon b/biscuit-parser/build.zig.zon index 54f2748..eb581fc 100644 --- a/biscuit-parser/build.zig.zon +++ b/biscuit-parser/build.zig.zon @@ -41,7 +41,6 @@ .@"biscuit-schema" = .{ .path = "../biscuit-schema" }, .@"biscuit-format" = .{ .path = "../biscuit-format" }, .@"biscuit-builder" = .{ .path = "../biscuit-builder" }, - .@"biscuit-datalog" = .{ .path = "../biscuit-datalog" }, .ziglyph = .{ .url = "https://codeberg.org/dude_the_builder/ziglyph/archive/947ed39203bf90412e3d16cbcf936518b6f23af0.tar.gz", .hash = "12208b23d1eb6dcb929e85346524db8f8b8aa1401bdf8a97dee1e0cfb55da8d5fb42", diff --git a/biscuit-parser/src/parser.zig b/biscuit-parser/src/parser.zig index 7e8e032..ae000d9 100644 --- a/biscuit-parser/src/parser.zig +++ b/biscuit-parser/src/parser.zig @@ -1,6 +1,5 @@ const std = @import("std"); const ziglyph = @import("ziglyph"); -const datalog = @import("biscuit-datalog"); const Term = @import("biscuit-builder").Term; const Fact = @import("biscuit-builder").Fact; const Check = @import("biscuit-builder").Check; @@ -10,8 +9,11 @@ const Expression = @import("biscuit-builder").Expression; const Scope = @import("biscuit-builder").Scope; const Date = @import("biscuit-builder").Date; const Policy = @import("biscuit-builder").Policy; +const Set = @import("biscuit-builder").Set; const Ed25519 = std.crypto.sign.Ed25519; +const log = std.log.scoped(.parser); + pub const Parser = struct { input: []const u8, offset: usize = 0, @@ -21,221 +23,286 @@ pub const Parser = struct { return .{ .input = input, .allocator = allocator }; } + /// Return a new temporary parser with the current state of parent parser + /// for attempting to parse one of choice of things. + /// + /// E.g. when we parse a term we try parse each subtype of term with a temporary + /// parser. + pub fn temporary(parser: *Parser) Parser { + return Parser.init(parser.allocator, parser.rest()); + } + + /// Try to parse fact + /// + /// E.g. read(1, "hello") will parse successfully, but read($foo, "hello") + /// will fail because it contains a variable `$foo`. pub fn fact(parser: *Parser) !Fact { - return .{ .predicate = try parser.factPredicate(), .variables = null }; + return .{ .predicate = try parser.predicate(.fact), .variables = null }; } - pub fn factPredicate(parser: *Parser) !Predicate { - const name = parser.readName(); + pub fn predicate(parser: *Parser, kind: enum { fact, rule }) !Predicate { + var terms = std.ArrayList(Term).init(parser.allocator); - std.debug.print("name = {s}\n", .{name}); + const predicate_name = try parser.name(); - parser.skipWhiteSpace(); + try parser.consume("("); - // Consume left paren - try parser.expect('('); + while (true) { + parser.skipWhiteSpace(); - // Parse terms - var terms = std.ArrayList(Term).init(parser.allocator); + try terms.append(try parser.term(switch (kind) { + .rule => .allow, + .fact => .disallow, + })); - var it = parser.factTermsIterator(); - while (try it.next()) |trm| { - try terms.append(trm); + parser.skipWhiteSpace(); - if (parser.peek()) |peeked| { - if (peeked != ',') break; - } else { - break; - } + if (!parser.startsWithConsuming(",")) break; } - try parser.expect(')'); + try parser.consume(")"); - return .{ .name = name, .terms = terms }; + return .{ .name = predicate_name, .terms = terms }; } - const FactTermIterator = struct { - parser: *Parser, + /// Try to parse a term + /// + /// Does not consume `parser` input on failure. + fn term(parser: *Parser, variables: AllowVariables) ParserError!Term { + if (variables == .disallow) { + try parser.reject("$"); // Variables are disallowed in a fact term + } else { + blk: { + var tmp = parser.temporary(); + + const value = tmp.variable() catch break :blk; - pub fn next(it: *FactTermIterator) !?Term { - it.parser.skipWhiteSpace(); + parser.offset += tmp.offset; - return try it.parser.factTerm(); + return .{ .variable = value }; + } } - }; - pub fn factTermsIterator(parser: *Parser) FactTermIterator { - return .{ .parser = parser }; - } + blk: { + var tmp = parser.temporary(); - const TermIterator = struct { - parser: *Parser, + const value = tmp.string() catch break :blk; - pub fn next(it: *TermIterator) !?Term { - it.parser.skipWhiteSpace(); + parser.offset += tmp.offset; - return try it.parser.term(); + return .{ .string = value }; } - }; - - pub fn termsIterator(parser: *Parser) TermIterator { - return .{ .parser = parser }; - } - - pub fn term(parser: *Parser) !Term { - const rst = parser.rest(); - variable_blk: { - var term_parser = Parser.init(parser.allocator, rst); + blk: { + var tmp = parser.temporary(); - const value = term_parser.variable() catch break :variable_blk; + const value = tmp.date() catch break :blk; - parser.offset += term_parser.offset; + parser.offset += tmp.offset; - return .{ .variable = value }; + return .{ .date = value }; } - string_blk: { - var term_parser = Parser.init(parser.allocator, rst); + blk: { + var tmp = parser.temporary(); - const value = term_parser.string() catch break :string_blk; + const value = tmp.number(i64) catch break :blk; - parser.offset += term_parser.offset; + parser.offset += tmp.offset; - return .{ .string = value }; + return .{ .integer = value }; } - date_blk: { - var term_parser = Parser.init(parser.allocator, rst); + blk: { + var tmp = parser.temporary(); - const value = term_parser.date() catch break :date_blk; + const value = tmp.boolean() catch break :blk; - parser.offset += term_parser.offset; + parser.offset += tmp.offset; - return .{ .date = value }; + return .{ .bool = value }; } - number_blk: { - var term_parser = Parser.init(parser.allocator, rst); + blk: { + var tmp = parser.temporary(); - const value = term_parser.number(i64) catch break :number_blk; + const value = tmp.bytes() catch break :blk; - parser.offset += term_parser.offset; + parser.offset += tmp.offset; - return .{ .integer = value }; + return .{ .bytes = value }; } - bool_blk: { - var term_parser = Parser.init(parser.allocator, rst); + blk: { + var tmp = parser.temporary(); - const value = term_parser.boolean() catch break :bool_blk; + const value = tmp.set(variables) catch break :blk; - parser.offset += term_parser.offset; + parser.offset += tmp.offset; - return .{ .bool = value }; + return .{ .set = value }; } return error.NoFactTermFound; } - pub fn factTerm(parser: *Parser) !Term { - const rst = parser.rest(); + pub fn policy(parser: *Parser) !Policy { + const kind: Policy.Kind = if (parser.startsWithConsuming("allow if")) + .allow + else if (parser.startsWithConsuming("deny if")) + .deny + else + return error.UnexpectedPolicyKind; - try parser.reject('$'); // Variables are disallowed in a fact term + // FIXME: figure out if the space is required or not + // try parser.requiredWhiteSpace(); - string_blk: { - var term_parser = Parser.init(parser.allocator, rst); + const queries = try parser.checkBody(); - const value = term_parser.string() catch break :string_blk; + return .{ .kind = kind, .queries = queries }; + } - parser.offset += term_parser.offset; + pub fn check(parser: *Parser) !Check { + const kind: Check.Kind = if (parser.startsWithConsuming("check if")) + .one + else if (parser.startsWithConsuming("check all")) + .all + else + return error.UnexpectedCheckKind; - return .{ .string = value }; - } + // FIXME: figure out if the space is required or not + // try parser.requiredWhiteSpace(); - date_blk: { - var term_parser = Parser.init(parser.allocator, rst); + const queries = try parser.checkBody(); - const value = term_parser.date() catch break :date_blk; + return .{ .kind = kind, .queries = queries }; + } - parser.offset += term_parser.offset; + /// Parse check body + /// + /// E.g. given check if right($0, $1), resource($0), operation($1), $0.contains(\"file\") or admin(true) + /// this will (attempt to) parse `right($0, $1), resource($0), operation($1), $0.contains(\"file\") or admin(true)` + /// + /// Requires at least one rule body. + fn checkBody(parser: *Parser) !std.ArrayList(Rule) { + var queries = std.ArrayList(Rule).init(parser.allocator); - return .{ .date = value }; - } + while (true) { + parser.skipWhiteSpace(); - number_blk: { - var term_parser = Parser.init(parser.allocator, rst); + const body = try parser.ruleBody(); - const value = term_parser.number(i64) catch break :number_blk; + try queries.append(.{ + .head = .{ .name = "query", .terms = std.ArrayList(Term).init(parser.allocator) }, + .body = body.predicates, + .expressions = body.expressions, + .scopes = body.scopes, + .variables = null, + }); - parser.offset += term_parser.offset; + parser.skipWhiteSpace(); - return .{ .integer = value }; + if (!parser.startsWithConsuming("or")) break; } - bool_blk: { - var term_parser = Parser.init(parser.allocator, rst); + return queries; + } - const value = term_parser.boolean() catch break :bool_blk; + pub fn rule(parser: *Parser) !Rule { + const head = try parser.predicate(.rule); - parser.offset += term_parser.offset; + parser.skipWhiteSpace(); - return .{ .bool = value }; - } + try parser.consume("<-"); - return error.NoFactTermFound; - } + parser.skipWhiteSpace(); - pub fn predicate(parser: *Parser) !Predicate { - const name = parser.readName(); + const body = try parser.ruleBody(); - parser.skipWhiteSpace(); + return .{ + .head = head, + .body = body.predicates, + .expressions = body.expressions, + .scopes = body.scopes, + .variables = null, + }; + } - // Consume left paren - try parser.expect('('); + fn ruleBody(parser: *Parser) !struct { predicates: std.ArrayList(Predicate), expressions: std.ArrayList(Expression), scopes: std.ArrayList(Scope) } { + var predicates = std.ArrayList(Predicate).init(parser.allocator); + var expressions = std.ArrayList(Expression).init(parser.allocator); + var scps = std.ArrayList(Scope).init(parser.allocator); - // Parse terms - var terms = std.ArrayList(Term).init(parser.allocator); + while (true) { + parser.skipWhiteSpace(); - var it = parser.termsIterator(); - while (try it.next()) |trm| { - try terms.append(trm); + const rule_body = try parser.ruleBodyElement(); - if (parser.peek()) |peeked| { - if (peeked == ',') { - parser.offset += 1; - continue; - } + switch (rule_body) { + .predicate => |p| try predicates.append(p), + .expression => |e| try expressions.append(e), } - break; + parser.skipWhiteSpace(); + + if (!parser.startsWithConsuming(",")) break; } - try parser.expect(')'); + blk: { + var tmp = parser.temporary(); + + const s = tmp.scopes(parser.allocator) catch break :blk; + + parser.offset += tmp.offset; + + scps = s; + } - return .{ .name = name, .terms = terms }; + return .{ .predicates = predicates, .expressions = expressions, .scopes = scps }; } - fn variable(parser: *Parser) ![]const u8 { - try parser.expect('$'); + const BodyElementTag = enum { + predicate, + expression, + }; - const start = parser.offset; + /// Try to parse a rule body element (a predicate or an expression) + /// + /// Does not consume `parser` input on failure. + fn ruleBodyElement(parser: *Parser) !union(BodyElementTag) { predicate: Predicate, expression: Expression } { + blk: { + var tmp = parser.temporary(); - for (parser.rest()) |c| { - if (ziglyph.isAlphaNum(c) or c == '_') { - parser.offset += 1; - continue; - } + const p = tmp.predicate(.rule) catch break :blk; - break; + parser.offset += tmp.offset; + + return .{ .predicate = p }; } - return parser.input[start..parser.offset]; + // Otherwise try parsing an expression + blk: { + var tmp = parser.temporary(); + + const e = tmp.expression() catch break :blk; + + parser.offset += tmp.offset; + + return .{ .expression = e }; + } + + return error.ExpectedPredicateOrExpression; + } + + fn variable(parser: *Parser) ![]const u8 { + try parser.consume("$"); + + return try parser.variableName(); } // FIXME: properly implement string parsing fn string(parser: *Parser) ![]const u8 { - try parser.expect('"'); + try parser.consume("\""); const start = parser.offset; @@ -252,32 +319,32 @@ pub const Parser = struct { fn date(parser: *Parser) !u64 { const year = try parser.number(i32); - try parser.expect('-'); + try parser.consume("-"); const month = try parser.number(u8); if (month < 1 or month > 12) return error.MonthOutOfRange; - try parser.expect('-'); + try parser.consume("-"); const day = try parser.number(u8); if (!Date.isDayMonthYearValid(i32, year, month, day)) return error.InvalidDayMonthYearCombination; - try parser.expect('T'); + try parser.consume("T"); const hour = try parser.number(u8); if (hour > 23) return error.HoyrOutOfRange; - try parser.expect(':'); + try parser.consume(":"); const minute = try parser.number(u8); if (minute > 59) return error.MinuteOutOfRange; - try parser.expect(':'); + try parser.consume(":"); const second = try parser.number(u8); if (second > 59) return error.SecondOutOfRange; - try parser.expect('Z'); + try parser.consume("Z"); const d: Date = .{ .year = year, @@ -296,8 +363,14 @@ pub const Parser = struct { fn number(parser: *Parser, comptime T: type) !T { const start = parser.offset; + if (parser.rest().len == 0) return error.ParsingNumberExpectsAtLeastOneCharacter; + const first_char = parser.rest()[0]; + + if (!(isDigit(first_char) or first_char == '-')) return error.ParsingNameFirstCharacterMustBeLetter; + parser.offset += 1; + for (parser.rest()) |c| { - if (ziglyph.isAsciiDigit(c)) { + if (isDigit(c)) { parser.offset += 1; continue; } @@ -311,202 +384,64 @@ pub const Parser = struct { } fn boolean(parser: *Parser) !bool { - if (std.mem.startsWith(u8, parser.rest(), "true")) { - parser.offset += "term".len; - return true; - } - - if (std.mem.startsWith(u8, parser.rest(), "false")) { - parser.offset += "false".len; - return false; - } + if (parser.startsWithConsuming("true")) return true; + if (parser.startsWithConsuming("false")) return false; return error.ExpectedBooleanTerm; } - pub fn policy(parser: *Parser) !Policy { - var kind: Policy.Kind = undefined; - - if (std.mem.startsWith(u8, parser.rest(), "allow if")) { - parser.offset += "allow if".len; - kind = .allow; - } else if (std.mem.startsWith(u8, parser.rest(), "deny if")) { - parser.offset += "deny if".len; - kind = .deny; - } else { - return error.UnexpectedPolicyKind; - } - - const queries = try parser.checkBody(); + fn bytes(parser: *Parser) ![]const u8 { + try parser.consume("hex:"); - return .{ .kind = kind, .queries = queries }; - } + const hex_string = try parser.hex(); - pub fn check(parser: *Parser) !Check { - var kind: datalog.Check.Kind = undefined; - - if (std.mem.startsWith(u8, parser.rest(), "check if")) { - parser.offset += "check if".len; - kind = .one; - } else if (std.mem.startsWith(u8, parser.rest(), "check all")) { - parser.offset += "check all".len; - kind = .all; - } else { - return error.UnexpectedCheckKind; - } + if (!(hex_string.len % 2 == 0)) return error.ExpectedEvenNumberOfHexDigis; - const queries = try parser.checkBody(); + const out = try parser.allocator.alloc(u8, hex_string.len / 2); - return .{ .kind = kind, .queries = queries }; + return try std.fmt.hexToBytes(out, hex_string); } - fn checkBody(parser: *Parser) !std.ArrayList(Rule) { - var queries = std.ArrayList(Rule).init(parser.allocator); - - const required_body = try parser.ruleBody(); + fn set(parser: *Parser, variables: AllowVariables) !Set(Term) { + var new_set = Set(Term).init(parser.allocator); - try queries.append(.{ - .head = .{ .name = "query", .terms = std.ArrayList(Term).init(parser.allocator) }, - .body = required_body.predicates, - .expressions = required_body.expressions, - .scopes = required_body.scopes, - .variables = null, - }); + try parser.consume("["); while (true) { parser.skipWhiteSpace(); - if (!std.mem.startsWith(u8, parser.rest(), "or")) break; - - parser.offset += "or".len; - - const body = try parser.ruleBody(); - - try queries.append(.{ - .head = .{ .name = "query", .terms = std.ArrayList(Term).init(parser.allocator) }, - .body = body.predicates, - .expressions = body.expressions, - .scopes = body.scopes, - .variables = null, - }); - } - - return queries; - } - - pub fn rule(parser: *Parser) !Rule { - const head = try parser.predicate(); - - parser.skipWhiteSpace(); - - if (!std.mem.startsWith(u8, parser.rest(), "<-")) return error.ExpectedArrow; - - parser.offset += "<-".len; - - const body = try parser.ruleBody(); - - return .{ - .head = head, - .body = body.predicates, - .expressions = body.expressions, - .scopes = body.scopes, - .variables = null, - }; - } - - pub fn ruleBody(parser: *Parser) !struct { predicates: std.ArrayList(Predicate), expressions: std.ArrayList(Expression), scopes: std.ArrayList(Scope) } { - var predicates = std.ArrayList(Predicate).init(parser.allocator); - var expressions = std.ArrayList(Expression).init(parser.allocator); - var scps = std.ArrayList(Scope).init(parser.allocator); + // Try to parse a term. Since sets can be empty we break on catch; + const trm = parser.term(variables) catch break; + try new_set.add(trm); - while (true) { parser.skipWhiteSpace(); - std.debug.print("{s}: \"{s}\"\n", .{ @src().fn_name, parser.rest() }); - - // Try parsing a predicate - predicate_blk: { - var predicate_parser = Parser.init(parser.allocator, parser.rest()); - - const p = predicate_parser.predicate() catch break :predicate_blk; - - parser.offset += predicate_parser.offset; - - try predicates.append(p); - - parser.skipWhiteSpace(); - - if (parser.peek()) |peeked| { - if (peeked == ',') { - parser.offset += 1; - continue; - } - } - } - - // Otherwise try parsing an expression - expression_blk: { - var expression_parser = Parser.init(parser.allocator, parser.rest()); - - const e = expression_parser.expression() catch break :expression_blk; - - parser.offset += expression_parser.offset; - - try expressions.append(e); - - parser.skipWhiteSpace(); - - if (parser.peek()) |peeked| { - if (peeked == ',') { - parser.offset += 1; - continue; - } - } - } - // We haven't found a predicate or expression so we're done, - // other than potentially parsing a scope - break; + if (!parser.startsWithConsuming(",")) break; } - scopes_blk: { - var scope_parser = Parser.init(parser.allocator, parser.rest()); - - const s = scope_parser.scopes(parser.allocator) catch break :scopes_blk; - - parser.offset += scope_parser.offset; - - scps = s; - } + try parser.consume("]"); - return .{ .predicates = predicates, .expressions = expressions, .scopes = scps }; + return new_set; } + /// Parse an expression + /// + /// This is the top-level expression parsing function. Where + /// other parts of the code call `parser.expression` you know + /// they are parsing a "full" expression. + /// + /// The code uses the "precedence climbing" approach. fn expression(parser: *Parser) ParserError!Expression { - std.debug.print("Attempting to parser {s}\n", .{parser.rest()}); - parser.skipWhiteSpace(); - const e = try parser.expr(); - - std.debug.print("parsed expression = {any}\n", .{e}); - - return e; - } - - fn expr(parser: *Parser) ParserError!Expression { - std.debug.print("[{s}]\n", .{@src().fn_name}); var e = try parser.expr1(); - std.debug.print("[{s}] e = {any}\n", .{ @src().fn_name, e }); - while (true) { parser.skipWhiteSpace(); - if (parser.rest().len == 0) break; const op = parser.binaryOp0() catch break; parser.skipWhiteSpace(); const e2 = try parser.expr1(); - std.debug.print("[{s}] e2 = {any}\n", .{ @src().fn_name, e2 }); e = try Expression.binary(parser.allocator, op, e, e2); } @@ -515,21 +450,16 @@ pub const Parser = struct { } fn expr1(parser: *Parser) ParserError!Expression { - std.debug.print("[{s}]\n", .{@src().fn_name}); var e = try parser.expr2(); - std.debug.print("[{s}] e = {any}\n", .{ @src().fn_name, e }); - while (true) { parser.skipWhiteSpace(); - if (parser.rest().len == 0) break; const op = parser.binaryOp1() catch break; parser.skipWhiteSpace(); const e2 = try parser.expr2(); - std.debug.print("[{s}] e2 = {any}\n", .{ @src().fn_name, e2 }); e = try Expression.binary(parser.allocator, op, e, e2); } @@ -538,21 +468,16 @@ pub const Parser = struct { } fn expr2(parser: *Parser) ParserError!Expression { - std.debug.print("[{s}]\n", .{@src().fn_name}); var e = try parser.expr3(); - std.debug.print("[{s}] e = {any}\n", .{ @src().fn_name, e }); - while (true) { parser.skipWhiteSpace(); - if (parser.rest().len == 0) break; const op = parser.binaryOp2() catch break; parser.skipWhiteSpace(); const e2 = try parser.expr3(); - std.debug.print("[{s}] e2 = {any}\n", .{ @src().fn_name, e2 }); e = try Expression.binary(parser.allocator, op, e, e2); } @@ -561,21 +486,16 @@ pub const Parser = struct { } fn expr3(parser: *Parser) ParserError!Expression { - std.debug.print("[{s}]\n", .{@src().fn_name}); var e = try parser.expr4(); - std.debug.print("[{s}] e = {any}\n", .{ @src().fn_name, e }); - while (true) { parser.skipWhiteSpace(); - if (parser.rest().len == 0) break; const op = parser.binaryOp3() catch break; parser.skipWhiteSpace(); const e2 = try parser.expr4(); - std.debug.print("[{s}] e2 = {any}\n", .{ @src().fn_name, e2 }); e = try Expression.binary(parser.allocator, op, e, e2); } @@ -584,21 +504,16 @@ pub const Parser = struct { } fn expr4(parser: *Parser) ParserError!Expression { - std.debug.print("[{s}]\n", .{@src().fn_name}); var e = try parser.expr5(); - std.debug.print("[{s}] e = {any}\n", .{ @src().fn_name, e }); - while (true) { parser.skipWhiteSpace(); - if (parser.rest().len == 0) break; const op = parser.binaryOp4() catch break; parser.skipWhiteSpace(); const e2 = try parser.expr5(); - std.debug.print("[{s}] e2 = {any}\n", .{ @src().fn_name, e2 }); e = try Expression.binary(parser.allocator, op, e, e2); } @@ -607,14 +522,10 @@ pub const Parser = struct { } fn expr5(parser: *Parser) ParserError!Expression { - std.debug.print("[{s}]\n", .{@src().fn_name}); var e = try parser.expr6(); - std.debug.print("[{s}] e = {any}\n", .{ @src().fn_name, e }); - while (true) { parser.skipWhiteSpace(); - if (parser.rest().len == 0) break; const op = parser.binaryOp5() catch break; @@ -622,8 +533,6 @@ pub const Parser = struct { const e2 = try parser.expr6(); - std.debug.print("[{s}] e2 = {any}\n", .{ @src().fn_name, e2 }); - e = try Expression.binary(parser.allocator, op, e, e2); } @@ -631,21 +540,16 @@ pub const Parser = struct { } fn expr6(parser: *Parser) ParserError!Expression { - std.debug.print("[{s}]\n", .{@src().fn_name}); var e = try parser.expr7(); - std.debug.print("[{s}] e = {any}\n", .{ @src().fn_name, e }); - while (true) { parser.skipWhiteSpace(); - if (parser.rest().len == 0) break; const op = parser.binaryOp6() catch break; parser.skipWhiteSpace(); const e2 = try parser.expr7(); - std.debug.print("[{s}] e2 = {any}\n", .{ @src().fn_name, e2 }); e = try Expression.binary(parser.allocator, op, e, e2); } @@ -654,239 +558,223 @@ pub const Parser = struct { } fn expr7(parser: *Parser) ParserError!Expression { - std.debug.print("[{s}]\n", .{@src().fn_name}); - const e1 = try parser.exprTerm(); + var e = try parser.expr8(); - std.debug.print("[{s}] e1 = {any}\n", .{ @src().fn_name, e1 }); + while (true) { + parser.skipWhiteSpace(); - parser.skipWhiteSpace(); + const op = parser.binaryOp7() catch break; - if (!parser.startsWith(".")) return e1; - try parser.expect('.'); + parser.skipWhiteSpace(); - const op = try parser.binaryOp7(); - parser.skipWhiteSpace(); + const e2 = try parser.expr8(); + + e = try Expression.binary(parser.allocator, op, e, e2); + } - std.debug.print("[{s}] op = {any}, rest = \"{s}\"\n", .{ @src().fn_name, op, parser.rest() }); + return e; + } - // if (!parser.startsWith("(")) return error.MissingLeftParen; - try parser.expect('('); + fn expr8(parser: *Parser) ParserError!Expression { + blk: { + var tmp = parser.temporary(); - parser.skipWhiteSpace(); + const e = tmp.unaryNegate() catch break :blk; - std.debug.print("here\n", .{}); + parser.offset += tmp.offset; - const e2 = try parser.expr(); + return e; + } - std.debug.print("[{s}] e2 = {any}\n", .{ @src().fn_name, e2 }); + blk: { + var tmp = parser.temporary(); - parser.skipWhiteSpace(); + const e = tmp.expr9() catch break :blk; - // if (!parser.startsWith(")")) return error.MissingRightParen; - try parser.expect(')'); + parser.offset += tmp.offset; - parser.skipWhiteSpace(); + return e; + } - return try Expression.binary(parser.allocator, op, e1, e2); + return error.ExpectedUnaryNegateOrMethod; } - fn binaryOp0(parser: *Parser) ParserError!Expression.BinaryOp { - if (parser.startsWith("&&")) { - try parser.expectString("&&"); - return .@"and"; - } + /// Parse a unary or binary method + fn expr9(parser: *Parser) ParserError!Expression { + var e1 = try parser.exprTerm(); - if (parser.startsWith("||")) { - try parser.expectString("||"); - return .@"or"; - } + parser.skipWhiteSpace(); - return error.UnexpectedOp; - } + if (!parser.startsWithConsuming(".")) return e1; - fn binaryOp1(parser: *Parser) ParserError!Expression.BinaryOp { - if (parser.startsWith("<=")) { - try parser.expectString("<="); - return .less_or_equal; - } + while (true) { + blk: { + var tmp = parser.temporary(); - if (parser.startsWith(">=")) { - try parser.expectString(">="); - return .greater_or_equal; - } + e1 = tmp.binaryMethod(e1) catch break :blk; - if (parser.startsWith("<")) { - try parser.expectString("<"); - return .less_than; - } + parser.offset += tmp.offset; - if (parser.startsWith(">")) { - try parser.expectString(">"); - return .greater_than; - } + if (parser.startsWithConsuming(".")) continue; + } - if (parser.startsWith("==")) { - try parser.expectString("=="); - return .equal; - } + blk: { + var tmp = parser.temporary(); - if (parser.startsWith("!=")) { - try parser.expectString("!="); - return .not_equal; - } + e1 = tmp.unaryMethod(e1) catch break :blk; - return error.UnexpectedOp; - } + parser.offset += tmp.offset; - fn binaryOp2(parser: *Parser) ParserError!Expression.BinaryOp { - if (parser.startsWith("+")) { - try parser.expectString("+"); - return .add; - } + if (parser.startsWithConsuming(".")) continue; + } - if (parser.startsWith("-")) { - try parser.expectString("-"); - return .sub; + break; } - return error.UnexpectedOp; + return e1; } - fn binaryOp3(parser: *Parser) ParserError!Expression.BinaryOp { - if (parser.startsWith("^")) { - try parser.expectString("^"); - return .bitwise_xor; - } + fn exprTerm(parser: *Parser) ParserError!Expression { + blk: { + var tmp = parser.temporary(); - return error.UnexpectedOp; - } + const p = tmp.unaryParens() catch break :blk; - fn binaryOp4(parser: *Parser) ParserError!Expression.BinaryOp { - if (parser.startsWith("|") and !parser.startsWith("||")) { - try parser.expectString("|"); - return .bitwise_or; + parser.offset += tmp.offset; + + return p; } - return error.UnexpectedOp; - } + // Otherwise we expect term + const term1 = try parser.term(.allow); - fn binaryOp5(parser: *Parser) ParserError!Expression.BinaryOp { - if (parser.startsWith("&") and !parser.startsWith("&&")) { - try parser.expectString("&"); - return .bitwise_and; - } - - return error.UnexpectedOp; + return try Expression.value(term1); } - fn binaryOp6(parser: *Parser) ParserError!Expression.BinaryOp { - if (parser.startsWith("*")) { - try parser.expectString("*"); - return .mul; - } + fn binaryMethod(parser: *Parser, e1: Expression) ParserError!Expression { + const op = try parser.binaryOp8(); - if (parser.startsWith("/")) { - try parser.expectString("/"); - return .div; - } + parser.skipWhiteSpace(); - return error.UnexpectedOp; - } + try parser.consume("("); - fn binaryOp7(parser: *Parser) ParserError!Expression.BinaryOp { - if (parser.startsWith("contains")) { - try parser.expectString("contains"); - return .contains; - } + parser.skipWhiteSpace(); - if (parser.startsWith("starts_with")) { - try parser.expectString("starts_with"); - return .prefix; - } + const e2 = try parser.expression(); - if (parser.startsWith("ends_with")) { - try parser.expectString("ends_with"); - return .suffix; - } + parser.skipWhiteSpace(); - if (parser.startsWith("matches")) { - try parser.expectString("matches"); - return .regex; - } + try parser.consume(")"); - return error.UnexpectedOp; + parser.skipWhiteSpace(); + + return try Expression.binary(parser.allocator, op, e1, e2); } - fn exprTerm(parser: *Parser) ParserError!Expression { - // Try to parse unary - unary_blk: { - var unary_parser = Parser.init(parser.allocator, parser.rest()); + fn unaryMethod(parser: *Parser, e1: Expression) ParserError!Expression { + try parser.consume("length()"); - const p = unary_parser.unary() catch break :unary_blk; + return try Expression.unary(parser.allocator, .length, e1); + } - parser.offset += unary_parser.offset; + fn unaryNegate(parser: *Parser) ParserError!Expression { + try parser.consume("!"); - return p; - } + parser.skipWhiteSpace(); - // Otherwise we expect term - const term1 = try parser.term(); + const e = try parser.expression(); - return try Expression.value(term1); + return try Expression.unary(parser.allocator, .negate, e); } - fn unary(parser: *Parser) ParserError!Expression { + fn unaryParens(parser: *Parser) ParserError!Expression { + try parser.consume("("); + parser.skipWhiteSpace(); - if (parser.peek()) |c| { - if (c == '!') { - try parser.expect('!'); - parser.skipWhiteSpace(); + const e = try parser.expression(); - const e = try parser.expr(); + parser.skipWhiteSpace(); - return try Expression.unary(parser.allocator, .negate, e); - } + try parser.consume(")"); - if (c == '(') { - return try parser.unaryParens(); - } - } + return try Expression.unary(parser.allocator, .parens, e); + } - var e: Expression = undefined; - if (parser.term()) |t1| { - parser.skipWhiteSpace(); - e = try Expression.value(t1); - } else |_| { - e = try parser.unaryParens(); - parser.skipWhiteSpace(); + fn binaryOp0(parser: *Parser) ParserError!Expression.BinaryOp { + if (parser.startsWithConsuming("||")) return .@"or"; + + return error.UnexpectedOp; + } + + fn binaryOp1(parser: *Parser) ParserError!Expression.BinaryOp { + if (parser.startsWithConsuming("&&")) return .@"and"; + + return error.UnexpectedOp; + } + + fn binaryOp2(parser: *Parser) ParserError!Expression.BinaryOp { + if (parser.startsWithConsuming("<=")) return .less_or_equal; + if (parser.startsWithConsuming(">=")) return .greater_or_equal; + if (parser.startsWithConsuming("<")) return .less_than; + if (parser.startsWithConsuming(">")) return .greater_than; + if (parser.startsWithConsuming("==")) return .equal; + if (parser.startsWithConsuming("!=")) return .not_equal; + + return error.UnexpectedOp; + } + + fn binaryOp3(parser: *Parser) ParserError!Expression.BinaryOp { + if (parser.startsWithConsuming("^")) return .bitwise_xor; + + return error.UnexpectedOp; + } + + fn binaryOp4(parser: *Parser) ParserError!Expression.BinaryOp { + if (parser.startsWith("|") and !parser.startsWith("||")) { + try parser.consume("|"); + return .bitwise_or; } - if (parser.expectString(".length()")) |_| { - return try Expression.unary(parser.allocator, .length, e); - } else |_| { - return error.UnexpectedToken; + return error.UnexpectedOp; + } + + fn binaryOp5(parser: *Parser) ParserError!Expression.BinaryOp { + if (parser.startsWith("&") and !parser.startsWith("&&")) { + try parser.consume("&"); + return .bitwise_and; } - return error.UnexpectedToken; + return error.UnexpectedOp; } - fn unaryParens(parser: *Parser) ParserError!Expression { - try parser.expectString("("); + fn binaryOp6(parser: *Parser) ParserError!Expression.BinaryOp { + if (parser.startsWithConsuming("+")) return .add; + if (parser.startsWithConsuming("-")) return .sub; - parser.skipWhiteSpace(); + return error.UnexpectedOp; + } - const e = try parser.expr(); + fn binaryOp7(parser: *Parser) ParserError!Expression.BinaryOp { + if (parser.startsWithConsuming("*")) return .mul; + if (parser.startsWithConsuming("/")) return .div; - parser.skipWhiteSpace(); + return error.UnexpectedOp; + } - try parser.expectString(")"); + fn binaryOp8(parser: *Parser) ParserError!Expression.BinaryOp { + if (parser.startsWithConsuming("contains")) return .contains; + if (parser.startsWithConsuming("starts_with")) return .prefix; + if (parser.startsWithConsuming("ends_with")) return .suffix; + if (parser.startsWithConsuming("matches")) return .regex; + if (parser.startsWithConsuming("intersection")) return .intersection; + if (parser.startsWithConsuming("union")) return .@"union"; - return try Expression.unary(parser.allocator, .parens, e); + return error.UnexpectedOp; } fn scopes(parser: *Parser, allocator: std.mem.Allocator) !std.ArrayList(Scope) { - try parser.expectString("trusting"); + try parser.consume("trusting"); parser.skipWhiteSpace(); @@ -901,35 +789,20 @@ pub const Parser = struct { parser.skipWhiteSpace(); - if (!parser.startsWith(",")) break; - - try parser.expectString(","); + if (!parser.startsWithConsuming(",")) break; } return scps; } fn scope(parser: *Parser, _: std.mem.Allocator) !Scope { - parser.skipWhiteSpace(); + if (parser.startsWithConsuming("authority")) return .{ .authority = {} }; + if (parser.startsWithConsuming("previous")) return .{ .previous = {} }; - if (parser.startsWith("authority")) { - try parser.expectString("authority"); + if (parser.startsWithConsuming("{")) { + const parameter = try parser.name(); - return .{ .authority = {} }; - } - - if (parser.startsWith("previous")) { - try parser.expectString("previous"); - - return .{ .previous = {} }; - } - - if (parser.startsWith("{")) { - try parser.expectString("{"); - - const parameter = parser.readName(); - - try parser.expectString("}"); + try parser.consume("}"); return .{ .parameter = parameter }; } @@ -937,13 +810,12 @@ pub const Parser = struct { return .{ .public_key = try parser.publicKey() }; } + /// Parser a public key. Currently only supports ed25519. fn publicKey(parser: *Parser) !Ed25519.PublicKey { - try parser.expectString("ed25519/"); + try parser.consume("ed25519/"); const h = try parser.hex(); - std.debug.print("publickey = {s}\n", .{h}); - var out_buf: [Ed25519.PublicKey.encoded_length]u8 = undefined; _ = try std.fmt.hexToBytes(out_buf[0..], h); @@ -961,36 +833,87 @@ pub const Parser = struct { return parser.input[parser.offset..]; } - /// Expect (and consume) char. - fn expect(parser: *Parser, char: u8) !void { - const peeked = parser.peek() orelse return error.ExpectedMoreInput; - if (peeked != char) return error.ExpectedChar; - - parser.offset += 1; - } - - /// Expect and consume string. - fn expectString(parser: *Parser, str: []const u8) !void { + /// Expect and consume string. Return error.UnexpectedString if + /// str is not the start of remaining parser input. + fn consume(parser: *Parser, str: []const u8) !void { if (!std.mem.startsWith(u8, parser.rest(), str)) return error.UnexpectedString; parser.offset += str.len; } + /// Returns true if the remaining parser input starts with str + /// + /// Does not consume any input. + /// + /// See also `fn startsWithConsuming` fn startsWith(parser: *Parser, str: []const u8) bool { return std.mem.startsWith(u8, parser.rest(), str); } + /// Returns true if the remaining parse input starts with str. If + /// it does start with that string, the parser consumes the string. + /// + /// See also `fn startsWith` + fn startsWithConsuming(parser: *Parser, str: []const u8) bool { + if (parser.startsWith(str)) { + parser.offset += str.len; // Consume + return true; + } + + return false; + } + /// Reject char. Does not consume the character - fn reject(parser: *Parser, char: u8) !void { - const peeked = parser.peek() orelse return error.ExpectedMoreInput; - if (peeked == char) return error.DisallowedChar; + fn reject(parser: *Parser, str: []const u8) !void { + if (parser.startsWith(str)) return error.DisallowedChar; } - fn hex(parser: *Parser) ![]const u8 { + fn name(parser: *Parser) ![]const u8 { const start = parser.offset; - for (parser.rest()) |c| { - if (ziglyph.isHexDigit(c)) { + if (parser.rest().len == 0) return error.ParsingNameExpectsAtLeastOneCharacter; + + const first_codepoint = try nextCodepoint(parser.rest()); + + if (!ziglyph.isLetter(first_codepoint.codepoint)) return error.ParsingNameFirstCharacterMustBeLetter; + + parser.offset += first_codepoint.len; + + while (true) { + const next_codepoint = try nextCodepoint(parser.rest()); + + if (ziglyph.isAlphaNum(next_codepoint.codepoint)) { + parser.offset += next_codepoint.len; + continue; + } else if (parser.startsWith("_") or parser.startsWith(":")) { + parser.offset += 1; + continue; + } + + break; + } + + return parser.input[start..parser.offset]; + } + + fn variableName(parser: *Parser) ![]const u8 { + const start = parser.offset; + + if (parser.rest().len == 0) return error.ParsingNameExpectsAtLeastOneCharacter; + + const first_codepoint = try nextCodepoint(parser.rest()); + + if (!ziglyph.isAlphaNum(first_codepoint.codepoint)) return error.ParsingNameFirstCharacterMustBeLetter; + + parser.offset += first_codepoint.len; + + while (true) { + const next_codepoint = try nextCodepoint(parser.rest()); + + if (ziglyph.isAlphaNum(next_codepoint.codepoint)) { + parser.offset += next_codepoint.len; + continue; + } else if (parser.startsWith("_") or parser.startsWith(":")) { parser.offset += 1; continue; } @@ -1001,12 +924,11 @@ pub const Parser = struct { return parser.input[start..parser.offset]; } - // FIXME: this should error? - fn readName(parser: *Parser) []const u8 { + fn hex(parser: *Parser) ![]const u8 { const start = parser.offset; for (parser.rest()) |c| { - if (ziglyph.isAlphaNum(c) or c == '_' or c == ':') { + if (isHexDigit(c)) { parser.offset += 1; continue; } @@ -1017,9 +939,17 @@ pub const Parser = struct { return parser.input[start..parser.offset]; } + /// Skip whitespace but the whitespace is required (i.e. we need at least one space, tab or newline) + // fn requiredWhiteSpace(parser: *Parser) !void { + // if (!(parser.startsWith(" ") or parser.startsWith("\t") or parser.startsWith("\n"))) return error.ExpectedWhiteSpace; + + // parser.skipWhiteSpace(); + // } + + /// Skip (optional) whitespace fn skipWhiteSpace(parser: *Parser) void { for (parser.rest()) |c| { - if (ziglyph.isWhiteSpace(c)) { + if (c == ' ' or c == '\t' or c == '\n') { parser.offset += 1; continue; } @@ -1029,11 +959,54 @@ pub const Parser = struct { } }; +const AllowVariables = enum { + allow, + disallow, +}; + +fn isHexDigit(char: u8) bool { + switch (char) { + 'a', 'b', 'c', 'd', 'e', 'f', 'A', 'B', 'C', 'D', 'E', 'F', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9' => return true, + else => return false, + } +} + +fn isDigit(char: u8) bool { + switch (char) { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9' => return true, + else => return false, + } +} + +/// Try to get the next UTF-8 codepoint from input. Returns the codepoint +/// and the number of bytes that codepoint takes up. +fn nextCodepoint(input: []const u8) !struct { codepoint: u21, len: u32 } { + if (input.len == 0) return error.NextCodePointExpectsAtLeastOneByte; + + const first_byte = input[0]; + + const byte_len = try std.unicode.utf8ByteSequenceLength(first_byte); + + if (input.len < byte_len) return error.NotEnoughInputForCodepoint; + + const codepoint = switch (byte_len) { + 1 => try std.unicode.utf8Decode(input[0..1]), + 2 => try std.unicode.utf8Decode2(input[0..2]), + 3 => try std.unicode.utf8Decode3(input[0..3]), + 4 => try std.unicode.utf8Decode4(input[0..4]), + else => return error.IncorrectUtfDecodeLength, + }; + + return .{ .len = byte_len, .codepoint = codepoint }; +} + const ParserError = error{ ExpectedMoreInput, DisallowedChar, UnexpectedString, ExpectedChar, + ExpectedUnaryNegateOrMethod, + ExpectedUnaryOrBinaryMethod, NoFactTermFound, UnexpectedOp, MissingLeftParen, @@ -1042,72 +1015,423 @@ const ParserError = error{ UnexpectedToken, }; -// test "parse fact predicate" { -// const testing = std.testing; -// const input: []const u8 = -// \\read(true) -// ; +test "parse predicates" { + const testing = std.testing; + + var arena_state = std.heap.ArenaAllocator.init(testing.allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + { + var parser = Parser.init(arena, "read(-1, 1, \"hello world\", hex:abcd, true, false, $foo, 2024-03-30T20:48:00Z, [1, 2, 3], [], hex:)"); + const predicate = try parser.predicate(.rule); + + try testing.expectEqualStrings("read", predicate.name); + try testing.expectEqual(-1, predicate.terms.items[0].integer); + try testing.expectEqual(1, predicate.terms.items[1].integer); + try testing.expectEqualStrings("hello world", predicate.terms.items[2].string); + try testing.expectEqualStrings("\xab\xcd", predicate.terms.items[3].bytes); + try testing.expectEqual(true, predicate.terms.items[4].bool); + try testing.expectEqual(false, predicate.terms.items[5].bool); + try testing.expectEqualStrings("foo", predicate.terms.items[6].variable); + try testing.expectEqual(1711831680, predicate.terms.items[7].date); + + const set = predicate.terms.items[8].set; + try testing.expect(set.contains(.{ .integer = 1 })); + try testing.expect(set.contains(.{ .integer = 2 })); + try testing.expect(set.contains(.{ .integer = 3 })); + + const empty_set = predicate.terms.items[9].set; + try testing.expectEqual(0, empty_set.count()); + + try testing.expectEqualStrings("", predicate.terms.items[10].bytes); + } + + { + var parser = Parser.init(arena, "read(true)"); + const predicate = try parser.predicate(.fact); + + try testing.expectEqualStrings("read", predicate.name); + try testing.expectEqual(true, predicate.terms.items[0].bool); + } + + { + // Names can contain : and _ + var parser = Parser.init(arena, "read:write_admin(true)"); + const predicate = try parser.predicate(.fact); + + try testing.expectEqualStrings("read:write_admin", predicate.name); + try testing.expectEqual(true, predicate.terms.items[0].bool); + } + + { + var parser = Parser.init(arena, "read(true, false)"); + const predicate = try parser.predicate(.fact); + + try testing.expectEqualStrings("read", predicate.name); + try testing.expectEqual(true, predicate.terms.items[0].bool); + try testing.expectEqual(false, predicate.terms.items[1].bool); + } + + { + var parser = Parser.init(arena, "read(true,false)"); + const predicate = try parser.predicate(.fact); + + try testing.expectEqualStrings("read", predicate.name); + try testing.expectEqual(true, predicate.terms.items[0].bool); + try testing.expectEqual(false, predicate.terms.items[1].bool); + } + + { + // We are allowed spaces around predicate terms + var parser = Parser.init(arena, "read( true , false )"); + const predicate = try parser.predicate(.fact); + + try testing.expectEqualStrings("read", predicate.name); + try testing.expectEqual(true, predicate.terms.items[0].bool); + try testing.expectEqual(false, predicate.terms.items[1].bool); + } + + { + // We don't allow a space between the predicate name and its opening paren + var parser = Parser.init(arena, "read (true, false )"); + + try testing.expectError(error.UnexpectedString, parser.predicate(.fact)); + } + + { + // We don't allow variables in fact predicates + var parser = Parser.init(arena, "read(true, $foo)"); + + try testing.expectError(error.DisallowedChar, parser.predicate(.fact)); + } + + { + // Non-fact predicates can contain variables + var parser = Parser.init(arena, "read(true, $foo)"); + + const predicate = try parser.predicate(.rule); + + try testing.expectEqualStrings("read", predicate.name); + try testing.expectEqual(true, predicate.terms.items[0].bool); + try testing.expectEqualStrings("foo", predicate.terms.items[1].variable); + } + + { + // Facts must have at least one term + var parser = Parser.init(arena, "read()"); + + try testing.expectError(error.NoFactTermFound, parser.predicate(.fact)); + } + + { + // Facts must start with a (UTF-8) letter + var parser = Parser.init(arena, "3read(true)"); + + try testing.expectError(error.ParsingNameFirstCharacterMustBeLetter, parser.predicate(.fact)); + } + + { + // Names can be UTF-8 + const input = "ビスケット(true)"; + var parser = Parser.init(arena, input); + + const predicate = try parser.predicate(.fact); + + try testing.expectEqualStrings("ビスケット", predicate.name); + try testing.expectEqual(true, predicate.terms.items[0].bool); + } +} + +test "parse numbers" { + const testing = std.testing; + + var arena_state = std.heap.ArenaAllocator.init(testing.allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + { + var parser = Parser.init(arena, "1"); + const integer = try parser.number(i64); + + try testing.expectEqual(1, integer); + } + + { + var parser = Parser.init(arena, "12345"); + const integer = try parser.number(i64); + + try testing.expectEqual(12345, integer); + } + + { + var parser = Parser.init(arena, "-1"); + const integer = try parser.number(i64); + + try testing.expectEqual(-1, integer); + } + + { + var parser = Parser.init(arena, "-"); + + try testing.expectError(error.InvalidCharacter, parser.number(i64)); + } +} + +test "parse boolean" { + const testing = std.testing; + + var arena_state = std.heap.ArenaAllocator.init(testing.allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + { + var parser = Parser.init(arena, "true"); + const boolean = try parser.boolean(); + + try testing.expectEqual(true, boolean); + } + + { + var parser = Parser.init(arena, "false"); + const boolean = try parser.boolean(); + + try testing.expectEqual(false, boolean); + } +} + +test "parse hex" { + const testing = std.testing; + + var arena_state = std.heap.ArenaAllocator.init(testing.allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + { + var parser = Parser.init(arena, "hex:BeEf"); + const bytes = try parser.bytes(); -// var parser = Parser.init(input); + try testing.expectEqualStrings("\xbe\xef", bytes); + } -// const r = try parser.factPredicate(testing.allocator); -// defer r.deinit(); + { + var parser = Parser.init(arena, "hex:BeE"); -// std.debug.print("{any}\n", .{r}); -// } + try testing.expectError(error.ExpectedEvenNumberOfHexDigis, parser.bytes()); + } +} -// test "parse rule body" { -// const testing = std.testing; -// const input: []const u8 = -// \\query(false) <- read(true), write(false) -// ; +test "parse rule" { + const testing = std.testing; -// var parser = Parser.init(testing.allocator, input); + var arena_state = std.heap.ArenaAllocator.init(testing.allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); -// const r = try parser.rule(); -// defer r.deinit(); + { + var parser = Parser.init(arena, "read($0, $1) <- operation($0), file($1)"); + const rule = try parser.rule(); -// std.debug.print("{any}\n", .{r}); -// } + try testing.expectEqualStrings("read", rule.head.name); + try testing.expectEqualStrings("0", rule.head.terms.items[0].variable); + try testing.expectEqualStrings("1", rule.head.terms.items[1].variable); -// test "parse rule body with variables" { -// const testing = std.testing; -// const input: []const u8 = -// \\query($0) <- read($0), write(false) -// ; + try testing.expectEqualStrings("operation", rule.body.items[0].name); + try testing.expectEqualStrings("0", rule.body.items[0].terms.items[0].variable); -// var parser = Parser.init(testing.allocator, input); + try testing.expectEqualStrings("file", rule.body.items[1].name); + try testing.expectEqualStrings("1", rule.body.items[1].terms.items[0].variable); + } -// const r = try parser.rule(); -// defer r.deinit(); + { + // Remove some spaces + var parser = Parser.init(arena, "read($0, $1)<-operation($0),file($1)"); + const rule = try parser.rule(); -// std.debug.print("{any}\n", .{r}); -// } + try testing.expectEqualStrings("read", rule.head.name); + try testing.expectEqualStrings("0", rule.head.terms.items[0].variable); + try testing.expectEqualStrings("1", rule.head.terms.items[1].variable); -// test "parse check" { -// const testing = std.testing; -// const input: []const u8 = -// \\check if right($0, $1), resource($0), operation($1) -// ; + try testing.expectEqualStrings("operation", rule.body.items[0].name); + try testing.expectEqualStrings("0", rule.body.items[0].terms.items[0].variable); -// var parser = Parser.init(testing.allocator, input); + try testing.expectEqualStrings("file", rule.body.items[1].name); + try testing.expectEqualStrings("1", rule.body.items[1].terms.items[0].variable); + } + + { + // Remove some spaces + var parser = Parser.init(arena, "read($0, $1) <- operation($0), 1 < 3, file($1)"); + const rule = try parser.rule(); -// const r = try parser.check(); -// defer r.deinit(); + try testing.expectEqualStrings("read", rule.head.name); + try testing.expectEqualStrings("0", rule.head.terms.items[0].variable); + try testing.expectEqualStrings("1", rule.head.terms.items[1].variable); -// std.debug.print("{any}\n", .{r}); -// } + try testing.expectEqualStrings("operation", rule.body.items[0].name); + try testing.expectEqualStrings("0", rule.body.items[0].terms.items[0].variable); -test "parse check with expression" { + try testing.expectEqualStrings("file", rule.body.items[1].name); + try testing.expectEqualStrings("1", rule.body.items[1].terms.items[0].variable); + + try testing.expectEqualStrings("1 < 3", try std.fmt.allocPrint(arena, "{any}", .{rule.expressions.items[0]})); + } + + { + // We need at least one predicate or expression in the body + var parser = Parser.init(arena, "read($0, $1) <- "); + + try testing.expectError(error.ExpectedPredicateOrExpression, parser.rule()); + } +} + +test "parse check" { const testing = std.testing; - const input: []const u8 = - \\check if right($0, $1), resource($0), operation($1), $0.contains("file") - ; - var parser = Parser.init(testing.allocator, input); + var arena_state = std.heap.ArenaAllocator.init(testing.allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + { + var parser = Parser.init(arena, "check if right($0, $1), resource($0), operation($1), $0.contains(\"file\")"); + const check = try parser.check(); + + try testing.expectEqual(.one, check.kind); + try testing.expectEqual(1, check.queries.items.len); + + try testing.expectEqualStrings("query", check.queries.items[0].head.name); + try testing.expectEqualStrings("right", check.queries.items[0].body.items[0].name); + try testing.expectEqualStrings("resource", check.queries.items[0].body.items[1].name); + try testing.expectEqualStrings("operation", check.queries.items[0].body.items[2].name); + + try testing.expectEqualStrings("$0.contains(\"file\")", try std.fmt.allocPrint(arena, "{any}", .{check.queries.items[0].expressions.items[0]})); + } + + { + // Check with or + var parser = Parser.init(arena, "check if right($0, $1), resource($0), operation($1), $0.contains(\"file\") or admin(true)"); + const check = try parser.check(); + + try testing.expectEqual(.one, check.kind); + try testing.expectEqual(2, check.queries.items.len); + + try testing.expectEqualStrings("query", check.queries.items[0].head.name); + try testing.expectEqualStrings("right", check.queries.items[0].body.items[0].name); + try testing.expectEqualStrings("resource", check.queries.items[0].body.items[1].name); + try testing.expectEqualStrings("operation", check.queries.items[0].body.items[2].name); + + try testing.expectEqualStrings("query", check.queries.items[1].head.name); + try testing.expectEqualStrings("admin", check.queries.items[1].body.items[0].name); + } + + { + // Check all + var parser = Parser.init(arena, "check all right($0, $1), resource($0), operation($1), $0.contains(\"file\") or admin(true)"); + const check = try parser.check(); + + try testing.expectEqual(.all, check.kind); + try testing.expectEqual(2, check.queries.items.len); - const r = try parser.check(); - defer r.deinit(); + try testing.expectEqualStrings("query", check.queries.items[0].head.name); + try testing.expectEqualStrings("right", check.queries.items[0].body.items[0].name); + try testing.expectEqualStrings("resource", check.queries.items[0].body.items[1].name); + try testing.expectEqualStrings("operation", check.queries.items[0].body.items[2].name); - std.debug.print("{any}\n", .{r}); + try testing.expectEqualStrings("query", check.queries.items[1].head.name); + try testing.expectEqualStrings("admin", check.queries.items[1].body.items[0].name); + } + + { + var parser = Parser.init(arena, "check if"); + + try testing.expectError(error.ExpectedPredicateOrExpression, parser.check()); + } + + { + var parser = Parser.init(arena, "check if "); + + try testing.expectError(error.ExpectedPredicateOrExpression, parser.check()); + } + + { + const input = "check if query(1, 2) trusting ed25519/acdd6d5b53bfee478bf689f8e012fe7988bf755e3d7c5152947abc149bc20189, ed25519/a060270db7e9c9f06e8f9cc33a64e99f6596af12cb01c4b638df8afc7b642463"; + var parser = Parser.init(arena, input); + const check = try parser.check(); + + try testing.expectEqual(.one, check.kind); + try testing.expectEqual(1, check.queries.items.len); + + try testing.expectEqualStrings("query", check.queries.items[0].head.name); + try testing.expectEqualStrings("query", check.queries.items[0].body.items[0].name); + + try testing.expectEqual(2, check.queries.items[0].scopes.items.len); + try testing.expectEqualStrings("acdd6d5b53bfee478bf689f8e012fe7988bf755e3d7c5152947abc149bc20189", &std.fmt.bytesToHex(check.queries.items[0].scopes.items[0].public_key.toBytes(), .lower)); + try testing.expectEqualStrings("a060270db7e9c9f06e8f9cc33a64e99f6596af12cb01c4b638df8afc7b642463", &std.fmt.bytesToHex(check.queries.items[0].scopes.items[1].public_key.toBytes(), .lower)); + } +} + +test "parse policy" { + const testing = std.testing; + + var arena_state = std.heap.ArenaAllocator.init(testing.allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + { + var parser = Parser.init(arena, "allow if right($0, $1), resource($0), operation($1), $0.contains(\"file\")"); + const policy = try parser.policy(); + + try testing.expectEqual(.allow, policy.kind); + try testing.expectEqual(1, policy.queries.items.len); + + try testing.expectEqualStrings("query", policy.queries.items[0].head.name); + try testing.expectEqualStrings("right", policy.queries.items[0].body.items[0].name); + try testing.expectEqualStrings("resource", policy.queries.items[0].body.items[1].name); + try testing.expectEqualStrings("operation", policy.queries.items[0].body.items[2].name); + + try testing.expectEqualStrings("$0.contains(\"file\")", try std.fmt.allocPrint(arena, "{any}", .{policy.queries.items[0].expressions.items[0]})); + } + + { + var parser = Parser.init(arena, "deny if right($0, $1), resource($0), operation($1), $0.contains(\"file\")"); + const policy = try parser.policy(); + + try testing.expectEqual(.deny, policy.kind); + try testing.expectEqual(1, policy.queries.items.len); + + try testing.expectEqualStrings("query", policy.queries.items[0].head.name); + try testing.expectEqualStrings("right", policy.queries.items[0].body.items[0].name); + try testing.expectEqualStrings("resource", policy.queries.items[0].body.items[1].name); + try testing.expectEqualStrings("operation", policy.queries.items[0].body.items[2].name); + + try testing.expectEqualStrings("$0.contains(\"file\")", try std.fmt.allocPrint(arena, "{any}", .{policy.queries.items[0].expressions.items[0]})); + } +} + +test "parse expression" { + const testing = std.testing; + + var arena_state = std.heap.ArenaAllocator.init(testing.allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + const inputs = [_][]const u8{ + "1", + "[2]", + "$0.contains(\"file\")", + "!(1 + 2)", + "1 ^ (4 + 6)", + "[1].intersection([2])", + "[1].intersection([2]).length().union([3])", + "1 + 2 * 3 / (4 + 5)", + "[1].length()", + "\"hello\".length()", + }; + + for (inputs) |input| { + var parser = Parser.init(arena, input); + const expression = try parser.expression(); + + try testing.expectEqualStrings(input, try std.fmt.allocPrint(arena, "{any}", .{expression})); + } } diff --git a/biscuit-samples/src/main.zig b/biscuit-samples/src/main.zig index a7dda48..735f261 100644 --- a/biscuit-samples/src/main.zig +++ b/biscuit-samples/src/main.zig @@ -6,6 +6,8 @@ const AuthorizerError = @import("biscuit").AuthorizerError; const Samples = @import("sample.zig").Samples; const Result = @import("sample.zig").Result; +const log = std.log.scoped(.samples); + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; pub fn main() anyerror!void { @@ -41,7 +43,7 @@ pub fn main() anyerror!void { const token = try std.fs.cwd().readFileAlloc(alloc, testcase.filename, 0xFFFFFFF); for (testcase.validations.map.values(), 0..) |validation, i| { - errdefer std.debug.print("Error on validation {} of {s}\n", .{ i, testcase.filename }); + errdefer log.err("Error on validation {} of {s}\n", .{ i, testcase.filename }); try validate(alloc, token, public_key, validation.result, validation.authorizer_code); } } @@ -169,7 +171,7 @@ pub fn runValidation(alloc: mem.Allocator, token: []const u8, public_key: std.cr } _ = a.authorize(errors) catch |err| { - std.debug.print("Authorization failed with errors: {any}\n", .{errors.items}); + log.debug("authorize() returned with errors: {any}\n", .{errors.items}); return err; }; } diff --git a/biscuit/src/authorizer.zig b/biscuit/src/authorizer.zig index 82942ea..a0d1470 100644 --- a/biscuit/src/authorizer.zig +++ b/biscuit/src/authorizer.zig @@ -11,8 +11,10 @@ const Parser = @import("biscuit-parser").Parser; const builder = @import("biscuit-builder"); const PolicyResult = @import("biscuit-builder").PolicyResult; +const log = std.log.scoped(.authorizer); + pub const Authorizer = struct { - allocator: mem.Allocator, + arena: mem.Allocator, checks: std.ArrayList(builder.Check), policies: std.ArrayList(builder.Policy), biscuit: ?Biscuit, @@ -21,9 +23,9 @@ pub const Authorizer = struct { public_key_to_block_id: std.AutoHashMap(usize, std.ArrayList(usize)), scopes: std.ArrayList(Scope), - pub fn init(allocator: std.mem.Allocator, biscuit: Biscuit) !Authorizer { - var symbols = SymbolTable.init("authorizer", allocator); - var public_key_to_block_id = std.AutoHashMap(usize, std.ArrayList(usize)).init(allocator); + pub fn init(arena: std.mem.Allocator, biscuit: Biscuit) !Authorizer { + var symbols = SymbolTable.init("authorizer", arena); + var public_key_to_block_id = std.AutoHashMap(usize, std.ArrayList(usize)).init(arena); // Map public key symbols into authorizer symbols and public_key_to_block_id map var it = biscuit.public_key_to_block_id.iterator(); @@ -39,213 +41,224 @@ pub const Authorizer = struct { } return .{ - .allocator = allocator, - .checks = std.ArrayList(builder.Check).init(allocator), - .policies = std.ArrayList(builder.Policy).init(allocator), + .arena = arena, + .checks = std.ArrayList(builder.Check).init(arena), + .policies = std.ArrayList(builder.Policy).init(arena), .biscuit = biscuit, - .world = World.init(allocator), + .world = World.init(arena), .symbols = symbols, .public_key_to_block_id = public_key_to_block_id, - .scopes = std.ArrayList(Scope).init(allocator), + .scopes = std.ArrayList(Scope).init(arena), }; } - pub fn deinit(authorizer: *Authorizer) void { - authorizer.world.deinit(); - authorizer.symbols.deinit(); - authorizer.scopes.deinit(); - - for (authorizer.checks.items) |check| { - check.deinit(); - } - authorizer.checks.deinit(); - - for (authorizer.policies.items) |policy| { - policy.deinit(); - } - authorizer.policies.deinit(); - - { - var it = authorizer.public_key_to_block_id.valueIterator(); - while (it.next()) |block_ids| { - block_ids.deinit(); - } - authorizer.public_key_to_block_id.deinit(); - } - } - - pub fn authorizerTrustedOrigins(authorizer: *Authorizer) !TrustedOrigins { - return try TrustedOrigins.fromScopes( - authorizer.allocator, - authorizer.scopes.items, - try TrustedOrigins.defaultOrigins(authorizer.allocator), - Origin.AUTHORIZER_ID, - authorizer.public_key_to_block_id, - ); + pub fn deinit(_: *Authorizer) void { + // authorizer.world.deinit(); + // authorizer.symbols.deinit(); + // authorizer.scopes.deinit(); + + // for (authorizer.checks.items) |check| { + // check.deinit(); + // } + // authorizer.checks.deinit(); + + // for (authorizer.policies.items) |policy| { + // policy.deinit(); + // } + // authorizer.policies.deinit(); + + // { + // var it = authorizer.public_key_to_block_id.valueIterator(); + // while (it.next()) |block_ids| { + // block_ids.deinit(); + // } + // authorizer.public_key_to_block_id.deinit(); + // } } - /// Add fact from string to authorizer - pub fn addFact(authorizer: *Authorizer, input: []const u8) !void { - std.debug.print("authorizer.addFact = {s}\n", .{input}); - var parser = Parser.init(authorizer.allocator, input); - - const fact = try parser.fact(); + /// Authorize token with authorizer + /// + /// Will return without error if there is a matching `allow` policy id and the `errors` + /// list is empty. + /// + /// Otherwise will return error.AuthorizationFailed with the reason(s) in `errors` + pub fn authorize(authorizer: *Authorizer, errors: *std.ArrayList(AuthorizerError)) !usize { + log.debug("Starting authorize()", .{}); + defer log.debug("Finished authorize()", .{}); - std.debug.print("fact = {any}\n", .{fact}); + try authorizer.addTokenFactsAndRules(errors); // Step 1 + try authorizer.generateFacts(); // Step 2 + try authorizer.authorizerChecks(errors); // Step 3 + try authorizer.authorityChecks(errors); // Step 4 + const allowed_policy_id: ?usize = try authorizer.authorizerPolicies(errors); // Step 5 + try authorizer.blockChecks(errors); // Step 6 - const origin = try Origin.initWithId(authorizer.allocator, Origin.AUTHORIZER_ID); + if (allowed_policy_id) |policy_id| { + if (errors.items.len == 0) return policy_id; + } - try authorizer.world.addFact(origin, try fact.convert(authorizer.allocator, &authorizer.symbols)); + return error.AuthorizationFailed; } - /// Add check from string to authorizer - pub fn addCheck(authorizer: *Authorizer, input: []const u8) !void { - var parser = Parser.init(authorizer.allocator, input); + /// Step 1 in authorize() + /// + /// Adds all the facts and rules from the token to the authorizer's world. + /// + /// Note: this should only be called from authorize() + fn addTokenFactsAndRules(authorizer: *Authorizer, errors: *std.ArrayList(AuthorizerError)) !void { + // Load facts and rules from authority block into world. Our block's facts + // will have a particular symbol table that we map into the symbol table + // of the world. + // + // For example, the token may have a string "user123" which has id 12. But + // when mapped into the world it may have id 5. + const biscuit = authorizer.biscuit orelse return; - const check = try parser.check(); + // Add authority block's facts + for (biscuit.authority.facts.items) |authority_fact| { + const fact = try authority_fact.remap(&biscuit.symbols, &authorizer.symbols); + const origin = try Origin.initWithId(authorizer.arena, 0); - try authorizer.checks.append(check); - } + try authorizer.world.addFact(origin, fact); + } - /// Add policy from string to authorizer - pub fn addPolicy(authorizer: *Authorizer, input: []const u8) !void { - var parser = Parser.init(authorizer.allocator, input); + const authority_trusted_origins = try TrustedOrigins.fromScopes( + authorizer.arena, + biscuit.authority.scopes.items, + try TrustedOrigins.defaultOrigins(authorizer.arena), + 0, + authorizer.public_key_to_block_id, + ); - const policy = try parser.policy(); + // Add authority block's rules + for (biscuit.authority.rules.items) |authority_rule| { + // Map from biscuit symbol space to authorizer symbol space + const rule = try authority_rule.remap(&biscuit.symbols, &authorizer.symbols); - try authorizer.policies.append(policy); - } - - /// authorize - /// - /// authorize the Authorizer - /// - /// The following high-level steps take place during authorization: - /// 1. _biscuit_ (where it exists): load _all_ of the facts and rules - /// in the biscuit. We can add all the facts and rules as this time because - /// the facts and rules are scoped, i.e. the facts / rules are added to particular - /// scopes within the world. - /// 2. Run the world to generate new facts. - /// 3. _authorizer_: Run the _authorizer's_ checks - /// 4. _biscuit_ (where it exists): run the authority block's checks - /// 5. _authorizer_: Run the _authorizer's_ policies - /// 6. _biscuit_ (where it exists): run the checks from all the non-authority blocks - pub fn authorize(authorizer: *Authorizer, errors: *std.ArrayList(AuthorizerError)) !usize { - std.debug.print("\nAuthorizing biscuit:\n", .{}); + if (!rule.validateVariables()) { + try errors.append(.unbound_variable); + } - std.debug.print("authorizer public keys:\n", .{}); - for (authorizer.symbols.public_keys.items, 0..) |pk, i| { - std.debug.print(" [{}]: {x}\n", .{ i, pk.bytes }); - } + // A authority block's rule trusts + const rule_trusted_origins = try TrustedOrigins.fromScopes( + authorizer.arena, + rule.scopes.items, + authority_trusted_origins, + 0, + authorizer.public_key_to_block_id, + ); - { - var it = authorizer.public_key_to_block_id.iterator(); - while (it.next()) |entry| { - std.debug.print("public_key_to_block_id: public key id = {}, block_ids = {any}\n", .{ entry.key_ptr.*, entry.value_ptr.items }); - } + try authorizer.world.addRule(0, rule_trusted_origins, rule); } - // 1. - // Load facts and rules from authority block into world. Our block's facts - // will have a particular symbol table that we map into the symvol table - // of the world. - // - // For example, the token may have a string "user123" which has id 12. But - // when mapped into the world it may have id 5. - if (authorizer.biscuit) |biscuit| { - std.debug.print("biscuit token public keys:\n", .{}); - for (biscuit.symbols.public_keys.items, 0..) |pk, i| { - std.debug.print(" [{}]: {x}\n", .{ i, pk.bytes }); - } - for (biscuit.authority.facts.items) |authority_fact| { - const fact = try authority_fact.convert(&biscuit.symbols, &authorizer.symbols); - const origin = try Origin.initWithId(authorizer.allocator, 0); + for (biscuit.blocks.items, 1..) |block, block_id| { + // Add block facts + for (block.facts.items) |block_fact| { + const fact = try block_fact.remap(&biscuit.symbols, &authorizer.symbols); + const origin = try Origin.initWithId(authorizer.arena, block_id); try authorizer.world.addFact(origin, fact); } - const authority_trusted_origins = try TrustedOrigins.fromScopes( - authorizer.allocator, - biscuit.authority.scopes.items, - try TrustedOrigins.defaultOrigins(authorizer.allocator), - 0, + const block_trusted_origins = try TrustedOrigins.fromScopes( + authorizer.arena, + block.scopes.items, + try TrustedOrigins.defaultOrigins(authorizer.arena), + block_id, authorizer.public_key_to_block_id, ); - for (biscuit.authority.rules.items) |authority_rule| { - // Map from biscuit symbol space to authorizer symbol space - const rule = try authority_rule.convert(&biscuit.symbols, &authorizer.symbols); + // Add block rules + for (block.rules.items) |block_rule| { + const rule = try block_rule.remap(&biscuit.symbols, &authorizer.symbols); + log.debug("block rule {any} CONVERTED to rule = {any}", .{ block_rule, rule }); if (!rule.validateVariables()) { try errors.append(.unbound_variable); } - // A authority block's rule trusts - const rule_trusted_origins = try TrustedOrigins.fromScopes( - authorizer.allocator, + const block_rule_trusted_origins = try TrustedOrigins.fromScopes( + authorizer.arena, rule.scopes.items, - authority_trusted_origins, - 0, + block_trusted_origins, + block_id, authorizer.public_key_to_block_id, ); - try authorizer.world.addRule(0, rule_trusted_origins, rule); + try authorizer.world.addRule(block_id, block_rule_trusted_origins, rule); } + } + } - for (biscuit.blocks.items, 1..) |block, block_id| { - for (block.facts.items) |block_fact| { - const fact = try block_fact.convert(&biscuit.symbols, &authorizer.symbols); - const origin = try Origin.initWithId(authorizer.allocator, block_id); + /// Step 2 in authorize() + /// + /// Generate all new facts by running world. + fn generateFacts(authorizer: *Authorizer) !void { + log.debug("Run world", .{}); + defer log.debug("Finished running world", .{}); - try authorizer.world.addFact(origin, fact); - } + try authorizer.world.run(&authorizer.symbols); + } - const block_trusted_origins = try TrustedOrigins.fromScopes( - authorizer.allocator, - block.scopes.items, - try TrustedOrigins.defaultOrigins(authorizer.allocator), - block_id, - authorizer.public_key_to_block_id, - ); + /// Step 3 in authorize() + /// + /// Runs the authorizer's checks + /// + /// Note: should only be called from authorize() + fn authorizerChecks(authorizer: *Authorizer, errors: *std.ArrayList(AuthorizerError)) !void { + log.debug("Start authorizerChecks()", .{}); + defer log.debug("End authorizerChecks()", .{}); - for (block.rules.items) |block_rule| { - const rule = try block_rule.convert(&biscuit.symbols, &authorizer.symbols); - std.debug.print("block rule {any} CONVERTED to rule = {any}\n", .{ block_rule, rule }); + for (authorizer.checks.items) |c| { + log.debug("authorizer check = {any}", .{c}); + const check = try c.toDatalog(authorizer.arena, &authorizer.symbols); - if (!rule.validateVariables()) { - try errors.append(.unbound_variable); - } + for (check.queries.items, 0..) |*query, check_id| { + const rule_trusted_origins = try TrustedOrigins.fromScopes( + authorizer.arena, + query.scopes.items, + try authorizer.authorizerTrustedOrigins(), + Origin.AUTHORIZER_ID, + authorizer.public_key_to_block_id, + ); - const block_rule_trusted_origins = try TrustedOrigins.fromScopes( - authorizer.allocator, - rule.scopes.items, - block_trusted_origins, - block_id, - authorizer.public_key_to_block_id, - ); + const is_match = switch (check.kind) { + .one => try authorizer.world.queryMatch(query, &authorizer.symbols, rule_trusted_origins), + .all => try authorizer.world.queryMatchAll(query, &authorizer.symbols, rule_trusted_origins), + }; - try authorizer.world.addRule(block_id, block_rule_trusted_origins, rule); - } + if (!is_match) try errors.append(.{ .failed_authorizer_check = .{ .check_id = check_id } }); + log.debug("match {any} = {}", .{ query, is_match }); } } + } - // 2. Run the world to generate all facts - std.debug.print("\nGENERATING NEW FACTS\n", .{}); - try authorizer.world.run(&authorizer.symbols); - std.debug.print("\nEND GENERATING NEW FACTS\n", .{}); + /// Step 4 in authorize() + /// + /// Run checks from token's authority block + /// + /// Note: should only be called from authorizer() + fn authorityChecks(authorizer: *Authorizer, errors: *std.ArrayList(AuthorizerError)) !void { + const biscuit = authorizer.biscuit orelse return; + + const authority_trusted_origins = try TrustedOrigins.fromScopes( + authorizer.arena, + biscuit.authority.scopes.items, + try TrustedOrigins.defaultOrigins(authorizer.arena), + 0, + authorizer.public_key_to_block_id, + ); - // 3. Run checks that have been added to this authorizer - std.debug.print("\nAUTHORIZER CHECKS\n", .{}); - for (authorizer.checks.items) |c| { - std.debug.print("authorizer check = {any}\n", .{c}); - const check = try c.convert(authorizer.allocator, &authorizer.symbols); + for (biscuit.authority.checks.items, 0..) |c, check_id| { + const check = try c.remap(&biscuit.symbols, &authorizer.symbols); + log.debug("{}: {any}", .{ check_id, check }); - for (check.queries.items, 0..) |*query, check_id| { + for (check.queries.items) |*query| { const rule_trusted_origins = try TrustedOrigins.fromScopes( - authorizer.allocator, + authorizer.arena, query.scopes.items, - try authorizer.authorizerTrustedOrigins(), - Origin.AUTHORIZER_ID, + authority_trusted_origins, + 0, authorizer.public_key_to_block_id, ); @@ -254,32 +267,80 @@ pub const Authorizer = struct { .all => try authorizer.world.queryMatchAll(query, &authorizer.symbols, rule_trusted_origins), }; - if (!is_match) try errors.append(.{ .failed_authorizer_check = .{ .check_id = check_id } }); - std.debug.print("match {any} = {}\n", .{ query, is_match }); + if (!is_match) try errors.append(.{ .failed_block_check = .{ .block_id = 0, .check_id = check_id } }); + log.debug("match {any} = {}", .{ query, is_match }); } } - std.debug.print("END AUTHORIZER CHECKS\n", .{}); - - // 4. Run checks in the biscuit's authority block - if (authorizer.biscuit) |biscuit| { - const authority_trusted_origins = try TrustedOrigins.fromScopes( - authorizer.allocator, - biscuit.authority.scopes.items, - try TrustedOrigins.defaultOrigins(authorizer.allocator), - 0, + } + + /// Step 5 in authorize() + /// + /// Run authorizer's policies. + /// + /// Note: should only be called from authorizer() + fn authorizerPolicies(authorizer: *Authorizer, errors: *std.ArrayList(AuthorizerError)) !?usize { + for (authorizer.policies.items) |policy| { + log.debug("testing policy {any}", .{policy}); + + for (policy.queries.items, 0..) |*q, policy_id| { + var query = try q.toDatalog(authorizer.arena, &authorizer.symbols); + + const rule_trusted_origins = try TrustedOrigins.fromScopes( + authorizer.arena, + query.scopes.items, + try authorizer.authorizerTrustedOrigins(), + Origin.AUTHORIZER_ID, + authorizer.public_key_to_block_id, + ); + + const is_match = try authorizer.world.queryMatch(&query, &authorizer.symbols, rule_trusted_origins); + log.debug("match {any} = {}", .{ query, is_match }); + + if (is_match) { + switch (policy.kind) { + .allow => return policy_id, + .deny => { + try errors.append(.{ .denied_by_policy = .{ .deny_policy_id = policy_id } }); + return null; + }, + } + } + } + } + + try errors.append(.{ .no_matching_policy = {} }); + + return null; + } + + /// Step 6 in authorize() + /// + /// Run checks in all other blocks. + /// + /// Note: should only be called from authorizer() + fn blockChecks(authorizer: *Authorizer, errors: *std.ArrayList(AuthorizerError)) !void { + const biscuit = authorizer.biscuit orelse return; + + for (biscuit.blocks.items, 1..) |block, block_id| { + const block_trusted_origins = try TrustedOrigins.fromScopes( + authorizer.arena, + block.scopes.items, + try TrustedOrigins.defaultOrigins(authorizer.arena), + block_id, authorizer.public_key_to_block_id, ); - for (biscuit.authority.checks.items, 0..) |c, check_id| { - const check = try c.convert(&biscuit.symbols, &authorizer.symbols); - std.debug.print("{}: {any}\n", .{ check_id, check }); + for (block.checks.items, 0..) |c, check_id| { + const check = try c.remap(&biscuit.symbols, &authorizer.symbols); + + log.debug("check = {any}", .{check}); for (check.queries.items) |*query| { const rule_trusted_origins = try TrustedOrigins.fromScopes( - authorizer.allocator, + authorizer.arena, query.scopes.items, - authority_trusted_origins, - 0, + block_trusted_origins, + block_id, authorizer.public_key_to_block_id, ); @@ -288,92 +349,54 @@ pub const Authorizer = struct { .all => try authorizer.world.queryMatchAll(query, &authorizer.symbols, rule_trusted_origins), }; - if (!is_match) try errors.append(.{ .failed_block_check = .{ .block_id = 0, .check_id = check_id } }); - std.debug.print("match {any} = {}\n", .{ query, is_match }); - } - } - } - - // 5. run policies from the authorizer - const allowed_policy_id: ?usize = policy_blk: { - for (authorizer.policies.items) |policy| { - std.debug.print("authorizer policy = {any}\n", .{policy}); - - for (policy.queries.items, 0..) |*q, policy_id| { - var query = try q.convert(authorizer.allocator, &authorizer.symbols); - - const rule_trusted_origins = try TrustedOrigins.fromScopes( - authorizer.allocator, - query.scopes.items, - try authorizer.authorizerTrustedOrigins(), - Origin.AUTHORIZER_ID, - authorizer.public_key_to_block_id, - ); + if (!is_match) try errors.append(.{ .failed_block_check = .{ .block_id = block_id, .check_id = check_id } }); - const is_match = try authorizer.world.queryMatch(&query, &authorizer.symbols, rule_trusted_origins); - std.debug.print("match {any} = {}\n", .{ query, is_match }); - - if (is_match) { - switch (policy.kind) { - .allow => break :policy_blk policy_id, - .deny => { - try errors.append(.{ .denied_by_policy = .{ .deny_policy_id = policy_id } }); - break :policy_blk null; - }, - } - } + log.debug("match {any} = {}", .{ query, is_match }); } } + } + } - try errors.append(.{ .no_matching_policy = {} }); - break :policy_blk null; - }; + pub fn authorizerTrustedOrigins(authorizer: *Authorizer) !TrustedOrigins { + return try TrustedOrigins.fromScopes( + authorizer.arena, + authorizer.scopes.items, + try TrustedOrigins.defaultOrigins(authorizer.arena), + Origin.AUTHORIZER_ID, + authorizer.public_key_to_block_id, + ); + } - // 6. Run checks in the biscuit's other blocks - if (authorizer.biscuit) |biscuit| { - for (biscuit.blocks.items, 1..) |block, block_id| { - const block_trusted_origins = try TrustedOrigins.fromScopes( - authorizer.allocator, - block.scopes.items, - try TrustedOrigins.defaultOrigins(authorizer.allocator), - block_id, - authorizer.public_key_to_block_id, - ); + /// Add fact from string to authorizer + pub fn addFact(authorizer: *Authorizer, input: []const u8) !void { + log.debug("addFact = {s}", .{input}); + var parser = Parser.init(authorizer.arena, input); - std.debug.print("block = {any}\n", .{block}); + const fact = try parser.fact(); - for (block.checks.items, 0..) |c, check_id| { - const check = try c.convert(&biscuit.symbols, &authorizer.symbols); + const origin = try Origin.initWithId(authorizer.arena, Origin.AUTHORIZER_ID); - std.debug.print("check = {any}\n", .{check}); + try authorizer.world.addFact(origin, try fact.toDatalog(authorizer.arena, &authorizer.symbols)); + } - for (check.queries.items) |*query| { - const rule_trusted_origins = try TrustedOrigins.fromScopes( - authorizer.allocator, - query.scopes.items, - block_trusted_origins, - block_id, - authorizer.public_key_to_block_id, - ); + /// Add check from string to authorizer + pub fn addCheck(authorizer: *Authorizer, input: []const u8) !void { + log.debug("addCheck = {s}", .{input}); + var parser = Parser.init(authorizer.arena, input); - const is_match = switch (check.kind) { - .one => try authorizer.world.queryMatch(query, &authorizer.symbols, rule_trusted_origins), - .all => try authorizer.world.queryMatchAll(query, &authorizer.symbols, rule_trusted_origins), - }; + const check = try parser.check(); - if (!is_match) try errors.append(.{ .failed_block_check = .{ .block_id = block_id, .check_id = check_id } }); + try authorizer.checks.append(check); + } - std.debug.print("match {any} = {}\n", .{ query, is_match }); - } - } - } - } + /// Add policy from string to authorizer + pub fn addPolicy(authorizer: *Authorizer, input: []const u8) !void { + log.debug("addPolicy = {s}", .{input}); + var parser = Parser.init(authorizer.arena, input); - if (allowed_policy_id) |policy_id| { - if (errors.items.len == 0) return policy_id; - } + const policy = try parser.policy(); - return error.AuthorizationFailed; + try authorizer.policies.append(policy); } }; @@ -391,4 +414,14 @@ pub const AuthorizerError = union(AuthorizerErrorKind) { failed_authorizer_check: struct { check_id: usize }, failed_block_check: struct { block_id: usize, check_id: usize }, unbound_variable: void, + + pub fn format(authorization_error: AuthorizerError, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + switch (authorization_error) { + .no_matching_policy => try writer.print("no matching policy", .{}), + .denied_by_policy => |e| try writer.print("denied by policy {}", .{e.deny_policy_id}), + .failed_authorizer_check => |e| try writer.print("failed authorizer check {}", .{e.check_id}), + .failed_block_check => |e| try writer.print("failed check {} on block {}", .{ e.check_id, e.block_id }), + .unbound_variable => try writer.print("unbound variable", .{}), + } + } }; diff --git a/biscuit/src/biscuit.zig b/biscuit/src/biscuit.zig index bf00222..631dd5c 100644 --- a/biscuit/src/biscuit.zig +++ b/biscuit/src/biscuit.zig @@ -8,6 +8,8 @@ const SymbolTable = @import("biscuit-datalog").SymbolTable; const World = @import("biscuit-datalog").world.World; const SerializedBiscuit = @import("biscuit-format").SerializedBiscuit; +const log = std.log.scoped(.biscuit); + pub const Biscuit = struct { serialized: SerializedBiscuit, authority: Block, @@ -28,12 +30,12 @@ pub const Biscuit = struct { const authority = try Block.fromBytes(allocator, serialized.authority, &token_symbols); try block_external_keys.append(null); - std.debug.print("authority block =\n{any}\n", .{authority}); + log.debug("authority {any}", .{authority}); var blocks = std.ArrayList(Block).init(allocator); for (serialized.blocks.items) |signed_block| { const block = try Block.fromBytes(allocator, signed_block, &token_symbols); - std.debug.print("non-authority block =\n{any}\n", .{block}); + log.debug("{any}", .{block}); const external_key = if (signed_block.external_signature) |external_signature| external_signature.public_key else null; try block_external_keys.append(external_key); @@ -61,7 +63,7 @@ pub const Biscuit = struct { { var it = public_key_to_block_id.iterator(); while (it.next()) |entry| { - std.debug.print("public_key_to_block_id: public key id = {}, block_ids = {any}\n", .{ entry.key_ptr.*, entry.value_ptr.items }); + log.debug("public_key_to_block_id: public key id = {}, block_ids = {any}", .{ entry.key_ptr.*, entry.value_ptr.items }); } } @@ -96,75 +98,85 @@ pub const Biscuit = struct { } }; -test { - const decode = @import("biscuit-format").decode; - const testing = std.testing; - var allocator = testing.allocator; - - // Key - // private: 83e03c958f83085923f3cd091bab3c3b33a0c7f93f44889739fdb6c6fdb26f5b - // public: 49fe7ec1972952c8c92119def96235ad622d0d024f3042a49c7317f7d5baf3da - const tokens: [6][]const u8 = .{ - "EpACCqUBCgFhCgFiCgFjCgFkCgdtYWxjb2xtCgRqb2huCgEwCgExGAMiCQoHCAISAxiACCIJCgcIAhIDGIEIIgkKBwgCEgMYgggiCQoHCAISAxiDCCIOCgwIBxIDGIQIEgMYgAgiDgoMCAcSAxiECBIDGIIIIg4KDAgHEgMYhQgSAxiBCCokCgsIBBIDCIYIEgIYABIHCAISAwiGCBIMCAcSAwiHCBIDCIYIEiQIABIgnSmYbzjEQ2n09JhlmGs6j_ZhKYgj3nRkEMdGJJqQimwaQD4UTmEDtu5G8kRJZbNTcNuGg8Izb5ja2BSV3Rlkv1Y6IV_Nd00sIstiEq1RPH-M8xfFdWaW1gixH54Y5deHzwYiIgogFmxoQyXPm8ccNBKKh0hv8eRwrYjS56s0OTQWZShHoVw=", - "En0KEwoEMTIzNBgDIgkKBwgKEgMYgAgSJAgAEiCicdgxKsSQpGYPKcR7hmnI7WcRLaFNUNzqkCc92yZluhpAyMoux34FBhYaTsw32rddToN7qbl-XOAPQcaUALPg_SfmuxfXbU9aEIJGVCANQLUfoQwU1GAa8ZkXESkW1uCdCyIiCiCyJCJ0e-e00kyM_3O6IbbftDeYAnkoI8-G1x06NK283w==", - "En0KEwoEMTIzNBgDIgkKBwgKEgMYgAgSJAgAEiCicdgxKsSQpGYPKcR7hmnI7WcRLaFNUNzqkCc92yZluhpAyMoux34FBhYaTsw32rddToN7qbl-XOAPQcaUALPg_SfmuxfXbU9aEIJGVCANQLUfoQwU1GAa8ZkXESkW1uCdCxp9ChMKBGFiY2QYAyIJCgcIAhIDGIEIEiQIABIgkJwspMgTz4pW4hQ_Tkua7EdZ5AajdxV35q42IyXzAt0aQBH3kiLfP06W0dPlQeuxgLU26ssrjoK-v1vvw0dzQ2BtaQjPs8eKhsowhFCjQ6nnhSP0p7v4TaJHWeO2fPsbUQwiIgogeuDcbq6waTZ1HpYt_zYNtAy02gbnjV-5-juc9sdXNJg=", - "En0KEwoEMTIzNBgDIgkKBwgKEgMYgAgSJAgAEiCicdgxKsSQpGYPKcR7hmnI7WcRLaFNUNzqkCc92yZluhpAyMoux34FBhYaTsw32rddToN7qbl-XOAPQcaUALPg_SfmuxfXbU9aEIJGVCANQLUfoQwU1GAa8ZkXESkW1uCdCxp9ChMKBGFiY2QYAyIJCgcIAhIDGIEIEiQIABIgkJwspMgTz4pW4hQ_Tkua7EdZ5AajdxV35q42IyXzAt0aQBH3kiLfP06W0dPlQeuxgLU26ssrjoK-v1vvw0dzQ2BtaQjPs8eKhsowhFCjQ6nnhSP0p7v4TaJHWeO2fPsbUQwiQhJAfNph7vZIL6WSLwOCmMHkwb4OmCc5s7EByizwq6HZOF04SRwCF8THWcNImPj-5xWOuI3zVdxg11Qr6d0c5yxuCw==", - "Eq4BCkQKBDEyMzQKBmRvdWJsZQoBeAoBeRgDIgkKBwgKEgMYgAgqIQoNCIEIEgMIgggSAwiDCBIHCAoSAwiCCBIHCAoSAwiDCBIkCAASIHJpGIZ74pbiyybTMn2zrCqHf5t7ZUV9tMnT5xkLq5rsGkCnAznWzInI1-kJGuRUgluqmr96bJwKG3RT3iceJ3kzzzBWGT5dEFXYyIqWxpLDk9Qoy-AWpwS49SA5ynGKb5UGIiIKIESr7u80iDgTstDzVk6obTp6zJmVfBqNcBNtwjOQyVOr", - // Token with check in authority block (that should pass): - "Eq8CCsQBCgFhCgFiCgFjCgFkCgdtYWxjb2xtCgRqb2huCgExCgEwGAMiCQoHCAISAxiACCIJCgcIAhIDGIEIIgkKBwgCEgMYgggiCQoHCAISAxiDCCIOCgwIBxIDGIQIEgMYgAgiDgoMCAcSAxiECBIDGIIIIg4KDAgHEgMYhQgSAxiBCCopChAIBBIDCIYIEgMIhwgSAhgAEgcIAhIDCIcIEgwIBxIDCIYIEgMIhwgyGAoWCgIIGxIQCAQSAxiECBIDGIAIEgIYABIkCAASIGMjO8ucGcxZst9FINaf7EmOsWh8kW039G8TeV9BYIhTGkCrqL87m-bqFGxmNUobqmw7iWHViQN6DRDksNCJMfkC1zvwVdSZwZwtgQmr90amKCPjdXCD0bev53dNyIanRPoPIiIKIMAzV_GYyKdq9NeJ80-E-bGqGYD4nLXCDRnGpzThEglb", - }; - - var public_key_mem: [32]u8 = undefined; - _ = try std.fmt.hexToBytes(&public_key_mem, "49fe7ec1972952c8c92119def96235ad622d0d024f3042a49c7317f7d5baf3da"); - const public_key = try Ed25519.PublicKey.fromBytes(public_key_mem); - - for (tokens) |token| { - const bytes = try decode.urlSafeBase64ToBytes(allocator, token); - defer allocator.free(bytes); - - var b = try Biscuit.fromBytes(allocator, bytes, public_key); - defer b.deinit(); - - var a = try b.authorizer(allocator); - defer a.deinit(); - - var errors = std.ArrayList(AuthorizerError).init(allocator); - defer errors.deinit(); - - _ = try a.authorize(&errors); - } -} +// test { +// const decode = @import("biscuit-format").decode; +// const testing = std.testing; +// var allocator = testing.allocator; -test "Tokens that should fail to validate" { - const decode = @import("biscuit-format").decode; - const testing = std.testing; - var allocator = testing.allocator; +// // Key +// // private: 83e03c958f83085923f3cd091bab3c3b33a0c7f93f44889739fdb6c6fdb26f5b +// // public: 49fe7ec1972952c8c92119def96235ad622d0d024f3042a49c7317f7d5baf3da +// const tokens: [6][]const u8 = .{ +// "EpACCqUBCgFhCgFiCgFjCgFkCgdtYWxjb2xtCgRqb2huCgEwCgExGAMiCQoHCAISAxiACCIJCgcIAhIDGIEIIgkKBwgCEgMYgggiCQoHCAISAxiDCCIOCgwIBxIDGIQIEgMYgAgiDgoMCAcSAxiECBIDGIIIIg4KDAgHEgMYhQgSAxiBCCokCgsIBBIDCIYIEgIYABIHCAISAwiGCBIMCAcSAwiHCBIDCIYIEiQIABIgnSmYbzjEQ2n09JhlmGs6j_ZhKYgj3nRkEMdGJJqQimwaQD4UTmEDtu5G8kRJZbNTcNuGg8Izb5ja2BSV3Rlkv1Y6IV_Nd00sIstiEq1RPH-M8xfFdWaW1gixH54Y5deHzwYiIgogFmxoQyXPm8ccNBKKh0hv8eRwrYjS56s0OTQWZShHoVw=", +// "En0KEwoEMTIzNBgDIgkKBwgKEgMYgAgSJAgAEiCicdgxKsSQpGYPKcR7hmnI7WcRLaFNUNzqkCc92yZluhpAyMoux34FBhYaTsw32rddToN7qbl-XOAPQcaUALPg_SfmuxfXbU9aEIJGVCANQLUfoQwU1GAa8ZkXESkW1uCdCyIiCiCyJCJ0e-e00kyM_3O6IbbftDeYAnkoI8-G1x06NK283w==", +// "En0KEwoEMTIzNBgDIgkKBwgKEgMYgAgSJAgAEiCicdgxKsSQpGYPKcR7hmnI7WcRLaFNUNzqkCc92yZluhpAyMoux34FBhYaTsw32rddToN7qbl-XOAPQcaUALPg_SfmuxfXbU9aEIJGVCANQLUfoQwU1GAa8ZkXESkW1uCdCxp9ChMKBGFiY2QYAyIJCgcIAhIDGIEIEiQIABIgkJwspMgTz4pW4hQ_Tkua7EdZ5AajdxV35q42IyXzAt0aQBH3kiLfP06W0dPlQeuxgLU26ssrjoK-v1vvw0dzQ2BtaQjPs8eKhsowhFCjQ6nnhSP0p7v4TaJHWeO2fPsbUQwiIgogeuDcbq6waTZ1HpYt_zYNtAy02gbnjV-5-juc9sdXNJg=", +// "En0KEwoEMTIzNBgDIgkKBwgKEgMYgAgSJAgAEiCicdgxKsSQpGYPKcR7hmnI7WcRLaFNUNzqkCc92yZluhpAyMoux34FBhYaTsw32rddToN7qbl-XOAPQcaUALPg_SfmuxfXbU9aEIJGVCANQLUfoQwU1GAa8ZkXESkW1uCdCxp9ChMKBGFiY2QYAyIJCgcIAhIDGIEIEiQIABIgkJwspMgTz4pW4hQ_Tkua7EdZ5AajdxV35q42IyXzAt0aQBH3kiLfP06W0dPlQeuxgLU26ssrjoK-v1vvw0dzQ2BtaQjPs8eKhsowhFCjQ6nnhSP0p7v4TaJHWeO2fPsbUQwiQhJAfNph7vZIL6WSLwOCmMHkwb4OmCc5s7EByizwq6HZOF04SRwCF8THWcNImPj-5xWOuI3zVdxg11Qr6d0c5yxuCw==", +// "Eq4BCkQKBDEyMzQKBmRvdWJsZQoBeAoBeRgDIgkKBwgKEgMYgAgqIQoNCIEIEgMIgggSAwiDCBIHCAoSAwiCCBIHCAoSAwiDCBIkCAASIHJpGIZ74pbiyybTMn2zrCqHf5t7ZUV9tMnT5xkLq5rsGkCnAznWzInI1-kJGuRUgluqmr96bJwKG3RT3iceJ3kzzzBWGT5dEFXYyIqWxpLDk9Qoy-AWpwS49SA5ynGKb5UGIiIKIESr7u80iDgTstDzVk6obTp6zJmVfBqNcBNtwjOQyVOr", +// // Token with check in authority block (that should pass): +// "Eq8CCsQBCgFhCgFiCgFjCgFkCgdtYWxjb2xtCgRqb2huCgExCgEwGAMiCQoHCAISAxiACCIJCgcIAhIDGIEIIgkKBwgCEgMYgggiCQoHCAISAxiDCCIOCgwIBxIDGIQIEgMYgAgiDgoMCAcSAxiECBIDGIIIIg4KDAgHEgMYhQgSAxiBCCopChAIBBIDCIYIEgMIhwgSAhgAEgcIAhIDCIcIEgwIBxIDCIYIEgMIhwgyGAoWCgIIGxIQCAQSAxiECBIDGIAIEgIYABIkCAASIGMjO8ucGcxZst9FINaf7EmOsWh8kW039G8TeV9BYIhTGkCrqL87m-bqFGxmNUobqmw7iWHViQN6DRDksNCJMfkC1zvwVdSZwZwtgQmr90amKCPjdXCD0bev53dNyIanRPoPIiIKIMAzV_GYyKdq9NeJ80-E-bGqGYD4nLXCDRnGpzThEglb", +// }; - // Key - // private: 83e03c958f83085923f3cd091bab3c3b33a0c7f93f44889739fdb6c6fdb26f5b - // public: 49fe7ec1972952c8c92119def96235ad622d0d024f3042a49c7317f7d5baf3da - const tokens: [1][]const u8 = .{ - // Token with check (in authority block) that should pass and a check (in the authority block) that should fail - "Es8CCuQBCgFhCgFiCgFjCgFkCgdtYWxjb2xtCgRqb2huCgExCgEwCgRlcmljGAMiCQoHCAISAxiACCIJCgcIAhIDGIEIIgkKBwgCEgMYgggiCQoHCAISAxiDCCIOCgwIBxIDGIQIEgMYgAgiDgoMCAcSAxiECBIDGIIIIg4KDAgHEgMYhQgSAxiBCCopChAIBBIDCIYIEgMIhwgSAhgAEgcIAhIDCIcIEgwIBxIDCIYIEgMIhwgyGAoWCgIIGxIQCAQSAxiECBIDGIAIEgIYADIYChYKAggbEhAIBBIDGIgIEgMYgwgSAhgBEiQIABIgbACOx_sohlqZpzEwG23cKbN5wsUseLHHPt1tM8zVilIaQHMBawtn2NIa0jkJ38FR-uw7ncEAP1Qp_g6zctajVDLo1eMhBzjBO6lCddBHyEgvwZ9bufXYClHAwEZQyGKeEgwiIgogCfqPElEy9fyO6r-E5GT9-io3bhhSSe9wVAn6x6fsM7k=", - }; +// var public_key_mem: [32]u8 = undefined; +// _ = try std.fmt.hexToBytes(&public_key_mem, "49fe7ec1972952c8c92119def96235ad622d0d024f3042a49c7317f7d5baf3da"); +// const public_key = try Ed25519.PublicKey.fromBytes(public_key_mem); - var public_key_mem: [32]u8 = undefined; - _ = try std.fmt.hexToBytes(&public_key_mem, "49fe7ec1972952c8c92119def96235ad622d0d024f3042a49c7317f7d5baf3da"); - const public_key = try Ed25519.PublicKey.fromBytes(public_key_mem); +// for (tokens) |token| { +// const bytes = try decode.urlSafeBase64ToBytes(allocator, token); +// defer allocator.free(bytes); - for (tokens) |token| { - const bytes = try decode.urlSafeBase64ToBytes(allocator, token); - defer allocator.free(bytes); +// var arena_state = std.heap.ArenaAllocator.init(testing.allocator); +// defer arena_state.deinit(); - var b = try Biscuit.fromBytes(allocator, bytes, public_key); - defer b.deinit(); +// const arena = arena_state.allocator(); - var a = try b.authorizer(allocator); - defer a.deinit(); +// var b = try Biscuit.fromBytes(arena, bytes, public_key); +// defer b.deinit(); - var errors = std.ArrayList(AuthorizerError).init(allocator); - defer errors.deinit(); +// var a = try b.authorizer(arena); +// defer a.deinit(); - try testing.expectError(error.AuthorizationFailed, a.authorize(&errors)); - } -} +// var errors = std.ArrayList(AuthorizerError).init(allocator); +// defer errors.deinit(); + +// _ = try a.authorize(&errors); +// } +// } + +// test "Tokens that should fail to validate" { +// const decode = @import("biscuit-format").decode; +// const testing = std.testing; +// var allocator = testing.allocator; + +// // Key +// // private: 83e03c958f83085923f3cd091bab3c3b33a0c7f93f44889739fdb6c6fdb26f5b +// // public: 49fe7ec1972952c8c92119def96235ad622d0d024f3042a49c7317f7d5baf3da +// const tokens: [1][]const u8 = .{ +// // Token with check (in authority block) that should pass and a check (in the authority block) that should fail +// "Es8CCuQBCgFhCgFiCgFjCgFkCgdtYWxjb2xtCgRqb2huCgExCgEwCgRlcmljGAMiCQoHCAISAxiACCIJCgcIAhIDGIEIIgkKBwgCEgMYgggiCQoHCAISAxiDCCIOCgwIBxIDGIQIEgMYgAgiDgoMCAcSAxiECBIDGIIIIg4KDAgHEgMYhQgSAxiBCCopChAIBBIDCIYIEgMIhwgSAhgAEgcIAhIDCIcIEgwIBxIDCIYIEgMIhwgyGAoWCgIIGxIQCAQSAxiECBIDGIAIEgIYADIYChYKAggbEhAIBBIDGIgIEgMYgwgSAhgBEiQIABIgbACOx_sohlqZpzEwG23cKbN5wsUseLHHPt1tM8zVilIaQHMBawtn2NIa0jkJ38FR-uw7ncEAP1Qp_g6zctajVDLo1eMhBzjBO6lCddBHyEgvwZ9bufXYClHAwEZQyGKeEgwiIgogCfqPElEy9fyO6r-E5GT9-io3bhhSSe9wVAn6x6fsM7k=", +// }; + +// var public_key_mem: [32]u8 = undefined; +// _ = try std.fmt.hexToBytes(&public_key_mem, "49fe7ec1972952c8c92119def96235ad622d0d024f3042a49c7317f7d5baf3da"); +// const public_key = try Ed25519.PublicKey.fromBytes(public_key_mem); + +// for (tokens) |token| { +// const bytes = try decode.urlSafeBase64ToBytes(allocator, token); +// defer allocator.free(bytes); + +// var arena_state = std.heap.ArenaAllocator.init(testing.allocator); +// defer arena_state.deinit(); + +// const arena = arena_state.allocator(); + +// var b = try Biscuit.fromBytes(arena, bytes, public_key); +// defer b.deinit(); + +// var a = try b.authorizer(arena); +// defer a.deinit(); + +// var errors = std.ArrayList(AuthorizerError).init(allocator); +// defer errors.deinit(); + +// try testing.expectError(error.AuthorizationFailed, a.authorize(&errors)); +// } +// } diff --git a/biscuit/src/block.zig b/biscuit/src/block.zig index 304dbc3..4e15b60 100644 --- a/biscuit/src/block.zig +++ b/biscuit/src/block.zig @@ -10,6 +10,8 @@ const Check = @import("biscuit-datalog").check.Check; const Scope = @import("biscuit-datalog").Scope; const SymbolTable = @import("biscuit-datalog").symbol_table.SymbolTable; +const log = std.log.scoped(.block); + pub const Block = struct { version: u32, context: []const u8, @@ -34,9 +36,9 @@ pub const Block = struct { } pub fn deinit(block: *Block) void { - for (block.checks.items) |*check| check.deinit(); - for (block.rules.items) |*rule| rule.deinit(); - for (block.facts.items) |*fact| fact.deinit(); + // for (block.checks.items) |*check| check.deinit(); + // for (block.rules.items) |*rule| rule.deinit(); + // for (block.facts.items) |*fact| fact.deinit(); block.checks.deinit(); block.rules.deinit(); @@ -49,7 +51,7 @@ pub const Block = struct { /// Given a blocks contents as bytes, derserialize into runtime block pub fn fromBytes(allocator: std.mem.Allocator, signed_block: SignedBlock, token_symbols: *SymbolTable) !Block { const data = signed_block.block; - std.debug.print("Block.fromBytes\n", .{}); + const decoded_block = try schema.decodeBlock(allocator, data); defer decoded_block.deinit();