Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract subtyping metaclass from FunsorMeta and domains #433

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 29 additions & 86 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@
import operator
import warnings
from functools import reduce
from weakref import WeakValueDictionary

import funsor.ops as ops
from funsor.util import broadcast_shape, get_backend, get_tracing_state, quote
from funsor.util import GenericTypeMeta, broadcast_shape, get_backend, get_tracing_state, quote

Domain = type

class Domain(GenericTypeMeta):
pass


class ArrayType(Domain):
"""
Base class of array-like domains.
"""
_type_cache = WeakValueDictionary()

def __getitem__(cls, dtype_shape):
dtype, shape = dtype_shape
assert dtype is not None
Expand All @@ -32,23 +31,7 @@ def __getitem__(cls, dtype_shape):
if shape is not None:
shape = tuple(map(int, shape))

assert cls.dtype in (None, dtype)
assert cls.shape in (None, shape)
key = dtype, shape
result = ArrayType._type_cache.get(key, None)
if result is None:
if dtype == "real":
assert all(isinstance(size, int) and size >= 0 for size in shape)
name = "Reals[{}]".format(",".join(map(str, shape))) if shape else "Real"
result = RealsType(name, (), {"shape": shape})
elif isinstance(dtype, int):
assert dtype >= 0
name = "Bint[{}, {}]".format(dtype, ",".join(map(str, shape)))
result = BintType(name, (), {"dtype": dtype, "shape": shape})
else:
raise ValueError("invalid dtype: {}".format(dtype))
ArrayType._type_cache[key] = result
return result
return super().__getitem__((dtype, shape))

def __subclasscheck__(cls, subcls):
if not isinstance(subcls, ArrayType):
Expand All @@ -59,34 +42,18 @@ def __subclasscheck__(cls, subcls):
return False
return True

def __repr__(cls):
return cls.__name__
@property
def dtype(cls):
return cls.__args__[0]

def __str__(cls):
return cls.__name__
@property
def shape(cls):
return cls.__args__[1]

@property
def num_elements(cls):
return reduce(operator.mul, cls.shape, 1)


class BintType(ArrayType):
def __getitem__(cls, size_shape):
if isinstance(size_shape, tuple):
size, shape = size_shape[0], size_shape[1:]
else:
size, shape = size_shape, ()
return super().__getitem__((size, shape))

def __subclasscheck__(cls, subcls):
if not isinstance(subcls, BintType):
return False
if cls.dtype not in (None, subcls.dtype):
return False
if cls.shape not in (None, subcls.shape):
return False
return True

@property
def size(cls):
return cls.dtype
Expand All @@ -96,27 +63,25 @@ def __iter__(cls):
return (Number(i, cls.size) for i in range(cls.size))


class RealsType(ArrayType):
dtype = "real"
class BintType(ArrayType):
def __getitem__(cls, size_shape):
if isinstance(size_shape, tuple):
size, shape = size_shape[0], size_shape[1:]
else:
size, shape = size_shape, ()
return Array.__getitem__((size, shape))


class RealsType(ArrayType):
def __getitem__(cls, shape):
if not isinstance(shape, tuple):
shape = (shape,)
return super().__getitem__(("real", shape))

def __subclasscheck__(cls, subcls):
if not isinstance(subcls, RealsType):
return False
if cls.dtype not in (None, subcls.dtype):
return False
if cls.shape not in (None, subcls.shape):
return False
return True
return Array.__getitem__(("real", shape))


def _pickle_array(cls):
if cls in (Array, Bint, Real, Reals):
return cls.__name__
if cls in (Array, Bint, Reals):
return repr(cls)
return operator.getitem, (Array, (cls.dtype, cls.shape))


Expand All @@ -132,30 +97,28 @@ class Array(metaclass=ArrayType):
Arary["real", (3, 3)] = Reals[3, 3]
Array["real", ()] = Real
"""
dtype = None
shape = None
pass


class Bint(metaclass=BintType):
class Bint(Array, metaclass=BintType):
"""
Factory for bounded integer types::

Bint[5] # integers ranging in {0,1,2,3,4}
Bint[2, 3, 3] # 3x3 matrices with entries in {0,1}
"""
dtype = None
shape = None
pass


class Reals(metaclass=RealsType):
class Reals(Array, metaclass=RealsType):
"""
Type of a real-valued array with known shape::

Reals[()] = Real # scalar
Reals[8] # vector of length 8
Reals[3, 3] # 3x3 matrix
"""
shape = None
pass


Real = Reals[()]
Expand All @@ -176,34 +139,14 @@ def bint(size):


class ProductDomain(Domain):

_type_cache = WeakValueDictionary()

def __getitem__(cls, arg_domains):
try:
return ProductDomain._type_cache[arg_domains]
except KeyError:
assert isinstance(arg_domains, tuple)
assert all(isinstance(arg_domain, Domain) for arg_domain in arg_domains)
subcls = type("Product_", (Product,), {"__args__": arg_domains})
ProductDomain._type_cache[arg_domains] = subcls
return subcls

def __repr__(cls):
return "Product[{}]".format(", ".join(map(repr, cls.__args__)))

@property
def __origin__(cls):
return Product

@property
def shape(cls):
return (len(cls.__args__),)


class Product(tuple, metaclass=ProductDomain):
"""like typing.Tuple, but works with issubclass"""
__args__ = NotImplemented
pass


@quote.register(BintType)
Expand Down
92 changes: 11 additions & 81 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from weakref import WeakValueDictionary

from multipledispatch import dispatch
from multipledispatch.variadic import Variadic, isvariadic
from multipledispatch.variadic import Variadic

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.domains import Array, Bint, Domain, Product, Real, find_domain
from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret
from funsor.ops import AssociativeOp, GetitemOp, Op
from funsor.util import getargspec, get_backend, lazy_property, pretty, quote
from funsor.util import GenericTypeMeta, getargspec, get_backend, lazy_property, pretty, quote


def substitute(expr, subs):
Expand Down Expand Up @@ -182,7 +182,7 @@ def moment_matching(cls, *args):
interpreter.set_interpretation(eager) # Use eager interpretation by default.


class FunsorMeta(type):
class FunsorMeta(GenericTypeMeta):
"""
Metaclass for Funsors to perform four independent tasks:

Expand All @@ -206,15 +206,16 @@ class FunsorMeta(type):
"""
def __init__(cls, name, bases, dct):
super(FunsorMeta, cls).__init__(name, bases, dct)
if not hasattr(cls, "__args__"):
cls.__args__ = ()
if cls.__args__:
base, = bases
cls.__origin__ = base
else:
if not cls.__args__:
cls._ast_fields = getargspec(cls.__init__)[0][1:]
cls._cons_cache = WeakValueDictionary()
cls._type_cache = WeakValueDictionary()

def __getitem__(cls, arg_types):
if not isinstance(arg_types, tuple):
arg_types = (arg_types,)
assert len(arg_types) == len(cls._ast_fields), \
"Must provide exactly one type per subexpression"
return super().__getitem__(arg_types)

def __call__(cls, *args, **kwargs):
if cls.__args__:
Expand All @@ -230,77 +231,6 @@ def __call__(cls, *args, **kwargs):

return interpret(cls, *args)

def __getitem__(cls, arg_types):
if not isinstance(arg_types, tuple):
arg_types = (arg_types,)
assert not any(isvariadic(arg_type) for arg_type in arg_types), "nested variadic types not supported"
# switch tuple to typing.Tuple
arg_types = tuple(typing.Tuple if arg_type is tuple else arg_type for arg_type in arg_types)
if arg_types not in cls._type_cache:
assert not cls.__args__, "cannot subscript a subscripted type {}".format(cls)
assert len(arg_types) == len(cls._ast_fields), "must provide types for all params"
new_dct = cls.__dict__.copy()
new_dct.update({"__args__": arg_types})
# type(cls) to handle FunsorMeta subclasses
cls._type_cache[arg_types] = type(cls)(cls.__name__, (cls,), new_dct)
return cls._type_cache[arg_types]

def __subclasscheck__(cls, subcls): # issubclass(subcls, cls)
if cls is subcls:
return True
if not isinstance(subcls, FunsorMeta):
return super(FunsorMeta, getattr(cls, "__origin__", cls)).__subclasscheck__(subcls)

cls_origin = getattr(cls, "__origin__", cls)
subcls_origin = getattr(subcls, "__origin__", subcls)
if not super(FunsorMeta, cls_origin).__subclasscheck__(subcls_origin):
return False

if cls.__args__:
if not subcls.__args__:
return False
if len(cls.__args__) != len(subcls.__args__):
return False
for subcls_param, param in zip(subcls.__args__, cls.__args__):
if not _issubclass_tuple(subcls_param, param):
return False
return True

@lazy_property
def classname(cls):
return cls.__name__ + "[{}]".format(", ".join(
str(getattr(t, "classname", t)) # Tuple doesn't have __name__
for t in cls.__args__))


def _issubclass_tuple(subcls, cls):
"""
utility for pattern matching with tuple subexpressions
"""
# so much boilerplate...
cls_is_union = hasattr(cls, "__origin__") and (cls.__origin__ or cls) is typing.Union
if isinstance(cls, tuple) or cls_is_union:
return any(_issubclass_tuple(subcls, option)
for option in (getattr(cls, "__args__", []) if cls_is_union else cls))

subcls_is_union = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) is typing.Union
if isinstance(subcls, tuple) or subcls_is_union:
return any(_issubclass_tuple(option, cls)
for option in (getattr(subcls, "__args__", []) if subcls_is_union else subcls))

subcls_is_tuple = hasattr(subcls, "__origin__") and (subcls.__origin__ or subcls) in (tuple, typing.Tuple)
cls_is_tuple = hasattr(cls, "__origin__") and (cls.__origin__ or cls) in (tuple, typing.Tuple)
if subcls_is_tuple != cls_is_tuple:
return False
if not cls_is_tuple:
return issubclass(subcls, cls)
if not cls.__args__:
return True
if not subcls.__args__ or len(subcls.__args__) != len(cls.__args__):
return False

return all(_issubclass_tuple(a, b) for a, b in zip(subcls.__args__, cls.__args__))


def _convert_reduced_vars(reduced_vars, inputs):
"""
Expand Down
Loading