From e7dc3231483b3087c849848467edce266f57b4d4 Mon Sep 17 00:00:00 2001 From: Malcolm Still Date: Sat, 23 Mar 2024 12:31:41 +0000 Subject: [PATCH] WIP parser --- biscuit-builder/src/check.zig | 4 + biscuit-builder/src/fact.zig | 4 +- biscuit-builder/src/predicate.zig | 16 +- biscuit-builder/src/root.zig | 3 + biscuit-builder/src/rule.zig | 14 +- biscuit-builder/src/term.zig | 12 ++ biscuit-datalog/src/check.zig | 2 +- biscuit-datalog/src/main.zig | 3 + biscuit-parser/build.zig | 2 + biscuit-parser/build.zig.zon | 1 + biscuit-parser/src/main.zig | 240 +++++++++++++++++++++++++++--- biscuit-samples/src/main.zig | 4 +- biscuit/src/authorizer.zig | 7 +- 13 files changed, 283 insertions(+), 29 deletions(-) diff --git a/biscuit-builder/src/check.zig b/biscuit-builder/src/check.zig index 903fcad..56a85f2 100644 --- a/biscuit-builder/src/check.zig +++ b/biscuit-builder/src/check.zig @@ -2,8 +2,12 @@ const std = @import("std"); const datalog = @import("biscuit-datalog"); const Predicate = @import("predicate.zig").Predicate; const Term = @import("term.zig").Term; +const Rule = @import("rule.zig").Rule; pub const Check = struct { + kind: datalog.Check.Kind, + queries: std.ArrayList(Rule), + pub fn deinit(_: Check) void { // } diff --git a/biscuit-builder/src/fact.zig b/biscuit-builder/src/fact.zig index bafb106..0b1cd19 100644 --- a/biscuit-builder/src/fact.zig +++ b/biscuit-builder/src/fact.zig @@ -12,8 +12,8 @@ pub const Fact = struct { } /// convert to datalog fact - pub fn convert(fact: Fact) datalog.Fact { - return .{ .predicate = fact.predicate.convert() }; + pub fn convert(fact: Fact, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Fact { + return .{ .predicate = try fact.predicate.convert(allocator, symbols) }; } pub fn format(fact: Fact, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { diff --git a/biscuit-builder/src/predicate.zig b/biscuit-builder/src/predicate.zig index 86ef1fb..8eec66f 100644 --- a/biscuit-builder/src/predicate.zig +++ b/biscuit-builder/src/predicate.zig @@ -7,12 +7,24 @@ pub const Predicate = struct { terms: std.ArrayList(Term), pub fn deinit(predicate: Predicate) void { + for (predicate.terms.items) |term| { + term.deinit(); + } + predicate.terms.deinit(); } /// convert to datalog predicate - pub fn convert(_: Predicate) datalog.Predicate { - unreachable; + pub fn convert(predicate: Predicate, allocator: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Predicate { + const name = try symbols.insert(predicate.name); + + var terms = std.ArrayList(datalog.Term).init(allocator); + + for (predicate.terms.items) |term| { + try terms.append(try term.convert(allocator, symbols)); + } + + return .{ .name = name, .terms = terms }; } pub fn format(predicate: Predicate, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { diff --git a/biscuit-builder/src/root.zig b/biscuit-builder/src/root.zig index 22f678b..8ab25be 100644 --- a/biscuit-builder/src/root.zig +++ b/biscuit-builder/src/root.zig @@ -2,3 +2,6 @@ pub const Fact = @import("fact.zig").Fact; pub const Predicate = @import("predicate.zig").Predicate; pub const Term = @import("term.zig").Term; pub const Check = @import("check.zig").Check; +pub const Rule = @import("rule.zig").Rule; +pub const Expression = @import("expression.zig").Expression; +pub const Scope = @import("scope.zig").Scope; diff --git a/biscuit-builder/src/rule.zig b/biscuit-builder/src/rule.zig index adc69f6..5a23a53 100644 --- a/biscuit-builder/src/rule.zig +++ b/biscuit-builder/src/rule.zig @@ -9,9 +9,21 @@ pub const Rule = struct { head: Predicate, body: std.ArrayList(Predicate), expressions: std.ArrayList(Expression), - variables: ?std.AutoHashMap([]const u8, ?Term), + variables: ?std.StringHashMap(?Term), scopes: std.ArrayList(Scope), + pub fn deinit(rule: Rule) void { + rule.head.deinit(); + + for (rule.body.items) |predicate| { + predicate.deinit(); + } + + rule.body.deinit(); + rule.expressions.deinit(); + rule.scopes.deinit(); + } + /// convert to datalog predicate pub fn convert(_: Rule) datalog.Rule { unreachable; diff --git a/biscuit-builder/src/term.zig b/biscuit-builder/src/term.zig index dcf1277..50be1ac 100644 --- a/biscuit-builder/src/term.zig +++ b/biscuit-builder/src/term.zig @@ -1,3 +1,6 @@ +const std = @import("std"); +const datalog = @import("biscuit-datalog"); + const TermTag = enum(u8) { string, bool, @@ -6,4 +9,13 @@ const TermTag = enum(u8) { pub const Term = union(TermTag) { string: []const u8, bool: bool, + + pub fn deinit(_: Term) void {} + + pub fn convert(term: Term, _: std.mem.Allocator, symbols: *datalog.SymbolTable) !datalog.Term { + return switch (term) { + .string => |s| .{ .string = try symbols.insert(s) }, + .bool => |b| .{ .bool = b }, + }; + } }; diff --git a/biscuit-datalog/src/check.zig b/biscuit-datalog/src/check.zig index 6aa6c67..8e78d6c 100644 --- a/biscuit-datalog/src/check.zig +++ b/biscuit-datalog/src/check.zig @@ -6,7 +6,7 @@ pub const Check = struct { queries: std.ArrayList(Rule), kind: Kind, - const Kind = enum(u8) { one, all }; + pub const Kind = enum(u8) { one, all }; pub fn fromSchema(allocator: std.mem.Allocator, schema_check: schema.CheckV2) !Check { var rules = std.ArrayList(Rule).init(allocator); diff --git a/biscuit-datalog/src/main.zig b/biscuit-datalog/src/main.zig index 4735ab7..5a5918f 100644 --- a/biscuit-datalog/src/main.zig +++ b/biscuit-datalog/src/main.zig @@ -5,6 +5,9 @@ pub const Predicate = @import("predicate.zig").Predicate; pub const rule = @import("rule.zig"); pub const check = @import("check.zig"); pub const symbol_table = @import("symbol_table.zig"); +pub const SymbolTable = @import("symbol_table.zig").SymbolTable; +pub const Term = @import("term.zig").Term; +pub const Check = @import("check.zig").Check; pub const world = @import("world.zig"); test { diff --git a/biscuit-parser/build.zig b/biscuit-parser/build.zig index ce05399..cbb23e9 100644 --- a/biscuit-parser/build.zig +++ b/biscuit-parser/build.zig @@ -19,6 +19,7 @@ pub fn build(b: *std.Build) void { const schema = b.dependency("biscuit-schema", .{ .target = target, .optimize = optimize }); const format = b.dependency("biscuit-format", .{ .target = target, .optimize = optimize }); const builder = b.dependency("biscuit-builder", .{ .target = target, .optimize = optimize }); + const datalog = b.dependency("biscuit-datalog", .{ .target = target, .optimize = optimize }); _ = b.addModule("biscuit-parser", .{ .root_source_file = .{ .path = "src/main.zig" }, @@ -26,6 +27,7 @@ pub fn build(b: *std.Build) void { .{ .name = "biscuit-schema", .module = schema.module("biscuit-schema") }, .{ .name = "biscuit-format", .module = format.module("biscuit-format") }, .{ .name = "biscuit-builder", .module = builder.module("biscuit-builder") }, + .{ .name = "biscuit-datalog", .module = datalog.module("biscuit-datalog") }, .{ .name = "ziglyph", .module = ziglyph.module("ziglyph") }, }, }); diff --git a/biscuit-parser/build.zig.zon b/biscuit-parser/build.zig.zon index eb581fc..54f2748 100644 --- a/biscuit-parser/build.zig.zon +++ b/biscuit-parser/build.zig.zon @@ -41,6 +41,7 @@ .@"biscuit-schema" = .{ .path = "../biscuit-schema" }, .@"biscuit-format" = .{ .path = "../biscuit-format" }, .@"biscuit-builder" = .{ .path = "../biscuit-builder" }, + .@"biscuit-datalog" = .{ .path = "../biscuit-datalog" }, .ziglyph = .{ .url = "https://codeberg.org/dude_the_builder/ziglyph/archive/947ed39203bf90412e3d16cbcf936518b6f23af0.tar.gz", .hash = "12208b23d1eb6dcb929e85346524db8f8b8aa1401bdf8a97dee1e0cfb55da8d5fb42", diff --git a/biscuit-parser/src/main.zig b/biscuit-parser/src/main.zig index d5bb804..e037d72 100644 --- a/biscuit-parser/src/main.zig +++ b/biscuit-parser/src/main.zig @@ -1,7 +1,13 @@ const std = @import("std"); const ziglyph = @import("ziglyph"); +const datalog = @import("biscuit-datalog"); const Term = @import("biscuit-builder").Term; const Fact = @import("biscuit-builder").Fact; +const Check = @import("biscuit-builder").Check; +const Rule = @import("biscuit-builder").Rule; +const Predicate = @import("biscuit-builder").Predicate; +const Expression = @import("biscuit-builder").Expression; +const Scope = @import("biscuit-builder").Scope; pub const Parser = struct { input: []const u8, @@ -12,12 +18,14 @@ pub const Parser = struct { } pub fn fact(parser: *Parser, allocator: std.mem.Allocator) !Fact { - return try parser.factPredicate(allocator); + return .{ .predicate = try parser.factPredicate(allocator), .variables = null }; } - pub fn factPredicate(parser: *Parser, allocator: std.mem.Allocator) !Fact { + pub fn factPredicate(parser: *Parser, allocator: std.mem.Allocator) !Predicate { const name = parser.readName(); + std.debug.print("name = {s}\n", .{name}); + parser.skipWhiteSpace(); // Consume left paren @@ -27,8 +35,8 @@ pub const Parser = struct { var terms = std.ArrayList(Term).init(allocator); var it = parser.factTermsIterator(); - while (try it.next()) |term| { - try terms.append(term); + while (try it.next()) |trm| { + try terms.append(trm); if (parser.peek()) |peeked| { if (peeked != ',') break; @@ -39,10 +47,7 @@ pub const Parser = struct { try parser.expect(')'); - return .{ - .predicate = .{ .name = name, .terms = terms }, - .variables = null, - }; + return .{ .name = name, .terms = terms }; } const FactTermIterator = struct { @@ -51,9 +56,7 @@ pub const Parser = struct { pub fn next(it: *FactTermIterator) !?Term { it.parser.skipWhiteSpace(); - const term = try it.parser.factTerm(); - - return term; + return try it.parser.factTerm(); } }; @@ -61,6 +64,46 @@ pub const Parser = struct { return .{ .parser = parser }; } + const TermIterator = struct { + parser: *Parser, + + pub fn next(it: *TermIterator) !?Term { + it.parser.skipWhiteSpace(); + + return try it.parser.term(); + } + }; + + pub fn termsIterator(parser: *Parser) TermIterator { + return .{ .parser = parser }; + } + + pub fn term(parser: *Parser) !Term { + const rst = parser.rest(); + + string_blk: { + var term_parser = Parser.init(rst); + + const value = term_parser.string() catch break :string_blk; + + parser.offset += term_parser.offset; + + return .{ .string = value }; + } + + bool_blk: { + var term_parser = Parser.init(rst); + + const value = term_parser.boolean() catch break :bool_blk; + + parser.offset += term_parser.offset; + + return .{ .bool = value }; + } + + return error.NoFactTermFound; + } + pub fn factTerm(parser: *Parser) !Term { const rst = parser.rest(); @@ -68,31 +111,66 @@ pub const Parser = struct { string_blk: { var term_parser = Parser.init(rst); - const s = term_parser.string() catch { - break :string_blk; - }; + + const value = term_parser.string() catch break :string_blk; parser.offset += term_parser.offset; - return .{ .string = s }; + + return .{ .string = value }; } bool_blk: { var term_parser = Parser.init(rst); - const b = term_parser.boolean() catch { - break :bool_blk; - }; + + const value = term_parser.boolean() catch break :bool_blk; parser.offset += term_parser.offset; - return .{ .bool = b }; + return .{ .bool = value }; } return error.NoFactTermFound; } + pub fn predicate(parser: *Parser, allocator: std.mem.Allocator) !Predicate { + const name = parser.readName(); + + parser.skipWhiteSpace(); + + // Consume left paren + try parser.expect('('); + + // Parse terms + var terms = std.ArrayList(Term).init(allocator); + + var it = parser.termsIterator(); + while (try it.next()) |trm| { + try terms.append(trm); + + if (parser.peek()) |peeked| { + if (peeked != ',') break; + } else { + break; + } + } + + try parser.expect(')'); + + return .{ .name = name, .terms = terms }; + } + fn string(parser: *Parser) ![]const u8 { try parser.expect('"'); + const start = parser.offset; + + while (parser.peek()) |peeked| { + defer parser.offset += 1; + if (peeked == '"') { + return parser.input[start..parser.offset]; + } + } + return error.ExpectedStringTerm; } @@ -110,6 +188,116 @@ pub const Parser = struct { return error.ExpectedBooleanTerm; } + pub fn check(parser: *Parser) !Check { + var kind: datalog.Check.Kind = undefined; + + if (std.mem.startsWith(u8, parser.rest(), "check if")) { + parser.offset += "check if".len; + kind = .one; + } else if (std.mem.startsWith(u8, parser.rest(), "check all")) { + parser.offset += "check all".len; + kind = .all; + } else { + return error.UnexpectedCheckKind; + } + + const queries = try parser.checkBody(); + + return .{ .kind = kind, .queries = queries }; + } + + fn checkBody(_: *Parser) !std.ArrayList(Rule) { + unreachable; + } + + pub fn rule(parser: *Parser, allocator: std.mem.Allocator) !Rule { + const head = try parser.predicate(allocator); + + parser.skipWhiteSpace(); + + if (!std.mem.startsWith(u8, parser.rest(), "<-")) return error.ExpectedArrow; + + parser.offset += "<-".len; + + const body = try parser.ruleBody(allocator); + + return .{ + .head = head, + .body = body.predicates, + .expressions = body.expressions, + .scopes = body.scopes, + .variables = null, + }; + } + + pub fn ruleBody(parser: *Parser, allocator: std.mem.Allocator) !struct { predicates: std.ArrayList(Predicate), expressions: std.ArrayList(Expression), scopes: std.ArrayList(Scope) } { + var predicates = std.ArrayList(Predicate).init(allocator); + var expressions = std.ArrayList(Expression).init(allocator); + var scopes = std.ArrayList(Scope).init(allocator); + + while (true) { + parser.skipWhiteSpace(); + + // Try parsing a predicate + predicate_blk: { + var predicate_parser = Parser.init(parser.rest()); + + const p = predicate_parser.predicate(allocator) catch break :predicate_blk; + + parser.offset += predicate_parser.offset; + + try predicates.append(p); + + parser.skipWhiteSpace(); + + if (parser.peek()) |peeked| { + if (peeked == ',') continue; + } + } + + // Otherwise try parsing an expression + expression_blk: { + var expression_parser = Parser.init(parser.rest()); + + const e = expression_parser.expression(allocator) catch break :expression_blk; + + parser.offset += expression_parser.offset; + + try expressions.append(e); + + parser.skipWhiteSpace(); + + if (parser.peek()) |peeked| { + if (peeked == ',') continue; + } + } + + // We haven't found a predicate or expression so we're done, + // other than potentially parsing a scope + break; + } + + scope_blk: { + var scope_parser = Parser.init(parser.rest()); + + const s = scope_parser.scope(allocator) catch break :scope_blk; + + parser.offset += scope_parser.offset; + + try scopes.append(s); + } + + return .{ .predicates = predicates, .expressions = expressions, .scopes = scopes }; + } + + fn expression(_: *Parser, _: std.mem.Allocator) !Expression { + return error.Unimplemented; + } + + fn scope(_: *Parser, _: std.mem.Allocator) !Scope { + return error.Unimplemented; + } + fn peek(parser: *Parser) ?u8 { if (parser.input[parser.offset..].len == 0) return null; @@ -173,3 +361,17 @@ test "parse fact predicate" { defer f.deinit(); std.debug.print("{any}\n", .{f}); } + +test "parse rule body" { + const testing = std.testing; + const rule_body: []const u8 = + \\query(false) <- read(true), write(false) + ; + + var parser = Parser.init(rule_body); + + const r = try parser.rule(testing.allocator); + defer r.deinit(); + + std.debug.print("{any}\n", .{r}); +} diff --git a/biscuit-samples/src/main.zig b/biscuit-samples/src/main.zig index 737dba8..03989e8 100644 --- a/biscuit-samples/src/main.zig +++ b/biscuit-samples/src/main.zig @@ -123,11 +123,11 @@ pub fn runValidation(alloc: mem.Allocator, token: []const u8, public_key: std.cr var it = std.mem.split(u8, authorizer_code, ";"); while (it.next()) |code| { - const text = std.mem.trim(u8, code, " "); + const text = std.mem.trim(u8, code, " \n"); if (text.len == 0) continue; if (std.mem.startsWith(u8, text, "check if") or std.mem.startsWith(u8, text, "check all")) { - // try a.addCheck(text); + try a.addCheck(text); } else if (std.mem.startsWith(u8, text, "allow if") or std.mem.startsWith(u8, text, "deny if")) { // try a.addPolicy(text); } else if (std.mem.startsWith(u8, text, "revocation_id")) { diff --git a/biscuit/src/authorizer.zig b/biscuit/src/authorizer.zig index ef7d133..dbdbb5f 100644 --- a/biscuit/src/authorizer.zig +++ b/biscuit/src/authorizer.zig @@ -35,14 +35,17 @@ pub const Authorizer = struct { } pub fn addFact(authorizer: *Authorizer, input: []const u8) !void { + std.debug.print("addFact = \"{s}\"\n", .{input}); var parser = Parser.init(input); const fact = try parser.fact(authorizer.allocator); - try authorizer.world.addFact(fact.convert()); + std.debug.print("fact = {any}\n", .{fact}); + + try authorizer.world.addFact(try fact.convert(authorizer.allocator, &authorizer.symbols)); } - pub fn addCheck(authorizer: *Authorizer, input: []const u8) void { + pub fn addCheck(authorizer: *Authorizer, input: []const u8) !void { var parser = Parser.init(input); const check = try parser.check();