Description
I wanted to share some thoughts on an alternative saving and loading strategy we might explore that revolves around pickling pytensor graphs instead of using config information to rebuild the model.
The weakness with configs is that they are brittle. If the state of the codebase changes, then old configs can cause unexpected behaviour after people update the code. We have to rely on the assumption that the config information can rebuild exactly the same model in the future. If we add a new feature or change the meaning of certain config options, those assumptions break. It's hard to know every config option we want to offer in advance so I expect that this problem will occur from time to time. When retraining is very expense, then it's important to be able to keep old models working while experimenting with new ones.
One way to reduce complexity here is to make the model's structure a static artifact that can be loaded. Then you don't have to worry about aligning configs with model structure. It also means people can iterate on experimental features faster. You don't have to wait for a fully configurable version of your feature to arrive. You build your feature anyway you like and still keep saving and loading. So to save a model, we'd save off an idata.nc
plus a pytensor_graph.pkl
.
Training and saving a model could look like this:
import pickle
import pymc as pm
import arviz as az
from pymc.model.fgraph import fgraph_from_model, model_from_fgraph
with pm.Model(coords=coords) as model:
# some mode code here
# convert to a pure pytensor object
fgraph = fgraph_from_model(model)[0]
# save pytensor object
f = open('obj.save', 'wb')
pickle.dump(fgraph, f, protocol=pickle.HIGHEST_PROTOCOL)
f.close()
# save trace
trace.to_netcdf("trace.nc")
Loading a model and then predicting off it could work like this:
# load pytensor graph
f = open('obj.save', 'rb')
loaded_obj = pickle.load(f)
f.close()
loaded_model = model_from_fgraph(fgraph)
# load the idata
trace_loaded = az.from_netcdf("trace.nc")
with loaded_model :
predictions = pm.sample_posterior_predictive(trace=trace_loaded,var_names=["some variables"])
To my mind, the three things people want to do after loading a model are:
- Make predictions for the future or for counterfactuals
- Make plots
- Optimize
In all three cases, all the model-specific information we want can be read off a pytensor graph. The tricky bit I see is enforcing conventions around what's a named node in the graph, what's a deterministic on the pymc side, and what all of them are called so that each of those tasks can exploit consistent information.