Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Cascade support, new ReturnedEmbeddingsType #104

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/compel/embeddings_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ReturnedEmbeddingsType(Enum):
LAST_HIDDEN_STATES_NORMALIZED = 0 # SD1/2 regular
PENULTIMATE_HIDDEN_STATES_NORMALIZED = 1 # SD1.5 with "clip skip"
PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED = 2 # SDXL
STABLE_CASCADE = 3 # Stable Cascade


class EmbeddingsProvider:
Expand Down Expand Up @@ -234,7 +235,7 @@ def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool =
return result

def get_pooled_embeddings(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None, device: Optional[str]=None) -> Optional[torch.Tensor]:

device = device or self.device

token_ids = self.get_token_ids(texts, padding="max_length", truncation_override=True)
Expand All @@ -243,7 +244,10 @@ def get_pooled_embeddings(self, texts: List[str], attention_mask: Optional[torch
text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True)
pooled = text_encoder_output.text_embeds

return pooled
if self.returned_embeddings_type is ReturnedEmbeddingsType.STABLE_CASCADE:
return pooled.unsqueeze(1)
else:
return pooled


def get_token_ids_and_expand_weights(self, fragments: List[str], weights: List[float], device: str
Expand Down Expand Up @@ -386,7 +390,8 @@ def build_weighted_embedding_tensor(self,
def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
needs_hidden_states = (self.returned_embeddings_type == ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED or
self.returned_embeddings_type == ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED)
self.returned_embeddings_type == ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED or
self.returned_embeddings_type == ReturnedEmbeddingsType.STABLE_CASCADE)
text_encoder_output = self.text_encoder(token_ids,
attention_mask,
output_hidden_states=needs_hidden_states,
Expand All @@ -400,6 +405,9 @@ def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor,
elif self.returned_embeddings_type is ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED:
# already normalized
return text_encoder_output.last_hidden_state
elif self.returned_embeddings_type is ReturnedEmbeddingsType.STABLE_CASCADE:
# last_hidden_state attribute does not work, non-intuitive
return text_encoder_output.hidden_states[-1]

assert False, f"unrecognized ReturnEmbeddingsType: {self.returned_embeddings_type}"

Expand Down