Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a compatibility shim for using contrib.funsor with existing models #2997

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyro/contrib/funsor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import os

import pyroapi

from pyro.contrib.funsor.handlers import condition, do, markov
Expand Down Expand Up @@ -38,6 +40,8 @@ def plate(*args, **kwargs):
},
)

os.environ["PYRO_FUNSOR_ACTIVE"] = "1" # TODO better toggle

__all__ = [
"clear_param_store",
"condition",
Expand Down
2 changes: 0 additions & 2 deletions pyro/contrib/funsor/handlers/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,12 @@ def push_global(self, frame):
self._global_stack.append(frame)

def pop_global(self):
assert self._global_stack, "cannot pop the global frame"
return self._global_stack.pop()

def push_iter(self, frame):
self._iter_stack.append(frame)

def pop_iter(self):
assert self._iter_stack, "cannot pop the global frame"
return self._iter_stack.pop()

def push_local(self, frame):
Expand Down
16 changes: 14 additions & 2 deletions pyro/poutine/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pyro.util import ignore_jit_warnings

from .messenger import Messenger
from .runtime import _ENUM_ALLOCATOR
from .runtime import _ENUM_ALLOCATOR, _is_funsor_active


def _tmc_mixture_sample(msg):
Expand Down Expand Up @@ -138,6 +138,10 @@ def __init__(self, first_available_dim=None):
first_available_dim is None or first_available_dim < 0
), first_available_dim
self.first_available_dim = first_available_dim
if _is_funsor_active():
from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger

self._funsor_named = NamedMessenger()
super().__init__()

def __enter__(self):
Expand All @@ -146,7 +150,15 @@ def __enter__(self):
self._markov_depths = {} # site name -> depth (nonnegative integer)
self._param_dims = {} # site name -> (enum dim -> unique id)
self._value_dims = {} # site name -> (enum dim -> unique id)
return super().__enter__()
result = super().__enter__()
if hasattr(self, "_funsor_named"):
self._funsor_named.__enter__()
return result

def __exit__(self, *args):
if hasattr(self, "_funsor_named"):
self._funsor_named.__exit__(*args)
return super().__exit__(*args)

@ignore_jit_warnings()
def _pyro_sample(self, msg):
Expand Down
21 changes: 20 additions & 1 deletion pyro/poutine/markov_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import ExitStack # python 3

from .reentrant_messenger import ReentrantMessenger
from .runtime import _is_funsor_active


class MarkovMessenger(ReentrantMessenger):
Expand Down Expand Up @@ -44,6 +45,14 @@ def __init__(self, history=1, keep=False, dim=None, name=None):
self._iterable = None
self._pos = -1
self._stack = []

if _is_funsor_active():
from pyro.contrib.funsor.handlers.named_messenger import (
MarkovMessenger as FunsorMarkovMessenger,
)

self._funsor_markov = FunsorMarkovMessenger(history=history, keep=keep)

super().__init__()

def generator(self, iterable):
Expand All @@ -60,12 +69,22 @@ def __enter__(self):
self._pos += 1
if len(self._stack) <= self._pos:
self._stack.append(set())
return super().__enter__()

result = super().__enter__()

if hasattr(self, "_funsor_markov"):
self._funsor_markov.__enter__()

return result

def __exit__(self, *args, **kwargs):
if not self.keep:
self._stack.pop()
self._pos -= 1

if hasattr(self, "_funsor_markov"):
self._funsor_markov.__exit__(*args, **kwargs)

return super().__exit__(*args, **kwargs)

def _pyro_sample(self, msg):
Expand Down
57 changes: 57 additions & 0 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

import functools
import os
from collections import OrderedDict
from typing import Dict

from pyro.params.param_store import ( # noqa: F401
Expand All @@ -16,6 +18,10 @@
_PYRO_PARAM_STORE = ParamStoreDict()


def _is_funsor_active() -> bool:
return "PYRO_FUNSOR_ACTIVE" in os.environ


class _DimAllocator:
"""
Dimension allocator for internal use by :class:`plate`.
Expand Down Expand Up @@ -62,6 +68,16 @@ def allocate(self, name, dim):
)
)
self._stack[-1 - dim] = name

if _is_funsor_active():
from pyro.contrib.funsor.handlers.runtime import (
_DIM_STACK,
DimRequest,
DimType,
)

_DIM_STACK.allocate({name: DimRequest(dim, DimType.VISIBLE)})

return dim

def free(self, name, dim):
Expand All @@ -74,6 +90,11 @@ def free(self, name, dim):
while self._stack and self._stack[-1] is None:
self._stack.pop()

if _is_funsor_active():
from pyro.contrib.funsor.handlers.runtime import _DIM_STACK

del _DIM_STACK.global_frame[name]


# Handles placement of plate dimensions
_DIM_ALLOCATOR = _DimAllocator()
Expand All @@ -100,6 +121,12 @@ def set_first_available_dim(self, first_available_dim):
self.next_available_id = 0
self.dim_to_id = {} # only the global ids

if _is_funsor_active():
from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, StackFrame

_DIM_STACK.set_first_available_dim(first_available_dim)
self.global_frame = StackFrame(OrderedDict(), OrderedDict())

def allocate(self, scope_dims=None):
"""
Allocate a new recyclable dim and a unique id.
Expand Down Expand Up @@ -132,8 +159,38 @@ def allocate(self, scope_dims=None):
while dim in scope_dims:
dim -= 1

if _is_funsor_active():
from pyro.contrib.funsor.handlers.runtime import (
_DIM_STACK,
DimRequest,
DimType,
)

dim_ = dim
name = f"_enum_dim_{id_}"
if scope_dims is None:
dim = _DIM_STACK.allocate({name: DimRequest(None, DimType.GLOBAL)})[
name
]
self.dim_to_id[dim] = self.dim_to_id.pop(dim_)
else:
dim = _DIM_STACK.allocate({name: DimRequest(None, DimType.LOCAL)})[name]
assert dim not in scope_dims

return dim, id_

def restore_globals(self):
if _is_funsor_active():
from pyro.contrib.funsor.handlers.runtime import _DIM_STACK

_DIM_STACK.push_global(self.global_frame)

def remove_globals(self):
if _is_funsor_active():
from pyro.contrib.funsor.handlers.runtime import _DIM_STACK

self.global_frame = _DIM_STACK.pop_global()


# Handles placement of enumeration dimensions
_ENUM_ALLOCATOR = _EnumAllocator()
Expand Down