Skip to content

Commit

Permalink
llvm/ExecutionMode: Use identity comparisons
Browse files Browse the repository at this point in the history
Enum members are singletons.

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Jan 31, 2025
1 parent f1f8d46 commit f2523c7
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions psyneulink/library/compositions/autodiffcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ def create_pathway(node)->list:
if node not in self.get_nodes_by_role(NodeRole.TARGET)
for pathway in _get_pytorch_backprop_pathway(node)]

if execution_mode == pnlvm.ExecutionMode.PyTorch:
if execution_mode is pnlvm.ExecutionMode.PyTorch:
# For PyTorch mode, only need to construct dummy TARGET Nodes, to allow targets to be:
# - specified in the same way as for other execution_modes
# - trial-by-trial values kept aligned with inputs in batch / minibatch construction
Expand Down Expand Up @@ -1073,7 +1073,7 @@ def autodiff_forward(self, inputs, targets,
before the next time it calls run(), in a call to backward() by do_gradient_optimization()
in _batch_inputs() or _batch_function_inputs(),
"""
assert execution_mode == pnlvm.ExecutionMode.PyTorch
assert execution_mode is pnlvm.ExecutionMode.PyTorch
pytorch_rep = self.parameters.pytorch_representation._get(context)

# --------- Do forward computation on current inputs -------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions psyneulink/library/compositions/compositionrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def run_learning(self,
**kwargs)
skip_initialization = True

if execution_mode == ExecutionMode.PyTorch:
if execution_mode is ExecutionMode.PyTorch:
pytorch_rep = (self._composition.parameters.pytorch_representation._get(context).
copy_weights_to_psyneulink(context))
if pytorch_rep and synch_with_pnl_options[MATRIX_WEIGHTS] == MINIBATCH:
Expand All @@ -372,7 +372,7 @@ def run_learning(self,
self._composition.parameters.results.get(context)[-1 * num_epoch_results:], context)
# return result of last *trial* (as usual for a call to run)

if execution_mode == ExecutionMode.PyTorch and synch_with_pnl_options[MATRIX_WEIGHTS] == EPOCH:
if execution_mode is ExecutionMode.PyTorch and synch_with_pnl_options[MATRIX_WEIGHTS] == EPOCH:
# Copy weights at end of learning run
pytorch_rep.copy_weights_to_psyneulink(context)

Expand Down
5 changes: 3 additions & 2 deletions tests/composition/test_autodiffcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2907,12 +2907,13 @@ def test_optimizer_specs(self, learning_rate, weight_decay, optimizer_type, expe

# fp32 results are different due to rounding
if pytest.helpers.llvm_current_fp_precision() == 'fp32' and \
autodiff_mode != pnl.ExecutionMode.PyTorch and \
autodiff_mode is not pnl.ExecutionMode.PyTorch and \
optimizer_type == 'sgd' and \
learning_rate == 10:
expected = [[[0.9918830394744873]], [[0.9982172846794128]], [[0.9978305697441101]], [[0.9994590878486633]]]

# FIXME: LLVM version is broken with learning rate == 1.5
if learning_rate != 1.5 or autodiff_mode == pnl.ExecutionMode.PyTorch:
if learning_rate != 1.5 or autodiff_mode is pnl.ExecutionMode.PyTorch:
np.testing.assert_allclose(results, expected)


Expand Down
2 changes: 1 addition & 1 deletion tests/composition/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4274,7 +4274,7 @@ def test_one_time_warning_for_run_with_no_inputs(self):
comp.run()

def _check_comp_ex(self, comp, comparison, comp_mode, struct_name, context=None, is_not=False):
if comp_mode == pnl.ExecutionMode.Python:
if comp_mode is pnl.ExecutionMode.Python:
return

if context is None:
Expand Down
8 changes: 4 additions & 4 deletions tests/composition/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2469,7 +2469,7 @@ def test_modulation_simple(self, cost, expected, exp_values, comp_mode):

ret = comp.run(inputs={mech: [2]}, num_trials=1, execution_mode=comp_mode)
np.testing.assert_allclose(ret, expected)
if comp_mode == pnl.ExecutionMode.Python:
if comp_mode is pnl.ExecutionMode.Python:
np.testing.assert_allclose(comp.controller.function.saved_values.flatten(), exp_values)

@pytest.mark.benchmark
Expand Down Expand Up @@ -2531,7 +2531,7 @@ def test_modulation_of_random_state_DDM(self, comp_mode, benchmark, prng):
benchmark(comp.run, inputs={ctl_mech:seeds, mech:5.0}, num_trials=len(seeds) * 2, execution_mode=comp_mode)

# Python uses fp64 irrespective of the pytest precision setting
precision = 'fp64' if comp_mode == pnl.ExecutionMode.Python else pytest.helpers.llvm_current_fp_precision()
precision = 'fp64' if comp_mode is pnl.ExecutionMode.Python else pytest.helpers.llvm_current_fp_precision()
if prng == 'Default':
np.testing.assert_allclose(np.squeeze(comp.results[:len(seeds) * 2]), [[100, 21], [100, 23], [100, 20]] * 2)
elif prng == 'Philox' and precision == 'fp64':
Expand Down Expand Up @@ -2644,7 +2644,7 @@ def test_modulation_of_random_state_DDM_Analytical(self, comp_mode, benchmark, p
benchmark(comp.run, inputs={ctl_mech:seeds, mech:0.1}, num_trials=len(seeds) * 2, execution_mode=comp_mode)

# Python uses fp64 irrespective of the pytest precision setting
precision = 'fp64' if comp_mode == pnl.ExecutionMode.Python else pytest.helpers.llvm_current_fp_precision()
precision = 'fp64' if comp_mode is pnl.ExecutionMode.Python else pytest.helpers.llvm_current_fp_precision()
if prng == 'Default':
np.testing.assert_allclose(np.squeeze(comp.results[:len(seeds) * 2]), [[-1, 3.99948962], [1, 3.99948962], [-1, 3.99948962]] * 2)
elif prng == 'Philox' and precision == 'fp64':
Expand Down Expand Up @@ -3359,7 +3359,7 @@ def comp_run(inputs, execution_mode):
results, saved_values = benchmark(comp_run, inputs, mode)

np.testing.assert_array_equal(results, result)
if mode == pnl.ExecutionMode.Python:
if mode is pnl.ExecutionMode.Python:
np.testing.assert_array_equal(saved_values.flatten(), [0.75, 1.5, 2.25])

def test_model_based_ocm_with_buffer(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_greedy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def action_fn(variable):
# np.testing.assert_allclose(run_results, [[0.9705216285127504, -0.1343332460369043]])
np.testing.assert_allclose(run_results, [[0.9705216285127504, -0.1343332460369043]], atol=1e-6, rtol=1e-6)
elif prng == 'Philox':
if mode == pnl.ExecutionMode.Python or pytest.helpers.llvm_current_fp_precision() == 'fp64':
if mode is pnl.ExecutionMode.Python or pytest.helpers.llvm_current_fp_precision() == 'fp64':
# np.testing.assert_allclose(run_results[0], [[-0.16882940384606543, -0.07280074899749223]])
np.testing.assert_allclose(run_results, [[-0.16882940384606543, -0.07280074899749223]])
elif pytest.helpers.llvm_current_fp_precision() == 'fp32':
Expand All @@ -225,7 +225,7 @@ def action_fn(variable):
else:
assert False, "Unknown PRNG!"

if mode == pnl.ExecutionMode.Python and not benchmark.enabled:
if mode is pnl.ExecutionMode.Python and not benchmark.enabled:
# FIXME: The results are 'close' for both Philox and MT,
# because they're dominated by costs
# FIX: Requires 1e-5 tolerance
Expand Down
3 changes: 1 addition & 2 deletions tests/ports/test_output_ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def test_output_port_variable_spec(self, mech_mode):
], ids=lambda x: str(x) if len(x) != 1 else '')
@pytest.mark.usefixtures("comp_mode_no_per_node")
def tests_output_port_variable_spec_composition(self, comp_mode, spec, expected1, expected2):
if (len(spec) == 2) and (spec[1] == pnl.TimeScale.RUN) and \
((comp_mode & pnl.ExecutionMode._Exec) == pnl.ExecutionMode._Exec):
if (len(spec) == 2) and (spec[1] == pnl.TimeScale.RUN) and (comp_mode & pnl.ExecutionMode._Exec):
pytest.skip("{} is not supported in {}".format(spec[1], comp_mode))

# Test specification of OutputPort's variable
Expand Down

0 comments on commit f2523c7

Please sign in to comment.