Skip to content

Commit

Permalink
add datalog parser for block
Browse files Browse the repository at this point in the history
  • Loading branch information
KannarFr committed Apr 8, 2024
1 parent 45a163c commit e622e1d
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package org.biscuitsec.biscuit.token.builder.parser;

import biscuit.format.schema.Schema;
import io.vavr.collection.Stream;
import org.biscuitsec.biscuit.crypto.PublicKey;
import org.biscuitsec.biscuit.datalog.SymbolTable;
import org.biscuitsec.biscuit.token.Policy;
import io.vavr.Tuple2;
import io.vavr.Tuple4;
Expand All @@ -10,12 +12,48 @@

import java.time.OffsetDateTime;
import java.time.format.DateTimeParseException;
import java.util.ArrayList;
import java.util.List;
import java.util.HashSet;
import java.util.*;
import java.util.function.Function;

public class Parser {
/**
* Takes a datalog string with <code>\n</code> as datalog line separator. It tries to parse
* each line using fact, rule, check and scope sequentially.
*
* If one succeed it returns Right(Block)
* else it returns a Map[lineNumber, List[Error]]
*
* @param index block index
* @param baseSymbols symbols table
* @param s datalog string to parse
* @return Either<Map<Integer, List<Error>>, Block>
*/
public static Either<Map<Integer, List<Error>>, Block> datalog(long index, SymbolTable baseSymbols, String s) {
Block blockBuilder = new Block(index, baseSymbols);
Map<Integer, List<Error>> errors = new HashMap<>();

Stream.of(s.split("\n")).zipWithIndex().forEach(indexedLine -> {
Integer lineNumber = indexedLine._2;
String codeLine = indexedLine._1;
List<Error> lineErrors = new ArrayList<>();

fact(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_fact);
rule(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_rule);
check(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_check);
scope(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_scope);

if (lineErrors.size() > 3) {
errors.put(lineNumber, lineErrors);
}
});

if (!errors.isEmpty()) {
return Either.left(errors);
}

return Either.right(blockBuilder);
}

public static Either<Error, Tuple2<String, Fact>> fact(String s) {
Either<Error, Tuple2<String, Predicate>> res = fact_predicate(s);
if (res.isLeft()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.biscuitsec.biscuit.datalog.SymbolTable;
import org.biscuitsec.biscuit.datalog.TemporarySymbolTable;
import org.biscuitsec.biscuit.datalog.expressions.Op;
import org.biscuitsec.biscuit.token.Biscuit;
import org.biscuitsec.biscuit.token.builder.parser.Error;
import org.biscuitsec.biscuit.token.builder.parser.Parser;
import io.vavr.Tuple2;
Expand All @@ -15,7 +16,7 @@
import org.junit.jupiter.api.Test;

import static org.biscuitsec.biscuit.datalog.Check.Kind.One;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.*;

import java.util.*;

Expand Down Expand Up @@ -360,4 +361,40 @@ void testParens() throws org.biscuitsec.biscuit.error.Error.Execution {
assertEquals(new org.biscuitsec.biscuit.datalog.Term.Integer(9), value2);
assertEquals("(1 + 2) * 3", ex2.print(s2).get());
}

@Test
void testDatalogSucceeds() throws org.biscuitsec.biscuit.error.Error.Parser {
SymbolTable symbols = Biscuit.default_symbol_table();

String l1 = "fact1(1)";
String l2 = "fact2(\"2\")";
String l3 = "rule1(2) <- fact2(\"2\")";
String l4 = "check if rule1(2)";
String toParse = String.join("\n", Arrays.asList(l1, l2, l3, l4));

Either<Map<Integer, List<Error>>, Block> output = Parser.datalog(1, symbols, toParse);
assertTrue(output.isRight());

Block validBlock = new Block(1, symbols);
validBlock.add_fact(l1);
validBlock.add_fact(l2);
validBlock.add_rule(l3);
validBlock.add_check(l4);

output.forEach(block ->
assertArrayEquals(block.build().to_bytes().get(), validBlock.build().to_bytes().get())
);
}

@Test
void testDatalogFailed() {
SymbolTable symbols = Biscuit.default_symbol_table();

String l1 = "fact(1)";
String l2 = "check fact(1)"; // typo missing "if"
String toParse = String.join("\n", Arrays.asList(l1, l2));

Either<Map<Integer, List<Error>>, Block> output = Parser.datalog(1, symbols, toParse);
assertTrue(output.isLeft());
}
}

0 comments on commit e622e1d

Please sign in to comment.