Skip to content

Commit e32e1da

Browse files
committed
refactor: allow for multiple Myokit simulations
Refactor the Myokit model mixin to allow multiple simulations to be run. Each simulation has its own dosing protocol and dosing events.
1 parent 7d8edea commit e32e1da

File tree

1 file changed

+183
-105
lines changed

1 file changed

+183
-105
lines changed

pkpdapp/pkpdapp/models/myokit_model_mixin.py

+183-105
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66

77
import pkpdapp
8+
from pkpdapp.models import Protocol
89
import numpy as np
910
from myokit.formats.mathml import MathMLExpressionWriter
1011
from myokit.formats.sbml import SBMLParser
@@ -19,6 +20,79 @@
1920

2021

2122
class MyokitModelMixin:
23+
def _initialise_variables(self, model, variables):
24+
# Convert units
25+
variables = {
26+
qname: self._convert_unit_qname(qname, value, model)
27+
for qname, value in variables.items()
28+
}
29+
30+
# Set constants in model
31+
for var_name, var_value in variables.items():
32+
model.get(var_name).set_rhs(float(var_value))
33+
34+
return variables
35+
36+
def _get_myokit_protocols(
37+
self, model, dosing_protocols, override_tlag, time_max
38+
):
39+
protocols = {}
40+
is_target = False
41+
time_var = model.binding("time")
42+
project = self.get_project()
43+
if project is None:
44+
compound = None
45+
else:
46+
compound = project.compound
47+
48+
for qname, protocol in dosing_protocols.items():
49+
amount_var = model.get(qname)
50+
set_administration(model, amount_var)
51+
tlag_value = self._get_tlag_value(qname)
52+
# override tlag if set
53+
if qname in override_tlag:
54+
tlag_value = override_tlag[qname]
55+
56+
if self.is_library_model:
57+
is_target = "CT1" in qname or "AT1" in qname
58+
59+
amount_conversion_factor = protocol.amount_unit.convert_to(
60+
amount_var.unit(), compound=compound, is_target=is_target
61+
)
62+
63+
time_conversion_factor = protocol.time_unit.convert_to(
64+
time_var.unit(), compound=compound
65+
)
66+
67+
dosing_events = _get_dosing_events(
68+
protocol.doses,
69+
amount_conversion_factor,
70+
time_conversion_factor,
71+
tlag_value,
72+
time_max
73+
)
74+
protocols[_get_pacing_label(amount_var)] = get_protocol(dosing_events)
75+
return protocols
76+
77+
def _get_override_tlag(self, variables):
78+
override_tlag = {}
79+
if isinstance(self, pkpdapp.models.CombinedModel):
80+
for dv in self.derived_variables.all():
81+
if dv.type == "TLG":
82+
derived_param = dv.pk_variable.qname + "_tlag_ud"
83+
if derived_param in variables:
84+
override_tlag[dv.pk_variable.qname] = variables[derived_param]
85+
return override_tlag
86+
87+
def _get_tlag_value(self, qname):
88+
from pkpdapp.models import Variable
89+
# get tlag value default to 0
90+
derived_param = qname + "_tlag_ud"
91+
try:
92+
return self.variables.get(qname=derived_param).default_value
93+
except Variable.DoesNotExist:
94+
return 0.0
95+
2296
def _get_myokit_model_cache_key(self):
2397
return "myokit_model_{}_{}".format(self._meta.db_table, self.id)
2498

@@ -45,97 +119,72 @@ def parse_mmt_string(mmt):
45119
def create_myokit_model(self):
46120
return self.parse_mmt_string(self.mmt)
47121

