Skip to content

Commit

Permalink
Add Lambda nn.Module that calls an arbitrary function in the forw…
Browse files Browse the repository at this point in the history
…ard pass (#176)
  • Loading branch information
nathanpainchaud authored Oct 24, 2023
1 parent 991229f commit 58e47bd
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions vital/models/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from collections import OrderedDict
from functools import wraps
from typing import Any, Callable, Dict, List, Sequence, Tuple
Expand Down Expand Up @@ -170,3 +171,35 @@ def reparameterize(mu: Tensor, logvar: Tensor) -> Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std


class Lambda(nn.Module):
"""Layer to call an arbitrary function on an input tensor."""

def __init__(self, fn: Callable[[Tensor, Any, ...], Tensor], **kwargs):
"""Stores function object and parameters.
Args:
fn: Function to call on the input tensor in the forward pass.
**kwargs: Parameters to pass along to the arbitrary function.
"""
super().__init__()
self.fn = fn
self.kwargs = kwargs

def __repr__(self):
"""Overrides the default repr to display the name and arguments of the function."""
fn_obj = self.fn if not isinstance(self.fn, functools.partial) else self.fn.func
kwargs_str = ["?"] + [f"{k}={v}" for k, v in self.kwargs.items()]
return f"{self.__class__.__name__}({fn_obj.__name__}({', '.join(kwargs_str)}))"

def forward(self, x: Tensor) -> Tensor:
"""Call the function on a tensor.
Args:
x: Input tensor.
Returns:
Output tensor after applying the function to the input tensor.
"""
return self.fn(x, **self.kwargs)

0 comments on commit 58e47bd

Please sign in to comment.