diff --git a/src/main/java/org/biscuitsec/biscuit/token/Block.java b/src/main/java/org/biscuitsec/biscuit/token/Block.java index 382224a8..659ee075 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/Block.java +++ b/src/main/java/org/biscuitsec/biscuit/token/Block.java @@ -90,11 +90,16 @@ public String print(SymbolTable symbol_table) { s.append(this.symbols.symbols); s.append("\n\t\tcontext: "); s.append(this.context); + s.append("\n\t\tscopes: ["); + for (Scope scope : this.scopes) { + s.append("\n\t\t\t"); + s.append(symbol_table.print_scope(scope)); + } if(this.externalKey.isDefined()) { s.append("\n\t\texternal key: "); s.append(this.externalKey.get().toString()); } - s.append("\n\t\tfacts: ["); + s.append("\n\t\t]\n\t\tfacts: ["); for (Fact f : this.facts) { s.append("\n\t\t\t"); s.append(symbol_table.print_fact(f)); @@ -208,6 +213,7 @@ static public Either deserialize(Schema.Block b, Optio } ArrayList scopes = new ArrayList<>(); + System.out.println(b.getScopeList()); for (Schema.Scope scope: b.getScopeList()) { Either res = Scope.deserialize(scope); if(res.isLeft()) { diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java index 6a17bbdf..c4591e03 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java @@ -150,8 +150,6 @@ public org.biscuitsec.biscuit.token.Block build() { publicKeys.add(this.symbols.publicKeys().get(i)); } - publicKeys.addAll(this.publicKeys); - SchemaVersion schemaVersion = new SchemaVersion(this.facts, this.rules, this.checks, this.scopes); return new org.biscuitsec.biscuit.token.Block(symbols, this.context, this.facts, this.rules, this.checks, diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java index 9c3c00a2..7050bfa5 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java @@ -517,6 +517,9 @@ public static Either> binary_op6(String s) } public static Either> binary_op7(String s) { + if(s.startsWith("intersection")) { + return Either.right(new Tuple2<>(s.substring(12), Expression.Op.Intersection)); + } if(s.startsWith("contains")) { return Either.right(new Tuple2<>(s.substring(8), Expression.Op.Contains)); } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java index f8d21676..a1481035 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java @@ -16,6 +16,10 @@ import java.util.function.Function; public class Parser { + public static Either>, Block> datalog(long index, SymbolTable baseSymbols, String s) { + return datalog(index, baseSymbols, null, s); + } + /** * Takes a datalog string with \n as datalog line separator. It tries to parse * each line using fact, rule, check and scope sequentially. @@ -25,24 +29,48 @@ public class Parser { * * @param index block index * @param baseSymbols symbols table + * @param blockSymbols block's custom symbols table (added to baseSymbols) * @param s datalog string to parse * @return Either>, Block> */ - public static Either>, Block> datalog(long index, SymbolTable baseSymbols, String s) { + public static Either>, Block> datalog(long index, SymbolTable baseSymbols, SymbolTable blockSymbols, String s) { Block blockBuilder = new Block(index, baseSymbols); + + // empty block code + if (s.isEmpty()) { + return Either.right(blockBuilder); + } + + if (blockSymbols != null) { + blockSymbols.symbols.forEach(blockBuilder::addSymbol); + } + Map> errors = new HashMap<>(); - Stream.of(s.split("\n")).zipWithIndex().forEach(indexedLine -> { + s = removeComments(s); + s = s.replace("\n", ""); + s = s.replace("\\\"","\""); + s = s.strip(); + String[] codeLines = s.split(";"); + codeLines = Arrays.stream(codeLines) + .filter(codeLine -> !codeLine.isEmpty()) + .map(String::strip) + .toArray(String[]::new); + + Stream.of(codeLines).zipWithIndex().forEach(indexedLine -> { Integer lineNumber = indexedLine._2; String codeLine = indexedLine._1; + System.out.println("NEW CODE LINE"); + System.out.println(codeLine); List lineErrors = new ArrayList<>(); - fact(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_fact); - rule(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_rule); - check(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_check); - scope(codeLine).bimap(lineErrors::add, r -> r._2).map(blockBuilder::add_scope); + rule(codeLine).bimap(lineErrors::add, r -> r._2).forEach(blockBuilder::add_rule); + scope(codeLine).bimap(lineErrors::add, r -> r._2).forEach(blockBuilder::add_scope); + fact(codeLine).bimap(lineErrors::add, r -> r._2).forEach(blockBuilder::add_fact); + check(codeLine).bimap(lineErrors::add, r -> r._2).forEach(blockBuilder::add_check); if (lineErrors.size() > 3) { + lineErrors.forEach(System.out::println); errors.put(lineNumber, lineErrors); } }); @@ -709,4 +737,43 @@ public static Tuple2 take_while(String s, Function(s.substring(0, index), s.substring(index)); } + + public static String removeComments(String str) { + StringBuilder result = new StringBuilder(); + boolean inMultilineComment = false; + boolean inSingleLineComment = false; + + for (int i = 0; i < str.length(); i++) { + // Check for start of multiline comment if not already in a single-line comment + if (!inSingleLineComment && i < str.length() - 1 && str.charAt(i) == '/' && str.charAt(i + 1) == '*') { + inMultilineComment = true; + i++; // Skip next character to avoid false ending + } + // Check for end of multiline comment + else if (inMultilineComment && i < str.length() - 1 && str.charAt(i) == '*' && str.charAt(i + 1) == '/') { + inMultilineComment = false; + i++; // Skip next character to move past the comment end + continue; // Skip adding the comment end to the result + } + // Check for start of single-line comment if not already in a multiline comment + else if (!inMultilineComment && i < str.length() - 1 && str.charAt(i) == '/' && str.charAt(i + 1) == '/') { + inSingleLineComment = true; + i++; // Skip next character + continue; // Skip adding the start of single-line comment + } + // Check for end of line, marking end of single-line comment + else if (inSingleLineComment && (str.charAt(i) == '\n')) { + inSingleLineComment = false; + // Optionally, add the newline character to the result to preserve line structure + result.append(str.charAt(i)); + continue; // Move to the next character + } + // If not in any comment, append the current character + if (!inMultilineComment && !inSingleLineComment) { + result.append(str.charAt(i)); + } + } + + return result.toString(); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java index fff15f76..f0b777dc 100644 --- a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java +++ b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java @@ -133,6 +133,23 @@ void testRuleWithExpressionOrdering() { res); } + @Test + void expressionIntersectionAndContainsTest() { + Either> res = + Parser.expression("[1, 2, 3].intersection([1, 2]).contains(1)"); + + assertEquals(Either.right(new Tuple2<>("", + new Expression.Binary( + Expression.Op.Contains, + new Expression.Binary( + Expression.Op.Intersection, + new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2), Utils.integer(3))))), + new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2))))) + ), + new Expression.Value(Utils.integer(1)) + ))), res); + } + @Test void ruleWithFreeExpressionVariables() { Either> res = @@ -366,11 +383,11 @@ void testParens() throws org.biscuitsec.biscuit.error.Error.Execution { void testDatalogSucceeds() throws org.biscuitsec.biscuit.error.Error.Parser { SymbolTable symbols = Biscuit.default_symbol_table(); - String l1 = "fact1(1)"; + String l1 = "fact1(1, 2)"; String l2 = "fact2(\"2\")"; String l3 = "rule1(2) <- fact2(\"2\")"; String l4 = "check if rule1(2)"; - String toParse = String.join("\n", Arrays.asList(l1, l2, l3, l4)); + String toParse = String.join(";", Arrays.asList(l1, l2, l3, l4)); Either>, Block> output = Parser.datalog(1, symbols, toParse); assertTrue(output.isRight()); @@ -392,9 +409,28 @@ void testDatalogFailed() { String l1 = "fact(1)"; String l2 = "check fact(1)"; // typo missing "if" - String toParse = String.join("\n", Arrays.asList(l1, l2)); + String toParse = String.join(";", Arrays.asList(l1, l2)); Either>, Block> output = Parser.datalog(1, symbols, toParse); assertTrue(output.isLeft()); } + + @Test + void testDatalogRemoveComment() throws org.biscuitsec.biscuit.error.Error.Parser { + SymbolTable symbols = Biscuit.default_symbol_table(); + + String l0 = "// test comment"; + String l1 = "fact1(1, 2);"; + String l2 = "fact2(\"2\");"; + String l3 = "rule1(2) <- fact2(\"2\");"; + String l4 = "// another comment"; + String l5 = "/* test multiline"; + String l6 = "comment */ check if rule1(2);"; + String l7 = " /* another multiline"; + String l8 = "comment */"; + String toParse = String.join("", Arrays.asList(l0, l1, l2, l3, l4, l5, l6, l7, l8)); + + Either>, Block> output = Parser.datalog(1, symbols, toParse); + assertTrue(output.isRight()); + } } \ No newline at end of file diff --git a/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java b/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java index c11fbebb..5e38166b 100644 --- a/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java +++ b/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java @@ -9,6 +9,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; public class ExpressionTest { @@ -114,4 +115,27 @@ public void testNegativeContainsStr() throws Error.Execution { e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbols)) ); } + + @Test + public void testIntersectionAndContains() throws Error.Execution { + SymbolTable symbols = new SymbolTable(); + + Expression e = new Expression(new ArrayList(Arrays.asList( + new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(new Term.Integer(1), new Term.Integer(2), new Term.Integer(3))))), + new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(new Term.Integer(1), new Term.Integer(2))))), + new Op.Binary(Op.BinaryOp.Intersection), + new Op.Value(new Term.Integer(1)), + new Op.Binary(Op.BinaryOp.Contains) + ))); + + assertEquals( + "[1, 2, 3].intersection([1, 2]).contains(1)", + e.print(symbols).get() + ); + + assertEquals( + new Term.Bool(true), + e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbols)) + ); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java b/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java index a670c875..465bced1 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java @@ -4,15 +4,20 @@ import com.google.gson.*; import com.google.protobuf.MapEntry; import io.vavr.Tuple2; +import io.vavr.control.Option; import org.biscuitsec.biscuit.crypto.KeyPair; import org.biscuitsec.biscuit.crypto.PublicKey; import org.biscuitsec.biscuit.datalog.Rule; import org.biscuitsec.biscuit.datalog.RunLimits; +import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.datalog.TrustedOrigins; import org.biscuitsec.biscuit.error.Error; import io.vavr.control.Either; import io.vavr.control.Try; import org.biscuitsec.biscuit.token.builder.Check; +import org.biscuitsec.biscuit.token.builder.Expression; +import org.biscuitsec.biscuit.token.builder.parser.ExpressionParser; +import org.biscuitsec.biscuit.token.builder.parser.Parser; import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.TestFactory; @@ -22,6 +27,7 @@ import java.time.Duration; import java.util.*; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.*; @@ -39,6 +45,48 @@ Stream jsonTest() { return sample.testcases.stream().map(t -> process_testcase(t, publicKey, keyPair)); } + void compareBlocks(List sampleBlocks, List blocks) { + assertEquals(sampleBlocks.size(), blocks.size()); + List> comparisonList = IntStream.range(0, sampleBlocks.size()) + .mapToObj(i -> new Tuple2<>(sampleBlocks.get(i), blocks.get(i))) + .collect(Collectors.toList()); + + // for each token we start from the base symbol table + SymbolTable baseSymbols = new SymbolTable(); + + io.vavr.collection.Stream.ofAll(comparisonList).zipWithIndex().forEach(indexedItem -> { + compareBlock(baseSymbols, indexedItem._2, indexedItem._1._1, indexedItem._1._2); + }); + } + + void compareBlock(SymbolTable baseSymbols, long sampleBlockIndex, Block sampleBlock, org.biscuitsec.biscuit.token.Block block) { + Option sampleExternalKey = sampleBlock.getExternalKey(); + List samplePublicKeys = sampleBlock.getPublicKeys(); + String sampleDatalog = sampleBlock.getCode().replace("\"","\\\""); + SymbolTable sampleSymbols = new SymbolTable(sampleBlock.symbols); + + Either>, org.biscuitsec.biscuit.token.builder.Block> outputSample = Parser.datalog(sampleBlockIndex, baseSymbols, sampleDatalog); + assertTrue(outputSample.isRight()); + + if (!block.publicKeys.isEmpty()) { + outputSample.get().addPublicKeys(samplePublicKeys); + } + + if (!block.externalKey.isDefined()) { + sampleSymbols.symbols.forEach(baseSymbols::add); + } else { + SymbolTable freshSymbols = new SymbolTable(); + sampleSymbols.symbols.forEach(freshSymbols::add); + outputSample.get().setExternalKey(sampleExternalKey); + } + + System.out.println("mdr"); + System.out.println(outputSample.get().build().print(sampleSymbols)); + System.out.println(block.symbols.symbols); + System.out.println(block.print(sampleSymbols)); + assertArrayEquals(outputSample.get().build().to_bytes().get(), block.to_bytes().get()); + } + DynamicTest process_testcase(final TestCase testCase, final PublicKey publicKey, final KeyPair privateKey) { return DynamicTest.dynamicTest(testCase.title + ": "+testCase.filename, () -> { System.out.println("Testcase name: \""+testCase.title+"\""); @@ -57,6 +105,12 @@ DynamicTest process_testcase(final TestCase testCase, final PublicKey publicKey, Biscuit token = Biscuit.from_bytes(data, publicKey); assertArrayEquals(token.serialize(), data); + List allBlocks = new ArrayList<>(); + allBlocks.add(token.authority); + allBlocks.addAll(token.blocks); + + compareBlocks(testCase.token, allBlocks); + List revocationIds = token.revocation_identifiers(); JsonArray validationRevocationIds = validation.getAsJsonArray("revocation_ids"); assertEquals(revocationIds.size(), validationRevocationIds.size()); @@ -142,19 +196,11 @@ DynamicTest process_testcase(final TestCase testCase, final PublicKey publicKey, }); } - class Block { List symbols; - - public String getCode() { - return code; - } - - public void setCode(String code) { - this.code = code; - } - String code; + List public_keys; + String external_key; public List getSymbols() { return symbols; @@ -163,6 +209,38 @@ public List getSymbols() { public void setSymbols(List symbols) { this.symbols = symbols; } + + public String getCode() { return code; } + + public void setCode(String code) { this.code = code; } + + public List getPublicKeys() { + return this.public_keys.stream() + .map(pk -> { + String[] parts = pk.split("/"); + return new PublicKey(Schema.PublicKey.Algorithm.Ed25519, parts[1]); + }) + .collect(Collectors.toList()); + } + + public void setPublicKeys(List publicKeys) { + this.public_keys = publicKeys.stream() + .map(PublicKey::toString) + .collect(Collectors.toList()); + } + + public Option getExternalKey() { + if (this.external_key != null) { + String[] parts = this.external_key.split("/"); + return Option.of(new PublicKey(Schema.PublicKey.Algorithm.Ed25519, parts[1])); + } else { + return Option.none(); + } + } + + public void setExternalKey(Option externalKey) { + this.external_key = externalKey.map(PublicKey::toString).getOrElse((String) null); + } } class TestCase {