Skip to content

Commit

Permalink
musig-spec: Add Session Context to reference implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
robot-dreams committed Mar 4, 2022
1 parent 01f62b2 commit 9e773e6
Showing 1 changed file with 37 additions and 23 deletions.
60 changes: 37 additions & 23 deletions doc/musig-reference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import namedtuple
from typing import Any, List, Optional, Tuple
import hashlib
import secrets
Expand Down Expand Up @@ -101,6 +102,23 @@ def pointc(x: bytes) -> Point:
return point_negate(P)
assert False

SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'msg'])

def get_session_values(session_ctx: SessionContext) -> tuple[bytes, List[bytes], bytes]:
(aggnonce, pubkeys, msg) = session_ctx
Q = key_agg_internal(pubkeys)
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
R_1 = pointc(aggnonce[0:33])
R_2 = pointc(aggnonce[33:66])
R = point_add(R_1, point_mul(R_2, b))
assert not is_infinite(R)
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
return (Q, b, R, e)

def get_session_key_coeff(session_ctx: SessionContext, P: Point) -> int:
(_, pubkeys, _) = session_ctx
return key_agg_coeff(pubkeys, bytes_from_point(P))

def key_agg(pubkeys: List[bytes]) -> bytes:
Q = key_agg_internal(pubkeys)
return bytes_from_point(Q)
Expand Down Expand Up @@ -152,13 +170,8 @@ def nonce_agg(pubnonces: List[bytes]) -> bytes:
aggnonce += cbytes(R_i)
return aggnonce

def sign(secnonce: bytes, sk: bytes, aggnonce: bytes, pubkeys: List[bytes], msg: bytes) -> bytes:
R_1 = pointc(aggnonce[0:33])
R_2 = pointc(aggnonce[33:66])
Q = key_agg_internal(pubkeys)
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
R = point_add(R_1, point_mul(R_2, b))
assert not is_infinite(R)
def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes:
(Q, b, R, e) = get_session_values(session_ctx)
k_1_ = int_from_bytes(secnonce[0:32])
k_2_ = int_from_bytes(secnonce[32:64])
assert 0 < k_1_ < n
Expand All @@ -168,35 +181,30 @@ def sign(secnonce: bytes, sk: bytes, aggnonce: bytes, pubkeys: List[bytes], msg:
d_ = int_from_bytes(sk)
assert 0 < d_ < n
P = point_mul(G, d_)
mu = get_session_key_coeff(session_ctx, P)
d = n - d_ if has_even_y(P) != has_even_y(Q) else d_
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
mu = key_agg_coeff(pubkeys, bytes_from_point(P))
s = (k_1 + b * k_2 + e * mu * d) % n
psig = bytes_from_int(s)
pubnonce = cbytes(point_mul(G, k_1_)) + cbytes(point_mul(G, k_2_))
assert partial_sig_verify_internal(psig, pubnonce, aggnonce, pubkeys, bytes_from_point(P), msg)
assert partial_sig_verify_internal(psig, pubnonce, bytes_from_point(P), session_ctx)
return psig

def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], msg: bytes, i: int) -> bool:
aggnonce = nonce_agg(pubnonces)
return partial_sig_verify_internal(psig, pubnonces[i], aggnonce, pubkeys, pubkeys[i], msg)
session_ctx = SessionContext(aggnonce, pubkeys, msg)
return partial_sig_verify_internal(psig, pubnonces[i], pubkeys[i], session_ctx)

def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, aggnonce: bytes, pubkeys: List[bytes], pk: bytes, msg: bytes) -> bool:
def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk: bytes, session_ctx: SessionContext) -> bool:
(Q, b, R, e) = get_session_values(session_ctx)
s = int_from_bytes(psig)
assert s < n
R_1 = pointc(aggnonce[0:33])
R_2 = pointc(aggnonce[33:66])
Q = key_agg_internal(pubkeys)
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
R = point_add(R_1, point_mul(R_2, b))
R_1_ = pointc(pubnonce[0:33])
R_2_ = pointc(pubnonce[33:66])
R__ = point_add(R_1_, point_mul(R_2_, b))
R_ = R__ if has_even_y(R) else point_negate(R__)
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
mu = key_agg_coeff(pubkeys, pk)
P_ = lift_x(pk)
P = P_ if has_even_y(Q) else point_negate(P_)
mu = get_session_key_coeff(session_ctx, P)
return point_mul(G, s) == point_add(R_, point_mul(P, e * mu % n))

#
Expand Down Expand Up @@ -252,9 +260,14 @@ def test_sign_vectors():

pk = bytes_from_point(point_mul(G, int_from_bytes(sk)))

assert sign(secnonce, sk, aggnonce, [pk, X[0], X[1]], msg) == expected[0]
assert sign(secnonce, sk, aggnonce, [X[0], pk, X[1]], msg) == expected[1]
assert sign(secnonce, sk, aggnonce, [X[0], X[1], pk], msg) == expected[2]
session_ctx = SessionContext(aggnonce, [pk, X[0], X[1]], msg)
assert sign(secnonce, sk, session_ctx) == expected[0]

session_ctx = SessionContext(aggnonce, [X[0], pk, X[1]], msg)
assert sign(secnonce, sk, session_ctx) == expected[1]

session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], msg)
assert sign(secnonce, sk, session_ctx) == expected[2]

def test_sign_and_verify_random(iters):
for i in range(iters):
Expand All @@ -271,7 +284,8 @@ def test_sign_and_verify_random(iters):

msg = secrets.token_bytes(32)

psig = sign(secnonce_1, sk_1, aggnonce, pubkeys, msg)
session_ctx = SessionContext(aggnonce, pubkeys, msg)
psig = sign(secnonce_1, sk_1, session_ctx)
assert partial_sig_verify(psig, pubnonces, pubkeys, msg, 0)

# Wrong signer index
Expand Down

0 comments on commit 9e773e6

Please sign in to comment.