diff --git a/frontend-v2/src/app/backendApi.ts b/frontend-v2/src/app/backendApi.ts index 1b6524ed..eec30f4e 100644 --- a/frontend-v2/src/app/backendApi.ts +++ b/frontend-v2/src/app/backendApi.ts @@ -1127,7 +1127,7 @@ export type CombinedModelSetVariablesFromInferenceUpdateApiArg = { combinedModel: CombinedModel; }; export type CombinedModelSimulateCreateApiResponse = - /** status 200 */ SimulateResponse; + /** status 200 */ SimulateResponse[]; export type CombinedModelSimulateCreateApiArg = { id: number; simulate: Simulate; @@ -1350,7 +1350,7 @@ export type PharmacodynamicSetVariablesFromInferenceUpdateApiArg = { pharmacodynamic: Pharmacodynamic; }; export type PharmacodynamicSimulateCreateApiResponse = - /** status 200 */ SimulateResponse; + /** status 200 */ SimulateResponse[]; export type PharmacodynamicSimulateCreateApiArg = { id: number; simulate: Simulate; diff --git a/frontend-v2/src/features/simulation/SimulationPlotView.tsx b/frontend-v2/src/features/simulation/SimulationPlotView.tsx index d5a7cc68..c7e883b7 100644 --- a/frontend-v2/src/features/simulation/SimulationPlotView.tsx +++ b/frontend-v2/src/features/simulation/SimulationPlotView.tsx @@ -127,7 +127,7 @@ function genIcLines( interface SimulationPlotProps { index: number; plot: FieldArrayWithId; - data: SimulateResponse; + data: SimulateResponse[]; variables: VariableRead[]; control: Control; setValue: UseFormSetValue; @@ -183,7 +183,7 @@ const SimulationPlotView: FC = ({ ? parseFloat(xcompatibleUnit.conversion_factor) : 1.0; - const convertedTime = data.time.map((t) => t * xconversionFactor); + const convertedTime = data[0].time.map((t) => t * xconversionFactor); const minX = Math.min(...convertedTime); const maxX = Math.max(...convertedTime); @@ -192,71 +192,73 @@ const SimulationPlotView: FC = ({ let maxY: number | undefined = undefined; let maxY2: number | undefined = undefined; - const plotData: Data[] = plot.y_axes.map((y_axis) => { - const variableValues = data.outputs[y_axis.variable]; - const variable = variables.find((v) => v.id === y_axis.variable); - const variableName = variable?.name; - const variableUnit = units.find((u) => u.id === variable?.unit); + const plotData = plot.y_axes.map((y_axis) => { + return data.map((d, index) => { + const variableValues = d.outputs[y_axis.variable]; + const variable = variables.find((v) => v.id === y_axis.variable); + const variableName = variable?.name; + const variableUnit = units.find((u) => u.id === variable?.unit); - const yaxisUnit = y_axis.right - ? units.find((u) => u.id === plot.y_unit2) - : units.find((u) => u.id === plot.y_unit); - const ycompatibleUnit = variableUnit?.compatible_units.find( - (u) => parseInt(u.id) === yaxisUnit?.id, - ); + const yaxisUnit = y_axis.right + ? units.find((u) => u.id === plot.y_unit2) + : units.find((u) => u.id === plot.y_unit); + const ycompatibleUnit = variableUnit?.compatible_units.find( + (u) => parseInt(u.id) === yaxisUnit?.id, + ); - const is_target = model.is_library_model - ? variableName?.includes("CT1") || variableName?.includes("AT1") - : false; - const yconversionFactor = ycompatibleUnit - ? parseFloat( - is_target - ? ycompatibleUnit.target_conversion_factor - : ycompatibleUnit.conversion_factor, - ) - : 1.0; + const is_target = model.is_library_model + ? variableName?.includes("CT1") || variableName?.includes("AT1") + : false; + const yconversionFactor = ycompatibleUnit + ? parseFloat( + is_target + ? ycompatibleUnit.target_conversion_factor + : ycompatibleUnit.conversion_factor, + ) + : 1.0; - if (variableValues) { - const y = variableValues.map((v) => v * yconversionFactor); - if (y_axis.right) { - if (minY2 === undefined) { - minY2 = Math.min(...y); - } else { - minY2 = Math.min(minY2, ...y); - } - if (maxY2 === undefined) { - maxY2 = Math.max(...y); + if (variableValues) { + const y = variableValues.map((v) => v * yconversionFactor); + if (y_axis.right) { + if (minY2 === undefined) { + minY2 = Math.min(...y); + } else { + minY2 = Math.min(minY2, ...y); + } + if (maxY2 === undefined) { + maxY2 = Math.max(...y); + } else { + maxY2 = Math.max(maxY2, ...y); + } } else { - maxY2 = Math.max(maxY2, ...y); + if (minY === undefined) { + minY = Math.min(...y); + } else { + minY = Math.min(minY, ...y); + } + if (maxY === undefined) { + maxY = Math.max(...y); + } else { + maxY = Math.max(maxY, ...y); + } } + return { + yaxis: y_axis.right ? "y2" : undefined, + x: convertedTime, + y: variableValues.map((v) => v * yconversionFactor), + name: `${variableName} ${index}` || "unknown", + }; } else { - if (minY === undefined) { - minY = Math.min(...y); - } else { - minY = Math.min(minY, ...y); - } - if (maxY === undefined) { - maxY = Math.max(...y); - } else { - maxY = Math.max(maxY, ...y); - } + return { + yaxis: y_axis.right ? "y2" : undefined, + x: [], + y: [], + type: "scatter", + name: `${y_axis.variable} ${index}`, + }; } - return { - yaxis: y_axis.right ? "y2" : undefined, - x: convertedTime, - y: variableValues.map((v) => v * yconversionFactor), - name: variableName || "unknown", - }; - } else { - return { - yaxis: y_axis.right ? "y2" : undefined, - x: [], - y: [], - type: "scatter", - name: `${y_axis.variable}`, - }; - } - }); + }); + }).flat() as Data[]; const concentrationUnit = units.find((unit) => unit.symbol === "pmol/L"); if (concentrationUnit === undefined) { diff --git a/frontend-v2/src/features/simulation/Simulations.tsx b/frontend-v2/src/features/simulation/Simulations.tsx index a2e3a3a9..60e007c1 100644 --- a/frontend-v2/src/features/simulation/Simulations.tsx +++ b/frontend-v2/src/features/simulation/Simulations.tsx @@ -180,7 +180,7 @@ const Simulations: FC = () => { ? (simulateErrorBase.data as ErrorObject) : { error: "Unknown error" } : undefined; - const [data, setData] = useState(null); + const [data, setData] = useState(null); const { data: compound, isLoading: isLoadingCompound } = useCompoundRetrieveQuery( { id: project?.compound || 0 }, @@ -303,7 +303,7 @@ const Simulations: FC = () => { }).then((response) => { setLoadingSimulate(false); if ("data" in response) { - const responseData = response.data as SimulateResponse; + const responseData = response.data as SimulateResponse[]; setData(responseData); } }); @@ -342,10 +342,10 @@ const Simulations: FC = () => { ), }).then((response) => { if ("data" in response) { - const responseData = response.data as SimulateResponse; + const [ projectData ] = response.data; const nrows = - responseData.outputs[Object.keys(responseData.outputs)[0]].length; - const cols = Object.keys(responseData.outputs); + projectData.outputs[Object.keys(projectData.outputs)[0]].length; + const cols = Object.keys(projectData.outputs); const vars = cols.map((vid) => variables.find((v) => v.id === parseInt(vid)), ); @@ -377,7 +377,7 @@ const Simulations: FC = () => { for (let i = 0; i < nrows; i++) { rows[rowi] = new Array(ncols); for (let j = 0; j < ncols; j++) { - rows[rowi][j] = responseData.outputs[cols[j]][i]; + rows[rowi][j] = projectData.outputs[cols[j]][i]; } rowi++; } diff --git a/pkpdapp/pkpdapp/api/views/simulate.py b/pkpdapp/pkpdapp/api/views/simulate.py index 716dc1c7..6cbf9fd8 100644 --- a/pkpdapp/pkpdapp/api/views/simulate.py +++ b/pkpdapp/pkpdapp/api/views/simulate.py @@ -38,6 +38,10 @@ def to_representation(self, instance): } +class SimulateResponseListSerializer(serializers.ListSerializer): + child = SimulateResponseSerializer() + + class ErrorResponseSerializer(serializers.Serializer): error = serializers.CharField() @@ -45,7 +49,7 @@ class ErrorResponseSerializer(serializers.Serializer): @extend_schema( request=SimulateSerializer, responses={ - 200: SimulateResponseSerializer, + 200: SimulateResponseListSerializer, 400: ErrorResponseSerializer, 404: None, }, @@ -67,7 +71,7 @@ def post(self, request, pk, format=None): return Response( serialized_result.data, status=status.HTTP_400_BAD_REQUEST ) - serialized_result = SimulateResponseSerializer(result) + serialized_result = SimulateResponseSerializer(result, many=True) return Response(serialized_result.data) diff --git a/pkpdapp/pkpdapp/models/myokit_model_mixin.py b/pkpdapp/pkpdapp/models/myokit_model_mixin.py index ae668ff2..8101e680 100644 --- a/pkpdapp/pkpdapp/models/myokit_model_mixin.py +++ b/pkpdapp/pkpdapp/models/myokit_model_mixin.py @@ -5,6 +5,7 @@ # import pkpdapp +from pkpdapp.models import SubjectGroup, Protocol import numpy as np from myokit.formats.mathml import MathMLExpressionWriter from myokit.formats.sbml import SBMLParser @@ -440,6 +441,27 @@ def serialize_datalog(self, datalog, myokit_model): def get_time_max(self): return self.time_max + def simulate_model( + self, outputs=None, variables=None, time_max=None, dosing_protocols=None + ): + model = self.get_myokit_model() + # Convert units + variables = self._initialise_variables(model, variables) + time_max = self._convert_bound_unit("time", time_max, model) + # get tlag vars + override_tlag = self._get_override_tlag(variables) + # create simulator + sim = self.create_myokit_simulator( + override_tlag=override_tlag, + model=model, + time_max=time_max, + dosing_protocols=dosing_protocols + ) + # TODO: take these from simulation model + sim.set_tolerance(abs_tol=1e-06, rel_tol=1e-08) + # Simulate, logging only state variables given by `outputs` + return self.serialize_datalog(sim.run(time_max, log=outputs), model) + def simulate(self, outputs=None, variables=None, time_max=None): """ Arguments @@ -475,21 +497,32 @@ def simulate(self, outputs=None, variables=None, time_max=None): **variables, } - model = self.get_myokit_model() - # Convert units - variables = self._initialise_variables(model, variables) - time_max = self._convert_bound_unit("time", time_max, model) - # get tlag vars - override_tlag = self._get_override_tlag(variables) - # create simulator - sim = self.create_myokit_simulator( - override_tlag=override_tlag, model=model, time_max=time_max + project_sim = self.simulate_model( + variables=variables, time_max=time_max, outputs=outputs ) - # TODO: take these from simulation model - sim.set_tolerance(abs_tol=1e-06, rel_tol=1e-08) - # Simulate, logging only state variables given by `outputs` - return self.serialize_datalog(sim.run(time_max, log=outputs), model) + project = self.get_project() + sims = [project_sim] + if project is not None: + for group in get_subject_groups(project): + # find unique protocols for this subject cohort + dosing_protocols = {} + subject_protocols = [ + Protocol.objects.get(pk=p['protocol']) + for p in group.subjects.values('protocol').distinct() + if p['protocol'] is not None + ] + for protocol in subject_protocols: + dosing_protocols[protocol.mapped_qname] = protocol + sim = self.simulate_model( + outputs=outputs, + variables=variables, + time_max=time_max, + dosing_protocols=dosing_protocols + ) + sims.append(sim) + + return sims def set_administration(model, drug_amount, direct=True): @@ -689,3 +722,15 @@ def _get_dosing_events( elif abs(start + duration - time_max) < 1e-6: dosing_events[i] = (level, start, time_max - start) return dosing_events + + +def get_subject_groups(project): + dataset = project.datasets.first() + dataset_groups = [] + if dataset is not None: + dataset_groups = [ + SubjectGroup.objects.get(pk=g['group']) + for g in dataset.subjects.values('group').distinct() + if g['group'] is not None + ] + return dataset_groups diff --git a/pkpdapp/pkpdapp/tests/test_views/test_combined_model.py b/pkpdapp/pkpdapp/tests/test_views/test_combined_model.py index 1e7ef1b3..aa37dc44 100644 --- a/pkpdapp/pkpdapp/tests/test_views/test_combined_model.py +++ b/pkpdapp/pkpdapp/tests/test_views/test_combined_model.py @@ -135,11 +135,12 @@ def simulate_combined_model(self, id): format="json", ) self.assertEqual(response.status_code, status.HTTP_200_OK) - keys = [key for key in response.data["outputs"].keys()] + data = response.data[0] + keys = [key for key in data["outputs"].keys()] return ( - response.data["outputs"][keys[0]], - response.data["outputs"][keys[1]], - response.data["outputs"][keys[2]], + data["outputs"][keys[0]], + data["outputs"][keys[1]], + data["outputs"][keys[2]], ) def test_swap_mapped_pd_model(self): diff --git a/pkpdapp/pkpdapp/tests/test_views/test_simulate.py b/pkpdapp/pkpdapp/tests/test_views/test_simulate.py index 91733d42..a3ce536b 100644 --- a/pkpdapp/pkpdapp/tests/test_views/test_simulate.py +++ b/pkpdapp/pkpdapp/tests/test_views/test_simulate.py @@ -5,9 +5,17 @@ # from pkpdapp.models import ( - PharmacodynamicModel, Variable, + PharmacodynamicModel, + PharmacokineticModel, + Variable, CombinedModel, - + Project, + Compound, + Dataset, + Protocol, + Subject, + SubjectGroup, + Unit ) from django.contrib.auth.models import User @@ -19,7 +27,29 @@ class TestSimulateView(APITestCase): def setUp(self): + au = Unit.objects.get(symbol='mg') + tu = Unit.objects.get(symbol='h') + compound = Compound.objects.create(name="demo", compound_type="LM") + self.project = Project.objects.create(name='test project', compound=compound) + self.dataset = Dataset.objects.create(name='test dataset', project=self.project) + self.protocol = Protocol.objects.create( + name='my_cool_protocol', + compound=compound, + amount_unit=au, + time_unit=tu, + mapped_qname='PKCompartment.A1' + ) + self.subject_group = SubjectGroup.objects.create( + name='my_cool_group', + ) + self.subject = Subject.objects.create( + id_in_dataset=1, + dataset=self.dataset, + group=self.subject_group, + protocol=self.protocol, + ) self.user = User.objects.create_user(username='testuser', password='12345') + self.project.users.add(self.user) self.client = APIClient() self.client.force_authenticate(user=self.user) @@ -28,9 +58,14 @@ def test_simulate(self): name='tumour_growth_gompertz', read_only=False, ) + pk = PharmacokineticModel.objects.get( + name="one_compartment_clinical", + ) m = CombinedModel.objects.create( name='my wonderful model', pd_model=pd, + pk_model=pk, + project=self.project, ) url = reverse('simulate-combined-model', args=(m.pk,)) @@ -44,14 +79,16 @@ def test_simulate(self): response = self.client.post(url, data, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) - outputs = response.data.get('outputs') - self.assertCountEqual( - list(outputs.keys()), - [ - Variable.objects.get(qname=qname, dosed_pk_model=m).id - for qname in data['outputs'] - ] - ) + + for sim in response.data: + outputs = sim.get('outputs') + self.assertCountEqual( + list(outputs.keys()), + [ + Variable.objects.get(qname=qname, dosed_pk_model=m).id + for qname in data['outputs'] + ] + ) url = reverse('simulate-combined-model', args=(123,)) response = self.client.post(url, data, format='json') diff --git a/pkpdapp/pkpdapp/tests/utils.py b/pkpdapp/pkpdapp/tests/utils.py index 150baebe..ef6f60a5 100644 --- a/pkpdapp/pkpdapp/tests/utils.py +++ b/pkpdapp/pkpdapp/tests/utils.py @@ -30,7 +30,7 @@ def create_pd_inference(sampling=False): # generate some fake data output = model.variables.get(qname='PDCompartment.TS') time = model.variables.get(qname='environment.t') - data = model.simulate(outputs=[output.qname, time.qname]) + data = model.simulate(outputs=[output.qname, time.qname])[0] print(data) TS = data[output.id] times = data[time.id] diff --git a/pkpdapp/schema.yml b/pkpdapp/schema.yml index 5e2e72c5..35b5bd01 100644 --- a/pkpdapp/schema.yml +++ b/pkpdapp/schema.yml @@ -539,7 +539,9 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SimulateResponse' + type: array + items: + $ref: '#/components/schemas/SimulateResponse' description: '' '400': content: @@ -1674,7 +1676,9 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SimulateResponse' + type: array + items: + $ref: '#/components/schemas/SimulateResponse' description: '' '400': content: