Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can not compile on cuda:1 on multi-gpus #2668

Closed
alibool opened this issue Mar 3, 2024 · 5 comments
Closed

Can not compile on cuda:1 on multi-gpus #2668

alibool opened this issue Mar 3, 2024 · 5 comments
Assignees
Labels
bug: triaged [verified] We can replicate the bug bug Something isn't working

Comments

@alibool
Copy link

alibool commented Mar 3, 2024

The following code works well with 'cuda:0', but will raise error when use 'cuda:1'

import torch
from torchvision import models
import torch_tensorrt


dtype = torch.float32
# works fine when cuda:0 but error on cuda:1
device = torch.device("cuda:0")

model = models.resnet50().to(dtype).to(device)
model.eval()
inputs_1 = torch.rand(12, 3, 224, 224).to(device).to(dtype)

optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False, fullgraph=True,
                                options={
                                    "precision": dtype,
                                    "device": device
                                })

with torch.no_grad():
    useless = optimized_model(inputs_1)
print(useless)

Two gpus are both A100, pytorch 2.2.1 cu118 torch-tensorrt 2.2.0+cu118 tensorrt 8.6.0

@alibool alibool added the bug Something isn't working label Mar 3, 2024
@gs-olive
Copy link
Collaborator

gs-olive commented Mar 4, 2024

Hi - thanks for the report - could you share the error logs when compiling with cuda:1, preferably with "debug": True?

@gs-olive
Copy link
Collaborator

gs-olive commented Mar 5, 2024

As an update on this, I am seeing the following output when running your script:

ERROR: [Torch-TensorRT] - 1: [cudaResources.cpp::~ScopedCudaEvent::24] Error Code 1: Cuda Runtime (an illegal memory access was encountered)
ERROR: [Torch-TensorRT] - 1: [defaultAllocator.cpp::deallocate::61] Error Code 1: Cuda Runtime (an illegal memory access was encountered)

It seems to appear because the active CUDA device in the context is being overwritten or incorrectly set. I am looking into the issue. For a temporary workaround, consider using the prefix argument CUDA_VISIBLE_DEVICES=1 when running the script, to have cuda:1 act as cuda:0 from within the script

@gs-olive gs-olive added the bug: triaged [verified] We can replicate the bug label Mar 5, 2024
@gs-olive
Copy link
Collaborator

gs-olive commented Mar 5, 2024

I have some additional details to share on this, as well as the proposed fix.

In the initial code, the device parameter was set to "cuda:1", but the ambient active CUDA device in the thread was the default ("cuda:0"). In order to have the active thread use "cuda:1", it is necessary to add the line device = torch.device("cuda:1"); torch.cuda.set_device(device) before compilation. This makes the compilation and inference functional on my end.

If the intended usecase was to compile/run on multiple GPUs, I would suggest using separate Python threads, each with a set_device call to ensure the ambient device in that thread is the intended one. Alternatively, you could enable torch_tensorrt.runtime.set_multi_device_safe_mode(True) - see below.

We recently added a feature, Multi-Device Runtime Safety, which gives more verbose messages regarding device contexts and will automatically switch devices if a mismatch is detected, at the cost of runtime device-checking. This can be enabled via torch_tensorrt.runtime.set_multi_device_safe_mode(True). The default is False.

This code works on my end, please let me know if it also resolves your case:

import torch
from torchvision import models
import torch_tensorrt


dtype = torch.float32
device = torch.device("cuda:1")
torch.cuda.set_device(device)

model = models.resnet50().to(dtype).to(device)
model.eval()
inputs_1 = torch.rand(12, 3, 224, 224).to(device).to(dtype)

optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False, fullgraph=True,
                                options={
                                    "precision": dtype,
                                    "device": device
                                })

with torch.no_grad():
    useless = optimized_model(inputs_1)
print(useless)

@alibool
Copy link
Author

alibool commented Mar 5, 2024

Thanks

I have some additional details to share on this, as well as the proposed fix.

In the initial code, the device parameter was set to "cuda:1", but the ambient active CUDA device in the thread was the default ("cuda:0"). In order to have the active thread use "cuda:1", it is necessary to add the line device = torch.device("cuda:1"); torch.cuda.set_device(device) before compilation. This makes the compilation and inference functional on my end.

If the intended usecase was to compile/run on multiple GPUs, I would suggest using separate Python threads, each with a set_device call to ensure the ambient device in that thread is the intended one. Alternatively, you could enable torch_tensorrt.runtime.set_multi_device_safe_mode(True) - see below.

We recently added a feature, Multi-Device Runtime Safety, which gives more verbose messages regarding device contexts and will automatically switch devices if a mismatch is detected, at the cost of runtime device-checking. This can be enabled via torch_tensorrt.runtime.set_multi_device_safe_mode(True). The default is False.

This code works on my end, please let me know if it also resolves your case:

import torch
from torchvision import models
import torch_tensorrt


dtype = torch.float32
device = torch.device("cuda:1")
torch.cuda.set_device(device)

model = models.resnet50().to(dtype).to(device)
model.eval()
inputs_1 = torch.rand(12, 3, 224, 224).to(device).to(dtype)

optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False, fullgraph=True,
                                options={
                                    "precision": dtype,
                                    "device": device
                                })

with torch.no_grad():
    useless = optimized_model(inputs_1)
print(useless)

Thanks, when will be this nice feature released and should I close this issue?

@gs-olive
Copy link
Collaborator

gs-olive commented Mar 5, 2024

Hi - this feature should already be in the version you are using. Once you have verified the resolution works for you, feel free to close this bug and open a new one if there are other issues.

@alibool alibool closed this as completed Mar 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug: triaged [verified] We can replicate the bug bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants