Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: prefetching #152

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 114 additions & 22 deletions rwkv_pip_package/src/rwkv/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import types, gc, os, time, re
from collections import ChainMap
import types, gc, os, time, re, contextlib
import torch
from torch.nn import functional as F
torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -77,12 +78,13 @@ def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry):
class RWKV(MyModule):
def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None):
super().__init__()
self.prefetch_buffer = {}
if verbose:
prxxx = lambda *args, **kwargs: print(*args, **kwargs)
else:
prxxx = lambda *args, **kwargs: None

STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?(\d+)?)? *)+$"
if not re.match(STRATEGY_REGEX, strategy):
raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/")

Expand All @@ -91,6 +93,7 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit =
args = self.args
args.MODEL_NAME = model
args.strategy_string = strategy
args.prefetch_layers_num = 0

# Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow)
self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0
Expand Down Expand Up @@ -144,9 +147,15 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit =
if len(si) > 2:
ss = si[2]
assert ss.startswith('*')
if ss.endswith('+'):
plan[i] = int(ss[1:-1])
if '+' in ss:
stream_i = i
if ss.endswith('+'):
plan[i] = int(ss[1:-1])
args.prefetch_layers_num = 0
else:
plan[i] = int(ss.split('+')[0][1:])
args.prefetch_layers_num = int(ss.split('+')[-1])
assert args.prefetch_layers_num >= 0
else:
plan[i] = int(ss[1:])
allocated += plan[i]
Expand Down Expand Up @@ -556,9 +565,12 @@ def cuda_att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_

def forward(self, tokens, state, full_output=False):
with torch.no_grad():
w = self.w
args = self.args

if not hasattr(self, 'prefetch_stream'):
self.prefetch_stream = torch.cuda.Stream()
self.prefetch_events = [None] * args.prefetch_layers_num

if state == None:
state = [None] * args.n_layer * 5
for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
Expand All @@ -573,13 +585,57 @@ def forward(self, tokens, state, full_output=False):

seq_mode = len(tokens) > 1

pb = self.prefetch_buffer

# init prefetch_buffer
for i in range(args.prefetch_layers_num):
dd = self.strategy[i]
dev = dd.device
if dd.stream:
att = f'blocks.{i}.att.'
kw_key = f'{att}key.weight'
vw_key = f'{att}value.weight'
rw_key = f'{att}receptance.weight'
ow_key = f'{att}output.weight'

with torch.cuda.stream(self.prefetch_stream):
pb[kw_key] = self.w[kw_key].to(device=dev, non_blocking=True)
pb[vw_key] = self.w[vw_key].to(device=dev, non_blocking=True)
pb[rw_key] = self.w[rw_key].to(device=dev, non_blocking=True)
pb[ow_key] = self.w[ow_key].to(device=dev, non_blocking=True)

ffn = f'blocks.{i}.ffn.'
kw_key = f'{ffn}key.weight'
vw_key = f'{ffn}value.weight'
rw_key = f'{ffn}receptance.weight'

with torch.cuda.stream(self.prefetch_stream):
pb[kw_key] = self.w[kw_key].to(device=dev, non_blocking=True)
pb[vw_key] = self.w[vw_key].to(device=dev, non_blocking=True)
pb[rw_key] = self.w[rw_key].to(device=dev, non_blocking=True)
event = torch.cuda.Event()
event.record(torch.cuda.current_stream())
self.prefetch_events[i % args.prefetch_layers_num] = event

# find in self.prefetch_buffer, and find in self.w if not found
w = ChainMap(pb, self.w)

x = w['emb.weight'][tokens if seq_mode else tokens[0]]

for i in range(args.n_layer):
bbb = f'blocks.{i}.'
att = f'blocks.{i}.att.'
ffn = f'blocks.{i}.ffn.'
dd = self.strategy[i]
if dd.stream and args.prefetch_layers_num > 0:
self.prefetch_events[i % args.prefetch_layers_num].synchronize()
self.prefetch_events[i % args.prefetch_layers_num] = None

pf_layer = i + args.prefetch_layers_num
if pf_layer < args.n_layer:
prefetch = self.strategy[pf_layer].stream
else:
prefetch = False
dev = dd.device
atype = dd.atype
wtype = dd.wtype
Expand All @@ -595,15 +651,30 @@ def forward(self, tokens, state, full_output=False):

