Skip to content

Commit

Permalink
Teddy/final mac examples
Browse files Browse the repository at this point in the history
* finetuning image classification final

* finetuning tabular classification final

* finetuning text classification final

* predict image final
  • Loading branch information
teddykoker authored and Borda committed Jan 28, 2021
1 parent 8149000 commit fcf39ee
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 19 deletions.
7 changes: 4 additions & 3 deletions flash/core/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ def __init__(

self.batch_size = batch_size

# TODO: figure out best solution for setting num_workers
# if num_workers is None:
# num_workers = os.cpu_count()
if num_workers is None:
num_workers = os.cpu_count()
if num_workers is None:
warnings.warn("Could not infer cpu count automatically, setting it to zero")
# warnings.warn("Could not infer cpu count automatically, setting it to zero")
num_workers = 0
self.num_workers = num_workers

Expand Down
10 changes: 5 additions & 5 deletions flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# import our libraries
import pytorch_lightning as pl
import torch
import flash
from flash.vision import ImageClassificationData, ImageClassifier
from flash.vision.classification.dataset import hymenoptera_data_download
from flash.core.data import download_data

# 1. Download data
hymenoptera_data_download('data/')
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Organize our data
datamodule = ImageClassificationData.from_folders(
Expand All @@ -17,10 +17,10 @@
model = ImageClassifier(num_classes=datamodule.num_classes)

# 4. Create trainer
trainer = pl.Trainer(max_epochs=1)
trainer = flash.Trainer(max_epochs=1)

# 5. Train the model
trainer.fit(model, datamodule=datamodule)
trainer.finetune(model, datamodule=datamodule)

# 6. Save model
torch.save(model, "image_classification_model.pt")
13 changes: 6 additions & 7 deletions flash_examples/finetuning/tabular_classification.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
# import our libraries
import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall

import flash
from flash.tabular import TabularClassifier, TabularData
from flash.tabular.classification.data.dataset import titanic_data_download
from flash.core.data import download_data

# 1. Download data
titanic_data_download("./data/titanic")
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

# 2. Organize our data - create a LightningDataModule
datamodule = TabularData.from_df(
pd.read_csv("./data/titanic/titanic.csv"),
datamodule = TabularData.from_csv(
"./data/titanic/titanic.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
target="Survived",
Expand All @@ -23,7 +22,7 @@
model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])

# 4. Create trainer
trainer = pl.Trainer(max_epochs=1)
trainer = flash.Trainer(max_epochs=1)

# 5. Train model
trainer.fit(model, datamodule=datamodule)
Expand Down
6 changes: 3 additions & 3 deletions flash_examples/finetuning/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# import our libraries
import pytorch_lightning as pl
import torch

import flash
from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier

Expand All @@ -20,10 +20,10 @@
model = TextClassifier(num_classes=datamodule.num_classes)

# 4. Create trainer - Make training slightly faster for demo.
trainer = pl.Trainer(max_epochs=1, limit_train_batches=8, limit_val_batches=8)
trainer = flash.Trainer(max_epochs=1, limit_train_batches=8, limit_val_batches=8)

# 5. Finetune the model
trainer.fit(model, datamodule=datamodule)
trainer.finetune(model, datamodule=datamodule)

# 6. Save model
torch.save(model, "text_classification_model.pt")
2 changes: 1 addition & 1 deletion flash_examples/predict/image_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# import our libraries
import torch
from pytorch_lightning import Trainer

from flash import Trainer
from flash.vision import ImageClassificationData

# 1. Load model from checkpoint
Expand Down

0 comments on commit fcf39ee

Please sign in to comment.