Skip to content

Commit

Permalink
refactor the builder API to avoid symbol table manipulation (#103)
Browse files Browse the repository at this point in the history
* refactor the builder API to avoid symbol table manipulation

the builder API was using symbol tables internally, which meant that
datalog elements were converted too early, and that made the API too
complicate. This aligns the builder implementation with the approach use
in the rust implementation, where the builder types only contain other
builder types, and the conversion with the symbol table is only done
when calling the build() method.
This also removes some functions from the public API that could be
misused to get invalid symbol tables

* Fix parser precedence

The datalog parser had operator precedence issues that are now fixed by following closely the rust implementation
  • Loading branch information
Geal authored Jun 15, 2024
1 parent ff3cb37 commit ad9f0e8
Show file tree
Hide file tree
Showing 17 changed files with 546 additions and 254 deletions.
41 changes: 26 additions & 15 deletions src/main/java/org/biscuitsec/biscuit/datalog/Check.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,6 @@ public List<Rule> queries() {
return queries;
}

@Override
public int hashCode() {
return Objects.hash(queries);
}

@Override
public boolean equals(Object o) {
return super.equals(o);
}

@Override
public String toString() {
return super.toString();
}

public Schema.CheckV2 serialize() {
Schema.CheckV2.Builder b = Schema.CheckV2.newBuilder();

Expand Down Expand Up @@ -95,4 +80,30 @@ static public Either<Error.FormatError, Check> deserializeV2(Schema.CheckV2 chec

return Right(new Check(kind, queries));
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

Check check = (Check) o;

if (kind != check.kind) return false;
return Objects.equals(queries, check.queries);
}

@Override
public int hashCode() {
int result = kind != null ? kind.hashCode() : 0;
result = 31 * result + (queries != null ? queries.hashCode() : 0);
return result;
}

@Override
public String toString() {
return "Check{" +
"kind=" + kind +
", queries=" + queries +
'}';
}
}
32 changes: 32 additions & 0 deletions src/main/java/org/biscuitsec/biscuit/datalog/Rule.java
Original file line number Diff line number Diff line change
Expand Up @@ -243,4 +243,36 @@ static public Either<Error.FormatError, Rule> deserializeV2(Schema.RuleV2 rule)
return Right(new Rule(res.get(), body, expressions, scopes));
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

Rule rule = (Rule) o;

if (!Objects.equals(head, rule.head)) return false;
if (!Objects.equals(body, rule.body)) return false;
if (!Objects.equals(expressions, rule.expressions)) return false;
return Objects.equals(scopes, rule.scopes);
}

@Override
public int hashCode() {
int result = head != null ? head.hashCode() : 0;
result = 31 * result + (body != null ? body.hashCode() : 0);
result = 31 * result + (expressions != null ? expressions.hashCode() : 0);
result = 31 * result + (scopes != null ? scopes.hashCode() : 0);
return result;
}

@Override
public String toString() {
return "Rule{" +
"head=" + head +
", body=" + body +
", expressions=" + expressions +
", scopes=" + scopes +
'}';
}
}
18 changes: 18 additions & 0 deletions src/main/java/org/biscuitsec/biscuit/datalog/Scope.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,22 @@ public String toString() {
", publicKey=" + publicKey +
'}';
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

Scope scope = (Scope) o;

if (publicKey != scope.publicKey) return false;
return kind == scope.kind;
}

@Override
public int hashCode() {
int result = kind != null ? kind.hashCode() : 0;
result = 31 * result + (int) (publicKey ^ (publicKey >>> 32));
return result;
}
}
28 changes: 28 additions & 0 deletions src/main/java/org/biscuitsec/biscuit/datalog/SymbolTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,32 @@ public List<String> getAllSymbols() {
allSymbols.addAll(symbols);
return allSymbols;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

SymbolTable that = (SymbolTable) o;

if (!dateTimeFormatter.equals(that.dateTimeFormatter)) return false;
if (!symbols.equals(that.symbols)) return false;
return publicKeys.equals(that.publicKeys);
}

@Override
public int hashCode() {
int result = dateTimeFormatter.hashCode();
result = 31 * result + symbols.hashCode();
result = 31 * result + publicKeys.hashCode();
return result;
}

