-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path0001-feat-enable-xpu-support-for-meta-reference-stack.patch
119 lines (107 loc) · 4.97 KB
/
0001-feat-enable-xpu-support-for-meta-reference-stack.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
From 2018c3c386db716259b94b8ddcbea17e207b97a2 Mon Sep 17 00:00:00 2001
From: Dmitry Rogozhkin <[email protected]>
Date: Mon, 18 Nov 2024 16:00:55 -0800
Subject: [PATCH] feat: enable xpu support for meta-reference stack
Requires: https://github.com/meta-llama/llama-models/pull/233
Signed-off-by: Dmitry Rogozhkin <[email protected]>
---
.../inference/meta_reference/generation.py | 45 +++++++++++++------
1 file changed, 32 insertions(+), 13 deletions(-)
diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py
index a96409ca..66f97b42 100644
--- a/llama_stack/providers/inline/inference/meta_reference/generation.py
+++ b/llama_stack/providers/inline/inference/meta_reference/generation.py
@@ -98,7 +98,10 @@ class Llama:
"""
llama_model_id = llama_model.core_model_id.value
if not torch.distributed.is_initialized():
- torch.distributed.init_process_group("nccl")
+ if torch.cuda.is_available():
+ torch.distributed.init_process_group("nccl")
+ else:
+ torch.distributed.init_process_group("gloo")
model_parallel_size = llama_model.pth_file_count
@@ -106,7 +109,14 @@ class Llama:
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
- torch.cuda.set_device(local_rank)
+ if torch.cuda.is_available():
+ device = "cuda"
+ torch.cuda.set_device(local_rank)
+ elif torch.xpu.is_available():
+ device = "xpu"
+ torch.xpu.set_device(local_rank)
+ else:
+ raise NotImplementedError("Devices other than CUDA and XPU are not supported yet")
# seed must be the same in all processes
if config.torch_seed is not None:
@@ -189,10 +199,17 @@ class Llama:
"Currently int4 and fp8 are the only supported quantization methods."
)
else:
- if torch.cuda.is_bf16_supported():
- torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
- else:
- torch.set_default_tensor_type(torch.cuda.HalfTensor)
+ if device == "cuda":
+ if torch.cuda.is_bf16_supported():
+ torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
+ else:
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
+ elif device == "xpu":
+ torch.set_default_device(device)
+ if torch.xpu.is_bf16_supported():
+ torch.set_default_dtype(torch.bfloat16)
+ else:
+ torch.set_default_dtype(torch.half)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
@@ -200,6 +217,8 @@ class Llama:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
+ model.to(device)
+
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model_id)
@@ -208,7 +227,7 @@ class Llama:
model: Transformer,
tokenizer: Tokenizer,
args: ModelArgs,
- llama_model: str,
+ llama_model: str
):
self.args = args
self.model = model
@@ -266,14 +285,14 @@ class Llama:
)
pad_id = self.tokenizer.pad_id
- tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
+ tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
- tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
+ tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
if logprobs:
- token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
+ token_logprobs = torch.zeros_like(tokens)
prev_pos = 0
- eos_reached = torch.tensor([False] * bsz, device="cuda")
+ eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
@@ -285,11 +304,11 @@ class Llama:
ignore_index=pad_id,
)
- stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
+ stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
- prev_pos, cur_pos, dtype=torch.long, device="cuda"
+ prev_pos, cur_pos, dtype=torch.long
)
logits = self.model.forward(
position_ids,
--
2.34.1