From ad6a2c4b0182e1a745fd0602a51a4303bf507f74 Mon Sep 17 00:00:00 2001 From: k-matsuzawa <49718559+ko-matsu@users.noreply.github.com> Date: Tue, 1 Dec 2020 19:07:04 +0900 Subject: [PATCH] update to v0.2.3 (#4) * update from cryptogarageinc v0.2.5 Co-authored-by: k-matsuzawa --- .github/workflows/check_pre-merge_develop.yml | 2 +- .github/workflows/check_pre-merge_master.yml | 2 +- .github/workflows/check_pre-merge_sprint.yml | 2 +- VERSION | 2 +- cfd/address.py | 62 ++-- cfd/confidential_address.py | 18 +- cfd/confidential_transaction.py | 346 ++++++++++++------ cfd/descriptor.py | 75 ++-- cfd/hdwallet.py | 140 ++++--- cfd/key.py | 150 +++++--- cfd/script.py | 25 +- cfd/transaction.py | 166 +++++---- cfd/util.py | 54 +-- external/CMakeLists.txt | 2 +- integration_test/tests/test_elements.py | 8 +- tests/data/elements_address_test.json | 6 +- tests/data/elements_transaction_test.json | 6 +- tests/data/key_test.json | 3 +- tests/test_address.py | 1 + tests/test_confidential_transaction.py | 24 +- tests/test_transaction.py | 3 + tests/util.py | 7 +- 22 files changed, 722 insertions(+), 382 deletions(-) diff --git a/.github/workflows/check_pre-merge_develop.yml b/.github/workflows/check_pre-merge_develop.yml index fe0902e..7842045 100644 --- a/.github/workflows/check_pre-merge_develop.yml +++ b/.github/workflows/check_pre-merge_develop.yml @@ -113,7 +113,7 @@ jobs: doxygen-ubuntu: name: doxygen-ubuntu - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - name: install_doxygen diff --git a/.github/workflows/check_pre-merge_master.yml b/.github/workflows/check_pre-merge_master.yml index 81c13bc..fd7bd0b 100644 --- a/.github/workflows/check_pre-merge_master.yml +++ b/.github/workflows/check_pre-merge_master.yml @@ -111,7 +111,7 @@ jobs: doxygen-ubuntu: name: doxygen-ubuntu - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - name: install_doxygen diff --git a/.github/workflows/check_pre-merge_sprint.yml b/.github/workflows/check_pre-merge_sprint.yml index 22d1aad..4f7527f 100644 --- a/.github/workflows/check_pre-merge_sprint.yml +++ b/.github/workflows/check_pre-merge_sprint.yml @@ -83,7 +83,7 @@ jobs: doxygen-ubuntu: name: doxygen-ubuntu - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - name: install_doxygen diff --git a/VERSION b/VERSION index f477849..373f8c6 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.2.2 \ No newline at end of file +0.2.3 \ No newline at end of file diff --git a/cfd/address.py b/cfd/address.py index f73b07d..1ce7b56 100644 --- a/cfd/address.py +++ b/cfd/address.py @@ -3,6 +3,7 @@ # @file address.py # @brief address function implements file. # @note Copyright 2020 CryptoGarage +import typing from .util import get_util, CfdError, JobHandle, to_hex_string from .key import Network, Pubkey from .script import HashType, Script @@ -15,27 +16,35 @@ class Address: ## # @var address # address string + address: str ## # @var locking_script # locking script (scriptPubkey) + locking_script: typing.Union[str, 'Script'] ## # @var pubkey # pubkey for pubkey hash. + pubkey: typing.Union[str, 'Pubkey'] ## # @var redeem_script # redeem script for script hash. + redeem_script: typing.Union[str, 'Script'] ## # @var p2sh_wrapped_script # witness locking script for p2sh. + p2sh_wrapped_script: typing.Union[str, 'Script'] ## # @var hash_type # hash type. + hash_type: 'HashType' ## # @var network # network. + network: 'Network' ## # @var witness_version # witness version. + witness_version: int ## # @brief constructor. @@ -48,27 +57,32 @@ class Address: # @param[in] p2sh_wrapped_script witness locking script for p2sh def __init__( self, - address, + address: str, locking_script, hash_type=HashType.P2SH, network=Network.MAINNET, pubkey='', redeem_script='', p2sh_wrapped_script=''): + _locking_script = to_hex_string(locking_script) + _redeem_script = to_hex_string(redeem_script) + _pubkey = to_hex_string(pubkey) self.address = address - self.locking_script = locking_script - self.pubkey = pubkey - self.redeem_script = redeem_script + self.locking_script = _locking_script if len( + _locking_script) == 0 else Script(locking_script) + self.pubkey = _pubkey if len(_pubkey) == 0 else Pubkey(pubkey) + self.redeem_script = _redeem_script if len( + _redeem_script) == 0 else Script(redeem_script) self.p2sh_wrapped_script = p2sh_wrapped_script - self.hash_type = hash_type - self.network = network + self.hash_type = HashType.get(hash_type) + self.network = Network.get(network) self.witness_version = -1 if p2sh_wrapped_script and len(p2sh_wrapped_script) > 2: - if int(locking_script[0:2], 16) < 16: + if int(_locking_script[0:2], 16) < 16: self.witness_version = int(p2sh_wrapped_script[0:2]) - elif len(locking_script) > 2: - if int(locking_script[0:2], 16) < 16: - self.witness_version = int(locking_script[0:2]) + elif len(_locking_script) > 2: + if int(_locking_script[0:2], 16) < 16: + self.witness_version = int(_locking_script[0:2]) ## # @brief get string. @@ -87,10 +101,10 @@ class AddressUtil: # @param[in] hash_type hash type # @return address object. @classmethod - def parse(cls, address, hash_type=HashType.P2WPKH): + def parse(cls, address, hash_type=HashType.P2WPKH) -> 'Address': util = get_util() with util.create_handle() as handle: - network, _hash_type, witness_version,\ + network, _hash_type, _witness_version,\ locking_script, _ = util.call_func( 'CfdGetAddressInfo', handle.get_handle(), str(address)) _hash_type = HashType.get(_hash_type) @@ -114,7 +128,7 @@ def parse(cls, address, hash_type=HashType.P2WPKH): # @param[in] network network # @return address object. @classmethod - def p2pkh(cls, pubkey, network=Network.MAINNET): + def p2pkh(cls, pubkey, network=Network.MAINNET) -> 'Address': return cls.from_pubkey_hash( pubkey, HashType.P2PKH, network) @@ -124,7 +138,7 @@ def p2pkh(cls, pubkey, network=Network.MAINNET): # @param[in] network network # @return address object. @classmethod - def p2wpkh(cls, pubkey, network=Network.MAINNET): + def p2wpkh(cls, pubkey, network=Network.MAINNET) -> 'Address': return cls.from_pubkey_hash( pubkey, HashType.P2WPKH, network) @@ -134,7 +148,7 @@ def p2wpkh(cls, pubkey, network=Network.MAINNET): # @param[in] network network # @return address object. @classmethod - def p2sh_p2wpkh(cls, pubkey, network=Network.MAINNET): + def p2sh_p2wpkh(cls, pubkey, network=Network.MAINNET) -> 'Address': return cls.from_pubkey_hash( pubkey, HashType.P2SH_P2WPKH, network) @@ -144,7 +158,7 @@ def p2sh_p2wpkh(cls, pubkey, network=Network.MAINNET): # @param[in] network network # @return address object. @classmethod - def p2sh(cls, redeem_script, network=Network.MAINNET): + def p2sh(cls, redeem_script, network=Network.MAINNET) -> 'Address': return cls.from_script_hash( redeem_script, HashType.P2SH, network) @@ -154,7 +168,7 @@ def p2sh(cls, redeem_script, network=Network.MAINNET): # @param[in] network network # @return address object. @classmethod - def p2wsh(cls, redeem_script, network=Network.MAINNET): + def p2wsh(cls, redeem_script, network=Network.MAINNET) -> 'Address': return cls.from_script_hash( redeem_script, HashType.P2WSH, network) @@ -164,7 +178,7 @@ def p2wsh(cls, redeem_script, network=Network.MAINNET): # @param[in] network network # @return address object. @classmethod - def p2sh_p2wsh(cls, redeem_script, network=Network.MAINNET): + def p2sh_p2wsh(cls, redeem_script, network=Network.MAINNET) -> 'Address': return cls.from_script_hash( redeem_script, HashType.P2SH_P2WSH, network) @@ -179,7 +193,7 @@ def from_pubkey_hash( cls, pubkey, hash_type, - network=Network.MAINNET): + network=Network.MAINNET) -> 'Address': _pubkey = str(pubkey) _hash_type = HashType.get(hash_type) _network = Network.get(network) @@ -208,7 +222,7 @@ def from_script_hash( cls, redeem_script, hash_type, - network=Network.MAINNET): + network=Network.MAINNET) -> 'Address': _script = str(redeem_script) _hash_type = HashType.get(hash_type) _network = Network.get(network) @@ -236,10 +250,10 @@ def from_script_hash( @classmethod def multisig( cls, - require_num, + require_num: int, pubkey_list, hash_type, - network=Network.MAINNET): + network=Network.MAINNET) -> 'Address': if isinstance(require_num, int) is False: raise CfdError( error_code=1, message='Error: Invalid require_num type.') @@ -290,7 +304,7 @@ def multisig( def from_locking_script( cls, locking_script, - network=Network.MAINNET): + network=Network.MAINNET) -> 'Address': _script = str(locking_script) _network = Network.get(network) util = get_util() @@ -311,7 +325,7 @@ def get_multisig_address_list( cls, redeem_script, hash_type, - network=Network.MAINNET): + network=Network.MAINNET) -> typing.List['Address']: _script = str(redeem_script) _hash_type = HashType.get(hash_type) _network = Network.get(network) diff --git a/cfd/confidential_address.py b/cfd/confidential_address.py index 09dfbae..c5d5c42 100644 --- a/cfd/confidential_address.py +++ b/cfd/confidential_address.py @@ -3,6 +3,9 @@ # @file confidential_address.py # @brief elements confidential address function implements file. # @note Copyright 2020 CryptoGarage +import typing +from .address import Address, AddressUtil +from .key import Pubkey from .util import get_util, to_hex_string, CfdError @@ -13,12 +16,15 @@ class ConfidentialAddress: ## # @var confidential_address # confidential address string + confidential_address: str ## # @var address # address + address: 'Address' ## # @var confidential_key # confidential key + confidential_key: 'Pubkey' ## # @brief check confidential address. @@ -26,7 +32,7 @@ class ConfidentialAddress: # @retval True confidential address # @retval False other @classmethod - def valid(cls, confidential_address): + def valid(cls, confidential_address) -> bool: util = get_util() try: with util.create_handle() as handle: @@ -42,7 +48,7 @@ def valid(cls, confidential_address): # @param[in] confidential_address confidential address # @return ConfidentialAddress object @classmethod - def parse(cls, confidential_address): + def parse(cls, confidential_address) -> 'ConfidentialAddress': util = get_util() with util.create_handle() as handle: _addr, _key, _ = util.call_func( @@ -55,8 +61,10 @@ def parse(cls, confidential_address): # @param[in] address address address # @param[in] confidential_key confidential key def __init__(self, address, confidential_key): - self.address = address - self.confidential_key = confidential_key + self.address = address if isinstance( + address, Address) else AddressUtil.parse(address) + self.confidential_key = confidential_key if isinstance( + confidential_key, Pubkey) else Pubkey(confidential_key) util = get_util() with util.create_handle() as handle: self.confidential_address = util.call_func( @@ -66,7 +74,7 @@ def __init__(self, address, confidential_key): ## # @brief get string. # @return confidential address. - def __str__(self): + def __str__(self) -> str: return self.confidential_address diff --git a/cfd/confidential_transaction.py b/cfd/confidential_transaction.py index 768d5df..bc6aaac 100644 --- a/cfd/confidential_transaction.py +++ b/cfd/confidential_transaction.py @@ -3,11 +3,14 @@ # @file confidential_transaction.py # @brief elements confidential transaction function implements file. # @note Copyright 2020 CryptoGarage +from typing import AnyStr, Dict, List, Optional, Union +import typing from .util import ReverseByteData, CfdError, JobHandle,\ CfdErrorCode, to_hex_string, get_util, ByteData from .address import Address, AddressUtil +from .descriptor import Descriptor from .key import Network, SigHashType, Privkey, Pubkey -from .script import HashType +from .script import HashType, Script from .transaction import UtxoData, OutPoint, Txid, TxIn, TxOut, _FundTxOpt,\ _TransactionBase from .confidential_address import ConfidentialAddress @@ -28,6 +31,22 @@ # @class BlindFactor # @brief blind factor (blinder) class. class BlindFactor(ReverseByteData): + ## + # @var hex + # hex string + hex: str + + ## + # @brief constructor. + # @param[in] data blind factor + # @return BlindFactor + @classmethod + def create(cls, data=''): + if isinstance(data, str) and (len(data) == 0): + return BlindFactor(EMPTY_BLINDER) + else: + return BlindFactor(data) + ## # @brief constructor. # @param[in] data blind factor @@ -35,7 +54,14 @@ def __init__(self, data): super().__init__(data) if len(self.hex) != 64: raise CfdError( - error_code=1, message='Error: Invalid blind factor.') + error_code=1, message=f'Error: Invalid blind factor.') + + ## + # @brief check empty. + # @retval True empty + # @retval False not empty + def is_empty(self): + return True if self.hex == '0'*64 else False ## @@ -45,6 +71,7 @@ class ConfidentialNonce: ## # @var hex # hex + hex: str ## # @brief constructor. @@ -58,14 +85,14 @@ def __init__(self, data=''): ## # @brief get string. # @return hex. - def __str__(self): + def __str__(self) -> str: return self.hex ## # @brief check empty. # @retval True empty. # @retval False value exist. - def is_empty(self): + def is_empty(self) -> bool: return (len(self.hex) == 0) @@ -76,6 +103,7 @@ class ConfidentialAsset: ## # @var hex # hex + hex: str ## # @brief constructor. @@ -94,14 +122,14 @@ def __init__(self, data): ## # @brief get string. # @return hex. - def __str__(self): + def __str__(self) -> str: return self.hex ## # @brief get blind state. # @retval True blinded. # @retval False unblind. - def has_blind(self): + def has_blind(self) -> bool: if (len(self.hex) == 66) and (self.hex[0] == '0') and ( self.hex[1].lower() in {'a', 'b'}): return True @@ -111,7 +139,7 @@ def has_blind(self): # @brief get commitment. (can use unblind only) # @param[in] asset_blind_factor asset blind factor # @return asset commitment - def get_commitment(self, asset_blind_factor): + def get_commitment(self, asset_blind_factor) -> 'ConfidentialAsset': if self.has_blind(): raise CfdError( error_code=1, message='Error: Blind asset.') @@ -130,9 +158,11 @@ class ConfidentialValue: ## # @var hex # hex + hex: str ## # @var amount # amount + amount: int ## # @brief create instance. @@ -140,7 +170,7 @@ class ConfidentialValue: # @param[in] amount amount # @return ConfidentialValue @classmethod - def create(cls, value, amount): + def create(cls, value, amount: int) -> 'ConfidentialValue': _value_hex = to_hex_string(value) if isinstance(value, ConfidentialValue): return value @@ -153,7 +183,8 @@ def create(cls, value, amount): # @brief get hex string from amount. # @param[in] amount amount # @return hex string - def _byte_from_amount(cls, amount): + @classmethod + def _byte_from_amount(cls, amount: int) -> str: util = get_util() with util.create_handle() as handle: value_hex = util.call_func( @@ -181,14 +212,14 @@ def __init__(self, data): ## # @brief get string. # @return hex or amount. - def __str__(self): + def __str__(self) -> str: return str(self.amount) if self.amount != 0 else self.hex ## # @brief get blind state. # @retval True blinded. # @retval False unblind. - def has_blind(self): + def has_blind(self) -> bool: return (len(self.hex) == 66) ## @@ -196,7 +227,8 @@ def has_blind(self): # @param[in] asset_commitment asset commitment # @param[in] blind_factor amount blind factor # @return amount commitment - def get_commitment(self, asset_commitment, blind_factor): + def get_commitment(self, asset_commitment, + blind_factor) -> 'ConfidentialValue': if self.has_blind(): raise CfdError( error_code=1, message='Error: Blind value.') @@ -220,36 +252,47 @@ class ElementsUtxoData(UtxoData): ## # @var outpoint # outpoint (for UtxoData class) + outpoint: 'OutPoint' ## # @var amount # amount (for UtxoData class) + amount: int ## # @var value # value + value: 'ConfidentialValue' ## # @var asset # asset + asset: 'ConfidentialAsset' ## # @var is_issuance # is issuance + is_issuance: bool ## # @var is_blind_issuance # is blinded issuance + is_blind_issuance: bool ## # @var is_pegin # is pegin + is_pegin: bool ## # @var pegin_btc_tx_size # pegin btc transaction size + pegin_btc_tx_size: int ## # @var fedpeg_script # fedpeg script + fedpeg_script: 'Script' ## # @var asset_blinder # asset blind factor + asset_blinder: 'BlindFactor' ## # @var amount_blinder # amount blind factor + amount_blinder: 'BlindFactor' ## # @brief constructor. @@ -269,24 +312,28 @@ class ElementsUtxoData(UtxoData): # @param[in] asset_blinder asset blind factor # @param[in] amount_blinder amount blind factor def __init__( - self, outpoint=None, txid='', vout=0, - amount=0, descriptor='', scriptsig_template='', - value='', asset='', is_issuance=False, is_blind_issuance=False, - is_pegin=False, pegin_btc_tx_size=0, fedpeg_script='', + self, outpoint: Optional['OutPoint'] = None, + txid='', vout: int = 0, amount: int = 0, + descriptor: Union[str, 'Descriptor'] = '', + scriptsig_template='', + value='', asset='', + is_issuance: bool = False, is_blind_issuance: bool = False, + is_pegin: bool = False, + pegin_btc_tx_size: int = 0, fedpeg_script='', asset_blinder='', amount_blinder=''): super().__init__( outpoint=outpoint, txid=txid, vout=vout, amount=amount, descriptor=descriptor, scriptsig_template=scriptsig_template) self.value = ConfidentialValue.create(value, amount) - self.asset = asset + self.asset = ConfidentialAsset(asset) self.is_issuance = is_issuance self.is_blind_issuance = is_blind_issuance self.is_pegin = is_pegin self.pegin_btc_tx_size = int(pegin_btc_tx_size) - self.fedpeg_script = fedpeg_script - self.asset_blinder = asset_blinder - self.amount_blinder = amount_blinder + self.fedpeg_script = Script(fedpeg_script) + self.asset_blinder = BlindFactor.create(asset_blinder) + self.amount_blinder = BlindFactor.create(amount_blinder) if self.amount == 0: self.amount = self.value.amount @@ -294,7 +341,7 @@ def __init__( # @brief equal method. # @param[in] other other object. # @return true or false. - def __eq__(self, other): + def __eq__(self, other: 'ElementsUtxoData') -> bool: if not isinstance(other, ElementsUtxoData): return NotImplemented return self.outpoint == other.outpoint @@ -303,7 +350,7 @@ def __eq__(self, other): # @brief diff method. # @param[in] other other object. # @return true or false. - def __lt__(self, other): + def __lt__(self, other: 'ElementsUtxoData') -> bool: if not isinstance(other, ElementsUtxoData): return NotImplemented return (self.outpoint) < (other.outpoint) @@ -312,28 +359,28 @@ def __lt__(self, other): # @brief equal method. # @param[in] other other object. # @return true or false. - def __ne__(self, other): + def __ne__(self, other: 'ElementsUtxoData') -> bool: return not self.__eq__(other) ## # @brief diff method. # @param[in] other other object. # @return true or false. - def __le__(self, other): + def __le__(self, other: 'ElementsUtxoData') -> bool: return self.__lt__(other) or self.__eq__(other) ## # @brief diff method. # @param[in] other other object. # @return true or false. - def __gt__(self, other): + def __gt__(self, other: 'ElementsUtxoData') -> bool: return not self.__le__(other) ## # @brief diff method. # @param[in] other other object. # @return true or false. - def __ge__(self, other): + def __ge__(self, other: 'ElementsUtxoData') -> bool: return not self.__lt__(other) @@ -344,15 +391,19 @@ class UnblindData: ## # @var asset # asset + asset: Union[AnyStr, 'ConfidentialAsset'] ## # @var value # value + value: 'ConfidentialValue' ## # @var asset_blinder # asset blind factor + asset_blinder: 'BlindFactor' ## # @var amount_blinder # amount blind factor + amount_blinder: 'BlindFactor' ## # @brief constructor. @@ -360,7 +411,8 @@ class UnblindData: # @param[in] amount amount # @param[in] asset_blinder asset blind factor # @param[in] amount_blinder amount blind factor - def __init__(self, asset, amount, asset_blinder, amount_blinder): + def __init__(self, asset, amount: int, + asset_blinder, amount_blinder): self.asset = asset self.value = ConfidentialValue(amount) self.asset_blinder = BlindFactor(asset_blinder) @@ -369,7 +421,7 @@ def __init__(self, asset, amount, asset_blinder, amount_blinder): ## # @brief get string. # @return hex - def __str__(self): + def __str__(self) -> str: return '{},{}'.format(self.asset, self.value) @@ -380,9 +432,11 @@ class BlindData(UnblindData): ## # @var vout # txout array index + vout: int ## # @var is_issuance # issuance flag + is_issuance: bool ## # @brief constructor. @@ -391,7 +445,8 @@ class BlindData(UnblindData): # @param[in] amount amount # @param[in] asset_blinder asset blind factor # @param[in] amount_blinder amount blind factor - def __init__(self, vout, asset, amount, asset_blinder, amount_blinder): + def __init__(self, vout: int, asset, amount: int, + asset_blinder, amount_blinder): super().__init__(asset, amount, asset_blinder, amount_blinder) self.vout = vout self.is_issuance = False @@ -404,9 +459,11 @@ class IssuanceAssetBlindData(BlindData): ## # @var outpoint # issuance outpoint + outpoint: 'OutPoint' ## # @var is_issuance # issuance flag + is_issuance: bool ## # @brief constructor. @@ -415,7 +472,8 @@ class IssuanceAssetBlindData(BlindData): # @param[in] asset asset # @param[in] amount amount # @param[in] amount_blinder amount blind factor - def __init__(self, outpoint, vout, asset, amount, amount_blinder): + def __init__(self, outpoint: OutPoint, vout: int, asset, + amount: int, amount_blinder): super().__init__(vout, asset, amount, EMPTY_BLINDER, amount_blinder) self.outpoint = outpoint self.is_issuance = True @@ -428,9 +486,11 @@ class IssuanceTokenBlindData(BlindData): ## # @var outpoint # issuance outpoint + outpoint: 'OutPoint' ## # @var is_issuance # issuance flag + is_issuance: bool ## # @brief constructor. @@ -439,7 +499,8 @@ class IssuanceTokenBlindData(BlindData): # @param[in] asset asset # @param[in] amount amount # @param[in] amount_blinder amount blind factor - def __init__(self, outpoint, vout, asset, amount, amount_blinder): + def __init__(self, outpoint: OutPoint, vout: int, asset, + amount: int, amount_blinder): super().__init__(vout, asset, amount, EMPTY_BLINDER, amount_blinder) self.outpoint = outpoint self.is_issuance = True @@ -452,30 +513,38 @@ class Issuance: ## # @var entropy # entropy + entropy: 'BlindFactor' ## # @var nonce # nonce + nonce: 'BlindFactor' ## # @var asset_value # asset value + asset_value: 'ConfidentialValue' ## # @var token_value # token value + token_value: 'ConfidentialValue' ## # @brief constructor. - def __init__(self, entropy='', nonce='', asset_value=0, token_value=0): - self.entropy = entropy - self.nonce = nonce + # @param[in] entropy entropy + # @param[in] nonce nonce + # @param[in] asset_value asset amount + # @param[in] token_value token amount + def __init__(self, entropy='', nonce='', + asset_value: int = 0, token_value: int = 0): + self.entropy = BlindFactor.create(entropy) + self.nonce = BlindFactor.create(nonce) self.asset_value = ConfidentialValue(asset_value) self.token_value = ConfidentialValue(token_value) ## # @brief get string. # @return hex - def __str__(self): - return '{},{},{}'.format( - self.entropy, self.asset_value, self.token_value) + def __str__(self) -> str: + return '{},{},{}'.format(str(self.entropy), self.asset_value, self.token_value) ## @@ -485,22 +554,24 @@ class IssuanceKeyPair: ## # @var asset_key # asset blinding key + asset_key: Optional['Privkey'] ## # @var token_key # token blinding key + token_key: Optional['Privkey'] ## # @brief constructor. # @param[in] asset_key asset blinding key # @param[in] token_key token blinding key def __init__(self, asset_key='', token_key=''): - self.asset_key = asset_key - self.token_key = token_key + self.asset_key = Privkey.from_hex_ignore_error(asset_key) + self.token_key = Privkey.from_hex_ignore_error(token_key) ## # @brief get string. # @return hex - def __str__(self): + def __str__(self) -> str: return 'IssuanceKeyPair' @@ -511,9 +582,11 @@ class ConfidentialTxIn(TxIn): ## # @var pegin_witness_stack # pegin witness stack + pegin_witness_stack: List[Union['ByteData', 'Script']] ## # @var issuance # issuance + issuance: 'Issuance' ## # @brief constructor. @@ -521,8 +594,9 @@ class ConfidentialTxIn(TxIn): # @param[in] txid txid # @param[in] vout vout # @param[in] sequence sequence - def __init__(self, outpoint=None, txid='', vout=0, - sequence=TxIn.SEQUENCE_DISABLE): + def __init__(self, outpoint: Optional['OutPoint'] = None, + txid='', vout: int = 0, + sequence: int = TxIn.SEQUENCE_DISABLE): super().__init__(outpoint, txid, vout, sequence) self.pegin_witness_stack = [] self.issuance = Issuance() @@ -535,18 +609,23 @@ class ConfidentialTxOut(TxOut): ## # @var value # value + value: 'ConfidentialValue' ## # @var asset # asset + asset: 'ConfidentialAsset' ## # @var nonce # nonce + nonce: 'ConfidentialNonce' ## # @var surjectionproof # surjection proof + surjectionproof: Union[List[int], AnyStr, 'ByteData'] ## # @var rangeproof # range proof + rangeproof: Union[List[int], AnyStr, 'ByteData'] ## # @brief get destroy amount txout. @@ -555,7 +634,8 @@ class ConfidentialTxOut(TxOut): # @param[in] nonce nonce # @return ConfidentialTxOut @classmethod - def for_destroy_amount(cls, amount, asset, nonce=''): + def for_destroy_amount( + cls, amount: int, asset, nonce='') -> 'ConfidentialTxOut': return ConfidentialTxOut(amount, asset=asset, nonce=nonce, locking_script='6a') @@ -565,7 +645,7 @@ def for_destroy_amount(cls, amount, asset, nonce=''): # @param[in] asset asset # @return ConfidentialTxOut @classmethod - def for_fee(cls, amount, asset): + def for_fee(cls, amount: int, asset) -> 'ConfidentialTxOut': return ConfidentialTxOut(amount, asset=asset) ## @@ -607,7 +687,8 @@ def has_blind(self): # @param[in] is_confidential Returns Confidential Address if possible. # @return address. def get_address(self, network=Network.LIQUID_V1, - is_confidential=False): + is_confidential: bool = False, + ) -> Union['Address', 'ConfidentialAddress']: _network = Network.get(network) if _network not in [Network.LIQUID_V1, Network.ELEMENTS_REGTEST]: raise CfdError(error_code=1, @@ -638,22 +719,30 @@ class TargetAmountData: ## # @var amount # amount + amount: int ## # @var asset # asset + asset: ConfidentialAsset ## # @var reserved_address # reserved address + reserved_address: Union[str, 'Address', 'ConfidentialAddress'] ## # @brief constructor. # @param[in] amount amount # @param[in] asset asset # @param[in] reserved_address reserved address - def __init__(self, amount, asset, reserved_address=''): + def __init__(self, amount: int, asset, + reserved_address: Union[str, 'Address', 'ConfidentialAddress'] = ''): self.amount = amount - self.asset = asset - self.reserved_address = reserved_address + self.asset = ConfidentialAsset(asset) + if isinstance(reserved_address, Address) or \ + isinstance(reserved_address, ConfidentialAddress): + self.reserved_address = reserved_address + else: + self.reserved_address = str(reserved_address) ## @@ -663,43 +752,54 @@ class ConfidentialTransaction(_TransactionBase): ## # @var hex # transaction hex string + hex: str ## # @var txin_list # transaction input list + txin_list: List['ConfidentialTxIn'] ## # @var txout_list # transaction output list + txout_list: List['ConfidentialTxOut'] ## # @var txid # txid + txid: 'Txid' ## # @var wtxid # wtxid + wtxid: 'Txid' ## # @var wit_hash # wit_hash + wit_hash: str ## # @var size # transaction size + size: int ## # @var vsize # transaction vsize + vsize: int ## # @var weight # transaction size weight + weight: int ## # @var version # version + version: int ## # @var locktime # locktime + locktime: int ## # bitcoin network value. - NETWORK = Network.LIQUID_V1.value + NETWORK: int = Network.LIQUID_V1.value ## # blind minimumBits on default. - DEFAULT_BLIND_MINIMUM_BITS = 52 + DEFAULT_BLIND_MINIMUM_BITS: int = 52 ## # transaction's free function name. FREE_FUNC_NAME = 'CfdFreeTransactionHandle' @@ -711,8 +811,8 @@ class ConfidentialTransaction(_TransactionBase): # @param[in] full_dump full_dump flag # @return json string @classmethod - def parse_to_json(cls, hex, network=Network.LIQUID_V1, - full_dump=False): + def parse_to_json(cls, hex: str, network=Network.LIQUID_V1, + full_dump: bool = False) -> str: _network = Network.get(network) mainchain_str = 'mainnet' network_str = 'liquidv1' @@ -736,14 +836,16 @@ def parse_to_json(cls, hex, network=Network.LIQUID_V1, # @param[in] locking_script locking script # @return blinding key @classmethod - def get_default_blinding_key(cls, master_blinding_key, locking_script): + def get_default_blinding_key( + cls, master_blinding_key, locking_script) -> 'Privkey': _key = to_hex_string(master_blinding_key) _script = to_hex_string(locking_script) util = get_util() with util.create_handle() as handle: - return util.call_func( + key = util.call_func( 'CfdGetDefaultBlindingKey', handle.get_handle(), _key, _script) + return Privkey(hex=key) ## # @brief get issuance blinding key for elemens default. @@ -752,15 +854,16 @@ def get_default_blinding_key(cls, master_blinding_key, locking_script): # @param[in] vout vout # @return blinding key @classmethod - def get_issuance_blinding_key(cls, master_blinding_key, - txid, vout): + def get_issuance_blinding_key( + cls, master_blinding_key, txid, vout: int) -> 'Privkey': _key = to_hex_string(master_blinding_key) _txid = to_hex_string(txid) util = get_util() with util.create_handle() as handle: - return util.call_func( + key = util.call_func( 'CfdGetIssuanceBlindingKey', handle.get_handle(), _key, _txid, vout) + return Privkey(hex=key) ## # @brief create transaction. @@ -771,7 +874,10 @@ def get_issuance_blinding_key(cls, master_blinding_key, # @param[in] enable_cache enable tx cache # @return transaction object @classmethod - def create(cls, version, locktime, txins=[], txouts=[], enable_cache=True): + def create(cls, version: int, locktime: int, + txins: List[Union['ConfidentialTxIn', 'TxIn']] = [], + txouts: List['ConfidentialTxOut'] = [], + enable_cache: bool = True) -> 'ConfidentialTransaction': util = get_util() with util.create_handle() as handle: _tx_handle = util.call_func( @@ -803,14 +909,14 @@ def create(cls, version, locktime, txins=[], txouts=[], enable_cache=True): # @param[in] enable_cache enable tx cache # @return transaction object @classmethod - def from_hex(cls, hex, enable_cache=True): + def from_hex(cls, hex, enable_cache: bool = True) -> 'ConfidentialTransaction': return ConfidentialTransaction(hex, enable_cache) ## # @brief constructor. # @param[in] hex tx hex # @param[in] enable_cache enable tx cache - def __init__(self, hex, enable_cache=True): + def __init__(self, hex, enable_cache: bool = True): super().__init__(hex, self.NETWORK, enable_cache) self.txin_list = [] self.txout_list = [] @@ -824,7 +930,9 @@ def __init__(self, hex, enable_cache=True): # @param[in] outpoint outpoint # @retval [0] txin # @retval [1] index - def _get_txin(self, handle, tx_handle, index=0, outpoint=None): + def _get_txin(self, handle, tx_handle, index=0, + outpoint: Optional['OutPoint'] = None, + ) -> typing.Tuple['ConfidentialTxIn', int]: util = get_util() if isinstance(outpoint, OutPoint): @@ -874,7 +982,7 @@ def _get_txin(self, handle, tx_handle, index=0, outpoint=None): ## # @brief update transaction information. # @return void - def _update_info(self): + def _update_info(self) -> None: if self.enable_cache is False: return util = get_util() @@ -895,7 +1003,7 @@ def _update_info(self): # @brief update transaction input. # @param[in] outpoint outpoint # @return void - def _update_txin(self, outpoint): + def _update_txin(self, outpoint: 'OutPoint'): if self.enable_cache is False: return util = get_util() @@ -920,7 +1028,8 @@ def _update_txin(self, outpoint): # @brief get transaction all data. # @retval [0] txin list # @retval [1] txout list - def get_tx_all(self): + def get_tx_all(self) -> typing.Tuple[ + List['ConfidentialTxIn'], List['ConfidentialTxOut']]: def get_txin_list(handle, tx_handle): txin_list = [] _count = util.call_func( @@ -969,7 +1078,7 @@ def get_txout_list(handle, tx_handle): ## # @brief get transaction output fee index. # @return index - def get_txout_fee_index(self): + def get_txout_fee_index(self) -> int: return self.get_txout_index() ## @@ -979,8 +1088,8 @@ def get_txout_fee_index(self): # @param[in] txid txid # @param[in] vout vout # @return void - def add_txin(self, outpoint=None, sequence=-1, - txid='', vout=0): + def add_txin(self, outpoint: Optional['OutPoint'] = None, + sequence: int = -1, txid='', vout: int = 0) -> None: sec = TxIn.get_sequence_number(self.locktime, sequence) txin = ConfidentialTxIn( outpoint=outpoint, sequence=sec, txid=txid, vout=vout) @@ -995,9 +1104,8 @@ def add_txin(self, outpoint=None, sequence=-1, # @param[in] asset asset # @param[in] nonce nonce # @return void - def add_txout( - self, amount=0, address='', locking_script='', - value='', asset='', nonce=''): + def add_txout(self, amount: int = 0, address='', + locking_script='', value='', asset='', nonce='') -> None: txout = ConfidentialTxOut( amount, address, locking_script, value, asset, nonce) self.add([], [txout]) @@ -1007,7 +1115,7 @@ def add_txout( # @param[in] amount amount # @param[in] asset asset # @return void - def add_fee_txout(self, amount, asset): + def add_fee_txout(self, amount: int, asset) -> None: self.add_txout(amount, asset=asset) ## @@ -1016,7 +1124,8 @@ def add_fee_txout(self, amount, asset): # @param[in] asset asset # @param[in] nonce nonce # @return void - def add_destroy_amount_txout(self, amount, asset, nonce=''): + def add_destroy_amount_txout( + self, amount: int, asset, nonce='') -> None: self.add_txout(amount, locking_script='6a', asset=asset, nonce=nonce) ## @@ -1024,7 +1133,8 @@ def add_destroy_amount_txout(self, amount, asset, nonce=''): # @param[in] txins txin list # @param[in] txouts txout list # @return void - def add(self, txins, txouts): + def add(self, txins: List['ConfidentialTxIn'], + txouts: List['ConfidentialTxOut']) -> None: util = get_util() with util.create_handle() as handle: _tx_handle = util.call_func( @@ -1064,7 +1174,7 @@ def add(self, txins, txouts): # @param[in] index index # @param[in] amount amount # @return void - def update_txout_amount(self, index, amount): + def update_txout_amount(self, index: int, amount: int) -> None: util = get_util() with util.create_handle() as handle: self.hex = util.call_func( @@ -1077,7 +1187,7 @@ def update_txout_amount(self, index, amount): # @brief update transaction output fee amount. # @param[in] amount amount # @return void - def update_txout_fee_amount(self, amount): + def update_txout_fee_amount(self, amount: int) -> None: index = self.get_txout_index() self.update_txout_amount(index, amount) @@ -1091,10 +1201,12 @@ def update_txout_fee_amount(self, amount): # @param[in] minimum_bits minimum bits # @param[in] collect_blinder blinder collect flag. # @return blinder_list (if collect_blinder is true) - def blind_txout(self, utxo_list, confidential_address_list=[], + def blind_txout(self, utxo_list: List['ElementsUtxoData'], + confidential_address_list=[], direct_confidential_key_map={}, - minimum_range_value=1, exponent=0, minimum_bits=-1, - collect_blinder=False): + minimum_range_value: int = 1, exponent: int = 0, + minimum_bits: int = -1, + collect_blinder: bool = False) -> List['BlindData']: return self.blind( utxo_list=utxo_list, confidential_address_list=confidential_address_list, @@ -1114,12 +1226,14 @@ def blind_txout(self, utxo_list, confidential_address_list=[], # @param[in] minimum_bits minimum bits # @param[in] collect_blinder blinder collect flag. # @return blinder_list (if collect_blinder is true) - def blind(self, utxo_list, - issuance_key_map={}, - confidential_address_list=[], - direct_confidential_key_map={}, - minimum_range_value=1, exponent=0, minimum_bits=-1, - collect_blinder=False): + def blind( + self, utxo_list: List['ElementsUtxoData'], + issuance_key_map={}, + confidential_address_list=[], + direct_confidential_key_map={}, + minimum_range_value: int = 1, exponent: int = 0, minimum_bits: int = -1, + collect_blinder: bool = False, + ) -> List[Union['BlindData', 'IssuanceAssetBlindData', 'IssuanceTokenBlindData']]: if minimum_bits == -1: minimum_bits = self.DEFAULT_BLIND_MINIMUM_BITS @@ -1142,7 +1256,8 @@ def set_opt(handle, tx_handle, key, i_val=0): asset_key, token_key = '', '' if str(txin.outpoint) in issuance_key_map: item = issuance_key_map[str(txin.outpoint)] - asset_key, token_key = item.asset_key, item.token_key + asset_key = str(item.asset_key) + token_key = str(item.token_key) issuance_count += 1 util.call_func( 'CfdAddBlindTxInData', handle.get_handle(), @@ -1213,7 +1328,7 @@ def set_opt(handle, tx_handle, key, i_val=0): # @param[in] index txout index # @param[in] blinding_key blinding key # @return UnblindData - def unblind_txout(self, index, blinding_key): + def unblind_txout(self, index: int, blinding_key) -> 'UnblindData': util = get_util() with util.create_handle() as handle: asset, asset_amount, asset_blinder,\ @@ -1230,7 +1345,8 @@ def unblind_txout(self, index, blinding_key): # @param[in] token_key token blinding key # @retval [0] asset unblind data # @retval [1] token unblind data - def unblind_issuance(self, index, asset_key, token_key=''): + def unblind_issuance(self, index: int, asset_key, + token_key='') -> 'UnblindData': util = get_util() with util.create_handle() as handle: asset, asset_amount, asset_blinder, amount_blinder, token,\ @@ -1252,7 +1368,8 @@ def unblind_issuance(self, index, asset_key, token_key=''): # @param[in] address address # @param[in] entropy entropy # @return reissue asset - def set_raw_reissue_asset(self, utxo, amount, address, entropy): + def set_raw_reissue_asset(self, utxo: 'ElementsUtxoData', amount: int, + address, entropy) -> 'ConfidentialAsset': _amount = amount if isinstance(amount, ConfidentialValue): _amount = amount.amount @@ -1275,8 +1392,8 @@ def set_raw_reissue_asset(self, utxo, amount, address, entropy): # @param[in] redeem_script redeem script # @param[in] sighashtype sighash type # @return sighash - def get_sighash(self, outpoint, hash_type, value, pubkey='', - redeem_script='', sighashtype=SigHashType.ALL): + def get_sighash(self, outpoint: 'OutPoint', hash_type, value, pubkey='', + redeem_script='', sighashtype=SigHashType.ALL) -> 'ByteData': _hash_type = HashType.get(hash_type) _pubkey = to_hex_string(pubkey) _script = to_hex_string(redeem_script) @@ -1304,8 +1421,8 @@ def get_sighash(self, outpoint, hash_type, value, pubkey='', # @param[in] grind_r grind-R flag # @return void def sign_with_privkey( - self, outpoint, hash_type, privkey, value, - sighashtype=SigHashType.ALL, grind_r=True): + self, outpoint: 'OutPoint', hash_type, privkey, value, + sighashtype=SigHashType.ALL, grind_r: bool = True) -> None: _hash_type = HashType.get(hash_type) if isinstance(privkey, Privkey): _privkey = privkey @@ -1336,7 +1453,8 @@ def sign_with_privkey( # @param[in] hash_type hash type # @param[in] value value # @return void - def verify_sign(self, outpoint, address, hash_type, value): + def verify_sign(self, outpoint: 'OutPoint', address, hash_type, + value) -> None: _hash_type = HashType.get(hash_type) _value = value if isinstance(value, ConfidentialValue) is False: @@ -1361,8 +1479,9 @@ def verify_sign(self, outpoint, address, hash_type, value): # @retval True signature valid. # @retval False signature invalid. def verify_signature( - self, outpoint, signature, hash_type, pubkey, value, - redeem_script='', sighashtype=SigHashType.ALL): + self, outpoint: 'OutPoint', signature, + hash_type, pubkey, value, + redeem_script='', sighashtype=SigHashType.ALL) -> bool: _signature = to_hex_string(signature) _pubkey = to_hex_string(pubkey) _script = to_hex_string(redeem_script) @@ -1404,10 +1523,15 @@ def verify_signature( # @retval [2] total tx fee. @classmethod def select_coins( - cls, utxo_list, tx_fee_amount, target_list, fee_asset, - effective_fee_rate=0.11, long_term_fee_rate=0.11, - dust_fee_rate=3.0, knapsack_min_change=-1, - exponent=0, minimum_bits=52): + cls, + utxo_list: List['ElementsUtxoData'], + tx_fee_amount: int, + target_list: List['TargetAmountData'], + fee_asset, + effective_fee_rate: float = 0.11, long_term_fee_rate: float = 0.11, + dust_fee_rate: float = 3.0, knapsack_min_change: int = -1, + exponent: int = 0, minimum_bits: int = 52, + ) -> typing.Tuple[List['ElementsUtxoData'], int, Dict[str, int]]: if (isinstance(utxo_list, list) is False) or ( len(utxo_list) == 0): raise CfdError(error_code=1, message='Error: Invalid utxo_list.') @@ -1471,7 +1595,8 @@ def set_opt(handle, tx_handle, key, i_val=0, f_val=0, b_val=False): total_amount = util.call_func( 'CfdGetSelectedCoinAssetAmount', handle.get_handle(), tx_handle.get_handle(), index) - total_amount_map[target.asset] = total_amount + key = str(target.asset) # not optimisation + total_amount_map[key] = total_amount return selected_utxo_list, _utxo_fee, total_amount_map ## @@ -1485,8 +1610,11 @@ def set_opt(handle, tx_handle, key, i_val=0, f_val=0, b_val=False): # @retval [0] total tx fee. (txout fee + utxo fee) # @retval [1] txout fee. # @retval [2] utxo fee. - def estimate_fee(self, utxo_list, fee_asset, fee_rate=0.11, - is_blind=True, exponent=0, minimum_bits=52): + def estimate_fee(self, utxo_list: List['ElementsUtxoData'], + fee_asset, + fee_rate: float = 0.11, is_blind: bool = True, + exponent: int = 0, minimum_bits: int = 52, + ) -> typing.Tuple[int, int, int]: _fee_asset = ConfidentialAsset(fee_asset) if (isinstance(utxo_list, list) is False) or ( len(utxo_list) == 0): @@ -1548,11 +1676,15 @@ def set_opt(handle, tx_handle, key, i_val=0, f_val=0, b_val=False): # @retval [0] total tx fee. # @retval [1] used reserved address. (None or reserved_address) def fund_raw_transaction( - self, txin_utxo_list, utxo_list, target_list, - fee_asset, effective_fee_rate=0.11, - long_term_fee_rate=-1.0, dust_fee_rate=-1.0, - knapsack_min_change=-1, is_blind=True, - exponent=0, minimum_bits=52): + self, txin_utxo_list: List['ElementsUtxoData'], + utxo_list: List['ElementsUtxoData'], + target_list: List['TargetAmountData'], + fee_asset, + effective_fee_rate: float = 0.11, + long_term_fee_rate: float = -1.0, + dust_fee_rate: float = -1.0, + knapsack_min_change: int = -1, is_blind: bool = True, + exponent: int = 0, minimum_bits: int = 52) -> typing.Tuple[int, List[str]]: util = get_util() def set_opt(handle, tx_handle, key, i_val=0, f_val=0, b_val=False): diff --git a/cfd/descriptor.py b/cfd/descriptor.py index eea5be1..2173be1 100644 --- a/cfd/descriptor.py +++ b/cfd/descriptor.py @@ -3,10 +3,12 @@ # @file descriptor.py # @brief hdwallet function implements file. # @note Copyright 2020 CryptoGarage +from typing import List, Optional, Union from .util import get_util, JobHandle, CfdError -from .address import AddressUtil -from .key import Network -from .script import HashType +from .address import Address, AddressUtil +from .key import Network, Pubkey +from .hdwallet import ExtPubkey, ExtPrivkey +from .script import HashType, Script from enum import Enum @@ -51,7 +53,7 @@ class DescriptorScriptType(Enum): ## # @brief get string. # @return name. - def as_str(self): + def as_str(self) -> str: return self.name.lower().replace('_', '') ## @@ -59,7 +61,7 @@ def as_str(self): # @param[in] desc_type descriptor type # @return object. @classmethod - def get(cls, desc_type): + def get(cls, desc_type) -> 'DescriptorScriptType': if (isinstance(desc_type, DescriptorScriptType)): return desc_type elif (isinstance(desc_type, int)): @@ -97,7 +99,7 @@ class DescriptorKeyType(Enum): ## # @brief get string. # @return name. - def as_str(self): + def as_str(self) -> str: if self == DescriptorKeyType.PUBLIC: return 'pubkey' elif self == DescriptorKeyType.BIP32: @@ -111,7 +113,7 @@ def as_str(self): # @param[in] desc_type descriptor type # @return object. @classmethod - def get(cls, desc_type): + def get(cls, desc_type) -> 'DescriptorKeyType': if (isinstance(desc_type, DescriptorKeyType)): return desc_type elif (isinstance(desc_type, int)): @@ -142,15 +144,19 @@ class DescriptorKeyData: ## # @var key_type # key type + key_type: 'DescriptorKeyType' ## # @var pubkey # pubkey + pubkey: Union['Pubkey', str] ## # @var ext_pubkey # ext pubkey + ext_pubkey: Union['ExtPubkey', str] ## # @var ext_privkey # ext privkey + ext_privkey: Union['ExtPrivkey', str] ## # @brief constructor. @@ -165,14 +171,24 @@ def __init__( ext_pubkey='', ext_privkey=''): self.key_type = DescriptorKeyType.get(key_type) - self.pubkey = pubkey - self.ext_pubkey = ext_pubkey - self.ext_privkey = ext_privkey + self.pubkey = pubkey if isinstance(pubkey, str) else Pubkey(pubkey) + if ext_pubkey is None: + self.ext_pubkey = '' + elif isinstance(ext_pubkey, str): + self.ext_pubkey = ext_pubkey + else: + self.ext_pubkey = ExtPubkey(ext_pubkey) + if ext_privkey is None: + self.ext_privkey = '' + elif isinstance(ext_privkey, str): + self.ext_privkey = ext_privkey + else: + self.ext_privkey = ExtPrivkey(ext_privkey) ## # @brief get string. # @return descriptor. - def __str__(self): + def __str__(self) -> str: if self.key_type == DescriptorKeyType.PUBLIC: return str(self.pubkey) elif self.key_type == DescriptorKeyType.BIP32: @@ -189,30 +205,39 @@ class DescriptorScriptData: ## # @var script_type # script type + script_type: 'DescriptorScriptType' ## # @var depth # depth + depth: int ## # @var hash_type # hash type + hash_type: 'HashType' ## # @var address # address + address: Union[str, 'Address'] ## # @var locking_script # locking script + locking_script: Union[str, 'Script'] ## # @var redeem_script # redeem script for script hash + redeem_script: Union[str, 'Script'] ## # @var key_data # key data + key_data: Optional['DescriptorKeyType'] ## # @var key_list # key list + key_list: List['DescriptorKeyData'] ## # @var multisig_require_num # multisig require num + multisig_require_num: int ## # @brief constructor. @@ -226,16 +251,18 @@ class DescriptorScriptData: # @param[in] key_list key list # @param[in] multisig_require_num multisig require num def __init__( - self, script_type, depth, hash_type, address, + self, script_type: 'DescriptorScriptType', depth: int, + hash_type: 'HashType', address, locking_script, redeem_script='', - key_data=None, - key_list=[], - multisig_require_num=0): + key_data: Optional['DescriptorKeyData'] = None, + key_list: List['DescriptorKeyType'] = [], + multisig_require_num: int = 0): self.script_type = script_type self.depth = depth self.hash_type = hash_type - self.address = address + self.address = address if isinstance( + address, Address) else str(address) self.locking_script = locking_script self.redeem_script = redeem_script self.key_data = key_data @@ -250,25 +277,30 @@ class Descriptor: ## # @var path # bip32 path + path: str ## # @var descriptor # descriptor string + descriptor: str ## # @var network # network + network: 'Network' ## # @var script_list # script list + script_list: List['DescriptorScriptData'] ## # @var data # reference data + data: 'DescriptorScriptData' ## # @brief constructor. # @param[in] descriptor descriptor # @param[in] network network # @param[in] path bip32 path - def __init__(self, descriptor, network=Network.MAINNET, path=''): + def __init__(self, descriptor, network=Network.MAINNET, path: str = ''): self.network = Network.get(network) self.path = str(path) self.descriptor = self._verify(str(descriptor)) @@ -279,7 +311,7 @@ def __init__(self, descriptor, network=Network.MAINNET, path=''): # @brief verify descriptor. # @param[in] descriptor descriptor # @return append checksum descriptor - def _verify(self, descriptor): + def _verify(self, descriptor: str) -> str: util = get_util() with util.create_handle() as handle: return util.call_func( @@ -289,7 +321,7 @@ def _verify(self, descriptor): ## # @brief parse descriptor. # @return script list - def _parse(self): + def _parse(self) -> List['DescriptorScriptData']: util = get_util() with util.create_handle() as handle: word_handle, max_index = util.call_func( @@ -357,7 +389,7 @@ def get_key(index): ## # @brief analyze descriptor. # @return reference data - def _analyze(self): + def _analyze(self) -> 'DescriptorScriptData': if (self.script_list[0].hash_type in [ HashType.P2WSH, HashType.P2SH]) and ( len(self.script_list) > 1) and ( @@ -433,7 +465,8 @@ def __str__(self): # @param[in] network network # @param[in] path bip32 path # @retval Descriptor descriptor object -def parse_descriptor(descriptor, network=Network.MAINNET, path=''): +def parse_descriptor(descriptor, network=Network.MAINNET, + path: str = '') -> 'Descriptor': return Descriptor(descriptor, network=network, path=path) diff --git a/cfd/hdwallet.py b/cfd/hdwallet.py index c97d3f5..e5cfa53 100644 --- a/cfd/hdwallet.py +++ b/cfd/hdwallet.py @@ -3,7 +3,9 @@ # @file hdwallet.py # @brief hdwallet function implements file. # @note Copyright 2020 CryptoGarage -from .util import get_util, JobHandle, to_hex_string, CfdError +import typing +from typing import List, Tuple, Union +from .util import ByteData, CfdUtil, get_util, JobHandle, to_hex_string, CfdError from .key import Network, Privkey, Pubkey from enum import Enum import unicodedata @@ -37,13 +39,13 @@ class ExtKeyType(Enum): ## # @brief get string. # @return name. - def __str__(self): + def __str__(self) -> str: return self.name.lower().replace('_', '') ## # @brief get string. # @return name. - def as_str(self): + def as_str(self) -> str: return self.name.lower().replace('_', '') ## @@ -51,7 +53,7 @@ def as_str(self): # @param[in] key_type key type # @return object. @classmethod - def get(cls, key_type): + def get(cls, key_type) -> 'ExtKeyType': if (isinstance(key_type, ExtKeyType)): return key_type elif (isinstance(key_type, int)): @@ -80,30 +82,39 @@ class Extkey(object): ## # @var extkey_type # extkey type + extkey_type: 'ExtKeyType' ## # @var util # cfd util + util: 'CfdUtil' ## # @var version # version + version: str ## # @var fingerprint # fingerprint + fingerprint: 'ByteData' ## # @var chain_code # chain code + chain_code: 'ByteData' ## # @var depth # depth + depth: int ## # @var child_number # child number + child_number: int ## # @var extkey - # extkey + # extkey string + extkey: str ## # @var network # network + network: 'Network' ## # @brief constructor. @@ -112,8 +123,8 @@ def __init__(self, extkey_type): self.extkey_type = extkey_type self.util = get_util() self.version = '' - self.fingerprint = '' - self.chain_code = '' + self.fingerprint = ByteData('') + self.chain_code = ByteData('') self.depth = 0 self.child_number = 0 self.extkey = '' @@ -124,12 +135,15 @@ def __init__(self, extkey_type): # @param[in] extkey extkey # @return void def _get_information(self, extkey): + _extkey = str(extkey) with self.util.create_handle() as handle: result = self.util.call_func( - 'CfdGetExtkeyInformation', handle.get_handle(), extkey) - self.version, self.fingerprint, self.chain_code, self.depth, \ + 'CfdGetExtkeyInformation', handle.get_handle(), _extkey) + self.version, _fingerprint, _chain_code, self.depth, \ self.child_number = result - self.extkey = extkey + self.fingerprint = ByteData(_fingerprint) + self.chain_code = ByteData(_chain_code) + self.extkey = _extkey if self.extkey_type == ExtKeyType.EXT_PRIVKEY: main, test, name = XPRIV_MAINNET_VERSION,\ XPRIV_TESTNET_VERSION, 'privkey' @@ -196,20 +210,20 @@ def _get_path_data(self, bip32_path, key_type): @classmethod def _create( cls, key_type, network, fingerprint, key, chain_code, - depth, number, parent_key=''): + depth, number, parent_key='') -> str: _network = Network.get_mainchain(network) _fingerprint = '' _path, _num_list = cls._convert_path(number=number) _number = _num_list[0] if len(_num_list) > 0 else number if parent_key == '': - _fingerprint = fingerprint + _fingerprint = str(fingerprint) _network = Network.get_mainchain(network) util = get_util() with util.create_handle() as handle: _extkey = util.call_func( 'CfdCreateExtkey', handle.get_handle(), - _network.value, key_type.value, parent_key, - _fingerprint, key, chain_code, depth, _number) + _network.value, key_type.value, str(parent_key), + _fingerprint, str(key), str(chain_code), depth, _number) return _extkey @@ -220,6 +234,7 @@ class ExtPrivkey(Extkey): ## # @var privkey # privkey + privkey: 'Privkey' ## # @brief create extkey from seed. @@ -227,7 +242,7 @@ class ExtPrivkey(Extkey): # @param[in] network network # @return ExtPrivkey @classmethod - def from_seed(cls, seed, network=Network.MAINNET): + def from_seed(cls, seed, network=Network.MAINNET) -> 'ExtPrivkey': _seed = to_hex_string(seed) _network = Network.get_mainchain(network) util = get_util() @@ -250,7 +265,7 @@ def from_seed(cls, seed, network=Network.MAINNET): @classmethod def create( cls, network, fingerprint, key, chain_code, - depth, number, parent_key=''): + depth: int, number: int, parent_key='') -> 'ExtPrivkey': _extkey = cls._create( ExtKeyType.EXT_PRIVKEY, network, fingerprint, key, chain_code, depth, number, parent_key) @@ -262,16 +277,19 @@ def create( def __init__(self, extkey): super().__init__(ExtKeyType.EXT_PRIVKEY) self._get_information(extkey) - with self.util.create_handle() as handle: - _hex, wif = self.util.call_func( - 'CfdGetPrivkeyFromExtkey', handle.get_handle(), - self.extkey, self.network.value) - self.privkey = Privkey(wif=wif) + if isinstance(extkey, ExtPrivkey): + self.privkey = extkey.privkey + else: + with self.util.create_handle() as handle: + _hex, wif = self.util.call_func( + 'CfdGetPrivkeyFromExtkey', handle.get_handle(), + self.extkey, self.network.value) + self.privkey = Privkey(wif=wif) ## # @brief get string. # @return extkey. - def __str__(self): + def __str__(self) -> str: return self.extkey ## @@ -280,7 +298,8 @@ def __str__(self): # @param[in] number bip32 number # @param[in] number_list bip32 number list # @return ExtPrivkey - def derive(self, path='', number=0, number_list=[]): + def derive(self, path: str = '', number: int = 0, + number_list: typing.List[int] = []) -> 'ExtPrivkey': _path, _list = self._convert_path(path, number, number_list) with self.util.create_handle() as handle: if _path == '': @@ -305,7 +324,8 @@ def derive(self, path='', number=0, number_list=[]): # @param[in] number bip32 number # @param[in] number_list bip32 number list # @return ExtPubkey - def derive_pubkey(self, path='', number=0, number_list=[]): + def derive_pubkey(self, path: str = '', number: int = 0, + number_list: typing.List[int] = []) -> 'ExtPrivkey': return self.derive( path=path, number=number, @@ -314,7 +334,7 @@ def derive_pubkey(self, path='', number=0, number_list=[]): ## # @brief get ext pubkey. # @return ExtPubkey - def get_extpubkey(self): + def get_extpubkey(self) -> 'ExtPubkey': with self.util.create_handle() as handle: ext_pubkey = self.util.call_func( 'CfdCreateExtPubkey', handle.get_handle(), @@ -325,8 +345,10 @@ def get_extpubkey(self): # @brief get extkey path data. # @param[in] bip32_path bip32 path # @param[in] key_type key type - # @return path data - def get_path_data(self, bip32_path, key_type=ExtKeyType.EXT_PRIVKEY): + # @retval [0] path data + # @retval [1] object + def get_path_data(self, bip32_path: str, key_type=ExtKeyType.EXT_PRIVKEY, + ) -> Tuple[int, Union['ExtPubkey', 'ExtPrivkey']]: path_data, child_key = self._get_path_data( bip32_path, key_type) _key_type = ExtKeyType.get(key_type) @@ -343,6 +365,7 @@ class ExtPubkey(Extkey): ## # @var pubkey # pubkey + pubkey: 'Pubkey' ## # @brief create extkey. @@ -357,7 +380,7 @@ class ExtPubkey(Extkey): @classmethod def create( cls, network, fingerprint, key, chain_code, - depth, number, parent_key=''): + depth: int, number: int, parent_key='') -> 'ExtPubkey': _extkey = cls._create( ExtKeyType.EXT_PUBKEY, network, fingerprint, key, chain_code, depth, number, parent_key) @@ -369,16 +392,19 @@ def create( def __init__(self, extkey): super().__init__(ExtKeyType.EXT_PUBKEY) self._get_information(extkey) - with self.util.create_handle() as handle: - hex = self.util.call_func( - 'CfdGetPubkeyFromExtkey', handle.get_handle(), - self.extkey, self.network.value) - self.pubkey = Pubkey(hex) + if isinstance(extkey, ExtPubkey): + self.pubkey = extkey.pubkey + else: + with self.util.create_handle() as handle: + hex = self.util.call_func( + 'CfdGetPubkeyFromExtkey', handle.get_handle(), + self.extkey, self.network.value) + self.pubkey = Pubkey(hex) ## # @brief get string. # @return extkey. - def __str__(self): + def __str__(self) -> str: return self.extkey ## @@ -387,7 +413,8 @@ def __str__(self): # @param[in] number bip32 number # @param[in] number_list bip32 number list # @return ExtPubkey - def derive(self, path='', number=0, number_list=[]): + def derive(self, path: str = '', number: int = 0, + number_list: List[int] = []) -> 'ExtPubkey': _path, _list = self._convert_path(path, number, number_list) with self.util.create_handle() as handle: if len(_path) == 0: @@ -410,8 +437,9 @@ def derive(self, path='', number=0, number_list=[]): ## # @brief get extkey path data. # @param[in] bip32_path bip32 path - # @return path data - def get_path_data(self, bip32_path): + # @retval [0] path data + # @retval [1] object + def get_path_data(self, bip32_path: str) -> Tuple[str, 'ExtPubkey']: path_data, child_key = self._get_path_data( bip32_path, ExtKeyType.EXT_PUBKEY) return path_data, ExtPubkey(child_key) @@ -448,7 +476,7 @@ class MnemonicLanguage(Enum): # @param[in] language language # @return object. @classmethod - def get(cls, language): + def get(cls, language) -> 'MnemonicLanguage': if (isinstance(language, MnemonicLanguage)): return language else: @@ -476,19 +504,22 @@ class HDWallet: ## # @var seed # seed + seed: 'ByteData' ## # @var network # network + network: 'Network' ## # @var ext_privkey # ext privkey + ext_privkey: 'ExtPrivkey' - @classmethod ## # @brief get mnemonic word list. # @param[in] language language # @return word_list mnemonic word list - def get_mnemonic_word_list(cls, language): + @classmethod + def get_mnemonic_word_list(cls, language) -> List[str]: util = get_util() _lang = MnemonicLanguage.get(language).value word_list = [] @@ -512,7 +543,7 @@ def get_mnemonic_word_list(cls, language): # @param[in] language language # @return mnemonic @classmethod - def get_mnemonic(cls, entropy, language): + def get_mnemonic(cls, entropy, language) -> str: _entropy = to_hex_string(entropy) _lang = MnemonicLanguage.get(language).value util = get_util() @@ -529,7 +560,8 @@ def get_mnemonic(cls, entropy, language): # @param[in] strict_check strict check # @return entropy @classmethod - def get_entropy(cls, mnemonic, language, strict_check=True): + def get_entropy(cls, mnemonic: Union[str, List[str]], language, + strict_check: bool = True) -> 'ByteData': _mnemonic = cls._convert_mnemonic(mnemonic) _lang = MnemonicLanguage.get(language).value _mnemonic = unicodedata.normalize('NFKD', _mnemonic) @@ -538,7 +570,7 @@ def get_entropy(cls, mnemonic, language, strict_check=True): _, entropy = util.call_func( 'CfdConvertMnemonicToSeed', handle.get_handle(), _mnemonic, '', strict_check, _lang, False) - return entropy + return ByteData(entropy) ## # @brief create extkey from seed. @@ -546,7 +578,7 @@ def get_entropy(cls, mnemonic, language, strict_check=True): # @param[in] network network # @return HDWallet @classmethod - def from_seed(cls, seed, network=Network.MAINNET): + def from_seed(cls, seed, network=Network.MAINNET) -> 'HDWallet': return HDWallet(seed=seed, network=network) ## @@ -559,8 +591,8 @@ def from_seed(cls, seed, network=Network.MAINNET): # @return HDWallet @classmethod def from_mnemonic( - cls, mnemonic, language='en', passphrase='', - network=Network.MAINNET, strict_check=True): + cls, mnemonic: Union[str, List[str]], language='en', passphrase: str = '', + network=Network.MAINNET, strict_check: bool = True) -> 'HDWallet': return HDWallet( mnemonic=mnemonic, language=language, passphrase=passphrase, network=network, strict_check=strict_check) @@ -574,9 +606,10 @@ def from_mnemonic( # @param[in] network network # @param[in] strict_check strict check def __init__( - self, seed='', mnemonic='', language='en', passphrase='', - network=Network.MAINNET, strict_check=True): - self.seed = to_hex_string(seed) + self, seed='', mnemonic: Union[str, List[str]] = '', + language='en', passphrase: str = '', + network=Network.MAINNET, strict_check: bool = True): + self.seed = ByteData(seed) self.network = Network.get_mainchain(network) _mnemonic = self._convert_mnemonic(mnemonic) _lang = MnemonicLanguage.get(language).value @@ -585,10 +618,11 @@ def __init__( if _mnemonic != '': util = get_util() with util.create_handle() as handle: - self.seed, _ = util.call_func( + _seed, _ = util.call_func( 'CfdConvertMnemonicToSeed', handle.get_handle(), _mnemonic, _passphrase, strict_check, _lang, False) + self.seed = ByteData(_seed) self.ext_privkey = ExtPrivkey.from_seed(self.seed, self.network) ## @@ -597,7 +631,8 @@ def __init__( # @param[in] number bip32 number # @param[in] number_list bip32 number list # @return ExtPrivkey - def get_privkey(self, path='', number=0, number_list=[]): + def get_privkey(self, path: str = '', number: int = 0, + number_list: List[int] = []) -> 'ExtPrivkey': return self.ext_privkey.derive(path, number, number_list) ## @@ -606,7 +641,8 @@ def get_privkey(self, path='', number=0, number_list=[]): # @param[in] number bip32 number # @param[in] number_list bip32 number list # @return ExtPubkey - def get_pubkey(self, path='', number=0, number_list=[]): + def get_pubkey(self, path: str = '', number: int = 0, + number_list: List[int] = []) -> 'ExtPubkey': return self.ext_privkey.derive_pubkey(path, number, number_list) ## diff --git a/cfd/key.py b/cfd/key.py index 6705b5b..bc10d61 100644 --- a/cfd/key.py +++ b/cfd/key.py @@ -3,7 +3,9 @@ # @file key.py # @brief key function implements file. # @note Copyright 2020 CryptoGarage -from .util import get_util, CfdError, to_hex_string, CfdErrorCode, JobHandle +from typing import Optional, Union +import typing +from .util import ByteData, get_util, CfdError, to_hex_string, CfdErrorCode, JobHandle import hashlib from enum import Enum @@ -34,13 +36,13 @@ class Network(Enum): ## # @brief get string. # @return name. - def __str__(self): + def __str__(self) -> str: return self.name.lower().replace('_', '') ## # @brief get string. # @return name. - def as_str(self): + def as_str(self) -> str: return self.name.lower().replace('_', '') ## @@ -48,7 +50,7 @@ def as_str(self): # @param[in] network network # @return object. @classmethod - def get(cls, network): + def get(cls, network) -> 'Network': if (isinstance(network, Network)): return network elif (isinstance(network, int)): @@ -74,7 +76,7 @@ def get(cls, network): # @param[in] network network # @return object. @classmethod - def get_mainchain(cls, network): + def get_mainchain(cls, network) -> 'Network': _network = cls.get(network) if _network == Network.LIQUID_V1: _network = Network.MAINNET @@ -109,32 +111,32 @@ class SigHashType(Enum): ## # @brief get string. # @return name. - def __str__(self): + def __str__(self) -> str: return self.name.lower().replace('_', '') ## # @brief get string. # @return name. - def as_str(self): + def as_str(self) -> str: return self.name.lower().replace('_', '') ## # @brief get type value. # @return value. - def get_type(self): + def get_type(self) -> int: return self.value & 0x0f ## # @brief get anyone can pay flag. # @retval True anyone can pay is true. # @retval False anyone can pay is false. - def anyone_can_pay(self): + def anyone_can_pay(self) -> bool: return self.value >= 0x80 ## # @brief get type object. # @return object. - def get_type_object(self): + def get_type_object(self) -> 'SigHashType': return self.get(self.get_type()) ## @@ -143,7 +145,7 @@ def get_type_object(self): # @param[in] anyone_can_pay anyone can pay flag # @return object. @classmethod - def get(cls, sighashtype, anyone_can_pay=False): + def get(cls, sighashtype, anyone_can_pay: bool = False) -> 'SigHashType': if (isinstance(sighashtype, SigHashType)): if anyone_can_pay is True: return cls.get(sighashtype.value | 0x80) @@ -176,21 +178,27 @@ class Privkey: ## # @var hex # privkey hex + hex: str ## # @var wif # wallet import format + wif: str ## # @var network # network type. + network: 'Network' ## # @var is_compressed # pubkey compressed flag + is_compressed: bool ## # @var wif_first # wif set flag. + wif_first: bool ## # @var pubkey # pubkey + pubkey: 'Pubkey' ## # @brief generate key pair. @@ -198,7 +206,7 @@ class Privkey: # @param[in] network network type # @return private key @classmethod - def generate(cls, is_compressed=True, network=Network.MAINNET): + def generate(cls, is_compressed: bool = True, network=Network.MAINNET): _network = Network.get_mainchain(network) util = get_util() with util.create_handle() as handle: @@ -214,16 +222,30 @@ def generate(cls, is_compressed=True, network=Network.MAINNET): # @param[in] is_compressed pubkey compressed # @return private key @classmethod - def from_hex(cls, hex, network=Network.MAINNET, is_compressed=True): + def from_hex(cls, hex, network=Network.MAINNET, is_compressed: bool = True): return Privkey(hex=hex, network=network, is_compressed=is_compressed) + ## + # @brief create privkey from hex string. + # @param[in] hex hex string + # @param[in] network network type + # @param[in] is_compressed pubkey compressed + # @return private key or None + @classmethod + def from_hex_ignore_error( + cls, hex, network=Network.MAINNET, + is_compressed: bool = True) -> Optional['Privkey']: + if not hex: + return None + return Privkey(hex=hex, network=network, is_compressed=is_compressed) + ## # @brief create privkey from hex string. # @param[in] wif wallet import format # @return private key @classmethod - def from_wif(cls, wif): + def from_wif(cls, wif: str) -> 'Privkey': return Privkey(wif=wif) ## @@ -234,10 +256,10 @@ def from_wif(cls, wif): # @param[in] is_compressed pubkey compressed def __init__( self, - wif='', + wif: str = '', hex='', network=Network.MAINNET, - is_compressed=True): + is_compressed: bool = True): self.hex = to_hex_string(hex) self.wif = wif self.network = Network.get_mainchain(network) @@ -256,21 +278,22 @@ def __init__( 'CfdParsePrivkeyWif', handle.get_handle(), self.wif) self.network = Network.get_mainchain(self.network) - self.pubkey = util.call_func( + _pubkey = util.call_func( 'CfdGetPubkeyFromPrivkey', handle.get_handle(), self.hex, '', self.is_compressed) + self.pubkey = Pubkey(_pubkey) ## # @brief get string. # @return pubkey hex. - def __str__(self): + def __str__(self) -> str: return self.wif if (self.wif_first) else self.hex ## # @brief add tweak. # @param[in] tweak tweak bytes. (32 byte) # @return tweaked private key - def add_tweak(self, tweak): + def add_tweak(self, tweak) -> 'Privkey': _tweak = to_hex_string(tweak) util = get_util() with util.create_handle() as handle: @@ -285,7 +308,7 @@ def add_tweak(self, tweak): # @brief mul tweak. # @param[in] tweak tweak bytes. (32 byte) # @return tweaked private key - def mul_tweak(self, tweak): + def mul_tweak(self, tweak) -> 'Privkey': _tweak = to_hex_string(tweak) util = get_util() with util.create_handle() as handle: @@ -299,7 +322,7 @@ def mul_tweak(self, tweak): ## # @brief negate. # @return negated private key - def negate(self): + def negate(self) -> 'Privkey': util = get_util() with util.create_handle() as handle: _key = util.call_func( @@ -313,7 +336,7 @@ def negate(self): # @param[in] sighash sighash # @param[in] grind_r grind-r flag # @return signature - def calculate_ec_signature(self, sighash, grind_r=True): + def calculate_ec_signature(self, sighash, grind_r: bool = True) -> 'SignParameter': _sighash = to_hex_string(sighash) util = get_util() with util.create_handle() as handle: @@ -332,13 +355,14 @@ class Pubkey: ## # @var _hex # pubkey hex + _hex: str ## # @brief combine pubkey. # @param[in] pubkey_list pubkey list # @return combined pubkey @classmethod - def combine(cls, pubkey_list): + def combine(cls, pubkey_list) -> 'Pubkey': if (isinstance(pubkey_list, list) is False) or ( len(pubkey_list) <= 1): raise CfdError( @@ -365,7 +389,10 @@ def combine(cls, pubkey_list): # @brief constructor. # @param[in] pubkey pubkey def __init__(self, pubkey): - self._hex = to_hex_string(pubkey) + if isinstance(pubkey, Pubkey): + self._hex = pubkey._hex + else: + self._hex = to_hex_string(pubkey) # validate util = get_util() with util.create_handle() as handle: @@ -375,13 +402,13 @@ def __init__(self, pubkey): ## # @brief get string. # @return pubkey hex. - def __str__(self): + def __str__(self) -> str: return self._hex ## # @brief compress pubkey. # @return compressed pubkey. - def compress(self): + def compress(self) -> 'Pubkey': util = get_util() with util.create_handle() as handle: _pubkey = util.call_func( @@ -391,7 +418,7 @@ def compress(self): ## # @brief uncompress pubkey. # @return uncompressed pubkey. - def uncompress(self): + def uncompress(self) -> 'Pubkey': util = get_util() with util.create_handle() as handle: _pubkey = util.call_func( @@ -402,7 +429,7 @@ def uncompress(self): # @brief add tweak. # @param[in] tweak tweak bytes. (32 byte) # @return tweaked public key - def add_tweak(self, tweak): + def add_tweak(self, tweak) -> 'Pubkey': _tweak = to_hex_string(tweak) util = get_util() with util.create_handle() as handle: @@ -415,7 +442,7 @@ def add_tweak(self, tweak): # @brief mul tweak. # @param[in] tweak tweak bytes. (32 byte) # @return tweaked public key - def mul_tweak(self, tweak): + def mul_tweak(self, tweak) -> 'Pubkey': _tweak = to_hex_string(tweak) util = get_util() with util.create_handle() as handle: @@ -427,7 +454,7 @@ def mul_tweak(self, tweak): ## # @brief negate. # @return negated public key - def negate(self): + def negate(self) -> 'Pubkey': util = get_util() with util.create_handle() as handle: _pubkey = util.call_func( @@ -440,13 +467,13 @@ def negate(self): # @param[in] signature signature # @retval True Verify success. # @retval False Verify fail. - def verify_ec_signature(self, sighash, signature): + def verify_ec_signature(self, sighash, signature) -> bool: try: util = get_util() with util.create_handle() as handle: util.call_func( 'CfdVerifyEcSignature', handle.get_handle(), - sighash, self._hex, signature) + to_hex_string(sighash), self._hex, to_hex_string(signature)) return True except CfdError as err: if err.error_code == CfdErrorCode.SIGN_VERIFICATION.value: @@ -462,15 +489,19 @@ class SignParameter: ## # @var hex # hex data + hex: str ## # @var related_pubkey # related pubkey for multisig + related_pubkey: Union[str, 'Pubkey'] ## # @var sighashtype # sighash type + sighashtype: 'SigHashType' ## # @var use_der_encode # use der encode. + use_der_encode: bool ## # @brief encode signature to der. @@ -478,7 +509,7 @@ class SignParameter: # @param[in] sighashtype sighash type # @return der encoded signature @classmethod - def encode_by_der(cls, signature, sighashtype=SigHashType.ALL): + def encode_by_der(cls, signature, sighashtype=SigHashType.ALL) -> 'SignParameter': _signature = to_hex_string(signature) _sighashtype = SigHashType.get(sighashtype) util = get_util() @@ -494,7 +525,7 @@ def encode_by_der(cls, signature, sighashtype=SigHashType.ALL): # @param[in] signature signature # @return der decoded signature @classmethod - def decode_from_der(cls, signature): + def decode_from_der(cls, signature) -> 'SignParameter': der_signature = to_hex_string(signature) util = get_util() with util.create_handle() as handle: @@ -509,7 +540,7 @@ def decode_from_der(cls, signature): # @param[in] signature signature # @return normalized signature @classmethod - def normalize(cls, signature): + def normalize(cls, signature) -> 'SignParameter': _signature = to_hex_string(signature) _sighashtype = SigHashType.ALL if isinstance(signature, SignParameter): @@ -529,20 +560,23 @@ def normalize(cls, signature): def __init__(self, data, related_pubkey='', sighashtype=SigHashType.ALL, use_der_encode=False): self.hex = to_hex_string(data) - self.related_pubkey = related_pubkey + if isinstance(related_pubkey, Pubkey): + self.related_pubkey = related_pubkey + else: + self.related_pubkey = to_hex_string(related_pubkey) self.sighashtype = SigHashType.get(sighashtype) self.use_der_encode = use_der_encode ## # @brief get string. # @return sing data hex. - def __str__(self): + def __str__(self) -> str: return self.hex ## # @brief set der encode flag. # @return void - def set_der_encode(self): + def set_der_encode(self) -> None: self.use_der_encode = True @@ -560,7 +594,7 @@ class EcdsaAdaptor: # @retval result[1] adaptor proof @classmethod def sign(cls, message, secret_key, adaptor, - is_message_hashed=True): + is_message_hashed=True) -> typing.Tuple['ByteData', 'ByteData']: msg = message if (not is_message_hashed) and isinstance(message, str): m = hashlib.sha256() @@ -574,7 +608,7 @@ def sign(cls, message, secret_key, adaptor, signature, proof = util.call_func( 'CfdSignEcdsaAdaptor', handle.get_handle(), _msg, _sk, _adaptor) - return signature, proof + return ByteData(signature), ByteData(proof) ## # @brief adapt. @@ -582,14 +616,14 @@ def sign(cls, message, secret_key, adaptor, # @param[in] adaptor_secret adaptor secret key # @return adapted signature @classmethod - def adapt(cls, adaptor_signature, adaptor_secret): + def adapt(cls, adaptor_signature, adaptor_secret) -> 'ByteData': _sig = to_hex_string(adaptor_signature) _sk = to_hex_string(adaptor_secret) util = get_util() with util.create_handle() as handle: signature = util.call_func( 'CfdAdaptEcdsaAdaptor', handle.get_handle(), _sig, _sk) - return signature + return ByteData(signature) ## # @brief extract secret. @@ -598,7 +632,7 @@ def adapt(cls, adaptor_signature, adaptor_secret): # @param[in] adaptor adaptor bytes # @return adaptor secret key @classmethod - def extract_secret(cls, adaptor_signature, signature, adaptor): + def extract_secret(cls, adaptor_signature, signature, adaptor) -> 'Privkey': _adaptor_signature = to_hex_string(adaptor_signature) _signature = to_hex_string(signature) _adaptor = to_hex_string(adaptor) @@ -621,7 +655,7 @@ def extract_secret(cls, adaptor_signature, signature, adaptor): # @retval False Verify fail. @classmethod def verify(cls, adaptor_signature, proof, adaptor, message, pubkey, - is_message_hashed=True): + is_message_hashed: bool = True) -> bool: msg = message if (not is_message_hashed) and isinstance(message, str): m = hashlib.sha256() @@ -653,6 +687,7 @@ class SchnorrPubkey: ## # @var hex # hex data + hex: str ## # @brief create SchnorrPubkey from privkey. @@ -660,7 +695,7 @@ class SchnorrPubkey: # @retval [0] SchnorrPubkey # @retval [1] parity flag @classmethod - def from_privkey(cls, privkey): + def from_privkey(cls, privkey) -> 'SchnorrPubkey': if isinstance(privkey, Privkey): _privkey = privkey.hex elif isinstance(privkey, str) and (len(privkey) != 64): @@ -681,7 +716,7 @@ def from_privkey(cls, privkey): # @retval [0] SchnorrPubkey # @retval [1] parity flag @classmethod - def from_pubkey(cls, pubkey): + def from_pubkey(cls, pubkey) -> 'SchnorrPubkey': _pubkey = to_hex_string(pubkey) util = get_util() with util.create_handle() as handle: @@ -698,7 +733,7 @@ def from_pubkey(cls, pubkey): # @retval [1] tweaked parity flag # @retval [2] tweaked Privkey @classmethod - def add_tweak_from_privkey(cls, privkey, tweak): + def add_tweak_from_privkey(cls, privkey, tweak) -> 'SchnorrPubkey': if isinstance(privkey, Privkey): _privkey = privkey.hex elif isinstance(privkey, str) and (len(privkey) != 64): @@ -727,7 +762,7 @@ def __init__(self, data): ## # @brief get string. # @return pubkey hex. - def __str__(self): + def __str__(self) -> str: return self.hex ## @@ -735,7 +770,7 @@ def __str__(self): # @param[in] tweak tweak data # @retval [0] tweaked SchnorrPubkey # @retval [1] tweaked parity flag - def add_tweak(self, tweak): + def add_tweak(self, tweak) -> 'SchnorrPubkey': _tweak = to_hex_string(tweak) util = get_util() with util.create_handle() as handle: @@ -751,7 +786,7 @@ def add_tweak(self, tweak): # @param[in] tweak tweak data # @retval True tweaked pubkey from base pubkey. # @retval False other. - def is_tweaked(self, tweaked_parity, base_pubkey, tweak): + def is_tweaked(self, tweaked_parity, base_pubkey, tweak) -> bool: _base_pubkey = to_hex_string(base_pubkey) _tweak = to_hex_string(tweak) try: @@ -775,12 +810,15 @@ class SchnorrSignature: ## # @var signature # signature data + signature: str ## # @var nonce # nonce data + nonce: 'SchnorrPubkey' ## # @var key # key data + key: 'Privkey' ## # @brief constructor. @@ -789,16 +827,16 @@ def __init__(self, signature): self.signature = to_hex_string(signature) util = get_util() with util.create_handle() as handle: - self.nonce, self.key = util.call_func( + _nonce, self.key = util.call_func( 'CfdSplitSchnorrSignature', handle.get_handle(), self.signature) - self.nonce = SchnorrPubkey(self.nonce) + self.nonce = SchnorrPubkey(_nonce) self.key = Privkey(hex=self.key) ## # @brief get string. # @return signature hex. - def __str__(self): + def __str__(self) -> str: return self.signature @@ -816,7 +854,7 @@ class SchnorrUtil: # @return signature @classmethod def sign(cls, message, secret_key, aux_rand='', nonce='', - is_message_hashed=True): + is_message_hashed: bool = True) -> 'SchnorrSignature': msg = message if (not is_message_hashed) and isinstance(message, str): m = hashlib.sha256() @@ -846,7 +884,7 @@ def sign(cls, message, secret_key, aux_rand='', nonce='', # @return signature @classmethod def compute_sig_point(cls, message, nonce, pubkey, - is_message_hashed=True): + is_message_hashed: bool = True) -> 'Pubkey': msg = message if (not is_message_hashed) and isinstance(message, str): m = hashlib.sha256() @@ -872,7 +910,7 @@ def compute_sig_point(cls, message, nonce, pubkey, # @retval False Verify fail. @classmethod def verify(cls, signature, message, pubkey, - is_message_hashed=True): + is_message_hashed: bool = True) -> bool: msg = message if (not is_message_hashed) and isinstance(message, str): m = hashlib.sha256() diff --git a/cfd/script.py b/cfd/script.py index 43c0fbf..a1222ed 100644 --- a/cfd/script.py +++ b/cfd/script.py @@ -3,6 +3,7 @@ # @file script.py # @brief bitcoin script function implements file. # @note Copyright 2020 CryptoGarage +from typing import List from .util import CfdError, to_hex_string, get_util, JobHandle from .key import SignParameter, SigHashType from enum import Enum @@ -34,13 +35,13 @@ class HashType(Enum): ## # @brief get string. # @return name. - def __str__(self): + def __str__(self) -> str: return self.name.lower().replace('_', '-') ## # @brief get string. # @return name. - def as_str(self): + def as_str(self) -> str: return self.name.lower().replace('_', '-') ## @@ -48,7 +49,7 @@ def as_str(self): # @param[in] hashtype hashtype # @return object @classmethod - def get(cls, hashtype): + def get(cls, hashtype) -> 'HashType': if (isinstance(hashtype, HashType)): return hashtype elif (isinstance(hashtype, int)): @@ -77,16 +78,18 @@ class Script: ## # @var hex # script hex + hex: str ## # @var asm # asm + asm: str ## # @brief get script from asm. # @param[in] script_items asm strings (list or string) # @return script object @classmethod - def from_asm(cls, script_items): + def from_asm(cls, script_items: List[str]) -> 'Script': _asm = script_items if isinstance(script_items, list): _asm = ' '.join(script_items) @@ -106,7 +109,9 @@ def from_asm(cls, script_items): # @param[in] sign_parameter_list signature list # @return script object @classmethod - def create_multisig_scriptsig(cls, redeem_script, sign_parameter_list): + def create_multisig_scriptsig( + cls, redeem_script, + sign_parameter_list: List['SignParameter']) -> 'Script': _script = to_hex_string(redeem_script) util = get_util() with util.create_handle() as handle: @@ -144,13 +149,17 @@ def create_multisig_scriptsig(cls, redeem_script, sign_parameter_list): # @brief constructor. # @param[in] script script def __init__(self, script): - self.hex = to_hex_string(script) - self.asm = Script._parse(self.hex) + if isinstance(script, Script): + self.hex = script.hex + self.asm = script.asm + else: + self.hex = to_hex_string(script) + self.asm = Script._parse(self.hex) ## # @brief get string. # @return script hex. - def __str__(self): + def __str__(self) -> str: return self.hex ## diff --git a/cfd/transaction.py b/cfd/transaction.py index c59de70..986aeb7 100644 --- a/cfd/transaction.py +++ b/cfd/transaction.py @@ -3,11 +3,14 @@ # @file transaction.py # @brief transaction function implements file. # @note Copyright 2020 CryptoGarage +from typing import AnyStr, List, Optional, Tuple, Union +import typing from .util import get_util, JobHandle, CfdError, to_hex_string,\ CfdErrorCode, ReverseByteData, ByteData from .address import Address, AddressUtil from .key import Network, SigHashType, SignParameter, Privkey -from .script import HashType +from .script import HashType, Script +from .descriptor import Descriptor from enum import Enum import ctypes import copy @@ -34,15 +37,17 @@ class OutPoint: ## # @var txid # txid + txid: 'Txid' ## # @var vout # vout + vout: int ## # @brief constructor. # @param[in] txid txid # @param[in] vout vout - def __init__(self, txid, vout): + def __init__(self, txid, vout: int): self.txid = Txid(txid) self.vout = vout if isinstance(vout, int) is False: @@ -53,14 +58,14 @@ def __init__(self, txid, vout): ## # @brief get string. # @return txid. - def __str__(self): + def __str__(self) -> str: return '{},{}'.format(str(self.txid), self.vout) ## # @brief equal method. # @param[in] other other object. # @return true or false. - def __eq__(self, other): + def __eq__(self, other: 'OutPoint') -> bool: if not isinstance(other, OutPoint): return NotImplemented return (self.txid.hex == other.txid.hex) and ( @@ -70,7 +75,7 @@ def __eq__(self, other): # @brief diff method. # @param[in] other other object. # @return true or false. - def __lt__(self, other): + def __lt__(self, other: 'OutPoint') -> bool: if not isinstance(other, OutPoint): return NotImplemented return (self.txid.hex, self.vout) < (other.txid.hex, other.vout) @@ -79,28 +84,28 @@ def __lt__(self, other): # @brief equal method. # @param[in] other other object. # @return true or false. - def __ne__(self, other): + def __ne__(self, other: 'OutPoint') -> bool: return not self.__eq__(other) ## # @brief diff method. # @param[in] other other object. # @return true or false. - def __le__(self, other): + def __le__(self, other: 'OutPoint') -> bool: return self.__lt__(other) or self.__eq__(other) ## # @brief diff method. # @param[in] other other object. # @return true or false. - def __gt__(self, other): + def __gt__(self, other: 'OutPoint') -> bool: return not self.__le__(other) ## # @brief diff method. # @param[in] other other object. # @return true or false. - def __ge__(self, other): + def __ge__(self, other: 'OutPoint') -> bool: return not self.__lt__(other) @@ -111,15 +116,19 @@ class UtxoData: ## # @var outpoint # outpoint + outpoint: 'OutPoint' ## # @var amount # amount + amount: int ## # @var descriptor # descriptor + descriptor: Union[str, 'Descriptor'] ## # @var scriptsig_template # scriptsig template + scriptsig_template: Union['Script', 'ByteData', AnyStr] ## # @brief constructor. @@ -130,8 +139,10 @@ class UtxoData: # @param[in] descriptor descriptor # @param[in] scriptsig_template scriptsig template def __init__( - self, outpoint=None, txid='', vout=0, - amount=0, descriptor='', scriptsig_template=''): + self, outpoint: Optional['OutPoint'] = None, + txid='', vout: int = 0, + amount: int = 0, descriptor: Union[str, 'Descriptor'] = '', + scriptsig_template: Union['Script', 'ByteData', AnyStr] = ''): if isinstance(outpoint, OutPoint): self.outpoint = outpoint else: @@ -200,15 +211,19 @@ class TxIn: ## # @var outpoint # outpoint + outpoint: 'OutPoint' ## # @var sequence # sequence + sequence: int ## # @var script_sig # script sig + script_sig: 'Script' ## # @var witness_stack # witness stack + witness_stack: List[Union['Script', 'ByteData', AnyStr]] ## # sequence disable. @@ -223,7 +238,7 @@ class TxIn: # @param[in] sequence sequence # @return sequence number. @classmethod - def get_sequence_number(cls, locktime=0, sequence=SEQUENCE_DISABLE): + def get_sequence_number(cls, locktime: int = 0, sequence: int = SEQUENCE_DISABLE): if sequence not in [-1, TxIn.SEQUENCE_DISABLE]: return sequence elif locktime == 0: @@ -237,20 +252,20 @@ def get_sequence_number(cls, locktime=0, sequence=SEQUENCE_DISABLE): # @param[in] txid txid # @param[in] vout vout # @param[in] sequence sequence - def __init__(self, outpoint=None, txid='', vout=0, - sequence=SEQUENCE_DISABLE): + def __init__(self, outpoint: Optional['OutPoint'] = None, + txid='', vout: int = 0, sequence: int = SEQUENCE_DISABLE): if isinstance(outpoint, OutPoint): self.outpoint = outpoint else: self.outpoint = OutPoint(txid=txid, vout=vout) self.sequence = sequence - self.script_sig = '' + self.script_sig = Script('') self.witness_stack = [] ## # @brief get string. # @return hex. - def __str__(self): + def __str__(self) -> str: return str(self.outpoint) @@ -261,32 +276,36 @@ class TxOut: ## # @var amount # amount + amount: int ## # @var address # address + address: Union['Address', str] ## # @var locking_script # locking script + locking_script: 'Script' ## # @brief constructor. # @param[in] amount amount # @param[in] address address # @param[in] locking_script locking script - def __init__(self, amount, address='', locking_script=''): + def __init__(self, amount: int, address='', locking_script=''): self.amount = amount if address != '': - self.address = address - self.locking_script = '' + self.address = address if isinstance( + address, Address) else str(address) + self.locking_script = Script('') else: - self.locking_script = locking_script + self.locking_script = Script(locking_script) self.address = '' ## # @brief constructor. # @param[in] network network # @return address. - def get_address(self, network=Network.MAINNET): + def get_address(self, network=Network.MAINNET) -> 'Address': if isinstance(self.address, Address): return self.address if self.address != '': @@ -296,7 +315,7 @@ def get_address(self, network=Network.MAINNET): ## # @brief get string. # @return address or script. - def __str__(self): + def __str__(self) -> str: if (self.address != ''): return str(self.address) else: @@ -310,12 +329,15 @@ class _TransactionBase: ## # @var hex # transaction hex string + hex: str ## # @var network # transaction network type + network: int ## # @var enable_cache # use transaction cache + enable_cache: bool ## # @brief constructor. @@ -325,12 +347,12 @@ class _TransactionBase: def __init__(self, hex, network, enable_cache=True): self.hex = to_hex_string(hex) self.enable_cache = enable_cache - self.network = network + self.network = Network.get(network).value ## # @brief get string. # @return tx hex. - def __str__(self): + def __str__(self) -> str: return self.hex ## @@ -380,7 +402,8 @@ def _get_txin(self, handle, tx_handle, index=0, outpoint=None): # @param[in] txid txid # @param[in] vout vout # @return index - def get_txin_index(self, outpoint=None, txid='', vout=0): + def get_txin_index(self, outpoint: Optional['OutPoint'] = None, + txid='', vout=0) -> int: txin = TxIn(outpoint=outpoint, txid=txid, vout=vout) util = get_util() with util.create_handle() as handle: @@ -395,7 +418,7 @@ def get_txin_index(self, outpoint=None, txid='', vout=0): # @param[in] address address # @param[in] locking_script locking_script # @return index - def get_txout_index(self, address='', locking_script=''): + def get_txout_index(self, address='', locking_script='') -> int: # get first target only. _script = to_hex_string(locking_script) util = get_util() @@ -414,8 +437,8 @@ def get_txout_index(self, address='', locking_script=''): # @param[in] sighashtype sighash type # @return void def add_pubkey_hash_sign( - self, outpoint, hash_type, pubkey, signature, - sighashtype=SigHashType.ALL): + self, outpoint: 'OutPoint', hash_type, pubkey, signature, + sighashtype=SigHashType.ALL) -> None: _hash_type = HashType.get(hash_type) _pubkey = to_hex_string(pubkey) _signature = to_hex_string(signature) @@ -442,8 +465,8 @@ def add_pubkey_hash_sign( # @param[in] signature_list signature list # @return void def add_multisig_sign( - self, outpoint, hash_type, redeem_script, - signature_list): + self, outpoint: 'OutPoint', hash_type, redeem_script, + signature_list) -> None: if (isinstance(signature_list, list) is False) or ( len(signature_list) == 0): raise CfdError( @@ -496,8 +519,8 @@ def add_multisig_sign( # @param[in] signature_list signature list # @return void def add_script_hash_sign( - self, outpoint, hash_type, redeem_script, - signature_list): + self, outpoint: 'OutPoint', hash_type, redeem_script, + signature_list) -> None: if (isinstance(signature_list, list) is False) or ( len(signature_list) == 0): raise CfdError( @@ -542,9 +565,9 @@ def add_script_hash_sign( # @param[in] sighashtype sighash type # @return void def add_sign( - self, outpoint, hash_type, sign_data, - clear_stack=False, use_der_encode=False, - sighashtype=SigHashType.ALL): + self, outpoint: 'OutPoint', hash_type, sign_data, + clear_stack: bool = False, use_der_encode: bool = False, + sighashtype=SigHashType.ALL) -> None: _hash_type = HashType.get(hash_type) _sign_data = sign_data if not isinstance(sign_data, str): @@ -568,33 +591,43 @@ class Transaction(_TransactionBase): ## # @var hex # transaction hex string + hex: str ## # @var txin_list # transaction input list + txin_list: List['TxIn'] ## # @var txout_list # transaction output list + txout_list: List['TxOut'] ## # @var txid # txid + txid: 'Txid' ## # @var wtxid # wtxid + wtxid: 'Txid' ## # @var size # transaction size + size: int ## # @var vsize # transaction vsize + vsize: int ## # @var weight # transaction size weight + weight: int ## # @var version # version + version: int ## # @var locktime # locktime + locktime: int ## # bitcoin network value. @@ -609,7 +642,7 @@ class Transaction(_TransactionBase): # @param[in] network network # @return json string @classmethod - def parse_to_json(cls, hex, network=Network.MAINNET): + def parse_to_json(cls, hex: str, network=Network.MAINNET) -> str: _network = Network.get(network) network_str = 'mainnet' if _network == Network.TESTNET: @@ -632,7 +665,8 @@ def parse_to_json(cls, hex, network=Network.MAINNET): # @param[in] enable_cache enable tx cache # @return transaction object @classmethod - def create(cls, version, locktime, txins, txouts, enable_cache=True): + def create(cls, version: int, locktime: int, txins: List['TxIn'], + txouts: List['TxOut'], enable_cache: bool = True) -> 'Transaction': util = get_util() with util.create_handle() as handle: _tx_handle = util.call_func( @@ -663,14 +697,14 @@ def create(cls, version, locktime, txins, txouts, enable_cache=True): # @param[in] enable_cache enable tx cache # @return transaction object @classmethod - def from_hex(cls, hex, enable_cache=True): + def from_hex(cls, hex, enable_cache: bool = True) -> 'Transaction': return Transaction(hex, enable_cache) ## # @brief constructor. # @param[in] hex tx hex # @param[in] enable_cache enable tx cache - def __init__(self, hex, enable_cache=True): + def __init__(self, hex, enable_cache: bool = True): super().__init__(hex, self.NETWORK, enable_cache) self.txin_list = [] self.txout_list = [] @@ -725,7 +759,7 @@ def _update_txin(self, outpoint): # @brief get transaction all data. # @retval [0] txin list # @retval [1] txout list - def get_tx_all(self): + def get_tx_all(self) -> typing.Tuple[List['TxIn'], List['TxOut']]: def get_txin_list(handle, tx_handle): txin_list = [] _count = util.call_func( @@ -773,8 +807,8 @@ def get_txout_list(handle, tx_handle): # @param[in] txid txid # @param[in] vout vout # @return void - def add_txin(self, outpoint=None, sequence=-1, - txid='', vout=0): + def add_txin(self, outpoint: Optional['OutPoint'] = None, + sequence: int = -1, txid='', vout: int = 0) -> None: sec = TxIn.get_sequence_number(self.locktime, sequence) txin = TxIn( outpoint=outpoint, sequence=sec, txid=txid, vout=vout) @@ -786,7 +820,7 @@ def add_txin(self, outpoint=None, sequence=-1, # @param[in] address address # @param[in] locking_script locking script # @return void - def add_txout(self, amount, address='', locking_script=''): + def add_txout(self, amount: int, address='', locking_script='') -> None: txout = TxOut(amount, address, locking_script) self.add([], [txout]) @@ -795,7 +829,7 @@ def add_txout(self, amount, address='', locking_script=''): # @param[in] txins txin list # @param[in] txouts txout list # @return void - def add(self, txins, txouts): + def add(self, txins: List['TxIn'], txouts: List['TxOut']) -> None: util = get_util() with util.create_handle() as handle: _tx_handle = util.call_func( @@ -833,7 +867,7 @@ def add(self, txins, txouts): # @param[in] index index # @param[in] amount amount # @return void - def update_txout_amount(self, index, amount): + def update_txout_amount(self, index: int, amount: int): util = get_util() with util.create_handle() as handle: self.hex = util.call_func( @@ -853,12 +887,12 @@ def update_txout_amount(self, index, amount): # @return sighash def get_sighash( self, - outpoint, + outpoint: 'OutPoint', hash_type, - amount=0, + amount: int = 0, pubkey='', redeem_script='', - sighashtype=SigHashType.ALL): + sighashtype=SigHashType.ALL) -> 'ByteData': _hash_type = HashType.get(hash_type) _pubkey = to_hex_string(pubkey) _script = to_hex_string(redeem_script) @@ -884,12 +918,12 @@ def get_sighash( # @return void def sign_with_privkey( self, - outpoint, + outpoint: 'OutPoint', hash_type, privkey, - amount=0, + amount: int = 0, sighashtype=SigHashType.ALL, - grind_r=True): + grind_r: bool = True) -> None: _hash_type = HashType.get(hash_type) if isinstance(privkey, Privkey): _privkey = privkey @@ -916,7 +950,8 @@ def sign_with_privkey( # @param[in] hash_type hash type # @param[in] amount amount # @return void - def verify_sign(self, outpoint, address, hash_type, amount): + def verify_sign(self, outpoint: 'OutPoint', address, hash_type, + amount: int) -> None: _hash_type = HashType.get(hash_type) util = get_util() with util.create_handle() as handle: @@ -938,8 +973,8 @@ def verify_sign(self, outpoint, address, hash_type, amount): # @retval True signature valid. # @retval False signature invalid. def verify_signature( - self, outpoint, signature, hash_type, pubkey, amount=0, - redeem_script='', sighashtype=SigHashType.ALL): + self, outpoint: 'OutPoint', signature, hash_type, pubkey, + amount: int = 0, redeem_script='', sighashtype=SigHashType.ALL) -> bool: _signature = to_hex_string(signature) _pubkey = to_hex_string(pubkey) _script = to_hex_string(redeem_script) @@ -974,10 +1009,11 @@ def verify_signature( # @retval [1] utxo fee. # @retval [2] total tx fee. @classmethod - def select_coins( - cls, utxo_list, tx_fee_amount, target_amount, - effective_fee_rate=20.0, long_term_fee_rate=20.0, - dust_fee_rate=3.0, knapsack_min_change=-1): + def select_coins(cls, utxo_list: List['UtxoData'], tx_fee_amount: int, + target_amount: int, effective_fee_rate: float = 20.0, + long_term_fee_rate: float = 20.0, dust_fee_rate: float = 3.0, + knapsack_min_change: int = -1, + ) -> Tuple[List['UtxoData'], int, int]: if (isinstance(utxo_list, list) is False) or ( len(utxo_list) == 0): raise CfdError( @@ -1029,7 +1065,8 @@ def select_coins( # @retval [0] total tx fee. (txout fee + utxo fee) # @retval [1] txout fee. # @retval [2] utxo fee. - def estimate_fee(self, utxo_list, fee_rate=20.0): + def estimate_fee(self, utxo_list: List['UtxoData'], fee_rate: float = 20.0, + ) -> Tuple[int, int, int]: if (isinstance(utxo_list, list) is False) or ( len(utxo_list) == 0): raise CfdError( @@ -1073,10 +1110,11 @@ def estimate_fee(self, utxo_list, fee_rate=20.0): # @retval [0] total tx fee. # @retval [1] used reserved address. (None or reserved_address) def fund_raw_transaction( - self, txin_utxo_list, utxo_list, reserved_address, - target_amount=0, effective_fee_rate=20.0, - long_term_fee_rate=20.0, dust_fee_rate=-1.0, - knapsack_min_change=-1): + self, txin_utxo_list: List['UtxoData'], utxo_list: List['UtxoData'], + reserved_address, target_amount: int = 0, + effective_fee_rate: float = 20.0, + long_term_fee_rate: float = 20.0, dust_fee_rate: float = -1.0, + knapsack_min_change: int = -1) -> Tuple[int, str]: util = get_util() def set_opt(handle, tx_handle, key, i_val=0, f_val=0, b_val=False): @@ -1142,7 +1180,7 @@ def set_opt(handle, tx_handle, key, i_val=0, f_val=0, b_val=False): handle.get_handle(), tx_handle.get_handle(), 0) used_addr = None if _used_addr == reserved_address: - used_addr = reserved_address + used_addr = str(reserved_address) self.hex = _new_hex self._update_tx_all() diff --git a/cfd/util.py b/cfd/util.py index 0fd7c03..7c44a47 100644 --- a/cfd/util.py +++ b/cfd/util.py @@ -3,7 +3,7 @@ # @file util.py # @brief cfd utility file. # @note Copyright 2020 CryptoGarage -from ctypes import c_int, c_void_p, c_char_p, c_int32, c_int64,\ +from ctypes import Union, c_int, c_void_p, c_char_p, c_int32, c_int64,\ c_uint32, c_uint64, c_bool, c_double, c_ubyte, \ CDLL, byref, POINTER, ArgumentError from os.path import isfile, abspath @@ -11,6 +11,7 @@ import platform import os import re +from typing import List ################ # Public class # @@ -63,23 +64,29 @@ class CfdError(Exception): ## # @var error_code # error code + error_code: int ## # @var message # error message + message: str ## # @brief constructor. # @param[in] error_code error code # @param[in] message error message - def __init__(self, error_code=CfdErrorCode.UNKNOWN.value, message=''): + def __init__( + self, + error_code: int = CfdErrorCode.UNKNOWN.value, + message: str = '', + ) -> None: self.error_code = error_code self.message = message ## # @brief get error information. # @return error information. - def __str__(self): - return 'code={}, msg={}'.format(self.error_code, self.message) + def __str__(self) -> str: + return f'code={self.error_code}, msg={self.message}' ## @@ -89,11 +96,12 @@ class ByteData: ## # @var hex # hex string + hex: str ## # @brief constructor. # @param[in] data byte data - def __init__(self, data): + def __init__(self, data) -> None: if isinstance(data, bytes) or isinstance(data, bytearray): self.hex = data.hex() elif isinstance(data, list): @@ -105,26 +113,26 @@ def __init__(self, data): ## # @brief get string. # @return hex. - def __str__(self): + def __str__(self) -> str: return self.hex ## # @brief get bytes data. # @return bytes data. - def as_bytes(self): + def as_bytes(self) -> bytes: return bytes.fromhex(self.hex) ## # @brief get array data. # @return array data. - def as_array(self): - _hex_list = re.split('(..)', self.hex)[1::2] + def as_array(self) -> List[int]: + _hex_list = re.split('(..)', self.hex)[1:: 2] return [int('0x' + s, 16) for s in _hex_list] ## # @brief get serialized data. # @return serialize hex. - def serialize(self): + def serialize(self) -> 'ByteData': util = get_util() with util.create_handle() as handle: _serialized = util.call_func( @@ -139,11 +147,12 @@ class ReverseByteData: ## # @var hex # hex string + hex: str ## # @brief constructor. # @param[in] data byte data - def __init__(self, data): + def __init__(self, data) -> None: if isinstance(data, bytes) or isinstance(data, bytearray): _data = data.hex() _list = re.split('(..)', _data)[1::2] @@ -165,13 +174,13 @@ def __init__(self, data): ## # @brief get string. # @return hex. - def __str__(self): + def __str__(self) -> str: return self.hex ## # @brief get bytes data. # @return bytes data. - def as_bytes(self): + def as_bytes(self) -> bytes: _hex_list = re.split('(..)', self.hex)[1::2] _hex_list = _hex_list[::-1] return bytes.fromhex(''.join(_hex_list)) @@ -179,7 +188,7 @@ def as_bytes(self): ## # @brief get array data. # @return array data. - def as_array(self): + def as_array(self) -> List[int]: _hex_list = re.split('(..)', self.hex)[1::2] _hex_list = _hex_list[::-1] return [int('0x' + s, 16) for s in _hex_list] @@ -189,15 +198,15 @@ def as_array(self): # @brief get hex string. # @param[in] value data # @return hex string. -def to_hex_string(value): +def to_hex_string(value) -> str: if isinstance(value, bytes): return value.hex() elif isinstance(value, bytearray): return value.hex() elif isinstance(value, list): - return "".join("%02x" % b for b in value) + return "".join("%02x" % int(b) for b in value) elif str(type(value)) == "": - return value.hex + return str(value.hex) else: _hex = str(value) if _hex != '': @@ -352,7 +361,8 @@ class JobHandle: # @param[in] handle handle # @param[in] job_handle job handle # @param[in] close_function_name close func name. - def __init__(self, handle, job_handle, close_function_name): + def __init__(self, handle: 'CfdHandle', + job_handle, close_function_name): self._handle = handle self._job_handle = job_handle self._close_func = close_function_name @@ -619,7 +629,6 @@ def get_instance(cls): ## # @brief constructor. - # @return utility instance. def __init__(self): self._func_map = {} @@ -725,6 +734,7 @@ def in_string_fn_wrapper(fn, pos, *args): return fn(*args) def string_fn_wrapper(fn, *args): + new_args = None try: # Return output string parameters directly without leaking p = c_char_p() @@ -872,7 +882,7 @@ def call_func(self, name, *args): # @brief create cfd handle. # @return cfd handle # @throw CfdError occurred error. - def create_handle(self): + def create_handle(self) -> 'CfdHandle': ret, handle = self._func_map['CfdCreateSimpleHandle']() if ret != 0: raise CfdError( @@ -884,14 +894,14 @@ def create_handle(self): # @brief free cfd handle. # @param[in] handle cfd handle # @return result - def free_handle(self, handle): + def free_handle(self, handle) -> c_int: return self._func_map['CfdFreeHandle'](handle) ## # @brief get utility object. # @return utility object. -def get_util(): +def get_util() -> 'CfdUtil': return CfdUtil.get_instance() diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 491f376..4ebcc93 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -43,7 +43,7 @@ if(CFD_TARGET_VERSION) set(CFD_TARGET_TAG ${CFD_TARGET_VERSION}) message(STATUS "[external project local] cfd target=${CFD_TARGET_VERSION}") else() -set(CFD_TARGET_TAG v0.2.1) +set(CFD_TARGET_TAG v0.2.2) endif() if(CFD_TARGET_URL) set(CFD_TARGET_REP ${CFD_TARGET_URL}) diff --git a/integration_test/tests/test_elements.py b/integration_test/tests/test_elements.py index 202ce43..45ec369 100644 --- a/integration_test/tests/test_elements.py +++ b/integration_test/tests/test_elements.py @@ -289,8 +289,8 @@ def update_pegin_tx(test_obj, pegin_tx, btc_tx, pegin_address): fee_amount = 0 has_fee = len(tx.txout_list) == 2 for index, txout in enumerate(tx.txout_list): - if txout.locking_script: - target_script_pubkey = txout.locking_script + if len(txout.locking_script.hex) > 0: + target_script_pubkey = str(txout.locking_script) target_amount = txout.amount target_index = index else: @@ -475,7 +475,7 @@ def test_elements_pkh(test_obj): txin_list = [] txin_utxo_list = [] for index, txout in enumerate(tx.txout_list): - if not txout.locking_script: + if not txout.locking_script.hex: continue temp_addr = str(txout.get_address(network=NETWORK)) if temp_addr == fee_addr: @@ -601,7 +601,7 @@ def test_elements_multisig(test_obj): txin_list = [] txin_utxo_list = [] for index, txout in enumerate(tx.txout_list): - if not txout.locking_script: + if not txout.locking_script.hex: continue temp_addr = str(txout.get_address(network=NETWORK)) if temp_addr == fee_addr: diff --git a/tests/data/elements_address_test.json b/tests/data/elements_address_test.json index c4a0709..b76d155 100644 --- a/tests/data/elements_address_test.json +++ b/tests/data/elements_address_test.json @@ -132,7 +132,8 @@ "code": 1, "type": "illegal_argument", "capi": "Failed to parameter. address is null or empty.", - "cfd": "unblinded_addrss is empty." + "cfd": "unblinded_addrss is empty.", + "python": "Failed to parameter. address is null." } }, { @@ -145,7 +146,8 @@ "code": 1, "type": "illegal_argument", "capi": "Failed to parameter. confidential key is null or empty.", - "cfd": "key is empty." + "cfd": "key is empty.", + "python": "Failed to parameter. pubkey is null or empty." } }, { diff --git a/tests/data/elements_transaction_test.json b/tests/data/elements_transaction_test.json index e0bcb9b..90685e3 100644 --- a/tests/data/elements_transaction_test.json +++ b/tests/data/elements_transaction_test.json @@ -8349,7 +8349,8 @@ "error": { "code": 1, "type": "illegal_argument", - "message": "Value hex string length Invalid." + "message": "Value hex string length Invalid.", + "python": "Error: Invalid blind factor." } }, { @@ -8746,7 +8747,8 @@ "error": { "code": 1, "type": "illegal_argument", - "message": "hex to byte convert error." + "message": "hex to byte convert error.", + "python": "Error: Invalid hex value." } }, { diff --git a/tests/data/key_test.json b/tests/data/key_test.json index 3022fac..d5f9971 100644 --- a/tests/data/key_test.json +++ b/tests/data/key_test.json @@ -815,7 +815,8 @@ }, "error": { "capi": "hex to byte convert error.", - "json": "" + "json": "", + "python": "Error: Invalid hex value." }, "exclude": [ "json" diff --git a/tests/test_address.py b/tests/test_address.py index ed4fdb5..8eac958 100644 --- a/tests/test_address.py +++ b/tests/test_address.py @@ -9,6 +9,7 @@ def test_address_func(obj, name, case, req, exp, error): try: + resp = None _network = req.get('network', 'mainnet') if req.get('isElements', False) and ( _network.lower() == Network.REGTEST.as_str()): diff --git a/tests/test_confidential_transaction.py b/tests/test_confidential_transaction.py index 1a80a40..f9676d9 100644 --- a/tests/test_confidential_transaction.py +++ b/tests/test_confidential_transaction.py @@ -1,4 +1,6 @@ import json +import typing +from typing import List from unittest import TestCase from tests.util import load_json_file, get_json_file, exec_test,\ assert_equal, assert_error, assert_match, assert_message @@ -8,7 +10,7 @@ from cfd.script import HashType from cfd.key import SigHashType, SignParameter, Network from cfd.transaction import OutPoint, TxIn -from cfd.confidential_transaction import ConfidentialTxOut,\ +from cfd.confidential_transaction import BlindData, ConfidentialTxOut,\ ConfidentialTransaction, ElementsUtxoData, IssuanceKeyPair,\ TargetAmountData, Issuance, UnblindData,\ IssuanceAssetBlindData, IssuanceTokenBlindData @@ -544,15 +546,18 @@ def test_ct_transaction_func4(obj, name, case, req, exp, error): 'Fail: {}:{}:{}'.format( name, case, 'maxVsize')) txout_list = resp['req_output'] - tx = resp['tx'] + tx = typing.cast('ConfidentialTransaction', resp['tx']) + blinder_list = typing.cast( + typing.List[typing.Union['BlindData', 'IssuanceAssetBlindData', 'IssuanceTokenBlindData']], resp['blinder_list']) blinding_keys = exp.get('blindingKeys', []) issuance_list = exp.get('issuanceList', []) txout_index_list = [] for index, txout in enumerate(tx.txout_list): if txout.value.has_blind(): txout_index_list.append(index) - for blind_index, blinder in enumerate(resp['blinder_list']): + for blind_index, blinder in enumerate(blinder_list): is_find = False + data = {} has_asset = isinstance(blinder, IssuanceAssetBlindData) has_token = isinstance(blinder, IssuanceTokenBlindData) if has_asset or has_token: @@ -746,13 +751,15 @@ def test_parse_tx_func(obj, name, case, req, exp, error): def test_elements_tx_func(obj, name, case, req, exp, error): try: + coin_resp = () + resp = () if name == 'Elements.CoinSelection': # selected_utxo_list, _utxo_fee, total_amount_map utxo_list = obj.utxos.get(req['utxoFile'], []) target_list = convert_target_amount(req['targets']) fee_info = req.get('feeInfo', {}) fee_rate = fee_info.get('feeRate', 20.0) - resp = ConfidentialTransaction.select_coins( + coin_resp = ConfidentialTransaction.select_coins( utxo_list, tx_fee_amount=fee_info.get('txFeeAmount', 0), target_list=target_list, @@ -798,7 +805,7 @@ def test_elements_tx_func(obj, name, case, req, exp, error): if name == 'Elements.CoinSelection': # selected_utxo_list, _utxo_fee, total_amount_map - selected_utxo_list, utxo_fee, total_amount_map = resp + selected_utxo_list, utxo_fee, total_amount_map = coin_resp assert_equal(obj, name, case, exp, utxo_fee, 'utxoFeeAmount') exp_list = convert_elements_utxo(exp['utxos']) for exp_utxo in exp_list: @@ -809,12 +816,13 @@ def test_elements_tx_func(obj, name, case, req, exp, error): assert_match(obj, name, case, len(exp_amount_list), len(total_amount_map), 'selectedAmountsLen') for exp_amount_data in exp_amount_list: - if exp_amount_data.asset not in total_amount_map: + if str(exp_amount_data.asset) not in total_amount_map: + print(f'{total_amount_map}') assert_message(obj, name, case, 'selectedAmounts:{}'.format( exp_amount_data.asset)) assert_match(obj, name, case, exp_amount_data.amount, - total_amount_map[exp_amount_data.asset], + total_amount_map[str(exp_amount_data.asset)], 'selectedAmounts:{}:amount'.format( exp_amount_data.asset)) elif name == 'Elements.EstimateFee': @@ -864,7 +872,7 @@ def convert_elements_utxo(json_utxo_list): return utxo_list -def convert_target_amount(json_target_list): +def convert_target_amount(json_target_list) -> List['TargetAmountData']: target_list = [] for target in json_target_list: data = TargetAmountData( diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 732bc01..5997722 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -13,6 +13,7 @@ def test_transaction_func1(obj, name, case, req, exp, error): try: + resp = None if 'tx' in req: resp = Transaction.from_hex(req['tx']) txins, txouts = [], [] @@ -66,6 +67,8 @@ def test_transaction_func1(obj, name, case, req, exp, error): def test_transaction_func2(obj, name, case, req, exp, error): try: + resp = None + txin = {} if 'tx' in req: resp = Transaction.from_hex(req['tx']) if 'txin' in req: diff --git a/tests/util.py b/tests/util.py index 3f67e46..fcca5cb 100644 --- a/tests/util.py +++ b/tests/util.py @@ -74,6 +74,8 @@ def assert_equal(test_obj, test_name, case, expect, value, err_msg, _value, 'Fail: {}:{}'.format(test_name, case)) elif param_name in expect: + if isinstance(expect[param_name], str) and (not isinstance(_value, str)): + _value = str(_value) fail_param_name = log_name if log_name else param_name test_obj.assertEqual( expect[param_name], _value, @@ -81,8 +83,11 @@ def assert_equal(test_obj, test_name, case, expect, value, def assert_match(test_obj, test_name, case, expect, value, param_name): + _value = value + if isinstance(expect, str) and (not isinstance(_value, str)): + _value = str(_value) test_obj.assertEqual( - expect, value, + expect, _value, 'Fail: {}:{}:{}'.format(test_name, case, str(param_name)))