Skip to content

Commit 6d63fe6

Browse files
committed
fix
1 parent 9f81cb5 commit 6d63fe6

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

predict.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch.nn as nn
1313
import torch.nn.functional as F
1414

15-
from network import NetworkLarge as Network
15+
from utils import build_network
1616

1717
def predict(net, device, nb_grid):
1818
x = np.linspace(-1.5, 1.5, nb_grid)
@@ -46,7 +46,7 @@ def predict(net, device, nb_grid):
4646
use_cuda = torch.cuda.is_available()
4747
device = torch.device("cuda" if use_cuda else "cpu")
4848

49-
net = Network(input_dim=3)
49+
net = build_network(input_dim=3)
5050
net.to(device)
5151
net.load_state_dict(torch.load('./models/{}_model.pth'.format(name), map_location=device))
5252

predict_toy_3d.py

-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch.nn as nn
1111
import torch.nn.functional as F
1212

13-
from network import Network
1413
from utils import build_network
1514

1615
def predict(net):

0 commit comments

Comments
 (0)