Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into dependabot-pip-requirements-protobuf-lte-4…
Browse files Browse the repository at this point in the history
….24.2
  • Loading branch information
Borda authored Oct 8, 2023
2 parents 9d43cc8 + 3c55c76 commit 51d8a5f
Show file tree
Hide file tree
Showing 86 changed files with 238 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
DATASETS_VERBOSITY: warning

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/docs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
make-docs:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: true
- uses: actions/setup-python@v4
Expand Down Expand Up @@ -66,7 +66,7 @@ jobs:
test-docs:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: true
- uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pypi-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-20.04

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.8
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ repos:
- id: detect-private-key

- repo: https://github.com/asottile/pyupgrade
rev: v3.8.0
rev: v3.14.0
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand All @@ -48,20 +48,20 @@ repos:
- id: nbstripout

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.3
rev: v1.7.5
hooks:
- id: docformatter
additional_dependencies: [tomli]
args: ["--in-place"]

- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.9.1
hooks:
- id: black
name: Format code

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.276
rev: v0.0.292
hooks:
- id: ruff
args: ["--fix"]
3 changes: 3 additions & 0 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Inference is the process of generating predictions from trained models. To use a
Here's an example of inference:

.. testcode::
:skipif: flash.core.utilities.imports._TRANSFORMERS_GREATER_EQUAL_4_0

# import our libraries
from flash import Trainer
Expand All @@ -110,11 +111,13 @@ Here's an example of inference:
We get the following output:

.. testoutput::
:skipif: flash.core.utilities.imports._TRANSFORMERS_GREATER_EQUAL_4_0
:hide:

...

.. testcode::
:skipif: flash.core.utilities.imports._TRANSFORMERS_GREATER_EQUAL_4_0
:hide:

