Skip to content

Commit

Permalink
Set funsor backend in backend-specific examples (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Sep 3, 2020
1 parent 6410b19 commit 41735ec
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/discrete_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@


def main(args):
funsor.set_backend("torch")

# Declare parameters.
trans_probs = torch.tensor([[0.2, 0.8], [0.7, 0.3]], requires_grad=True)
emit_probs = torch.tensor([[0.4, 0.6], [0.1, 0.9]], requires_grad=True)
Expand Down
2 changes: 2 additions & 0 deletions examples/eeg_slds.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def filter_and_predict(self, data, smoothing=False):


def main(args):
funsor.set_backend("torch")

# download and pre-process EEG data if not in test mode
if not args.test:
download_data()
Expand Down
2 changes: 2 additions & 0 deletions examples/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


def main(args):
funsor.set_backend("torch")

# Declare parameters.
trans_noise = torch.tensor(0.1, requires_grad=True)
emit_noise = torch.tensor(0.5, requires_grad=True)
Expand Down
3 changes: 3 additions & 0 deletions examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from pyro.generic import infer, optim, pyro, pyro_backend
from torch.distributions import constraints

import funsor
from funsor.interpreter import interpretation
from funsor.montecarlo import monte_carlo


def main(args):
funsor.set_backend("torch")

# Define a basic model with a single Normal latent random variable `loc`
# and a batch of Normally distributed observations.
def model(data):
Expand Down
2 changes: 2 additions & 0 deletions examples/pcfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch

import funsor
import funsor.ops as ops
from funsor.delta import Delta
from funsor.domains import Bint
Expand Down Expand Up @@ -36,6 +37,7 @@ def model(size, position=0):


def main(args):
funsor.set_backend("torch")
torch.manual_seed(args.seed)

print_ = print if args.verbose else lambda msg: None
Expand Down
2 changes: 2 additions & 0 deletions examples/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn as nn
from torch.optim import Adam

import funsor
import funsor.torch.distributions as f_dist
import funsor.ops as ops
from funsor.domains import Reals
Expand Down Expand Up @@ -203,6 +204,7 @@ def track(args):


def main(args):
funsor.set_backend("torch")
if args.force or not args.metrics_filename or not os.path.exists(args.metrics_filename):
results = track(args)
else:
Expand Down
2 changes: 2 additions & 0 deletions examples/slds.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


def main(args):
funsor.set_backend("torch")

# Declare parameters.
trans_probs = funsor.Tensor(torch.tensor([[0.9, 0.1],
[0.1, 0.9]], requires_grad=True))
Expand Down
2 changes: 2 additions & 0 deletions examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def forward(self, z):


def main(args):
funsor.set_backend("torch")

encoder = Encoder()
decoder = Decoder()

Expand Down

0 comments on commit 41735ec

Please sign in to comment.