From b66ff158ffb68ece909fb4019667faf7acdafaec Mon Sep 17 00:00:00 2001
From: dbuscombe-usgs <dbuscombe@gmail.com>
Date: Fri, 10 Feb 2023 12:39:30 -0800
Subject: [PATCH] implement segformer model

---
 doodleverse_utils/imports.py            |  34 +--
 doodleverse_utils/model_imports.py      |  23 ++
 doodleverse_utils/prediction_imports.py | 352 ++++++++++++------------
 3 files changed, 221 insertions(+), 188 deletions(-)

diff --git a/doodleverse_utils/imports.py b/doodleverse_utils/imports.py
index 86bee4e..f9ace6d 100755
--- a/doodleverse_utils/imports.py
+++ b/doodleverse_utils/imports.py
@@ -3,7 +3,7 @@
 #
 # MIT License
 #
-# Copyright (c) 2021-2022, Marda Science LLC
+# Copyright (c) 2021-23, Marda Science LLC
 #
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to deal
@@ -175,7 +175,7 @@ def inpaint_nans(im):
 
 
 # -----------------------------------
-def plot_seg_history_iou(history, train_hist_fig):
+def plot_seg_history_iou(history, train_hist_fig, MODEL):
     """
     "plot_seg_history_iou(history, train_hist_fig)"
     This function plots the training history of a model
@@ -189,21 +189,23 @@ def plot_seg_history_iou(history, train_hist_fig):
     n = len(history.history["val_loss"])
 
     plt.figure(figsize=(20, 10))
-    plt.subplot(121)
-    plt.plot(
-        np.arange(1, n + 1), history.history["mean_iou"], "b", label="train accuracy"
-    )
-    plt.plot(
-        np.arange(1, n + 1),
-        history.history["val_mean_iou"],
-        "k",
-        label="validation accuracy",
-    )
-    plt.xlabel("Epoch number", fontsize=10)
-    plt.ylabel("Mean IoU Coefficient", fontsize=10)
-    plt.legend(fontsize=10)
+    if MODEL!='segformer':
+        plt.subplot(121)
+        plt.plot(
+            np.arange(1, n + 1), history.history["mean_iou"], "b", label="train accuracy"
+        )
+        plt.plot(
+            np.arange(1, n + 1),
+            history.history["val_mean_iou"],
+            "k",
+            label="validation accuracy",
+        )
+        plt.xlabel("Epoch number", fontsize=10)
+        plt.ylabel("Mean IoU Coefficient", fontsize=10)
+        plt.legend(fontsize=10)
+
+        plt.subplot(122)
 
-    plt.subplot(122)
     plt.plot(np.arange(1, n + 1), history.history["loss"], "b", label="train loss")
     plt.plot(
         np.arange(1, n + 1), history.history["val_loss"], "k", label="validation loss"
diff --git a/doodleverse_utils/model_imports.py b/doodleverse_utils/model_imports.py
index 73314f0..bc2bc95 100755
--- a/doodleverse_utils/model_imports.py
+++ b/doodleverse_utils/model_imports.py
@@ -30,6 +30,8 @@
 # keras functions for early stopping and model weights saving
 from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
 
+from transformers import TFSegformerForSemanticSegmentation
+
 SEED = 42
 np.random.seed(SEED)
 AUTO = tf.data.experimental.AUTOTUNE  # used in tf.data.Dataset API
@@ -43,6 +45,27 @@
 ###############################################################
 ### MODEL ARCHITECTURES
 ###############################################################
+def segformer(
+    id2label,
+    num_classes=2,
+):
+    """
+    https://keras.io/examples/vision/segformer/
+    https://huggingface.co/nvidia/mit-b0
+    """
+
+    label2id = {label: id for id, label in id2label.items()}
+    model_checkpoint = "nvidia/mit-b0"
+
+    model = TFSegformerForSemanticSegmentation.from_pretrained(
+        model_checkpoint,
+        num_labels=num_classes,
+        id2label=id2label,
+        label2id=label2id,
+        ignore_mismatched_sizes=True,
+    )
+    return model
+
 
 # -----------------------------------
 def simple_resunet(
diff --git a/doodleverse_utils/prediction_imports.py b/doodleverse_utils/prediction_imports.py
index 46d989e..3347c44 100755
--- a/doodleverse_utils/prediction_imports.py
+++ b/doodleverse_utils/prediction_imports.py
@@ -3,7 +3,7 @@
 #
 # MIT License
 #
-# Copyright (c) 2021-22, Marda Science LLC
+# Copyright (c) 2021-23, Marda Science LLC
 #
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to deal
@@ -25,14 +25,10 @@
 
 from .imports import standardize, label_to_colors, fromhex
 
-import os  # , time
-
+import os 
 import numpy as np
 import matplotlib.pyplot as plt
 from scipy import io
-
-# from skimage.filters.rank import median
-# from skimage.morphology import disk
 from tkinter import filedialog
 from tkinter import *
 from tkinter import messagebox
@@ -40,23 +36,15 @@
 from skimage.io import imsave, imread
 from numpy.lib.stride_tricks import as_strided as ast
 from glob import glob
-
-# from joblib import Parallel, delayed
-# from skimage.morphology import remove_small_holes, remove_small_objects
-# from scipy.ndimage import maximum_filter
 from skimage.transform import resize
-
-# from tqdm import tqdm
 from skimage.filters import threshold_otsu
 import matplotlib.pyplot as plt
 
 import tensorflow as tf  # numerical operations on gpu
 import tensorflow.keras.backend as K
 
-# #crf
 # import pydensecrf.densecrf as dcrf
 # from pydensecrf.utils import create_pairwise_bilateral, unary_from_softmax
-# # unary_from_labels, 
 
 SEED = 42
 np.random.seed(SEED)
@@ -69,130 +57,6 @@
 print("GPU name: ", tf.config.experimental.list_physical_devices("GPU"))
 print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices("GPU")))
 
-
-
-# #-----------------------------------
-# def crf_refine_from_integer_labels(label, img, nclasses = 2, theta_col=100, theta_spat=3, compat=120):
-#     """
-#     "crf_refine(label, img)"
-#     This function refines a label image based on an input label image and the associated image
-#     Uses a conditional random field algorithm using spatial and image features
-#     INPUTS:
-#         * label [ndarray]: label image 2D matrix of integers
-#         * image [ndarray]: image 3D matrix of integers
-#     OPTIONAL INPUTS: None
-#     GLOBAL INPUTS: None
-#     OUTPUTS: label [ndarray]: label image 2D matrix of integers
-#     """
-
-#     gx,gy = np.meshgrid(np.arange(img.shape[1]), np.arange(img.shape[0]))
-#     # print(gx.shape)
-#     img = np.dstack((img,gx,gy))
-
-#     H = label.shape[0]
-#     W = label.shape[1]
-#     U = unary_from_labels(1+label,nclasses,gt_prob=0.51)
-#     d = dcrf.DenseCRF2D(H, W, nclasses)
-#     d.setUnaryEnergy(U)
-
-#     # to add the color-independent term, where features are the locations only:
-#     d.addPairwiseGaussian(sxy=(theta_spat, theta_spat),
-#                  compat=3,
-#                  kernel=dcrf.DIAG_KERNEL,
-#                  normalization=dcrf.NORMALIZE_SYMMETRIC)
-#     feats = create_pairwise_bilateral(
-#                           sdims=(theta_col, theta_col),
-#                           schan=(2,2,2),
-#                           img=img,
-#                           chdim=2)
-
-#     d.addPairwiseEnergy(feats, compat=compat,kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC)
-#     Q = d.inference(20)
-#     kl1 = d.klDivergence(Q)
-#     return np.argmax(Q, axis=0).reshape((H, W)).astype(np.uint8), kl1
-
-
-
-# ##========================================================
-# def crf_refine(label,
-#     img,n,
-#     crf_theta_slider_value,
-#     crf_mu_slider_value,
-#     crf_downsample_factor): #gt_prob
-#     """
-#     "crf_refine(label, img)"
-#     This function refines a label image based on an input label image and the associated image
-#     Uses a conditional random field algorithm using spatial and image features
-#     INPUTS:
-#         * label [ndarray]: label image 2D matrix of integers
-#         * image [ndarray]: image 3D matrix of integers
-#     OPTIONAL INPUTS: None
-#     GLOBAL INPUTS: None
-#     OUTPUTS: label [ndarray]: label image 2D matrix of integers
-#     """
-
-#     Horig = img.shape[0]
-#     Worig = img.shape[1]
-#     l_unique = label.shape[-1]
-
-#     label = label.reshape(Horig,Worig,l_unique)
-
-#     scale = 1+(5 * (np.array(img.shape).max() / 3000))
-
-#     # decimate by factor by taking only every other row and column
-#     try:
-#         img = img[::crf_downsample_factor,::crf_downsample_factor, :]
-#     except:
-#         img = img[::crf_downsample_factor,::crf_downsample_factor]
-
-#     # do the same for the label image
-#     label = label[::crf_downsample_factor,::crf_downsample_factor]
-#     # yes, I know this aliases, but considering the task, it is ok; the objective is to
-#     # make fast inference and resize the output
-
-#     H = img.shape[0]
-#     W = img.shape[1]
-#     # U = unary_from_labels(np.argmax(label,-1).astype('int'), n, gt_prob=gt_prob)
-#     # d = dcrf.DenseCRF2D(H, W, n)
-
-#     U = unary_from_softmax(np.ascontiguousarray(np.rollaxis(label,-1,0)))
-#     d = dcrf.DenseCRF2D(H, W, l_unique)
-
-#     d.setUnaryEnergy(U)
-
-#     # to add the color-independent term, where features are the locations only:
-#     d.addPairwiseGaussian(sxy=(3, 3),
-#                  compat=3,
-#                  kernel=dcrf.DIAG_KERNEL,
-#                  normalization=dcrf.NORMALIZE_SYMMETRIC)
-
-#     try:
-#         feats = create_pairwise_bilateral(
-#                             sdims=(crf_theta_slider_value, crf_theta_slider_value),
-#                             schan=(scale,scale,scale),
-#                             img=img,
-#                             chdim=2)
-
-#         d.addPairwiseEnergy(feats, compat=crf_mu_slider_value, kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC)
-#     except:
-#         feats = create_pairwise_bilateral(
-#                             sdims=(crf_theta_slider_value, crf_theta_slider_value),
-#                             schan=(scale,scale, scale),
-#                             img=np.dstack((img,img,img)),
-#                             chdim=2)
-
-#         d.addPairwiseEnergy(feats, compat=crf_mu_slider_value, kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC)
-
-#     Q = d.inference(10)
-#     result = np.argmax(Q, axis=0).reshape((H, W)).astype(np.uint8) +1
-
-#     # uniq = np.unique(result.flatten())
-
-#     result = resize(result, (Horig, Worig), order=0, anti_aliasing=False) #True)
-#     result = rescale(result, 1, l_unique).astype(np.uint8)
-
-#     return result, l_unique
-
 ##========================================================
 def rescale(dat,
     mn,
@@ -204,10 +68,8 @@ def rescale(dat,
     M = max(dat.flatten())
     return (mx-mn)*(dat-m)/(M-m)+mn
 
-
-
 # #-----------------------------------
-def seg_file2tensor_ND(f, TARGET_SIZE):  # , resize):
+def seg_file2tensor_ND(f, TARGET_SIZE):  
     """
     "seg_file2tensor(f)"
     This function reads a NPZ image from file into a cropped and resized tensor,
@@ -226,7 +88,6 @@ def seg_file2tensor_ND(f, TARGET_SIZE):  # , resize):
     smallimage = resize(
         bigimage, (TARGET_SIZE[0], TARGET_SIZE[1]), preserve_range=True, clip=True
     )
-    # smallimage=bigimage.resize((TARGET_SIZE[1], TARGET_SIZE[0]))
     smallimage = np.array(smallimage)
     smallimage = tf.cast(smallimage, tf.uint8)
 
@@ -237,7 +98,7 @@ def seg_file2tensor_ND(f, TARGET_SIZE):  # , resize):
 
 
 # #-----------------------------------
-def seg_file2tensor_3band(f, TARGET_SIZE):  # , resize):
+def seg_file2tensor_3band(f, TARGET_SIZE):  
     """
     "seg_file2tensor(f)"
     This function reads a jpeg image from file into a cropped and resized tensor,
@@ -250,11 +111,10 @@ def seg_file2tensor_3band(f, TARGET_SIZE):  # , resize):
     GLOBAL INPUTS: TARGET_SIZE
     """
 
-    bigimage = imread(f)  # Image.open(f)
+    bigimage = imread(f)  
     smallimage = resize(
         bigimage, (TARGET_SIZE[0], TARGET_SIZE[1]), preserve_range=True, clip=True
     )
-    # smallimage=bigimage.resize((TARGET_SIZE[1], TARGET_SIZE[0]))
     smallimage = np.array(smallimage)
     smallimage = tf.cast(smallimage, tf.uint8)
 
@@ -306,6 +166,9 @@ def do_seg(
 
         image = standardize(image.numpy()).squeeze()
 
+        if MODEL=='segformer':
+            image = tf.transpose(image, (2, 0, 1))
+
         E0 = []
         E1 = []
 
@@ -313,35 +176,61 @@ def do_seg(
             # heatmap = make_gradcam_heatmap(tf.expand_dims(image, 0) , model)
 
             try:
-                est_label = model.predict(tf.expand_dims(image, 0), batch_size=1).squeeze()
+                if MODEL=='segformer':
+                    est_label = model.predict(tf.expand_dims(image, 0), batch_size=1).logits
+                else:
+                    est_label = model.predict(tf.expand_dims(image, 0), batch_size=1).squeeze()
             except:
-                est_label = model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1).squeeze()
+                if MODEL=='segformer':
+                    est_label = model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1).logits
+                else:
+                    est_label = model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1).squeeze()
 
             if TESTTIMEAUG == True:
                 # return the flipped prediction
-                est_label2 = np.flipud(
-                    model.predict(
-                        tf.expand_dims(np.flipud(image), 0), batch_size=1
-                    ).squeeze()
-                )
-                est_label3 = np.fliplr(
-                    model.predict(
-                        tf.expand_dims(np.fliplr(image), 0), batch_size=1
-                    ).squeeze()
-                )
-                est_label4 = np.flipud(
-                    np.fliplr(
+                if MODEL=='segformer':
+                    est_label2 = np.flipud(
+                        model.predict(tf.expand_dims(np.flipud(image), 0), batch_size=1).logits
+                        )
+                else:
+                    est_label2 = np.flipud(
+                        model.predict(tf.expand_dims(np.flipud(image), 0), batch_size=1).squeeze()
+                        )
+
+                if MODEL=='segformer':
+                    est_label3 = np.fliplr(
                         model.predict(
-                            tf.expand_dims(np.flipud(np.fliplr(image)), 0), batch_size=1
-                        ).squeeze()
-                    )
-                )
-
+                            tf.expand_dims(np.fliplr(image), 0), batch_size=1).logits
+                            )
+                else:
+                    est_label3 = np.fliplr(
+                        model.predict(
+                            tf.expand_dims(np.fliplr(image), 0), batch_size=1).squeeze()
+                            )
+                    
+                if MODEL=='segformer':
+                    est_label4 = np.flipud(
+                        np.fliplr(
+                            model.predict(
+                                tf.expand_dims(np.flipud(np.fliplr(image)), 0), batch_size=1).logits)
+                                )
+                else:
+                    est_label4 = np.flipud(
+                        np.fliplr(
+                            model.predict(
+                                tf.expand_dims(np.flipud(np.fliplr(image)), 0), batch_size=1).squeeze())
+                                )
+                    
                 # soft voting - sum the softmax scores to return the new TTA estimated softmax scores
                 est_label = est_label + est_label2 + est_label3 + est_label4
                 del est_label2, est_label3, est_label4
             
             est_label = est_label.astype('float32')
+
+            if MODEL=='segformer':
+                est_label = resize(est_label, (1, NCLASSES, TARGET_SIZE[0],TARGET_SIZE[1]), preserve_range=True, clip=True).squeeze()
+                est_label = np.transpose(est_label, (1,2,0))
+
             E0.append(
                 resize(est_label[:, :, 0], (w, h), preserve_range=True, clip=True)
             )
@@ -353,11 +242,11 @@ def do_seg(
         # heatmap = resize(heatmap,(w,h), preserve_range=True, clip=True)
         K.clear_session()
 
-        e0 = np.average(np.dstack(E0), axis=-1)  # , weights=np.array(MW))
+        e0 = np.average(np.dstack(E0), axis=-1)  
 
         del E0
 
-        e1 = np.average(np.dstack(E1), axis=-1)  # , weights=np.array(MW))
+        e1 = np.average(np.dstack(E1), axis=-1) 
         del E1
 
         est_label = (e1 + (1 - e0)) / 2
@@ -386,7 +275,6 @@ def do_seg(
                 metadatadict["otsu_threshold"] = thres
 
         else:
-            # print("Not using Otsu threshold")
             est_label = (est_label > 0.5).astype("uint8")
             if WRITE_MODELMETADATA:
                 metadatadict["otsu_threshold"] = 0.5            
@@ -396,13 +284,12 @@ def do_seg(
         if N_DATA_BANDS <= 3:
             image, w, h, bigimage = seg_file2tensor_3band(
                 f, TARGET_SIZE
-            )  # , resize=True)
+            )  
             w = w.numpy()
             h = h.numpy()
         else:
             image, w, h, bigimage = seg_file2tensor_ND(f, TARGET_SIZE)
 
-        # image = tf.image.per_image_standardization(image)
         image = standardize(image.numpy())
         # return the base prediction
         if N_DATA_BANDS == 1:
@@ -479,10 +366,7 @@ def do_seg(
     ]
     # add classes for more than 10 classes
 
-    # if NCLASSES > 1:
     class_label_colormap = class_label_colormap[:NCLASSES]
-    # else:
-    #     class_label_colormap = class_label_colormap[:2]
 
     if WRITE_MODELMETADATA:
         metadatadict["color_segmentation_output"] = segfile
@@ -804,3 +688,127 @@ def sliding_window(a, ws, ss=None, flatten=True):
     # dim = filter(lambda i : i != 1,dim)
 
     return a.reshape(dim), newshape
+
+
+
+# #-----------------------------------
+# def crf_refine_from_integer_labels(label, img, nclasses = 2, theta_col=100, theta_spat=3, compat=120):
+#     """
+#     "crf_refine(label, img)"
+#     This function refines a label image based on an input label image and the associated image
+#     Uses a conditional random field algorithm using spatial and image features
+#     INPUTS:
+#         * label [ndarray]: label image 2D matrix of integers
+#         * image [ndarray]: image 3D matrix of integers
+#     OPTIONAL INPUTS: None
+#     GLOBAL INPUTS: None
+#     OUTPUTS: label [ndarray]: label image 2D matrix of integers
+#     """
+
+#     gx,gy = np.meshgrid(np.arange(img.shape[1]), np.arange(img.shape[0]))
+#     # print(gx.shape)
+#     img = np.dstack((img,gx,gy))
+
+#     H = label.shape[0]
+#     W = label.shape[1]
+#     U = unary_from_labels(1+label,nclasses,gt_prob=0.51)
+#     d = dcrf.DenseCRF2D(H, W, nclasses)
+#     d.setUnaryEnergy(U)
+
+#     # to add the color-independent term, where features are the locations only:
+#     d.addPairwiseGaussian(sxy=(theta_spat, theta_spat),
+#                  compat=3,
+#                  kernel=dcrf.DIAG_KERNEL,
+#                  normalization=dcrf.NORMALIZE_SYMMETRIC)
+#     feats = create_pairwise_bilateral(
+#                           sdims=(theta_col, theta_col),
+#                           schan=(2,2,2),
+#                           img=img,
+#                           chdim=2)
+
+#     d.addPairwiseEnergy(feats, compat=compat,kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC)
+#     Q = d.inference(20)
+#     kl1 = d.klDivergence(Q)
+#     return np.argmax(Q, axis=0).reshape((H, W)).astype(np.uint8), kl1
+
+
+
+# ##========================================================
+# def crf_refine(label,
+#     img,n,
+#     crf_theta_slider_value,
+#     crf_mu_slider_value,
+#     crf_downsample_factor): #gt_prob
+#     """
+#     "crf_refine(label, img)"
+#     This function refines a label image based on an input label image and the associated image
+#     Uses a conditional random field algorithm using spatial and image features
+#     INPUTS:
+#         * label [ndarray]: label image 2D matrix of integers
+#         * image [ndarray]: image 3D matrix of integers
+#     OPTIONAL INPUTS: None
+#     GLOBAL INPUTS: None
+#     OUTPUTS: label [ndarray]: label image 2D matrix of integers
+#     """
+
+#     Horig = img.shape[0]
+#     Worig = img.shape[1]
+#     l_unique = label.shape[-1]
+
+#     label = label.reshape(Horig,Worig,l_unique)
+
+#     scale = 1+(5 * (np.array(img.shape).max() / 3000))
+
+#     # decimate by factor by taking only every other row and column
+#     try:
+#         img = img[::crf_downsample_factor,::crf_downsample_factor, :]
+#     except:
+#         img = img[::crf_downsample_factor,::crf_downsample_factor]
+
+#     # do the same for the label image
+#     label = label[::crf_downsample_factor,::crf_downsample_factor]
+#     # yes, I know this aliases, but considering the task, it is ok; the objective is to
+#     # make fast inference and resize the output
+
+#     H = img.shape[0]
+#     W = img.shape[1]
+#     # U = unary_from_labels(np.argmax(label,-1).astype('int'), n, gt_prob=gt_prob)
+#     # d = dcrf.DenseCRF2D(H, W, n)
+
+#     U = unary_from_softmax(np.ascontiguousarray(np.rollaxis(label,-1,0)))
+#     d = dcrf.DenseCRF2D(H, W, l_unique)
+
+#     d.setUnaryEnergy(U)
+
+#     # to add the color-independent term, where features are the locations only:
+#     d.addPairwiseGaussian(sxy=(3, 3),
+#                  compat=3,
+#                  kernel=dcrf.DIAG_KERNEL,
+#                  normalization=dcrf.NORMALIZE_SYMMETRIC)
+
+#     try:
+#         feats = create_pairwise_bilateral(
+#                             sdims=(crf_theta_slider_value, crf_theta_slider_value),
+#                             schan=(scale,scale,scale),
+#                             img=img,
+#                             chdim=2)
+
+#         d.addPairwiseEnergy(feats, compat=crf_mu_slider_value, kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC)
+#     except:
+#         feats = create_pairwise_bilateral(
+#                             sdims=(crf_theta_slider_value, crf_theta_slider_value),
+#                             schan=(scale,scale, scale),
+#                             img=np.dstack((img,img,img)),
+#                             chdim=2)
+
+#         d.addPairwiseEnergy(feats, compat=crf_mu_slider_value, kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC)
+
+#     Q = d.inference(10)
+#     result = np.argmax(Q, axis=0).reshape((H, W)).astype(np.uint8) +1
+
+#     # uniq = np.unique(result.flatten())
+
+#     result = resize(result, (Horig, Worig), order=0, anti_aliasing=False) #True)
+#     result = rescale(result, 1, l_unique).astype(np.uint8)
+
+#     return result, l_unique
\ No newline at end of file