From 9d653cb0a93d90ddc802b0f9966ef9435b950dc9 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 18 Jan 2021 22:14:12 -0500 Subject: [PATCH 1/2] Extract parametric type metaclass from FunsorMeta and domains --- funsor/domains.py | 115 ++++++++++++---------------------------------- funsor/terms.py | 85 ++-------------------------------- funsor/util.py | 92 +++++++++++++++++++++++++++++++++++++ test/test_cnf.py | 2 +- 4 files changed, 126 insertions(+), 168 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 87f2ce644..22568700f 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -6,20 +6,19 @@ import operator import warnings from functools import reduce -from weakref import WeakValueDictionary import funsor.ops as ops -from funsor.util import broadcast_shape, get_backend, get_tracing_state, quote +from funsor.util import GenericTypeMeta, broadcast_shape, get_backend, get_tracing_state, quote -Domain = type + +class Domain(GenericTypeMeta): + pass class ArrayType(Domain): """ Base class of array-like domains. """ - _type_cache = WeakValueDictionary() - def __getitem__(cls, dtype_shape): dtype, shape = dtype_shape assert dtype is not None @@ -32,23 +31,7 @@ def __getitem__(cls, dtype_shape): if shape is not None: shape = tuple(map(int, shape)) - assert cls.dtype in (None, dtype) - assert cls.shape in (None, shape) - key = dtype, shape - result = ArrayType._type_cache.get(key, None) - if result is None: - if dtype == "real": - assert all(isinstance(size, int) and size >= 0 for size in shape) - name = "Reals[{}]".format(",".join(map(str, shape))) if shape else "Real" - result = RealsType(name, (), {"shape": shape}) - elif isinstance(dtype, int): - assert dtype >= 0 - name = "Bint[{}, {}]".format(dtype, ",".join(map(str, shape))) - result = BintType(name, (), {"dtype": dtype, "shape": shape}) - else: - raise ValueError("invalid dtype: {}".format(dtype)) - ArrayType._type_cache[key] = result - return result + return super().__getitem__((dtype, shape)) def __subclasscheck__(cls, subcls): if not isinstance(subcls, ArrayType): @@ -59,34 +42,18 @@ def __subclasscheck__(cls, subcls): return False return True - def __repr__(cls): - return cls.__name__ + @property + def dtype(cls): + return cls.__args__[0] - def __str__(cls): - return cls.__name__ + @property + def shape(cls): + return cls.__args__[1] @property def num_elements(cls): return reduce(operator.mul, cls.shape, 1) - -class BintType(ArrayType): - def __getitem__(cls, size_shape): - if isinstance(size_shape, tuple): - size, shape = size_shape[0], size_shape[1:] - else: - size, shape = size_shape, () - return super().__getitem__((size, shape)) - - def __subclasscheck__(cls, subcls): - if not isinstance(subcls, BintType): - return False - if cls.dtype not in (None, subcls.dtype): - return False - if cls.shape not in (None, subcls.shape): - return False - return True - @property def size(cls): return cls.dtype @@ -96,27 +63,25 @@ def __iter__(cls): return (Number(i, cls.size) for i in range(cls.size)) -class RealsType(ArrayType): - dtype = "real" +class BintType(ArrayType): + def __getitem__(cls, size_shape): + if isinstance(size_shape, tuple): + size, shape = size_shape[0], size_shape[1:] + else: + size, shape = size_shape, () + return Array.__getitem__((size, shape)) + +class RealsType(ArrayType): def __getitem__(cls, shape): if not isinstance(shape, tuple): shape = (shape,) - return super().__getitem__(("real", shape)) - - def __subclasscheck__(cls, subcls): - if not isinstance(subcls, RealsType): - return False - if cls.dtype not in (None, subcls.dtype): - return False - if cls.shape not in (None, subcls.shape): - return False - return True + return Array.__getitem__(("real", shape)) def _pickle_array(cls): - if cls in (Array, Bint, Real, Reals): - return cls.__name__ + if cls in (Array, Bint, Reals): + return repr(cls) return operator.getitem, (Array, (cls.dtype, cls.shape)) @@ -132,22 +97,20 @@ class Array(metaclass=ArrayType): Arary["real", (3, 3)] = Reals[3, 3] Array["real", ()] = Real """ - dtype = None - shape = None + pass -class Bint(metaclass=BintType): +class Bint(Array, metaclass=BintType): """ Factory for bounded integer types:: Bint[5] # integers ranging in {0,1,2,3,4} Bint[2, 3, 3] # 3x3 matrices with entries in {0,1} """ - dtype = None - shape = None + pass -class Reals(metaclass=RealsType): +class Reals(Array, metaclass=RealsType): """ Type of a real-valued array with known shape:: @@ -155,7 +118,7 @@ class Reals(metaclass=RealsType): Reals[8] # vector of length 8 Reals[3, 3] # 3x3 matrix """ - shape = None + pass Real = Reals[()] @@ -176,26 +139,6 @@ def bint(size): class ProductDomain(Domain): - - _type_cache = WeakValueDictionary() - - def __getitem__(cls, arg_domains): - try: - return ProductDomain._type_cache[arg_domains] - except KeyError: - assert isinstance(arg_domains, tuple) - assert all(isinstance(arg_domain, Domain) for arg_domain in arg_domains) - subcls = type("Product_", (Product,), {"__args__": arg_domains}) - ProductDomain._type_cache[arg_domains] = subcls - return subcls - - def __repr__(cls): - return "Product[{}]".format(", ".join(map(repr, cls.__args__))) - - @property - def __origin__(cls): - return Product - @property def shape(cls): return (len(cls.__args__),) @@ -203,7 +146,7 @@ def shape(cls): class Product(tuple, metaclass=ProductDomain): """like typing.Tuple, but works with issubclass""" - __args__ = NotImplemented + pass @quote.register(BintType) diff --git a/funsor/terms.py b/funsor/terms.py index 7cf6d1f51..45007c2a2 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -13,14 +13,14 @@ from weakref import WeakValueDictionary from multipledispatch import dispatch -from multipledispatch.variadic import Variadic, isvariadic +from multipledispatch.variadic import Variadic import funsor.interpreter as interpreter import funsor.ops as ops from funsor.domains import Array, Bint, Domain, Product, Real, find_domain from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret from funsor.ops import AssociativeOp, GetitemOp, Op -from funsor.util import getargspec, get_backend, lazy_property, pretty, quote +from funsor.util import GenericTypeMeta, getargspec, get_backend, lazy_property, pretty, quote def substitute(expr, subs): @@ -182,7 +182,7 @@ def moment_matching(cls, *args): interpreter.set_interpretation(eager) # Use eager interpretation by default. -class FunsorMeta(type): +class FunsorMeta(GenericTypeMeta): """ Metaclass for Funsors to perform four independent tasks: @@ -206,15 +206,9 @@ class FunsorMeta(type): """ def __init__(cls, name, bases, dct): super(FunsorMeta, cls).__init__(name, bases, dct) - if not hasattr(cls, "__args__"): - cls.__args__ = () - if cls.__args__: - base, = bases - cls.__origin__ = base - else: + if not cls.__args__: cls._ast_fields = getargspec(cls.__init__)[0][1:] cls._cons_cache = WeakValueDictionary() - cls._type_cache = WeakValueDictionary() def __call__(cls, *args, **kwargs): if cls.__args__: @@ -230,77 +224,6 @@ def __call__(cls, *args, **kwargs): return interpret(cls, *args) - def __getitem__(cls, arg_types): - if not isinstance(arg_types, tuple): - arg_types = (arg_types,) - assert not any(isvariadic(arg_type) for arg_type in arg_types), "nested variadic types not supported" - # switch tuple to typing.Tuple - arg_types = tuple(typing.Tuple if arg_type is tuple else arg_type for arg_type in arg_types) - if arg_types not in cls._type_cache: - assert not cls.__args__, "cannot subscript a subscripted type {}".format(cls) - assert len(arg_types) == len(cls._ast_fields), "must provide types for all params" - new_dct = cls.__dict__.copy() - new_dct.update({"__args__": arg_types}) - # type(cls) to handle FunsorMeta subclasses - cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct) - return cls._type_cache[arg_types] - - def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) - if cls is subcls: - return True - if not isinstance(subcls, FunsorMeta): - return super(FunsorMeta, getattr(cls, "__origin__", cls)).__subclasscheck__(subcls) - - cls_origin = getattr(cls, "__origin__", cls) - subcls_origin = getattr(subcls, "__origin__", subcls) - if not super(FunsorMeta, cls_origin).__subclasscheck__(subcls_origin): - return False - - if cls.__args__: - if not subcls.__args__: - return False - if len(cls.__args__) != len(subcls.__args__): - return False - for subcls_param, param in zip(subcls.__args__, cls.__args__): - if not _issubclass_tuple(subcls_param, param): - return False - return True - - @lazy_property - def classname(cls): - return cls.__name__ + "[{}]".format(", ".join( - str(getattr(t, "classname", t)) # Tuple doesn't have __name__ - for t in cls.__args__)) - - -def _issubclass_tuple(subcls, cls): - """ - utility for pattern matching with tuple subexpressions - """ - # so much boilerplate... - cls_is_union = hasattr(cls, "__origin__") and (cls.__origin__ or cls) is typing.Union - if isinstance(cls, tuple) or cls_is_union: - return any(_issubclass_tuple(subcls, option) - for option in (getattr(cls, "__args__", []) if cls_is_union else cls)) - - subcls_is_union = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) is typing.Union - if isinstance(subcls, tuple) or subcls_is_union: - return any(_issubclass_tuple(option, cls) - for option in (getattr(subcls, "__args__", []) if subcls_is_union else subcls)) - - subcls_is_tuple = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) in (tuple, typing.Tuple) - cls_is_tuple = hasattr(cls, "__origin__") and (cls.__origin__ or cls) in (tuple, typing.Tuple) - if subcls_is_tuple != cls_is_tuple: - return False - if not cls_is_tuple: - return issubclass(subcls, cls) - if not cls.__args__: - return True - if not subcls.__args__ or len(subcls.__args__) != len(cls.__args__): - return False - - return all(_issubclass_tuple(a, b) for a, b in zip(subcls.__args__, cls.__args__)) - def _convert_reduced_vars(reduced_vars, inputs): """ diff --git a/funsor/util.py b/funsor/util.py index d0692ebd8..004f5d596 100644 --- a/funsor/util.py +++ b/funsor/util.py @@ -4,8 +4,12 @@ import functools import inspect import re +import typing +import weakref import numpy as np +from multipledispatch.variadic import isvariadic + _FUNSOR_BACKEND = "numpy" _JAX_LOADED = False @@ -230,3 +234,91 @@ def decorator(fn): setattr(cls, name_, fn) return fn return decorator + + +class GenericTypeMeta(type): + """ + Metaclass to support subtyping with parameters for pattern matching, e.g. Number[int, int]. + """ + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + if not hasattr(cls, "__args__"): + cls.__args__ = () + if cls.__args__: + base, = bases + cls.__origin__ = base + else: + cls._type_cache = weakref.WeakValueDictionary() + + def __getitem__(cls, arg_types): + if not isinstance(arg_types, tuple): + arg_types = (arg_types,) + assert not any(isvariadic(arg_type) for arg_type in arg_types), "nested variadic types not supported" + # switch tuple to typing.Tuple + arg_types = tuple(typing.Tuple if arg_type is tuple else arg_type for arg_type in arg_types) + if arg_types not in cls._type_cache: + assert not cls.__args__, "cannot subscript a subscripted type {}".format(cls) + new_dct = cls.__dict__.copy() + new_dct.update({"__args__": arg_types}) + # type(cls) to handle GenericTypeMeta subclasses + cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct) + return cls._type_cache[arg_types] + + def __subclasscheck__(cls, subcls): # issubclass(subcls, cls) + if cls is subcls: + return True + if not isinstance(subcls, GenericTypeMeta): + return super(GenericTypeMeta, getattr(cls, "__origin__", cls)).__subclasscheck__(subcls) + + cls_origin = getattr(cls, "__origin__", cls) + subcls_origin = getattr(subcls, "__origin__", subcls) + if not super(GenericTypeMeta, cls_origin).__subclasscheck__(subcls_origin): + return False + + if cls.__args__: + if not subcls.__args__: + return False + if len(cls.__args__) != len(subcls.__args__): + return False + for subcls_param, param in zip(subcls.__args__, cls.__args__): + if not _issubclass_tuple(subcls_param, param): + return False + return True + + def __repr__(cls): + return cls.__name__ + ( + "" if not cls.__args__ else + "[{}]".format(", ".join(repr(t) for t in cls.__args__))) + + @lazy_property + def classname(cls): + return repr(cls) + + +def _issubclass_tuple(subcls, cls): + """ + utility for structural subtype checking with tuple subexpressions + """ + # so much boilerplate... + cls_is_union = hasattr(cls, "__origin__") and (cls.__origin__ or cls) is typing.Union + if isinstance(cls, tuple) or cls_is_union: + return any(_issubclass_tuple(subcls, option) + for option in (getattr(cls, "__args__", []) if cls_is_union else cls)) + + subcls_is_union = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) is typing.Union + if isinstance(subcls, tuple) or subcls_is_union: + return any(_issubclass_tuple(option, cls) + for option in (getattr(subcls, "__args__", []) if subcls_is_union else subcls)) + + subcls_is_tuple = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) in (tuple, typing.Tuple) + cls_is_tuple = hasattr(cls, "__origin__") and (cls.__origin__ or cls) in (tuple, typing.Tuple) + if subcls_is_tuple != cls_is_tuple: + return False + if not cls_is_tuple: + return issubclass(subcls, cls) + if not cls.__args__: + return True + if not subcls.__args__ or len(subcls.__args__) != len(cls.__args__): + return False + + return all(_issubclass_tuple(a, b) for a, b in zip(subcls.__args__, cls.__args__)) diff --git a/test/test_cnf.py b/test/test_cnf.py index 2e1c6494d..a5e184955 100644 --- a/test/test_cnf.py +++ b/test/test_cnf.py @@ -9,7 +9,7 @@ from funsor import ops from funsor.cnf import Contraction, BACKEND_TO_EINSUM_BACKEND, BACKEND_TO_LOGSUMEXP_BACKEND -from funsor.domains import Bint, Bint # noqa F403 +from funsor.domains import Array, Bint # noqa F403 from funsor.domains import Reals from funsor.einsum import einsum, naive_plated_einsum from funsor.interpreter import interpretation, reinterpret From 0177b072faadd9822a38651c602256dec087d439 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 19 Jan 2021 00:23:20 -0500 Subject: [PATCH 2/2] restore assertion in FunsorMeta.__getitem__ --- funsor/terms.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/funsor/terms.py b/funsor/terms.py index 45007c2a2..019b1d7f5 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -210,6 +210,13 @@ def __init__(cls, name, bases, dct): cls._ast_fields = getargspec(cls.__init__)[0][1:] cls._cons_cache = WeakValueDictionary() + def __getitem__(cls, arg_types): + if not isinstance(arg_types, tuple): + arg_types = (arg_types,) + assert len(arg_types) == len(cls._ast_fields), \ + "Must provide exactly one type per subexpression" + return super().__getitem__(arg_types) + def __call__(cls, *args, **kwargs): if cls.__args__: cls = cls.__origin__