x = x.to(dtype=atype, device=dev)

kw = w[f'{att}key.weight']
vw = w[f'{att}value.weight']
rw = w[f'{att}receptance.weight']
ow = w[f'{att}output.weight']
if dd.stream:
kw = kw.to(device=dev, non_blocking=True)
vw = vw.to(device=dev, non_blocking=True)
rw = rw.to(device=dev, non_blocking=True)
ow = ow.to(device=dev, non_blocking=True)
if prefetch:
pf_att = f'blocks.{pf_layer}.att.'
pf_kw_key = f'{pf_att}key.weight'
pf_vw_key = f'{pf_att}value.weight'
pf_rw_key = f'{pf_att}receptance.weight'
pf_ow_key = f'{pf_att}output.weight'

with torch.cuda.stream(self.prefetch_stream) if args.prefetch_layers_num > 0 else contextlib.nullcontext():
pb[pf_kw_key] = self.w[pf_kw_key].to(device=dev, non_blocking=True)
pb[pf_vw_key] = self.w[pf_vw_key].to(device=dev, non_blocking=True)
pb[pf_rw_key] = self.w[pf_rw_key].to(device=dev, non_blocking=True)
pb[pf_ow_key] = self.w[pf_ow_key].to(device=dev, non_blocking=True)
kw_key = f'{att}key.weight'
vw_key = f'{att}value.weight'
rw_key = f'{att}receptance.weight'
ow_key = f'{att}output.weight'
kw = w[kw_key]
vw = w[vw_key]
rw = w[rw_key]
ow = w[ow_key]
assert kw.device.type == dev, f'{kw.device.type} != {dev}, {att=}'
assert vw.device.type == dev, f'{vw.device.type} != {dev}, {att=}'
assert rw.device.type == dev, f'{rw.device.type} != {dev}, {att=}'
assert ow.device.type == dev, f'{ow.device.type} != {dev}, {att=}'
kmx = w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x
krx = w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x
kmy = w[f'{att}key.weight_my'] if wtype == torch.uint8 else x
Expand Down Expand Up @@ -633,14 +704,32 @@ def forward(self, tokens, state, full_output=False):
)
if dd.stream:
del kw, vw, rw, ow

kw = w[f'{ffn}key.weight']
vw = w[f'{ffn}value.weight']
rw = w[f'{ffn}receptance.weight']
if dd.stream:
kw = kw.to(device=dev, non_blocking=True)
vw = vw.to(device=dev, non_blocking=True)
rw = rw.to(device=dev, non_blocking=True)
del pb[kw_key]
del pb[vw_key]
del pb[rw_key]
del pb[ow_key]

if prefetch:
pf_ffn = f'blocks.{pf_layer}.ffn.'
pf_kw_key = f'{pf_ffn}key.weight'
pf_vw_key = f'{pf_ffn}value.weight'
pf_rw_key = f'{pf_ffn}receptance.weight'

with torch.cuda.stream(self.prefetch_stream) if args.prefetch_layers_num > 0 else contextlib.nullcontext():
pb[pf_kw_key] = self.w[pf_kw_key].to(device=dev, non_blocking=True)
pb[pf_vw_key] = self.w[pf_vw_key].to(device=dev, non_blocking=True)
pb[pf_rw_key] = self.w[pf_rw_key].to(device=dev, non_blocking=True)
if args.prefetch_layers_num > 0:
event = torch.cuda.Event()
event.record(torch.cuda.current_stream())
self.prefetch_events[i % args.prefetch_layers_num] = event

kw_key = f'{ffn}key.weight'
vw_key = f'{ffn}value.weight'
rw_key = f'{ffn}receptance.weight'
kw = w[kw_key]
vw = w[vw_key]
rw = w[rw_key]
kmx = w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x
krx = w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x
kmy = w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x
Expand All @@ -664,6 +753,9 @@ def forward(self, tokens, state, full_output=False):
)
if dd.stream:
del kw, vw, rw
del pb[kw_key]
del pb[vw_key]
del pb[rw_key]

if self.RESCALE_LAYER > 0:
if (i+1) % self.RESCALE_LAYER == 0:
Expand Down