-
Notifications
You must be signed in to change notification settings - Fork 509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error in using multi GPUs #774
Comments
This error may occur when the version of transformers is higher than 4.44, you can add the following code when constructing the device map to fix this issue: |
Thank you very much for your suggestion. Now this problem is solved. I have another question: if I want to use a message type, which consists several "role"s and "content"s, which function should I use ? And how can I pass the parameters?
|
We organize the inputs into the message type in this part. You can implement your inference code by referring the the batch chat function of InternVL. |
Related to OpenGVLab#774 Add `rotary_emb` layer to device map in `split_model` function to fix multi GPU inference error. * Add `device_map['language_model.model.rotary_emb'] = 0` to the `split_model` function in `internvl_chat/internvl/model/__init__.py`.
Related to #774 Add `rotary_emb` layer to device map in `split_model` function to fix multi GPU inference error. * Add `device_map['language_model.model.rotary_emb'] = 0` to the `split_model` function in `internvl_chat/internvl/model/__init__.py`.
I'm sorry that when I use your code on huggingface to use multi GPUs, I encountered the following error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)
The script core is as follows:
def split_model(model_name):
device_map = {}
world_size = torch.cuda.device_count()
num_layers = {
'InternVL2-1B': 24, 'InternVL2-2B': 24, 'InternVL2-4B': 32, 'InternVL2-8B': 32,
'InternVL2-26B': 48, 'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name]
# Since the first GPU will be used for ViT, treat it as half a GPU.
num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
num_layers_per_gpu = [num_layers_per_gpu] * world_size
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f'language_model.model.layers.{layer_cnt}'] = i
layer_cnt += 1
device_map['vision_model'] = 0
device_map['mlp1'] = 0
device_map['language_model.model.tok_embeddings'] = 0
device_map['language_model.model.embed_tokens'] = 0
device_map['language_model.output'] = 0
device_map['language_model.model.norm'] = 0
device_map['language_model.lm_head'] = 0
device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
path = 'OpenGVLab/InternVL2-40B'
device_map = split_model('InternVL2-40B')
model = AutoModel.from_pretrained(
path,
torch_dtype=torch.bfloat16,
load_in_8bit=True,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True,
device_map=device_map).eval()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
So how can I adjust it to implement multi GPU inference on V100? Thank you very much for your reply.
The text was updated successfully, but these errors were encountered: