Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…Link into pytorch_batch
  • Loading branch information
David Turner committed Feb 13, 2025
2 parents bd9bb4f + 359737e commit f2f7519
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 22 deletions.
4 changes: 2 additions & 2 deletions psyneulink/core/components/functions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
from psyneulink.core.globals.preferences.preferenceset import PreferenceEntry, PreferenceLevel
from psyneulink.core.globals.registry import register_category
from psyneulink.core.globals.utilities import (
convert_all_elements_to_np_array, convert_to_np_array, get_global_seed, is_instance_or_subclass, object_has_single_value, parameter_spec, parse_valid_identifier, safe_len,
convert_all_elements_to_np_array, convert_to_np_array, _get_global_seed, is_instance_or_subclass, object_has_single_value, parameter_spec, parse_valid_identifier, safe_len,
SeededRandomState, try_extract_0d_array_item, contains_type, is_numeric, NumericCollections,
random_matrix, array_from_matrix_string
)
Expand Down Expand Up @@ -357,7 +357,7 @@ def _seed_setter(value, owning_component, context, *, compilation_sync):

value = try_extract_0d_array_item(value)
if value is None or value == DEFAULT_SEED():
value = get_global_seed()
value = _get_global_seed()

# Remove any old PRNG state
owning_component.parameters.random_state.set(None, context=context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def _evaluate(self, variable=None, context=None, params=None, fit_evaluate=False
# Run compiled mode if requested by parameter and everything is initialized
if self.owner and self.owner.parameters.comp_execution_mode._get(context) != 'Python' and \
ContextFlags.PROCESSING in context.flags:
all_samples = [s for s in itertools.product(*self.search_space)]
all_samples = list(itertools.product(*self.search_space))
all_values, num_evals = self._grid_evaluate(self.owner, context, fit_evaluate)
assert len(all_values) == num_evals
assert len(all_samples) == num_evals
Expand Down Expand Up @@ -846,7 +846,7 @@ def reset_grid(self, context):
"""Reset iterators in `search_space <GridSearch.search_space>`"""
for s in self.search_space:
s.reset()
self.parameters.grid._set(itertools.product(*[s for s in self.search_space]), context)
self.parameters.grid._set((s for s in itertools.product(*[s for s in self.search_space])), context)

def _traverse_grid(self, variable, sample_num, context=None):
"""Get next sample from grid.
Expand Down
3 changes: 1 addition & 2 deletions psyneulink/core/globals/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@
* `ContentAddressableList`
* `make_readonly_property`
* `get_class_attributes`
* `get_global_seed`
* `set_global_seed`
"""
Expand Down Expand Up @@ -1718,7 +1717,7 @@ def seed(self, seed):


_seed = np.uint32((time.time() * 1000) % 2**31)
def get_global_seed(offset=1):
def _get_global_seed(offset=1):
global _seed
old_seed = _seed
_seed = (_seed + offset) % 2**31
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@
from psyneulink.core.globals.parameters import Parameter, check_user_specified
from psyneulink.core.globals.preferences.basepreferenceset import ValidPrefSet, REPORT_OUTPUT_PREF
from psyneulink.core.globals.preferences.preferenceset import PreferenceEntry, PreferenceLevel
from psyneulink.core.globals.utilities import convert_all_elements_to_np_array, is_numeric, is_same_function_spec, object_has_single_value, get_global_seed
from psyneulink.core.globals.utilities import convert_all_elements_to_np_array, is_numeric, is_same_function_spec, object_has_single_value
from psyneulink.core.scheduling.condition import AtTrialStart
from psyneulink.core.components.functions.userdefinedfunction import UserDefinedFunction

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ filterwarnings =
error:the matrix subclass is not the recommended way to represent matrices or deal with linear algebra
error:Passing (type, 1) or '1type' as a synonym of type is deprecated
error:A builtin ctypes object gave a PEP3118:RuntimeWarning
error:Pickle, copy, and deepcopy support will be removed from itertools:DeprecationWarning

[pycodestyle]
# for code explanation see https://pep8.readthedocs.io/en/latest/intro.html#error-codes
Expand Down
30 changes: 17 additions & 13 deletions tests/functions/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def test_output_type_conversion_failure(output_type, variable):
)
def test_seed_setting_results(obj):
obj = obj()
new_seed = pnl.core.components.functions.function.get_global_seed()

# A seed different from the one used by the instance
new_seed = obj.parameters.seed.get() + 1

obj.parameters.seed.set(new_seed, context='c1')

Expand All @@ -81,20 +83,22 @@ def test_seed_setting_results(obj):
# with different seeds, unlike those in test_seed_setting_results
@pytest.mark.function
@pytest.mark.parametrize(
"obj",
"config",
[
pnl.DDM,
pnl.DriftDiffusionIntegrator,
pnl.DriftOnASphereIntegrator(dimension=3),
pnl.OrnsteinUhlenbeckIntegrator,
pnl.DictionaryMemory,
pnl.ContentAddressableMemory,
]
(pnl.DDM, {}),
(pnl.DriftDiffusionIntegrator, {}),
(pnl.DriftOnASphereIntegrator, {"dimension":3}),
(pnl.OrnsteinUhlenbeckIntegrator, {}),
(pnl.DictionaryMemory, {}),
(pnl.ContentAddressableMemory, {}),
],
ids=lambda x: x[0]
)
def test_seed_setting_params(obj):
if not isinstance(obj, pnl.Component):
obj = obj()
new_seed = pnl.core.components.functions.function.get_global_seed()
def test_seed_setting_params(config):
obj_c, params = config
obj = obj_c(**params)

new_seed = obj.parameters.seed.get() + 1

obj._initialize_from_context(pnl.Context(execution_id='c1'))
obj._initialize_from_context(pnl.Context(execution_id='c2'), pnl.Context(execution_id='c1'))
Expand Down
4 changes: 2 additions & 2 deletions tests/misc/test_user_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
import sys

def test_user_seed():
seed1 = subprocess.check_output((sys.executable, "-c" ,"from psyneulink.core.globals.utilities import get_global_seed; print(get_global_seed())"))
seed2 = subprocess.check_output((sys.executable, "-c" ,"from psyneulink.core.globals.utilities import get_global_seed; print(get_global_seed())"))
seed1 = subprocess.check_output((sys.executable, "-c" ,"from psyneulink.core.globals.utilities import _get_global_seed; print(_get_global_seed())"))
seed2 = subprocess.check_output((sys.executable, "-c" ,"from psyneulink.core.globals.utilities import _get_global_seed; print(_get_global_seed())"))
assert seed1 != seed2

0 comments on commit f2f7519

Please sign in to comment.