@@ -341,56 +341,50 @@ def __init__(
341
341
self ._hidden_size , nonlin = activation_nonlin , bias = False , weight = True
342
342
)
343
343
344
- # Potential dynamics.
344
+ # Potential dynamics' capacity .
345
345
self .potentials_dyn_fcn = potentials_dyn_fcn
346
346
self .capacity_learnable = capacity_learnable
347
347
if self .potentials_dyn_fcn in [pd_capacity_21 , pd_capacity_21_abs , pd_capacity_32 , pd_capacity_32_abs ]:
348
348
if _is_iterable (activation_nonlin ):
349
- self ._init_capacity (activation_nonlin [0 ], device )
349
+ self ._capacity_opt_init = self . _init_capacity_heuristic (activation_nonlin [0 ])
350
350
else :
351
- self ._init_capacity (activation_nonlin , device ) # type: ignore[arg-type]
351
+ self ._capacity_opt_init = self . _init_capacity_heuristic (activation_nonlin ) # type: ignore[arg-type]
352
352
else :
353
- self ._capacity_opt = None
354
-
355
- # Initialize cubic decay and capacity if learnable.
356
- if (self .potentials_dyn_fcn is pd_cubic ) and self .kappa_learnable :
357
- self ._kappa_opt .data = self ._kappa_opt_init
358
- elif self .potentials_dyn_fcn in [pd_capacity_21 , pd_capacity_21_abs , pd_capacity_32 , pd_capacity_32_abs ]:
359
- self ._capacity_opt .data = self ._capacity_opt_init
353
+ # Even if the potential function does not include a capacity term, we initialize it to be compatible with
354
+ # custom functions.
355
+ self ._capacity_opt_init = torch .tensor (1.0 , dtype = torch .get_default_dtype ())
356
+ self ._capacity_opt = nn .Parameter (self ._capacity_opt_init .to (device = device ), requires_grad = capacity_learnable )
360
357
361
358
# Move the complete model to the given device.
362
359
self .to (device = device )
363
360
364
- def _init_capacity (self , activation_nonlin : ActivationFunction , device : Union [ str , torch . device ] ) -> None :
361
+ def _init_capacity_heuristic (self , activation_nonlin : ActivationFunction ) -> torch . Tensor :
365
362
"""Initialize the value of the capacity parameter $C$ depending on the activation function.
366
363
367
364
Args:
368
365
activation_nonlin: Nonlinear activation function used.
366
+
367
+ Returns:
368
+ Heuristic initial value for the capacity parameter.
369
369
"""
370
370
if activation_nonlin is torch .sigmoid :
371
371
# sigmoid(7.) approx 0.999
372
- self ._capacity_opt_init = PotentialBased .transform_to_opt_space (
373
- torch .tensor ([7.0 ], device = device , dtype = torch .get_default_dtype ())
374
- )
372
+ return PotentialBased .transform_to_opt_space (torch .tensor ([7.0 ], dtype = torch .get_default_dtype ()))
375
373
elif activation_nonlin is torch .tanh :
376
374
# tanh(3.8) approx 0.999
377
- self ._capacity_opt_init = PotentialBased .transform_to_opt_space (
378
- torch .tensor ([3.8 ], device = device , dtype = torch .get_default_dtype ())
379
- )
380
- else :
381
- raise ValueError (
382
- "For the potential dynamics including a capacity, only output nonlinearities of type "
383
- "torch.sigmoid and torch.tanh are supported!"
384
- )
385
- self ._capacity_opt = nn .Parameter (self ._capacity_opt_init , requires_grad = self .capacity_learnable )
375
+ return PotentialBased .transform_to_opt_space (torch .tensor ([3.8 ], dtype = torch .get_default_dtype ()))
376
+ raise NotImplementedError (
377
+ "For the potential dynamics including a capacity, the initialization heuristic only supports "
378
+ "the activation functions `torch.sigmoid` and `torch.tanh`!"
379
+ )
386
380
387
381
def extra_repr (self ) -> str :
388
382
return super ().extra_repr () + f", capacity_learnable={ self .capacity_learnable } "
389
383
390
384
@property
391
- def capacity (self ) -> Optional [torch .Tensor ]:
392
- """Get the capacity parameter (exists for capacity-based dynamics functions), otherwise return `None` ."""
393
- return None if self . _capacity_opt is None else PotentialBased .transform_to_img_space (self ._capacity_opt )
385
+ def capacity (self ) -> Union [torch .Tensor , nn . Parameter ]:
386
+ """Get the capacity parameter (only used for capacity-based dynamics functions)."""
387
+ return PotentialBased .transform_to_img_space (self ._capacity_opt )
394
388
395
389
def potentials_dot (self , potentials : torch .Tensor , stimuli : torch .Tensor ) -> torch .Tensor :
396
390
r"""Compute the derivative of the neurons' potentials per time step.
0 commit comments