48-
def create_myokit_simulator(self, override_tlag=None, model=None, time_max=None):
122+
def create_myokit_simulations_from_dataset(
123+
self,
124+
dataset=None,
125+
variables=None,
126+
time_max=None,
127+
outputs=None
128+
):
129+
override_tlag = self._get_override_tlag(variables),
130+
if dataset is None:
131+
subject_protocols = []
132+
else:
133+
subject_protocols = [
134+
Protocol.objects.get(pk=p['protocol'])
135+
for p in dataset.subjects.values('protocol').distinct()
136+
if p['protocol'] is not None
137+
]
138+
for protocol in subject_protocols:
139+
model = self.get_myokit_model()
140+
141+
# Convert units
142+
variables = self._initialise_variables(model, variables)
143+
time_max = self._convert_bound_unit("time", time_max, model)
144+
# get tlag vars
145+
override_tlag = self._get_override_tlag(variables)
146+
# define a dosing protocol for this subject protocol
147+
dosing_protocols = {}
148+
if protocol.mapped_qname:
149+
dosing_protocols[protocol.mapped_qname] = protocol
150+
# create simulator
151+
sim = self.create_myokit_simulator(
152+
override_tlag=override_tlag,
153+
model=model,
154+
time_max=time_max,
155+
dosing_protocols=dosing_protocols
156+
)
157+
# TODO: take these from simulation model
158+
sim.set_tolerance(abs_tol=1e-06, rel_tol=1e-08)
159+
# Simulate, logging only state variables given by `outputs`
160+
print('##########################################')
161+
print(self.serialize_datalog(sim.run(time_max, log=outputs), model))
162+
163+
def create_myokit_simulator(
164+
self, override_tlag=None, model=None, time_max=None, dosing_protocols=None
165+
):
49166
if override_tlag is None:
50167
override_tlag = {}
51168

52169
if model is None:
53170
model = self.get_myokit_model()
54171

55-
from pkpdapp.models import Variable
56-
57-
if override_tlag is None:
58-
try:
59-
tlag_value = self.variables.get(
60-
qname="PKCompartment.tlag"
61-
).default_value
62-
except Variable.DoesNotExist:
63-
tlag_value = 0.0
64-
else:
65-
tlag_value = override_tlag
66-
67-
# add a dose_rate variable to the model for each
68-
# dosed variable
69-
for v in self.variables.filter(state=True):
70-
if v.protocol:
71-
myokit_v = model.get(v.qname)
72-
set_administration(model, myokit_v)
73-
74-
protocols = {}
75-
project = self.get_project()
76-
if project is None:
77-
compound = None
78-
else:
79-
compound = project.compound
80-
for v in self.variables.filter(state=True):
81-
if v.protocol:
82-
# get tlag value default to 0
83-
derived_param = v.qname + "_tlag_ud"
84-
try:
85-
tlag_value = self.variables.get(qname=derived_param).default_value
86-
except Variable.DoesNotExist:
87-
tlag_value = 0.0
88-
89-
# override tlag if set
90-
if v.qname in override_tlag:
91-
tlag_value = override_tlag[v.qname]
92-
93-
amount_var = model.get(v.qname)
94-
time_var = model.binding("time")
95-
96-
is_target = False
97-
if self.is_library_model:
98-
is_target = "CT1" in v.qname or "AT1" in v.qname
99-
100-
amount_conversion_factor = v.protocol.amount_unit.convert_to(
101-
amount_var.unit(), compound=compound, is_target=is_target
102-
)
103-
104-
time_conversion_factor = v.protocol.time_unit.convert_to(
105-
time_var.unit(), compound=compound
106-
)
107-
108-
dosing_events = []
109-
last_dose_time = tlag_value
110-
for d in v.protocol.doses.all():
111-
if d.repeat_interval <= 0:
112-
continue
113-
start_times = np.arange(
114-
d.start_time + last_dose_time,
115-
d.start_time + last_dose_time + d.repeat_interval * d.repeats,
116-
d.repeat_interval,
117-
)
118-
if len(start_times) == 0:
119-
continue
120-
last_dose_time = start_times[-1]
121-
dosing_events += [
122-
(
123-
(amount_conversion_factor / time_conversion_factor)
124-
* (d.amount / d.duration),
125-
time_conversion_factor * start_time,
126-
time_conversion_factor * d.duration,
127-
)
128-
for start_time in start_times
129-
]
130-
# if any dosing events are close to time_max,
131-
# make them equal to time_max
132-
if time_max is not None:
133-
for i, (level, start, duration) in enumerate(dosing_events):
134-
if abs(start - time_max) < 1e-6:
135-
dosing_events[i] = (level, time_max, duration)
136-
elif abs(start + duration - time_max) < 1e-6:
137-
dosing_events[i] = (level, start, time_max - start)
138-
protocols[_get_pacing_label(amount_var)] = get_protocol(dosing_events)
172+
if dosing_protocols is None:
173+
# add a dose_rate variable to the model for each
174+
# dosed variable
175+
dosing_variables = [
176+
v for v in self.variables.filter(state=True) if v.protocol
177+
]
178+
dosing_protocols = {}
179+
for v in dosing_variables:
180+
dosing_protocols[v.qname] = v.protocol
181+
182+
protocols = self._get_myokit_protocols(
183+
model=model,
184+
dosing_protocols=dosing_protocols,
185+
override_tlag=override_tlag,
186+
time_max=time_max
187+
)
139188

140189
with lock:
141190
sim = myokit.Simulation(model, protocol=protocols)
@@ -468,30 +517,23 @@ def simulate(self, outputs=None, variables=None, time_max=None):
468517
**variables,
469518
}
470519

471-
model = self.get_myokit_model()
520+
project = self.get_project()
521+
if project is not None:
522+
dataset = project.datasets.first()
523+
self.create_myokit_simulations_from_dataset(
524+
dataset=dataset,
525+
variables=variables,
526+
time_max=time_max,
527+
outputs=outputs
528+
)
472529

530+
model = self.get_myokit_model()
473531
# Convert units
474-
variables = {
475-
qname: self._convert_unit_qname(qname, value, model)
476-
for qname, value in variables.items()
477-
}
532+
variables = self._initialise_variables(model, variables)
478533
time_max = self._convert_bound_unit("time", time_max, model)
479-
480-
# Set constants in model
481-
for var_name, var_value in variables.items():
482-
model.get(var_name).set_rhs(float(var_value))
483-
484-
# create simulator
485-
486534
# get tlag vars
487-
override_tlag = {}
488-
if isinstance(self, pkpdapp.models.CombinedModel):
489-
for dv in self.derived_variables.all():
490-
if dv.type == "TLG":
491-
derived_param = dv.pk_variable.qname + "_tlag_ud"
492-
if derived_param in variables:
493-
override_tlag[dv.pk_variable.qname] = variables[derived_param]
494-
535+
override_tlag = self._get_override_tlag(variables)
536+
# create simulator
495537
sim = self.create_myokit_simulator(
496538
override_tlag=override_tlag, model=model, time_max=time_max
497539
)
@@ -663,3 +705,39 @@ def _add_dose_compartment(model, drug_amount, time_unit):
663705
)
664706

665707
return dose_drug_amount
708+
709+
710+
def _get_dosing_events(
711+
doses,
712+
amount_conversion_factor=1.0,
713+
time_conversion_factor=1.0,
714+
last_dose_time=0.0,
715+
time_max=None
716+
):
717+
dosing_events = []
718+
for d in doses.all():
719+
if d.repeat_interval <= 0:
720+
continue
721+
start_times = np.arange(
722+
d.start_time + last_dose_time,
723+
d.start_time + last_dose_time + d.repeat_interval * d.repeats,
724+
d.repeat_interval,
725+
)
726+
if len(start_times) == 0:
727+
continue
728+
last_dose_time = start_times[-1]
729+
dose_level = d.amount / d.duration
730+
dosing_events += [(
731+
(amount_conversion_factor / time_conversion_factor) * dose_level,
732+
time_conversion_factor * start_time,
733+
time_conversion_factor * d.duration,
734+
) for start_time in start_times]
735+
# if any dosing events are close to time_max,
736+
# make them equal to time_max
737+
if time_max is not None:
738+
for i, (level, start, duration) in enumerate(dosing_events):
739+
if abs(start - time_max) < 1e-6:
740+
dosing_events[i] = (level, time_max, duration)
741+
elif abs(start + duration - time_max) < 1e-6:
742+
dosing_events[i] = (level, start, time_max - start)
743+
return dosing_events

0 commit comments

Comments
 (0)