Skip to content

Commit

Permalink
TorchScript bad_alloc issue (pytorch#2542)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#2542

Differential Revision: D65495806
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Nov 6, 2024
1 parent 509b0d2 commit 3b13683
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion torchrec/models/tests/test_deepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_fx_script(self) -> None:
sparse_features=sparse_features,
)

gm = symbolic_trace(deepfm_nn)
gm = torch.fx.GraphModule(deepfm_nn, Tracer().trace(deepfm_nn))

scripted_gm = torch.jit.script(gm)

Expand Down
3 changes: 2 additions & 1 deletion torchrec/modules/tests/test_mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Dict

import torch
from torchrec.fx import Tracer
from torchrec.modules.mc_modules import (
average_threshold_filter,
DistanceLFU_EvictionPolicy,
Expand Down Expand Up @@ -357,5 +358,5 @@ def test_fx_jit_script_not_training(self) -> None:
)

model.train(False)
gm = torch.fx.symbolic_trace(model)
gm = torch.fx.GraphModule(model, Tracer().trace(model))
torch.jit.script(gm)
6 changes: 3 additions & 3 deletions torchrec/modules/tests/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from hypothesis import given, settings
from torch import nn
from torchrec.fx import symbolic_trace
from torchrec.fx import symbolic_trace, Tracer
from torchrec.modules.mlp import MLP, Perceptron


Expand Down Expand Up @@ -99,13 +99,13 @@ def test_fx_script_Perceptron(self) -> None:
# Dry-run to initialize lazy module.
m(torch.randn(batch_size, in_features))

gm = symbolic_trace(m)
gm = torch.fx.GraphModule(m, Tracer().trace(m))
torch.jit.script(gm)

def test_fx_script_MLP(self) -> None:
in_features = 3
layer_sizes = [16, 8, 4]
m = MLP(in_features, layer_sizes)

gm = symbolic_trace(m)
gm = torch.fx.GraphModule(m, Tracer().trace(m))
torch.jit.script(gm)

0 comments on commit 3b13683

Please sign in to comment.