diff --git a/CHANGELOG.md b/CHANGELOG.md index e1ca71bf9b..7ab3453c84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- ### Changed @@ -17,7 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - +- Fixed `_strip` for eventually `int` ([#1667](https://github.com/Lightning-AI/lightning-flash/pull/1667)) +- Fixed checking `target_formatter` as attrib type ([#1665](https://github.com/Lightning-AI/lightning-flash/pull/1665)) +- Fixed remote check in `download_data` ([#1666](https://github.com/Lightning-AI/lightning-flash/pull/1666)) ## [0.8.2] - 2023-06-30 diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index 3d19f0cbf3..305d9900bd 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -1,7 +1,7 @@ # NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup torchvision <=0.15.2 -timm >0.4.5, <=0.9.2 # effdet 0.3.0 depends on timm>=0.4.12 +timm >0.4.5, <=0.9.5 # effdet 0.3.0 depends on timm>=0.4.12 lightning-bolts >=0.7.0, <0.8.0 Pillow >8.0, <=10.0.0 albumentations >1.0.0, <=1.3.1 diff --git a/requirements/serve.txt b/requirements/serve.txt index 1f49801318..165f6c5317 100644 --- a/requirements/serve.txt +++ b/requirements/serve.txt @@ -6,8 +6,8 @@ cytoolz >0.11, <=0.12.2 graphviz >=0.19, <=0.20.1 tqdm >4.60, <=4.65.0 fastapi >0.65, <=0.100.0 -pydantic >1.8.1, <=2.0.3 -starlette <=0.30.0 +pydantic >1.8.1, <=2.1.1 +starlette <=0.31.0 uvicorn[standard] >=0.12.0, <=0.23.2 aiofiles >22.1.0, <=23.1.0 jinja2 >=3.0.0, <3.2.0 diff --git a/requirements/testing_audio.txt b/requirements/testing_audio.txt index 7a0fcb577b..59b93175ae 100644 --- a/requirements/testing_audio.txt +++ b/requirements/testing_audio.txt @@ -3,7 +3,7 @@ torch ==2.0.1 torchaudio ==2.0.2 torchvision ==0.15.2 -timm >0.4.5, <=0.9.2 # effdet 0.3.0 depends on timm>=0.4.12 +timm >0.4.5, <=0.9.5 # effdet 0.3.0 depends on timm>=0.4.12 lightning-bolts >=0.7.0, <0.8.0 Pillow >8.0, <=10.0.0 albumentations >1.0.0, <=1.3.1 diff --git a/src/flash/audio/classification/input.py b/src/flash/audio/classification/input.py index d0174fdb57..0ed30a155f 100644 --- a/src/flash/audio/classification/input.py +++ b/src/flash/audio/classification/input.py @@ -138,6 +138,7 @@ def load_data( # If we had binary multi-class targets then we also know the labels (column names) if ( self.training + and hasattr(self, "target_formatter") and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List) ): diff --git a/src/flash/core/data/utilities/classification.py b/src/flash/core/data/utilities/classification.py index 19a40e0449..6bd6992a5a 100644 --- a/src/flash/core/data/utilities/classification.py +++ b/src/flash/core/data/utilities/classification.py @@ -42,8 +42,11 @@ def _as_list(x: Union[List, Tensor, np.ndarray]) -> List: return x -def _strip(x: str) -> str: - return x.strip(", ") +def _strip(x: Union[str, int]) -> str: + """Replace both ` ` and `,` from str.""" + if isinstance(x, str): + return x.strip(", ") + return str(x) @dataclass diff --git a/src/flash/core/data/utils.py b/src/flash/core/data/utils.py index e142615d74..8a6827a145 100644 --- a/src/flash/core/data/utils.py +++ b/src/flash/core/data/utils.py @@ -59,7 +59,7 @@ } -def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: +def download_data(url: str, path: str = "data/", verbose: bool = False, chunk_size: int = 1024) -> None: """Download file with progressbar. # Code adapted from: https://gist.github.com/ruxi/5d6803c116ec1130d484a4ab8c00c603 @@ -78,39 +78,42 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: [...] """ + local_filename = os.path.join(path, url.split("/")[-1]) + if os.path.exists(local_filename): + if verbose: + print(f"local file already exists: '{local_filename}'") + return + + os.makedirs(path, exist_ok=True) # Disable warning about making an insecure request urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - if not os.path.exists(path): - os.makedirs(path) - local_filename = os.path.join(path, url.split("/")[-1]) r = requests.get(url, stream=True, verify=False) file_size = int(r.headers["Content-Length"]) if "Content-Length" in r.headers else 0 - chunk_size = 1024 num_bars = int(file_size / chunk_size) if verbose: - print({"file_size": file_size}) - print({"num_bars": num_bars}) - - if not os.path.exists(local_filename): - with open(local_filename, "wb") as fp: - for chunk in tq( - r.iter_content(chunk_size=chunk_size), - total=num_bars, - unit="KB", - desc=local_filename, - leave=True, # progressbar stays - ): - fp.write(chunk) # type: ignore - - def extract_tarfile(file_path: str, extract_path: str, mode: str): - if os.path.exists(file_path): - with tarfile.open(file_path, mode=mode) as tar_ref: - for member in tar_ref.getmembers(): - try: - tar_ref.extract(member, path=extract_path, set_attrs=False) - except PermissionError: - raise PermissionError(f"Could not extract tar file {file_path}") + print(f"file size: {file_size}") + print(f"num bars: {num_bars}") + + with open(local_filename, "wb") as fp: + for chunk in tq( + r.iter_content(chunk_size=chunk_size), + total=num_bars, + unit="KB", + desc=local_filename, + leave=True, # progressbar stays + ): + fp.write(chunk) # type: ignore + + def extract_tarfile(file_path: str, extract_path: str, mode: str) -> None: + if not os.path.exists(file_path): + return + with tarfile.open(file_path, mode=mode) as tar_ref: + for member in tar_ref.getmembers(): + try: + tar_ref.extract(member, path=extract_path, set_attrs=False) + except PermissionError: + raise PermissionError(f"Could not extract tar file {file_path}") if ".zip" in local_filename: if os.path.exists(local_filename): diff --git a/src/flash/image/classification/input.py b/src/flash/image/classification/input.py index 7991f595a4..18a9c38096 100644 --- a/src/flash/image/classification/input.py +++ b/src/flash/image/classification/input.py @@ -157,6 +157,7 @@ def load_data( # If we had binary multi-class targets then we also know the labels (column names) if ( self.training + and hasattr(self, "target_formatter") and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List) ): diff --git a/src/flash/text/classification/input.py b/src/flash/text/classification/input.py index 54aae52532..cb37a69326 100644 --- a/src/flash/text/classification/input.py +++ b/src/flash/text/classification/input.py @@ -53,7 +53,11 @@ def load_data( self.load_target_metadata(targets, target_formatter=target_formatter) # If we had binary multi-class targets then we also know the labels (column names) - if isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List): + if ( + hasattr(self, "target_formatter") + and isinstance(self.target_formatter, MultiBinaryTargetFormatter) + and isinstance(target_keys, List) + ): self.labels = target_keys # remove extra columns diff --git a/src/flash/video/classification/input.py b/src/flash/video/classification/input.py index 2d8b42ab54..cc6bdb01cc 100644 --- a/src/flash/video/classification/input.py +++ b/src/flash/video/classification/input.py @@ -215,6 +215,7 @@ def load_data( # If we had binary multi-class targets then we also know the labels (column names) if ( self.training + and hasattr(self, "target_formatter") and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List) ): @@ -243,6 +244,7 @@ def load_data( # If we had binary multi-class targets then we also know the labels (column names) if ( self.training + and hasattr(self, "target_formatter") and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(targets, List) ):