From fcf39ee0ab5cefd619b026a6463bbd87b89d0584 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Wed, 27 Jan 2021 21:33:06 -0500 Subject: [PATCH] Teddy/final mac examples * finetuning image classification final * finetuning tabular classification final * finetuning text classification final * predict image final --- flash/core/data/datamodule.py | 7 ++++--- flash_examples/finetuning/image_classification.py | 10 +++++----- flash_examples/finetuning/tabular_classification.py | 13 ++++++------- flash_examples/finetuning/text_classification.py | 6 +++--- flash_examples/predict/image_classification.py | 2 +- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/flash/core/data/datamodule.py b/flash/core/data/datamodule.py index dbb9faa689..9d1bc377fb 100644 --- a/flash/core/data/datamodule.py +++ b/flash/core/data/datamodule.py @@ -51,10 +51,11 @@ def __init__( self.batch_size = batch_size + # TODO: figure out best solution for setting num_workers + # if num_workers is None: + # num_workers = os.cpu_count() if num_workers is None: - num_workers = os.cpu_count() - if num_workers is None: - warnings.warn("Could not infer cpu count automatically, setting it to zero") + # warnings.warn("Could not infer cpu count automatically, setting it to zero") num_workers = 0 self.num_workers = num_workers diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index e201f308bd..f306529d7c 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -1,11 +1,11 @@ # import our libraries -import pytorch_lightning as pl import torch +import flash from flash.vision import ImageClassificationData, ImageClassifier -from flash.vision.classification.dataset import hymenoptera_data_download +from flash.core.data import download_data # 1. Download data -hymenoptera_data_download('data/') +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') # 2. Organize our data datamodule = ImageClassificationData.from_folders( @@ -17,10 +17,10 @@ model = ImageClassifier(num_classes=datamodule.num_classes) # 4. Create trainer -trainer = pl.Trainer(max_epochs=1) +trainer = flash.Trainer(max_epochs=1) # 5. Train the model -trainer.fit(model, datamodule=datamodule) +trainer.finetune(model, datamodule=datamodule) # 6. Save model torch.save(model, "image_classification_model.pt") diff --git a/flash_examples/finetuning/tabular_classification.py b/flash_examples/finetuning/tabular_classification.py index 02e4d28f98..533b93d588 100644 --- a/flash_examples/finetuning/tabular_classification.py +++ b/flash_examples/finetuning/tabular_classification.py @@ -1,18 +1,17 @@ # import our libraries -import pandas as pd -import pytorch_lightning as pl import torch from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall +import flash from flash.tabular import TabularClassifier, TabularData -from flash.tabular.classification.data.dataset import titanic_data_download +from flash.core.data import download_data # 1. Download data -titanic_data_download("./data/titanic") +download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') # 2. Organize our data - create a LightningDataModule -datamodule = TabularData.from_df( - pd.read_csv("./data/titanic/titanic.csv"), +datamodule = TabularData.from_csv( + "./data/titanic/titanic.csv", categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], numerical_input=["Fare"], target="Survived", @@ -23,7 +22,7 @@ model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()]) # 4. Create trainer -trainer = pl.Trainer(max_epochs=1) +trainer = flash.Trainer(max_epochs=1) # 5. Train model trainer.fit(model, datamodule=datamodule) diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index db368d2afa..d99c156cdf 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -1,7 +1,7 @@ # import our libraries -import pytorch_lightning as pl import torch +import flash from flash.core.data import download_data from flash.text import TextClassificationData, TextClassifier @@ -20,10 +20,10 @@ model = TextClassifier(num_classes=datamodule.num_classes) # 4. Create trainer - Make training slightly faster for demo. -trainer = pl.Trainer(max_epochs=1, limit_train_batches=8, limit_val_batches=8) +trainer = flash.Trainer(max_epochs=1, limit_train_batches=8, limit_val_batches=8) # 5. Finetune the model -trainer.fit(model, datamodule=datamodule) +trainer.finetune(model, datamodule=datamodule) # 6. Save model torch.save(model, "text_classification_model.pt") diff --git a/flash_examples/predict/image_classification.py b/flash_examples/predict/image_classification.py index 6eaf659be8..46fdaea523 100644 --- a/flash_examples/predict/image_classification.py +++ b/flash_examples/predict/image_classification.py @@ -1,7 +1,7 @@ # import our libraries import torch -from pytorch_lightning import Trainer +from flash import Trainer from flash.vision import ImageClassificationData # 1. Load model from checkpoint