diff --git a/pyro/nn/auto_reg_nn.py b/pyro/nn/auto_reg_nn.py index e6a59f5a9e..e2d29feda2 100644 --- a/pyro/nn/auto_reg_nn.py +++ b/pyro/nn/auto_reg_nn.py @@ -22,9 +22,7 @@ def sample_mask_indices( :param simple: True to space fractional indices by rounding to nearest int, false round randomly :type simple: bool """ - indices = torch.linspace(1, input_dim, steps=hidden_dim, device="cpu").to( - torch.Tensor().device - ) + indices = torch.linspace(1, input_dim, steps=hidden_dim) if simple: # Simple procedure tries to space fractional indices evenly by rounding to nearest int return torch.round(indices)