Skip to content

Commit

Permalink
implementation of odeint solver as an alternative to RK4
Browse files Browse the repository at this point in the history
  • Loading branch information
pranjaldhole committed May 14, 2024
1 parent 54c1ba2 commit 9e3529a
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 123 deletions.
1 change: 0 additions & 1 deletion causal_inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Importing core modules
"""
from .base.lotka_volterra import LotkaVolterra
from .config import LV_PARAMS
122 changes: 0 additions & 122 deletions causal_inference/base/lotka_volterra.py

This file was deleted.

85 changes: 85 additions & 0 deletions causal_inference/base/lv_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-

import argparse
import os
import logging
import datetime

import h5py
import matplotlib.pyplot as plt

from causal_inference.config import RESULTS_DIR
from causal_inference.utils.log_config import log_LV_params
from causal_inference.base.ode_solver import ODE_solver
from causal_inference.base.runge_kutta_solver import RungeKuttaSolver

def _save_population(prey_list, predator_list):
filename = os.path.join(RESULTS_DIR, 'populations.h5')
hf = h5py.File(filename, 'w')
hf.create_dataset('prey_pop', data=prey_list)
hf.create_dataset('pred_pop', data=predator_list)
hf.close()

def plot_population_over_time(prey_list, predator_list, save=True, filename='predator_prey'):
fig = plt.figure(figsize=(15, 5))
ax = fig.add_subplot(2, 1, 1)
PreyLine, = plt.plot(prey_list , color='g')
PredatorsLine, = plt.plot(predator_list, color='r')
ax.set_xscale('log')

plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators'])
plt.ylabel('Population')
plt.xlabel('Time')
if save:
plt.savefig(os.path.join(RESULTS_DIR, f"{filename}.svg"),
format='svg', transparent=False, bbox_inches='tight')
else:
plt.show()
plt.close()

def get_solver(method):
'''
solving LV equation with scipy function.
'''
if method == 'RK4':
solver = RungeKuttaSolver()
elif method == 'ODE':
solver = ODE_solver()
else:
raise AssertionError(f'{method} is not implemented!')
return solver

def main(method):
'''
Main function that solves LV system.
'''
log_LV_params()
solver = get_solver(method)
prey_list, predator_list = solver._solve()
_save_population(prey_list, predator_list)
plot_population_over_time(prey_list, predator_list)

if __name__ == '__main__':
PARSER = argparse.ArgumentParser()
PARSER.add_argument('-log', '--logfile', help='name of the logfile', default='log')
PARSER.add_argument('-out', '--outdir', help='Where to place the results', default='lv_simulation')
PARSER.add_argument('-s', '--solver', help='Mathematical solver for solving LV system',
choices=['RK4', 'ODE'], default='RK4')
ARGS = PARSER.parse_args()

RESULTS_DIR = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir)))

if not os.path.exists(RESULTS_DIR):
os.makedirs(RESULTS_DIR)

LOG_FILE = os.path.join(RESULTS_DIR, f"{ARGS.logfile}.txt") # write logg to this file
logging.basicConfig(
level=logging.INFO,
handlers=[
logging.FileHandler(LOG_FILE),
logging.StreamHandler()
]
)
solver = ARGS.solver
main(solver)
20 changes: 20 additions & 0 deletions causal_inference/base/lv_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-

from causal_inference.config import LV_PARAMS

class LotkaVolterra():
'''
Class simulates predator-prey dynamics and solves it with 4th order Runge-Kutta method.
'''
def __init__(self):
self.A = LV_PARAMS['A']
self.B = LV_PARAMS['B']
self.C = LV_PARAMS['C']
self.D = LV_PARAMS['D']
self.time = LV_PARAMS['INITIAL_TIME']
self.step_size = LV_PARAMS['STEP_SIZE']
self.max_iterations = LV_PARAMS['MAX_ITERATIONS']

self.prey_population = LV_PARAMS['INITIAL_PREY_POPULATION']
self.predator_population = LV_PARAMS['INITIAL_PREDATOR_POPULATION']
33 changes: 33 additions & 0 deletions causal_inference/base/ode_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-

import logging
import numpy as np
from scipy import integrate

from causal_inference.config import LV_PARAMS
from causal_inference.base.lv_system import LotkaVolterra

class ODE_solver(LotkaVolterra):
'''
This class implements scipy ODE solver for LV system.
'''
def __init__(self):
super().__init__()
logging.info('Solving Lotka-Volterra predator-prey dynamics odeint solver')

@staticmethod
def LV_derivative(X, t, alpha, beta, delta, gamma):
x, y = X
dotx = x * (alpha - beta * y)
doty = y * (-delta + gamma * x)
return np.array([dotx, doty])

def _solve(self):
logging.info('Computing population over time...')
t = np.arange(0.,self.max_iterations, self.step_size)
X0 = [self.prey_population, self.predator_population]
res = integrate.odeint(self.LV_derivative, X0, t, args=(self.A, self.B, self.C, self.D))
prey_list, predator_list = res.T
logging.info('done!')
return prey_list, predator_list
56 changes: 56 additions & 0 deletions causal_inference/base/runge_kutta_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
import logging
from math import ceil
from causal_inference.base.lv_system import LotkaVolterra

class RungeKuttaSolver(LotkaVolterra):
'''
This class implements 4th order Runge-Kutta solver.
'''
def __init__(self):
super().__init__()
logging.info('Solving Lotka-Volterra predator-prey dynamics with 4th order Runge-Kutta method')
self.time_stamp = [self.time]
self.prey_list = [self.prey_population]
self.predator_list = [self.predator_population]

def compute_prey_rate(self, current_prey, current_predators):
return self.A * current_prey - self.B * current_prey * current_predators

def compute_predator_rate(self, current_prey, current_predators):
return - self.C * current_predators + self.D * current_prey * current_predators

def runge_kutta_update(self, current_prey, current_predators):
self.time = self.time + self.step_size
self.time_stamp.append(self.time)

k1_prey = self.step_size * self.compute_prey_rate(current_prey, current_predators)
k1_pred = self.step_size * self.compute_predator_rate(current_prey, current_predators)

k2_prey = self.step_size * self.compute_prey_rate(current_prey + 0.5 * k1_prey, current_predators + 0.5 * k1_pred)
k2_pred = self.step_size * self.compute_predator_rate(current_prey + 0.5 * k1_prey, current_predators + 0.5 * k1_pred)

k3_prey = self.step_size * self.compute_prey_rate(current_prey + 0.5 * k2_prey, current_predators + 0.5 * k2_pred)
k3_pred = self.step_size * self.compute_predator_rate(current_prey + 0.5 * k2_prey, current_predators + 0.5 * k2_pred)

k4_prey = self.step_size * self.compute_prey_rate(current_prey + k3_prey, current_predators + k3_pred)
k4_pred = self.step_size * self.compute_predator_rate(current_prey + k3_prey, current_predators + k3_pred)

new_prey_population = current_prey + 1/6 * (k1_prey + 2 * k2_prey + 2 * k3_prey + k4_prey)
new_predator_population = current_predators + 1/6 * (k1_pred + 2 * k2_pred + 2 * k3_pred + k4_pred)

self.prey_list.append(new_prey_population)
self.predator_list.append(new_predator_population)

return new_prey_population, new_predator_population

def _solve(self):
current_prey, current_predators = self.prey_population, self.predator_population
logging.info('Computing population over time...')
for gen_idx in range(ceil(self.max_iterations/self.step_size)):
current_prey, current_predators = self.runge_kutta_update(current_prey, current_predators)
msg= f'Gen: {gen_idx} | Prey population: {current_prey} | Predator population: {current_predators}'
logging.info(msg)
print('Done!')
return self.prey_list, self.predator_list

0 comments on commit 9e3529a

Please sign in to comment.