-
Notifications
You must be signed in to change notification settings - Fork 278
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Add OpenFlamingo (#2237)
Co-authored-by: Tony Lee <[email protected]> Co-authored-by: JosselinSomervilleRoberts <[email protected]>
1 parent
5004acf
commit 481d12e
Showing
17 changed files
with
1,203 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .src.flamingo import Flamingo | ||
from .src.factory import create_model_and_transforms |
Empty file.
147 changes: 147 additions & 0 deletions
147
src/helm/clients/vision_language/open_flamingo/src/factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
""" | ||
Source: https://github.com/mlfoundations/open_flamingo | ||
""" | ||
|
||
from typing import Optional | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from helm.common.general import handle_module_not_found_error | ||
from .flamingo import Flamingo | ||
from .flamingo_lm import FlamingoLMMixin | ||
from .utils import extend_instance | ||
|
||
|
||
def create_model_and_transforms( | ||
clip_vision_encoder_path: str, | ||
clip_vision_encoder_pretrained: str, | ||
lang_encoder_path: str, | ||
tokenizer_path: str, | ||
cross_attn_every_n_layers: int = 1, | ||
use_local_files: bool = False, | ||
decoder_layers_attr_name: str = None, | ||
freeze_lm_embeddings: bool = False, | ||
cache_dir: Optional[str] = None, | ||
**flamingo_kwargs, | ||
): | ||
""" | ||
Initialize a Flamingo model from a pretrained vision encoder and language encoder. | ||
Appends special tokens to the tokenizer and freezes backbones. | ||
Args: | ||
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") | ||
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") | ||
lang_encoder_path (str): path to pretrained language encoder | ||
tokenizer_path (str): path to pretrained tokenizer | ||
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. | ||
use_local_files (bool, optional): whether to use local files. Defaults to False. | ||
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. | ||
freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver. | ||
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. | ||
Returns: | ||
Flamingo: Flamingo model from pretrained vision and language encoders | ||
Image processor: Pipeline to preprocess input images | ||
Tokenizer: A tokenizer for the language model | ||
""" | ||
try: | ||
import open_clip | ||
except ModuleNotFoundError as e: | ||
handle_module_not_found_error(e, ["vlm"]) | ||
|
||
vision_encoder, _, image_processor = open_clip.create_model_and_transforms( | ||
clip_vision_encoder_path, | ||
pretrained=clip_vision_encoder_pretrained, | ||
cache_dir=cache_dir, | ||
) | ||
# set the vision encoder to output the visual features | ||
vision_encoder.visual.output_tokens = True | ||
|
||
text_tokenizer = AutoTokenizer.from_pretrained( | ||
tokenizer_path, | ||
local_files_only=use_local_files, | ||
trust_remote_code=True, | ||
cache_dir=cache_dir, | ||
) | ||
# add Flamingo special tokens to the tokenizer | ||
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]}) | ||
if text_tokenizer.pad_token is None: | ||
# Issue: GPT models don't have a pad token, which we use to | ||
# modify labels for the loss. | ||
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"}) | ||
|
||
lang_encoder = AutoModelForCausalLM.from_pretrained( | ||
lang_encoder_path, | ||
local_files_only=use_local_files, | ||
trust_remote_code=True, | ||
cache_dir=cache_dir, | ||
) | ||
|
||
# hacks for MPT-1B, which doesn't have a get_input_embeddings method | ||
if "mpt-1b-redpajama-200b" in lang_encoder_path: | ||
|
||
class EmbeddingFnMixin: | ||
def get_input_embeddings(self): | ||
return self.transformer.wte | ||
|
||
def set_input_embeddings(self, new_embeddings): | ||
self.transformer.wte = new_embeddings | ||
|
||
extend_instance(lang_encoder, EmbeddingFnMixin) | ||
|
||
# convert LM to FlamingoLM | ||
extend_instance(lang_encoder, FlamingoLMMixin) | ||
|
||
if decoder_layers_attr_name is None: | ||
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) | ||
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) | ||
lang_encoder.resize_token_embeddings(len(text_tokenizer)) | ||
|
||
model = Flamingo( | ||
vision_encoder, | ||
lang_encoder, | ||
text_tokenizer.encode("<|endofchunk|>")[-1], | ||
text_tokenizer.encode("<image>")[-1], | ||
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"], | ||
cross_attn_every_n_layers=cross_attn_every_n_layers, | ||
**flamingo_kwargs, | ||
) | ||
|
||
# Freeze all parameters | ||
model.requires_grad_(False) | ||
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 | ||
|
||
# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings | ||
model.perceiver.requires_grad_(True) | ||
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) | ||
if not freeze_lm_embeddings: | ||
model.lang_encoder.get_input_embeddings().requires_grad_(True) | ||
# TODO: investigate also training the output embeddings when untied | ||
|
||
print( | ||
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" | ||
) | ||
|
||
return model, image_processor, text_tokenizer | ||
|
||
|
||
def _infer_decoder_layers_attr_name(model): | ||
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: | ||
if k.lower() in model.__class__.__name__.lower(): | ||
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] | ||
|
||
raise ValueError( | ||
"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. " | ||
"Please supply this string manually." | ||
) | ||
|
||
|
||
__KNOWN_DECODER_LAYERS_ATTR_NAMES = { | ||
"opt": "model.decoder.layers", | ||
"gptj": "transformer.h", | ||
"gpt-j": "transformer.h", | ||
"pythia": "gpt_neox.layers", | ||
"llama": "model.layers", | ||
"gptneoxforcausallm": "gpt_neox.layers", | ||
"mpt": "transformer.blocks", | ||
"mosaicgpt": "transformer.blocks", | ||
} |
337 changes: 337 additions & 0 deletions
337
src/helm/clients/vision_language/open_flamingo/src/flamingo.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,337 @@ | ||
""" | ||
Source: https://github.com/mlfoundations/open_flamingo | ||
""" | ||
|
||
import torch | ||
from einops import rearrange | ||
from torch import nn | ||
from .helpers import PerceiverResampler | ||
from torch.distributed.fsdp.wrap import ( | ||
enable_wrap, | ||
wrap, | ||
) | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
from torch.distributed.fsdp import ( | ||
FullyShardedDataParallel as FSDP, | ||
) | ||
|
||
from .utils import apply_with_stopping_condition | ||
|
||
|
||
class Flamingo(nn.Module): | ||
def __init__( | ||
self, | ||
vision_encoder: nn.Module, | ||
lang_encoder: nn.Module, | ||
eoc_token_id: int, | ||
media_token_id: int, | ||
vis_dim: int, | ||
cross_attn_every_n_layers: int = 1, | ||
gradient_checkpointing: bool = False, | ||
): | ||
""" | ||
Args: | ||
vision_encoder (nn.Module): HF CLIPModel | ||
lang_encoder (nn.Module): HF causal language model | ||
eoc_token_id (int): Token id for <|endofchunk|> | ||
media_token_id (int): Token id for <image> | ||
vis_dim (int): Dimension of the visual features. | ||
Visual features are projected to match this shape along the last dimension. | ||
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. | ||
""" | ||
super().__init__() | ||
self.eoc_token_id = eoc_token_id | ||
self.media_token_id = media_token_id | ||
self.vis_dim = vis_dim | ||
if hasattr(lang_encoder.config, "d_model"): | ||
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model | ||
else: | ||
self.lang_dim = lang_encoder.config.hidden_size | ||
|
||
self.vision_encoder = vision_encoder.visual | ||
self.perceiver = PerceiverResampler(dim=self.vis_dim) | ||
self.lang_encoder = lang_encoder | ||
self.lang_encoder.init_flamingo( | ||
media_token_id=media_token_id, | ||
lang_hidden_size=self.lang_dim, | ||
vis_hidden_size=self.vis_dim, | ||
cross_attn_every_n_layers=cross_attn_every_n_layers, | ||
gradient_checkpointing=gradient_checkpointing, | ||
) | ||
self._use_gradient_checkpointing = gradient_checkpointing | ||
self.perceiver._use_gradient_checkpointing = gradient_checkpointing | ||
|
||
def forward( | ||
self, | ||
vision_x: torch.Tensor, | ||
lang_x: torch.Tensor, | ||
attention_mask: torch.Tensor = None, | ||
labels: torch.Tensor = None, | ||
clear_conditioned_layers: bool = True, | ||
past_key_values=None, | ||
use_cache: bool = False, | ||
): | ||
""" | ||
Forward pass of Flamingo. | ||
Args: | ||
vision_x (torch.Tensor): Vision input | ||
shape (B, T_img, F, C, H, W) with F=1 | ||
lang_x (torch.Tensor): Language input ids | ||
shape (B, T_txt) | ||
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. | ||
labels (torch.Tensor, optional): Labels. Defaults to None. | ||
clear_conditioned_layers: if True, clear the conditioned layers | ||
once the foward pass is completed. Set this to false if the | ||
same set of images will be reused in another subsequent | ||
forward pass. | ||
past_key_values: pre-computed values to pass to language model. | ||
See past_key_values documentation in Hugging Face | ||
CausalLM models. | ||
use_cache: whether to use cached key values. See use_cache | ||
documentation in Hugging Face CausalLM models. | ||
""" | ||
assert ( | ||
self.lang_encoder.initialized_flamingo | ||
), "Flamingo layers are not initialized. Please call `init_flamingo` first." | ||
|
||
assert ( | ||
self.lang_encoder._use_cached_vision_x or vision_x is not None | ||
), "Must provide either vision_x or have precached media using cache_media()." | ||
|
||
if self.lang_encoder._use_cached_vision_x: | ||
# Case: use cached; vision_x should be cached and other | ||
# vision-related inputs should not be provided. | ||
assert ( | ||
vision_x is None | ||
), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first." | ||
assert self.lang_encoder.is_conditioned() | ||
|
||
else: | ||
# Case: do not use caching (i.e. this is a standard forward pass); | ||
self._encode_vision_x(vision_x=vision_x) | ||
self._condition_media_locations(input_ids=lang_x) | ||
|
||
output = self.lang_encoder( | ||
input_ids=lang_x, | ||
attention_mask=attention_mask, | ||
labels=labels, | ||
past_key_values=past_key_values, | ||
use_cache=use_cache, | ||
) | ||
|
||
if clear_conditioned_layers: | ||
self.lang_encoder.clear_conditioned_layers() | ||
|
||
return output | ||
|
||
def generate( | ||
self, | ||
vision_x: torch.Tensor, | ||
lang_x: torch.Tensor, | ||
attention_mask: torch.Tensor = None, | ||
**kwargs, | ||
): | ||
""" | ||
Generate text conditioned on vision and language inputs. | ||
Args: | ||
vision_x (torch.Tensor): Vision input | ||
shape (B, T_img, F, C, H, W) | ||
images in the same chunk are collated along T_img, and frames are collated along F | ||
currently only F=1 is supported (single-frame videos) | ||
lang_x (torch.Tensor): Language input | ||
shape (B, T_txt) | ||
**kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs: | ||
max_length (int, optional): Maximum length of the output. Defaults to None. | ||
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. | ||
num_beams (int, optional): Number of beams. Defaults to 1. | ||
max_new_tokens (int, optional): Maximum new tokens. Defaults to None. | ||
temperature (float, optional): Temperature. Defaults to 1.0. | ||
top_k (int, optional): Top k. Defaults to 50. | ||
top_p (float, optional): Top p. Defaults to 1.0. | ||
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. | ||
length_penalty (float, optional): Length penalty. Defaults to 1.0. | ||
num_return_sequences (int, optional): Number of return sequences. Defaults to 1. | ||
do_sample (bool, optional): Do sample. Defaults to False. | ||
early_stopping (bool, optional): Early stopping. Defaults to False. | ||
Returns: | ||
torch.Tensor: lang_x with generated tokens appended to it | ||
""" | ||
num_beams = kwargs.pop("num_beams", 1) | ||
if num_beams > 1: | ||
vision_x = vision_x.repeat_interleave(num_beams, dim=0) | ||
|
||
self.lang_encoder._use_cached_vision_x = True | ||
self._encode_vision_x(vision_x=vision_x) | ||
|
||
eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id) | ||
output = self.lang_encoder.generate( | ||
input_ids=lang_x, | ||
attention_mask=attention_mask, | ||
eos_token_id=eos_token_id, | ||
num_beams=num_beams, | ||
**kwargs, | ||
) | ||
|
||
self.lang_encoder.clear_conditioned_layers() | ||
self.lang_encoder._use_cached_vision_x = False | ||
return output | ||
|
||
def _encode_vision_x(self, vision_x: torch.Tensor): | ||
""" | ||
Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | ||
Args: | ||
vision_x (torch.Tensor): Vision input | ||
shape (B, T_img, F, C, H, W) | ||
Images in the same chunk are collated along T_img, and frames are collated along F | ||
Currently only F=1 is supported (single-frame videos) | ||
rearrange code based on https://github.com/dhansmair/flamingo-mini | ||
""" | ||
|
||
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" | ||
b, T, F = vision_x.shape[:3] | ||
assert F == 1, "Only single frame supported" | ||
|
||
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") | ||
with torch.no_grad(): | ||
vision_x = self.vision_encoder(vision_x)[1] | ||
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) | ||
vision_x = self.perceiver(vision_x) | ||
|
||
for layer in self.lang_encoder._get_decoder_layers(): | ||
layer.condition_vis_x(vision_x) | ||
|
||
def wrap_fsdp(self, wrapper_kwargs, device_id): | ||
""" | ||
Manually wraps submodules for FSDP and move other parameters to device_id. | ||
Why manually wrap? | ||
- all parameters within the FSDP wrapper must have the same requires_grad. | ||
We have a mix of frozen and unfrozen parameters. | ||
- model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors | ||
See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344 | ||
The rough wrapping structure is: | ||
- FlamingoModel | ||
- FSDP(FSDP(vision_encoder)) | ||
- FSDP(FSDP(perceiver)) | ||
- lang_encoder | ||
- FSDP(FSDP(input_embeddings)) | ||
- FlamingoLayers | ||
- FSDP(FSDP(gated_cross_attn_layer)) | ||
- FSDP(FSDP(decoder_layer)) | ||
- FSDP(FSDP(output_embeddings)) | ||
- other parameters | ||
Known issues: | ||
- Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied, | ||
train with DDP or set the --freeze_lm_embeddings flag to true. | ||
- With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound. | ||
Although the training curves look okay, we found that downstream performance dramatically | ||
degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M). | ||
FAQs about our FSDP wrapping strategy: | ||
Why double wrap? | ||
As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook | ||
only free gathered parameters if the module is NOT FSDP root. | ||
Why unfreeze the decoder_layers? | ||
See https://github.com/pytorch/pytorch/issues/95805 | ||
As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param | ||
requires_grad=True. We need the postback to fire to avoid OOM. | ||
To effectively freeze the decoder layers, we exclude them from the optimizer. | ||
What is assumed to be frozen v. unfrozen? | ||
We assume that the model is being trained under normal Flamingo settings | ||
with these lines being called in factory.py: | ||
``` | ||
# Freeze all parameters | ||
model.requires_grad_(False) | ||
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 | ||
# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings | ||
model.perceiver.requires_grad_(True) | ||
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) | ||
[optional] model.lang_encoder.get_input_embeddings().requires_grad_(True) | ||
``` | ||
""" | ||
# unfreeze the decoder layers | ||
for block in self.lang_encoder.old_decoder_blocks: | ||
block.requires_grad_(True) | ||
|
||
# wrap in FSDP | ||
with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs): | ||
self.perceiver = wrap(wrap(self.perceiver)) | ||
self.lang_encoder.old_decoder_blocks = nn.ModuleList( | ||
wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks | ||
) | ||
self.lang_encoder.gated_cross_attn_layers = nn.ModuleList( | ||
wrap(wrap(layer)) if layer is not None else None for layer in self.lang_encoder.gated_cross_attn_layers | ||
) | ||
self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing) | ||
self.lang_encoder.set_input_embeddings(wrap(wrap(self.lang_encoder.get_input_embeddings()))) | ||
self.lang_encoder.set_output_embeddings(wrap(wrap(self.lang_encoder.get_output_embeddings()))) | ||
self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen | ||
|
||
# manually move non-FSDP managed parameters to device_id | ||
# these are all in lang_encoder | ||
apply_with_stopping_condition( | ||
module=self.lang_encoder, | ||
apply_fn=lambda m: m.to(device_id), | ||
apply_condition=lambda m: len(list(m.children())) == 0, | ||
stopping_condition=lambda m: isinstance(m, FSDP), | ||
) | ||
|
||
# exclude the original decoder layers from the optimizer | ||
for block in self.lang_encoder.old_decoder_blocks: | ||
for p in block.parameters(): | ||
p.exclude_from_optimizer = True | ||
|
||
# set up clip_grad_norm_ function | ||
def clip_grad_norm_(max_norm): | ||
self.perceiver.clip_grad_norm_(max_norm) | ||
for layer in self.lang_encoder.gated_cross_attn_layers: | ||
if layer is not None: | ||
layer.clip_grad_norm_(max_norm) | ||
self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm) | ||
|
||
self.clip_grad_norm_ = clip_grad_norm_ | ||
|
||
def _condition_media_locations(self, input_ids: torch.Tensor): | ||
""" | ||
Compute the media token locations from lang_x and condition the language model on these. | ||
Args: | ||
input_ids (torch.Tensor): Language input | ||
shape (B, T_txt) | ||
""" | ||
media_locations = input_ids == self.media_token_id | ||
|
||
for layer in self.lang_encoder._get_decoder_layers(): | ||
layer.condition_media_locations(media_locations) | ||
|
||
def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor): | ||
""" | ||
Pre-cache a prompt/sequence of images / text for log-likelihood evaluations. | ||
All subsequent calls to forward() will generate attending to the LAST | ||
image in vision_x. | ||
This is not meant to be used to cache things for generate(). | ||
Args: | ||
input_ids (torch.Tensor): Language input | ||
shape (B, T_txt) | ||
vision_x (torch.Tensor): Vision input | ||
shape (B, T_img, F, C, H, W) | ||
Images in the same chunk are collated along T_img, and frames are collated along F | ||
Currently only F=1 is supported (single-frame videos) | ||
""" | ||
self._encode_vision_x(vision_x=vision_x) | ||
self._condition_media_locations(input_ids=input_ids) | ||
self.lang_encoder._use_cached_vision_x = True | ||
|
||
def uncache_media(self): | ||
""" | ||
Clear all conditioning. | ||
""" | ||
self.lang_encoder.clear_conditioned_layers() | ||
self.lang_encoder._use_cached_vision_x = False |
153 changes: 153 additions & 0 deletions
153
src/helm/clients/vision_language/open_flamingo/src/flamingo_lm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
""" | ||
Source: https://github.com/mlfoundations/open_flamingo | ||
""" | ||
|
||
import torch.nn as nn | ||
from .helpers import GatedCrossAttentionBlock | ||
from .utils import getattr_recursive, setattr_recursive | ||
|
||
|
||
class FlamingoLayer(nn.Module): | ||
""" | ||
FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer. | ||
""" | ||
|
||
def __init__(self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False): | ||
super().__init__() | ||
self.gated_cross_attn_layer = gated_cross_attn_layer | ||
self.decoder_layer = decoder_layer | ||
self.vis_x = None | ||
self.media_locations = None | ||
if self.gated_cross_attn_layer is not None: | ||
self.gated_cross_attn_layer._use_gradient_checkpointing = gradient_checkpointing | ||
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing | ||
|
||
def is_conditioned(self) -> bool: | ||
"""Check whether the layer is conditioned.""" | ||
return self.vis_x is not None and self.media_locations is not None | ||
|
||
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) | ||
def condition_vis_x(self, vis_x): | ||
self.vis_x = vis_x | ||
|
||
def condition_media_locations(self, media_locations): | ||
self.media_locations = media_locations | ||
|
||
def condition_use_cached_media(self, use_cached_media): | ||
self.use_cached_media = use_cached_media | ||
|
||
def forward( | ||
self, | ||
lang_x, | ||
attention_mask=None, | ||
**decoder_layer_kwargs, | ||
): | ||
# Cross attention | ||
if self.gated_cross_attn_layer is not None: | ||
if self.vis_x is None: | ||
raise ValueError("vis_x must be conditioned before forward pass") | ||
|
||
if self.media_locations is None: | ||
raise ValueError("media_locations must be conditioned before forward pass") | ||
|
||
lang_x = self.gated_cross_attn_layer( | ||
lang_x, | ||
self.vis_x, | ||
media_locations=self.media_locations, | ||
use_cached_media=self.use_cached_media, | ||
) | ||
|
||
# Normal decoder layer | ||
lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) | ||
return lang_x | ||
|
||
|
||
class FlamingoLMMixin(nn.Module): | ||
""" | ||
Mixin to add cross-attention layers to a language model. | ||
""" | ||
|
||
def set_decoder_layers_attr_name(self, decoder_layers_attr_name): | ||
self.decoder_layers_attr_name = decoder_layers_attr_name | ||
|
||
def _get_decoder_layers(self): | ||
return getattr_recursive(self, self.decoder_layers_attr_name) | ||
|
||
def _set_decoder_layers(self, value): | ||
setattr_recursive(self, self.decoder_layers_attr_name, value) | ||
|
||
def init_flamingo( | ||
self, | ||
media_token_id, | ||
lang_hidden_size, | ||
vis_hidden_size, | ||
cross_attn_every_n_layers, | ||
gradient_checkpointing, | ||
): | ||
""" | ||
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. | ||
""" | ||
self.old_decoder_blocks = self._get_decoder_layers() | ||
self.gated_cross_attn_layers = nn.ModuleList( | ||
[ | ||
GatedCrossAttentionBlock(dim=lang_hidden_size, dim_visual=vis_hidden_size) | ||
if (layer_idx + 1) % cross_attn_every_n_layers == 0 | ||
else None | ||
for layer_idx, _ in enumerate(self._get_decoder_layers()) | ||
] | ||
) | ||
self.init_flamingo_layers(gradient_checkpointing) | ||
self.media_token_id = media_token_id | ||
self.initialized_flamingo = True | ||
self._use_cached_vision_x = False | ||
|
||
def init_flamingo_layers(self, gradient_checkpointing): | ||
""" | ||
Re initializes the FlamingoLayers. | ||
Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks | ||
""" | ||
self._set_decoder_layers( | ||
nn.ModuleList( | ||
[ | ||
FlamingoLayer(gated_cross_attn_layer, decoder_layer, gradient_checkpointing) | ||
for gated_cross_attn_layer, decoder_layer in zip( | ||
self.gated_cross_attn_layers, self.old_decoder_blocks | ||
) | ||
] | ||
) | ||
) | ||
|
||
def forward(self, input_ids, attention_mask, **kwargs): | ||
"""Condition the Flamingo layers on the media locations before forward()""" | ||
if not self.initialized_flamingo: | ||
raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.") | ||
|
||
media_locations = input_ids == self.media_token_id | ||
|
||
# if there are media already cached and we're generating and there are no media tokens in the input, | ||
# we'll assume that ALL input tokens should attend to the last previous media that is cached. | ||
# this is especially important for HF generate() compatibility, since generate() calls forward() | ||
# repeatedly one token at a time (with no media tokens). | ||
# without this check, the model would not attend to any images when generating (after the first token) | ||
use_cached_media_locations = self._use_cached_vision_x and self.is_conditioned() and not media_locations.any() | ||
|
||
for layer in self._get_decoder_layers(): | ||
if not use_cached_media_locations: | ||
layer.condition_media_locations(media_locations) | ||
layer.condition_use_cached_media(use_cached_media_locations) | ||
|
||
# package arguments for the other parent's forward. since we don't know the order of the arguments, | ||
# make them all kwargs | ||
kwargs["input_ids"] = input_ids | ||
kwargs["attention_mask"] = attention_mask | ||
return super().forward(**kwargs) # Call the other parent's forward method | ||
|
||
def is_conditioned(self) -> bool: | ||
"""Check whether all decoder layers are already conditioned.""" | ||
return all(l.is_conditioned() for l in self._get_decoder_layers()) | ||
|
||
def clear_conditioned_layers(self): | ||
for layer in self._get_decoder_layers(): | ||
layer.condition_vis_x(None) | ||
layer.condition_media_locations(None) | ||
layer.condition_use_cached_media(None) |
267 changes: 267 additions & 0 deletions
267
src/helm/clients/vision_language/open_flamingo/src/helpers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
""" | ||
Based on: https://github.com/lucidrains/flamingo-pytorch | ||
""" | ||
|
||
import torch | ||
from einops import rearrange, repeat | ||
from einops_exts import rearrange_many | ||
from torch import einsum, nn | ||
|
||
|
||
def exists(val): | ||
return val is not None | ||
|
||
|
||
def FeedForward(dim, mult=4): | ||
inner_dim = int(dim * mult) | ||
return nn.Sequential( | ||
nn.LayerNorm(dim), | ||
nn.Linear(dim, inner_dim, bias=False), | ||
nn.GELU(), | ||
nn.Linear(inner_dim, dim, bias=False), | ||
) | ||
|
||
|
||
class PerceiverAttention(nn.Module): | ||
def __init__(self, *, dim, dim_head=64, heads=8): | ||
super().__init__() | ||
self.scale = dim_head**-0.5 | ||
self.heads = heads | ||
inner_dim = dim_head * heads | ||
|
||
self.norm_media = nn.LayerNorm(dim) | ||
self.norm_latents = nn.LayerNorm(dim) | ||
|
||
self.to_q = nn.Linear(dim, inner_dim, bias=False) | ||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) | ||
self.to_out = nn.Linear(inner_dim, dim, bias=False) | ||
|
||
def forward(self, x, latents): | ||
""" | ||
Args: | ||
x (torch.Tensor): image features | ||
shape (b, T, n1, D) | ||
latent (torch.Tensor): latent features | ||
shape (b, T, n2, D) | ||
""" | ||
x = self.norm_media(x) | ||
latents = self.norm_latents(latents) | ||
|
||
h = self.heads | ||
|
||
q = self.to_q(latents) | ||
kv_input = torch.cat((x, latents), dim=-2) | ||
k, v = self.to_kv(kv_input).chunk(2, dim=-1) | ||
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) | ||
q = q * self.scale | ||
|
||
# attention | ||
sim = einsum("... i d, ... j d -> ... i j", q, k) | ||
sim = sim - sim.amax(dim=-1, keepdim=True).detach() | ||
attn = sim.softmax(dim=-1) | ||
|
||
out = einsum("... i j, ... j d -> ... i d", attn, v) | ||
out = rearrange(out, "b h t n d -> b t n (h d)", h=h) | ||
return self.to_out(out) | ||
|
||
|
||
class PerceiverResampler(nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
dim, | ||
depth=6, | ||
dim_head=64, | ||
heads=8, | ||
num_latents=64, | ||
max_num_media=None, | ||
max_num_frames=None, | ||
ff_mult=4, | ||
): | ||
super().__init__() | ||
self.latents = nn.Parameter(torch.randn(num_latents, dim)) | ||
self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None | ||
self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None | ||
|
||
self.layers = nn.ModuleList([]) | ||
for _ in range(depth): | ||
self.layers.append( | ||
nn.ModuleList( | ||
[ | ||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), | ||
FeedForward(dim=dim, mult=ff_mult), | ||
] | ||
) | ||
) | ||
|
||
self.norm = nn.LayerNorm(dim) | ||
|
||
def forward(self, x): | ||
""" | ||
Args: | ||
x (torch.Tensor): image features | ||
shape (b, T, F, v, D) | ||
Returns: | ||
shape (b, T, n, D) where n is self.num_latents | ||
""" | ||
b, T, F, v = x.shape[:4] | ||
|
||
# frame and media time embeddings | ||
if exists(self.frame_embs): | ||
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) | ||
x = x + frame_embs | ||
x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions | ||
if exists(self.media_time_embs): | ||
x = x + self.media_time_embs[:T] | ||
|
||
# blocks | ||
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) | ||
for attn, ff in self.layers: | ||
latents = attn(x, latents) + latents | ||
latents = ff(latents) + latents | ||
return self.norm(latents) | ||
|
||
|
||
# gated cross attention | ||
class MaskedCrossAttention(nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
dim, | ||
dim_visual, | ||
dim_head=64, | ||
heads=8, | ||
only_attend_immediate_media=True, | ||
): | ||
super().__init__() | ||
self.scale = dim_head**-0.5 | ||
self.heads = heads | ||
inner_dim = dim_head * heads | ||
|
||
self.norm = nn.LayerNorm(dim) | ||
|
||
self.to_q = nn.Linear(dim, inner_dim, bias=False) | ||
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) | ||
self.to_out = nn.Linear(inner_dim, dim, bias=False) | ||
|
||
# whether for text to only attend to immediate preceding image, or all previous images | ||
self.only_attend_immediate_media = only_attend_immediate_media | ||
|
||
def forward(self, x, media, media_locations=None, use_cached_media=False): | ||
""" | ||
Args: | ||
x (torch.Tensor): text features | ||
shape (B, T_txt, D_txt) | ||
media (torch.Tensor): image features | ||
shape (B, T_img, n, D_img) where n is the dim of the latents | ||
media_locations: boolean mask identifying the media tokens in x | ||
shape (B, T_txt) | ||
use_cached_media: bool | ||
If true, treat all of x as if they occur after the last media | ||
registered in media_locations. T_txt does not need to exactly | ||
equal media_locations.shape[1] in this case | ||
""" | ||
|
||
if not use_cached_media: | ||
assert ( | ||
media_locations.shape[1] == x.shape[1] | ||
), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}" | ||
|
||
T_txt = x.shape[1] | ||
_, T_img, n = media.shape[:3] | ||
h = self.heads | ||
|
||
x = self.norm(x) | ||
|
||
q = self.to_q(x) | ||
media = rearrange(media, "b t n d -> b (t n) d") | ||
|
||
k, v = self.to_kv(media).chunk(2, dim=-1) | ||
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) | ||
|
||
q = q * self.scale | ||
|
||
sim = einsum("... i d, ... j d -> ... i j", q, k) | ||
|
||
if exists(media_locations): | ||
media_time = torch.arange(T_img, device=x.device) + 1 | ||
|
||
if use_cached_media: | ||
# text time is set to the last cached media location | ||
text_time = repeat( | ||
torch.count_nonzero(media_locations, dim=1), | ||
"b -> b i", | ||
i=T_txt, | ||
) | ||
else: | ||
# at each boolean of True, increment the time counter (relative to media time) | ||
text_time = media_locations.cumsum(dim=-1) | ||
|
||
# text time must equal media time if only attending to most immediate image | ||
# otherwise, as long as text time is greater than media time (if attending to all previous images / media) | ||
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge | ||
|
||
text_to_media_mask = mask_op( | ||
rearrange(text_time, "b i -> b 1 i 1"), | ||
repeat(media_time, "j -> 1 1 1 (j n)", n=n), | ||
) | ||
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) | ||
|
||
sim = sim - sim.amax(dim=-1, keepdim=True).detach() | ||
attn = sim.softmax(dim=-1) | ||
|
||
if exists(media_locations) and self.only_attend_immediate_media: | ||
# any text without a preceding media needs to have attention zeroed out | ||
text_without_media_mask = text_time == 0 | ||
text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1") | ||
attn = attn.masked_fill(text_without_media_mask, 0.0) | ||
|
||
out = einsum("... i j, ... j d -> ... i d", attn, v) | ||
out = rearrange(out, "b h n d -> b n (h d)") | ||
return self.to_out(out) | ||
|
||
|
||
class GatedCrossAttentionBlock(nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
dim, | ||
dim_visual, | ||
dim_head=64, | ||
heads=8, | ||
ff_mult=4, | ||
only_attend_immediate_media=True, | ||
): | ||
super().__init__() | ||
self.attn = MaskedCrossAttention( | ||
dim=dim, | ||
dim_visual=dim_visual, | ||
dim_head=dim_head, | ||
heads=heads, | ||
only_attend_immediate_media=only_attend_immediate_media, | ||
) | ||
self.attn_gate = nn.Parameter(torch.tensor([0.0])) | ||
|
||
self.ff = FeedForward(dim, mult=ff_mult) | ||
self.ff_gate = nn.Parameter(torch.tensor([0.0])) | ||
|
||
def forward( | ||
self, | ||
x, | ||
media, | ||
media_locations=None, | ||
use_cached_media=False, | ||
): | ||
x = ( | ||
self.attn( | ||
x, | ||
media, | ||
media_locations=media_locations, | ||
use_cached_media=use_cached_media, | ||
) | ||
* self.attn_gate.tanh() | ||
+ x | ||
) | ||
x = self.ff(x) * self.ff_gate.tanh() + x | ||
|
||
return x |
47 changes: 47 additions & 0 deletions
47
src/helm/clients/vision_language/open_flamingo/src/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Source: https://github.com/mlfoundations/open_flamingo | ||
""" | ||
|
||
|
||
def extend_instance(obj, mixin): | ||
"""Apply mixins to a class instance after creation""" | ||
base_cls = obj.__class__ | ||
base_cls_name = obj.__class__.__name__ | ||
obj.__class__ = type( | ||
base_cls_name, (mixin, base_cls), {} | ||
) # mixin needs to go first for our forward() logic to work | ||
|
||
|
||
def getattr_recursive(obj, att): | ||
""" | ||
Return nested attribute of obj | ||
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c | ||
""" | ||
if att == "": | ||
return obj | ||
i = att.find(".") | ||
if i < 0: | ||
return getattr(obj, att) | ||
else: | ||
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) | ||
|
||
|
||
def setattr_recursive(obj, att, val): | ||
""" | ||
Set nested attribute of obj | ||
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val | ||
""" | ||
if "." in att: | ||
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) | ||
setattr(obj, att.split(".")[-1], val) | ||
|
||
|
||
def apply_with_stopping_condition(module, apply_fn, apply_condition=None, stopping_condition=None, **other_args): | ||
if stopping_condition(module): | ||
return | ||
if apply_condition(module): | ||
apply_fn(module, **other_args) | ||
for child in module.children(): | ||
apply_with_stopping_condition( | ||
child, apply_fn, apply_condition=apply_condition, stopping_condition=stopping_condition, **other_args | ||
) |
160 changes: 160 additions & 0 deletions
160
src/helm/clients/vision_language/open_flamingo_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from threading import Lock | ||
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
from huggingface_hub import hf_hub_download | ||
|
||
from helm.common.cache import CacheConfig | ||
from helm.common.hierarchical_logger import hlog, htrack_block | ||
from helm.common.images_utils import open_image | ||
from helm.common.gpu_utils import get_torch_device_name | ||
from helm.common.media_object import TEXT_TYPE | ||
from helm.common.optional_dependencies import handle_module_not_found_error | ||
from helm.common.request import Request, RequestResult, Sequence, Token | ||
from helm.common.request import wrap_request_time | ||
from helm.clients.vision_language.open_flamingo import create_model_and_transforms | ||
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt | ||
|
||
try: | ||
from PIL import Image | ||
except ModuleNotFoundError as e: | ||
handle_module_not_found_error(e, ["images"]) | ||
|
||
|
||
class OpenFlamingoClient(CachingClient): | ||
""" | ||
OpenFlamingo is an open source implementation of DeepMind's Flamingo models. | ||
Implementation following: | ||
https://github.com/mlfoundations/open_flamingo | ||
https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b | ||
""" | ||
|
||
END_OF_CHUNK_TOKEN: str = "<|endofchunk|>" | ||
IMAGE_TOKEN: str = "<image>" | ||
|
||
_model_lock: Lock = Lock() | ||
|
||
def __init__( | ||
self, | ||
cache_config: CacheConfig, | ||
checkpoint_path: Optional[str] = None, | ||
tokenizer_name: Optional[str] = None, | ||
cross_attn_every_n_layers: int = 4, | ||
): | ||
super().__init__(cache_config) | ||
self._device: str = get_torch_device_name() | ||
self._checkpoint_path: Optional[str] = checkpoint_path | ||
self._tokenizer_name: Optional[str] = tokenizer_name | ||
self._cross_attn_every_n_layers: int = cross_attn_every_n_layers | ||
|
||
# Model | ||
# The model is only initialized when the first request is made | ||
# This is to avoid loading the model if it is not used | ||
self._model: Optional[torch.nn.Module] = None | ||
|
||
def _get_model(self): | ||
if not self._checkpoint_path: | ||
raise ValueError("OpenFlamingoClient requires a checkpoint path") | ||
if not self._tokenizer_name: | ||
raise ValueError("OpenFlamingoClient requires a tokenizer name") | ||
with htrack_block("Initializing OpenFlamingo model"): | ||
with self._model_lock: | ||
self._model, self.image_processor, self.tokenizer = create_model_and_transforms( | ||
clip_vision_encoder_path="ViT-L-14", | ||
clip_vision_encoder_pretrained="openai", | ||
lang_encoder_path=self._tokenizer_name, | ||
tokenizer_path=self._tokenizer_name, | ||
cross_attn_every_n_layers=self._cross_attn_every_n_layers, | ||
) | ||
self.tokenizer.padding_side = "left" | ||
checkpoint_path = hf_hub_download(self._checkpoint_path, "checkpoint.pt") | ||
self._model.load_state_dict(torch.load(checkpoint_path), strict=False) | ||
self._model = self._model.to(self._device) | ||
hlog(f"Loaded model to {self._device}.") | ||
|
||
def make_request(self, request: Request) -> RequestResult: | ||
assert request.multimodal_prompt is not None, "Multimodal prompt is required" | ||
|
||
# Load model if needed | ||
if self._model is None: | ||
self._get_model() | ||
|
||
# Build the prompt | ||
prompt_text: str = "" | ||
images: List[Image.Image] = [] | ||
for media_object in request.multimodal_prompt.media_objects: | ||
if media_object.is_type("image") and media_object.location: | ||
images.append(open_image(media_object.location)) | ||
prompt_text += self.IMAGE_TOKEN | ||
elif media_object.is_type(TEXT_TYPE): | ||
if media_object.text is None: | ||
raise ValueError("MediaObject of text type has missing text field value") | ||
prompt_text += media_object.text | ||
else: | ||
raise ValueError(f"Unrecognized MediaObject type {media_object.type}") | ||
|
||
# Preprocess | ||
vision_x: torch.Tensor = torch.cat([self.image_processor(image).unsqueeze(0) for image in images], dim=0) | ||
vision_x = vision_x.unsqueeze(1).unsqueeze(0) | ||
|
||
lang_x = self.tokenizer( | ||
[prompt_text], | ||
return_tensors="pt", | ||
) | ||
|
||
# Generate | ||
try: | ||
generation_args = { | ||
"max_new_tokens": request.max_tokens, | ||
"num_beams": request.num_completions, | ||
"n": request.num_completions, | ||
} | ||
|
||
def do_it(): | ||
tensors = self._model.generate( | ||
vision_x=vision_x.to(self._device), | ||
lang_x=lang_x["input_ids"].to(self._device), | ||
attention_mask=lang_x["attention_mask"].to(self._device), | ||
max_new_tokens=generation_args["max_new_tokens"], | ||
num_beams=generation_args["num_beams"], | ||
num_return_sequences=generation_args["num_beams"], | ||
) | ||
generated_completions: List[Tuple[str, List[str]]] = [] | ||
for tensor in tensors: | ||
generated_text: str = self.tokenizer.decode(tensor) | ||
assert generated_text.startswith( | ||
prompt_text | ||
), f"Generated text: {generated_text} does not start with prompt: {prompt_text}" | ||
|
||
# Remove the prompt from the generated text | ||
generated_text = generated_text[len(prompt_text) :].replace(self.END_OF_CHUNK_TOKEN, "").strip() | ||
raw_tokens: List[str] = self.tokenizer.tokenize(generated_text) | ||
generated_completions.append((generated_text, raw_tokens)) | ||
|
||
return {"output": generated_completions} | ||
|
||
cache_key = CachingClient.make_cache_key( | ||
raw_request={ | ||
"model": request.model, | ||
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt), | ||
**generation_args, | ||
}, | ||
request=request, | ||
) | ||
result, cached = self.cache.get(cache_key, wrap_request_time(do_it)) | ||
except RuntimeError as e: | ||
return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[]) | ||
|
||
completions: List[Sequence] = [] | ||
for text, tokens in result["output"]: | ||
completions.append( | ||
Sequence(text=text, logprob=0, tokens=[Token(text=token, logprob=0) for token in tokens]) | ||
) | ||
|
||
return RequestResult( | ||
success=True, | ||
cached=cached, | ||
request_time=result["request_time"], | ||
completions=completions, | ||
embedding=[], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters