diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/DeclanParams.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/DeclanParams.py index 16f881f7c06..9f5f652b28a 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/DeclanParams.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/DeclanParams.py @@ -67,8 +67,8 @@ def calc_prob(em_preds, test_ys): # context_weight = 1, # weight of the context used during memory retrieval state_weight = .5, # weight of the state used during memory retrieval context_weight = .5, # weight of the context used during memory retrieval - normalize_field_weights = False, # whether to normalize the field weights during memory retrieval - # normalize_field_weights = True, # whether to normalize the field weights during memory retrieval + # normalize_field_weights = False, # whether to normalize the field weights during memory retrieval + normalize_field_weights = True, # whether to normalize the field weights during memory retrieval # softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like # softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/ScriptControl.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/ScriptControl.py index d448f79c29e..04027649aa3 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/ScriptControl.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/ScriptControl.py @@ -2,22 +2,22 @@ # Settings for running script: -MODEL_PARAMS = 'TestParams' -# MODEL_PARAMS = 'DeclanParams' +# MODEL_PARAMS = 'TestParams' +MODEL_PARAMS = 'DeclanParams' CONSTRUCT_MODEL = True # THIS MUST BE SET TO True to run the script DISPLAY_MODEL = ( # Only one of the following can be uncommented: - # None # suppress display of model - { # show simple visual display of model - # 'show_pytorch': True, # show pytorch graph of model - 'show_learning': True - # 'show_projections_not_in_composition': True, - # 'exclude_from_gradient_calc_style': 'dashed'# show target mechanisms for learning - # {'show_node_structure': True # show detailed view of node structures and projections - } + None # suppress display of model + # { # show simple visual display of model + # # 'show_pytorch': True, # show pytorch graph of model + # 'show_learning': True + # # 'show_projections_not_in_composition': True, + # # 'exclude_from_gradient_calc_style': 'dashed'# show target mechanisms for learning + # # {'show_node_structure': True # show detailed view of node structures and projections + # } ) -RUN_MODEL = False # False => don't run the model -# RUN_MODEL = True, # True => run the model +# RUN_MODEL = False # False => don't run the model +RUN_MODEL = True, # True => run the model # REPORT_OUTPUT = ReportOutput.FULL # Sets console output during run [ReportOutput.ON, .TERSE OR .FULL] REPORT_OUTPUT = ReportOutput.OFF # Sets console output during run [ReportOutput.ON, .TERSE OR .FULL] REPORT_PROGRESS = ReportProgress.OFF # Sets console progress bar during run diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/TestParams.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/TestParams.py index 3b8caf6c0bc..2c9bb768d2a 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/TestParams.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/TestParams.py @@ -31,8 +31,8 @@ integration_rate = .69, # rate at which state is integrated into new context state_weight = 1, # weight of the state used during memory retrieval context_weight = 1, # weight of the context used during memory retrieval - normalize_field_weights = False, # whether to normalize the field weights during memory retrieval - # normalize_field_weights = True, # whether to normalize the field weights during memory retrieval + # normalize_field_weights = False, # whether to normalize the field weights during memory retrieval + normalize_field_weights = True, # whether to normalize the field weights during memory retrieval # softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like # softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like @@ -40,10 +40,13 @@ # softmax_threshold = None, # threshold used to mask out small values in softmax softmax_threshold = .001, # threshold used to mask out small values in softmax enable_learning=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE - # enable_learning=[True, True, True], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE + # enable_learning=[True, True, True] + # enable_learning=True, # enable_learning=False, - learn_field_weights = False, + learn_field_weights = True, + # learn_field_weights = False, loss_spec = Loss.BINARY_CROSS_ENTROPY, + # loss_spec = Loss.CROSS_ENTROPY, # loss_spec = Loss.MSE, learning_rate = .5, num_optimization_steps = 10, diff --git a/psyneulink/core/components/functions/nonstateful/learningfunctions.py b/psyneulink/core/components/functions/nonstateful/learningfunctions.py index 96b4b3c3085..c8f4d6ea349 100644 --- a/psyneulink/core/components/functions/nonstateful/learningfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/learningfunctions.py @@ -1014,7 +1014,7 @@ class Kohonen(LearningFunction): # -------------------------------------------- and :math:`w_j` is the column of the matrix in `variable `\\[2] that corresponds to the jth element of the activity array in `variable `\\[1]. - .. _note:: + .. note:: the array of activities in `variable `\\[1] is assumed to have been generated by the dot product of the input pattern in `variable `\\[0] and the matrix in `variable `\\[2], and thus the element with the greatest value in `variable `\\[1] diff --git a/psyneulink/core/compositions/composition.py b/psyneulink/core/compositions/composition.py index c10706330d0..98116b80c70 100644 --- a/psyneulink/core/compositions/composition.py +++ b/psyneulink/core/compositions/composition.py @@ -5297,7 +5297,7 @@ def _determine_node_roles(self, context=None): ORIGIN: - all Nodes that are in first consideration_set (i.e., self.scheduler.consideration_queue[0]). - .. _note:: + .. note:: - this takes account of any Projections designated as feedback by graph_processing (i.e., self.graph.comp_to_vertex[efferent].feedback == EdgeType.FEEDBACK) - these will all be assigined afferent Projections from Composition.input_CIM @@ -5390,7 +5390,7 @@ def _determine_node_roles(self, context=None): - or for which any efferent projections are either: - to output_CIM OR - assigned as feedback (i.e., self.graph.comp_to_vertex[efferent].feedback == EdgeType.FEEDBACK - .. _note:: + .. note:: - this insures that for cases in which there are nested CYCLES (e.g., LearningMechanisms for a `learning Pathway `), only the Node in the *outermost* CYCLE that is specified as a FEEDBACK_SENDER diff --git a/psyneulink/core/compositions/parameterestimationcomposition.py b/psyneulink/core/compositions/parameterestimationcomposition.py index 1b39557d5bc..024f9152da4 100644 --- a/psyneulink/core/compositions/parameterestimationcomposition.py +++ b/psyneulink/core/compositions/parameterestimationcomposition.py @@ -363,7 +363,7 @@ class ParameterEstimationComposition(Composition): number of trials executed (see `number of trials ` for additional information). - .. _note:: + .. note:: The **num_trials_per_estimate** is distinct from the **num_trials** argument of the ParameterEstimationComposition's `run ` method. The latter determines how many full fits of the `model ` are carried out (that is, how many times the diff --git a/psyneulink/library/compositions/emcomposition.py b/psyneulink/library/compositions/emcomposition.py index e0ef65c2443..807f316fe55 100644 --- a/psyneulink/library/compositions/emcomposition.py +++ b/psyneulink/library/compositions/emcomposition.py @@ -453,11 +453,18 @@ * **normalize_field_weights**: specifies whether the `field_weights ` are normalized or their raw values are used. If True, the `field_weights ` are normalized so that - they sum to 1.0, and are used to weight the corresponding fields during retrieval (see `Weight fields - `). If False, the raw values of the `field_weights ` are - used to weight (i.e., multiply) the retrieved value of each field. This setting is ignored if **field_weights** + they sum to 1.0, and are used to weight (i.e., multiply) the corresponding fields during retrieval (see `Weight + fields `). If False, the raw values of the `field_weights ` + are used to weight the retrieved value of each field. This setting is ignored if **field_weights** is None or `concatenate_queries ` is in effect. + .. warning:: + If **normalize_field_weights** is False and **enable_learning** is True, a warning is issued indicating that + this may produce an error if the `loss_spec ` for the EMComposition (or an + `AutodiffComposition` that contains it) requires all values to be between 0 and 1, and calling the + EMComposition's `learn ` method will generate an error if the loss_spec is specified is + one known to be incompatible (e.g., `BINARY_CROSS_ENTROPY `). + .. _EMComposition_Field_Names: * **field_names**: specifies names that can be assigned to the fields. The number of names specified must @@ -698,7 +705,7 @@ product for each key field is passed to the corresponding `field_weight_node ` where it is multiplied by the corresponding `field_weight ` (if `use_gating_for_weighting ` is True, this is done by using the `field_weight - ` to output gate the `softmax_node `). The weighted softamx + ` to output gate the `softmax_node `). The weighted softmax vectors for all key fields are then passed to the `combined_softmax_node `, where they are haddamard summed to produce a single weighting for each memory. @@ -997,9 +1004,7 @@ from psyneulink._typing import Optional, Union from psyneulink.core.components.functions.nonstateful.transferfunctions import SoftMax, LinearMatrix from psyneulink.core.components.functions.nonstateful.combinationfunctions import Concatenate, LinearCombination -from psyneulink.core.components.functions.nonstateful.selectionfunctions import ARG_MAX, ARG_MAX_INDICATOR -from psyneulink.core.components.functions.function import \ - DEFAULT_SEED, _random_state_getter, _seed_setter +from psyneulink.core.components.functions.function import DEFAULT_SEED, _random_state_getter, _seed_setter from psyneulink.core.compositions.composition import CompositionError, NodeRole from psyneulink.library.compositions.autodiffcomposition import AutodiffComposition, torch_available from psyneulink.library.components.mechanisms.modulatory.learning.EMstoragemechanism import EMStorageMechanism @@ -1010,15 +1015,15 @@ from psyneulink.core.globals.parameters import Parameter, check_user_specified from psyneulink.core.globals.context import handle_external_context from psyneulink.core.globals.keywords import \ - (ADAPTIVE, ALL, AUTO, CONTEXT, CONTROL, DEFAULT_INPUT, DEFAULT_VARIABLE, EM_COMPOSITION, FULL_CONNECTIVITY_MATRIX, - GAIN, IDENTITY_MATRIX, MAX_INDICATOR, MULTIPLICATIVE_PARAM, NAME, PARAMS, PROB_INDICATOR, PRODUCT, PROJECTIONS, - RANDOM, SIZE, VARIABLE) + (ADAPTIVE, ALL, ARG_MAX, ARG_MAX_INDICATOR, AUTO, CONTEXT, CONTROL, DEFAULT_INPUT, DEFAULT_VARIABLE, + EM_COMPOSITION, FULL_CONNECTIVITY_MATRIX, GAIN, IDENTITY_MATRIX, MULTIPLICATIVE_PARAM, NAME, + PARAMS, PROB_INDICATOR, PRODUCT, PROJECTIONS, RANDOM, SIZE, VARIABLE, Loss) from psyneulink.core.globals.utilities import convert_all_elements_to_np_array, is_numeric_scalar from psyneulink.core.globals.registry import name_without_suffix from psyneulink.core.llvm import ExecutionMode -__all__ = ['EMComposition', 'WEIGHTED_AVG', 'PROBABILISTIC'] +__all__ = ['EMComposition', 'EMCompositionError', 'WEIGHTED_AVG', 'PROBABILISTIC'] STORAGE_PROB = 'storage_prob' WEIGHTED_AVG = ALL @@ -1029,7 +1034,7 @@ MATCH_TO_KEYS_AFFIX = ' [MATCH to KEYS]' RETRIEVED_AFFIX = ' [RETRIEVED]' WEIGHTED_SOFTMAX_AFFIX = ' [WEIGHTED SOFTMAX]' -RETRIEVE_NODE_NAME = 'RETRIEVE' +COMBINED_SOFTMAX_NODE_NAME = 'RETRIEVE' STORE_NODE_NAME = 'STORE' @@ -1290,26 +1295,18 @@ class EMComposition(AutodiffComposition): `field_names ` is specified, then the name of each value_input_node is assigned the corresponding field name appended with * [VALUE]*. - input_nodes : list[ProcessingMechanism] - Full list of `INPUT ` `Nodes ` ordered with query_input_nodes first - followed by value_input_nodes; used primarily for internal computations - - input_nodes_by_fields : list[ProcessingMechanism] - Full list of `INPUT ` `Nodes ` in the same order specified in the - **field_names** argument of the constructor and in `self.field_names `. - concatenate_queries_node : ProcessingMechanism `ProcessingMechanism` that concatenates the inputs to `query_input_nodes ` into a single vector used for the matching processing if `concatenate keys ` is True. This is not created if the **concatenate_queries** argument to the EMComposition's constructor is False or is overridden (see `concatenate_queries `), or there is only one - query_input_node. + query_input_node. This node is named *CONCATENATE_KEYS* match_nodes : list[ProcessingMechanism] `ProcessingMechanisms ` that receive the dot product of each key and those stored in the corresponding field of `memory ` (see `Match memories by field - ` for additional details). These are assigned names that prepend *MATCH_n* to the - name of the corresponding `query_input_nodes `. + ` for additional details). These are named the same as the corresponding + `query_input_nodes ` appended with the suffix *[MATCH to KEYS]*. softmax_gain_control_nodes : list[ControlMechanism] `ControlMechanisms ` that adaptively control the `softmax_gain ` @@ -1320,7 +1317,8 @@ class EMComposition(AutodiffComposition): softmax_nodes : list[ProcessingMechanism] `ProcessingMechanisms ` that compute the softmax over the vectors received from the corresponding `match_nodes ` (see `Softmax normalize matches over fields - ` for additional details). + ` for additional details). These are named the same as the corresponding + `query_input_nodes ` appended with the suffix *[SOFTMAX]*. field_weight_nodes : list[ProcessingMechanism] `ProcessingMechanisms `, each of which use the `field weight ` @@ -1328,7 +1326,8 @@ class EMComposition(AutodiffComposition): `weighted_softmax_node `. These are implemented only if more than one `key field ` is specified (see `Fields ` for additional details), and are replaced with `retrieval_gating_nodes ` if - `use_gating_for_weighting ` is True. + `use_gating_for_weighting ` is True. These are named the same as the + corresponding `query_input_nodes ` appended with the suffix *[WEIGHT]*. weighted_softmax_nodes : list[ProcessingMechanism] `ProcessingMechanisms `, each of which receives the output of the corresponding @@ -1336,9 +1335,10 @@ class EMComposition(AutodiffComposition): for a given `field `, and multiplies them to produce the weighted softmax for that field; these are implemented only if more than one `key field ` is specified (see `Fields ` for additional details) and `use_gating_for_weighting - ` is False (in which case, `field_weights ` + ` is False (otherwise, `field_weights ` are applied through output gating of the `softmax_nodes ` by the - `retrieval_gating_nodes `). + `retrieval_gating_nodes `). These are named the same as the corresponding + `query_input_nodes ` appended with the suffix *[WEIGHTED SOFTMAX]*. retrieval_gating_nodes : list[GatingMechanism] `GatingMechanisms ` that uses the `field weight ` for each @@ -1348,32 +1348,41 @@ class EMComposition(AutodiffComposition): `key field ` is specified (see `Fields ` for additional details). combined_softmax_node : ProcessingMechanism - `ProcessingMechanism` that receives the softmax normalized dot products of the keys and memories - from the `softmax_nodes `, weighted by the `field_weights_nodes + `ProcessingMechanism` that receives the softmax normalized dot products of the keys and memories from the + `softmax_nodes `, weighted by the `field_weights_nodes ` if more than one `key field ` is specified - (or `retrieval_gating_nodes ` if `use_gating_for_weighting + (or by `retrieval_gating_nodes ` if `use_gating_for_weighting ` is True), and combines them into a single vector that is used to retrieve the corresponding memory for each field from `memory ` (see `Retrieve values by - field ` for additional details). + field ` for additional details). This node is named *RETRIEVE*. retrieved_nodes : list[ProcessingMechanism] `ProcessingMechanisms ` that receive the vector retrieved for each field in `memory - ` (see `Retrieve values by field ` for additional details); - these are assigned the same names as the `query_input_nodes ` and + ` (see `Retrieve values by field ` for additional details). + These are assigned the same names as the `query_input_nodes ` and `value_input_nodes ` to which they correspond appended with the suffix * [RETRIEVED]*, and are in the same order as `input_nodes_by_fields ` to which to which they correspond. storage_node : EMStorageMechanism `EMStorageMechanism` that receives inputs from the `query_input_nodes ` and - `value_input_nodes `, and stores these in the corresponding field of - `memory ` with probability `storage_prob ` after a retrieval - has been made (see `Retrieval and Storage ` for additional details). + `value_input_nodes `, and stores these in the corresponding field of`memory + ` with probability `storage_prob ` after a retrieval has been + made (see `Retrieval and Storage ` for additional details). This node is named *STORE*. .. technical_note:: The `storage_node ` is assigned a Condition to execute after the `retrieved_nodes ` have executed, to ensure that storage occurs after retrieval, but before any subequent processing is done (i.e., in a composition in which the EMComposition may be embededded. + + input_nodes : list[ProcessingMechanism] + Full list of `INPUT ` `Nodes ` ordered with query_input_nodes first + followed by value_input_nodes; used primarily for internal computations + + input_nodes_by_fields : list[ProcessingMechanism] + Full list of `INPUT ` `Nodes ` in the same order specified in the + **field_names** argument of the constructor and in `self.field_names `. + """ componentCategory = EM_COMPOSITION @@ -1618,6 +1627,7 @@ def __init__(self, memory_capacity = memory_capacity, field_weights = field_weights, field_names = field_names, + normalize_field_weights = normalize_field_weights, concatenate_queries = concatenate_queries, softmax_gain = softmax_gain, softmax_threshold = softmax_threshold, @@ -1633,7 +1643,7 @@ def __init__(self, **kwargs ) - self._validate_softmax_choice(softmax_choice, enable_learning) + self._validate_options_with_learning(softmax_choice, normalize_field_weights, enable_learning) self._construct_pathways(self.memory_template, self.memory_capacity, @@ -1666,7 +1676,7 @@ def __init__(self, self.scheduler.add_condition(self.storage_node, conditions.AllHaveRun(*self.retrieved_nodes)) # # Generates expected results, but execution_sets has a second set for INPUT nodes - # and the the match_nodes again with storage_node + # and the match_nodes again with storage_node # # --------------------------------------- # @@ -1904,11 +1914,18 @@ def _parse_fields(self, self.num_fields = len(self.entry_template) keys_weights = [i for i in parsed_field_weights if i != 0] self.num_keys = len(keys_weights) + # Get indices of field_weights that specify keys and values: self.key_indices = np.flatnonzero(parsed_field_weights) + assert len(self.key_indices) == self.num_keys, \ + f"PROGRAM ERROR: number of keys ({self.num_keys}) does not match number of " \ + f"non-zero values in field_weights ({len(self.key_indices)})." self.value_indices = np.where(parsed_field_weights==0)[0] - self.num_values = self.num_fields - self.num_keys + assert len(self.value_indices) == self.num_values, \ + f"PROGRAM ERROR: number of values ({self.num_values}) does not match number of " \ + f"zero values in field_weights ({len(self.value_indices)})." + if parsed_field_names: self.key_names = [parsed_field_names[i] for i in self.key_indices] # self.value_names = parsed_field_names[self.num_keys:] @@ -2208,12 +2225,17 @@ def _construct_match_nodes(self, memory_template, memory_capacity, concatenate_q return match_nodes - def _validate_softmax_choice(self, softmax_choice, enable_learning): + def _validate_options_with_learning(self, softmax_choice, normalize_field_weights, enable_learning): if softmax_choice in {ARG_MAX, PROBABILISTIC} and enable_learning: warnings.warn(f"The 'softmax_choice' arg of '{self.name}' is set to '{softmax_choice}' with " f"'enable_learning' set to True (or a list); this will generate an error if its " f"'learn' method is called. Set 'softmax_choice' to WEIGHTED_AVG before learning.") + if enable_learning and not normalize_field_weights: + warnings.warn(f"The 'normalize_field_weights' arg of '{self.name}' is set to False with " + f"'enable_learning' set to True (or a list); this may generate an error if " + f"the 'loss_spec' used for learning requires values to be between 0 and 1.") + def _construct_softmax_nodes(self, memory_capacity, field_weights, softmax_gain, softmax_threshold, softmax_choice)->list: """Create nodes that, for each key field, compute the softmax over the similarities between the input and the @@ -2221,12 +2243,7 @@ def _construct_softmax_nodes(self, memory_capacity, field_weights, """ # Get indices of field_weights that specify keys: - key_indices = np.where(np.array(field_weights) != 0) - key_weights = [field_weights[i] for i in key_indices[0]] - - assert len(key_indices[0]) == self.num_keys, \ - f"PROGRAM ERROR: number of keys ({self.num_keys}) does not match number of " \ - f"non-zero values in field_weights ({len(key_indices)})." + key_weights = [field_weights[i] for i in self.key_indices] if softmax_choice == ARG_MAX: # ARG_MAX would return entry multiplied by its dot product @@ -2269,7 +2286,8 @@ def _construct_field_weight_nodes(self, field_weights, concatenate_queries, use_ if not concatenate_queries and self.num_keys > 1: if use_gating_for_weighting: - field_weight_nodes = [GatingMechanism(input_ports={VARIABLE: np.array(field_weights[i]), + field_weight_nodes = [GatingMechanism(input_ports={VARIABLE: + np.array(field_weights[self.key_indices[i]]), PARAMS:{DEFAULT_INPUT: DEFAULT_VARIABLE}, NAME: 'OUTCOME'}, gate=[key_match_pair[1].output_ports[0]], @@ -2278,8 +2296,9 @@ def _construct_field_weight_nodes(self, field_weights, concatenate_queries, use_ for i, key_match_pair in enumerate(zip(self.query_input_nodes, self.softmax_nodes))] else: - field_weight_nodes = [ProcessingMechanism(input_ports={VARIABLE: np.array(field_weights[i]), - PARAMS:{DEFAULT_INPUT: DEFAULT_VARIABLE}, + field_weight_nodes = [ProcessingMechanism(input_ports={VARIABLE: + np.array(field_weights[self.key_indices[i]]), + PARAMS: {DEFAULT_INPUT: DEFAULT_VARIABLE}, NAME: 'FIELD_WEIGHT'}, name= 'WEIGHT' if self.num_keys == 1 else f'{self.key_names[i]} [WEIGHT]') @@ -2330,7 +2349,7 @@ def _construct_combined_softmax_node(self, name=f'WEIGHTED SOFTMAX to RETRIEVAL for ' f'{self.key_names[i]}') for i, s in enumerate(input_source)]}], - name=RETRIEVE_NODE_NAME)) + name=COMBINED_SOFTMAX_NODE_NAME)) assert len(combined_softmax_node.output_port.value) == memory_capacity, \ 'PROGRAM ERROR: number of items in combined_softmax_node ' \ @@ -2534,6 +2553,9 @@ def learn(self, *args, **kwargs)->list: if arg in {ARG_MAX, PROBABILISTIC}: raise EMCompositionError(f"The ARG_MAX and PROBABILISTIC options for the 'softmax_choice' arg " f"of '{self.name}' cannot be used during learning; change to WEIGHTED_AVG.") + if self.loss_spec in {Loss.BINARY_CROSS_ENTROPY} and not self.normalize_field_weights: + raise EMCompositionError(f"The 'loss_spec' arg of '{self.name}' is set to '{self.loss_spec.name}' with " + f"'normalize_field_weights' set to False; this must be True to use this loss_spec.") return super().learn(*args, **kwargs) def _get_execution_mode(self, execution_mode): diff --git a/tests/composition/test_emcomposition.py b/tests/composition/test_emcomposition.py index 3c90676b990..d0206a020a4 100644 --- a/tests/composition/test_emcomposition.py +++ b/tests/composition/test_emcomposition.py @@ -238,7 +238,7 @@ def test_softmax_choice(self): em = EMComposition(memory_template=[[[1,.1,.1]], [[.1,1,.1]], [[.1,.1,1]]]) for softmax_choice in [pnl.ARG_MAX, pnl.PROBABILISTIC]: - with pytest.raises(pnl.ComponentError) as error_text: + with pytest.raises(EMCompositionError) as error_text: em.parameters.softmax_choice.set(softmax_choice) em.learn() assert (f"The ARG_MAX and PROBABILISTIC options for the 'softmax_choice' arg " @@ -252,6 +252,22 @@ def test_softmax_choice(self): f"'learn' method is called. Set 'softmax_choice' to WEIGHTED_AVG before learning.") assert warning_msg in str(warning[0].message) + def test_normalize_field_weights_with_learning_enabled(self): + with pytest.warns(UserWarning) as warning: + em = EMComposition(normalize_field_weights=False, + enable_learning=True, + memory_fill=(0,.1), + loss_spec=pnl.Loss.BINARY_CROSS_ENTROPY) + warning_msg = (f"The 'normalize_field_weights' arg of 'EM_Composition' is set to False with " + f"'enable_learning' set to True (or a list); this may generate an error if the " + f"'loss_spec' used for learning requires values to be between 0 and 1.") + assert warning_msg in str(warning[0].message) + + with pytest.raises(EMCompositionError) as error_text: + em.learn() + assert (f"The 'loss_spec' arg of 'EM_Composition' is set to 'BINARY_CROSS_ENTROPY' with " + f"'normalize_field_weights' set to False; this must be True to use this loss_spec." + in str(error_text.value)) @pytest.mark.pytorch