-
-
Notifications
You must be signed in to change notification settings - Fork 139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Best practices to convert torch.nn.Module
to eqx.Module
#396
Comments
torch.nn.Module
to eqx.Module
torch.nn.Module
to eqx.Module
Hey there! Buffers are most simply handled by storing them as an array (just like a parameter) and then calling Alternatively, you can follow the freeze parameter example, which involves passing them through Modules are typically stored as attributes just like parameters, e.g. see the source code for If you need dynamically-named parameters then you can store those in a dictionary, and then store the dictionary on the parent module. Modules themselves are not variadically-sized. (Note that this dynamism should only happen at |
Thanks @patrick-kidger for your answer! :) I'd still have a question on how best handle the following scenario, where I have a linear layer which matrix
I would be really keen on knowing your opinion, as none of the two options seems ideal :)
|
I think doing the conversion before inference time probably makes most sense. Here's an example of training a linear layer with a symmetric weight matrix: #
# Train-time: resolve on-the-fly.
#
class Symmetric(eqx.Module):
array: Array
def get(self):
return 0.5 * (self.array + self.array.T)
is_symmetric = lambda x: isinstance(x, Symmetric)
@eqx.filter_jit
def train_step(model, ...):
model = jax.tree_util.tree_map(lambda x: x.get() if is_symmetric(x) else x, model, is_leaf=is_symmetric)
... # compute gradients, update, etc.
model = eqx.nn.Linear(...)
model = eqx.tree_at(lambda m: m.weight, model, replace_fn=Symmetric)
for _ in range(steps):
model = train_step(model, ...)
#
# Inference time: perform conversion.
#
inference_model = eqx.tree_inference(model, True)
inference_model = jax.tree_util.tree_map(lambda x: x.get() if is_symmetric(x) else x, inference_model, is_leaf=is_symmetric)
inference_model(...) # evaluate Doing some kind of train->inference conversion is pretty common -- e.g. quantisation, pruning, absorbing adjacent batchnorm and linear layers into a single linear transformation, etc. etc. Also, note that I don't do something like This is a deliberate design choice, as it helps to reason about changes in the presence of jit, grad, etc. |
Thanks that's really useful, wasn't aware of I eventually implemented something like
|
@patrick-kidger |
Yep, this is totally possible. First of all, if you just want to have non-learnt arrays then call def __call__(self, ...):
buffer = lax.stop_gradient(self.buffer)
# now use `buffer` wherever you want to use it. If you need to do something more complicated with filtering, then you can use a wrapper class: class FooArray(eqx.Module):
array: Array
class Model(eqx.Module):
def __init__(self, ...):
self.foo = FooArray(some_array)
...
model = Model(...)
is_foo = lambda x: isinstance(x, FooArray)
has_foo, no_foo = eqx.partition(model, is_foo, is_leaf=is_foo) Here's a fully-fledged example for creating a linear transformation with a symmetric matrix. |
Thanks @patrick-kidger! If that's something you're interested in and/or have any suggestions/remarks I'd be keen on hearing them :) |
Thanks! I've just had a quick look.
|
Regarding this, I completely agree, how can I achieve this? with something like the following? is_layer = lambda m: isinstance(m, eqx.Module)
new = jax.tree_util.tree_map(lambda m: m.train(mode), self, is_leaf=is_layer) is |
@patrick-kidger would you have an idea by any chance whether it's usually better/faster in |
Also to handle both statelful and stateless modules I found myself adding something like for layer in self.layers:
if "state" in inspect.signature(layer).parameters:
x, state = layer(x, state)
else:
x = layer(x) Is there any way around? Could wrap the stateless module with def state_wrapper(layer: eqx.Module):
if "state" in inspect.signature(layer).parameters:
return layer
else:
return lambda x, state: layer(x), state or something like Would it be worth adding to |
I'd recommend against this. Equinox modules are really just pytrees like any other, so it's not appropriate to special case them. Moreover what if there is some non-E3NN-Module that doesn't implement a In the spirit of nominative subtyping, I would instead recommend the following pattern: # Declare that this method should exist
class E3NNModule(eqx.Module):
@abc.abstractmethod
def train(self, mode):
...
# Now go looking for such layers, knowing that the train method must exist.
is_layer = lambda m: isinstance(m, E3NNModule)
new = jax.tree_util.tree_map(lambda m: m.train(mode), self, is_leaf=is_layer)
# On your concrete classes, go ahead and provide an implementation.
class SomeModule(E3NNModule):
def train(self, mode): If you have nested Incidentally the above is exactly the sort of thing I do very widely across my JAX libraries -- I'm a big fan of using ABCs to explicitly declare what interfaces are available.
I would recommend (2). JAX's heuristics for in-place updates are sometimes not great.
Hmm, these aren't really designed to be used interchangeably. After all, one could easily define a module with a completely arbitrary custom signature, it's not like the only two valid ones are Stateful layers are pretty unusual -- in particular batchnorm is used very infrequently. What's your use case? def state_wrapper(layer: eqx.Module):
if "state" in inspect.signature(layer).parameters:
return layer
else:
return lambda x, state: layer(x), state Note that this snippet is dangerous. The |
Thanks for the great package!
I was wondering whether there was some documentation regarding the best practice for converting
torch.nn.Module
toeqx.Module
?In particular
register_parameters
would be replace by an attribute .e.g.weights: Array
register_buffer
?add_module(name, intertwiner_basis)
? especially when thename
is not known in advanced e.g.f"module_{variable}"
Thanks a lot!
(for context I'm looking at porting
escnn
tojax
cf QUVA-Lab/escnn#55)The text was updated successfully, but these errors were encountered: