Skip to content

Commit bbf180a

Browse files
Thiago Crepaldipytorchmergebot
Thiago Crepaldi
authored andcommitted
Add new aten::device variant to TorchScript (pytorch#97023)
Fixes pytorch#96627 Pull Request resolved: pytorch#97023 Approved by: https://github.com/jgong5, https://github.com/BowenBao, https://github.com/davidberard98
1 parent d1e7434 commit bbf180a

File tree

5 files changed

+61
-0
lines changed

5 files changed

+61
-0
lines changed

test/onnx/test_pytorch_onnx_no_runtime.py

+24
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,30 @@ def forward(self, x):
12411241
double_type_count += 1
12421242
self.assertNotEqual(double_type_count, 0)
12431243

1244+
@pytorch_test_common.skipIfNoCuda
1245+
def test_aten_device_with_index(self):
1246+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
1247+
1248+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
1249+
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
1250+
model = torch.compile(model, backend="onnxrt")
1251+
model = model.eval()
1252+
device = "cuda:0"
1253+
model = model.to(device)
1254+
ids = tokenizer.batch_encode_plus(["This is a test"], return_tensors="pt").to(
1255+
device
1256+
)
1257+
1258+
with torch.no_grad():
1259+
_ = model(
1260+
**{
1261+
"input_ids": ids["input_ids"],
1262+
"attention_mask": ids["attention_mask"],
1263+
"decoder_input_ids": ids["input_ids"],
1264+
"decoder_attention_mask": ids["attention_mask"],
1265+
}
1266+
)
1267+
12441268

12451269
if __name__ == "__main__":
12461270
common_utils.run_tests()

torch/csrc/jit/mobile/promoted_prim_ops.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ void device(Stack& stack) {
9292
push(stack, pop(stack).toTensor().device());
9393
}
9494

95+
void device_with_index(Stack& stack) {
96+
std::string type = pop(stack).toStringRef();
97+
int index = pop(stack).toInt();
98+
std::string device_str = type + ":" + std::to_string(index);
99+
auto device = c10::Device(device_str);
100+
push(stack, device);
101+
}
102+
95103
void dtype(Stack& stack) {
96104
at::Tensor a;
97105
pop(stack, a);

torch/csrc/jit/mobile/promoted_prim_ops.h

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ void sym_stride(Stack& stack);
3333

3434
void device(Stack& stack);
3535

36+
void device_with_index(Stack& stack);
37+
3638
void dtype(Stack& stack);
3739

3840
void layout(Stack& stack);

torch/csrc/jit/passes/peephole.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,28 @@ struct PeepholeOptimizeImpl {
269269
node->output()->replaceAllUsesWith(output);
270270
changed = true;
271271
}
272+
} else if (
273+
node->matches("aten::device(str type, int index) -> Device") &&
274+
shape_peepholes_) {
275+
auto string_type = node->inputs().at(0)->type()->expect<StringType>();
276+
if (string_type) {
277+
WithInsertPoint guard(node);
278+
std::string type_str = node->inputs().at(0)->node()->s(attr::value);
279+
auto maybe_index = toIValue(node->inputs().at(1));
280+
int64_t index = 0;
281+
if (maybe_index) {
282+
index = maybe_index->toInt();
283+
}
284+
auto device = c10::Device(type_str + ":" + std::to_string(index));
285+
auto output = node->owningGraph()->insertConstant(device);
286+
GRAPH_UPDATE(
287+
"Replacing ",
288+
getHeader(node),
289+
" with a device constant ",
290+
output->debugName());
291+
node->output()->replaceAllUsesWith(output);
292+
changed = true;
293+
}
272294
} else if (
273295
node->matches("aten::dim(Tensor self) -> int") && shape_peepholes_) {
274296
auto ptt = node->input()->type()->expect<TensorType>();

torch/csrc/jit/runtime/register_prim_ops.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -2292,6 +2292,11 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs1{
22922292
push(stack, c10::Device(pop(stack).toStringRef()));
22932293
},
22942294
aliasAnalysisFromSchema()),
2295+
OperatorGeneratorArgs(
2296+
TORCH_SELECTIVE_SCHEMA(
2297+
"aten::device.with_index(str type, int index) -> Device"),
2298+
device_with_index,
2299+
aliasAnalysisFromSchema()),
22952300
OperatorGeneratorArgs(
22962301
TORCH_SELECTIVE_SCHEMA("aten::percentFormat(str self, ...) -> str"),
22972302
[](Stack& stack) {

0 commit comments

Comments
 (0)