From 3c0fd46a59468e2484ea81ae720062f4afba74ca Mon Sep 17 00:00:00 2001 From: Jim O'Donnell Date: Mon, 11 Mar 2024 12:47:27 +0000 Subject: [PATCH 1/5] feat: allow for multiple Myokit simulations Refactor the Myokit model mixin to allow multiple simulations to be run. Each simulation has its own dosing protocols and dosing events. Update the simulate API to return multiple simulations. --- frontend-v2/src/app/backendApi.ts | 4 +- pkpdapp/pkpdapp/api/views/simulate.py | 8 +- pkpdapp/pkpdapp/models/myokit_model_mixin.py | 77 +++++++++++++++---- .../tests/test_views/test_combined_model.py | 9 ++- .../pkpdapp/tests/test_views/test_simulate.py | 17 ++-- pkpdapp/pkpdapp/tests/utils.py | 2 +- pkpdapp/schema.yml | 8 +- 7 files changed, 93 insertions(+), 32 deletions(-) diff --git a/frontend-v2/src/app/backendApi.ts b/frontend-v2/src/app/backendApi.ts index 1b6524ed..eec30f4e 100644 --- a/frontend-v2/src/app/backendApi.ts +++ b/frontend-v2/src/app/backendApi.ts @@ -1127,7 +1127,7 @@ export type CombinedModelSetVariablesFromInferenceUpdateApiArg = { combinedModel: CombinedModel; }; export type CombinedModelSimulateCreateApiResponse = - /** status 200 */ SimulateResponse; + /** status 200 */ SimulateResponse[]; export type CombinedModelSimulateCreateApiArg = { id: number; simulate: Simulate; @@ -1350,7 +1350,7 @@ export type PharmacodynamicSetVariablesFromInferenceUpdateApiArg = { pharmacodynamic: Pharmacodynamic; }; export type PharmacodynamicSimulateCreateApiResponse = - /** status 200 */ SimulateResponse; + /** status 200 */ SimulateResponse[]; export type PharmacodynamicSimulateCreateApiArg = { id: number; simulate: Simulate; diff --git a/pkpdapp/pkpdapp/api/views/simulate.py b/pkpdapp/pkpdapp/api/views/simulate.py index 716dc1c7..6cbf9fd8 100644 --- a/pkpdapp/pkpdapp/api/views/simulate.py +++ b/pkpdapp/pkpdapp/api/views/simulate.py @@ -38,6 +38,10 @@ def to_representation(self, instance): } +class SimulateResponseListSerializer(serializers.ListSerializer): + child = SimulateResponseSerializer() + + class ErrorResponseSerializer(serializers.Serializer): error = serializers.CharField() @@ -45,7 +49,7 @@ class ErrorResponseSerializer(serializers.Serializer): @extend_schema( request=SimulateSerializer, responses={ - 200: SimulateResponseSerializer, + 200: SimulateResponseListSerializer, 400: ErrorResponseSerializer, 404: None, }, @@ -67,7 +71,7 @@ def post(self, request, pk, format=None): return Response( serialized_result.data, status=status.HTTP_400_BAD_REQUEST ) - serialized_result = SimulateResponseSerializer(result) + serialized_result = SimulateResponseSerializer(result, many=True) return Response(serialized_result.data) diff --git a/pkpdapp/pkpdapp/models/myokit_model_mixin.py b/pkpdapp/pkpdapp/models/myokit_model_mixin.py index ae668ff2..2f348b1f 100644 --- a/pkpdapp/pkpdapp/models/myokit_model_mixin.py +++ b/pkpdapp/pkpdapp/models/myokit_model_mixin.py @@ -5,6 +5,7 @@ # import pkpdapp +from pkpdapp.models import Protocol import numpy as np from myokit.formats.mathml import MathMLExpressionWriter from myokit.formats.sbml import SBMLParser @@ -440,6 +441,27 @@ def serialize_datalog(self, datalog, myokit_model): def get_time_max(self): return self.time_max + def simulate_model( + self, outputs=None, variables=None, time_max=None, dosing_protocols=None + ): + model = self.get_myokit_model() + # Convert units + variables = self._initialise_variables(model, variables) + time_max = self._convert_bound_unit("time", time_max, model) + # get tlag vars + override_tlag = self._get_override_tlag(variables) + # create simulator + sim = self.create_myokit_simulator( + override_tlag=override_tlag, + model=model, + time_max=time_max, + dosing_protocols=dosing_protocols + ) + # TODO: take these from simulation model + sim.set_tolerance(abs_tol=1e-06, rel_tol=1e-08) + # Simulate, logging only state variables given by `outputs` + return self.serialize_datalog(sim.run(time_max, log=outputs), model) + def simulate(self, outputs=None, variables=None, time_max=None): """ Arguments @@ -475,21 +497,32 @@ def simulate(self, outputs=None, variables=None, time_max=None): **variables, } - model = self.get_myokit_model() - # Convert units - variables = self._initialise_variables(model, variables) - time_max = self._convert_bound_unit("time", time_max, model) - # get tlag vars - override_tlag = self._get_override_tlag(variables) - # create simulator - sim = self.create_myokit_simulator( - override_tlag=override_tlag, model=model, time_max=time_max + project_sim = self.simulate_model( + variables=variables, time_max=time_max, outputs=outputs ) - # TODO: take these from simulation model - sim.set_tolerance(abs_tol=1e-06, rel_tol=1e-08) - # Simulate, logging only state variables given by `outputs` - return self.serialize_datalog(sim.run(time_max, log=outputs), model) + project = self.get_project() + sims = [project_sim] + if project is not None: + for subjects in get_project_cohorts(project): + # find unique protocols for this subject cohort + dosing_protocols = {} + subject_protocols = [ + Protocol.objects.get(pk=p['protocol']) + for p in subjects.values('protocol').distinct() + if p['protocol'] is not None + ] + for protocol in subject_protocols: + dosing_protocols[protocol.mapped_qname] = protocol + sim = self.simulate_model( + outputs=outputs, + variables=variables, + time_max=time_max, + dosing_protocols=dosing_protocols + ) + sims.append(sim) + + return sims def set_administration(model, drug_amount, direct=True): @@ -689,3 +722,21 @@ def _get_dosing_events( elif abs(start + duration - time_max) < 1e-6: dosing_events[i] = (level, start, time_max - start) return dosing_events + + +def get_project_cohorts(project): + dataset = project.datasets.first() + cohorts = [] + if dataset is not None: + # TODO: create backend subject cohorts based on the + # frontend upload stepper + dataset_protocols = [ + Protocol.objects.get(pk=p['protocol']) + for p in dataset.subjects.values('protocol').distinct() + if p['protocol'] is not None + ] + cohorts = [ + dataset.subjects.filter(protocol=protocol) + for protocol in dataset_protocols + ] + return cohorts diff --git a/pkpdapp/pkpdapp/tests/test_views/test_combined_model.py b/pkpdapp/pkpdapp/tests/test_views/test_combined_model.py index 1e7ef1b3..aa37dc44 100644 --- a/pkpdapp/pkpdapp/tests/test_views/test_combined_model.py +++ b/pkpdapp/pkpdapp/tests/test_views/test_combined_model.py @@ -135,11 +135,12 @@ def simulate_combined_model(self, id): format="json", ) self.assertEqual(response.status_code, status.HTTP_200_OK) - keys = [key for key in response.data["outputs"].keys()] + data = response.data[0] + keys = [key for key in data["outputs"].keys()] return ( - response.data["outputs"][keys[0]], - response.data["outputs"][keys[1]], - response.data["outputs"][keys[2]], + data["outputs"][keys[0]], + data["outputs"][keys[1]], + data["outputs"][keys[2]], ) def test_swap_mapped_pd_model(self): diff --git a/pkpdapp/pkpdapp/tests/test_views/test_simulate.py b/pkpdapp/pkpdapp/tests/test_views/test_simulate.py index 91733d42..d59b3b61 100644 --- a/pkpdapp/pkpdapp/tests/test_views/test_simulate.py +++ b/pkpdapp/pkpdapp/tests/test_views/test_simulate.py @@ -44,14 +44,15 @@ def test_simulate(self): response = self.client.post(url, data, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) - outputs = response.data.get('outputs') - self.assertCountEqual( - list(outputs.keys()), - [ - Variable.objects.get(qname=qname, dosed_pk_model=m).id - for qname in data['outputs'] - ] - ) + for sim in response.data: + outputs = sim.get('outputs') + self.assertCountEqual( + list(outputs.keys()), + [ + Variable.objects.get(qname=qname, dosed_pk_model=m).id + for qname in data['outputs'] + ] + ) url = reverse('simulate-combined-model', args=(123,)) response = self.client.post(url, data, format='json') diff --git a/pkpdapp/pkpdapp/tests/utils.py b/pkpdapp/pkpdapp/tests/utils.py index 150baebe..ef6f60a5 100644 --- a/pkpdapp/pkpdapp/tests/utils.py +++ b/pkpdapp/pkpdapp/tests/utils.py @@ -30,7 +30,7 @@ def create_pd_inference(sampling=False): # generate some fake data output = model.variables.get(qname='PDCompartment.TS') time = model.variables.get(qname='environment.t') - data = model.simulate(outputs=[output.qname, time.qname]) + data = model.simulate(outputs=[output.qname, time.qname])[0] print(data) TS = data[output.id] times = data[time.id] diff --git a/pkpdapp/schema.yml b/pkpdapp/schema.yml index 5e2e72c5..35b5bd01 100644 --- a/pkpdapp/schema.yml +++ b/pkpdapp/schema.yml @@ -539,7 +539,9 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SimulateResponse' + type: array + items: + $ref: '#/components/schemas/SimulateResponse' description: '' '400': content: @@ -1674,7 +1676,9 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SimulateResponse' + type: array + items: + $ref: '#/components/schemas/SimulateResponse' description: '' '400': content: From cc33ed40184d2b5e686e83187ed3a35ecd539a87 Mon Sep 17 00:00:00 2001 From: Jim O'Donnell Date: Wed, 13 Mar 2024 17:04:58 +0000 Subject: [PATCH 2/5] Show multiple plots in the Simulations tab --- .../src/features/simulation/Simulations.tsx | 51 ++++++++++--------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/frontend-v2/src/features/simulation/Simulations.tsx b/frontend-v2/src/features/simulation/Simulations.tsx index a2e3a3a9..a427614c 100644 --- a/frontend-v2/src/features/simulation/Simulations.tsx +++ b/frontend-v2/src/features/simulation/Simulations.tsx @@ -180,7 +180,7 @@ const Simulations: FC = () => { ? (simulateErrorBase.data as ErrorObject) : { error: "Unknown error" } : undefined; - const [data, setData] = useState(null); + const [data, setData] = useState(null); const { data: compound, isLoading: isLoadingCompound } = useCompoundRetrieveQuery( { id: project?.compound || 0 }, @@ -303,7 +303,8 @@ const Simulations: FC = () => { }).then((response) => { setLoadingSimulate(false); if ("data" in response) { - const responseData = response.data as SimulateResponse; + const responseData = response.data as SimulateResponse[]; + console.log({ responseData }) setData(responseData); } }); @@ -342,10 +343,10 @@ const Simulations: FC = () => { ), }).then((response) => { if ("data" in response) { - const responseData = response.data as SimulateResponse; + const [ projectData ] = response.data; const nrows = - responseData.outputs[Object.keys(responseData.outputs)[0]].length; - const cols = Object.keys(responseData.outputs); + projectData.outputs[Object.keys(projectData.outputs)[0]].length; + const cols = Object.keys(projectData.outputs); const vars = cols.map((vid) => variables.find((v) => v.id === parseInt(vid)), ); @@ -377,7 +378,7 @@ const Simulations: FC = () => { for (let i = 0; i < nrows; i++) { rows[rowi] = new Array(ncols); for (let j = 0; j < ncols; j++) { - rows[rowi][j] = responseData.outputs[cols[j]][i]; + rows[rowi][j] = projectData.outputs[cols[j]][i]; } rowi++; } @@ -592,24 +593,26 @@ const Simulations: FC = () => { )} {plots.map((plot, index) => ( - - {data && model ? ( - - ) : ( -
Loading...
- )} -
+ data?.map(d => ( + + {d && model ? ( + + ) : ( +
Loading...
+ )} +
+ )) ))}
From cf486b21ff4f5d0c98a0fb8fd7de986067f7bb6d Mon Sep 17 00:00:00 2001 From: Jim O'Donnell Date: Fri, 15 Mar 2024 11:00:48 +0000 Subject: [PATCH 3/5] Use dataset subject groups to run simulations --- pkpdapp/pkpdapp/models/myokit_model_mixin.py | 26 ++++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/pkpdapp/pkpdapp/models/myokit_model_mixin.py b/pkpdapp/pkpdapp/models/myokit_model_mixin.py index 2f348b1f..8101e680 100644 --- a/pkpdapp/pkpdapp/models/myokit_model_mixin.py +++ b/pkpdapp/pkpdapp/models/myokit_model_mixin.py @@ -5,7 +5,7 @@ # import pkpdapp -from pkpdapp.models import Protocol +from pkpdapp.models import SubjectGroup, Protocol import numpy as np from myokit.formats.mathml import MathMLExpressionWriter from myokit.formats.sbml import SBMLParser @@ -504,12 +504,12 @@ def simulate(self, outputs=None, variables=None, time_max=None): project = self.get_project() sims = [project_sim] if project is not None: - for subjects in get_project_cohorts(project): + for group in get_subject_groups(project): # find unique protocols for this subject cohort dosing_protocols = {} subject_protocols = [ Protocol.objects.get(pk=p['protocol']) - for p in subjects.values('protocol').distinct() + for p in group.subjects.values('protocol').distinct() if p['protocol'] is not None ] for protocol in subject_protocols: @@ -724,19 +724,13 @@ def _get_dosing_events( return dosing_events -def get_project_cohorts(project): +def get_subject_groups(project): dataset = project.datasets.first() - cohorts = [] + dataset_groups = [] if dataset is not None: - # TODO: create backend subject cohorts based on the - # frontend upload stepper - dataset_protocols = [ - Protocol.objects.get(pk=p['protocol']) - for p in dataset.subjects.values('protocol').distinct() - if p['protocol'] is not None + dataset_groups = [ + SubjectGroup.objects.get(pk=g['group']) + for g in dataset.subjects.values('group').distinct() + if g['group'] is not None ] - cohorts = [ - dataset.subjects.filter(protocol=protocol) - for protocol in dataset_protocols - ] - return cohorts + return dataset_groups From b0e98ef05584d723a17fed9e603fb6d4452105ad Mon Sep 17 00:00:00 2001 From: Jim O'Donnell Date: Fri, 15 Mar 2024 14:52:30 +0000 Subject: [PATCH 4/5] Plot multiple simulations in the same plot --- .../simulation/SimulationPlotView.tsx | 122 +++++++++--------- .../src/features/simulation/Simulations.tsx | 39 +++--- 2 files changed, 80 insertions(+), 81 deletions(-) diff --git a/frontend-v2/src/features/simulation/SimulationPlotView.tsx b/frontend-v2/src/features/simulation/SimulationPlotView.tsx index d5a7cc68..c7e883b7 100644 --- a/frontend-v2/src/features/simulation/SimulationPlotView.tsx +++ b/frontend-v2/src/features/simulation/SimulationPlotView.tsx @@ -127,7 +127,7 @@ function genIcLines( interface SimulationPlotProps { index: number; plot: FieldArrayWithId; - data: SimulateResponse; + data: SimulateResponse[]; variables: VariableRead[]; control: Control; setValue: UseFormSetValue; @@ -183,7 +183,7 @@ const SimulationPlotView: FC = ({ ? parseFloat(xcompatibleUnit.conversion_factor) : 1.0; - const convertedTime = data.time.map((t) => t * xconversionFactor); + const convertedTime = data[0].time.map((t) => t * xconversionFactor); const minX = Math.min(...convertedTime); const maxX = Math.max(...convertedTime); @@ -192,71 +192,73 @@ const SimulationPlotView: FC = ({ let maxY: number | undefined = undefined; let maxY2: number | undefined = undefined; - const plotData: Data[] = plot.y_axes.map((y_axis) => { - const variableValues = data.outputs[y_axis.variable]; - const variable = variables.find((v) => v.id === y_axis.variable); - const variableName = variable?.name; - const variableUnit = units.find((u) => u.id === variable?.unit); + const plotData = plot.y_axes.map((y_axis) => { + return data.map((d, index) => { + const variableValues = d.outputs[y_axis.variable]; + const variable = variables.find((v) => v.id === y_axis.variable); + const variableName = variable?.name; + const variableUnit = units.find((u) => u.id === variable?.unit); - const yaxisUnit = y_axis.right - ? units.find((u) => u.id === plot.y_unit2) - : units.find((u) => u.id === plot.y_unit); - const ycompatibleUnit = variableUnit?.compatible_units.find( - (u) => parseInt(u.id) === yaxisUnit?.id, - ); + const yaxisUnit = y_axis.right + ? units.find((u) => u.id === plot.y_unit2) + : units.find((u) => u.id === plot.y_unit); + const ycompatibleUnit = variableUnit?.compatible_units.find( + (u) => parseInt(u.id) === yaxisUnit?.id, + ); - const is_target = model.is_library_model - ? variableName?.includes("CT1") || variableName?.includes("AT1") - : false; - const yconversionFactor = ycompatibleUnit - ? parseFloat( - is_target - ? ycompatibleUnit.target_conversion_factor - : ycompatibleUnit.conversion_factor, - ) - : 1.0; + const is_target = model.is_library_model + ? variableName?.includes("CT1") || variableName?.includes("AT1") + : false; + const yconversionFactor = ycompatibleUnit + ? parseFloat( + is_target + ? ycompatibleUnit.target_conversion_factor + : ycompatibleUnit.conversion_factor, + ) + : 1.0; - if (variableValues) { - const y = variableValues.map((v) => v * yconversionFactor); - if (y_axis.right) { - if (minY2 === undefined) { - minY2 = Math.min(...y); - } else { - minY2 = Math.min(minY2, ...y); - } - if (maxY2 === undefined) { - maxY2 = Math.max(...y); + if (variableValues) { + const y = variableValues.map((v) => v * yconversionFactor); + if (y_axis.right) { + if (minY2 === undefined) { + minY2 = Math.min(...y); + } else { + minY2 = Math.min(minY2, ...y); + } + if (maxY2 === undefined) { + maxY2 = Math.max(...y); + } else { + maxY2 = Math.max(maxY2, ...y); + } } else { - maxY2 = Math.max(maxY2, ...y); + if (minY === undefined) { + minY = Math.min(...y); + } else { + minY = Math.min(minY, ...y); + } + if (maxY === undefined) { + maxY = Math.max(...y); + } else { + maxY = Math.max(maxY, ...y); + } } + return { + yaxis: y_axis.right ? "y2" : undefined, + x: convertedTime, + y: variableValues.map((v) => v * yconversionFactor), + name: `${variableName} ${index}` || "unknown", + }; } else { - if (minY === undefined) { - minY = Math.min(...y); - } else { - minY = Math.min(minY, ...y); - } - if (maxY === undefined) { - maxY = Math.max(...y); - } else { - maxY = Math.max(maxY, ...y); - } + return { + yaxis: y_axis.right ? "y2" : undefined, + x: [], + y: [], + type: "scatter", + name: `${y_axis.variable} ${index}`, + }; } - return { - yaxis: y_axis.right ? "y2" : undefined, - x: convertedTime, - y: variableValues.map((v) => v * yconversionFactor), - name: variableName || "unknown", - }; - } else { - return { - yaxis: y_axis.right ? "y2" : undefined, - x: [], - y: [], - type: "scatter", - name: `${y_axis.variable}`, - }; - } - }); + }); + }).flat() as Data[]; const concentrationUnit = units.find((unit) => unit.symbol === "pmol/L"); if (concentrationUnit === undefined) { diff --git a/frontend-v2/src/features/simulation/Simulations.tsx b/frontend-v2/src/features/simulation/Simulations.tsx index a427614c..60e007c1 100644 --- a/frontend-v2/src/features/simulation/Simulations.tsx +++ b/frontend-v2/src/features/simulation/Simulations.tsx @@ -304,7 +304,6 @@ const Simulations: FC = () => { setLoadingSimulate(false); if ("data" in response) { const responseData = response.data as SimulateResponse[]; - console.log({ responseData }) setData(responseData); } }); @@ -593,26 +592,24 @@ const Simulations: FC = () => { )} {plots.map((plot, index) => ( - data?.map(d => ( - - {d && model ? ( - - ) : ( -
Loading...
- )} -
- )) + + {data && model ? ( + + ) : ( +
Loading...
+ )} +
))}
From d030a9b77eb86fe142a3307eee42a4065c6dc4f4 Mon Sep 17 00:00:00 2001 From: Jim O'Donnell Date: Fri, 15 Mar 2024 15:44:41 +0000 Subject: [PATCH 5/5] Test simulations with a project and dataset --- .../pkpdapp/tests/test_views/test_simulate.py | 40 ++++++++++++++++++- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/pkpdapp/pkpdapp/tests/test_views/test_simulate.py b/pkpdapp/pkpdapp/tests/test_views/test_simulate.py index d59b3b61..a3ce536b 100644 --- a/pkpdapp/pkpdapp/tests/test_views/test_simulate.py +++ b/pkpdapp/pkpdapp/tests/test_views/test_simulate.py @@ -5,9 +5,17 @@ # from pkpdapp.models import ( - PharmacodynamicModel, Variable, + PharmacodynamicModel, + PharmacokineticModel, + Variable, CombinedModel, - + Project, + Compound, + Dataset, + Protocol, + Subject, + SubjectGroup, + Unit ) from django.contrib.auth.models import User @@ -19,7 +27,29 @@ class TestSimulateView(APITestCase): def setUp(self): + au = Unit.objects.get(symbol='mg') + tu = Unit.objects.get(symbol='h') + compound = Compound.objects.create(name="demo", compound_type="LM") + self.project = Project.objects.create(name='test project', compound=compound) + self.dataset = Dataset.objects.create(name='test dataset', project=self.project) + self.protocol = Protocol.objects.create( + name='my_cool_protocol', + compound=compound, + amount_unit=au, + time_unit=tu, + mapped_qname='PKCompartment.A1' + ) + self.subject_group = SubjectGroup.objects.create( + name='my_cool_group', + ) + self.subject = Subject.objects.create( + id_in_dataset=1, + dataset=self.dataset, + group=self.subject_group, + protocol=self.protocol, + ) self.user = User.objects.create_user(username='testuser', password='12345') + self.project.users.add(self.user) self.client = APIClient() self.client.force_authenticate(user=self.user) @@ -28,9 +58,14 @@ def test_simulate(self): name='tumour_growth_gompertz', read_only=False, ) + pk = PharmacokineticModel.objects.get( + name="one_compartment_clinical", + ) m = CombinedModel.objects.create( name='my wonderful model', pd_model=pd, + pk_model=pk, + project=self.project, ) url = reverse('simulate-combined-model', args=(m.pk,)) @@ -44,6 +79,7 @@ def test_simulate(self): response = self.client.post(url, data, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) + for sim in response.data: outputs = sim.get('outputs') self.assertCountEqual(