From e84dea75602002b4d69978cb1eb906dd19324311 Mon Sep 17 00:00:00 2001 From: daquexian Date: Fri, 7 Jul 2023 21:29:23 +0800 Subject: [PATCH] overlap communication and computation Signed-off-by: daquexian --- rwkv_pip_package/src/rwkv/model.py | 136 ++++++++++++++++++++++++----- 1 file changed, 114 insertions(+), 22 deletions(-) diff --git a/rwkv_pip_package/src/rwkv/model.py b/rwkv_pip_package/src/rwkv/model.py index e50ad64f..2f408ae1 100644 --- a/rwkv_pip_package/src/rwkv/model.py +++ b/rwkv_pip_package/src/rwkv/model.py @@ -2,7 +2,8 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -import types, gc, os, time, re +from collections import ChainMap +import types, gc, os, time, re, contextlib import torch from torch.nn import functional as F torch.backends.cudnn.benchmark = True @@ -77,12 +78,13 @@ def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry): class RWKV(MyModule): def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None): super().__init__() + self.prefetch_buffer = {} if verbose: prxxx = lambda *args, **kwargs: print(*args, **kwargs) else: prxxx = lambda *args, **kwargs: None - STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" + STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?(\d+)?)? *)+$" if not re.match(STRATEGY_REGEX, strategy): raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/") @@ -91,6 +93,7 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = args = self.args args.MODEL_NAME = model args.strategy_string = strategy + args.prefetch_layers_num = 0 # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow) self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0 @@ -144,9 +147,15 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = if len(si) > 2: ss = si[2] assert ss.startswith('*') - if ss.endswith('+'): - plan[i] = int(ss[1:-1]) + if '+' in ss: stream_i = i + if ss.endswith('+'): + plan[i] = int(ss[1:-1]) + args.prefetch_layers_num = 0 + else: + plan[i] = int(ss.split('+')[0][1:]) + args.prefetch_layers_num = int(ss.split('+')[-1]) + assert args.prefetch_layers_num >= 0 else: plan[i] = int(ss[1:]) allocated += plan[i] @@ -556,9 +565,12 @@ def cuda_att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_ def forward(self, tokens, state, full_output=False): with torch.no_grad(): - w = self.w args = self.args + if not hasattr(self, 'prefetch_stream'): + self.prefetch_stream = torch.cuda.Stream() + self.prefetch_events = [None] * args.prefetch_layers_num + if state == None: state = [None] * args.n_layer * 5 for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx @@ -573,6 +585,41 @@ def forward(self, tokens, state, full_output=False): seq_mode = len(tokens) > 1 + pb = self.prefetch_buffer + + # init prefetch_buffer + for i in range(args.prefetch_layers_num): + dd = self.strategy[i] + dev = dd.device + if dd.stream: + att = f'blocks.{i}.att.' + kw_key = f'{att}key.weight' + vw_key = f'{att}value.weight' + rw_key = f'{att}receptance.weight' + ow_key = f'{att}output.weight' + + with torch.cuda.stream(self.prefetch_stream): + pb[kw_key] = self.w[kw_key].to(device=dev, non_blocking=True) + pb[vw_key] = self.w[vw_key].to(device=dev, non_blocking=True) + pb[rw_key] = self.w[rw_key].to(device=dev, non_blocking=True) + pb[ow_key] = self.w[ow_key].to(device=dev, non_blocking=True) + + ffn = f'blocks.{i}.ffn.' + kw_key = f'{ffn}key.weight' + vw_key = f'{ffn}value.weight' + rw_key = f'{ffn}receptance.weight' + + with torch.cuda.stream(self.prefetch_stream): + pb[kw_key] = self.w[kw_key].to(device=dev, non_blocking=True) + pb[vw_key] = self.w[vw_key].to(device=dev, non_blocking=True) + pb[rw_key] = self.w[rw_key].to(device=dev, non_blocking=True) + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + self.prefetch_events[i % args.prefetch_layers_num] = event + + # find in self.prefetch_buffer, and find in self.w if not found + w = ChainMap(pb, self.w) + x = w['emb.weight'][tokens if seq_mode else tokens[0]] for i in range(args.n_layer): @@ -580,6 +627,15 @@ def forward(self, tokens, state, full_output=False): att = f'blocks.{i}.att.' ffn = f'blocks.{i}.ffn.' dd = self.strategy[i] + if dd.stream and args.prefetch_layers_num > 0: + self.prefetch_events[i % args.prefetch_layers_num].synchronize() + self.prefetch_events[i % args.prefetch_layers_num] = None + + pf_layer = i + args.prefetch_layers_num + if pf_layer < args.n_layer: + prefetch = self.strategy[pf_layer].stream + else: + prefetch = False dev = dd.device atype = dd.atype wtype = dd.wtype @@ -595,15 +651,30 @@ def forward(self, tokens, state, full_output=False): x = x.to(dtype=atype, device=dev) - kw = w[f'{att}key.weight'] - vw = w[f'{att}value.weight'] - rw = w[f'{att}receptance.weight'] - ow = w[f'{att}output.weight'] - if dd.stream: - kw = kw.to(device=dev, non_blocking=True) - vw = vw.to(device=dev, non_blocking=True) - rw = rw.to(device=dev, non_blocking=True) - ow = ow.to(device=dev, non_blocking=True) + if prefetch: + pf_att = f'blocks.{pf_layer}.att.' + pf_kw_key = f'{pf_att}key.weight' + pf_vw_key = f'{pf_att}value.weight' + pf_rw_key = f'{pf_att}receptance.weight' + pf_ow_key = f'{pf_att}output.weight' + + with torch.cuda.stream(self.prefetch_stream) if args.prefetch_layers_num > 0 else contextlib.nullcontext(): + pb[pf_kw_key] = self.w[pf_kw_key].to(device=dev, non_blocking=True) + pb[pf_vw_key] = self.w[pf_vw_key].to(device=dev, non_blocking=True) + pb[pf_rw_key] = self.w[pf_rw_key].to(device=dev, non_blocking=True) + pb[pf_ow_key] = self.w[pf_ow_key].to(device=dev, non_blocking=True) + kw_key = f'{att}key.weight' + vw_key = f'{att}value.weight' + rw_key = f'{att}receptance.weight' + ow_key = f'{att}output.weight' + kw = w[kw_key] + vw = w[vw_key] + rw = w[rw_key] + ow = w[ow_key] + assert kw.device.type == dev, f'{kw.device.type} != {dev}, {att=}' + assert vw.device.type == dev, f'{vw.device.type} != {dev}, {att=}' + assert rw.device.type == dev, f'{rw.device.type} != {dev}, {att=}' + assert ow.device.type == dev, f'{ow.device.type} != {dev}, {att=}' kmx = w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x krx = w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x kmy = w[f'{att}key.weight_my'] if wtype == torch.uint8 else x @@ -633,14 +704,32 @@ def forward(self, tokens, state, full_output=False): ) if dd.stream: del kw, vw, rw, ow - - kw = w[f'{ffn}key.weight'] - vw = w[f'{ffn}value.weight'] - rw = w[f'{ffn}receptance.weight'] - if dd.stream: - kw = kw.to(device=dev, non_blocking=True) - vw = vw.to(device=dev, non_blocking=True) - rw = rw.to(device=dev, non_blocking=True) + del pb[kw_key] + del pb[vw_key] + del pb[rw_key] + del pb[ow_key] + + if prefetch: + pf_ffn = f'blocks.{pf_layer}.ffn.' + pf_kw_key = f'{pf_ffn}key.weight' + pf_vw_key = f'{pf_ffn}value.weight' + pf_rw_key = f'{pf_ffn}receptance.weight' + + with torch.cuda.stream(self.prefetch_stream) if args.prefetch_layers_num > 0 else contextlib.nullcontext(): + pb[pf_kw_key] = self.w[pf_kw_key].to(device=dev, non_blocking=True) + pb[pf_vw_key] = self.w[pf_vw_key].to(device=dev, non_blocking=True) + pb[pf_rw_key] = self.w[pf_rw_key].to(device=dev, non_blocking=True) + if args.prefetch_layers_num > 0: + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + self.prefetch_events[i % args.prefetch_layers_num] = event + + kw_key = f'{ffn}key.weight' + vw_key = f'{ffn}value.weight' + rw_key = f'{ffn}receptance.weight' + kw = w[kw_key] + vw = w[vw_key] + rw = w[rw_key] kmx = w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x krx = w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x kmy = w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x @@ -664,6 +753,9 @@ def forward(self, tokens, state, full_output=False): ) if dd.stream: del kw, vw, rw + del pb[kw_key] + del pb[vw_key] + del pb[rw_key] if self.RESCALE_LAYER > 0: if (i+1) % self.RESCALE_LAYER == 0: