From 57031ee5f1a173df3f28b8cbe6484f334e50dcdc Mon Sep 17 00:00:00 2001 From: Thomas Rieutord Date: Thu, 7 Sep 2023 16:31:32 +0000 Subject: [PATCH] Minor changes --- drafts/mlctnet_simple_inference.py | 14 +++++++++++++- mmt/graphs/models/universal_embedding.py | 14 +++++++++++++- ...t_are_the_architectures_of_the_tested_models.py | 12 ++++++------ tests/import_test.py | 3 +++ 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/drafts/mlctnet_simple_inference.py b/drafts/mlctnet_simple_inference.py index b02835f..2a3320a 100644 --- a/drafts/mlctnet_simple_inference.py +++ b/drafts/mlctnet_simple_inference.py @@ -15,6 +15,7 @@ import sys import torch from torch import nn +from torchinfo import summary from torchgeo.datasets.utils import BoundingBox import numpy as np import rasterio.crs @@ -169,7 +170,18 @@ # Show results #----------- - +inner_shape = [min(s1,s2) for (s1,s2) in zip(y.shape[1:], y_true["mask"].shape[1:])] +ccrop = transforms.t.CenterCrop(size=inner_shape) +acc = (ccrop(y) == ccrop(y_true["mask"])).sum()/ccrop(y).numel() +print(f"Overall accuracy over this patch: {acc}") out_lc.plot({"mask":y}) plt.show(block=False) +# Export +#----------- +model = nn.Sequential( + esawc_autoencoder.encoder, + esgp_autoencoder.decoder +) +summary(model, x.shape) +# torch.onnx.export(model, "esawc2esgp_test.onnx", verbose=True) diff --git a/mmt/graphs/models/universal_embedding.py b/mmt/graphs/models/universal_embedding.py index d8b5c6d..f81fd0d 100644 --- a/mmt/graphs/models/universal_embedding.py +++ b/mmt/graphs/models/universal_embedding.py @@ -14,7 +14,6 @@ class AtrouMMU(nn.Module): def __init__(self, inf, scale_factor=10, bias=False): super(AtrouMMU, self).__init__() - print(inf, scale_factor) self.conv1 = nn.Sequential( nn.Conv2d( inf, inf, kernel_size=3, padding=1, stride=scale_factor, bias=bias @@ -365,6 +364,19 @@ def MemoryMonged_forward(self, x): def forward(self, x): return self.forward_method(x) + + def check_shapes(self, x = None): + """Display shapes of some tensors""" + if x is None: + x = torch.rand(10, 3, 60, 120) + print(f"Random input: x = {x.shape}") + else: + print(f"Given input: x = {x.shape}") + + states = self.encoder_part(x) + for i, xx in enumerate(states): + print(f"x{i} = {xx.shape}") + class UnivEmb(nn.Module): diff --git a/questions/what_are_the_architectures_of_the_tested_models.py b/questions/what_are_the_architectures_of_the_tested_models.py index 6bd13b4..0969ef4 100644 --- a/questions/what_are_the_architectures_of_the_tested_models.py +++ b/questions/what_are_the_architectures_of_the_tested_models.py @@ -41,7 +41,8 @@ "universal_embedding.json", ) ) -n_labels = 12 +n_labels = 43 +resize = None baseline_model = universal_embedding.UnivEmb( in_channels = n_labels + 1, @@ -54,7 +55,7 @@ num_groups = mlulcconfig.group_norm, decoder_depth = mlulcconfig.decoder_depth, mode = mlulcconfig.mode, - resize = 6, + resize = resize, cat=False, pooling_factors = mlulcconfig.pooling_factors, decoder_atrou = mlulcconfig.decoder_atrou, @@ -74,7 +75,7 @@ num_groups = mlulcconfig.group_norm, decoder_depth = mlulcconfig.decoder_depth, mode = mlulcconfig.mode, - resize = 6, + resize = resize, cat=False, pooling_factors = mlulcconfig.pooling_factors, decoder_atrou = mlulcconfig.decoder_atrou, @@ -84,15 +85,14 @@ candidate_model.decoder, ) -x = torch.rand(mlulcconfig.train_batch_size, n_labels + 1, 100, 100) +x = torch.rand(15, n_labels + 1, 600, 600) print(" Architecture of the BASELINE model") baseline_summary = summary(baseline_model, x.shape) -print(baseline_summary) print("\n Architecture of the CANDIDATE model") candidate_summary = summary(candidate_model, x.shape) -print(candidate_summary) + print(" ") print(f"The BASELINE model has {baseline_summary.trainable_params} trainable parameters and needs {baseline_summary.total_output_bytes/10**9} GB to run") print(f"The CANDIDATE model has {candidate_summary.trainable_params} trainable parameters and needs {candidate_summary.total_output_bytes/10**9} GB to run") diff --git a/tests/import_test.py b/tests/import_test.py index e000c83..51b7217 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -44,6 +44,7 @@ import mmt.graphs import mmt.graphs.models import mmt.graphs.models.custom_layers +from mmt import _repopath_ as mmt_repopath from mmt.datasets import landcover_to_landcover landcover_to_landcover.LandcoverToLandcoverDataLoader @@ -61,6 +62,8 @@ from mmt.utils import plt_utils plt_utils.plt_loss2 +from mmt.utils import config as utilconf +utilconf.get_config_from_json print("All imports passed successfully") print(f"Package {mmt.__name__}-{mmt.__version__} from {mmt._repopath_}")