Skip to content

Commit

Permalink
Concat
Browse files Browse the repository at this point in the history
  • Loading branch information
malcolmstill committed Mar 28, 2024
1 parent 5ca2107 commit 270e63c
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 30 deletions.
4 changes: 2 additions & 2 deletions biscuit-datalog/src/combinator.zig
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ pub const Combinator = struct {
current_bindings: ?std.AutoHashMap(u64, Term) = null,
facts: *const FactSet,
trusted_fact_iterator: FactSet.TrustedIterator,
symbols: SymbolTable,
symbols: *SymbolTable,
trusted_origins: TrustedOrigins,

pub fn init(id: usize, allocator: mem.Allocator, variables: MatchedVariables, predicates: []Predicate, expressions: []Expression, all_facts: *const FactSet, symbols: SymbolTable, trusted_origins: TrustedOrigins) Combinator {
pub fn init(id: usize, allocator: mem.Allocator, variables: MatchedVariables, predicates: []Predicate, expressions: []Expression, all_facts: *const FactSet, symbols: *SymbolTable, trusted_origins: TrustedOrigins) Combinator {
return .{
.id = id,
.allocator = allocator,
Expand Down
14 changes: 9 additions & 5 deletions biscuit-datalog/src/expression.zig
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub const Expression = struct {
expression.ops.deinit();
}

pub fn evaluate(expr: Expression, allocator: mem.Allocator, values: std.AutoHashMap(u32, Term), symbols: SymbolTable) !Term {
pub fn evaluate(expr: Expression, allocator: mem.Allocator, values: std.AutoHashMap(u32, Term), symbols: *SymbolTable) !Term {
var stack = std.ArrayList(Term).init(allocator);
defer stack.deinit();

Expand Down Expand Up @@ -173,7 +173,7 @@ const Unary = enum {
parens,
length,

pub fn evaluate(expr: Unary, value: Term, symbols: SymbolTable) !Term {
pub fn evaluate(expr: Unary, value: Term, symbols: *SymbolTable) !Term {
_ = symbols; // Different type instead of SymbolTable
//
return switch (expr) {
Expand Down Expand Up @@ -207,7 +207,7 @@ const Binary = enum {
bitwise_xor,
not_equal,

pub fn evaluate(expr: Binary, allocator: std.mem.Allocator, left: Term, right: Term, symbols: SymbolTable) !Term {
pub fn evaluate(expr: Binary, allocator: std.mem.Allocator, left: Term, right: Term, symbols: *SymbolTable) !Term {
// Integer operands
if (left == .integer and right == .integer) {
const i = left.integer;
Expand Down Expand Up @@ -238,7 +238,7 @@ const Binary = enum {
.suffix => .{ .bool = mem.endsWith(u8, sl, sr) },
.regex => .{ .bool = try match(allocator, sr, sl) },
.contains => .{ .bool = mem.containsAtLeast(u8, sl, 1, sr) },
.add => return error.StringConcatNotImplemented,
.add => .{ .string = try symbols.insert(try concat(allocator, sl, sr)) },
.equal => .{ .bool = mem.eql(u8, sl, sr) },
.not_equal => .{ .bool = !mem.eql(u8, sl, sr) },
else => return error.UnexpectedOperationForStringOperands,
Expand Down Expand Up @@ -296,7 +296,11 @@ const Binary = enum {
fn match(allocator: std.mem.Allocator, regex: []const u8, string: []const u8) !bool {
var re = try Regex.compile(allocator, regex);

return re.match(string);
return re.partialMatch(string);
}

fn concat(allocator: std.mem.Allocator, left: []const u8, right: []const u8) ![]const u8 {
return try std.mem.concat(allocator, u8, &[_][]const u8{ left, right });
}

test {
Expand Down
2 changes: 1 addition & 1 deletion biscuit-datalog/src/matched_variables.zig
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ pub const MatchedVariables = struct {
matchced_variables: *const MatchedVariables,
allocator: std.mem.Allocator,
expressions: []Expression,
symbols: SymbolTable,
symbols: *SymbolTable,
) !bool {
const variables = try matchced_variables.complete(allocator) orelse return error.IncompleteVariables;

Expand Down
18 changes: 11 additions & 7 deletions biscuit-datalog/src/rule.zig
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ pub const Rule = struct {
/// ```
///
/// ...and we add it to the set of facts (the set will take care of deduplication)
pub fn apply(rule: *const Rule, allocator: mem.Allocator, origin_id: u64, facts: *const FactSet, new_facts: *FactSet, symbols: SymbolTable, trusted_origins: TrustedOrigins) !void {
pub fn apply(rule: *const Rule, allocator: mem.Allocator, origin_id: u64, facts: *const FactSet, new_facts: *FactSet, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !void {
var arena = std.heap.ArenaAllocator.init(allocator);
defer arena.deinit();

Expand Down Expand Up @@ -173,7 +173,7 @@ pub const Rule = struct {
///
/// Note: whilst the combinator may return multiple valid matches, `findMatch` only requires a single match
/// so stopping on the first `it.next()` that returns not-null is enough.
pub fn findMatch(rule: *Rule, allocator: mem.Allocator, facts: *const FactSet, symbols: SymbolTable, trusted_origins: TrustedOrigins) !bool {
pub fn findMatch(rule: *Rule, allocator: mem.Allocator, facts: *const FactSet, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !bool {
std.debug.print("\nrule.findMatch on {any} ({any})\n", .{ rule, trusted_origins });
var arena = std.heap.ArenaAllocator.init(allocator);
defer arena.deinit();
Expand Down Expand Up @@ -208,7 +208,7 @@ pub const Rule = struct {
}
}

pub fn checkMatchAll(rule: *Rule, allocator: mem.Allocator, facts: *const FactSet, symbols: SymbolTable, trusted_origins: TrustedOrigins) !bool {
pub fn checkMatchAll(rule: *Rule, allocator: mem.Allocator, facts: *const FactSet, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !bool {
std.debug.print("\nrule.checkMatchAll on {any} ({any})\n", .{ rule, trusted_origins });
var arena = std.heap.ArenaAllocator.init(allocator);
defer arena.deinit();
Expand Down Expand Up @@ -252,11 +252,15 @@ pub const Rule = struct {
if (i < rule.body.items.len - 1) try writer.print(", ", .{});
}

if (rule.expressions.items.len > 0) try writer.print(", ", .{});
if (rule.expressions.items.len > 0) {
try writer.print(", [", .{});

for (rule.expressions.items, 0..) |*expression, i| {
try writer.print("{any}", .{expression.*});
if (i < rule.expressions.items.len - 1) try writer.print(", ", .{});
for (rule.expressions.items, 0..) |*expression, i| {
try writer.print("{any}", .{expression.*});
if (i < rule.expressions.items.len - 1) try writer.print(", ", .{});
}

try writer.print("]", .{});
}

if (rule.scopes.items.len > 0) try writer.print(", ", .{});
Expand Down
8 changes: 4 additions & 4 deletions biscuit-datalog/src/world.zig
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ pub const World = struct {
world.fact_set.deinit();
}

pub fn run(world: *World, symbols: SymbolTable) !void {
pub fn run(world: *World, symbols: *SymbolTable) !void {
try world.runWithLimits(symbols, .{});
}

pub fn runWithLimits(world: *World, symbols: SymbolTable, limits: RunLimits) !void {
pub fn runWithLimits(world: *World, symbols: *SymbolTable, limits: RunLimits) !void {
for (0..limits.max_iterations) |iteration| {
std.debug.print("\nrunWithLimits[{}]\n", .{iteration});
const starting_fact_count = world.fact_set.count();
Expand Down Expand Up @@ -109,11 +109,11 @@ pub const World = struct {
try world.rule_set.add(origin_id, scope, rule);
}

pub fn queryMatch(world: *World, rule: *Rule, symbols: SymbolTable, trusted_origins: TrustedOrigins) !bool {
pub fn queryMatch(world: *World, rule: *Rule, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !bool {
return rule.findMatch(world.allocator, &world.fact_set, symbols, trusted_origins);
}

pub fn queryMatchAll(world: *World, rule: *Rule, symbols: SymbolTable, trusted_origins: TrustedOrigins) !bool {
pub fn queryMatchAll(world: *World, rule: *Rule, symbols: *SymbolTable, trusted_origins: TrustedOrigins) !bool {
return rule.checkMatchAll(world.allocator, &world.fact_set, symbols, trusted_origins);
}
};
22 changes: 11 additions & 11 deletions biscuit/src/authorizer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ pub const Authorizer = struct {

// 2. Run the world to generate all facts
std.debug.print("\nGENERATING NEW FACTS\n", .{});
try authorizer.world.run(authorizer.symbols);
try authorizer.world.run(&authorizer.symbols);
std.debug.print("\nEND GENERATING NEW FACTS\n", .{});

// 3. Run checks that have been added to this authorizer
Expand All @@ -242,8 +242,8 @@ pub const Authorizer = struct {
);

const is_match = switch (check.kind) {
.one => try authorizer.world.queryMatch(query, authorizer.symbols, rule_trusted_origins),
.all => try authorizer.world.queryMatchAll(query, authorizer.symbols, rule_trusted_origins),
.one => try authorizer.world.queryMatch(query, &authorizer.symbols, rule_trusted_origins),
.all => try authorizer.world.queryMatchAll(query, &authorizer.symbols, rule_trusted_origins),
};

if (!is_match) try errors.append(.{ .failed_authorizer_check = .{ .check_id = check_id } });
Expand All @@ -262,11 +262,11 @@ pub const Authorizer = struct {
authorizer.public_key_to_block_id,
);

for (biscuit.authority.checks.items) |c| {
for (biscuit.authority.checks.items, 0..) |c, check_id| {
const check = try c.convert(&biscuit.symbols, &authorizer.symbols);
std.debug.print("{any}\n", .{check});
std.debug.print("{}: {any}\n", .{ check_id, check });

for (check.queries.items, 0..) |*query, check_id| {
for (check.queries.items) |*query| {
const rule_trusted_origins = try TrustedOrigins.fromScopes(
authorizer.allocator,
query.scopes.items,
Expand All @@ -276,8 +276,8 @@ pub const Authorizer = struct {
);

const is_match = switch (check.kind) {
.one => try authorizer.world.queryMatch(query, authorizer.symbols, rule_trusted_origins),
.all => try authorizer.world.queryMatchAll(query, authorizer.symbols, rule_trusted_origins),
.one => try authorizer.world.queryMatch(query, &authorizer.symbols, rule_trusted_origins),
.all => try authorizer.world.queryMatchAll(query, &authorizer.symbols, rule_trusted_origins),
};

if (!is_match) try errors.append(.{ .failed_block_check = .{ .block_id = 0, .check_id = check_id } });
Expand All @@ -302,7 +302,7 @@ pub const Authorizer = struct {
authorizer.public_key_to_block_id,
);

const is_match = try authorizer.world.queryMatch(&query, authorizer.symbols, rule_trusted_origins);
const is_match = try authorizer.world.queryMatch(&query, &authorizer.symbols, rule_trusted_origins);
std.debug.print("match {any} = {}\n", .{ query, is_match });

if (is_match) {
Expand Down Expand Up @@ -349,8 +349,8 @@ pub const Authorizer = struct {
);

const is_match = switch (check.kind) {
.one => try authorizer.world.queryMatch(query, authorizer.symbols, rule_trusted_origins),
.all => try authorizer.world.queryMatchAll(query, authorizer.symbols, rule_trusted_origins),
.one => try authorizer.world.queryMatch(query, &authorizer.symbols, rule_trusted_origins),
.all => try authorizer.world.queryMatchAll(query, &authorizer.symbols, rule_trusted_origins),
};

if (!is_match) try errors.append(.{ .failed_block_check = .{ .block_id = block_id, .check_id = check_id } });
Expand Down

0 comments on commit 270e63c

Please sign in to comment.