Skip to content

Commit

Permalink
implement segformer model
Browse files Browse the repository at this point in the history
  • Loading branch information
dbuscombe-usgs committed Feb 10, 2023
1 parent ab0e0c0 commit b66ff15
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 188 deletions.
34 changes: 18 additions & 16 deletions doodleverse_utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# MIT License
#
# Copyright (c) 2021-2022, Marda Science LLC
# Copyright (c) 2021-23, Marda Science LLC
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -175,7 +175,7 @@ def inpaint_nans(im):


# -----------------------------------
def plot_seg_history_iou(history, train_hist_fig):
def plot_seg_history_iou(history, train_hist_fig, MODEL):
"""
"plot_seg_history_iou(history, train_hist_fig)"
This function plots the training history of a model
Expand All @@ -189,21 +189,23 @@ def plot_seg_history_iou(history, train_hist_fig):
n = len(history.history["val_loss"])

plt.figure(figsize=(20, 10))
plt.subplot(121)
plt.plot(
np.arange(1, n + 1), history.history["mean_iou"], "b", label="train accuracy"
)
plt.plot(
np.arange(1, n + 1),
history.history["val_mean_iou"],
"k",
label="validation accuracy",
)
plt.xlabel("Epoch number", fontsize=10)
plt.ylabel("Mean IoU Coefficient", fontsize=10)
plt.legend(fontsize=10)
if MODEL!='segformer':
plt.subplot(121)
plt.plot(
np.arange(1, n + 1), history.history["mean_iou"], "b", label="train accuracy"
)
plt.plot(
np.arange(1, n + 1),
history.history["val_mean_iou"],
"k",
label="validation accuracy",
)
plt.xlabel("Epoch number", fontsize=10)
plt.ylabel("Mean IoU Coefficient", fontsize=10)
plt.legend(fontsize=10)

plt.subplot(122)

plt.subplot(122)
plt.plot(np.arange(1, n + 1), history.history["loss"], "b", label="train loss")
plt.plot(
np.arange(1, n + 1), history.history["val_loss"], "k", label="validation loss"
Expand Down
23 changes: 23 additions & 0 deletions doodleverse_utils/model_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
# keras functions for early stopping and model weights saving
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

from transformers import TFSegformerForSemanticSegmentation

SEED = 42
np.random.seed(SEED)
AUTO = tf.data.experimental.AUTOTUNE # used in tf.data.Dataset API
Expand All @@ -43,6 +45,27 @@
###############################################################
### MODEL ARCHITECTURES
###############################################################
def segformer(
id2label,
num_classes=2,
):
"""
https://keras.io/examples/vision/segformer/
https://huggingface.co/nvidia/mit-b0
"""

label2id = {label: id for id, label in id2label.items()}
model_checkpoint = "nvidia/mit-b0"

model = TFSegformerForSemanticSegmentation.from_pretrained(
model_checkpoint,
num_labels=num_classes,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
return model


# -----------------------------------
def simple_resunet(
Expand Down
Loading

0 comments on commit b66ff15

Please sign in to comment.