diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java index e3bac7cb..54eee4d5 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java @@ -16,6 +16,104 @@ import java.util.function.Function; public class Parser { + public static Either>, Block> datalog(long index, SymbolTable baseSymbols, String s) { + return datalog(index, baseSymbols, null, s); + } + + /** + * Takes a datalog string with \n as datalog line separator. It tries to parse + * each line using fact, rule, check and scope sequentially. + * + * If one succeeds it returns Right(Block) + * else it returns a Map[lineNumber, List[Error]] + * + * @param index block index + * @param baseSymbols symbols table + * @param blockSymbols block's custom symbols table (added to baseSymbols) + * @param s datalog string to parse + * @return Either>, Block> + */ + public static Either>, Block> datalog(long index, SymbolTable baseSymbols, SymbolTable blockSymbols, String s) { + Block blockBuilder = new Block(index, baseSymbols); + + // empty block code + if (s.isEmpty()) { + return Either.right(blockBuilder); + } + + if (blockSymbols != null) { + blockSymbols.symbols.forEach(blockBuilder::addSymbol); + } + + Map> errors = new HashMap<>(); + + s = removeCommentsAndWhitespaces(s); + String[] codeLines = s.split(";"); + + Stream.of(codeLines) + .zipWithIndex() + .forEach(indexedLine -> { + String code = indexedLine._1.strip(); + + if (!code.isEmpty()) { + int lineNumber = indexedLine._2; + System.out.println("NEW CODE LINE"); + System.out.println(code); + List lineErrors = new ArrayList<>(); + + boolean parsed = false; + parsed = rule(code).fold(e -> { + lineErrors.add(e); + return false; + }, r -> { + blockBuilder.add_rule(r._2); + return true; + }); + + if (!parsed) { + parsed = scope(code).fold(e -> { + lineErrors.add(e); + return false; + }, r -> { + blockBuilder.add_scope(r._2); + return true; + }); + } + + if (!parsed) { + parsed = fact(code).fold(e -> { + lineErrors.add(e); + return false; + }, r -> { + blockBuilder.add_fact(r._2); + return true; + }); + } + + if (!parsed) { + parsed = check(code).fold(e -> { + lineErrors.add(e); + return false; + }, r -> { + blockBuilder.add_check(r._2); + return true; + }); + } + + if (!parsed) { + lineErrors.forEach(System.out::println); + errors.put(lineNumber, lineErrors); + } + } + }); + + if (!errors.isEmpty()) { + return Either.left(errors); + } + + return Either.right(blockBuilder); + } + public static Either> fact(String s) { Either> res = fact_predicate(s); if (res.isLeft()) { @@ -671,4 +769,43 @@ public static Tuple2 take_while(String s, Function(s.substring(0, index), s.substring(index)); } + + public static String removeCommentsAndWhitespaces(String s) { + s = removeComments(s); + s = s.replace("\n", "").replace("\\\"", "\"").strip(); + return s; + } + + public static String removeComments(String str) { + StringBuilder result = new StringBuilder(); + String remaining = str; + + while (!remaining.isEmpty()) { + remaining = space(remaining); // Skip leading whitespace + if (remaining.startsWith("/*")) { + // Find the end of the multiline comment + remaining = remaining.substring(2); // Skip "/*" + String finalRemaining = remaining; + Tuple2 split = take_while(remaining, c -> !finalRemaining.startsWith("*/")); + remaining = split._2.length() > 2 ? split._2.substring(2) : ""; // Skip "*/" + } else if (remaining.startsWith("//")) { + // Find the end of the single-line comment + remaining = remaining.substring(2); // Skip "//" + Tuple2 split = take_while(remaining, c -> c != '\n' && c != '\r'); + remaining = split._2; + if (!remaining.isEmpty()) { + result.append(remaining.charAt(0)); // Preserve line break + remaining = remaining.substring(1); + } + } else { + // Take non-comment text until the next comment or end of string + String finalRemaining = remaining; + Tuple2 split = take_while(remaining, c -> !finalRemaining.startsWith("/*") && !finalRemaining.startsWith("//")); + result.append(split._1); + remaining = split._2; + } + } + + return result.toString(); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java index 0230b235..9ec36a1f 100644 --- a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java +++ b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java @@ -398,4 +398,95 @@ void testParens() throws org.biscuitsec.biscuit.error.Error.Execution { assertEquals(new org.biscuitsec.biscuit.datalog.Term.Integer(9), value2); assertEquals("(1 + 2) * 3", ex2.print(s2).get()); } + + @Test + void testDatalogSucceeds() throws org.biscuitsec.biscuit.error.Error.Parser { + SymbolTable symbols = Biscuit.default_symbol_table(); + + String l1 = "fact1(1, 2)"; + String l2 = "fact2(\"2\")"; + String l3 = "rule1(2) <- fact2(\"2\")"; + String l4 = "check if rule1(2)"; + String toParse = String.join(";", Arrays.asList(l1, l2, l3, l4)); + + Either>, Block> output = Parser.datalog(1, symbols, toParse); + assertTrue(output.isRight()); + + Block validBlock = new Block(1, symbols); + validBlock.add_fact(l1); + validBlock.add_fact(l2); + validBlock.add_rule(l3); + validBlock.add_check(l4); + + output.forEach(block -> + assertArrayEquals(block.build().to_bytes().get(), validBlock.build().to_bytes().get()) + ); + } + + @Test + void testDatalogSucceedsArrays() throws org.biscuitsec.biscuit.error.Error.Parser { + SymbolTable symbols = Biscuit.default_symbol_table(); + + String l1 = "check if [2, 3].union([2])"; + String toParse = String.join(";", List.of(l1)); + + Either>, Block> output = Parser.datalog(1, symbols, toParse); + assertTrue(output.isRight()); + + Block validBlock = new Block(1, symbols); + validBlock.add_check(l1); + + output.forEach(block -> + assertArrayEquals(block.build().to_bytes().get(), validBlock.build().to_bytes().get()) + ); + } + + @Test + void testDatalogSucceedsArraysContains() throws org.biscuitsec.biscuit.error.Error.Parser { + SymbolTable symbols = Biscuit.default_symbol_table(); + + String l1 = "check if [2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z].contains(2020-12-04T09:46:41Z)"; + String toParse = String.join(";", List.of(l1)); + + Either>, Block> output = Parser.datalog(1, symbols, toParse); + assertTrue(output.isRight()); + + Block validBlock = new Block(1, symbols); + validBlock.add_check(l1); + + output.forEach(block -> + assertArrayEquals(block.build().to_bytes().get(), validBlock.build().to_bytes().get()) + ); + } + + @Test + void testDatalogFailed() { + SymbolTable symbols = Biscuit.default_symbol_table(); + + String l1 = "fact(1)"; + String l2 = "check fact(1)"; // typo missing "if" + String toParse = String.join(";", Arrays.asList(l1, l2)); + + Either>, Block> output = Parser.datalog(1, symbols, toParse); + assertTrue(output.isLeft()); + } + + @Test + void testDatalogRemoveComment() throws org.biscuitsec.biscuit.error.Error.Parser { + SymbolTable symbols = Biscuit.default_symbol_table(); + + String l0 = "// test comment"; + String l1 = "fact1(1, 2);"; + String l2 = "fact2(\"2\");"; + String l3 = "rule1(2) <- fact2(\"2\");"; + String l4 = "// another comment"; + String l5 = "/* test multiline"; + String l6 = "comment */ check if rule1(2);"; + String l7 = " /* another multiline"; + String l8 = "comment */"; + String toParse = String.join("", Arrays.asList(l0, l1, l2, l3, l4, l5, l6, l7, l8)); + + Either>, Block> output = Parser.datalog(1, symbols, toParse); + assertTrue(output.isRight()); + } } \ No newline at end of file diff --git a/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java b/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java index bb5907c4..ff097c0a 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java @@ -45,6 +45,47 @@ Stream jsonTest() { return sample.testcases.stream().map(t -> process_testcase(t, publicKey, keyPair)); } + void compareBlocks(List sampleBlocks, List blocks) { + assertEquals(sampleBlocks.size(), blocks.size()); + List> comparisonList = IntStream.range(0, sampleBlocks.size()) + .mapToObj(i -> new Tuple2<>(sampleBlocks.get(i), blocks.get(i))) + .collect(Collectors.toList()); + + // for each token we start from the base symbol table + SymbolTable baseSymbols = new SymbolTable(); + + io.vavr.collection.Stream.ofAll(comparisonList).zipWithIndex().forEach(indexedItem -> { + compareBlock(baseSymbols, indexedItem._2, indexedItem._1._1, indexedItem._1._2); + }); + } + + void compareBlock(SymbolTable baseSymbols, long sampleBlockIndex, Block sampleBlock, org.biscuitsec.biscuit.token.Block block) { + Option sampleExternalKey = sampleBlock.getExternalKey(); + List samplePublicKeys = sampleBlock.getPublicKeys(); + String sampleDatalog = sampleBlock.getCode().replace("\"","\\\""); + SymbolTable sampleSymbols = new SymbolTable(sampleBlock.symbols); + + Either>, org.biscuitsec.biscuit.token.builder.Block> outputSample = Parser.datalog(sampleBlockIndex, baseSymbols, sampleDatalog); + assertTrue(outputSample.isRight()); + + if (!block.publicKeys.isEmpty()) { + outputSample.get().addPublicKeys(samplePublicKeys); + } + + if (!block.externalKey.isDefined()) { + sampleSymbols.symbols.forEach(baseSymbols::add); + } else { + SymbolTable freshSymbols = new SymbolTable(); + sampleSymbols.symbols.forEach(freshSymbols::add); + outputSample.get().setExternalKey(sampleExternalKey); + } + + System.out.println(outputSample.get().build().print(sampleSymbols)); + System.out.println(block.symbols.symbols); + System.out.println(block.print(sampleSymbols)); + assertArrayEquals(outputSample.get().build().to_bytes().get(), block.to_bytes().get()); + } + DynamicTest process_testcase(final TestCase testCase, final PublicKey publicKey, final KeyPair privateKey) { return DynamicTest.dynamicTest(testCase.title + ": "+testCase.filename, () -> { System.out.println("Testcase name: \""+testCase.title+"\""); @@ -63,6 +104,12 @@ DynamicTest process_testcase(final TestCase testCase, final PublicKey publicKey, Biscuit token = Biscuit.from_bytes(data, publicKey); assertArrayEquals(token.serialize(), data); + List allBlocks = new ArrayList<>(); + allBlocks.add(token.authority); + allBlocks.addAll(token.blocks); + + compareBlocks(testCase.token, allBlocks); + List revocationIds = token.revocation_identifiers(); JsonArray validationRevocationIds = validation.getAsJsonArray("revocation_ids"); assertEquals(revocationIds.size(), validationRevocationIds.size());