Skip to content

Commit

Permalink
further improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanPonnachan committed Jan 7, 2025
1 parent 92a2223 commit 94352ce
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 46 deletions.
8 changes: 8 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
],
# Models
"models": [],
"models.aimv2": ["AIMv2Config"],
"models.albert": ["AlbertConfig"],
"models.align": [
"AlignConfig",
Expand Down Expand Up @@ -1404,6 +1405,13 @@

# PyTorch models structure

_import_structure["models.aimv2"].extend(
[
"AIMv2Model",
"AIMv2PreTrainedModel",
]
)

_import_structure["models.albert"].extend(
[
"AlbertForMaskedLM",
Expand Down
92 changes: 46 additions & 46 deletions src/transformers/models/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,59 +280,59 @@ def forward(
BaseModelOutputWithNoAttention,
]:
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`~AutoImageProcessor.__call__`] for details.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`~AutoImageProcessor.__call__`] for details.
mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to apply to the attention scores. A value of 1 indicates the position is not masked, and a value of 0
indicates the position is masked.
mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to apply to the attention scores. A value of 1 indicates the position is not masked, and a value of 0
indicates the position is masked.
<Tip>
<Tip>
What is the mask? Most models expect a value of 1, indicating the position *should* attend, and 0,
indicating the position *should not* attend. For example, if your input sequence length is 5 and you only
want to attend to the first 3 positions, the mask should be `[1, 1, 1, 0, 0]`.
What is the mask? Most models expect a value of 1, indicating the position *should* attend, and 0,
indicating the position *should not* attend. For example, if your input sequence length is 5 and you only
want to attend to the first 3 positions, the mask should be `[1, 1, 1, 0, 0]`.
</Tip>
</Tip>
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Returns:
Returns a tuple if not dictionary if config.use_return_dict is set to True, else a tuple.
x (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Hidden states of the output at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True` is passed or if `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
the model at the output of each layer plus the optional initial embedding outputs.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AIMv2Model
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-224")
>>> model = AIMv2Model.from_pretrained("apple/aimv2-large-patch14-224")
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
```
Returns a tuple if not dictionary if config.use_return_dict is set to True, else a tuple.
x (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Hidden states of the output at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True` is passed or if `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
the model at the output of each layer plus the optional initial embedding outputs.
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AIMv2Model
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-224")
>>> model = AIMv2Model.from_pretrained("apple/aimv2-large-patch14-224")
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down

0 comments on commit 94352ce

Please sign in to comment.