@Override
public String toString() {
return "SymbolTable{" +
"symbols=" + symbols +
", publicKeys=" + publicKeys +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
import io.vavr.control.Either;
import io.vavr.control.Option;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.Map;
import java.util.*;

import static io.vavr.API.Left;
import static io.vavr.API.Right;
Expand Down Expand Up @@ -78,4 +75,26 @@ static public Either<Error.FormatError, Expression> deserializeV2(Schema.Express

return Right(new Expression(ops));
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

Expression that = (Expression) o;

return Objects.equals(ops, that.ops);
}

@Override
public int hashCode() {
return ops != null ? ops.hashCode() : 0;
}

@Override
public String toString() {
return "Expression{" +
"ops=" + ops +
'}';
}
}
47 changes: 17 additions & 30 deletions src/main/java/org/biscuitsec/biscuit/token/Biscuit.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class Biscuit extends UnverifiedBiscuit {
* @return
*/
public static org.biscuitsec.biscuit.token.builder.Biscuit builder(final KeyPair root) {
return new org.biscuitsec.biscuit.token.builder.Biscuit(new SecureRandom(), root, default_symbol_table());
return new org.biscuitsec.biscuit.token.builder.Biscuit(new SecureRandom(), root);
}

/**
Expand All @@ -42,34 +42,20 @@ public static org.biscuitsec.biscuit.token.builder.Biscuit builder(final KeyPair
* @return
*/
public static org.biscuitsec.biscuit.token.builder.Biscuit builder(final SecureRandom rng, final KeyPair root) {
return new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root, default_symbol_table());
return new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root);
}

/**
* Creates a token builder
*
* @param rng random number generator
* @param root root private key
* @param symbols symbol table
* @return
*/
public static org.biscuitsec.biscuit.token.builder.Biscuit builder(final SecureRandom rng, final KeyPair root, final Option<Integer> root_key_id, SymbolTable symbols) {
return new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root, root_key_id, symbols);
public static org.biscuitsec.biscuit.token.builder.Biscuit builder(final SecureRandom rng, final KeyPair root, final Option<Integer> root_key_id) {
return new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root, root_key_id);
}

/**
* Creates a token builder
*
* @param rng random number generator
* @param root root private key
* @param symbols symbol table
* @return
*/
public static org.biscuitsec.biscuit.token.builder.Biscuit builder(final SecureRandom rng, final KeyPair root, SymbolTable symbols) {
return new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root, symbols);
}


