Skip to content

Commit

Permalink
chore: adjust forward call
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaschhh committed Jun 11, 2024
1 parent 137d4f2 commit d93e6e5
Showing 1 changed file with 57 additions and 83 deletions.
140 changes: 57 additions & 83 deletions src/modalities/models/coca/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
from modalities.nn.attention import AttentionConfig


class AVConfig(BaseModel):
audio_transformer_config: AudioTransformerConfig
vision_transformer_config: VisionTransformerConfig


class TextDecoderConfig(BaseModel):
sample_key: str
prediction_key: str
Expand All @@ -43,17 +38,19 @@ class TextDecoderConfig(BaseModel):

class CoCaConfig(BaseModel):
prediction_key: str = "logits"
modality_key: str = "modality"
modality_embd_prediction_key: str
audio_embd_prediction_key: str # same key as vision encoder
vision_embd_prediction_key: str # same key as vision encoder
text_embd_prediction_key: str
modality_cls_prediction_key: str
audio_cls_prediction_key: str
vision_cls_prediction_key: str
text_cls_prediction_key: str
logit_scale_prediction_key: str
modality_encoder_config: AudioTransformerConfig | VisionTransformerConfig | AVConfig
vision_encoder_config: VisionTransformerConfig
audio_encoder_config: AudioTransformerConfig
text_decoder_config: TextDecoderConfig
n_pool_head: Annotated[int, Field(ge=1)]
n_vision_queries: Annotated[int, Field(ge=1)] | None
n_audio_queries: Annotated[int, Field(ge=1)] | None
n_audio_queries: Annotated[int, Field(ge=1)]
n_vision_queries: Annotated[int, Field(ge=1)]
bias_attn_pool: bool
epsilon_attn_pool: Annotated[float, Field(ge=0.0)]
weight_init: WeightInitializationConfig
Expand All @@ -72,64 +69,35 @@ class CoCa(NNModel):
def __init__(
self,
prediction_key: str,
modality_key: str,
modality_embd_prediction_key: str,
vision_cls_prediction_key: str,
audio_cls_prediction_key: str,
text_cls_prediction_key: str,
vision_embd_prediction_key: str,
text_embd_prediction_key: str,
audio_embd_prediction_key: str,
logit_scale_prediction_key: str,
modality_cls_prediction_key: str,
text_cls_prediction_key: str,
n_vision_queries: int,
n_audio_queries: int,
n_pool_head: int,
bias_attn_pool: bool,
epsilon_attn_pool: float,
modality_encoder_config: VisionTransformerConfig | AudioTransformerConfig | AVConfig,
vision_encoder_config: VisionTransformerConfig,
audio_encoder_config: AudioTransformerConfig,
text_decoder_config: TextDecoderConfig,
weight_init: WeightInitializationConfig,
) -> None:
super().__init__()

self.AUDIO = 0
self.VISION = 1

self.prediction_key = prediction_key
self.modality_key = modality_key
self.modality_embd_prediction_key = modality_embd_prediction_key
self.vision_cls_prediction_key = vision_cls_prediction_key
self.text_cls_prediction_key = text_cls_prediction_key
self.audio_cls_prediction_key = audio_cls_prediction_key
self.vision_embd_prediction_key = vision_embd_prediction_key
self.text_embd_prediction_key = text_embd_prediction_key
self.audio_embd_prediction_key = audio_embd_prediction_key
self.logit_scale_prediction_key = logit_scale_prediction_key

self.modality_cls_prediction_key = modality_cls_prediction_key
self.text_cls_prediction_key = text_cls_prediction_key

self.n_pool_head = n_pool_head
self.bias_attn_pool = bias_attn_pool
self.epsilon_attn_pool = epsilon_attn_pool
self.text_decoder_config = text_decoder_config

if isinstance(modality_encoder_config, VisionTransformerConfig):
self.vision_encoder, self.vision_queries, self.vision_attn_pool = self._init_modality(
VisionTransformer,
modality_encoder_config,
n_vision_queries,
)
elif isinstance(modality_encoder_config, AudioTransformerConfig):
self.audio_encoder, self.audio_queries, self.audio_attn_pool = self._init_modality(
AudioTransformer,
modality_encoder_config,
n_audio_queries,
)
else:
self.vision_encoder, self.vision_queries, self.vision_attn_pool = self._init_modality(
VisionTransformer,
modality_encoder_config.vision_transformer_config,
n_vision_queries,
)
self.audio_encoder, self.audio_queries, self.audio_attn_pool = self._init_modality(
AudioTransformer,
modality_encoder_config.audio_transformer_config,
n_audio_queries,
)

