Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate EBM into the pytorch framework #521

Open
JWKKWJ123 opened this issue Mar 14, 2024 · 7 comments
Open

Integrate EBM into the pytorch framework #521

JWKKWJ123 opened this issue Mar 14, 2024 · 7 comments

Comments

@JWKKWJ123
Copy link

Hi all,
I want to use EBM as a GAM to replace the fully connected layer at the end of a large CNN/Transformer to get interpretable output. However, I need to train the EBM like a deep learning model, with mini batches of data as input.
I would like to ask is it possible to train the model step by step (batch by batch) instead of use the end-to-end fit() function? Or are there some people already working on this?
Yours Sincerely,
Wenjie Kang

@sunnycasmir
Copy link

This is a novel and intriguing method: training an Energy-Based Model (EBM) as a Generalised Additive Model (GAM) inside a huge CNN or Transformer architecture. While training classic EBMs usually involves end-to-end optimisation techniques, it is possible to modify them to operate within a broader neural network architecture and train them incrementally (batch-by-batch).

Here is a code example of how you can achieve this:

#import necessary libries
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

#define neural network architecture with an EBM layer
class CNNWithEBM(nn.Module):
def init(self):
super(CNNWithEBM, self).init()
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.ebm = nn.Linear(32 * 14 * 14, 1) # Example linear EBM layer

def forward(self, x):
    x = self.cnn(x)
    x = x.view(x.size(0), -1)  # Flatten the output
    energy = self.ebm(x)
    return energy

#obtain synthetic dataset and define training loop

Generate synthetic dataset

def generate_data(batch_size=32):
# Generate random data and labels
data = torch.randn(batch_size, 1, 28, 28) # MNIST-like data
labels = torch.randint(0, 2, (batch_size,))
return data, labels

Instantiate the model

model = CNNWithEBM()

Define loss function (energy-based loss)

criterion = nn.MSELoss()

Define optimizer

optimizer = optim.Adam(model.parameters(), lr=0.001)

Training loop

num_epochs = 10
batch_size = 32
for epoch in range(num_epochs):
total_loss = 0.0
for batch_idx in range(num_batches):
# Generate mini-batch data
data, labels = generate_data(batch_size)

    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    energy = model(data)

    # Compute loss
    loss = criterion(energy.squeeze(), labels.float())  # Energy-based loss

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()

    # Accumulate total loss
    total_loss += loss.item()

# Print average loss for the epoch
print(f"Epoch {epoch + 1}, Avg. Loss: {total_loss / num_batches:.4f}")

I hope that this helps.
Thank you

@JWKKWJ123
Copy link
Author

Dear Sunnycasmir,
Thank you for your reply!

This is a novel and intriguing method: training an Energy-Based Model (EBM) as a Generalised Additive Model (GAM) inside a huge CNN or Transformer architecture. While training classic EBMs usually involves end-to-end optimisation techniques, it is possible to modify them to operate within a broader neural network architecture and train them incrementally (batch-by-batch).

Here is a code example of how you can achieve this:

#import necessary libries import torch import torch.nn as nn import torch.optim as optim import numpy as np

#define neural network architecture with an EBM layer class CNNWithEBM(nn.Module): def init(self): super(CNNWithEBM, self).init() self.cnn = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.ebm = nn.Linear(32 * 14 * 14, 1) # Example linear EBM layer

def forward(self, x):
    x = self.cnn(x)
    x = x.view(x.size(0), -1)  # Flatten the output
    energy = self.ebm(x)
    return energy

#obtain synthetic dataset and define training loop

Generate synthetic dataset

def generate_data(batch_size=32): # Generate random data and labels data = torch.randn(batch_size, 1, 28, 28) # MNIST-like data labels = torch.randint(0, 2, (batch_size,)) return data, labels

Instantiate the model

model = CNNWithEBM()

Define loss function (energy-based loss)

criterion = nn.MSELoss()

Define optimizer

optimizer = optim.Adam(model.parameters(), lr=0.001)

Training loop

num_epochs = 10 batch_size = 32 for epoch in range(num_epochs): total_loss = 0.0 for batch_idx in range(num_batches): # Generate mini-batch data data, labels = generate_data(batch_size)

    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    energy = model(data)

    # Compute loss
    loss = criterion(energy.squeeze(), labels.float())  # Energy-based loss

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()

    # Accumulate total loss
    total_loss += loss.item()

