-
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implementation of odeint solver as an alternative to RK4
- Loading branch information
1 parent
54c1ba2
commit 9e3529a
Showing
6 changed files
with
194 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
""" | ||
Importing core modules | ||
""" | ||
from .base.lotka_volterra import LotkaVolterra | ||
from .config import LV_PARAMS |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#!/usr/bin/python | ||
# -*- coding: utf-8 -*- | ||
|
||
import argparse | ||
import os | ||
import logging | ||
import datetime | ||
|
||
import h5py | ||
import matplotlib.pyplot as plt | ||
|
||
from causal_inference.config import RESULTS_DIR | ||
from causal_inference.utils.log_config import log_LV_params | ||
from causal_inference.base.ode_solver import ODE_solver | ||
from causal_inference.base.runge_kutta_solver import RungeKuttaSolver | ||
|
||
def _save_population(prey_list, predator_list): | ||
filename = os.path.join(RESULTS_DIR, 'populations.h5') | ||
hf = h5py.File(filename, 'w') | ||
hf.create_dataset('prey_pop', data=prey_list) | ||
hf.create_dataset('pred_pop', data=predator_list) | ||
hf.close() | ||
|
||
def plot_population_over_time(prey_list, predator_list, save=True, filename='predator_prey'): | ||
fig = plt.figure(figsize=(15, 5)) | ||
ax = fig.add_subplot(2, 1, 1) | ||
PreyLine, = plt.plot(prey_list , color='g') | ||
PredatorsLine, = plt.plot(predator_list, color='r') | ||
ax.set_xscale('log') | ||
|
||
plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators']) | ||
plt.ylabel('Population') | ||
plt.xlabel('Time') | ||
if save: | ||
plt.savefig(os.path.join(RESULTS_DIR, f"{filename}.svg"), | ||
format='svg', transparent=False, bbox_inches='tight') | ||
else: | ||
plt.show() | ||
plt.close() | ||
|
||
def get_solver(method): | ||
''' | ||
solving LV equation with scipy function. | ||
''' | ||
if method == 'RK4': | ||
solver = RungeKuttaSolver() | ||
elif method == 'ODE': | ||
solver = ODE_solver() | ||
else: | ||
raise AssertionError(f'{method} is not implemented!') | ||
return solver | ||
|
||
def main(method): | ||
''' | ||
Main function that solves LV system. | ||
''' | ||
log_LV_params() | ||
solver = get_solver(method) | ||
prey_list, predator_list = solver._solve() | ||
_save_population(prey_list, predator_list) | ||
plot_population_over_time(prey_list, predator_list) | ||
|
||
if __name__ == '__main__': | ||
PARSER = argparse.ArgumentParser() | ||
PARSER.add_argument('-log', '--logfile', help='name of the logfile', default='log') | ||
PARSER.add_argument('-out', '--outdir', help='Where to place the results', default='lv_simulation') | ||
PARSER.add_argument('-s', '--solver', help='Mathematical solver for solving LV system', | ||
choices=['RK4', 'ODE'], default='RK4') | ||
ARGS = PARSER.parse_args() | ||
|
||
RESULTS_DIR = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir))) | ||
|
||
if not os.path.exists(RESULTS_DIR): | ||
os.makedirs(RESULTS_DIR) | ||
|
||
LOG_FILE = os.path.join(RESULTS_DIR, f"{ARGS.logfile}.txt") # write logg to this file | ||
logging.basicConfig( | ||
level=logging.INFO, | ||
handlers=[ | ||
logging.FileHandler(LOG_FILE), | ||
logging.StreamHandler() | ||
] | ||
) | ||
solver = ARGS.solver | ||
main(solver) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#!/usr/bin/python | ||
# -*- coding: utf-8 -*- | ||
|
||
from causal_inference.config import LV_PARAMS | ||
|
||
class LotkaVolterra(): | ||
''' | ||
Class simulates predator-prey dynamics and solves it with 4th order Runge-Kutta method. | ||
''' | ||
def __init__(self): | ||
self.A = LV_PARAMS['A'] | ||
self.B = LV_PARAMS['B'] | ||
self.C = LV_PARAMS['C'] | ||
self.D = LV_PARAMS['D'] | ||
self.time = LV_PARAMS['INITIAL_TIME'] | ||
self.step_size = LV_PARAMS['STEP_SIZE'] | ||
self.max_iterations = LV_PARAMS['MAX_ITERATIONS'] | ||
|
||
self.prey_population = LV_PARAMS['INITIAL_PREY_POPULATION'] | ||
self.predator_population = LV_PARAMS['INITIAL_PREDATOR_POPULATION'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#!/usr/bin/python | ||
# -*- coding: utf-8 -*- | ||
|
||
import logging | ||
import numpy as np | ||
from scipy import integrate | ||
|
||
from causal_inference.config import LV_PARAMS | ||
from causal_inference.base.lv_system import LotkaVolterra | ||
|
||
class ODE_solver(LotkaVolterra): | ||
''' | ||
This class implements scipy ODE solver for LV system. | ||
''' | ||
def __init__(self): | ||
super().__init__() | ||
logging.info('Solving Lotka-Volterra predator-prey dynamics odeint solver') | ||
|
||
@staticmethod | ||
def LV_derivative(X, t, alpha, beta, delta, gamma): | ||
x, y = X | ||
dotx = x * (alpha - beta * y) | ||
doty = y * (-delta + gamma * x) | ||
return np.array([dotx, doty]) | ||
|
||
def _solve(self): | ||
logging.info('Computing population over time...') | ||
t = np.arange(0.,self.max_iterations, self.step_size) | ||
X0 = [self.prey_population, self.predator_population] | ||
res = integrate.odeint(self.LV_derivative, X0, t, args=(self.A, self.B, self.C, self.D)) | ||
prey_list, predator_list = res.T | ||
logging.info('done!') | ||
return prey_list, predator_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#!/usr/bin/python | ||
# -*- coding: utf-8 -*- | ||
import logging | ||
from math import ceil | ||
from causal_inference.base.lv_system import LotkaVolterra | ||
|
||
class RungeKuttaSolver(LotkaVolterra): | ||
''' | ||
This class implements 4th order Runge-Kutta solver. | ||
''' | ||
def __init__(self): | ||
super().__init__() | ||
logging.info('Solving Lotka-Volterra predator-prey dynamics with 4th order Runge-Kutta method') | ||
self.time_stamp = [self.time] | ||
self.prey_list = [self.prey_population] | ||
self.predator_list = [self.predator_population] | ||
|
||
def compute_prey_rate(self, current_prey, current_predators): | ||
return self.A * current_prey - self.B * current_prey * current_predators | ||
|
||
def compute_predator_rate(self, current_prey, current_predators): | ||
return - self.C * current_predators + self.D * current_prey * current_predators | ||
|
||
def runge_kutta_update(self, current_prey, current_predators): | ||
self.time = self.time + self.step_size | ||
self.time_stamp.append(self.time) | ||
|
||
k1_prey = self.step_size * self.compute_prey_rate(current_prey, current_predators) | ||
k1_pred = self.step_size * self.compute_predator_rate(current_prey, current_predators) | ||
|
||
k2_prey = self.step_size * self.compute_prey_rate(current_prey + 0.5 * k1_prey, current_predators + 0.5 * k1_pred) | ||
k2_pred = self.step_size * self.compute_predator_rate(current_prey + 0.5 * k1_prey, current_predators + 0.5 * k1_pred) | ||
|
||
k3_prey = self.step_size * self.compute_prey_rate(current_prey + 0.5 * k2_prey, current_predators + 0.5 * k2_pred) | ||
k3_pred = self.step_size * self.compute_predator_rate(current_prey + 0.5 * k2_prey, current_predators + 0.5 * k2_pred) | ||
|
||
k4_prey = self.step_size * self.compute_prey_rate(current_prey + k3_prey, current_predators + k3_pred) | ||
k4_pred = self.step_size * self.compute_predator_rate(current_prey + k3_prey, current_predators + k3_pred) | ||
|
||
new_prey_population = current_prey + 1/6 * (k1_prey + 2 * k2_prey + 2 * k3_prey + k4_prey) | ||
new_predator_population = current_predators + 1/6 * (k1_pred + 2 * k2_pred + 2 * k3_pred + k4_pred) | ||
|
||
self.prey_list.append(new_prey_population) | ||
self.predator_list.append(new_predator_population) | ||
|
||
return new_prey_population, new_predator_population | ||
|
||
def _solve(self): | ||
current_prey, current_predators = self.prey_population, self.predator_population | ||
logging.info('Computing population over time...') | ||
for gen_idx in range(ceil(self.max_iterations/self.step_size)): | ||
current_prey, current_predators = self.runge_kutta_update(current_prey, current_predators) | ||
msg= f'Gen: {gen_idx} | Prey population: {current_prey} | Predator population: {current_predators}' | ||
logging.info(msg) | ||
print('Done!') | ||
return self.prey_list, self.predator_list |