Skip to content

Commit 16a8fc7

Browse files
committed
added gain scheduling
1 parent b047801 commit 16a8fc7

File tree

3 files changed

+162
-30
lines changed

3 files changed

+162
-30
lines changed

tests/test_pid.py

+52-6
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,75 @@
11
import unittest
2-
from tinypid import PID
2+
from tinypid import PID, Gain, PIDGainScheduler
3+
34

45
class TestPID(unittest.TestCase):
56
def test_output_with_no_manual_output(self):
6-
pid = PID(K_p=1.0, K_i=0.5, K_d=0.2, setpoint=10.0, dt=0.1)
7+
pid = PID(k_p=1.0, k_i=0.5, k_d=0.2, setpoint=10.0, dt=0.1)
78
process_variable = 8.0
89
expected_output = 6.1 # P = 1.0 * (10.0 - 8.0) = 2.0, I = 0.5 * (10.0 - 8.0) * 0.1 = 0.1, D = 0.2 * ((10.0 - 8.0) / 0.1) = 4
910
output = pid(process_variable)
1011
self.assertAlmostEqual(output, expected_output)
1112

1213
def test_output_with_manual_output(self):
13-
pid = PID(K_p=1.0, K_i=0.5, K_d=0.2, setpoint=10.0, dt=0.1)
14+
pid = PID(k_p=1.0, k_i=0.5, k_d=0.2, setpoint=10.0, dt=0.1)
1415
process_variable = 8.0
1516
manual_output = 5.0
1617
expected_output = manual_output # Since manual_output is provided, the output should be equal to it
1718
output = pid(process_variable, manual_output=manual_output)
1819
self.assertAlmostEqual(output, expected_output)
1920

2021
def test_output_with_anti_windup_disabled(self):
21-
pid = PID(K_p=1.0, K_i=0.5, K_d=0.2, setpoint=10.0, dt=0.1)
22+
pid = PID(k_p=1.0, k_i=0.5, k_d=0.2, setpoint=10.0, dt=0.1)
2223
process_variable = 8.0
2324
manual_output = 15.0
2425
expected_output = manual_output # Since manual_output is provided, the output should be equal to it
2526
output = pid(process_variable, manual_output=manual_output, anti_windup=False)
2627
self.assertAlmostEqual(output, expected_output)
2728

28-
if __name__ == '__main__':
29-
unittest.main()
29+
30+
class TestGainSchedule(unittest.TestCase):
31+
def test_gain_scheduler_basic(self):
32+
gains = [
33+
Gain(setpoint_scope=(0, 20), k_p=1.0, k_i=0.5, k_d=0.2),
34+
Gain(setpoint_scope=(20, 30), k_p=2.0, k_i=1.0, k_d=0.4),
35+
]
36+
37+
pid = PIDGainScheduler(gains=gains, setpoint=10)
38+
39+
pid.update_gain(10)
40+
self.assertAlmostEqual(pid.k_p, 1.0)
41+
self.assertAlmostEqual(pid.k_i, 0.5)
42+
self.assertAlmostEqual(pid.k_d, 0.2)
43+
44+
pid.update_gain(25)
45+
self.assertAlmostEqual(pid.k_p, 2.0)
46+
self.assertAlmostEqual(pid.k_i, 1.0)
47+
self.assertAlmostEqual(pid.k_d, 0.4)
48+
49+
def test_gain_scheduler_constant(self):
50+
gains = [
51+
Gain(setpoint_scope=(0, 20), k_p=1.0, k_i=0.5, k_d=0.2),
52+
Gain(setpoint_scope=(20, 30), k_p=2.0, k_i=1.0, k_d=0.4),
53+
]
54+
55+
pid = PIDGainScheduler(gains=gains, setpoint=10, dt=0.1)
56+
57+
process_variable = 8.0
58+
expected_output = 6.1 # P = 1.0 * (10.0 - 8.0) = 2.0, I = 0.5 * (10.0 - 8.0) * 0.1 = 0.1, D = 0.2 * ((10.0 - 8.0) / 0.1) = 4
59+
output = pid(process_variable)
60+
self.assertAlmostEqual(output, expected_output)
61+
62+
def test_gain_scheduler_no_gain_found(self):
63+
gains = [
64+
Gain(setpoint_scope=(0, 20), k_p=1.0, k_i=0.5, k_d=0.2),
65+
Gain(setpoint_scope=(20, 30), k_p=2.0, k_i=1.0, k_d=0.4),
66+
]
67+
68+
pid = PIDGainScheduler(gains=gains, setpoint=10)
69+
70+
with self.assertRaises(ValueError):
71+
pid.update_gain(40)
72+
73+
74+
if __name__ == "__main__":
75+
unittest.main()

