From b66ff158ffb68ece909fb4019667faf7acdafaec Mon Sep 17 00:00:00 2001 From: dbuscombe-usgs <dbuscombe@gmail.com> Date: Fri, 10 Feb 2023 12:39:30 -0800 Subject: [PATCH] implement segformer model --- doodleverse_utils/imports.py | 34 +-- doodleverse_utils/model_imports.py | 23 ++ doodleverse_utils/prediction_imports.py | 352 ++++++++++++------------ 3 files changed, 221 insertions(+), 188 deletions(-) diff --git a/doodleverse_utils/imports.py b/doodleverse_utils/imports.py index 86bee4e..f9ace6d 100755 --- a/doodleverse_utils/imports.py +++ b/doodleverse_utils/imports.py @@ -3,7 +3,7 @@ # # MIT License # -# Copyright (c) 2021-2022, Marda Science LLC +# Copyright (c) 2021-23, Marda Science LLC # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -175,7 +175,7 @@ def inpaint_nans(im): # ----------------------------------- -def plot_seg_history_iou(history, train_hist_fig): +def plot_seg_history_iou(history, train_hist_fig, MODEL): """ "plot_seg_history_iou(history, train_hist_fig)" This function plots the training history of a model @@ -189,21 +189,23 @@ def plot_seg_history_iou(history, train_hist_fig): n = len(history.history["val_loss"]) plt.figure(figsize=(20, 10)) - plt.subplot(121) - plt.plot( - np.arange(1, n + 1), history.history["mean_iou"], "b", label="train accuracy" - ) - plt.plot( - np.arange(1, n + 1), - history.history["val_mean_iou"], - "k", - label="validation accuracy", - ) - plt.xlabel("Epoch number", fontsize=10) - plt.ylabel("Mean IoU Coefficient", fontsize=10) - plt.legend(fontsize=10) + if MODEL!='segformer': + plt.subplot(121) + plt.plot( + np.arange(1, n + 1), history.history["mean_iou"], "b", label="train accuracy" + ) + plt.plot( + np.arange(1, n + 1), + history.history["val_mean_iou"], + "k", + label="validation accuracy", + ) + plt.xlabel("Epoch number", fontsize=10) + plt.ylabel("Mean IoU Coefficient", fontsize=10) + plt.legend(fontsize=10) + + plt.subplot(122) - plt.subplot(122) plt.plot(np.arange(1, n + 1), history.history["loss"], "b", label="train loss") plt.plot( np.arange(1, n + 1), history.history["val_loss"], "k", label="validation loss" diff --git a/doodleverse_utils/model_imports.py b/doodleverse_utils/model_imports.py index 73314f0..bc2bc95 100755 --- a/doodleverse_utils/model_imports.py +++ b/doodleverse_utils/model_imports.py @@ -30,6 +30,8 @@ # keras functions for early stopping and model weights saving from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint +from transformers import TFSegformerForSemanticSegmentation + SEED = 42 np.random.seed(SEED) AUTO = tf.data.experimental.AUTOTUNE # used in tf.data.Dataset API @@ -43,6 +45,27 @@ ############################################################### ### MODEL ARCHITECTURES ############################################################### +def segformer( + id2label, + num_classes=2, +): + """ + https://keras.io/examples/vision/segformer/ + https://huggingface.co/nvidia/mit-b0 + """ + + label2id = {label: id for id, label in id2label.items()} + model_checkpoint = "nvidia/mit-b0" + + model = TFSegformerForSemanticSegmentation.from_pretrained( + model_checkpoint, + num_labels=num_classes, + id2label=id2label, + label2id=label2id, + ignore_mismatched_sizes=True, + ) + return model + # ----------------------------------- def simple_resunet( diff --git a/doodleverse_utils/prediction_imports.py b/doodleverse_utils/prediction_imports.py index 46d989e..3347c44 100755 --- a/doodleverse_utils/prediction_imports.py +++ b/doodleverse_utils/prediction_imports.py @@ -3,7 +3,7 @@ # # MIT License # -# Copyright (c) 2021-22, Marda Science LLC +# Copyright (c) 2021-23, Marda Science LLC # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -25,14 +25,10 @@ from .imports import standardize, label_to_colors, fromhex -import os # , time - +import os import numpy as np import matplotlib.pyplot as plt from scipy import io - -# from skimage.filters.rank import median -# from skimage.morphology import disk from tkinter import filedialog from tkinter import * from tkinter import messagebox @@ -40,23 +36,15 @@ from skimage.io import imsave, imread from numpy.lib.stride_tricks import as_strided as ast from glob import glob - -# from joblib import Parallel, delayed -# from skimage.morphology import remove_small_holes, remove_small_objects -# from scipy.ndimage import maximum_filter from skimage.transform import resize - -# from tqdm import tqdm from skimage.filters import threshold_otsu import matplotlib.pyplot as plt import tensorflow as tf # numerical operations on gpu import tensorflow.keras.backend as K -# #crf # import pydensecrf.densecrf as dcrf # from pydensecrf.utils import create_pairwise_bilateral, unary_from_softmax -# # unary_from_labels, SEED = 42 np.random.seed(SEED) @@ -69,130 +57,6 @@ print("GPU name: ", tf.config.experimental.list_physical_devices("GPU")) print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices("GPU"))) - - -# #----------------------------------- -# def crf_refine_from_integer_labels(label, img, nclasses = 2, theta_col=100, theta_spat=3, compat=120): -# """ -# "crf_refine(label, img)" -# This function refines a label image based on an input label image and the associated image -# Uses a conditional random field algorithm using spatial and image features -# INPUTS: -# * label [ndarray]: label image 2D matrix of integers -# * image [ndarray]: image 3D matrix of integers -# OPTIONAL INPUTS: None -# GLOBAL INPUTS: None -# OUTPUTS: label [ndarray]: label image 2D matrix of integers -# """ - -# gx,gy = np.meshgrid(np.arange(img.shape[1]), np.arange(img.shape[0])) -# # print(gx.shape) -# img = np.dstack((img,gx,gy)) - -# H = label.shape[0] -# W = label.shape[1] -# U = unary_from_labels(1+label,nclasses,gt_prob=0.51) -# d = dcrf.DenseCRF2D(H, W, nclasses) -# d.setUnaryEnergy(U) - -# # to add the color-independent term, where features are the locations only: -# d.addPairwiseGaussian(sxy=(theta_spat, theta_spat), -# compat=3, -# kernel=dcrf.DIAG_KERNEL, -# normalization=dcrf.NORMALIZE_SYMMETRIC) -# feats = create_pairwise_bilateral( -# sdims=(theta_col, theta_col), -# schan=(2,2,2), -# img=img, -# chdim=2) - -# d.addPairwiseEnergy(feats, compat=compat,kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC) -# Q = d.inference(20) -# kl1 = d.klDivergence(Q) -# return np.argmax(Q, axis=0).reshape((H, W)).astype(np.uint8), kl1 - - - -# ##======================================================== -# def crf_refine(label, -# img,n, -# crf_theta_slider_value, -# crf_mu_slider_value, -# crf_downsample_factor): #gt_prob -# """ -# "crf_refine(label, img)" -# This function refines a label image based on an input label image and the associated image -# Uses a conditional random field algorithm using spatial and image features -# INPUTS: -# * label [ndarray]: label image 2D matrix of integers -# * image [ndarray]: image 3D matrix of integers -# OPTIONAL INPUTS: None -# GLOBAL INPUTS: None -# OUTPUTS: label [ndarray]: label image 2D matrix of integers -# """ - -# Horig = img.shape[0] -# Worig = img.shape[1] -# l_unique = label.shape[-1] - -# label = label.reshape(Horig,Worig,l_unique) - -# scale = 1+(5 * (np.array(img.shape).max() / 3000)) - -# # decimate by factor by taking only every other row and column -# try: -# img = img[::crf_downsample_factor,::crf_downsample_factor, :] -# except: -# img = img[::crf_downsample_factor,::crf_downsample_factor] - -# # do the same for the label image -# label = label[::crf_downsample_factor,::crf_downsample_factor] -# # yes, I know this aliases, but considering the task, it is ok; the objective is to -# # make fast inference and resize the output - -# H = img.shape[0] -# W = img.shape[1] -# # U = unary_from_labels(np.argmax(label,-1).astype('int'), n, gt_prob=gt_prob) -# # d = dcrf.DenseCRF2D(H, W, n) - -# U = unary_from_softmax(np.ascontiguousarray(np.rollaxis(label,-1,0))) -# d = dcrf.DenseCRF2D(H, W, l_unique) - -# d.setUnaryEnergy(U) - -# # to add the color-independent term, where features are the locations only: -# d.addPairwiseGaussian(sxy=(3, 3), -# compat=3, -# kernel=dcrf.DIAG_KERNEL, -# normalization=dcrf.NORMALIZE_SYMMETRIC) - -# try: -# feats = create_pairwise_bilateral( -# sdims=(crf_theta_slider_value, crf_theta_slider_value), -# schan=(scale,scale,scale), -# img=img, -# chdim=2) - -# d.addPairwiseEnergy(feats, compat=crf_mu_slider_value, kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC) -# except: -# feats = create_pairwise_bilateral( -# sdims=(crf_theta_slider_value, crf_theta_slider_value), -# schan=(scale,scale, scale), -# img=np.dstack((img,img,img)), -# chdim=2) - -# d.addPairwiseEnergy(feats, compat=crf_mu_slider_value, kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC) - -# Q = d.inference(10) -# result = np.argmax(Q, axis=0).reshape((H, W)).astype(np.uint8) +1 - -# # uniq = np.unique(result.flatten()) - -# result = resize(result, (Horig, Worig), order=0, anti_aliasing=False) #True) -# result = rescale(result, 1, l_unique).astype(np.uint8) - -# return result, l_unique - ##======================================================== def rescale(dat, mn, @@ -204,10 +68,8 @@ def rescale(dat, M = max(dat.flatten()) return (mx-mn)*(dat-m)/(M-m)+mn - - # #----------------------------------- -def seg_file2tensor_ND(f, TARGET_SIZE): # , resize): +def seg_file2tensor_ND(f, TARGET_SIZE): """ "seg_file2tensor(f)" This function reads a NPZ image from file into a cropped and resized tensor, @@ -226,7 +88,6 @@ def seg_file2tensor_ND(f, TARGET_SIZE): # , resize): smallimage = resize( bigimage, (TARGET_SIZE[0], TARGET_SIZE[1]), preserve_range=True, clip=True ) - # smallimage=bigimage.resize((TARGET_SIZE[1], TARGET_SIZE[0])) smallimage = np.array(smallimage) smallimage = tf.cast(smallimage, tf.uint8) @@ -237,7 +98,7 @@ def seg_file2tensor_ND(f, TARGET_SIZE): # , resize): # #----------------------------------- -def seg_file2tensor_3band(f, TARGET_SIZE): # , resize): +def seg_file2tensor_3band(f, TARGET_SIZE): """ "seg_file2tensor(f)" This function reads a jpeg image from file into a cropped and resized tensor, @@ -250,11 +111,10 @@ def seg_file2tensor_3band(f, TARGET_SIZE): # , resize): GLOBAL INPUTS: TARGET_SIZE """ - bigimage = imread(f) # Image.open(f) + bigimage = imread(f) smallimage = resize( bigimage, (TARGET_SIZE[0], TARGET_SIZE[1]), preserve_range=True, clip=True ) - # smallimage=bigimage.resize((TARGET_SIZE[1], TARGET_SIZE[0])) smallimage = np.array(smallimage) smallimage = tf.cast(smallimage, tf.uint8) @@ -306,6 +166,9 @@ def do_seg( image = standardize(image.numpy()).squeeze() + if MODEL=='segformer': + image = tf.transpose(image, (2, 0, 1)) + E0 = [] E1 = [] @@ -313,35 +176,61 @@ def do_seg( # heatmap = make_gradcam_heatmap(tf.expand_dims(image, 0) , model) try: - est_label = model.predict(tf.expand_dims(image, 0), batch_size=1).squeeze() + if MODEL=='segformer': + est_label = model.predict(tf.expand_dims(image, 0), batch_size=1).logits + else: + est_label = model.predict(tf.expand_dims(image, 0), batch_size=1).squeeze() except: - est_label = model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1).squeeze() + if MODEL=='segformer': + est_label = model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1).logits + else: + est_label = model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1).squeeze() if TESTTIMEAUG == True: # return the flipped prediction - est_label2 = np.flipud( - model.predict( - tf.expand_dims(np.flipud(image), 0), batch_size=1 - ).squeeze() - ) - est_label3 = np.fliplr( - model.predict( - tf.expand_dims(np.fliplr(image), 0), batch_size=1 - ).squeeze() - ) - est_label4 = np.flipud( - np.fliplr( + if MODEL=='segformer': + est_label2 = np.flipud( + model.predict(tf.expand_dims(np.flipud(image), 0), batch_size=1).logits + ) + else: + est_label2 = np.flipud( + model.predict(tf.expand_dims(np.flipud(image), 0), batch_size=1).squeeze() + ) + + if MODEL=='segformer': + est_label3 = np.fliplr( model.predict( - tf.expand_dims(np.flipud(np.fliplr(image)), 0), batch_size=1 - ).squeeze() - ) - ) - + tf.expand_dims(np.fliplr(image), 0), batch_size=1).logits + ) + else: + est_label3 = np.fliplr( + model.predict( + tf.expand_dims(np.fliplr(image), 0), batch_size=1).squeeze() + ) + + if MODEL=='segformer': + est_label4 = np.flipud( + np.fliplr( + model.predict( + tf.expand_dims(np.flipud(np.fliplr(image)), 0), batch_size=1).logits) + ) + else: + est_label4 = np.flipud( + np.fliplr( + model.predict( + tf.expand_dims(np.flipud(np.fliplr(image)), 0), batch_size=1).squeeze()) + ) + # soft voting - sum the softmax scores to return the new TTA estimated softmax scores est_label = est_label + est_label2 + est_label3 + est_label4 del est_label2, est_label3, est_label4 est_label = est_label.astype('float32') + + if MODEL=='segformer': + est_label = resize(est_label, (1, NCLASSES, TARGET_SIZE[0],TARGET_SIZE[1]), preserve_range=True, clip=True).squeeze() + est_label = np.transpose(est_label, (1,2,0)) + E0.append( resize(est_label[:, :, 0], (w, h), preserve_range=True, clip=True) ) @@ -353,11 +242,11 @@ def do_seg( # heatmap = resize(heatmap,(w,h), preserve_range=True, clip=True) K.clear_session() - e0 = np.average(np.dstack(E0), axis=-1) # , weights=np.array(MW)) + e0 = np.average(np.dstack(E0), axis=-1) del E0 - e1 = np.average(np.dstack(E1), axis=-1) # , weights=np.array(MW)) + e1 = np.average(np.dstack(E1), axis=-1) del E1 est_label = (e1 + (1 - e0)) / 2 @@ -386,7 +275,6 @@ def do_seg( metadatadict["otsu_threshold"] = thres else: - # print("Not using Otsu threshold") est_label = (est_label > 0.5).astype("uint8") if WRITE_MODELMETADATA: metadatadict["otsu_threshold"] = 0.5 @@ -396,13 +284,12 @@ def do_seg( if N_DATA_BANDS <= 3: image, w, h, bigimage = seg_file2tensor_3band( f, TARGET_SIZE - ) # , resize=True) + ) w = w.numpy() h = h.numpy() else: image, w, h, bigimage = seg_file2tensor_ND(f, TARGET_SIZE) - # image = tf.image.per_image_standardization(image) image = standardize(image.numpy()) # return the base prediction if N_DATA_BANDS == 1: @@ -479,10 +366,7 @@ def do_seg( ] # add classes for more than 10 classes - # if NCLASSES > 1: class_label_colormap = class_label_colormap[:NCLASSES] - # else: - # class_label_colormap = class_label_colormap[:2] if WRITE_MODELMETADATA: metadatadict["color_segmentation_output"] = segfile @@ -804,3 +688,127 @@ def sliding_window(a, ws, ss=None, flatten=True): # dim = filter(lambda i : i != 1,dim) return a.reshape(dim), newshape + + + +# #----------------------------------- +# def crf_refine_from_integer_labels(label, img, nclasses = 2, theta_col=100, theta_spat=3, compat=120): +# """ +# "crf_refine(label, img)" +# This function refines a label image based on an input label image and the associated image +# Uses a conditional random field algorithm using spatial and image features +# INPUTS: +# * label [ndarray]: label image 2D matrix of integers +# * image [ndarray]: image 3D matrix of integers +# OPTIONAL INPUTS: None +# GLOBAL INPUTS: None +# OUTPUTS: label [ndarray]: label image 2D matrix of integers +# """ + +# gx,gy = np.meshgrid(np.arange(img.shape[1]), np.arange(img.shape[0])) +# # print(gx.shape) +# img = np.dstack((img,gx,gy)) + +# H = label.shape[0] +# W = label.shape[1] +# U = unary_from_labels(1+label,nclasses,gt_prob=0.51) +# d = dcrf.DenseCRF2D(H, W, nclasses) +# d.setUnaryEnergy(U) + +# # to add the color-independent term, where features are the locations only: +# d.addPairwiseGaussian(sxy=(theta_spat, theta_spat), +# compat=3, +# kernel=dcrf.DIAG_KERNEL, +# normalization=dcrf.NORMALIZE_SYMMETRIC) +# feats = create_pairwise_bilateral( +# sdims=(theta_col, theta_col), +# schan=(2,2,2), +# img=img, +# chdim=2) + +# d.addPairwiseEnergy(feats, compat=compat,kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC) +# Q = d.inference(20) +# kl1 = d.klDivergence(Q) +# return np.argmax(Q, axis=0).reshape((H, W)).astype(np.uint8), kl1 + + + +# ##======================================================== +# def crf_refine(label, +# img,n, +# crf_theta_slider_value, +# crf_mu_slider_value, +# crf_downsample_factor): #gt_prob +# """ +# "crf_refine(label, img)" +# This function refines a label image based on an input label image and the associated image +# Uses a conditional random field algorithm using spatial and image features +# INPUTS: +# * label [ndarray]: label image 2D matrix of integers +# * image [ndarray]: image 3D matrix of integers +# OPTIONAL INPUTS: None +# GLOBAL INPUTS: None +# OUTPUTS: label [ndarray]: label image 2D matrix of integers +# """ + +# Horig = img.shape[0] +# Worig = img.shape[1] +# l_unique = label.shape[-1] + +# label = label.reshape(Horig,Worig,l_unique) + +# scale = 1+(5 * (np.array(img.shape).max() / 3000)) + +# # decimate by factor by taking only every other row and column +# try: +# img = img[::crf_downsample_factor,::crf_downsample_factor, :] +# except: +# img = img[::crf_downsample_factor,::crf_downsample_factor] + +# # do the same for the label image +# label = label[::crf_downsample_factor,::crf_downsample_factor] +# # yes, I know this aliases, but considering the task, it is ok; the objective is to +# # make fast inference and resize the output + +# H = img.shape[0] +# W = img.shape[1] +# # U = unary_from_labels(np.argmax(label,-1).astype('int'), n, gt_prob=gt_prob) +# # d = dcrf.DenseCRF2D(H, W, n) + +# U = unary_from_softmax(np.ascontiguousarray(np.rollaxis(label,-1,0))) +# d = dcrf.DenseCRF2D(H, W, l_unique) + +# d.setUnaryEnergy(U) + +# # to add the color-independent term, where features are the locations only: +# d.addPairwiseGaussian(sxy=(3, 3), +# compat=3, +# kernel=dcrf.DIAG_KERNEL, +# normalization=dcrf.NORMALIZE_SYMMETRIC) + +# try: +# feats = create_pairwise_bilateral( +# sdims=(crf_theta_slider_value, crf_theta_slider_value), +# schan=(scale,scale,scale), +# img=img, +# chdim=2) + +# d.addPairwiseEnergy(feats, compat=crf_mu_slider_value, kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC) +# except: +# feats = create_pairwise_bilateral( +# sdims=(crf_theta_slider_value, crf_theta_slider_value), +# schan=(scale,scale, scale), +# img=np.dstack((img,img,img)), +# chdim=2) + +# d.addPairwiseEnergy(feats, compat=crf_mu_slider_value, kernel=dcrf.DIAG_KERNEL,normalization=dcrf.NORMALIZE_SYMMETRIC) + +# Q = d.inference(10) +# result = np.argmax(Q, axis=0).reshape((H, W)).astype(np.uint8) +1 + +# # uniq = np.unique(result.flatten()) + +# result = resize(result, (Horig, Worig), order=0, anti_aliasing=False) #True) +# result = rescale(result, 1, l_unique).astype(np.uint8) + +# return result, l_unique \ No newline at end of file