diff --git a/README.md b/README.md index d0bf277..b851338 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # BBS+ Signature Verifier Smart Contract Solidity implementation of the BBS+ Signature Verifier smart contract -- To build: `forge build` -- Run test: `forge test` +- To build: `forge build --via-ir` +- Run test: `forge test --via-ir` diff --git a/src/bbs_verify.sol b/src/bbs_verify.sol index 73f7316..896ca74 100644 --- a/src/bbs_verify.sol +++ b/src/bbs_verify.sol @@ -590,11 +590,9 @@ contract BBS_Verifier { uint8[] memory undisclosedIndices = complement(uint8(u), uint8(r), disclosedIndices); uint256 domain = calculate_domain(pk, uint64(l)); - Pairing.G1Point memory t1 = Pairing.scalar_mul(proof.bBar, proof.challenge); - Pairing.G1Point memory t11 = Pairing.scalar_mul(proof.aBar, proof.eCap); - Pairing.G1Point memory t12 = Pairing.scalar_mul(proof.d, proof.r1Cap); - t1 = Pairing.plus(t1, t11); - t1 = Pairing.plus(t1, t12); + Pairing.G1Point memory temp1 = Pairing.plus(Pairing.scalar_mul(proof.aBar, proof.eCap), Pairing.scalar_mul(proof.d, proof.r1Cap)); + + Pairing.G1Point memory t1 = Pairing.plus(Pairing.scalar_mul(proof.bBar, proof.challenge), temp1); Pairing.G1Point memory bv1 = Pairing.scalar_mul(BBS.generators()[0], domain); Pairing.G1Point memory bv = Pairing.plus(BBS.BP1(), bv1); @@ -605,9 +603,7 @@ contract BBS_Verifier { } uint256 challenge = proof.challenge; Pairing.G1Point memory d = proof.d; - Pairing.G1Point memory t21 = Pairing.scalar_mul(bv, challenge); - Pairing.G1Point memory t22 = Pairing.scalar_mul(d, proof.r3Cap); - Pairing.G1Point memory t2 = Pairing.plus(t21, t22); + Pairing.G1Point memory t2 = Pairing.plus(Pairing.scalar_mul(bv, challenge), Pairing.scalar_mul(d, proof.r3Cap)); for (uint256 i = 0; i < u; i++) { t2 = Pairing.plus(t2, Pairing.scalar_mul(BBS.generators()[undisclosedIndices[i] + 1], proof.commitments[i])); @@ -644,23 +640,66 @@ contract BBS_Verifier { ) public pure returns (uint256) { require(disclosedMsg.length == disclosedIndices.length, "invalid length"); - bytes memory serializeBytes = uint64ToBytes(disclosedIndices.length); + uint256 totalLength = 8 + disclosedMsg.length * (8 + 32) + initProof.points.length * 64 + 32 + 8; + bytes memory serializeBytes = new bytes(totalLength); + + uint256 serializeBytesPtr; + assembly { + serializeBytesPtr := add(serializeBytes, 0x20) + } + + bytes memory lengthBytes = uint64ToBytes(disclosedIndices.length); + assembly { + let lenPtr := add(lengthBytes, 0x20) + mstore(serializeBytesPtr, mload(lenPtr)) // Copy the lengthBytes (8 bytes) + serializeBytesPtr := add(serializeBytesPtr, 8) + } + // Serialize disclosedIndices and disclosedMsg for (uint256 i = 0; i < disclosedMsg.length; i++) { - serializeBytes = abi.encodePacked(serializeBytes, uint64ToBytes(uint64(disclosedIndices[i]))); - serializeBytes = abi.encodePacked(serializeBytes, reverseBytes(uintToBytes(disclosedMsg[i]))); + bytes memory indexBytes = uint64ToBytes(uint64(disclosedIndices[i])); + bytes memory msgBytes = reverseBytes(uintToBytes(disclosedMsg[i])); + + assembly { + let indexPtr := add(indexBytes, 0x20) + mstore(serializeBytesPtr, mload(indexPtr)) + serializeBytesPtr := add(serializeBytesPtr, 8) + } + + // Concatenate msgBytes (32 bytes) + assembly { + let msgPtr := add(msgBytes, 0x20) + mstore(serializeBytesPtr, mload(msgPtr)) + serializeBytesPtr := add(serializeBytesPtr, 32) + } } + // Serialize G1 points for (uint256 i = 0; i < initProof.points.length; i++) { - serializeBytes = abi.encodePacked(serializeBytes, g1ToBytes(initProof.points[i])); + bytes memory pointBytes = g1ToBytes(initProof.points[i]); + + assembly { + let pointPtr := add(pointBytes, 0x20) + mstore(serializeBytesPtr, mload(pointPtr)) + mstore(add(serializeBytesPtr, 0x20), mload(add(pointPtr, 0x20))) // Copy 64 bytes for G1 point + serializeBytesPtr := add(serializeBytesPtr, 64) + } } - serializeBytes = abi.encodePacked(serializeBytes, reverseBytes(uintToBytes(initProof.scalar))); - bytes1 zeroByte = 0x00; - serializeBytes = abi.encodePacked( - serializeBytes, zeroByte, zeroByte, zeroByte, zeroByte, zeroByte, zeroByte, zeroByte, zeroByte - ); + // Serialize scalar (32 bytes) + bytes memory scalarBytes = reverseBytes(uintToBytes(initProof.scalar)); + assembly { + let scalarPtr := add(scalarBytes, 0x20) + mstore(serializeBytesPtr, mload(scalarPtr)) // Copy 32 bytes + serializeBytesPtr := add(serializeBytesPtr, 32) + } + + assembly { + mstore(serializeBytesPtr, 0) // Zero bytes + serializeBytesPtr := add(serializeBytesPtr, 8) + } + // Return the calculated hash using BBS.hashToScalar return BBS.hashToScalar(serializeBytes, dst); } }