Skip to content

Commit

Permalink
Merge pull request #1521 from pints-team/grad-desc-unit-test
Browse files Browse the repository at this point in the history
Added value-based test for gradient descent.
  • Loading branch information
MichaelClerx authored Mar 21, 2024
2 parents 13b263a + ce7abf9 commit abcc8e5
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions pints/tests/test_opt_gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,38 @@ def test_name(self):
opt = method(np.array([0, 1.01]))
self.assertIn('radient descent', opt.name())

def test_step(self):
# Numerically test that it takes the correct step

# Create a gradient descent optimiser starting at 0
x0 = np.zeros(8)
opt = pints.GradientDescent(x0)
opt.set_learning_rate(1)

# Check that it starts at 0
xs = opt.ask()
self.assertEqual(len(xs), 1)
# Cast to list gives nicest message if any elements don't match
self.assertEqual(list(xs[0]), list(x0))

# If we pass in gradient -g, we should move to g
g = np.array([1, 2, 3, 4, 8, -7, 6, 5])
opt.tell([(0, -g)])
ys = opt.ask()
self.assertEqual(list(ys[0]), list(g))

# If we halve the learning rate and pass in +g, we should retrace half
# a step
opt.set_learning_rate(0.5)
opt.tell([(0, g)])
ys = opt.ask()
self.assertEqual(list(ys[0]), list(0.5 * g))

# And if we pass in +g again we should be back at 0
opt.tell([(0, g)])
ys = opt.ask()
self.assertEqual(list(ys[0]), list(x0))


if __name__ == '__main__':
print('Add -v for more debug output')
Expand Down

0 comments on commit abcc8e5

Please sign in to comment.