Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

harmonization of simulation solver method imports #5

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 10 additions & 35 deletions causal_inference/base/lv_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,12 @@
import logging
import datetime

import h5py
import matplotlib.pyplot as plt

from causal_inference.config import RESULTS_DIR
from causal_inference.utils.log_config import log_LV_params
from causal_inference.base.ode_solver import ODE_solver
from causal_inference.base.runge_kutta_solver import RungeKuttaSolver

def _save_population(prey_list, predator_list):
filename = os.path.join(RESULTS_DIR, 'populations.h5')
hf = h5py.File(filename, 'w')
hf.create_dataset('prey_pop', data=prey_list)
hf.create_dataset('pred_pop', data=predator_list)
hf.close()

def plot_population_over_time(prey_list, predator_list, save=True, filename='predator_prey'):
fig = plt.figure(figsize=(15, 5))
ax = fig.add_subplot(2, 1, 1)
PreyLine, = plt.plot(prey_list , color='g')
PredatorsLine, = plt.plot(predator_list, color='r')
ax.set_xscale('log')

plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators'])
plt.ylabel('Population')
plt.xlabel('Time')
if save:
plt.savefig(os.path.join(RESULTS_DIR, f"{filename}.svg"),
format='svg', transparent=False, bbox_inches='tight')
else:
plt.show()
plt.close()
from causal_inference.utils.writer import _save_population
from causal_inference.utils.visualisations import plot_population_over_time

def get_solver(method):
'''
Expand All @@ -50,15 +25,15 @@ def get_solver(method):
raise AssertionError(f'{method} is not implemented!')
return solver

def main(method):
def main(method, results_dir):
'''
Main function that solves LV system.
'''
log_LV_params()
solver = get_solver(method)
prey_list, predator_list = solver._solve()
_save_population(prey_list, predator_list)
plot_population_over_time(prey_list, predator_list)
_save_population(prey_list, predator_list, solver.time_stamps, results_dir)
plot_population_over_time(prey_list, predator_list, solver.time_stamps, results_dir)

if __name__ == '__main__':
PARSER = argparse.ArgumentParser()
Expand All @@ -68,12 +43,12 @@ def main(method):
choices=['RK4', 'ODE'], default='RK4')
ARGS = PARSER.parse_args()

RESULTS_DIR = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir)))
results_dir = os.path.join(RESULTS_DIR, '{}_{}'.format(datetime.datetime.now().strftime("%Y%h%d_%H_%M_%S"), str(ARGS.outdir)))

if not os.path.exists(RESULTS_DIR):
os.makedirs(RESULTS_DIR)
if not os.path.exists(results_dir):
os.makedirs(results_dir)

LOG_FILE = os.path.join(RESULTS_DIR, f"{ARGS.logfile}.txt") # write logg to this file
LOG_FILE = os.path.join(results_dir, f"{ARGS.logfile}.txt") # write logg to this file
logging.basicConfig(
level=logging.INFO,
handlers=[
Expand All @@ -82,4 +57,4 @@ def main(method):
]
)
solver = ARGS.solver
main(solver)
main(solver, results_dir)
33 changes: 22 additions & 11 deletions causal_inference/base/lv_system.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-

import numpy as np
from causal_inference.config import LV_PARAMS

class LotkaVolterra():
'''
Class simulates predator-prey dynamics and solves it with 4th order Runge-Kutta method.
Base Lotka-Volterra Class that defines a predator-prey system.
'''
def __init__(self):
self.A = LV_PARAMS['A']
self.B = LV_PARAMS['B']
self.C = LV_PARAMS['C']
self.D = LV_PARAMS['D']
self.time = LV_PARAMS['INITIAL_TIME']
self.step_size = LV_PARAMS['STEP_SIZE']
self.max_iterations = LV_PARAMS['MAX_ITERATIONS']
def __init__(self,
A=LV_PARAMS['A'], B=LV_PARAMS['B'], C=LV_PARAMS['C'], D=LV_PARAMS['D'],
prey_population=LV_PARAMS['INITIAL_PREY_POPULATION'],
pred_population=LV_PARAMS['INITIAL_PREDATOR_POPULATION'],
total_time=LV_PARAMS['TOTAL_TIME'], step_size=LV_PARAMS['STEP_SIZE'],
max_iter=LV_PARAMS['MAX_ITERATIONS']):
# Lotka-Volterra parameters
self.A = A
self.B = B
self.C = C
self.D = D

self.prey_population = LV_PARAMS['INITIAL_PREY_POPULATION']
self.predator_population = LV_PARAMS['INITIAL_PREDATOR_POPULATION']
self.prey_population = prey_population # Initial prey population
self.predator_population = pred_population # Initial predator population

self.init_time = 0 # initial time
self.total_time = total_time # total time in units
self.step_size = step_size # increment for each time step
self.max_iterations = max_iter # tolerance parameter

self.time_stamps = np.arange(self.init_time, self.total_time, self.step_size)
31 changes: 19 additions & 12 deletions causal_inference/base/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
from scipy import integrate

from causal_inference.config import LV_PARAMS
from causal_inference.base.lv_system import LotkaVolterra

class ODE_solver(LotkaVolterra):
Expand All @@ -14,20 +13,28 @@ class ODE_solver(LotkaVolterra):
'''
def __init__(self):
super().__init__()
logging.info('Solving Lotka-Volterra predator-prey dynamics odeint solver')
logging.info('Simulating Lotka-Volterra predator-prey dynamics with odeint solver')

