From c944c2339868e33b00339fd6027027d03240c4af Mon Sep 17 00:00:00 2001 From: eb8680 Date: Fri, 14 Aug 2020 23:39:42 -0400 Subject: [PATCH] Make torch backend log return the correct dtype (#349) --- funsor/torch/ops.py | 2 +- test/test_tensor.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 32a63d048..55bec50df 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -107,7 +107,7 @@ def _is_numeric_array(x): @ops.log.register(torch.Tensor) def _log(x): if x.dtype in (torch.bool, torch.uint8, torch.long): - x = x.float() + x = x.to(dtype=torch.get_default_dtype()) return x.log() diff --git a/test/test_tensor.py b/test/test_tensor.py index a52452fdc..56b50c420 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -941,3 +941,15 @@ def test_tensor_to_funsor_ambiguous_output(): f2 = funsor.to_funsor(x, output=reals(), dim_to_name=OrderedDict({-2: 'a'})) assert f.inputs == f2.inputs == OrderedDict(a=bint(2)) assert f.output.shape == () == f2.output.shape + + +@pytest.mark.skipif(get_backend() != "torch", reason="torch-specific regression") +def test_log_correct_dtype(): + import torch + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float64) + x = Tensor(torch.rand(3, dtype=torch.get_default_dtype())) + try: + assert (x == x).all().log().data.dtype is x.data.dtype + finally: + torch.set_default_dtype(old_dtype)