Skip to content

Commit

Permalink
Make torch backend log return the correct dtype (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored Aug 15, 2020
1 parent 661ed03 commit c944c23
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion funsor/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _is_numeric_array(x):
@ops.log.register(torch.Tensor)
def _log(x):
if x.dtype in (torch.bool, torch.uint8, torch.long):
x = x.float()
x = x.to(dtype=torch.get_default_dtype())
return x.log()


Expand Down
12 changes: 12 additions & 0 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,3 +941,15 @@ def test_tensor_to_funsor_ambiguous_output():
f2 = funsor.to_funsor(x, output=reals(), dim_to_name=OrderedDict({-2: 'a'}))
assert f.inputs == f2.inputs == OrderedDict(a=bint(2))
assert f.output.shape == () == f2.output.shape


@pytest.mark.skipif(get_backend() != "torch", reason="torch-specific regression")
def test_log_correct_dtype():
import torch
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
x = Tensor(torch.rand(3, dtype=torch.get_default_dtype()))
try:
assert (x == x).all().log().data.dtype is x.data.dtype
finally:
torch.set_default_dtype(old_dtype)

0 comments on commit c944c23

Please sign in to comment.