self.vision_encoder = VisionTransformer(**dict(vision_encoder_config))
self.audio_encoder = AudioTransformer(**dict(audio_encoder_config))
self.text_decoder = TextDecoder(
sample_key=text_decoder_config.sample_key,
prediction_key=text_embd_prediction_key,
Expand Down Expand Up @@ -165,6 +133,23 @@ def __init__(
self.multimodal_decoder.lm_head.weight
) # https://paperswithcode.com/method/weight-tying

# vision_queries: 256 queries for multimodal cross attention and 1 as vision cls token for contrastive learning
self.vision_queries = nn.Parameter(torch.randn(n_vision_queries + 1, vision_encoder_config.n_embd))
self.audio_queries = nn.Parameter(torch.randn(n_audio_queries + 1, audio_encoder_config.n_embd))
self.attn_pool = AttentionPooling(
n_embd=vision_encoder_config.n_embd,
n_head=n_pool_head,
bias=bias_attn_pool,
epsilon=epsilon_attn_pool,
attention_config=text_decoder_config.attention_config,
)
self.audio_attn_pool = AttentionPooling(
n_embd=audio_encoder_config.n_embd,
n_head=n_pool_head,
bias=bias_attn_pool,
epsilon=epsilon_attn_pool,
attention_config=text_decoder_config.attention_config,
)
# Logit scale for contrastive loss
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

Expand All @@ -180,18 +165,6 @@ def __init__(
/ math.sqrt(2 * (text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text)),
)

def _init_modality(self, encoder_class, encoder_config, n_queries):
encoder = encoder_class(**dict(encoder_config))
queries = nn.Parameter(torch.randn(n_queries + 1, encoder_config.n_embd))
attn_pool = AttentionPooling(
n_embd=encoder_config.n_embd,
n_head=self.n_pool_head,
bias=self.bias_attn_pool,
epsilon=self.epsilon_attn_pool,
attention_config=self.text_decoder_config.attention_config,
)
return encoder, queries, attn_pool

def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConfig):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)
Expand All @@ -201,43 +174,44 @@ def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConf
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# TODO: The "modality_key" needs to be implemented.
if inputs[self.modality_key][0] == self.AUDIO:
modality_embd, modality_cls_token = self._forward_encode_audio(inputs)
if inputs[self.modality_key][0] == self.VISION:
modality_embd, modality_cls_token = self._forward_encode_vision(inputs)
# TODO: select encodings based on input modalities. Adjust return accordingly.
# vision_embd, vision_cls_token = self._forward_encode_vision(inputs)
text_embd, text_cls_token = self._forward_encode_text(inputs)
logits = self._forward_decode(text_embd, modality_embd)
audio_embd, audio_cls_token = self._forward_encode_audio(inputs)
logits = self._forward_decode(text_embd, audio_embd)
return {
self.prediction_key: logits,
self.modality_cls_prediction_key: modality_cls_token,
# self.vision_cls_prediction_key: vision_cls_token,
self.audio_cls_prediction_key: audio_cls_token,
self.text_cls_prediction_key: text_cls_token,
self.logit_scale_prediction_key: self.logit_scale.exp(),
}

def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
vision_embd = self.vision_encoder(inputs)[self.modality_embd_prediction_key]
queries = repeat(self.vision_queries, "n d -> b n d", b=vision_embd.shape[0])
vision_embd = self.vision_attn_pool(queries, context=vision_embd)
vision_embd, vision_cls_token = vision_embd[:, :-1, :], F.normalize(vision_embd[:, -1, :], dim=-1)
return vision_embd, vision_cls_token

def _forward_encode_audio(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
audio_embd = self.audio_encoder(inputs)[self.modality_embd_prediction_key]
audio_embd = self.audio_encoder(inputs)[self.audio_embd_prediction_key]
queries = repeat(self.audio_queries, "n d -> b n d", b=audio_embd.shape[0])
audio_embd = self.audio_attn_pool(queries, context=audio_embd)
audio_embd, audio_cls_token = audio_embd[:, :-1, :], F.normalize(audio_embd[:, -1:, :], dim=-1)
return audio_embd, audio_cls_token

def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
vision_embd = self.vision_encoder(inputs)[self.vision_embd_prediction_key]
queries = repeat(self.vision_queries, "n d -> b n d", b=vision_embd.shape[0])
vision_embd = self.attn_pool(queries, context=vision_embd)
vision_embd, vision_cls_token = vision_embd[:, :-1, :], F.normalize(vision_embd[:, -1, :], dim=-1)
return vision_embd, vision_cls_token

def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
text_embd = self.text_decoder(inputs)[self.text_embd_prediction_key]
text_embd, text_cls_token = text_embd[:, :-1, :], F.normalize(text_embd[:, -1, :], dim=-1)
return text_embd, text_cls_token

def _forward_decode(self, text_embd: torch.Tensor, modality_embd: torch.Tensor) -> torch.Tensor:
def _forward_decode(self, text_embd: torch.Tensor, audio_embd: torch.Tensor) -> torch.Tensor:
# TODO: set decoder inputs based on input modalities.

decoder_inputs = {
self.text_embd_prediction_key: text_embd,
"context": modality_embd,
"context": audio_embd,
}
decoder_outputs = self.multimodal_decoder(decoder_inputs)
logits = decoder_outputs[self.multimodal_decoder.prediction_key]
Expand Down

0 comments on commit d93e6e5

Please sign in to comment.