From be696d772cc7d7b11ffd51db7b8cce493a3486fe Mon Sep 17 00:00:00 2001 From: kannar Date: Mon, 8 Apr 2024 18:06:41 +0200 Subject: [PATCH] wip: add block comparison between sample definition and its serialization --- .../org/biscuitsec/biscuit/token/Block.java | 7 +- .../biscuit/token/builder/Block.java | 2 - .../builder/parser/ExpressionParser.java | 3 + .../biscuit/token/builder/parser/Parser.java | 125 ++++++++++++++++-- .../biscuit/builder/parser/ParserTest.java | 42 +++++- .../biscuit/datalog/ExpressionTest.java | 24 ++++ .../biscuitsec/biscuit/token/SamplesTest.java | 98 ++++++++++++-- 7 files changed, 272 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/biscuitsec/biscuit/token/Block.java b/src/main/java/org/biscuitsec/biscuit/token/Block.java index 382224a8..2309fad7 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/Block.java +++ b/src/main/java/org/biscuitsec/biscuit/token/Block.java @@ -90,11 +90,16 @@ public String print(SymbolTable symbol_table) { s.append(this.symbols.symbols); s.append("\n\t\tcontext: "); s.append(this.context); + s.append("\n\t\tscopes: ["); + for (Scope scope : this.scopes) { + s.append("\n\t\t\t"); + s.append(symbol_table.print_scope(scope)); + } if(this.externalKey.isDefined()) { s.append("\n\t\texternal key: "); s.append(this.externalKey.get().toString()); } - s.append("\n\t\tfacts: ["); + s.append("\n\t\t]\n\t\tfacts: ["); for (Fact f : this.facts) { s.append("\n\t\t\t"); s.append(symbol_table.print_fact(f)); diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java index 6a17bbdf..c4591e03 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java @@ -150,8 +150,6 @@ public org.biscuitsec.biscuit.token.Block build() { publicKeys.add(this.symbols.publicKeys().get(i)); } - publicKeys.addAll(this.publicKeys); - SchemaVersion schemaVersion = new SchemaVersion(this.facts, this.rules, this.checks, this.scopes); return new org.biscuitsec.biscuit.token.Block(symbols, this.context, this.facts, this.rules, this.checks, diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java index 9c3c00a2..7050bfa5 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java @@ -517,6 +517,9 @@ public static Either> binary_op6(String s) } public static Either> binary_op7(String s) { + if(s.startsWith("intersection")) { + return Either.right(new Tuple2<>(s.substring(12), Expression.Op.Intersection)); + } if(s.startsWith("contains")) { return Either.right(new Tuple2<>(s.substring(8), Expression.Op.Contains)); } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java index f8d21676..fee51e66 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java @@ -16,6 +16,10 @@ import java.util.function.Function; public class Parser { + public static Either>, Block> datalog(long index, SymbolTable baseSymbols, String s) { + return datalog(index, baseSymbols, null, s); + } + /** * Takes a datalog string with \n as datalog line separator. It tries to parse * each line using fact, rule, check and scope sequentially. @@ -25,26 +29,82 @@ public class Parser { * * @param index block index * @param baseSymbols symbols table + * @param blockSymbols block's custom symbols table (added to baseSymbols) * @param s datalog string to parse * @return Either>, Block> */ - public static Either>, Block> datalog(long index, SymbolTable baseSymbols, String s) { + public static Either>, Block> datalog(long index, SymbolTable baseSymbols, SymbolTable blockSymbols, String s) { Block blockBuilder = new Block(index, baseSymbols); - Map> errors = new HashMap<>(); - Stream.of(s.split("\n")).zipWithIndex().forEach(indexedLine -> { - Integer lineNumber = indexedLine._2; - String codeLine = indexedLine._1; - List lineErrors = new ArrayList<>(); + // empty block code + if (s.isEmpty()) { + return Either.right(blockBuilder); + } - fact(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_fact); - rule(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_rule); - check(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_check); - scope(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_scope); + if (blockSymbols != null) { + blockSymbols.symbols.forEach(blockBuilder::addSymbol); + } - if (lineErrors.size() > 3) { - errors.put(lineNumber, lineErrors); - } + Map> errors = new HashMap<>(); + + s = removeCommentsAndWhitespaces(s); + String[] codeLines = s.split(";"); + + Stream.of(codeLines) + .zipWithIndex() + .forEach(indexedLine -> { + String code = indexedLine._1.strip(); + + if (!code.isEmpty()) { + int lineNumber = indexedLine._2; + System.out.println("NEW CODE LINE"); + System.out.println(code); + List lineErrors = new ArrayList<>(); + + boolean parsed = false; + parsed = rule(code).fold(e -> { + lineErrors.add(e); + return false; + }, r -> { + blockBuilder.add_rule(r._2); + return true; + }); + + if (!parsed) { + parsed = scope(code).fold(e -> { + lineErrors.add(e); + return false; + }, r -> { + blockBuilder.add_scope(r._2); + return true; + }); + } + + if (!parsed) { + parsed = fact(code).fold(e -> { + lineErrors.add(e); + return false; + }, r -> { + blockBuilder.add_fact(r._2); + return true; + }); + } + + if (!parsed) { + parsed = check(code).fold(e -> { + lineErrors.add(e); + return false; + }, r -> { + blockBuilder.add_check(r._2); + return true; + }); + } + + if (!parsed) { + lineErrors.forEach(System.out::println); + errors.put(lineNumber, lineErrors); + } + } }); if (!errors.isEmpty()) { @@ -709,4 +769,43 @@ public static Tuple2 take_while(String s, Function(s.substring(0, index), s.substring(index)); } + + public static String removeCommentsAndWhitespaces(String s) { + s = removeComments(s); + s = s.replace("\n", "").replace("\\\"", "\"").strip(); + return s; + } + + public static String removeComments(String str) { + StringBuilder result = new StringBuilder(); + String remaining = str; + + while (!remaining.isEmpty()) { + remaining = space(remaining); // Skip leading whitespace + if (remaining.startsWith("/*")) { + // Find the end of the multiline comment + remaining = remaining.substring(2); // Skip "/*" + String finalRemaining = remaining; + Tuple2 split = take_while(remaining, c -> !finalRemaining.startsWith("*/")); + remaining = split._2.length() > 2 ? split._2.substring(2) : ""; // Skip "*/" + } else if (remaining.startsWith("//")) { + // Find the end of the single-line comment + remaining = remaining.substring(2); // Skip "//" + Tuple2 split = take_while(remaining, c -> c != '\n' && c != '\r'); + remaining = split._2; + if (!remaining.isEmpty()) { + result.append(remaining.charAt(0)); // Preserve line break + remaining = remaining.substring(1); + } + } else { + // Take non-comment text until the next comment or end of string + String finalRemaining = remaining; + Tuple2 split = take_while(remaining, c -> !finalRemaining.startsWith("/*") && !finalRemaining.startsWith("//")); + result.append(split._1); + remaining = split._2; + } + } + + return result.toString(); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java index fff15f76..f0b777dc 100644 --- a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java +++ b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java @@ -133,6 +133,23 @@ void testRuleWithExpressionOrdering() { res); } + @Test + void expressionIntersectionAndContainsTest() { + Either> res = + Parser.expression("[1, 2, 3].intersection([1, 2]).contains(1)"); + + assertEquals(Either.right(new Tuple2<>("", + new Expression.Binary( + Expression.Op.Contains, + new Expression.Binary( + Expression.Op.Intersection, + new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2), Utils.integer(3))))), + new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2))))) + ), + new Expression.Value(Utils.integer(1)) + ))), res); + } + @Test void ruleWithFreeExpressionVariables() { Either> res = @@ -366,11 +383,11 @@ void testParens() throws org.biscuitsec.biscuit.error.Error.Execution { void testDatalogSucceeds() throws org.biscuitsec.biscuit.error.Error.Parser { SymbolTable symbols = Biscuit.default_symbol_table(); - String l1 = "fact1(1)"; + String l1 = "fact1(1, 2)"; String l2 = "fact2(\"2\")"; String l3 = "rule1(2) <- fact2(\"2\")"; String l4 = "check if rule1(2)"; - String toParse = String.join("\n", Arrays.asList(l1, l2, l3, l4)); + String toParse = String.join(";", Arrays.asList(l1, l2, l3, l4)); Either>, Block> output = Parser.datalog(1, symbols, toParse); assertTrue(output.isRight()); @@ -392,9 +409,28 @@ void testDatalogFailed() { String l1 = "fact(1)"; String l2 = "check fact(1)"; // typo missing "if" - String toParse = String.join("\n", Arrays.asList(l1, l2)); + String toParse = String.join(";", Arrays.asList(l1, l2)); Either>, Block> output = Parser.datalog(1, symbols, toParse); assertTrue(output.isLeft()); } + + @Test + void testDatalogRemoveComment() throws org.biscuitsec.biscuit.error.Error.Parser { + SymbolTable symbols = Biscuit.default_symbol_table(); + + String l0 = "// test comment"; + String l1 = "fact1(1, 2);"; + String l2 = "fact2(\"2\");"; + String l3 = "rule1(2) <- fact2(\"2\");"; + String l4 = "// another comment"; + String l5 = "/* test multiline"; + String l6 = "comment */ check if rule1(2);"; + String l7 = " /* another multiline"; + String l8 = "comment */"; + String toParse = String.join("", Arrays.asList(l0, l1, l2, l3, l4, l5, l6, l7, l8)); + + Either>, Block> output = Parser.datalog(1, symbols, toParse); + assertTrue(output.isRight()); + } } \ No newline at end of file diff --git a/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java b/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java index c11fbebb..5e38166b 100644 --- a/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java +++ b/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java @@ -9,6 +9,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; public class ExpressionTest { @@ -114,4 +115,27 @@ public void testNegativeContainsStr() throws Error.Execution { e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbols)) ); } + + @Test + public void testIntersectionAndContains() throws Error.Execution { + SymbolTable symbols = new SymbolTable(); + + Expression e = new Expression(new ArrayList(Arrays.asList( + new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(new Term.Integer(1), new Term.Integer(2), new Term.Integer(3))))), + new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(new Term.Integer(1), new Term.Integer(2))))), + new Op.Binary(Op.BinaryOp.Intersection), + new Op.Value(new Term.Integer(1)), + new Op.Binary(Op.BinaryOp.Contains) + ))); + + assertEquals( + "[1, 2, 3].intersection([1, 2]).contains(1)", + e.print(symbols).get() + ); + + assertEquals( + new Term.Bool(true), + e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbols)) + ); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java b/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java index a670c875..26f611db 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java @@ -4,15 +4,20 @@ import com.google.gson.*; import com.google.protobuf.MapEntry; import io.vavr.Tuple2; +import io.vavr.control.Option; import org.biscuitsec.biscuit.crypto.KeyPair; import org.biscuitsec.biscuit.crypto.PublicKey; import org.biscuitsec.biscuit.datalog.Rule; import org.biscuitsec.biscuit.datalog.RunLimits; +import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.datalog.TrustedOrigins; import org.biscuitsec.biscuit.error.Error; import io.vavr.control.Either; import io.vavr.control.Try; import org.biscuitsec.biscuit.token.builder.Check; +import org.biscuitsec.biscuit.token.builder.Expression; +import org.biscuitsec.biscuit.token.builder.parser.ExpressionParser; +import org.biscuitsec.biscuit.token.builder.parser.Parser; import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.TestFactory; @@ -22,6 +27,7 @@ import java.time.Duration; import java.util.*; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.*; @@ -39,6 +45,48 @@ Stream jsonTest() { return sample.testcases.stream().map(t -> process_testcase(t, publicKey, keyPair)); } + void compareBlocks(List sampleBlocks, List blocks) { + assertEquals(sampleBlocks.size(), blocks.size()); + List> comparisonList = IntStream.range(0, sampleBlocks.size()) + .mapToObj(i -> new Tuple2<>(sampleBlocks.get(i), blocks.get(i))) + .collect(Collectors.toList()); + + // for each token we start from the base symbol table + SymbolTable baseSymbols = new SymbolTable(); + + io.vavr.collection.Stream.ofAll(comparisonList).zipWithIndex().forEach(indexedItem -> { + compareBlock(baseSymbols, indexedItem._2, indexedItem._1._1, indexedItem._1._2); + }); + } + + void compareBlock(SymbolTable baseSymbols, long sampleBlockIndex, Block sampleBlock, org.biscuitsec.biscuit.token.Block block) { + Option sampleExternalKey = sampleBlock.getExternalKey(); + List samplePublicKeys = sampleBlock.getPublicKeys(); + String sampleDatalog = sampleBlock.getCode().replace("\"","\\\""); + SymbolTable sampleSymbols = new SymbolTable(sampleBlock.symbols); + + Either>, org.biscuitsec.biscuit.token.builder.Block> outputSample = Parser.datalog(sampleBlockIndex, baseSymbols, sampleDatalog); + assertTrue(outputSample.isRight()); + + if (!block.publicKeys.isEmpty()) { + outputSample.get().addPublicKeys(samplePublicKeys); + } + + if (!block.externalKey.isDefined()) { + sampleSymbols.symbols.forEach(baseSymbols::add); + } else { + SymbolTable freshSymbols = new SymbolTable(); + sampleSymbols.symbols.forEach(freshSymbols::add); + outputSample.get().setExternalKey(sampleExternalKey); + } + + System.out.println("mdr"); + System.out.println(outputSample.get().build().print(sampleSymbols)); + System.out.println(block.symbols.symbols); + System.out.println(block.print(sampleSymbols)); + assertArrayEquals(outputSample.get().build().to_bytes().get(), block.to_bytes().get()); + } + DynamicTest process_testcase(final TestCase testCase, final PublicKey publicKey, final KeyPair privateKey) { return DynamicTest.dynamicTest(testCase.title + ": "+testCase.filename, () -> { System.out.println("Testcase name: \""+testCase.title+"\""); @@ -57,6 +105,12 @@ DynamicTest process_testcase(final TestCase testCase, final PublicKey publicKey, Biscuit token = Biscuit.from_bytes(data, publicKey); assertArrayEquals(token.serialize(), data); + List allBlocks = new ArrayList<>(); + allBlocks.add(token.authority); + allBlocks.addAll(token.blocks); + + compareBlocks(testCase.token, allBlocks); + List revocationIds = token.revocation_identifiers(); JsonArray validationRevocationIds = validation.getAsJsonArray("revocation_ids"); assertEquals(revocationIds.size(), validationRevocationIds.size()); @@ -142,19 +196,11 @@ DynamicTest process_testcase(final TestCase testCase, final PublicKey publicKey, }); } - class Block { List symbols; - - public String getCode() { - return code; - } - - public void setCode(String code) { - this.code = code; - } - String code; + List public_keys; + String external_key; public List getSymbols() { return symbols; @@ -163,6 +209,38 @@ public List getSymbols() { public void setSymbols(List symbols) { this.symbols = symbols; } + + public String getCode() { return code; } + + public void setCode(String code) { this.code = code; } + + public List getPublicKeys() { + return this.public_keys.stream() + .map(pk -> + Parser.publicKey(pk).fold(e -> { throw new IllegalArgumentException(e.toString());}, r -> r._2) + ) + .collect(Collectors.toList()); + } + + public void setPublicKeys(List publicKeys) { + this.public_keys = publicKeys.stream() + .map(PublicKey::toString) + .collect(Collectors.toList()); + } + + public Option getExternalKey() { + if (this.external_key != null) { + PublicKey externalKey = Parser.publicKey(this.external_key) + .fold(e -> { throw new IllegalArgumentException(e.toString());}, r -> r._2); + return Option.of(externalKey); + } else { + return Option.none(); + } + } + + public void setExternalKey(Option externalKey) { + this.external_key = externalKey.map(PublicKey::toString).getOrElse((String) null); + } } class TestCase {