Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lokta-volterra inference module #8

Merged
merged 1 commit into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,9 @@ python causal_inference/base/lv_simulator.py
This will take all the default arguments and configuration to run a simulation instance of lotka-volterra population dynamics.

- The simulation statistics will be saved in the `repo/results` directory by default.

## Simulation and inference

- The simulation and inference methods are separately implemented in `repo/causal_inference/base/lotka_volterra/lv_system.py`.
- Currently, this inference method is experimental and may not always converge to correct optimal parameters.
- More work is needed to find a good approximation schema to initiate the parameters of the LV-system.
Empty file.
129 changes: 129 additions & 0 deletions causal_inference/base/lotka_volterra/lv_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-

import numpy as np
from scipy import integrate
from scipy.optimize import minimize

from causal_inference.config import LV_PARAMS

class LotkaVolterra():
'''
Base Lotka-Volterra Class that defines a predator-prey system.
'''
def __init__(self,
A=LV_PARAMS['A'], B=LV_PARAMS['B'], C=LV_PARAMS['C'], D=LV_PARAMS['D'],
prey_population=LV_PARAMS['INITIAL_PREY_POPULATION'],
pred_population=LV_PARAMS['INITIAL_PREDATOR_POPULATION'],
total_time=LV_PARAMS['TOTAL_TIME'], step_size=LV_PARAMS['STEP_SIZE'],
max_iter=LV_PARAMS['MAX_ITERATIONS']):
# Lotka-Volterra parameters
self.A = A
self.B = B
self.C = C
self.D = D

self.prey_population = prey_population # Initial prey population
self.predator_population = pred_population # Initial predator population

self.init_time = 0 # initial time
self.total_time = total_time # total time in units
self.step_size = step_size # increment for each time step
self.max_iterations = max_iter # tolerance parameter

self.time_stamps = np.arange(self.init_time, self.total_time, self.step_size)

@staticmethod
def LV_derivative(t, Z, A, B, C, D):
'''
Returns the rate of change of predator and prey population

Simulates Lotka-Volterra dynamics

Parameters:
t (list): [t0, tf] initial and final time points for simulation (Not used but necessary for integration step)
Z (tuple): (x, y) state of the system
A: prey growth rate (model parameter)
B: predation rate (model parameter)
C: predator death rate (model parameter)
D: predator growth rate from eating prey (model parameter)

Returns:
array: rate of change of prey and predator population
'''
x, y = Z
dotx = x * (A - B * y)
doty = y * (-C + D * x)
return np.array([dotx, doty])

Check warning on line 57 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L54-L57

Added lines #L54 - L57 were not covered by tests

def simulate_lotka_volterra(params, t, initial_conditions):
"""
Simulates Lotka-Volterra dynamics

Parameters:
params (tuple): (A, B, C, D) model parameters
A: prey growth rate
B: predation rate
C: predator death rate
D: predator growth rate from eating prey
t (array): time points for simulation
initial_conditions (tuple): (prey0, predator0) initial populations

Returns:
population: (n, 2) array where each row is [prey_pop, predator_pop]
"""
A, B, C, D = params

Check warning on line 75 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L75

Added line #L75 was not covered by tests

solution = integrate.solve_ivp(LV_derivative, [t[0], t[-1]], initial_conditions,

Check warning on line 77 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L77

Added line #L77 was not covered by tests
args=(A, B, C, D), dense_output=True)

population = solution.sol(t)
return population

Check warning on line 81 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L80-L81

Added lines #L80 - L81 were not covered by tests

def fit_lotka_volterra(time_points, observed_data, initial_guess):
"""
Fits Lotka-Volterra parameters to observed population data

Parameters:
time_points (array): time points of observations
observed_data (array): observed population data [prey, predator]
initial_guess (tuple): initial parameter guess (A, B, C, D)

Returns:
tuple: Fitted parameters (A, B, C, D) after optimization
"""
def objective_function(params):

Check warning on line 95 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L95

Added line #L95 was not covered by tests
# Simulate with current parameters
simulated = simulate_lotka_volterra(params, time_points,

Check warning on line 97 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L97

Added line #L97 was not covered by tests
observed_data[:, 0])
# Calculate mean squared error
mse = np.mean((simulated - observed_data) ** 2)
return mse

Check warning on line 101 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L100-L101

Added lines #L100 - L101 were not covered by tests

# Parameter bounds (all parameters must be positive)
bounds = [(0, None) for _ in range(4)]

Check warning on line 104 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L104

Added line #L104 was not covered by tests

# Optimize parameters
result = minimize(objective_function, initial_guess,

Check warning on line 107 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L107

Added line #L107 was not covered by tests
bounds=bounds, method='L-BFGS-B')

return result.x

Check warning on line 110 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L110

Added line #L110 was not covered by tests

# Example usage:
if __name__ == "__main__":
# Generate synthetic data
m = LotkaVolterra()
true_params = (m.A, m.B, m.C, m.D)
initial_conditions = (m.prey_population, m.predator_population) # (prey0, predator0)

Check warning on line 117 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L115-L117

Added lines #L115 - L117 were not covered by tests

# Generate synthetic data with some noise
data = simulate_lotka_volterra(true_params, m.time_stamps, initial_conditions)
noisy_data = data + np.random.normal(0, 0.1, data.shape)

Check warning on line 121 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L120-L121

Added lines #L120 - L121 were not covered by tests

# Fit parameters
initial_guess = (0.5, 0.05, 0.05, 0.05) #FIXME Use more educated schema for initial guess

Check warning on line 124 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L124

Added line #L124 was not covered by tests

fitted_params = fit_lotka_volterra(m.time_stamps, noisy_data, initial_guess)

Check warning on line 126 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L126

Added line #L126 was not covered by tests

print("True parameters:", true_params)
print("Fitted parameters:", fitted_params)

Check warning on line 129 in causal_inference/base/lotka_volterra/lv_system.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lotka_volterra/lv_system.py#L128-L129

Added lines #L128 - L129 were not covered by tests
14 changes: 9 additions & 5 deletions causal_inference/base/lv_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@
raise AssertionError(f'{method} is not implemented!')
return solver

def simulate_lotka_volterra(method):
log_LV_params()
solver = get_solver(method)
prey_list, predator_list = solver._solve()
return prey_list, predator_list, solver.time_stamps

Check warning on line 32 in causal_inference/base/lv_simulator.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lv_simulator.py#L28-L32

Added lines #L28 - L32 were not covered by tests

def main(method, results_dir):
'''
Main function that solves LV system.
'''
log_LV_params()
solver = get_solver(method)
prey_list, predator_list = solver._solve()
_save_population(prey_list, predator_list, solver.time_stamps, results_dir)
plot_population_over_time(prey_list, predator_list, solver.time_stamps, results_dir)
prey_list, predator_list, t = simulate_lotka_volterra(method)
_save_population(prey_list, predator_list, t, results_dir)
plot_population_over_time(prey_list, predator_list, t, results_dir)

Check warning on line 40 in causal_inference/base/lv_simulator.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/lv_simulator.py#L38-L40

Added lines #L38 - L40 were not covered by tests

if __name__ == '__main__':
PARSER = argparse.ArgumentParser()
Expand Down
31 changes: 0 additions & 31 deletions causal_inference/base/lv_system.py

This file was deleted.

15 changes: 2 additions & 13 deletions causal_inference/base/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# -*- coding: utf-8 -*-

import logging
import numpy as np
from scipy import integrate

from causal_inference.base.lv_system import LotkaVolterra
from causal_inference.base.lotka_volterra.lv_system import LotkaVolterra, LV_derivative

Check warning on line 7 in causal_inference/base/ode_solver.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/ode_solver.py#L7

Added line #L7 was not covered by tests

class ODE_solver(LotkaVolterra):
'''
Expand All @@ -15,24 +14,14 @@
super().__init__()
logging.info('Simulating Lotka-Volterra predator-prey dynamics with odeint solver')

@staticmethod
def LV_derivative(t, Z, A, B, C, D):
'''
Returns the rate of change of predator and prey population
'''
x, y = Z
dotx = x * (A - B * y)
doty = y * (-C + D * x)
return np.array([dotx, doty])

def _solve(self):
'''
ODE solver that returns the predator and prey populations at each time step in time series.
'''
logging.info(f'Computing population over {self.total_time} generation with step size of {self.step_size}...')

INIT_POP = [self.prey_population, self.predator_population]
sol = integrate.solve_ivp(self.LV_derivative, [self.init_time, self.total_time], INIT_POP, args=(self.A, self.B, self.C, self.D), dense_output=True)
sol = integrate.solve_ivp(LV_derivative, [self.init_time, self.total_time], INIT_POP, args=(self.A, self.B, self.C, self.D), dense_output=True)

Check warning on line 24 in causal_inference/base/ode_solver.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/ode_solver.py#L24

Added line #L24 was not covered by tests
prey_list, predator_list = sol.sol(self.time_stamps)

logging.info('done!')
Expand Down
2 changes: 1 addition & 1 deletion causal_inference/base/runge_kutta_solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
import logging
from causal_inference.base.lv_system import LotkaVolterra
from causal_inference.base.lotka_volterra.lv_system import LotkaVolterra

Check warning on line 4 in causal_inference/base/runge_kutta_solver.py

View check run for this annotation

Codecov / codecov/patch

causal_inference/base/runge_kutta_solver.py#L4

Added line #L4 was not covered by tests

class RungeKuttaSolver(LotkaVolterra):
'''
Expand Down
14 changes: 7 additions & 7 deletions causal_inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

# Lotka-Volterra Parameters
LV_PARAMS = {
'A' : 10.0,
'B' : 7.0,
'C' : 3.0,
'D' : 5.0,
'A' : 1.0,
'B' : 0.1,
'C' : 0.3,
'D' : 0.4,
'STEP_SIZE' : 0.01,
'TOTAL_TIME' : 20,
'INITIAL_PREY_POPULATION' : 60,
'TOTAL_TIME' : 10,
'INITIAL_PREY_POPULATION' : 40,
'INITIAL_PREDATOR_POPULATION' : 25,
'MAX_ITERATIONS' : 200
'MAX_ITERATIONS' : 100
}

# PATHS
Expand Down
2 changes: 1 addition & 1 deletion causal_inference/tests/test_lv_simulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# flake8: noqa
from causal_inference.base.lv_system import LotkaVolterra
from causal_inference.base.lotka_volterra.lv_system import LotkaVolterra
from causal_inference.config import LV_PARAMS

def test_lotka_volterra():
Expand Down
Loading