# Print average loss for the epoch
print(f"Epoch {epoch + 1}, Avg. Loss: {total_loss / num_batches:.4f}")

I hope that this helps. Thank you

Dear Sunnycasmir,
Thank you very much for your reply!
More specifically, I want to use EBM (explainable boosting machine) as the output layer of a large CNN/transformer. I considered using EBM as a custom layer of torch, but this would make EBM untrainable. So my question is how to train EBM incrementally (batch-by-batch) as a custom layer of torch? I think the example code didn't solve this question.

@sunnycasmir
Copy link

Is it possible to see the code you are working on to see how I can contribute more

@JWKKWJ123
Copy link
Author

JWKKWJ123 commented Mar 18, 2024

Hi all,
I have some update this week:
I think the main difficulty is the deep-learning models and GAMs (including EBM) have very different training strategies. The GAMs need to read all training data at once and update the weights of all shape functions in the residuals sequentially. And the deep-learning models need to take the training data in mini-batch because of the memory limit (I use batchsize of 4 now), and update the model step by step.
I would like to use the EBM as the output block in a large end-to-end 3D CNN. Then the question will be: Can the EBM be progressively updated step by step (mini-batch by mini-batch) simultaneously with CNN?
I am trying to use the ebm.merge() to train the EBM in batchs and it seems work with a large batch.
This is the code that I put EBM in to a deep learning model, now I made EBM untrainable in a CNN, because I am going to alternatively train EBM and CNN:

class EBM_layer(nn.Module):
    def __init__(self, **kwargs):
        super(EBM_layer, self).__init__(**kwargs)

    def forward(self, x, ebm):
    
        x = x.detach().cpu().numpy()
        output_pro_ebm = ebm.predict_proba(x)
        output_pro_ebm = output_pro_ebm[:,1]
        output_pro_ebm = torch.tensor(output_pro_ebm, requires_grad=True)
        output_pro_ebm = output_pro_ebm.unsqueeze(1)

        return output_pro_ebm
def forward(self, x,ebm): #now I train EBM and CNN alternatively, so I input a trained ebm to the model in each epoch
    for i in range(0,N):
       out = self.cnnlist[i](x)
       out_all=torch.cat([out_all,out],1) #this is the concatenation of the feature extracted by multiple CNNs
    out_pro = self.EBM_layer(out_all,ebm)

    return out_pro

@paulbkoch
Copy link
Collaborator

Hi @JWKKWJ123 -- This kind of federated learning approach isn't something that we support out of the box. You can kind of hack it as you've discovered using merge_ebms, but the implementation isn't ideal. At some point we'll provide a better interface for building EBMs one boosting round at a time, and from batches.

Your other point though about DNNs and EBMs (based on decision trees) is quite pertinent too though. The training strategies are quite different and it's not clear to me that bringing them together will result in an ideal union. An alternative approach that I might suggest would be to train the DNN as normal, then remove the last layer, and train the EBM to replace it on the now frozen DNN. Will this approach work for you?

@JWKKWJ123
Copy link
Author

Hi @JWKKWJ123 -- This kind of federated learning approach isn't something that we support out of the box. You can kind of hack it as you've discovered using merge_ebms, but the implementation isn't ideal. At some point we'll provide a better interface for building EBMs one boosting round at a time, and from batches.

Your other point though about DNNs and EBMs (based on decision trees) is quite pertinent too though. The training strategies are quite different and it's not clear to me that bringing them together will result in an ideal union. An alternative approach that I might suggest would be to train the DNN as normal, then remove the last layer, and train the EBM to replace it on the now frozen DNN. Will this approach work for you?

Dear Paul,
Thank you very much for your reply! I'm glad I've made some progress now.
I found it is possible to use merge.ebm() to train ebm in batch with DNN. But now I am using a huge DNN so I can just set the batchsize to 4, and these training strategy cannot work when batchsize < 10.
So after trails and errors, I developed a new training strategy (figure below), which is the train the model alternatively in two stages. Now I use is in a case that take both take whole image (global) and image patches (local) as input, each path way in the end-to-end model is a CNN:
image

@JWKKWJ123
Copy link
Author

This training strategy works (I accidentally added the accuracy twice in the epoch between the two stages). It can provide the contributions of different pathways in a large composite DNN, without sacrificing performance:
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

3 participants