Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRieutord committed Sep 7, 2023
1 parent 6e38b2b commit 57031ee
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 8 deletions.
14 changes: 13 additions & 1 deletion drafts/mlctnet_simple_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
import torch
from torch import nn
from torchinfo import summary
from torchgeo.datasets.utils import BoundingBox
import numpy as np
import rasterio.crs
Expand Down Expand Up @@ -169,7 +170,18 @@

# Show results
#-----------

inner_shape = [min(s1,s2) for (s1,s2) in zip(y.shape[1:], y_true["mask"].shape[1:])]
ccrop = transforms.t.CenterCrop(size=inner_shape)
acc = (ccrop(y) == ccrop(y_true["mask"])).sum()/ccrop(y).numel()
print(f"Overall accuracy over this patch: {acc}")
out_lc.plot({"mask":y})
plt.show(block=False)

# Export
#-----------
model = nn.Sequential(
esawc_autoencoder.encoder,
esgp_autoencoder.decoder
)
summary(model, x.shape)
# torch.onnx.export(model, "esawc2esgp_test.onnx", verbose=True)
14 changes: 13 additions & 1 deletion mmt/graphs/models/universal_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
class AtrouMMU(nn.Module):
def __init__(self, inf, scale_factor=10, bias=False):
super(AtrouMMU, self).__init__()
print(inf, scale_factor)
self.conv1 = nn.Sequential(
nn.Conv2d(
inf, inf, kernel_size=3, padding=1, stride=scale_factor, bias=bias
Expand Down Expand Up @@ -365,6 +364,19 @@ def MemoryMonged_forward(self, x):

def forward(self, x):
return self.forward_method(x)

def check_shapes(self, x = None):
"""Display shapes of some tensors"""
if x is None:
x = torch.rand(10, 3, 60, 120)
print(f"Random input: x = {x.shape}")
else:
print(f"Given input: x = {x.shape}")

states = self.encoder_part(x)
for i, xx in enumerate(states):
print(f"x{i} = {xx.shape}")



class UnivEmb(nn.Module):
Expand Down
12 changes: 6 additions & 6 deletions questions/what_are_the_architectures_of_the_tested_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
"universal_embedding.json",
)
)
n_labels = 12
n_labels = 43
resize = None

baseline_model = universal_embedding.UnivEmb(
in_channels = n_labels + 1,
Expand All @@ -54,7 +55,7 @@
num_groups = mlulcconfig.group_norm,
decoder_depth = mlulcconfig.decoder_depth,
mode = mlulcconfig.mode,
resize = 6,
resize = resize,
cat=False,
pooling_factors = mlulcconfig.pooling_factors,
decoder_atrou = mlulcconfig.decoder_atrou,
Expand All @@ -74,7 +75,7 @@
num_groups = mlulcconfig.group_norm,
decoder_depth = mlulcconfig.decoder_depth,
mode = mlulcconfig.mode,
resize = 6,
resize = resize,
cat=False,
pooling_factors = mlulcconfig.pooling_factors,
decoder_atrou = mlulcconfig.decoder_atrou,
Expand All @@ -84,15 +85,14 @@
candidate_model.decoder,
)

x = torch.rand(mlulcconfig.train_batch_size, n_labels + 1, 100, 100)
x = torch.rand(15, n_labels + 1, 600, 600)

print(" Architecture of the BASELINE model")
baseline_summary = summary(baseline_model, x.shape)
print(baseline_summary)

print("\n Architecture of the CANDIDATE model")
candidate_summary = summary(candidate_model, x.shape)
print(candidate_summary)

print(" ")
print(f"The BASELINE model has {baseline_summary.trainable_params} trainable parameters and needs {baseline_summary.total_output_bytes/10**9} GB to run")
print(f"The CANDIDATE model has {candidate_summary.trainable_params} trainable parameters and needs {candidate_summary.total_output_bytes/10**9} GB to run")
3 changes: 3 additions & 0 deletions tests/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import mmt.graphs
import mmt.graphs.models
import mmt.graphs.models.custom_layers
from mmt import _repopath_ as mmt_repopath

from mmt.datasets import landcover_to_landcover
landcover_to_landcover.LandcoverToLandcoverDataLoader
Expand All @@ -61,6 +62,8 @@

from mmt.utils import plt_utils
plt_utils.plt_loss2
from mmt.utils import config as utilconf
utilconf.get_config_from_json

print("All imports passed successfully")
print(f"Package {mmt.__name__}-{mmt.__version__} from {mmt._repopath_}")

0 comments on commit 57031ee

Please sign in to comment.