Skip to content

Commit

Permalink
Fix/emcomposition fieldweights assignment (#3086)
Browse files Browse the repository at this point in the history
* • emcomposition.py
  - _parse_fields: fix bug to properly assign field_weights using self.key_indices
  - add _validate_options_with_learning() that checks compatibility of softmax_choice and normalize_field_weights with learning
  - learn(): raise exception if normalize_fields_weights is False and loss_spec is BINARY_CROSS_ENTROPY
  - docstring mods

• test_emcomposition.py
  - test_normalize_field_weights_with_learning_enabled()
  • Loading branch information
jdcpni authored Oct 28, 2024
1 parent 04e9a59 commit 529c967
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def calc_prob(em_preds, test_ys):
# context_weight = 1, # weight of the context used during memory retrieval
state_weight = .5, # weight of the state used during memory retrieval
context_weight = .5, # weight of the context used during memory retrieval
normalize_field_weights = False, # whether to normalize the field weights during memory retrieval
# normalize_field_weights = True, # whether to normalize the field weights during memory retrieval
# normalize_field_weights = False, # whether to normalize the field weights during memory retrieval
normalize_field_weights = True, # whether to normalize the field weights during memory retrieval
# softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@

# Settings for running script:

MODEL_PARAMS = 'TestParams'
# MODEL_PARAMS = 'DeclanParams'
# MODEL_PARAMS = 'TestParams'
MODEL_PARAMS = 'DeclanParams'

CONSTRUCT_MODEL = True # THIS MUST BE SET TO True to run the script
DISPLAY_MODEL = ( # Only one of the following can be uncommented:
# None # suppress display of model
{ # show simple visual display of model
# 'show_pytorch': True, # show pytorch graph of model
'show_learning': True
# 'show_projections_not_in_composition': True,
# 'exclude_from_gradient_calc_style': 'dashed'# show target mechanisms for learning
# {'show_node_structure': True # show detailed view of node structures and projections
}
None # suppress display of model
# { # show simple visual display of model
# # 'show_pytorch': True, # show pytorch graph of model
# 'show_learning': True
# # 'show_projections_not_in_composition': True,
# # 'exclude_from_gradient_calc_style': 'dashed'# show target mechanisms for learning
# # {'show_node_structure': True # show detailed view of node structures and projections
# }
)
RUN_MODEL = False # False => don't run the model
# RUN_MODEL = True, # True => run the model
# RUN_MODEL = False # False => don't run the model
RUN_MODEL = True, # True => run the model
# REPORT_OUTPUT = ReportOutput.FULL # Sets console output during run [ReportOutput.ON, .TERSE OR .FULL]
REPORT_OUTPUT = ReportOutput.OFF # Sets console output during run [ReportOutput.ON, .TERSE OR .FULL]
REPORT_PROGRESS = ReportProgress.OFF # Sets console progress bar during run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,22 @@
integration_rate = .69, # rate at which state is integrated into new context
state_weight = 1, # weight of the state used during memory retrieval
context_weight = 1, # weight of the context used during memory retrieval
normalize_field_weights = False, # whether to normalize the field weights during memory retrieval
# normalize_field_weights = True, # whether to normalize the field weights during memory retrieval
# normalize_field_weights = False, # whether to normalize the field weights during memory retrieval
normalize_field_weights = True, # whether to normalize the field weights during memory retrieval
# softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_threshold = None, # threshold used to mask out small values in softmax
softmax_threshold = .001, # threshold used to mask out small values in softmax
enable_learning=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
# enable_learning=[True, True, True], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
# enable_learning=[True, True, True]
# enable_learning=True,
# enable_learning=False,
learn_field_weights = False,
learn_field_weights = True,
# learn_field_weights = False,
loss_spec = Loss.BINARY_CROSS_ENTROPY,
# loss_spec = Loss.CROSS_ENTROPY,
# loss_spec = Loss.MSE,
learning_rate = .5,
num_optimization_steps = 10,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ class Kohonen(LearningFunction): # --------------------------------------------
and :math:`w_j` is the column of the matrix in `variable <Kohonen.variable>`\\[2] that corresponds to
the jth element of the activity array in `variable <Kohonen.variable>`\\[1].
.. _note::
.. note::
the array of activities in `variable <Kohonen.variable>`\\[1] is assumed to have been generated by the
dot product of the input pattern in `variable <Kohonen.variable>`\\[0] and the matrix in `variable
<Kohonen.variable>`\\[2], and thus the element with the greatest value in `variable <Kohonen.variable>`\\[1]
Expand Down
4 changes: 2 additions & 2 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5297,7 +5297,7 @@ def _determine_node_roles(self, context=None):

ORIGIN:
- all Nodes that are in first consideration_set (i.e., self.scheduler.consideration_queue[0]).
.. _note::
.. note::
- this takes account of any Projections designated as feedback by graph_processing
(i.e., self.graph.comp_to_vertex[efferent].feedback == EdgeType.FEEDBACK)
- these will all be assigined afferent Projections from Composition.input_CIM
Expand Down Expand Up @@ -5390,7 +5390,7 @@ def _determine_node_roles(self, context=None):
- or for which any efferent projections are either:
- to output_CIM OR
- assigned as feedback (i.e., self.graph.comp_to_vertex[efferent].feedback == EdgeType.FEEDBACK
.. _note::
.. note::
- this insures that for cases in which there are nested CYCLES
(e.g., LearningMechanisms for a `learning Pathway <Composition.Learning_Pathway>`),
only the Node in the *outermost* CYCLE that is specified as a FEEDBACK_SENDER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class ParameterEstimationComposition(Composition):
number of trials executed (see `number of trials <Composition_Execution_Num_Trials>` for additional
information).
.. _note::
.. note::
The **num_trials_per_estimate** is distinct from the **num_trials** argument of the
ParameterEstimationComposition's `run <Composition.run>` method. The latter determines how many full fits
of the `model <ParameterEstimationComposition.model>` are carried out (that is, how many times the
Expand Down
Loading

0 comments on commit 529c967

Please sign in to comment.