Skip to content

Commit

Permalink
Small testings
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRieutord committed Sep 15, 2023
1 parent a73badf commit 569ee2c
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions drafts/mlctnet_simple_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Weather-oriented physiography toolbox (WOPT)
"""Multiple land-cover/land-use Maps Translation (MMT)
Program to make prediction with MLCT-net.
Expand Down Expand Up @@ -35,13 +35,15 @@
from mmt.utils import config as utilconf


# Must run in the MLULC code directory
# Configs
#---------
usegpu = True
device = torch.device("cuda" if usegpu else "cpu")

woptconfig = wopt.ml.utils.load_config()
print(f"Executing program {sys.argv[0]} from {os.getcwd()}")

# Configs
#---------
xp_name = "test_if_it_runs"
xp_name = "vanilla_with_esgp_v2"
mlulcconfig, _ = utilconf.get_config_from_json(
os.path.join(
mmt_repopath,
Expand Down Expand Up @@ -137,6 +139,21 @@

ce = mmt_transforms.CoordEnc(None)

esawc2esgp = nn.Sequential(
esawc_autoencoder.encoder,
esgp_autoencoder.decoder
)
esawc2esgp2esgp = nn.Sequential(
esawc_autoencoder.encoder,
esgp_autoencoder.decoder,
esgp_autoencoder.encoder,
esgp_autoencoder.decoder
)
esgp2esgp = nn.Sequential(
esgp_autoencoder.encoder,
esgp_autoencoder.decoder
)

# Loading query
#----------------
qdomain = domains.dublin_city
Expand Down Expand Up @@ -165,6 +182,8 @@
emb, x2 = esawc_autoencoder(x.float(), full = True, res = res)
emb2, z2 = esgp_autoencoder(z.float(), full = True, res = res)
logits = esgp_autoencoder.decoder(emb)
logits = esawc2esgp(x.float())
# logits = esawc2esgp2esgp(x.float())
proba = logits.detach().softmax(1)
y = logits.detach().argmax(1)

Expand All @@ -184,4 +203,4 @@
esgp_autoencoder.decoder
)
summary(model, x.shape)
# torch.onnx.export(model, "esawc2esgp_test.onnx", verbose=True)
# torch.onnx.export(model, x.float(), "esawc2esgp_test.onnx", verbose=True)

0 comments on commit 569ee2c

Please sign in to comment.