@staticmethod
def LV_derivative(X, t, alpha, beta, delta, gamma):
x, y = X
dotx = x * (alpha - beta * y)
doty = y * (-delta + gamma * x)
def LV_derivative(t, Z, A, B, C, D):
'''
Returns the rate of change of predator and prey population
'''
x, y = Z
dotx = x * (A - B * y)
doty = y * (-C + D * x)
return np.array([dotx, doty])

def _solve(self):
logging.info('Computing population over time...')
t = np.arange(0.,self.max_iterations, self.step_size)
X0 = [self.prey_population, self.predator_population]
res = integrate.odeint(self.LV_derivative, X0, t, args=(self.A, self.B, self.C, self.D))
prey_list, predator_list = res.T
'''
ODE solver that returns the predator and prey populations at each time step in time series.
'''
logging.info(f'Computing population over {self.total_time} generation with step size of {self.step_size}...')

INIT_POP = [self.prey_population, self.predator_population]
sol = integrate.solve_ivp(self.LV_derivative, [self.init_time, self.total_time], INIT_POP, args=(self.A, self.B, self.C, self.D), dense_output=True)
prey_list, predator_list = sol.sol(self.time_stamps)

logging.info('done!')
return prey_list, predator_list

return prey_list, predator_list
16 changes: 9 additions & 7 deletions causal_inference/base/runge_kutta_solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
import logging
from math import ceil
from causal_inference.base.lv_system import LotkaVolterra

class RungeKuttaSolver(LotkaVolterra):
Expand All @@ -11,7 +10,6 @@ class RungeKuttaSolver(LotkaVolterra):
def __init__(self):
super().__init__()
logging.info('Solving Lotka-Volterra predator-prey dynamics with 4th order Runge-Kutta method')
self.time_stamp = [self.time]
self.prey_list = [self.prey_population]
self.predator_list = [self.predator_population]

Expand All @@ -22,8 +20,6 @@ def compute_predator_rate(self, current_prey, current_predators):
return - self.C * current_predators + self.D * current_prey * current_predators

def runge_kutta_update(self, current_prey, current_predators):
self.time = self.time + self.step_size
self.time_stamp.append(self.time)

k1_prey = self.step_size * self.compute_prey_rate(current_prey, current_predators)
k1_pred = self.step_size * self.compute_predator_rate(current_prey, current_predators)
Expand All @@ -46,11 +42,17 @@ def runge_kutta_update(self, current_prey, current_predators):
return new_prey_population, new_predator_population

def _solve(self):
'''
Runge-Kutta solver that returns the predator and prey populations at each time step in time series.
'''
#initial population
current_prey, current_predators = self.prey_population, self.predator_population
logging.info('Computing population over time...')
for gen_idx in range(ceil(self.max_iterations/self.step_size)):

logging.info(f'Computing population over {self.total_time} generation with step size of {self.step_size}...')

for step_idx in self.time_stamps[1:]:
current_prey, current_predators = self.runge_kutta_update(current_prey, current_predators)
msg= f'Gen: {gen_idx} | Prey population: {current_prey} | Predator population: {current_predators}'
msg= f'Step: {step_idx} | Prey population: {current_prey} | Predator population: {current_predators}'
logging.info(msg)
print('Done!')
return self.prey_list, self.predator_list
2 changes: 1 addition & 1 deletion causal_inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
'C' : 3.0,
'D' : 5.0,
'STEP_SIZE' : 0.01,
'INITIAL_TIME' : 0,
'TOTAL_TIME' : 20,
'INITIAL_PREY_POPULATION' : 60,
'INITIAL_PREDATOR_POPULATION' : 25,
'MAX_ITERATIONS' : 200
Expand Down
21 changes: 21 additions & 0 deletions causal_inference/utils/visualisations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
from os.path import join
import matplotlib.pyplot as plt

def plot_population_over_time(prey_list, predator_list, time_stamps, results_dir, save=True, filename='predator_prey'):
fig = plt.figure(figsize=(15, 5))
ax = fig.add_subplot(2, 1, 1)
PreyLine, = plt.plot(time_stamps, prey_list, color='g')
PredatorsLine, = plt.plot(time_stamps, predator_list, color='r')
ax.set_xscale('log')

plt.legend([PreyLine, PredatorsLine], ['Prey', 'Predators'])
plt.ylabel('Population')
plt.xlabel('Time')
if save:
plt.savefig(join(results_dir, f"{filename}.svg"),
format='svg', transparent=False, bbox_inches='tight')
else:
plt.show()
plt.close()
12 changes: 12 additions & 0 deletions causal_inference/utils/writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
from os.path import join
import h5py

def _save_population(prey_list, predator_list, time_stamps, results_dir):
filename = join(results_dir, 'populations.h5')
hf = h5py.File(filename, 'w')
hf.create_dataset('time_stamp', data=time_stamps)
hf.create_dataset('prey_pop', data=prey_list)
hf.create_dataset('pred_pop', data=predator_list)
hf.close()
Loading