Skip to content

Commit

Permalink
Refactor/emcomposition mechs and concatenate (#3084)
Browse files Browse the repository at this point in the history
* • emcomposition.py
  - TransferMechanism -> ProcessingMechanism

* • emcomposition.py
  concatenate_keys -> concatenate_queries
  • Loading branch information
jdcpni authored Oct 27, 2024
1 parent ba25f6e commit 04e9a59
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def calc_prob(em_preds, test_ys):
memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims
memory_init = (0,.0001), # Initialize memory with random values in interval
# memory_init = None, # Initialize with zeros
concatenate_keys = False,
# concatenate_keys = True,
concatenate_queries = False,
# concatenate_queries = True,

# environment
# curriculum_type = 'Interleaved',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def construct_model(model_name:str=model_params['name'],
previous_state_retrieval_weight:Union[float,int]=model_params['state_weight'],
context_retrieval_weight:Union[float,int]=model_params['context_weight'],
normalize_field_weights = model_params['normalize_field_weights'],
concatenate_keys = model_params['concatenate_keys'],
concatenate_queries = model_params['concatenate_queries'],
learn_field_weights = model_params['learn_field_weights'],
memory_capacity = memory_capacity,
memory_init=model_params['memory_init'],
Expand Down Expand Up @@ -260,7 +260,7 @@ def construct_model(model_name:str=model_params['name'],
context_retrieval_weight
),
normalize_field_weights=normalize_field_weights,
concatenate_keys=concatenate_keys,
concatenate_queries=concatenate_queries,
learn_field_weights=learn_field_weights,
learning_rate=learning_rate,
enable_learning=enable_learning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims
memory_init = (0,.0001), # Initialize memory with random values in interval
# memory_init = None, # Initialize with zeros
concatenate_keys = False,
# concatenate_keys = True,
concatenate_queries = False,
# concatenate_queries = True,

# environment
# curriculum_type = 'Interleaved',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ class EMStorageMechanism(LearningMechanism):
concatenation_node : OutputPort or Mechanism : default None
specifies the `OutputPort` or `Mechanism` in which the `value <OutputPort.value>` of the `key fields
<EMStorageMechanism_Fields>` are concatenated (see `concatenate keys <EMComposition_Concatenate_Keys>`
<EMStorageMechanism_Fields>` are concatenated (see `concatenate keys <EMComposition_Concatenate_Queries>`
for additional details).
memory_matrix : List or 2d np.array : default None
Expand Down Expand Up @@ -657,12 +657,12 @@ def _validate_params(self, request_set, target_set=None, context=None):
f"the same number of items as its 'fields' arg ({len(fields)}).")

num_keys = len([i for i in field_types if i==1])
concatenate_keys = 'concatenation_node' in request_set and request_set['concatenation_node'] is not None
concatenate_queries = 'concatenation_node' in request_set and request_set['concatenation_node'] is not None

# Ensure the number of learning_signals is equal to the number of fields + number of keys
if LEARNING_SIGNALS in request_set:
learning_signals = request_set[LEARNING_SIGNALS]
if concatenate_keys:
if concatenate_queries:
num_match_fields = 1
else:
num_match_fields = num_keys
Expand All @@ -674,7 +674,7 @@ def _validate_params(self, request_set, target_set=None, context=None):
# Ensure shape of learning_signals matches shapes of matrices for match nodes (i.e., either keys or concatenate)
for i, learning_signal in enumerate(learning_signals[:num_match_fields]):
learning_signal_shape = learning_signal.parameters.matrix._get(context).shape
if concatenate_keys:
if concatenate_queries:
memory_matrix_field_shape = np.array([np.concatenate(row, dtype=object).flatten()
for row in memory_matrix[:,0:num_keys]]).T.shape
else:
Expand Down
Loading

0 comments on commit 04e9a59

Please sign in to comment.