/**
* Creates a token
*
Expand All @@ -78,8 +64,8 @@ public static org.biscuitsec.biscuit.token.builder.Biscuit builder(final SecureR
* @param authority authority block
* @return Biscuit
*/
static public Biscuit make(final SecureRandom rng, final KeyPair root, final SymbolTable symbols, final Block authority) throws Error.SymbolTableOverlap, Error.FormatError {
return Biscuit.make(rng, root, Option.none(), symbols, authority);
public static Biscuit make(final SecureRandom rng, final KeyPair root, final Block authority) throws Error.FormatError {
return Biscuit.make(rng, root, Option.none(), authority);
}

/**
Expand All @@ -90,8 +76,8 @@ static public Biscuit make(final SecureRandom rng, final KeyPair root, final Sym
* @param authority authority block
* @return Biscuit
*/
static public Biscuit make(final SecureRandom rng, final KeyPair root, final Integer root_key_id, final SymbolTable symbols, final Block authority) throws Error.SymbolTableOverlap, Error.FormatError {
return Biscuit.make(rng, root, Option.of(root_key_id), symbols, authority);
public static Biscuit make(final SecureRandom rng, final KeyPair root, final Integer root_key_id, final Block authority) throws Error.FormatError {
return Biscuit.make(rng, root, Option.of(root_key_id), authority);
}

/**
Expand All @@ -102,12 +88,7 @@ static public Biscuit make(final SecureRandom rng, final KeyPair root, final Int
* @param authority authority block
* @return Biscuit
*/
static private Biscuit make(final SecureRandom rng, final KeyPair root, final Option<Integer> root_key_id, final SymbolTable symbols, final Block authority) throws Error.SymbolTableOverlap, Error.FormatError {
if (!Collections.disjoint(symbols.symbols, authority.symbols.symbols)) {
throw new Error.SymbolTableOverlap();
}

symbols.symbols.addAll(authority.symbols.symbols);
static private Biscuit make(final SecureRandom rng, final KeyPair root, final Option<Integer> root_key_id, final Block authority) throws Error.FormatError {
ArrayList<Block> blocks = new ArrayList<>();

KeyPair next = new KeyPair(rng);
Expand All @@ -122,7 +103,7 @@ static private Biscuit make(final SecureRandom rng, final KeyPair root, final Op
HashMap<Long, List<Long>> publicKeyToBlockId = new HashMap<>();

Option<SerializedBiscuit> c = Option.some(s);
return new Biscuit(authority, blocks, symbols, s, publicKeyToBlockId, revocation_ids, root_key_id);
return new Biscuit(authority, blocks, authority.symbols, s, publicKeyToBlockId, revocation_ids, root_key_id);
}
}

Expand Down Expand Up @@ -325,7 +306,13 @@ public String serialize_b64url() throws Error.FormatError.SerializationError {
public Biscuit attenuate(org.biscuitsec.biscuit.token.builder.Block block) throws Error {
SecureRandom rng = new SecureRandom();
KeyPair keypair = new KeyPair(rng);
return attenuate(rng, keypair, block.build());
SymbolTable builderSymbols = new SymbolTable(this.symbols);
return attenuate(rng, keypair, block.build(builderSymbols));
}

public Biscuit attenuate(final SecureRandom rng, final KeyPair keypair,org.biscuitsec.biscuit.token.builder.Block block) throws Error {
SymbolTable builderSymbols = new SymbolTable(this.symbols);
return attenuate(rng, keypair, block.build(builderSymbols));
}

/**
Expand Down
47 changes: 47 additions & 0 deletions src/main/java/org/biscuitsec/biscuit/token/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -313,5 +313,52 @@ public Either<Error.FormatError, byte[]> to_bytes() {
return Left(new Error.FormatError.SerializationError(e.toString()));
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

Block block = (Block) o;

if (version != block.version) return false;
if (!Objects.equals(symbols, block.symbols)) return false;
if (!Objects.equals(context, block.context)) return false;
if (!Objects.equals(facts, block.facts)) return false;
if (!Objects.equals(rules, block.rules)) return false;
if (!Objects.equals(checks, block.checks)) return false;
if (!Objects.equals(scopes, block.scopes)) return false;
if (!Objects.equals(publicKeys, block.publicKeys)) return false;
return Objects.equals(externalKey, block.externalKey);
}

@Override
public int hashCode() {
int result = symbols != null ? symbols.hashCode() : 0;
result = 31 * result + (context != null ? context.hashCode() : 0);
result = 31 * result + (facts != null ? facts.hashCode() : 0);
result = 31 * result + (rules != null ? rules.hashCode() : 0);
result = 31 * result + (checks != null ? checks.hashCode() : 0);
result = 31 * result + (scopes != null ? scopes.hashCode() : 0);
result = 31 * result + (publicKeys != null ? publicKeys.hashCode() : 0);
result = 31 * result + (externalKey != null ? externalKey.hashCode() : 0);
result = 31 * result + (int) (version ^ (version >>> 32));
return result;
}

@Override
public String toString() {
return "Block{" +
"symbols=" + symbols +
", context='" + context + '\'' +
", facts=" + facts +
", rules=" + rules +
", checks=" + checks +
", scopes=" + scopes +
", publicKeys=" + publicKeys +
", externalKey=" + externalKey +
", version=" + version +
'}';
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public String serialize_b64url() throws Error.FormatError.SerializationError {
* @return
*/
public org.biscuitsec.biscuit.token.builder.Block create_block() {
return new org.biscuitsec.biscuit.token.builder.Block(1 + this.blocks.size(), new SymbolTable(this.symbols));
return new org.biscuitsec.biscuit.token.builder.Block(1 + this.blocks.size());
}

/**
Expand All @@ -141,7 +141,13 @@ public org.biscuitsec.biscuit.token.builder.Block create_block() {
public UnverifiedBiscuit attenuate(org.biscuitsec.biscuit.token.builder.Block block) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error {
SecureRandom rng = new SecureRandom();
KeyPair keypair = new KeyPair(rng);
return attenuate(rng, keypair, block.build());
SymbolTable builderSymbols = new SymbolTable(this.symbols);
return attenuate(rng, keypair, block.build(builderSymbols));
}

public UnverifiedBiscuit attenuate(final SecureRandom rng, final KeyPair keypair, org.biscuitsec.biscuit.token.builder.Block block) throws Error {
SymbolTable builderSymbols = new SymbolTable(this.symbols);
return attenuate(rng, keypair, block.build(builderSymbols));
}

/**
Expand Down
Loading

0 comments on commit ad9f0e8

Please sign in to comment.