-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(docs): add half-precision training section in using_simulator docs #678
Open
PabloCarmona
wants to merge
5
commits into
master
Choose a base branch
from
feat-677-half-precision-docs
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
7ee0f44
feat(docs): add half-precision training section in using_simulator docs
PabloCarmona 0877c23
feat(examples): add example 31 for half precision training
PabloCarmona 2ba2058
Merge branch 'master' into feat-677-half-precision-docs
PabloCarmona 4e1bf93
format example with black
PabloCarmona 7c509ca
Merge branch 'master' into feat-677-half-precision-docs
kaoutar55 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# type: ignore | ||
# pylint: disable-all | ||
# -*- coding: utf-8 -*- | ||
|
||
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
"""aihwkit example 31: Using half precision training. | ||
|
||
This example demonstrates how to use half precision training with aihwkit. | ||
|
||
""" | ||
# pylint: disable=invalid-name | ||
|
||
import tqdm | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
from aihwkit.simulator.configs import InferenceRPUConfig, TorchInferenceRPUConfig | ||
from aihwkit.nn.conversion import convert_to_analog | ||
from aihwkit.optim import AnalogSGD | ||
from aihwkit.simulator.parameters.enums import RPUDataType | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 32, 3, 1) | ||
self.conv2 = nn.Conv2d(32, 64, 3, 1) | ||
self.dropout1 = nn.Dropout(0.25) | ||
self.dropout2 = nn.Dropout(0.5) | ||
self.fc1 = nn.Linear(9216, 128) | ||
self.fc2 = nn.Linear(128, 10) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = F.relu(x) | ||
x = self.conv2(x) | ||
x = F.relu(x) | ||
x = F.max_pool2d(x, 2) | ||
x = self.dropout1(x) | ||
x = torch.flatten(x, 1) | ||
x = self.fc1(x) | ||
x = F.relu(x) | ||
x = self.dropout2(x) | ||
x = self.fc2(x) | ||
output = F.log_softmax(x, dim=1) | ||
return output | ||
|
||
|
||
if __name__ == "__main__": | ||
model = Net() | ||
rpu_config = TorchInferenceRPUConfig() | ||
model = convert_to_analog(model, rpu_config) | ||
nll_loss = torch.nn.NLLLoss() | ||
transform = transforms.Compose( | ||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | ||
) | ||
dataset = datasets.MNIST("data", train=True, download=True, transform=transform) | ||
train_loader = torch.utils.data.DataLoader(dataset, batch_size=32) | ||
|
||
model = model.to(device=device, dtype=torch.bfloat16) | ||
optimizer = AnalogSGD(model.parameters(), lr=0.1) | ||
model = model.train() | ||
|
||
pbar = tqdm.tqdm(enumerate(train_loader)) | ||
for batch_idx, (data, target) in pbar: | ||
data, target = data.to(device=device, dtype=torch.bfloat16), target.to( | ||
device=device | ||
) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output.float(), target) | ||
loss.backward() | ||
optimizer.step() | ||
pbar.set_description(f"Loss {loss:.4f}") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this also work with other tiles that TorchInference? I had added a HALF type, like in the above:
rpu_config.runtime.data_type = RPUDataType.HALF
because the HALF could either be bfloat16 of float16 depending on the compilation options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There also a
runtime.data_type.as_torch()
or something function to convert it to the corresponding torch type, see https://github.com/IBM/aihwkit/blob/master/src/aihwkit/simulator/parameters/enums.py#L23