assert all(
Expand Down
4 changes: 2 additions & 2 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pytorch-lightning >1.8.0, <2.0.0 # strict
pyDeprecate >0.2.0
pandas >1.1.0, <=2.0.3
jsonargparse[signatures] >=4.22.0, <4.23.0
click >=7.1.2, <=8.1.6
click >=7.1.2, <8.2.0
protobuf <4.25.0
fsspec[http] >=2022.5.0,<=2023.6.0
fsspec[http] >=2022.5.0,<=2023.9.0
lightning-utilities >=0.4.1
2 changes: 1 addition & 1 deletion requirements/datatype_audio.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ numpy <1.26
torchaudio <=2.0.2
torchvision <=0.15.2
librosa >=0.8.1, <=0.10.0.post2
transformers >=4.13.0, <=4.30.2
transformers >=4.13.0, <4.34.0
datasets >1.16.1, <=2.14.3
4 changes: 2 additions & 2 deletions requirements/datatype_image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
torchvision <=0.15.2
timm >0.4.5, <=0.9.5 # effdet 0.3.0 depends on timm>=0.4.12
lightning-bolts >=0.7.0, <0.8.0
Pillow >8.0, <=10.0.0
Pillow >8.0, <10.1.0
albumentations >1.0.0, <=1.3.1
pystiche >1.0.0, <=1.0.1
ftfy >6.0.0, <=6.1.1
regex <=2023.6.3
regex <=2023.8.8
sahi >=0.8.19, <0.11 # strict - Fixes compatibility with icevision

icevision >0.8, <0.13.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/datatype_image_extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ classy-vision <=0.7.0
effdet <=0.4.1
kornia >0.5.1, <=0.6.12
learn2learn <=0.1.7; platform_system != "Windows" # dead
fastface <=0.1.3 # dead
fastface <=0.1.4 # dead
fairscale

# pinned PL so we force a compatible TM version
Expand Down
2 changes: 1 addition & 1 deletion requirements/datatype_pointcloud.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
open3d >=0.17.0, <0.18.0
# torch >=1.8.0, <1.9.0
# torchvision >0.9.0, <0.10.0
tensorboard <=2.13.0
tensorboard <=2.14.0
8 changes: 4 additions & 4 deletions requirements/datatype_text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

torchvision <=0.15.2
sentencepiece >=0.1.95, <=0.1.99
filelock <=3.12.2
transformers >=4.13.0, <=4.30.2
torchmetrics[text] >0.5.0, <1.1.0
filelock <=3.12.3
transformers >=4.13.0, <4.34.0
torchmetrics[text] >0.5.0, <1.3.0
datasets >=2.0.0, <=2.14.3
sentence-transformers <=2.2.2
ftfy <=6.1.1
regex <=2023.6.3
regex <=2023.8.8
2 changes: 1 addition & 1 deletion requirements/datatype_video.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup

torchvision <=0.15.2
Pillow >7.1, <=10.0.0
Pillow >7.1, <10.1.0
kornia >=0.5.1, <=0.6.12
pytorchvideo ==0.1.5

Expand Down
6 changes: 3 additions & 3 deletions requirements/serve.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup

pillow >9.0.0, <=10.0.0
pillow >9.0.0, <10.1.0
pyyaml >5.4, <=6.0.1
cytoolz >0.11, <=0.12.2
graphviz >=0.19, <=0.20.1
tqdm >4.60, <=4.65.0
tqdm >4.60, <=4.66.1
fastapi >0.65, <=0.103.0
pydantic >1.8.1, <2.0.0 # strict
starlette <=0.31.0
starlette <0.32.0
uvicorn[standard] >=0.12.0, <=0.23.2
aiofiles >22.1.0, <=23.1.0
jinja2 >=3.0.0, <3.2.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup

coverage[toml]
pytest ==7.4.0
pytest ==7.4.1
pytest-doctestplus ==0.13.0
pytest-rerunfailures ==12.0
pytest-forked ==1.6.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/testing_audio.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ torchvision ==0.15.2

timm >0.4.5, <=0.9.5 # effdet 0.3.0 depends on timm>=0.4.12
lightning-bolts >=0.7.0, <0.8.0
Pillow >8.0, <=10.0.0
Pillow >8.0, <10.1.0
albumentations >1.0.0, <=1.3.1
pystiche >1.0.0, <=1.0.1
10 changes: 8 additions & 2 deletions src/flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@


class AudioClassificationData(DataModule):
"""The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of
class methods for loading data for audio classification."""
"""The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of class
methods for loading data for audio classification."""

input_transform_cls = AudioClassificationInputTransform

Expand Down Expand Up @@ -141,6 +141,7 @@ def from_files(
>>> import os
>>> _ = [os.remove(f"spectrogram_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_spectrogram_{i}.png") for i in range(1, 4)]
"""

ds_kw = {
Expand Down Expand Up @@ -275,6 +276,7 @@ def from_folders(
>>> import shutil
>>> shutil.rmtree("train_folder")
>>> shutil.rmtree("predict_folder")
"""

ds_kw = {
Expand Down Expand Up @@ -365,6 +367,7 @@ def from_numpy(
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""

ds_kw = {
Expand Down Expand Up @@ -453,6 +456,7 @@ def from_tensors(
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""

ds_kw = {
Expand Down Expand Up @@ -607,6 +611,7 @@ def from_data_frame(
>>> shutil.rmtree("predict_folder")
>>> del train_data_frame
>>> del predict_data_frame
"""

ds_kw = {
Expand Down Expand Up @@ -854,6 +859,7 @@ def from_csv(
>>> shutil.rmtree("predict_folder")
>>> os.remove("train_data.tsv")
>>> os.remove("predict_data.tsv")
"""

ds_kw = {
Expand Down
4 changes: 4 additions & 0 deletions src/flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def from_files(
>>> import os
>>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
"""

ds_kw = {"sampling_rate": sampling_rate}
Expand Down Expand Up @@ -302,6 +303,7 @@ def from_csv(
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
>>> os.remove("train_data.tsv")
>>> os.remove("predict_data.tsv")
"""

ds_kw = {"input_key": input_field, "sampling_rate": sampling_rate}
Expand Down Expand Up @@ -424,6 +426,7 @@ def from_json(
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
>>> os.remove("train_data.json")
>>> os.remove("predict_data.json")
"""

ds_kw = {"input_key": input_field, "sampling_rate": sampling_rate, "field": field}
Expand Down Expand Up @@ -570,6 +573,7 @@ def from_datasets(
>>> import os
>>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
"""

ds_kw = {"sampling_rate": sampling_rate}
Expand Down
1 change: 1 addition & 0 deletions src/flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class SpeechRecognition(Task):
learning_rate: Learning rate to use for training, defaults to ``1e-5``.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
"""

backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES
Expand Down
2 changes: 2 additions & 0 deletions src/flash/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter":
"""Instantiate the adapter from the given :class:`~flash.core.model.Task`.
This includes resolution / creation of backbones / heads and any other provider specific options.
"""

def forward(self, x: Any) -> Any:
Expand Down Expand Up @@ -73,6 +74,7 @@ class AdapterTask(Task):
Args:
adapter: The :class:`~flash.core.adapter.Adapter` to wrap.
kwargs: Keyword arguments to be passed to the base :class:`~flash.core.model.Task`.
"""

def __init__(self, adapter: Adapter, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage):
As the :class:`~flash.core.data.io.input_transform.InputTransform` hooks are injected within
the threaded workers of the DataLoader,
the data won't be accessible when using ``num_workers > 0``.
"""

def _show(
Expand Down
2 changes: 2 additions & 0 deletions src/flash/core/data/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class FlashCallback(Callback):
Same as PyTorch Lightning, Callbacks can be provided directly to the Trainer::
trainer = Trainer(callbacks=[MyCustomCallback()])
"""

def on_per_sample_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
Expand Down Expand Up @@ -146,6 +147,7 @@ def from_inputs(
'val': {},
'predict': {}
}
"""

batches: dict
Expand Down
5 changes: 3 additions & 2 deletions src/flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@

class DatasetInput(Input):
"""The ``DatasetInput`` implements default behaviours for data sources which expect the input to
:meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset`
"""
:meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset`"""

def load_sample(self, sample: Any) -> Dict[str, Any]:
if isinstance(sample, tuple) and len(sample) == 2:
Expand Down Expand Up @@ -103,6 +102,7 @@ class DataModule(pl.LightningDataModule):
>>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1)
>>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<torch.utils.data.sampler.WeightedRandomSampler object at ...>
"""

input_transform_cls = InputTransform
Expand Down Expand Up @@ -399,6 +399,7 @@ def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
"""This function is used to configure a :class:`~flash.core.data.callback.BaseDataFetcher`.
Override with your custom one.
"""
return BaseDataFetcher()

Expand Down
2 changes: 2 additions & 0 deletions src/flash/core/data/io/classification_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ClassificationInputMixin(Properties):
targets and store metadata like ``labels`` and ``num_classes``.
* In the ``load_sample`` method, use ``format_target`` to convert the target to a standard format for use with our
tasks.
"""

target_formatter: TargetFormatter
Expand All @@ -46,6 +47,7 @@ def load_target_metadata(
rather than inferring from the targets.
add_background: If ``True``, a background class will be inserted as class zero if ``labels`` and
``num_classes`` are being inferred.
"""
self.target_formatter = target_formatter
if target_formatter is None and targets is not None:
Expand Down
Loading

0 comments on commit 51d8a5f

Please sign in to comment.