Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using parallel jax solver #61

Open
martinjrobins opened this issue Nov 8, 2021 · 6 comments
Open

using parallel jax solver #61

martinjrobins opened this issue Nov 8, 2021 · 6 comments
Labels
enhancement New feature or request

Comments

@martinjrobins
Copy link

@TomTranter : here is an example script that runs 20 instances of the spm model in parallel using jax. Let me know if this is what you need.

import time
import pybamm
import numpy as np
import jax
import matplotlib.pylab as plt

import os
# specify 20 logical devices for execution
ncpu = 20
os.environ['XLA_FLAGS'] = (
    '--xla_force_host_platform_device_count={}'.format(ncpu)
)

# print out the available devices

print('devices', jax.devices())

pybamm.set_logging_level("INFO")
model = pybamm.lithium_ion.SPM()
model.convert_to_format = "jax"
model.events = []

# create geometry
geometry = model.default_geometry

# load parameter values and process model and geometry
param = model.default_parameter_values
parameter = "Electrode height [m]"
value = param[parameter]
param.update({parameter: "[input]"})
param.process_model(model)
param.process_geometry(geometry)

# set mesh
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)

# discretise model
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

# solve model for 1 hour
t_eval = np.linspace(0, 3600, 100)
solver = pybamm.JaxSolver()

# the model is setup and compiled (expensive) during this call
solution = solver.solve(
    model, t_eval,
    inputs={parameter: value},
)

# create a new jax function that can take an array of inputs
mapped_jax_function = jax.pmap(
    solver.get_solve(model, t_eval),
)

# now our input is an array of "parameter",
# size needs to be the same as the number
# of devices
inputs_array = {
    parameter: jax.numpy.linspace(value / 10, value * 10, ncpu)
}

# the multiple inputs are executed in parallel here,
# the result is a 3d array of shape (Ni, Ns, Nt), where
# Ni is the size of the parameter array, Ns is the size of the
# full state vector, and Nt is the number of timesteps in t_eval
print('running in parallel')
tic = time.perf_counter()
result_array = mapped_jax_function(inputs_array)

# access the result so its actually computed
print(result_array[0, 0, 0])
toc = time.perf_counter()
print('time elapsed: {} sec', toc - tic)
@TomTranter
Copy link
Collaborator

Fantastic thanks. Will you be joining the dev meeting later? I will play with it now and may have a few questions

@martinjrobins
Copy link
Author

I've got a shortlisting meeting that should end at 2pm, I can join after that or give you a call once its finished

@wigging
Copy link
Collaborator

wigging commented Nov 8, 2021

@martinjrobins Can Jax perform in parallel using multiple machines?

@TomTranter TomTranter added the enhancement New feature or request label Nov 19, 2021
@wigging
Copy link
Collaborator

wigging commented Feb 9, 2022

I tried to run the Jax example that @martinjrobins posted above but I get the following error:

ValueError: model.timescale must be a Scalar after parameter processing
(cannot contain 'InputParameter's). You have probably set one of the
parameters used to calculate the timescale to an InputParameter. To avoid
this error, hardcode model.timescale to a constant value by passing the
option {'timescale': value} to the model.

@TomTranter
Copy link
Collaborator

@wigging yes we changed how timescale is implemented recently in pybamm. Can you run it with the suggested option for a scalar timescale

@wigging
Copy link
Collaborator

wigging commented Feb 10, 2022

Using model = pybamm.lithium_ion.SPM(options={"timescale": 1.0}) makes the example work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants