diff --git a/clvm/CLVMObject.py b/clvm/CLVMObject.py index 7586968c..9ac935f4 100644 --- a/clvm/CLVMObject.py +++ b/clvm/CLVMObject.py @@ -20,7 +20,10 @@ def __new__(class_, v): self = super(CLVMObject, class_).__new__(class_) if isinstance(v, tuple): if len(v) != 2: - raise ValueError("tuples must be of size 2, cannot create CLVMObject from: %s" % str(v)) + raise ValueError( + "tuples must be of size 2, cannot create CLVMObject from: %s" + % str(v) + ) self.pair = v self.atom = None else: diff --git a/clvm/SExp.py b/clvm/SExp.py index 398fe108..baa30090 100644 --- a/clvm/SExp.py +++ b/clvm/SExp.py @@ -131,6 +131,7 @@ class SExp: elements implementing the CLVM object protocol. Exactly one of "atom" and "pair" must be None. """ + true: "SExp" false: "SExp" __null__: "SExp" diff --git a/clvm/__init__.py b/clvm/__init__.py index a7062e1e..3b05ea60 100644 --- a/clvm/__init__.py +++ b/clvm/__init__.py @@ -1,6 +1,8 @@ from .SExp import SExp +from .dialect import Dialect +from .chia_dialect import dialect_factories # noqa from .operators import ( # noqa - QUOTE_ATOM, + QUOTE_ATOM, # deprecated KEYWORD_TO_ATOM, KEYWORD_FROM_ATOM, ) diff --git a/clvm/chainable_multi_op_fn.py b/clvm/chainable_multi_op_fn.py new file mode 100644 index 00000000..4a5d5f8b --- /dev/null +++ b/clvm/chainable_multi_op_fn.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +from .types import CLVMObjectType, MultiOpFn, OperatorDict + + +@dataclass +class ChainableMultiOpFn: + """ + This structure handles clvm operators. Given an atom, it looks it up in a `dict`, then + falls back to calling `unknown_op_handler`. + """ + + op_lookup: OperatorDict + unknown_op_handler: MultiOpFn + + def __call__( + self, op: bytes, arguments: CLVMObjectType, max_cost: Optional[int] = None + ) -> Tuple[int, CLVMObjectType]: + f = self.op_lookup.get(op) + if f: + try: + return f(arguments) + except TypeError: + # some operators require `max_cost` + return f(arguments, max_cost) + return self.unknown_op_handler(op, arguments, max_cost) diff --git a/clvm/chia_dialect.py b/clvm/chia_dialect.py new file mode 100644 index 00000000..4ea42d8f --- /dev/null +++ b/clvm/chia_dialect.py @@ -0,0 +1,102 @@ +from .SExp import SExp +from .casts import int_to_bytes +from .types import CLVMObjectType, ConversionFn, MultiOpFn, OperatorDict +from .chainable_multi_op_fn import ChainableMultiOpFn +from .handle_unknown_op import ( + handle_unknown_op_softfork_ready, + handle_unknown_op_strict, +) +from .dialect import ( + ConversionFn, + Dialect, + new_dialect, + opcode_table_for_backend, + python_new_dialect, + native_new_dialect, +) +from .chia_dialect_constants import KEYWORDS, KEYWORD_FROM_ATOM, KEYWORD_TO_ATOM # noqa +from .operators import OPERATOR_LOOKUP + + +def configure_chia_dialect(dialect: Dialect, backend=None) -> Dialect: + quote_kw = KEYWORD_TO_ATOM["q"] + apply_kw = KEYWORD_TO_ATOM["a"] + table = opcode_table_for_backend(KEYWORD_TO_ATOM, backend=backend) + dialect.update(table) + return dialect + + +def chia_dialect(strict: bool, to_python: ConversionFn, backend=None) -> Dialect: + dialect = new_dialect(quote_kw, apply_kw, strict, to_python, backend=backend) + return configure_chia_dialect(dialect, backend) + + +class DebugDialect(Dialect): + def __init__( + self, + quote_kw: bytes, + apply_kw: bytes, + multi_op_fn: MultiOpFn, + to_python: ConversionFn, + ): + super().__init__(quote_kw, apply_kw, multi_op_fn, to_python) + self.tracer = lambda x, y: None + + def do_sha256_with_trace(self, prev): + def _run(value, max_cost=None): + try: + cost, result = prev(value) + except TypeError: + cost, result = prev(value, max_cost) + + self.tracer(value, result) + return cost, result + + return _run + + def configure(self, **kwargs): + if "sha256_tracer" in kwargs: + self.tracer = kwargs["sha256_tracer"] + + +def chia_python_new_dialect( + quote_kw: bytes, + apply_kw: bytes, + strict: bool, + to_python: ConversionFn, + backend="python", +) -> Dialect: + unknown_op_callback = ( + handle_unknown_op_strict if strict else handle_unknown_op_softfork_ready + ) + + # Setup as a chia style clvm provider giving the chia operators. + return configure_chia_dialect( + DebugDialect(quote_kw, apply_kw, OPERATOR_LOOKUP, to_python), backend + ) + + +# Dialect that can allow acausal tracing of sha256 hashes. +def debug_new_dialect( + quote_kw: bytes, + apply_kw: bytes, + strict: bool, + to_python: ConversionFn, + backend="python", +) -> Dialect: + d = chia_python_new_dialect(quote_kw, apply_kw, strict, to_python, backend) + + # Override operators we want to track. + std_op_table = opcode_table_for_backend(KEYWORD_TO_ATOM, backend="python") + sha256_op = KEYWORD_TO_ATOM["sha256"] + table = {sha256_op: d.do_sha256_with_trace(std_op_table[sha256_op])} + d.update(table) + + return d + + +dialect_factories = { + "python": chia_python_new_dialect, + "native": native_new_dialect, + "debug": debug_new_dialect, +} diff --git a/clvm/chia_dialect_constants.py b/clvm/chia_dialect_constants.py new file mode 100644 index 00000000..fa906b46 --- /dev/null +++ b/clvm/chia_dialect_constants.py @@ -0,0 +1,37 @@ +from .casts import int_to_bytes + +KEYWORDS = ( + # core opcodes 0x01-x08 + ". q a i c f r l x " + # opcodes on atoms as strings 0x09-0x0f + "= >s sha256 substr strlen concat . " + # opcodes on atoms as ints 0x10-0x17 + "+ - * / divmod > ash lsh " + # opcodes on atoms as vectors of bools 0x18-0x1c + "logand logior logxor lognot . " + # opcodes for bls 1381 0x1d-0x1f + "point_add pubkey_for_exp . " + # bool opcodes 0x20-0x23 + "not any all . " + # misc 0x24 + "softfork " +).split() + +KEYWORD_FROM_ATOM = {int_to_bytes(k): v for k, v in enumerate(KEYWORDS)} +KEYWORD_TO_ATOM = {v: k for k, v in KEYWORD_FROM_ATOM.items()} + +KEYWORD_TO_LONG_KEYWORD = { + "i": "op_if", + "c": "op_cons", + "f": "op_first", + "r": "op_rest", + "l": "op_listp", + "x": "op_raise", + "=": "op_eq", + "+": "op_add", + "-": "op_subtract", + "*": "op_multiply", + "/": "op_divmod", + ">": "op_gr", + ">s": "op_gr_bytes", +} diff --git a/clvm/dialect.py b/clvm/dialect.py new file mode 100644 index 00000000..1392cbb8 --- /dev/null +++ b/clvm/dialect.py @@ -0,0 +1,221 @@ +from typing import Callable, Optional, Tuple +from .SExp import SExp + +try: + import clvm_rs +except ImportError: + clvm_rs = None + +import io +from . import core_ops, more_ops +from .chainable_multi_op_fn import ChainableMultiOpFn +from .handle_unknown_op import ( + handle_unknown_op_softfork_ready, + handle_unknown_op_strict, +) +from .run_program import _run_program +from .types import CLVMObjectType, ConversionFn, MultiOpFn, OperatorDict +from clvm.serialize import sexp_from_stream, sexp_to_stream +from .chia_dialect_constants import KEYWORD_FROM_ATOM, KEYWORD_TO_LONG_KEYWORD + + +OP_REWRITE = { + "+": "add", + "-": "subtract", + "*": "multiply", + "/": "div", + "i": "if", + "c": "cons", + "f": "first", + "r": "rest", + "l": "listp", + "x": "raise", + "=": "eq", + ">": "gr", + ">s": "gr_bytes", +} + + +def op_table_for_module(mod): + return {k: v for k, v in mod.__dict__.items() if k.startswith("op_")} + + +def op_imp_table_for_backend(backend): + if backend is None and clvm_rs: + backend = "native" + + if backend == "native": + if clvm_rs is None: + raise RuntimeError("native backend not installed") + return clvm_rs.native_opcodes_dict() + + table = {} + table.update(op_table_for_module(core_ops)) + table.update(op_table_for_module(more_ops)) + return table + + +def op_atom_to_imp_table(op_imp_table, keyword_to_atom, op_rewrite=OP_REWRITE): + op_atom_to_imp_table = {} + for op, bytecode in keyword_to_atom.items(): + op_name = "op_%s" % op_rewrite.get(op, op) + op_f = op_imp_table.get(op_name) + if op_f: + op_atom_to_imp_table[bytecode] = op_f + return op_atom_to_imp_table + + +def opcode_table_for_backend(keyword_to_atom, backend): + op_imp_table = op_imp_table_for_backend(backend) + return op_atom_to_imp_table(op_imp_table, keyword_to_atom) + + +class Dialect: + def __init__( + self, + quote_kw: bytes, + apply_kw: bytes, + multi_op_fn: MultiOpFn, + to_python: ConversionFn, + ): + self.quote_kw = quote_kw + self.apply_kw = apply_kw + self.opcode_lookup = dict() + self.multi_op_fn = ChainableMultiOpFn(self.opcode_lookup, multi_op_fn) + self.to_python = to_python + + def configure(self, **kwargs): + pass + + def update(self, d: OperatorDict) -> None: + self.opcode_lookup.update(d) + + def clear(self) -> None: + self.opcode_lookup.clear() + + def run_program( + self, + program: CLVMObjectType, + env: CLVMObjectType, + max_cost: int, + pre_eval_f: Optional[ + Callable[[CLVMObjectType, CLVMObjectType], Tuple[int, CLVMObjectType]] + ] = None, + ) -> Tuple[int, CLVMObjectType]: + cost, r = _run_program( + program, + env, + self.multi_op_fn, + self.quote_kw, + self.apply_kw, + max_cost, + pre_eval_f, + ) + return cost, self.to_python(r) + + +class NativeDialect: + def __init__( + self, + quote_kw: bytes, + apply_kw: bytes, + multi_op_fn: MultiOpFn, + to_python: ConversionFn, + ): + native_dict = clvm_rs.native_opcodes_dict() + + def get_native_op_for_kw(op, k): + kw = ( + KEYWORD_TO_LONG_KEYWORD[k] + if k in KEYWORD_TO_LONG_KEYWORD + else "op_%s" % k + ) + return (op, native_dict[kw]) + + native_opcode_names_by_opcode = dict( + get_native_op_for_kw(op, k) + for op, k in KEYWORD_FROM_ATOM.items() + if k not in "qa." + ) + + self.quote_kw = quote_kw + self.apply_kw = apply_kw + self.to_python = to_python + self.callbacks = multi_op_fn + self.held = clvm_rs.Dialect(quote_kw, apply_kw, multi_op_fn, to_python) + + self.held.update(native_opcode_names_by_opcode) + + def update(self, d): + return self.held.update(d) + + def clear(self) -> None: + return self.held.clear() + + def run_program( + self, + program: CLVMObjectType, + env: CLVMObjectType, + max_cost: int, + pre_eval_f: Optional[ + Callable[[CLVMObjectType, CLVMObjectType], Tuple[int, CLVMObjectType]] + ] = None, + ) -> Tuple[int, CLVMObjectType]: + prog = io.BytesIO() + e = io.BytesIO() + sexp_to_stream(program, prog) + sexp_to_stream(env, e) + + return self.held.deserialize_and_run_program( + prog.getvalue(), e.getvalue(), max_cost, pre_eval_f + ) + + def configure(self, **kwargs): + pass + + +def native_new_dialect( + quote_kw: bytes, apply_kw: bytes, strict: bool, to_python: ConversionFn +) -> Dialect: + unknown_op_callback = ( + clvm_rs.NATIVE_OP_UNKNOWN_STRICT + if strict + else clvm_rs.NATIVE_OP_UNKNOWN_NON_STRICT + ) + + dialect = NativeDialect( + quote_kw, + apply_kw, + unknown_op_callback, + to_python=to_python, + ) + return dialect + + +def python_new_dialect( + quote_kw: bytes, apply_kw: bytes, strict: bool, to_python: ConversionFn +) -> Dialect: + unknown_op_callback = ( + handle_unknown_op_strict if strict else handle_unknown_op_softfork_ready + ) + + dialect = Dialect( + quote_kw, + apply_kw, + unknown_op_callback, + to_python=to_python, + ) + return dialect + + +def new_dialect( + quote_kw: bytes, + apply_kw: bytes, + strict: bool, + to_python: ConversionFn, + backend=None, +): + if backend is None: + backend = "python" if clvm_rs is None else "native" + backend_f = native_new_dialect if backend == "native" else python_new_dialect + return backend_f(quote_kw, apply_kw, strict, to_python) diff --git a/clvm/handle_unknown_op.py b/clvm/handle_unknown_op.py new file mode 100644 index 00000000..e495a647 --- /dev/null +++ b/clvm/handle_unknown_op.py @@ -0,0 +1,124 @@ +from typing import Tuple + +from .CLVMObject import CLVMObject +from .EvalError import EvalError + +from .costs import ( + ARITH_BASE_COST, + ARITH_COST_PER_BYTE, + ARITH_COST_PER_ARG, + MUL_BASE_COST, + MUL_COST_PER_OP, + MUL_LINEAR_COST_PER_BYTE, + MUL_SQUARE_COST_PER_BYTE_DIVIDER, + CONCAT_BASE_COST, + CONCAT_COST_PER_ARG, + CONCAT_COST_PER_BYTE, +) + + +def handle_unknown_op_strict(op, arguments, _max_cost=None): + raise EvalError("unimplemented operator", arguments.to(op)) + + +def args_len(op_name, args): + for arg in args.as_iter(): + if arg.pair: + raise EvalError("%s requires int args" % op_name, arg) + yield len(arg.as_atom()) + + +# unknown ops are reserved if they start with 0xffff +# otherwise, unknown ops are no-ops, but they have costs. The cost is computed +# like this: + +# byte index (reverse): +# | 4 | 3 | 2 | 1 | 0 | +# +---+---+---+---+------------+ +# | multiplier |XX | XXXXXX | +# +---+---+---+---+---+--------+ +# ^ ^ ^ +# | | + 6 bits ignored when computing cost +# cost_multiplier | +# + 2 bits +# cost_function + +# 1 is always added to the multiplier before using it to multiply the cost, this +# is since cost may not be 0. + +# cost_function is 2 bits and defines how cost is computed based on arguments: +# 0: constant, cost is 1 * (multiplier + 1) +# 1: computed like operator add, multiplied by (multiplier + 1) +# 2: computed like operator mul, multiplied by (multiplier + 1) +# 3: computed like operator concat, multiplied by (multiplier + 1) + +# this means that unknown ops where cost_function is 1, 2, or 3, may still be +# fatal errors if the arguments passed are not atoms. + + +def handle_unknown_op_softfork_ready( + op: bytes, args: CLVMObject, max_cost: int +) -> Tuple[int, CLVMObject]: + # any opcode starting with ffff is reserved (i.e. fatal error) + # opcodes are not allowed to be empty + if len(op) == 0 or op[:2] == b"\xff\xff": + raise EvalError("reserved operator", args.to(op)) + + # all other unknown opcodes are no-ops + # the cost of the no-ops is determined by the opcode number, except the + # 6 least significant bits. + + cost_function = (op[-1] & 0b11000000) >> 6 + # the multiplier cannot be 0. it starts at 1 + + if len(op) > 5: + raise EvalError("invalid operator", args.to(op)) + + cost_multiplier = int.from_bytes(op[:-1], "big", signed=False) + 1 + + # 0 = constant + # 1 = like op_add/op_sub + # 2 = like op_multiply + # 3 = like op_concat + if cost_function == 0: + cost = 1 + elif cost_function == 1: + # like op_add + cost = ARITH_BASE_COST + arg_size = 0 + for length in args_len("unknown op", args): + arg_size += length + cost += ARITH_COST_PER_ARG + cost += arg_size * ARITH_COST_PER_BYTE + elif cost_function == 2: + # like op_multiply + cost = MUL_BASE_COST + operands = args_len("unknown op", args) + try: + vs = next(operands) + for rs in operands: + cost += MUL_COST_PER_OP + cost += (rs + vs) * MUL_LINEAR_COST_PER_BYTE + cost += (rs * vs) // MUL_SQUARE_COST_PER_BYTE_DIVIDER + # this is an estimate, since we don't want to actually multiply the + # values + vs += rs + except StopIteration: + pass + + elif cost_function == 3: + # like concat + cost = CONCAT_BASE_COST + length = 0 + for arg in args.as_iter(): + if arg.pair: + raise EvalError("unknown op on list", arg) + cost += CONCAT_COST_PER_ARG + length += len(arg.atom) + cost += length * CONCAT_COST_PER_BYTE + + cost *= cost_multiplier + if cost >= 2 ** 32: + raise EvalError("invalid operator", args.to(op)) + + return (cost, args.to(b"")) diff --git a/clvm/more_ops.py b/clvm/more_ops.py index abea509e..80a7c011 100644 --- a/clvm/more_ops.py +++ b/clvm/more_ops.py @@ -81,7 +81,9 @@ def args_as_int32(op_name, args: SExp): if arg.pair: raise EvalError("%s requires int32 args" % op_name, arg) if len(arg.atom) > 4: - raise EvalError("%s requires int32 args (with no leading zeros)" % op_name, arg) + raise EvalError( + "%s requires int32 args (with no leading zeros)" % op_name, arg + ) yield arg.as_int() @@ -89,7 +91,9 @@ def args_as_int_list(op_name, args, count): int_list = list(args_as_ints(op_name, args)) if len(int_list) != count: plural = "s" if count != 1 else "" - raise EvalError("%s takes exactly %d argument%s" % (op_name, count, plural), args) + raise EvalError( + "%s takes exactly %d argument%s" % (op_name, count, plural), args + ) return int_list @@ -106,7 +110,9 @@ def args_as_bool_list(op_name, args, count): bool_list = list(args_as_bools(op_name, args)) if len(bool_list) != count: plural = "s" if count != 1 else "" - raise EvalError("%s takes exactly %d argument%s" % (op_name, count, plural), args) + raise EvalError( + "%s takes exactly %d argument%s" % (op_name, count, plural), args + ) return bool_list @@ -249,7 +255,7 @@ def op_substr(args: SExp): s0 = a0.as_atom() if arg_count == 2: - i1, = list(args_as_int32("substr", args.rest())) + (i1,) = list(args_as_int32("substr", args.rest())) i2 = len(s0) else: i1, i2 = list(args_as_int32("substr", args.rest())) @@ -277,7 +283,9 @@ def op_concat(args: SExp): def op_ash(args): (i0, l0), (i1, l1) = args_as_int_list("ash", args, 2) if l1 > 4: - raise EvalError("ash requires int32 args (with no leading zeros)", args.rest().first()) + raise EvalError( + "ash requires int32 args (with no leading zeros)", args.rest().first() + ) if abs(i1) > 65535: raise EvalError("shift too large", args.to(i1)) if i1 >= 0: @@ -292,7 +300,9 @@ def op_ash(args): def op_lsh(args): (i0, l0), (i1, l1) = args_as_int_list("lsh", args, 2) if l1 > 4: - raise EvalError("lsh requires int32 args (with no leading zeros)", args.rest().first()) + raise EvalError( + "lsh requires int32 args (with no leading zeros)", args.rest().first() + ) if abs(i1) > 65535: raise EvalError("shift too large", args.to(i1)) # we actually want i0 to be an *unsigned* int @@ -344,7 +354,7 @@ def binop(a, b): def op_lognot(args): - (i0, l0), = args_as_int_list("lognot", args, 1) + ((i0, l0),) = args_as_int_list("lognot", args, 1) cost = LOGNOT_BASE_COST + l0 * LOGNOT_COST_PER_BYTE return malloc_cost(cost, args.to(~i0)) diff --git a/clvm/operators.py b/clvm/operators.py index a63e6a88..c02aaed2 100644 --- a/clvm/operators.py +++ b/clvm/operators.py @@ -1,168 +1,14 @@ +# this API is deprecated in favor of dialects. See `dialect.py` and `chia_dialect.py` + from typing import Dict, Tuple from . import core_ops, more_ops from .CLVMObject import CLVMObject -from .SExp import SExp -from .EvalError import EvalError - -from .casts import int_to_bytes from .op_utils import operators_for_module - -from .costs import ( - ARITH_BASE_COST, - ARITH_COST_PER_BYTE, - ARITH_COST_PER_ARG, - MUL_BASE_COST, - MUL_COST_PER_OP, - MUL_LINEAR_COST_PER_BYTE, - MUL_SQUARE_COST_PER_BYTE_DIVIDER, - CONCAT_BASE_COST, - CONCAT_COST_PER_ARG, - CONCAT_COST_PER_BYTE, -) - -KEYWORDS = ( - # core opcodes 0x01-x08 - ". q a i c f r l x " - - # opcodes on atoms as strings 0x09-0x0f - "= >s sha256 substr strlen concat . " - - # opcodes on atoms as ints 0x10-0x17 - "+ - * / divmod > ash lsh " - - # opcodes on atoms as vectors of bools 0x18-0x1c - "logand logior logxor lognot . " - - # opcodes for bls 1381 0x1d-0x1f - "point_add pubkey_for_exp . " - - # bool opcodes 0x20-0x23 - "not any all . " - - # misc 0x24 - "softfork " -).split() - -KEYWORD_FROM_ATOM = {int_to_bytes(k): v for k, v in enumerate(KEYWORDS)} -KEYWORD_TO_ATOM = {v: k for k, v in KEYWORD_FROM_ATOM.items()} - -OP_REWRITE = { - "+": "add", - "-": "subtract", - "*": "multiply", - "/": "div", - "i": "if", - "c": "cons", - "f": "first", - "r": "rest", - "l": "listp", - "x": "raise", - "=": "eq", - ">": "gr", - ">s": "gr_bytes", -} - - -def args_len(op_name, args): - for arg in args.as_iter(): - if arg.pair: - raise EvalError("%s requires int args" % op_name, arg) - yield len(arg.as_atom()) - - -# unknown ops are reserved if they start with 0xffff -# otherwise, unknown ops are no-ops, but they have costs. The cost is computed -# like this: - -# byte index (reverse): -# | 4 | 3 | 2 | 1 | 0 | -# +---+---+---+---+------------+ -# | multiplier |XX | XXXXXX | -# +---+---+---+---+---+--------+ -# ^ ^ ^ -# | | + 6 bits ignored when computing cost -# cost_multiplier | -# + 2 bits -# cost_function - -# 1 is always added to the multiplier before using it to multiply the cost, this -# is since cost may not be 0. - -# cost_function is 2 bits and defines how cost is computed based on arguments: -# 0: constant, cost is 1 * (multiplier + 1) -# 1: computed like operator add, multiplied by (multiplier + 1) -# 2: computed like operator mul, multiplied by (multiplier + 1) -# 3: computed like operator concat, multiplied by (multiplier + 1) - -# this means that unknown ops where cost_function is 1, 2, or 3, may still be -# fatal errors if the arguments passed are not atoms. - -def default_unknown_op(op: bytes, args: CLVMObject) -> Tuple[int, CLVMObject]: - # any opcode starting with ffff is reserved (i.e. fatal error) - # opcodes are not allowed to be empty - if len(op) == 0 or op[:2] == b"\xff\xff": - raise EvalError("reserved operator", args.to(op)) - - # all other unknown opcodes are no-ops - # the cost of the no-ops is determined by the opcode number, except the - # 6 least significant bits. - - cost_function = (op[-1] & 0b11000000) >> 6 - # the multiplier cannot be 0. it starts at 1 - - if len(op) > 5: - raise EvalError("invalid operator", args.to(op)) - - cost_multiplier = int.from_bytes(op[:-1], "big", signed=False) + 1 - - # 0 = constant - # 1 = like op_add/op_sub - # 2 = like op_multiply - # 3 = like op_concat - if cost_function == 0: - cost = 1 - elif cost_function == 1: - # like op_add - cost = ARITH_BASE_COST - arg_size = 0 - for length in args_len("unknown op", args): - arg_size += length - cost += ARITH_COST_PER_ARG - cost += arg_size * ARITH_COST_PER_BYTE - elif cost_function == 2: - # like op_multiply - cost = MUL_BASE_COST - operands = args_len("unknown op", args) - try: - vs = next(operands) - for rs in operands: - cost += MUL_COST_PER_OP - cost += (rs + vs) * MUL_LINEAR_COST_PER_BYTE - cost += (rs * vs) // MUL_SQUARE_COST_PER_BYTE_DIVIDER - # this is an estimate, since we don't want to actually multiply the - # values - vs += rs - except StopIteration: - pass - - elif cost_function == 3: - # like concat - cost = CONCAT_BASE_COST - length = 0 - for arg in args.as_iter(): - if arg.pair: - raise EvalError("unknown op on list", arg) - cost += CONCAT_COST_PER_ARG - length += len(arg.atom) - cost += length * CONCAT_COST_PER_BYTE - - cost *= cost_multiplier - if cost >= 2**32: - raise EvalError("invalid operator", args.to(op)) - - return (cost, SExp.null()) +from .handle_unknown_op import handle_unknown_op_softfork_ready +from .dialect import OP_REWRITE +from .chia_dialect_constants import KEYWORDS, KEYWORD_FROM_ATOM, KEYWORD_TO_ATOM # noqa class OperatorDict(dict): @@ -184,13 +30,18 @@ def __new__(class_, d: Dict, *args, **kwargs): if "unknown_op_handler" in kwargs: self.unknown_op_handler = kwargs["unknown_op_handler"] else: - self.unknown_op_handler = default_unknown_op + self.unknown_op_handler = handle_unknown_op_softfork_ready return self - def __call__(self, op: bytes, arguments: CLVMObject) -> Tuple[int, CLVMObject]: + def __call__( + self, op: bytes, arguments: CLVMObject, max_cost=None + ) -> Tuple[int, CLVMObject]: f = self.get(op) if f is None: - return self.unknown_op_handler(op, arguments) + try: + return self.unknown_op_handler(op, arguments, max_cost) + except TypeError: + return self.unknown_op_handler(op, arguments) else: return f(arguments) @@ -199,6 +50,8 @@ def __call__(self, op: bytes, arguments: CLVMObject) -> Tuple[int, CLVMObject]: APPLY_ATOM = KEYWORD_TO_ATOM["a"] OPERATOR_LOOKUP = OperatorDict( - operators_for_module(KEYWORD_TO_ATOM, core_ops, OP_REWRITE), quote=QUOTE_ATOM, apply=APPLY_ATOM + operators_for_module(KEYWORD_TO_ATOM, core_ops, OP_REWRITE), + quote=QUOTE_ATOM, + apply=APPLY_ATOM, ) OPERATOR_LOOKUP.update(operators_for_module(KEYWORD_TO_ATOM, more_ops, OP_REWRITE)) diff --git a/clvm/run_program.py b/clvm/run_program.py index 20f4b75c..0910f2c4 100644 --- a/clvm/run_program.py +++ b/clvm/run_program.py @@ -9,12 +9,14 @@ QUOTE_COST, PATH_LOOKUP_BASE_COST, PATH_LOOKUP_COST_PER_LEG, - PATH_LOOKUP_COST_PER_ZERO_BYTE + PATH_LOOKUP_COST_PER_ZERO_BYTE, ) # the "Any" below should really be "OpStackType" but # recursive types aren't supported by mypy +MultiOpFn = Callable[[bytes, SExp, int], Tuple[int, SExp]] + OpCallable = Callable[[Any, "ValStackType"], int] ValStackType = List[SExp] @@ -53,6 +55,27 @@ def run_program( pre_eval_f=None, ) -> Tuple[int, CLVMObject]: + return _run_program( + program, + args, + operator_lookup, + operator_lookup.quote_atom, + operator_lookup.apply_atom, + max_cost, + pre_eval_f, + ) + + +def _run_program( + program: CLVMObject, + args: CLVMObject, + operator_lookup: MultiOpFn, + quote_atom: bytes, + apply_atom: bytes, + max_cost=None, + pre_eval_f=None, +) -> Tuple[int, CLVMObject]: + program = SExp.to(program) if pre_eval_f: pre_eval_op = to_pre_eval_op(pre_eval_f, program.to) @@ -137,7 +160,7 @@ def eval_op(op_stack: OpStackType, value_stack: ValStackType) -> int: op = operator.as_atom() operand_list = sexp.rest() - if op == operator_lookup.quote_atom: + if op == quote_atom: value_stack.append(operand_list) return QUOTE_COST @@ -160,7 +183,7 @@ def apply_op(op_stack: OpStackType, value_stack: ValStackType) -> int: raise EvalError("internal error", operator) op = operator.as_atom() - if op == operator_lookup.apply_atom: + if op == apply_atom: if operand_list.list_len() != 2: raise EvalError("apply requires exactly 2 parameters", operand_list) new_program = operand_list.first() diff --git a/clvm/serialize.py b/clvm/serialize.py index 8d23c8b4..c889aecf 100644 --- a/clvm/serialize.py +++ b/clvm/serialize.py @@ -22,13 +22,13 @@ def sexp_to_byte_iterator(sexp): todo_stack = [sexp] while todo_stack: sexp = todo_stack.pop() - pair = sexp.as_pair() + pair = sexp.pair if pair: yield bytes([CONS_BOX_MARKER]) todo_stack.append(pair[1]) todo_stack.append(pair[0]) else: - yield from atom_to_byte_iterator(sexp.as_atom()) + yield from atom_to_byte_iterator(sexp.atom) def atom_to_byte_iterator(as_atom): diff --git a/clvm/types.py b/clvm/types.py new file mode 100644 index 00000000..c9f794ba --- /dev/null +++ b/clvm/types.py @@ -0,0 +1,15 @@ +from typing import Any, Callable, Dict, Tuple, Union + + +CLVMAtom = Any +CLVMPair = Any + +CLVMObjectType = Union["CLVMAtom", "CLVMPair"] + +MultiOpFn = Callable[[bytes, CLVMObjectType, int], Tuple[int, CLVMObjectType]] + +ConversionFn = Callable[[CLVMObjectType], CLVMObjectType] + +OpFn = Callable[[CLVMObjectType, int], Tuple[int, CLVMObjectType]] + +OperatorDict = Dict[bytes, Callable[[CLVMObjectType, int], Tuple[int, CLVMObjectType]]] diff --git a/setup.py b/setup.py index a1a86be7..6be83646 100755 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ dependencies = [ "blspy>=0.9", + "clvm_rs>=0.1.8" ] dev_dependencies = [ diff --git a/tests/brun/trace-1.txt b/tests/brun/trace-1.txt index 72fb303f..a1d5c92d 100644 --- a/tests/brun/trace-1.txt +++ b/tests/brun/trace-1.txt @@ -1,4 +1,4 @@ -brun --backend=python -c -v '(+ (q . 10) (f 1))' '(51)' +brun -c -v '(+ (q . 10) (f 1))' '(51)' cost = 860 61 diff --git a/tests/brun/trace-2.txt b/tests/brun/trace-2.txt index fb475ec5..08ba2687 100644 --- a/tests/brun/trace-2.txt +++ b/tests/brun/trace-2.txt @@ -1,4 +1,4 @@ -brun --backend=python -c -v '(x)' +brun -c -v '(x)' FAIL: clvm raise () (a 2 3) [((x))] => (didn't finish) diff --git a/tests/operatordict_test.py b/tests/operatordict_test.py deleted file mode 100644 index 897f6ffe..00000000 --- a/tests/operatordict_test.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest - -from clvm.operators import OperatorDict - - -class OperatorDictTest(unittest.TestCase): - def test_operatordict_constructor(self): - """Constructing should fail if quote or apply are not specified, - either by object property or by keyword argument. - Note that they cannot be specified in the operator dictionary itself. - """ - d = {1: "hello", 2: "goodbye"} - with self.assertRaises(AttributeError): - o = OperatorDict(d) - with self.assertRaises(AttributeError): - o = OperatorDict(d, apply=1) - with self.assertRaises(AttributeError): - o = OperatorDict(d, quote=1) - o = OperatorDict(d, apply=1, quote=2) - print(o) - # Why does the constructed Operator dict contain entries for "apply":1 and "quote":2 ? - # assert d == o - self.assertEqual(o.apply_atom, 1) - self.assertEqual(o.quote_atom, 2) - - # Test construction from an already existing OperatorDict - o2 = OperatorDict(o) - self.assertEqual(o2.apply_atom, 1) - self.assertEqual(o2.quote_atom, 2) diff --git a/tests/operators_test.py b/tests/operators_test.py index 9c84d719..c0282412 100644 --- a/tests/operators_test.py +++ b/tests/operators_test.py @@ -1,59 +1,90 @@ import unittest -from clvm.operators import (OPERATOR_LOOKUP, KEYWORD_TO_ATOM, default_unknown_op, OperatorDict) +from clvm.chainable_multi_op_fn import ChainableMultiOpFn +from clvm.costs import CONCAT_BASE_COST +from clvm.dialect import opcode_table_for_backend +from clvm.handle_unknown_op import handle_unknown_op_softfork_ready +from clvm.operators import KEYWORD_TO_ATOM from clvm.EvalError import EvalError from clvm import SExp -from clvm.costs import CONCAT_BASE_COST +OPERATOR_LOOKUP = opcode_table_for_backend(KEYWORD_TO_ATOM, backend=None) +MAX_COST = int(1e18) -class OperatorsTest(unittest.TestCase): +class OperatorsTest(unittest.TestCase): def setUp(self): self.handler_called = False - def unknown_handler(self, name, args): + def unknown_handler(self, name, args, _max_cost): self.handler_called = True - self.assertEqual(name, b'\xff\xff1337') + self.assertEqual(name, b"\xff\xff1337") self.assertEqual(args, SExp.to(1337)) - return 42, SExp.to(b'foobar') + return 42, SExp.to(b"foobar") def test_unknown_op(self): - self.assertRaises(EvalError, lambda: OPERATOR_LOOKUP(b'\xff\xff1337', SExp.to(1337))) - od = OperatorDict(OPERATOR_LOOKUP, unknown_op_handler=lambda name, args: self.unknown_handler(name, args)) - cost, ret = od(b'\xff\xff1337', SExp.to(1337)) + self.assertRaises( + KeyError, lambda: OPERATOR_LOOKUP[b"\xff\xff1337"](SExp.to(1337), None) + ) + od = ChainableMultiOpFn( + opcode_table_for_backend(KEYWORD_TO_ATOM, backend=None), + self.unknown_handler, + ) + cost, ret = od(b"\xff\xff1337", SExp.to(1337), None) self.assertTrue(self.handler_called) self.assertEqual(cost, 42) - self.assertEqual(ret, SExp.to(b'foobar')) + self.assertEqual(ret, SExp.to(b"foobar")) def test_plus(self): print(OPERATOR_LOOKUP) - self.assertEqual(OPERATOR_LOOKUP(KEYWORD_TO_ATOM['+'], SExp.to([3, 4, 5]))[1], SExp.to(12)) + self.assertEqual( + OPERATOR_LOOKUP[KEYWORD_TO_ATOM["+"]](SExp.to([3, 4, 5]), MAX_COST)[1], + SExp.to(12), + ) def test_unknown_op_reserved(self): # any op that starts with ffff is reserved, and results in a hard # failure with self.assertRaises(EvalError): - default_unknown_op(b"\xff\xff", SExp.null()) + handle_unknown_op_softfork_ready(b"\xff\xff", SExp.null(), max_cost=None) for suffix in [b"\xff", b"0", b"\x00", b"\xcc\xcc\xfe\xed\xfa\xce"]: with self.assertRaises(EvalError): - default_unknown_op(b"\xff\xff" + suffix, SExp.null()) + handle_unknown_op_softfork_ready( + b"\xff\xff" + suffix, SExp.null(), max_cost=None + ) with self.assertRaises(EvalError): # an empty atom is not a valid opcode - self.assertEqual(default_unknown_op(b"", SExp.null()), (1, SExp.null())) + self.assertEqual( + handle_unknown_op_softfork_ready(b"", SExp.null(), max_cost=None), + (1, SExp.null()), + ) # a single ff is not sufficient to be treated as a reserved opcode - self.assertEqual(default_unknown_op(b"\xff", SExp.null()), (CONCAT_BASE_COST, SExp.null())) + self.assertEqual( + handle_unknown_op_softfork_ready(b"\xff", SExp.null(), max_cost=None), + (CONCAT_BASE_COST, SExp.null()), + ) # leading zeroes count, and this does not count as a ffff-prefix # the cost is 0xffff00 = 16776960 - self.assertEqual(default_unknown_op(b"\x00\xff\xff\x00\x00", SExp.null()), (16776961, SExp.null())) + self.assertEqual( + handle_unknown_op_softfork_ready( + b"\x00\xff\xff\x00\x00", SExp.null(), max_cost=None + ), + (16776961, SExp.null()), + ) def test_unknown_ops_last_bits(self): # The last byte is ignored for no-op unknown ops for suffix in [b"\x3f", b"\x0f", b"\x00", b"\x2c"]: # the cost is unchanged by the last byte - self.assertEqual(default_unknown_op(b"\x3c" + suffix, SExp.null()), (61, SExp.null())) + self.assertEqual( + handle_unknown_op_softfork_ready( + b"\x3c" + suffix, SExp.null(), max_cost=None + ), + (61, SExp.null()), + ) diff --git a/tests/run_program_test.py b/tests/run_program_test.py index d64462ca..e5e5f47f 100644 --- a/tests/run_program_test.py +++ b/tests/run_program_test.py @@ -4,7 +4,6 @@ class BitTest(unittest.TestCase): - def test_msb_mask(self): self.assertEqual(msb_mask(0x0), 0x0) self.assertEqual(msb_mask(0x01), 0x01) @@ -17,6 +16,6 @@ def test_msb_mask(self): self.assertEqual(msb_mask(0x80), 0x80) self.assertEqual(msb_mask(0x44), 0x40) - self.assertEqual(msb_mask(0x2a), 0x20) - self.assertEqual(msb_mask(0xff), 0x80) - self.assertEqual(msb_mask(0x0f), 0x08) + self.assertEqual(msb_mask(0x2A), 0x20) + self.assertEqual(msb_mask(0xFF), 0x80) + self.assertEqual(msb_mask(0x0F), 0x08) diff --git a/tests/serialize_test.py b/tests/serialize_test.py index 786f1c95..70687bcb 100644 --- a/tests/serialize_test.py +++ b/tests/serialize_test.py @@ -2,7 +2,11 @@ import unittest from clvm import to_sexp_f -from clvm.serialize import (sexp_from_stream, sexp_buffer_from_stream, atom_to_byte_iterator) +from clvm.serialize import ( + sexp_from_stream, + sexp_buffer_from_stream, + atom_to_byte_iterator, +) TEXT = b"the quick brown fox jumps over the lazy dogs" @@ -13,12 +17,12 @@ def __init__(self, b): self.buf = b def read(self, n): - ret = b'' + ret = b"" while n > 0 and len(self.buf) > 0: ret += self.buf[0:1] self.buf = self.buf[1:] n -= 1 - ret += b' ' * n + ret += b" " * n return ret @@ -79,7 +83,7 @@ def test_long_blobs(self): def test_blob_limit(self): with self.assertRaises(ValueError): for b in atom_to_byte_iterator(LargeAtom()): - print('%02x' % b) + print("%02x" % b) def test_very_long_blobs(self): for size in [0x40, 0x2000, 0x100000, 0x8000000]: @@ -100,7 +104,7 @@ def test_very_deep_tree(self): self.check_serde(s) def test_deserialize_empty(self): - bytes_in = b'' + bytes_in = b"" with self.assertRaises(ValueError): sexp_from_stream(io.BytesIO(bytes_in), to_sexp_f) @@ -110,7 +114,7 @@ def test_deserialize_empty(self): def test_deserialize_truncated_size(self): # fe means the total number of bytes in the length-prefix is 7 # one for each bit set. 5 bytes is too few - bytes_in = b'\xfe ' + bytes_in = b"\xfe " with self.assertRaises(ValueError): sexp_from_stream(io.BytesIO(bytes_in), to_sexp_f) @@ -120,7 +124,7 @@ def test_deserialize_truncated_size(self): def test_deserialize_truncated_blob(self): # this is a complete length prefix. The blob is supposed to be 63 bytes # the blob itself is truncated though, it's less than 63 bytes - bytes_in = b'\xbf ' + bytes_in = b"\xbf " with self.assertRaises(ValueError): sexp_from_stream(io.BytesIO(bytes_in), to_sexp_f) @@ -134,7 +138,7 @@ def test_deserialize_large_blob(self): # we don't support blobs this large, and we should fail immediately when # exceeding the max blob size, rather than trying to read this many # bytes from the stream - bytes_in = b'\xfe' + b'\xff' * 6 + bytes_in = b"\xfe" + b"\xff" * 6 with self.assertRaises(ValueError): sexp_from_stream(InfiniteStream(bytes_in), to_sexp_f) diff --git a/tests/to_sexp_test.py b/tests/to_sexp_test.py index 36a73e49..5de88960 100644 --- a/tests/to_sexp_test.py +++ b/tests/to_sexp_test.py @@ -91,11 +91,17 @@ def pair(self) -> Optional[Tuple[Any, Any]]: if self.depth == 0: return None new_depth: int = self.depth - 1 - return (GeneratedTree(new_depth, self.val), GeneratedTree(new_depth, self.val + 2**new_depth)) + return ( + GeneratedTree(new_depth, self.val), + GeneratedTree(new_depth, self.val + 2 ** new_depth), + ) tree = SExp.to(GeneratedTree(5, 0)) - assert print_leaves(tree) == "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 " + \ - "16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 " + assert ( + print_leaves(tree) + == "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 " + + "16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 " + ) tree = SExp.to(GeneratedTree(3, 0)) assert print_leaves(tree) == "0 1 2 3 4 5 6 7 "