tinypid/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .pid import PID
1+
from .pid import PID, Gain, PIDGainScheduler

tinypid/pid.py

+109-23
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,29 @@
1010
output = controller(10)
1111
1212
"""
13-
from typing import Optional, Tuple
13+
14+
from typing import List, Optional, Tuple
15+
16+
17+
class Gain:
18+
"""
19+
A simple class to store PID gains and the setpoint range for which they apply.
20+
"""
21+
def __init__(self, setpoint_scope: Tuple[float, float], k_p: float, k_i: float, k_d: float) -> None:
22+
"""
23+
Initializes a Gain object with a setpoint range and PID gains.
24+
25+
Parameters:
26+
setpoint_scope: The range of setpoints for which these gains apply.
27+
The range is inclusive of the lower bound and exclusive of the upper bound.
28+
k_p : The proportional gain.
29+
k_i : The integral gain.
30+
k_d : The derivative gain.
31+
"""
32+
self.setpoint_scope = setpoint_scope
33+
self.k_p = k_p
34+
self.k_i = k_i
35+
self.k_d = k_d
1436

1537

1638
class PID:
@@ -25,9 +47,9 @@ class PID:
2547

2648
def __init__(
2749
self,
28-
K_p: float = 1,
29-
K_i: float = 0.1,
30-
K_d: float = 0,
50+
k_p: float = 1,
51+
k_i: float = 0.1,
52+
k_d: float = 0,
3153
setpoint: float = 0,
3254
dt: float = 1,
3355
derivative_lowpass: float = 1,
@@ -38,9 +60,9 @@ def __init__(
3860
Initialize PID controller
3961
4062
Parameters:
41-
K_p : Proportional gain
42-
K_i : Integral gain
43-
K_d : Derivative gain
63+
k_p : Proportional gain
64+
k_i : Integral gain
65+
k_d : Derivative gain
4466
dt : Time step
4567
derivative_lowpass: lowpass constant (between 1 and 0, 1 meaning no lowpass)
4668
upper_limit : Upper limit for the output
@@ -51,9 +73,9 @@ def __init__(
5173
if not 0 <= derivative_lowpass <= 1:
5274
raise ValueError("derivative_lowpass must be between 0 and 1")
5375

54-
self.K_p = K_p
55-
self.K_i = K_i
56-
self.K_d = K_d
76+
self.k_p = k_p
77+
self.k_i = k_i
78+
self.k_d = k_d
5779
self.P, self.I, self.D = None, None, None
5880
self.dt = dt
5981
self.alpha = derivative_lowpass
@@ -108,22 +130,25 @@ def limit(self, output: float) -> Tuple[bool, float]:
108130
saturated = output != unlimited
109131

110132
return saturated, output
111-
112-
def update_gains(self, K_p: float, K_i: float, K_d: float) -> None:
133+
134+
def update_gains(self, k_p: float, k_i: float, k_d: float) -> None:
113135
"""
114136
Update the PID gains.
115137
116138
Parameters:
117-
K_p : The new proportional gain
118-
K_i : The new integral gain
119-
K_d : The new derivative gain
139+
k_p : The new proportional gain
140+
k_i : The new integral gain
141+
k_d : The new derivative gain
120142
"""
121-
self.K_p = K_p
122-
self.K_i = K_i
123-
self.K_d = K_d
143+
self.k_p = k_p
144+
self.k_i = k_i
145+
self.k_d = k_d
124146

125147
def __call__(
126-
self, process_variable: float, manual_output: Optional[float] = None, anti_windup: bool = True
148+
self,
149+
process_variable: float,
150+
manual_output: Optional[float] = None,
151+
anti_windup: bool = True,
127152
) -> float:
128153
"""
129154
Process the input signal and return the controller output.
@@ -137,9 +162,9 @@ def __call__(
137162
self.integral += error * self.dt
138163
derivative = (error - self._previous_error) / self.dt if self.dt != 0 else 0
139164

140-
self.P = self.K_p * error
141-
self.I = self.K_i * self.integral
142-
self.D = self.K_d * (self.alpha * derivative + (1 - self.alpha) * self._previous_derivative)
165+
self.P = self.k_p * error
166+
self.I = self.k_i * self.integral
167+
self.D = self.k_d * (self.alpha * derivative + (1 - self.alpha) * self._previous_derivative)
143168

144169
output = self.P + self.I + self.D
145170

@@ -154,7 +179,7 @@ def __call__(
154179

155180
if manual_output:
156181
# Use setpoint tracking by calculating integral so that the output matches the manual setpoint
157-
self.integral = -(self.P + self.D - manual_output) / self.K_i if self.K_i != 0 else 0
182+
self.integral = -(self.P + self.D - manual_output) / self.k_i if self.k_i != 0 else 0
158183
output = manual_output
159184

160185
return output
@@ -165,3 +190,64 @@ def __repr__(self):
165190
f"P: {self.P}, I: {self.I}, D: {self.D}\n"
166191
f"Limits: {self.lower_limit} < output < {self.upper_limit}"
167192
)
193+
194+
195+
class PIDGainScheduler(PID):
196+
"""
197+
An extended Proportional-Integral-Derivative controller that uses
198+
gain scheduling to allow different gains, i.e., k_p, k_i, and k_d depending on
199+
the setpoint.
200+
201+
"""
202+
203+
def __init__(
204+
self,
205+
gains: List[Gain],
206+
setpoint: float = 0,
207+
dt: float = 1,
208+
derivative_lowpass: float = 1,
209+
upper_limit: Optional[float] = None,
210+
lower_limit: Optional[float] = None,
211+
) -> None:
212+
"""
213+
Initialize PID controller
214+
215+
Parameters:
216+
gains : List of Gain objects
217+
setpoint : The setpoint
218+
dt : Time step
219+
derivative_lowpass: lowpass constant (between 1 and 0, 1 meaning no lowpass)
220+
upper_limit : Upper limit for the output
221+
lower_limit : Lower limit for the output
222+
"""
223+
224+
super().__init__(None, None, None, setpoint, dt, derivative_lowpass, upper_limit, lower_limit)
225+
226+
self.gains = gains
227+
self.update_gain(setpoint)
228+
229+
def update_gain(self, setpoint: float) -> Tuple[float, float, float]:
230+
"""
231+
Update the PID gains based on the current setpoint
232+
233+
Parameters:
234+
output : The current output
235+
"""
236+
for gain in self.gains:
237+
lower, upper = gain.setpoint_scope
238+
if lower <= setpoint < upper:
239+
self.k_p, self.k_i, self.k_d = gain.k_p, gain.k_i, gain.k_d
240+
break
241+
else:
242+
raise ValueError("No gain found for the given setpoint.")
243+
244+
def __call__(
245+
self,
246+
process_variable: float,
247+
manual_output: Optional[float] = None,
248+
anti_windup: bool = True,
249+
) -> float:
250+
self.update_gain(self.setpoint)
251+
252+
output = super().__call__(process_variable, manual_output, anti_windup)
253+
return output

0 commit comments

Comments
 (0)