Skip to content

Commit

Permalink
Update load warnings (#126)
Browse files Browse the repository at this point in the history
farzadab authored Oct 18, 2024
1 parent 7295369 commit 9a2a6a8
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
@@ -34,14 +34,8 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):

config_class = UltravoxConfig
config: UltravoxConfig # for type hinting
# We minimize the weights in state_dict in order to reduce the size of the checkpoint
# The issue is that load_pretrained() uses state_dict() keys to know what keys are expected
# As such we have to tell is to ignore some keys that are not always in the model
_keys_to_ignore_on_load_unexpected = ["audio_tower.*", "language_model.*"]
# Usually we load encoder weights from a pretrained model, so we don't want to load the decoder weights
# Technically we never hit this issue because these keys are already removed from state_dict() however,
# but there's no harm in keeping it here for when we change that behavior.
_keys_to_ignore_on_load_missing = ["audio_tower.*"]
# Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
_keys_to_ignore_on_load_missing = ["audio_tower.*", "language_model.*"]

def __init__(self, config: UltravoxConfig):
super().__init__(config)

0 comments on commit 9a2a6a8

Please sign in to comment.