From 9e3529add78c345440d60b7635dee5f0f4cd554b Mon Sep 17 00:00:00 2001 From: Pranjal Dhole Date: Tue, 14 May 2024 22:25:35 +0200 Subject: [PATCH] implementation of odeint solver as an alternative to RK4 --- causal_inference/__init__.py | 1 - causal_inference/base/lotka_volterra.py | 122 -------------------- causal_inference/base/lv_simulator.py | 85 ++++++++++++++ causal_inference/base/lv_system.py | 20 ++++ causal_inference/base/ode_solver.py | 33 ++++++ causal_inference/base/runge_kutta_solver.py | 56 +++++++++ 6 files changed, 194 insertions(+), 123 deletions(-) delete mode 100644 causal_inference/base/lotka_volterra.py create mode 100644 causal_inference/base/lv_simulator.py create mode 100644 causal_inference/base/lv_system.py create mode 100644 causal_inference/base/ode_solver.py create mode 100644 causal_inference/base/runge_kutta_solver.py diff --git a/causal_inference/__init__.py b/causal_inference/__init__.py index 8ee0331..f2c76d3 100644 --- a/causal_inference/__init__.py +++ b/causal_inference/__init__.py @@ -1,5 +1,4 @@ """ Importing core modules """ -from .base.lotka_volterra import LotkaVolterra from .config import LV_PARAMS \ No newline at end of file diff --git a/causal_inference/base/lotka_volterra.py b/causal_inference/base/lotka_volterra.py deleted file mode 100644 index 28242ec..0000000 --- a/causal_inference/base/lotka_volterra.py +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -import argparse -import os -import logging -import datetime -import h5py -import matplotlib.pyplot as plt - -from causal_inference.config import LV_PARAMS, RESULTS_DIR -from causal_inference.utils.log_config import log_LV_params - - -class LotkaVolterra(): - ''' - Class simulates predator-prey dynamics and solves it with 4th order Runge-Kutta method. - ''' - def __init__(self): - self.A = LV_PARAMS['A'] - self.B = LV_PARAMS['B'] - self.C = LV_PARAMS['C'] - self.D = LV_PARAMS['D'] - self.time = LV_PARAMS['INITIAL_TIME'] - self.step_size = LV_PARAMS['STEP_SIZE'] - self.max_iterations = LV_PARAMS['MAX_ITERATIONS'] - - self.prey_population = LV_PARAMS['INITIAL_PREY_POPULATION'] - self.predator_population = LV_PARAMS['INITIAL_PREDATOR_POPULATION'] - - self.time_stamp = [self.time] - self.prey_list = [self.prey_population] - self.predator_list = [self.predator_population] - - def compute_prey_rate(self, current_prey, current_predators): - return self.A * current_prey - self.B * current_prey * current_predators - - def compute_predator_rate(self, current_prey, current_predators): - return - self.C * current_predators + self.D * current_prey * current_predators - - def runge_kutta_update(self, current_prey, current_predators): - self.time = self.time + self.step_size - self.time_stamp.append(self.time) - - k1_prey = self.step_size * self.compute_prey_rate(current_prey, current_predators) - k1_pred = self.step_size * self.compute_predator_rate(current_prey, current_predators) - - k2_prey = self.step_size * self.compute_prey_rate(current_prey + 0.5 * k1_prey, current_predators + 0.5 * k1_pred) - k2_pred = self.step_size * self.compute_predator_rate(current_prey + 0.5 * k1_prey, current_predators + 0.5 * k1_pred) - - k3_prey = self.step_size * self.compute_prey_rate(current_prey + 0.5 * k2_prey, current_predators + 0.5 * k2_pred) - k3_pred = self.step_size * self.compute_predator_rate(current_prey + 0.5 * k2_prey, current_predators + 0.5 * k2_pred) - - k4_prey = self.step_size * self.compute_prey_rate(current_prey + k3_prey, current_predators + k3_pred) - k4_pred = self.step_size * self.compute_predator_rate(current_prey + k3_prey, current_predators + k3_pred) - - new_prey_population = current_prey + 1/6 * (k1_prey + 2 * k2_prey + 2 * k3_prey + k4_prey) - new_predator_population = current_predators + 1/6 * (k1_pred + 2 * k2_pred + 2 * k3_pred + k4_pred) - - self.prey_list.append(new_prey_population) - self.predator_list.append(new_predator_population) - - return new_prey_population, new_predator_population - - def _save_population(self): - filename = os.path.join(RESULTS_DIR, 'populations.h5') - hf = h5py.File(filename, 'w') - hf.create_dataset('prey_pop', data=self.prey_list) - hf.create_dataset('pred_pop', data=self.predator_list) - hf.close() - - def _solve(self): - current_prey, current_predators = self.prey_population, self.predator_population - logging.info('Computing population over time...') - for gen_idx in range(self.max_iterations): - current_prey, current_predators = self.runge_kutta_update(current_prey, current_predators) - logging.info('Gen: %d | Prey population: %d | Predator population: %d', gen_idx, current_prey, current_predators) - print('Done!') - - def plot_population_over_time(self, save=True, filename='predator_prey'): - _ = plt.figure(figsize=(15, 5)) - PreyLine, = plt.plot(self.prey_list , color='g') - PredatorsLine, = plt.plot(self.predator_list, color='r') - plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators']) - plt.ylabel('Population') - plt.xlabel('Time') - if save: - plt.savefig(os.path.join(RESULTS_DIR, f"{filename}.svg"), - format='svg', transparent=False, bbox_inches='tight') - else: - plt.show() - plt.close() - -def main(): - logging.info('Lotka-Volterra predator-prey dynamics with 4th order Runge-Kutta method') - m = LotkaVolterra() - log_LV_params() - m._solve() - m._save_population() - m.plot_population_over_time() - -if __name__ == '__main__': - PARSER = argparse.ArgumentParser() - PARSER.add_argument('-log', '--logfile', help='name of the logfile', default='log') - PARSER.add_argument('-out', '--outdir', help='Where to place the results', default='lv_simulation') - ARGS = PARSER.parse_args() - - RESULTS_DIR = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir))) - - if not os.path.exists(RESULTS_DIR): - os.makedirs(RESULTS_DIR) - - LOG_FILE = os.path.join(RESULTS_DIR, f"{ARGS.logfile}.txt") # write logg to this file - logging.basicConfig( - level=logging.INFO, - handlers=[ - logging.FileHandler(LOG_FILE), - logging.StreamHandler() - ] - ) - - main() diff --git a/causal_inference/base/lv_simulator.py b/causal_inference/base/lv_simulator.py new file mode 100644 index 0000000..d9bce29 --- /dev/null +++ b/causal_inference/base/lv_simulator.py @@ -0,0 +1,85 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +import argparse +import os +import logging +import datetime + +import h5py +import matplotlib.pyplot as plt + +from causal_inference.config import RESULTS_DIR +from causal_inference.utils.log_config import log_LV_params +from causal_inference.base.ode_solver import ODE_solver +from causal_inference.base.runge_kutta_solver import RungeKuttaSolver + +def _save_population(prey_list, predator_list): + filename = os.path.join(RESULTS_DIR, 'populations.h5') + hf = h5py.File(filename, 'w') + hf.create_dataset('prey_pop', data=prey_list) + hf.create_dataset('pred_pop', data=predator_list) + hf.close() + +def plot_population_over_time(prey_list, predator_list, save=True, filename='predator_prey'): + fig = plt.figure(figsize=(15, 5)) + ax = fig.add_subplot(2, 1, 1) + PreyLine, = plt.plot(prey_list , color='g') + PredatorsLine, = plt.plot(predator_list, color='r') + ax.set_xscale('log') + + plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators']) + plt.ylabel('Population') + plt.xlabel('Time') + if save: + plt.savefig(os.path.join(RESULTS_DIR, f"{filename}.svg"), + format='svg', transparent=False, bbox_inches='tight') + else: + plt.show() + plt.close() + +def get_solver(method): + ''' + solving LV equation with scipy function. + ''' + if method == 'RK4': + solver = RungeKuttaSolver() + elif method == 'ODE': + solver = ODE_solver() + else: + raise AssertionError(f'{method} is not implemented!') + return solver + +def main(method): + ''' + Main function that solves LV system. + ''' + log_LV_params() + solver = get_solver(method) + prey_list, predator_list = solver._solve() + _save_population(prey_list, predator_list) + plot_population_over_time(prey_list, predator_list) + +if __name__ == '__main__': + PARSER = argparse.ArgumentParser() + PARSER.add_argument('-log', '--logfile', help='name of the logfile', default='log') + PARSER.add_argument('-out', '--outdir', help='Where to place the results', default='lv_simulation') + PARSER.add_argument('-s', '--solver', help='Mathematical solver for solving LV system', + choices=['RK4', 'ODE'], default='RK4') + ARGS = PARSER.parse_args() + + RESULTS_DIR = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir))) + + if not os.path.exists(RESULTS_DIR): + os.makedirs(RESULTS_DIR) + + LOG_FILE = os.path.join(RESULTS_DIR, f"{ARGS.logfile}.txt") # write logg to this file + logging.basicConfig( + level=logging.INFO, + handlers=[ + logging.FileHandler(LOG_FILE), + logging.StreamHandler() + ] + ) + solver = ARGS.solver + main(solver) diff --git a/causal_inference/base/lv_system.py b/causal_inference/base/lv_system.py new file mode 100644 index 0000000..b580e78 --- /dev/null +++ b/causal_inference/base/lv_system.py @@ -0,0 +1,20 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +from causal_inference.config import LV_PARAMS + +class LotkaVolterra(): + ''' + Class simulates predator-prey dynamics and solves it with 4th order Runge-Kutta method. + ''' + def __init__(self): + self.A = LV_PARAMS['A'] + self.B = LV_PARAMS['B'] + self.C = LV_PARAMS['C'] + self.D = LV_PARAMS['D'] + self.time = LV_PARAMS['INITIAL_TIME'] + self.step_size = LV_PARAMS['STEP_SIZE'] + self.max_iterations = LV_PARAMS['MAX_ITERATIONS'] + + self.prey_population = LV_PARAMS['INITIAL_PREY_POPULATION'] + self.predator_population = LV_PARAMS['INITIAL_PREDATOR_POPULATION'] \ No newline at end of file diff --git a/causal_inference/base/ode_solver.py b/causal_inference/base/ode_solver.py new file mode 100644 index 0000000..7ddb0e1 --- /dev/null +++ b/causal_inference/base/ode_solver.py @@ -0,0 +1,33 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +import logging +import numpy as np +from scipy import integrate + +from causal_inference.config import LV_PARAMS +from causal_inference.base.lv_system import LotkaVolterra + +class ODE_solver(LotkaVolterra): + ''' + This class implements scipy ODE solver for LV system. + ''' + def __init__(self): + super().__init__() + logging.info('Solving Lotka-Volterra predator-prey dynamics odeint solver') + + @staticmethod + def LV_derivative(X, t, alpha, beta, delta, gamma): + x, y = X + dotx = x * (alpha - beta * y) + doty = y * (-delta + gamma * x) + return np.array([dotx, doty]) + + def _solve(self): + logging.info('Computing population over time...') + t = np.arange(0.,self.max_iterations, self.step_size) + X0 = [self.prey_population, self.predator_population] + res = integrate.odeint(self.LV_derivative, X0, t, args=(self.A, self.B, self.C, self.D)) + prey_list, predator_list = res.T + logging.info('done!') + return prey_list, predator_list \ No newline at end of file diff --git a/causal_inference/base/runge_kutta_solver.py b/causal_inference/base/runge_kutta_solver.py new file mode 100644 index 0000000..8a54c89 --- /dev/null +++ b/causal_inference/base/runge_kutta_solver.py @@ -0,0 +1,56 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import logging +from math import ceil +from causal_inference.base.lv_system import LotkaVolterra + +class RungeKuttaSolver(LotkaVolterra): + ''' + This class implements 4th order Runge-Kutta solver. + ''' + def __init__(self): + super().__init__() + logging.info('Solving Lotka-Volterra predator-prey dynamics with 4th order Runge-Kutta method') + self.time_stamp = [self.time] + self.prey_list = [self.prey_population] + self.predator_list = [self.predator_population] + + def compute_prey_rate(self, current_prey, current_predators): + return self.A * current_prey - self.B * current_prey * current_predators + + def compute_predator_rate(self, current_prey, current_predators): + return - self.C * current_predators + self.D * current_prey * current_predators + + def runge_kutta_update(self, current_prey, current_predators): + self.time = self.time + self.step_size + self.time_stamp.append(self.time) + + k1_prey = self.step_size * self.compute_prey_rate(current_prey, current_predators) + k1_pred = self.step_size * self.compute_predator_rate(current_prey, current_predators) + + k2_prey = self.step_size * self.compute_prey_rate(current_prey + 0.5 * k1_prey, current_predators + 0.5 * k1_pred) + k2_pred = self.step_size * self.compute_predator_rate(current_prey + 0.5 * k1_prey, current_predators + 0.5 * k1_pred) + + k3_prey = self.step_size * self.compute_prey_rate(current_prey + 0.5 * k2_prey, current_predators + 0.5 * k2_pred) + k3_pred = self.step_size * self.compute_predator_rate(current_prey + 0.5 * k2_prey, current_predators + 0.5 * k2_pred) + + k4_prey = self.step_size * self.compute_prey_rate(current_prey + k3_prey, current_predators + k3_pred) + k4_pred = self.step_size * self.compute_predator_rate(current_prey + k3_prey, current_predators + k3_pred) + + new_prey_population = current_prey + 1/6 * (k1_prey + 2 * k2_prey + 2 * k3_prey + k4_prey) + new_predator_population = current_predators + 1/6 * (k1_pred + 2 * k2_pred + 2 * k3_pred + k4_pred) + + self.prey_list.append(new_prey_population) + self.predator_list.append(new_predator_population) + + return new_prey_population, new_predator_population + + def _solve(self): + current_prey, current_predators = self.prey_population, self.predator_population + logging.info('Computing population over time...') + for gen_idx in range(ceil(self.max_iterations/self.step_size)): + current_prey, current_predators = self.runge_kutta_update(current_prey, current_predators) + msg= f'Gen: {gen_idx} | Prey population: {current_prey} | Predator population: {current_predators}' + logging.info(msg) + print('Done!') + return self.prey_list, self.predator_list \ No newline at end of file