Skip to content

Commit aab0d3d

Browse files
mszhanyimalfet
andauthored
Basic CNN RNN Test (pytorch#564)
* rnn cnn smoke test * only for cuda11 first * rm cuda11 condition * some refactor * add missing file * fix wording error * minor change * share the smoke tests for linux * Update check_binary.sh Co-authored-by: Nikita Shulga <[email protected]>
1 parent a15fac8 commit aab0d3d

File tree

4 files changed

+64
-0
lines changed

4 files changed

+64
-0
lines changed

check_binary.sh

+7
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,13 @@ if [[ "$DESIRED_CUDA" != 'cpu' && "$DESIRED_CUDA" != *"rocm"* ]]; then
362362

363363
echo "Checking that CuDNN is available"
364364
python -c 'import torch; exit(0 if torch.backends.cudnn.is_available() else 1)'
365+
366+
echo "Checking that basic RNN works"
367+
python ${TEST_CODE_DIR}/rnn_smoke.py
368+
369+
echo "Checking that basic CNN works"
370+
python ${TEST_CODE_DIR}/cnn_smoke.py
371+
365372
popd
366373
fi # if libtorch
367374
fi # if cuda

test_example_code/cnn_smoke.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
r"""
2+
It's used to check basic rnn features with cuda.
3+
For example, it would throw exception if some components are missing
4+
"""
5+
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
import torch.optim as optim
10+
11+
class SimpleCNN(nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
self.conv = nn.Conv2d(1, 1, 3)
15+
self.pool = nn.MaxPool2d(2, 2)
16+
17+
def forward(self, inputs):
18+
output = self.pool(F.relu(self.conv(inputs)))
19+
output = output.view(1)
20+
return output
21+
22+
# Mock one infer
23+
device = torch.device("cuda:0")
24+
net = SimpleCNN().to(device)
25+
net_inputs = torch.rand((1, 1, 5, 5), device=device)
26+
outputs = net(net_inputs)
27+
print(outputs)
28+
29+
criterion = nn.MSELoss()
30+
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.1)
31+
32+
# Mock one step training
33+
label = torch.full((1,), 1.0, dtype=torch.float, device=device)
34+
loss = criterion(outputs, label)
35+
loss.backward()
36+
optimizer.step()

test_example_code/rnn_smoke.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
r"""
2+
It's used to check basic rnn features with cuda.
3+
For example, it would throw exception if missing some components are missing
4+
"""
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
device = torch.device("cuda:0")
10+
rnn = nn.RNN(10, 20, 2).to(device)
11+
inputs = torch.randn(5, 3, 10).to(device)
12+
h0 = torch.randn(2, 3, 20).to(device)
13+
output, hn = rnn(inputs, h0)

windows/internal/smoke_test.bat

+8
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ echo Checking that CuDNN is available
146146
python -c "import torch; exit(0 if torch.backends.cudnn.is_available() else 1)"
147147
if ERRORLEVEL 1 exit /b 1
148148

149+
echo Checking that basic RNN works
150+
python %BUILDER_ROOT%\test_example_code\rnn_smoke.py
151+
if ERRORLEVEL 1 exit /b 1
152+
153+
echo Checking that basic CNN works
154+
python %BUILDER_ROOT%\test_example_code\cnn_smoke.py
155+
if ERRORLEVEL 1 exit /b 1
156+
149157
goto end
150158

151159
:libtorch

0 commit comments

Comments
 (0)