Skip to content

Commit c79acc3

Browse files
authored
Always initialize the capacity for SimpleNeuralField
1 parent db26722 commit c79acc3

File tree

2 files changed

+21
-27
lines changed

2 files changed

+21
-27
lines changed

neuralfields/simple_neural_fields.py

+20-26
Original file line numberDiff line numberDiff line change
@@ -341,56 +341,50 @@ def __init__(
341341
self._hidden_size, nonlin=activation_nonlin, bias=False, weight=True
342342
)
343343

344-
# Potential dynamics.
344+
# Potential dynamics' capacity.
345345
self.potentials_dyn_fcn = potentials_dyn_fcn
346346
self.capacity_learnable = capacity_learnable
347347
if self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]:
348348
if _is_iterable(activation_nonlin):
349-
self._init_capacity(activation_nonlin[0], device)
349+
self._capacity_opt_init = self._init_capacity_heuristic(activation_nonlin[0])
350350
else:
351-
self._init_capacity(activation_nonlin, device) # type: ignore[arg-type]
351+
self._capacity_opt_init = self._init_capacity_heuristic(activation_nonlin) # type: ignore[arg-type]
352352
else:
353-
self._capacity_opt = None
354-
355-
# Initialize cubic decay and capacity if learnable.
356-
if (self.potentials_dyn_fcn is pd_cubic) and self.kappa_learnable:
357-
self._kappa_opt.data = self._kappa_opt_init
358-
elif self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]:
359-
self._capacity_opt.data = self._capacity_opt_init
353+
# Even if the potential function does not include a capacity term, we initialize it to be compatible with
354+
# custom functions.
355+
self._capacity_opt_init = torch.tensor(1.0, dtype=torch.get_default_dtype())
356+
self._capacity_opt = nn.Parameter(self._capacity_opt_init.to(device=device), requires_grad=capacity_learnable)
360357

361358
# Move the complete model to the given device.
362359
self.to(device=device)
363360

364-
def _init_capacity(self, activation_nonlin: ActivationFunction, device: Union[str, torch.device]) -> None:
361+
def _init_capacity_heuristic(self, activation_nonlin: ActivationFunction) -> torch.Tensor:
365362
"""Initialize the value of the capacity parameter $C$ depending on the activation function.
366363
367364
Args:
368365
activation_nonlin: Nonlinear activation function used.
366+
367+
Returns:
368+
Heuristic initial value for the capacity parameter.
369369
"""
370370
if activation_nonlin is torch.sigmoid:
371371
# sigmoid(7.) approx 0.999
372-
self._capacity_opt_init = PotentialBased.transform_to_opt_space(
373-
torch.tensor([7.0], device=device, dtype=torch.get_default_dtype())
374-
)
372+
return PotentialBased.transform_to_opt_space(torch.tensor([7.0], dtype=torch.get_default_dtype()))
375373
elif activation_nonlin is torch.tanh:
376374
# tanh(3.8) approx 0.999
377-
self._capacity_opt_init = PotentialBased.transform_to_opt_space(
378-
torch.tensor([3.8], device=device, dtype=torch.get_default_dtype())
379-
)
380-
else:
381-
raise ValueError(
382-
"For the potential dynamics including a capacity, only output nonlinearities of type "
383-
"torch.sigmoid and torch.tanh are supported!"
384-
)
385-
self._capacity_opt = nn.Parameter(self._capacity_opt_init, requires_grad=self.capacity_learnable)
375+
return PotentialBased.transform_to_opt_space(torch.tensor([3.8], dtype=torch.get_default_dtype()))
376+
raise NotImplementedError(
377+
"For the potential dynamics including a capacity, the initialization heuristic only supports "
378+
"the activation functions `torch.sigmoid` and `torch.tanh`!"
379+
)
386380

387381
def extra_repr(self) -> str:
388382
return super().extra_repr() + f", capacity_learnable={self.capacity_learnable}"
389383

390384
@property
391-
def capacity(self) -> Optional[torch.Tensor]:
392-
"""Get the capacity parameter (exists for capacity-based dynamics functions), otherwise return `None`."""
393-
return None if self._capacity_opt is None else PotentialBased.transform_to_img_space(self._capacity_opt)
385+
def capacity(self) -> Union[torch.Tensor, nn.Parameter]:
386+
"""Get the capacity parameter (only used for capacity-based dynamics functions)."""
387+
return PotentialBased.transform_to_img_space(self._capacity_opt)
394388

395389
def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor:
396390
r"""Compute the derivative of the neurons' potentials per time step.

tests/test_simple_neural_fields.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def test_neural_fields_trafos(kappa_init: float, tau_init: float):
166166

167167

168168
def test_simple_neural_fields_fail():
169-
with pytest.raises(ValueError):
169+
with pytest.raises(NotImplementedError):
170170
SimpleNeuralField(input_size=6, output_size=3, potentials_dyn_fcn=pd_capacity_21, activation_nonlin=torch.sqrt)
171171

172172
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)