From f105e030973810267a1e151d229ec06a3af98c92 Mon Sep 17 00:00:00 2001 From: john2 Date: Tue, 21 Jan 2025 15:27:25 +0100 Subject: [PATCH 01/14] improved FNO arch --- cfno/layers/skip_connection.py | 123 +++++++++++++++++++++ cfno/models/cfno2d.py | 176 ++++++++++++++++++++++++++++-- cfno/training/pySDC.py | 14 ++- scripts/configs/config_dte-3.yaml | 15 ++- 4 files changed, 312 insertions(+), 16 deletions(-) create mode 100644 cfno/layers/skip_connection.py diff --git a/cfno/layers/skip_connection.py b/cfno/layers/skip_connection.py new file mode 100644 index 0000000..70b6e5a --- /dev/null +++ b/cfno/layers/skip_connection.py @@ -0,0 +1,123 @@ +import torch +from torch import nn + + +def skip_connection( + in_features, out_features, n_dim=2, bias=False, skip_type="soft-gating" +): + """A wrapper for several types of skip connections. + Returns an nn.Module skip connections, one of {'identity', 'linear', soft-gating'} + + Parameters + ---------- + in_features : int + number of input features + out_features : int + number of output features + n_dim : int, default is 2 + Dimensionality of the input (excluding batch-size and channels). + ``n_dim=2`` corresponds to having Module2D. + bias : bool, optional + whether to use a bias, by default False + skip_type : {'identity', 'linear', soft-gating'} + kind of skip connection to use, by default "soft-gating" + + Returns + ------- + nn.Module + module that takes in x and returns skip(x) + """ + if skip_type.lower() == "soft-gating": + return SoftGating( + in_features=in_features, + out_features=out_features, + bias=bias, + n_dim=n_dim, + ) + elif skip_type.lower() == "linear": + return Flattened1dConv(in_channels=in_features, + out_channels=out_features, + kernel_size=1, + bias=bias,) + elif skip_type.lower() == "identity": + return nn.Identity() + else: + raise ValueError( + f"Got skip-connection type={skip_type}, expected one of" + f" {'soft-gating', 'linear', 'id'}." + ) + + +class SoftGating(nn.Module): + """Applies soft-gating by weighting the channels of the given input + + Given an input x of size `(batch-size, channels, height, width)`, + this returns `x * w ` + where w is of shape `(1, channels, 1, 1)` + + Parameters + ---------- + in_features : int + out_features : None + this is provided for API compatibility with nn.Linear only + n_dim : int, default is 2 + Dimensionality of the input (excluding batch-size and channels). + ``n_dim=2`` corresponds to having Module2D. + bias : bool, default is False + """ + + def __init__(self, in_features, out_features=None, n_dim=2, bias=False): + super().__init__() + if out_features is not None and in_features != out_features: + raise ValueError( + f"Got in_features={in_features} and out_features={out_features}, " + "but these two must be the same for soft-gating" + ) + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.ones(1, self.in_features, *(1,) * n_dim)) + if bias: + self.bias = nn.Parameter(torch.ones(1, self.in_features, *(1,) * n_dim)) + else: + self.bias = None + + def forward(self, x): + """Applies soft-gating to a batch of activations""" + if self.bias is not None: + return self.weight * x + self.bias + else: + return self.weight * x + +class Flattened1dConv(nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size, bias=False): + """Flattened1dConv is a Conv-based skip layer for + input tensors of ndim > 3 (batch, channels, d1, ...) that flattens all dimensions + past the batch and channel dims into one dimension, applies the Conv, + and un-flattens. + + Parameters + ---------- + in_channels : int + in_channels of Conv1d + out_channels : int + out_channels of Conv1d + kernel_size : int + kernel_size of Conv1d + bias : bool, optional + bias of Conv3d, by default False + """ + super().__init__() + self.conv = nn.Conv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + def forward(self, x): + # x.shape: b, c, x1, ..., xn x_ndim > 1 + size = list(x.shape) + # flatten everything past 1st data dim + x = x.view(*size[:2], -1) + x = self.conv(x) + # reshape x into an Nd tensor b, c, x1, x2, ... + x = x.view(size[0], self.conv.out_channels, *size[2:]) + return x + \ No newline at end of file diff --git a/cfno/models/cfno2d.py b/cfno/models/cfno2d.py index 66c4c7c..4cfec1a 100644 --- a/cfno/models/cfno2d.py +++ b/cfno/models/cfno2d.py @@ -5,7 +5,9 @@ import torch.nn as nn import pandas as pd from torch_dct import dct, idct +import torch.nn.functional as F +from cfno.layers.skip_connection import skip_connection from cfno.utils import CudaMemoryDebugger, format_tensor_size, activation_selection class CF2DConv(nn.Module): @@ -13,16 +15,26 @@ class CF2DConv(nn.Module): USE_T_CACHE = False - def __init__(self, kX, kY, dv, forceFFT=False, reorder=False): + def __init__(self, kX, kY, dv, forceFFT=False, reorder=False, bias=False, order=2): super().__init__() self.kX = kX self.kY = kY self.forceFFT = forceFFT self.reorder = reorder + self.order = order self.R = nn.Parameter( th.rand(dv, dv, kX*(2 if forceFFT else 1), kY, dtype=th.cfloat)) + + if bias: + self.init_std = (2 / (dv + dv))**0.5 + self.bias = nn.Parameter( + self.init_std * th.randn(*(tuple([dv]) + (1,) * self.order)) + ) + else: + self.init_std = None + self.bias = None if forceFFT: if reorder: @@ -119,6 +131,91 @@ def forward(self, x:th.tensor): # Transform back to Real space -> [nBatch, dv, nX, nY] x = self._toRealSpace(x) + + if self.bias is not None: + x = x + self.bias + + return x + + +class ChannelMLP(nn.Module): + """ChannelMLP applies an arbitrary number of layers of + 1d convolution and nonlinearity to the channels of input + and is invariant to spatial resolution. + + Parameters + ---------- + in_channels : int + out_channels : int, default is None + if None, same is in_channels + hidden_channels : int, default is None + if None, same is in_channels + n_layers : int, default is 2 + number of linear layers in the MLP + non_linearity : default is F.gelu + dropout : float, default is 0 + if > 0, dropout probability + """ + + def __init__( + self, + in_channels, + out_channels=None, + hidden_channels=None, + n_layers=2, + n_dim=2, + non_linearity=F.gelu, + dropout=0.0, + **kwargs, + ): + super().__init__() + self.n_layers = n_layers + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.hidden_channels = ( + in_channels if hidden_channels is None else hidden_channels + ) + self.non_linearity = non_linearity + self.dropout = ( + nn.ModuleList([nn.Dropout(dropout) for _ in range(n_layers)]) + if dropout > 0.0 + else None + ) + + # we use nn.Conv1d for everything and roll data along the 1st data dim + self.fcs = nn.ModuleList() + for i in range(n_layers): + if i == 0 and i == (n_layers - 1): + self.fcs.append(nn.Conv1d(self.in_channels, self.out_channels, 1)) + elif i == 0: + self.fcs.append(nn.Conv1d(self.in_channels, self.hidden_channels, 1)) + elif i == (n_layers - 1): + self.fcs.append(nn.Conv1d(self.hidden_channels, self.out_channels, 1)) + else: + self.fcs.append(nn.Conv1d(self.hidden_channels, self.hidden_channels, 1)) + + def forward(self, x): + reshaped = False + size = list(x.shape) + if x.ndim > 3: + # batch, channels, x1, x2... extra dims + # .reshape() is preferable but .view() + # cannot be called on non-contiguous tensors + x = x.reshape((*size[:2], -1)) + reshaped = True + + for i, fc in enumerate(self.fcs): + x = fc(x) + if i < self.n_layers - 1: + x = self.non_linearity(x) + if self.dropout is not None: + x = self.dropout[i](x) + + # if x was an N-d tensor reshaped into 1d, undo the reshaping + # same logic as above: .reshape() handles contiguous tensors as well + if reshaped: + x = x.reshape((size[0], self.out_channels, *size[2:])) + return x @@ -170,15 +267,24 @@ def forward(self, x): class CF2DLayer(nn.Module): - def __init__(self, kX, kY, dv, forceFFT=False, non_linearity='gelu', bias=True, reorder=False): + def __init__(self, kX, kY, dv, + forceFFT=False, + non_linearity='gelu', + bias=True, reorder=False, + use_fno_skip_connection=False, + fno_skip_type='linear' + ): super().__init__() - self.conv = CF2DConv(kX, kY, dv, forceFFT, reorder) + self.conv = CF2DConv(kX, kY, dv, forceFFT, reorder, bias) if non_linearity == 'gelu': self.sigma = nn.functional.gelu else: self.sigma = nn.ReLU(inplace=True) - self.W = Grid2DLinear(dv, dv, bias) + if use_fno_skip_connection: + self.W = skip_connection(dv, dv, skip_type=fno_skip_type) + else: + self.W = Grid2DLinear(dv, dv, bias) def forward(self, x): @@ -194,20 +300,62 @@ def forward(self, x): class CFNO2D(nn.Module): - def __init__(self, da, dv, du, kX=4, kY=4, nLayers=1, - forceFFT=False, non_linearity='gelu', bias=True, reorder=False): + def __init__(self, da, dv, du, kX=4, kY=4, + nLayers=1, + forceFFT=False, + non_linearity='gelu', + bias=True, + reorder=False, + use_prechannel_mlp=False, + use_fno_skip_connection=False, + fno_skip_type='linear', + use_postfnochannel_mlp=False, + channel_mlp_skip_type='soft-gating', + channel_mlp_expansion=4 + ): + super().__init__() self.config = { key: val for key, val in locals().items() if key != "self" and not key.startswith('__')} + + self.use_postfnochannel_mlp = use_postfnochannel_mlp + + if use_prechannel_mlp: + self.P = ChannelMLP( + in_channels=da, + out_channels=dv, + hidden_channels=dv*channel_mlp_expansion, + n_layers=4 + ) + self.Q = ChannelMLP( + in_channels=dv, + out_channels=du, + hidden_channels=dv*channel_mlp_expansion, + n_layers=4 + ) + else: + self.P = Grid2DLinear(da, dv, bias) + self.Q = Grid2DLinear(dv, du, bias) - self.P = Grid2DLinear(da, dv, bias) - self.Q = Grid2DLinear(dv, du, bias) self.layers = nn.ModuleList( - [CF2DLayer(kX, kY, dv, forceFFT, non_linearity, bias, reorder) + [CF2DLayer(kX, kY, dv, forceFFT, non_linearity, bias, reorder, + use_fno_skip_connection, + fno_skip_type) for _ in range(nLayers)]) # self.pos = Grid2DPartialPositiver([0, 0, 1, 1]) + if self.use_postfnochannel_mlp: + postchannel_mlp_expansion = 0.5 + self.channel_mlp = nn.ModuleList( + [ChannelMLP(in_channels=dv, + hidden_channels=round(dv * postchannel_mlp_expansion)) + for _ in range(nLayers)]) + + self.channel_mlp_skips = nn.ModuleList( + [skip_connection(dv, dv, skip_type=channel_mlp_skip_type) + for _ in range(nLayers)]) + self.memory = CudaMemoryDebugger(print_mem=True) def forward(self, x): @@ -215,9 +363,17 @@ def forward(self, x): # x = x.permute(0,3,1,2) x = self.P(x) - for layer in self.layers: + for index,layer in enumerate(self.layers): + if self.use_postfnochannel_mlp: + x_skip_channel_mlp = self.channel_mlp_skips[index](x) + x = layer(x) + if self.use_postfnochannel_mlp: + x = self.channel_mlp[index](x) + x_skip_channel_mlp + if index < len(self.layers) - 1: + x = nn.functional.gelu(x) + x = self.Q(x) # x = self.pos(x) # x = x.permute(0,2,3,1) diff --git a/cfno/training/pySDC.py b/cfno/training/pySDC.py index 14cc0fa..c8b175b 100644 --- a/cfno/training/pySDC.py +++ b/cfno/training/pySDC.py @@ -130,13 +130,16 @@ def setupOptimizer(self, optim=None): def setupLRScheduler(self, lr_scheduler=None): if lr_scheduler is None: lr_scheduler = {"scheduler": "StepLR", "step_size": 100.0, "gamma": 0.98} + self.scheduler_config = lr_scheduler scheduler = lr_scheduler.pop('scheduler') + self.scheduler_name = scheduler if scheduler == "StepLR": self.lr_scheduler = th.optim.lr_scheduler.StepLR(self.optimizer, **lr_scheduler) elif scheduler == "CosAnnealingLR": self.lr_scheduler = th.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, **lr_scheduler) else: raise ValueError(f"LR scheduler {scheduler} not implemented yet") + def setOptimizerParam(self, **params): """ @@ -161,6 +164,9 @@ def printInfos(self): print(f" -- name : {self.optim}") for key, val in self.optimConfig.items(): print(f" -- {key}: {val}") + print(f"Scheduler: {self.scheduler_name}") + for key,val in self.scheduler_config.items(): + print(f" -- {key}: {val}") # TODO: add more details here ... print("-"*80) @@ -251,7 +257,7 @@ def closure(): avgLoss /= nBatches gradsEpoch /= nBatches - print(f"Training: \n Avg loss: {avgLoss:>8f} (id: {idLoss:>7f})\n") + print(f"Training: \n Avg loss: {avgLoss:>8f} (id: {idLoss:>7f}) -- lr: {optimizer.param_groups[0]['lr']}\n") self.losses["model"]["train"] = avgLoss self.gradientNormEpoch = gradsEpoch @@ -339,6 +345,7 @@ def save(self, filePath, modelOnly=False): # Optimizer config and state 'optim': self.optim, 'optimizer_state_dict': self.optimizer.state_dict(), + 'lr_scheduler_state_dict': self.lr_scheduler.state_dict() }) th.save(infos, fullPath) @@ -368,6 +375,11 @@ def load(self, filePath, modelOnly=False): optim = checkpoint['optim'] self.setupOptimizer({"name": optim}) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + try: + self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) + except KeyError: + print("Learning rate scheduler is restarted!") + # ------------------------------------------------------------------------- diff --git a/scripts/configs/config_dte-3.yaml b/scripts/configs/config_dte-3.yaml index b50fee4..a99b440 100644 --- a/scripts/configs/config_dte-3.yaml +++ b/scripts/configs/config_dte-3.yaml @@ -29,20 +29,25 @@ model: nLayers: 2 forceFFT: true bias: false + use_prechannel_mlp: true + channel_mlp_expansion: 4 + use_fno_skip_connection: false + fno_skip_type: soft-gating + use_postfnochannel_mlp: false + channel_mlp_skip_type: 'soft-gating' optim: name: adam - lr: 0.0001 + lr: 0.001 weight_decay: 1.0e-5 lr_scheduler: - scheduler: StepLR - step_size: 100.0 - gamma: 0.98 + scheduler: CosAnnealingLR + T_max: 10000 train: trainDir: training_dt1e-3_update_fft - nEpochs: 20000 + nEpochs: 10000 checkpoint: model.pt lossesFile: losses.txt saveEvery: 100 From a9099edfc220e691678761a65071a749e359f829 Mon Sep 17 00:00:00 2001 From: john2 Date: Tue, 21 Jan 2025 15:35:53 +0100 Subject: [PATCH 02/14] set default bias to False --- cfno/models/cfno2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cfno/models/cfno2d.py b/cfno/models/cfno2d.py index 4cfec1a..bcf872c 100644 --- a/cfno/models/cfno2d.py +++ b/cfno/models/cfno2d.py @@ -270,7 +270,7 @@ class CF2DLayer(nn.Module): def __init__(self, kX, kY, dv, forceFFT=False, non_linearity='gelu', - bias=True, reorder=False, + bias=False, reorder=False, use_fno_skip_connection=False, fno_skip_type='linear' ): From 6478a02895f4453d8169d814f29691b9ebab1478 Mon Sep 17 00:00:00 2001 From: john2 Date: Wed, 22 Jan 2025 15:49:51 +0100 Subject: [PATCH 03/14] training with patches --- cfno/data/preprocessing.py | 105 ++++++++++++++++++++++++++++-- cfno/training/pySDC.py | 6 +- scripts/configs/config_dte-3.yaml | 10 ++- 3 files changed, 110 insertions(+), 11 deletions(-) diff --git a/cfno/data/preprocessing.py b/cfno/data/preprocessing.py index 15a0ba1..122c5bb 100644 --- a/cfno/data/preprocessing.py +++ b/cfno/data/preprocessing.py @@ -6,7 +6,7 @@ import torch import glob import numpy as np - +import random from torch.utils.data import Dataset, DataLoader, random_split, Subset from cfno.simulation.post import OutputFiles @@ -110,10 +110,27 @@ def get_concat_data(self, task:str, nsamples:int, reader, multistep:bool=True): class HDF5Dataset(Dataset): - def __init__(self, dataFile): + def __init__(self, dataFile, + use_domain_sampling=False, + nPatch_per_sample=1, + use_min_limit=False, + padding=[0,0,0,0], + kX= 12, kY= 12): + self.file = h5py.File(dataFile, 'r') self.inputs = self.file['inputs'] self.outputs = self.file['outputs'] + self.use_domain_sampling = use_domain_sampling + self.nPatch_per_sample = nPatch_per_sample + self.use_min_limit = use_min_limit + self.kX = kX + self.kY = kY + xGrid, yGrid = self.grid + self.nX = xGrid.size + self.nY = yGrid.size + self.slices = self.find_patch_size() + self.padding = padding #[left, right, bottom, top] + assert len(self.inputs) == len(self.outputs), \ f"different sample number for inputs and outputs ({len(self.inputs)},{len(self.outputs)})" @@ -121,7 +138,38 @@ def __len__(self): return len(self.inputs) def __getitem__(self, idx): - inpt, outp = self.sample(idx) + if self.use_domain_sampling: + patch_padding = self.padding.copy() + iSample = idx // self.nPatch_per_sample + iPatch = idx % self.nPatch_per_sample + inpt_grid, outp_grid = self.sample(iSample) + inpt, outp = np.zeros_like(inpt_grid), np.zeros_like(outp_grid) + + sX, sY = self.slices[iPatch] + xPatch_start = random.randint(0, self.nX - sX) + yPatch_start = random.randint(0, self.nY - sY) + + if xPatch_start == 0: + patch_padding[0] = 0 + if xPatch_start == (self.nX-sX): + patch_padding[1] = 0 + if yPatch_start == 0: + patch_padding[2] = 0 + if yPatch_start == (self.nY-sY): + patch_padding[3] = 0 + + # print(f"Input size: {inpt.shape}") + # print(f'For patch {iPatch} of sample {iSample}') + # print(f'(sx,sy): {sX,sY}, (x_start,y_start): {xPatch_start,yPatch_start}') + inpt[:, :(sX + patch_padding[0] + patch_padding[1]), + :(sY + patch_padding[2] + patch_padding[3])] = inpt_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1], + yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + outp[:,:(sX + patch_padding[0] + patch_padding[1]), + :(sY + patch_padding[2] + patch_padding[3])] = outp_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1], + yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + else: + inpt, outp = self.sample(idx) + return torch.tensor(inpt), torch.tensor(outp) def __del__(self): @@ -148,6 +196,34 @@ def outType(self): @property def outScaling(self): return float(self.infos["outScaling"][()]) + + def find_patch_size(self): + """ + List containing random patch sizes + """ + slices = [] + if self.use_min_limit: + nX_min = self.calc_slice_min(self.nX, self.kX) + nY_min = self.calc_slice_min(self.nY, self.kY) + else: + nX_min, nY_min = 0, 0 + for i in range(self.nPatch_per_sample): + sX = random.randint(nX_min, self.nX) + sY = random.randint(nY_min, self.nY) + slices.append((sX,sY)) + return slices + + def calc_slice_min(self, n, modes): + """ + Finding min number of points to satisfy + n/2 +1 >= modes + """ + slice_min = 2*(modes-1) + if slice_min < n: + return slice_min + else: + print("Insufficient number of points to slice") + return 0 def printInfos(self): xGrid, yGrid = self.grid @@ -164,6 +240,12 @@ def printInfos(self): print(f" -- dtInput : {infos['dtInput'][()]:1.2g}") print(f" -- outType : {infos['outType'][()].decode('utf-8')}") print(f" -- outScaling : {infos['outScaling'][()]:1.2g}") + if self.use_domain_sampling: + print(f"-- nPatch (per sample): {self.nPatch_per_sample}") + print(f" --patches (per sample): {self.slices}") + print(f" --padding (per patch): {self.padding}") + if self.use_min_limit: + print(f" Min nX & nY for patch computed using {self.kX, self.kY} modes") def createDataset( dataDir, inSize, outStep, inStep, outType, outScaling, dataFile, @@ -231,10 +313,19 @@ def createDataset( dataset.close() print(" -- done !") - -def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, seed=None): - dataset = HDF5Dataset(dataFile) - nBatches = len(dataset) +def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, + seed=None, use_domain_sampling=False, + nPatch_per_sample=1,use_min_limit=False, + padding=[0,0,0,0],kX=12, kY=12): + + dataset = HDF5Dataset(dataFile,use_domain_sampling, + nPatch_per_sample,use_min_limit, + padding,kX,kY) + if use_domain_sampling: + nBatches = len(dataset)//nPatch_per_sample + else: + nBatches = len(dataset) + trainSize = int(trainRatio*nBatches) valSize = nBatches - trainSize diff --git a/cfno/training/pySDC.py b/cfno/training/pySDC.py index c8b175b..b271ceb 100644 --- a/cfno/training/pySDC.py +++ b/cfno/training/pySDC.py @@ -41,9 +41,10 @@ def __init__(self, # Data setup assert "dataFile" in data, "" + self.data_config = {**data} self.xStep, self.yStep = data.pop("xStep", 1), data.pop("yStep", 1) data.pop("outType", None), data.pop("outScaling", None) # overwritten by dataset - self.trainLoader, self.valLoader, self.dataset = getDataLoaders(**data) + self.trainLoader, self.valLoader, self.dataset = getDataLoaders(**data, kX=model['kX'], kY=model['kY'] ) # sample : [batchSize, 4, nX, nY] self.outType = self.dataset.outType self.outScaling = self.dataset.outScaling @@ -167,6 +168,9 @@ def printInfos(self): print(f"Scheduler: {self.scheduler_name}") for key,val in self.scheduler_config.items(): print(f" -- {key}: {val}") + print("Data settings") + for key, val in self.data_config.items(): + print(f" -- {key}: {val}") # TODO: add more details here ... print("-"*80) diff --git a/scripts/configs/config_dte-3.yaml b/scripts/configs/config_dte-3.yaml index a99b440..b231f30 100644 --- a/scripts/configs/config_dte-3.yaml +++ b/scripts/configs/config_dte-3.yaml @@ -19,6 +19,10 @@ data: batchSize: 20 xStep: 1 yStep: 1 + use_domain_sampling: true + nPatch_per_sample: 100 + use_min_limit: false + padding : [1,1,1,1] model: da: 4 @@ -43,11 +47,11 @@ optim: lr_scheduler: scheduler: CosAnnealingLR - T_max: 10000 + T_max: 50000 train: - trainDir: training_dt1e-3_update_fft - nEpochs: 10000 + trainDir: training_dt1e-3_update_fft_run26 + nEpochs: 20000 checkpoint: model.pt lossesFile: losses.txt saveEvery: 100 From 50c26f230ad0661c57dc86e5cd4043b19aaaf9f6 Mon Sep 17 00:00:00 2001 From: john2 Date: Wed, 22 Jan 2025 15:52:24 +0100 Subject: [PATCH 04/14] Revert "training with patches" This reverts commit 6478a02895f4453d8169d814f29691b9ebab1478. --- cfno/data/preprocessing.py | 105 ++---------------------------- cfno/training/pySDC.py | 6 +- scripts/configs/config_dte-3.yaml | 10 +-- 3 files changed, 11 insertions(+), 110 deletions(-) diff --git a/cfno/data/preprocessing.py b/cfno/data/preprocessing.py index 122c5bb..15a0ba1 100644 --- a/cfno/data/preprocessing.py +++ b/cfno/data/preprocessing.py @@ -6,7 +6,7 @@ import torch import glob import numpy as np -import random + from torch.utils.data import Dataset, DataLoader, random_split, Subset from cfno.simulation.post import OutputFiles @@ -110,27 +110,10 @@ def get_concat_data(self, task:str, nsamples:int, reader, multistep:bool=True): class HDF5Dataset(Dataset): - def __init__(self, dataFile, - use_domain_sampling=False, - nPatch_per_sample=1, - use_min_limit=False, - padding=[0,0,0,0], - kX= 12, kY= 12): - + def __init__(self, dataFile): self.file = h5py.File(dataFile, 'r') self.inputs = self.file['inputs'] self.outputs = self.file['outputs'] - self.use_domain_sampling = use_domain_sampling - self.nPatch_per_sample = nPatch_per_sample - self.use_min_limit = use_min_limit - self.kX = kX - self.kY = kY - xGrid, yGrid = self.grid - self.nX = xGrid.size - self.nY = yGrid.size - self.slices = self.find_patch_size() - self.padding = padding #[left, right, bottom, top] - assert len(self.inputs) == len(self.outputs), \ f"different sample number for inputs and outputs ({len(self.inputs)},{len(self.outputs)})" @@ -138,38 +121,7 @@ def __len__(self): return len(self.inputs) def __getitem__(self, idx): - if self.use_domain_sampling: - patch_padding = self.padding.copy() - iSample = idx // self.nPatch_per_sample - iPatch = idx % self.nPatch_per_sample - inpt_grid, outp_grid = self.sample(iSample) - inpt, outp = np.zeros_like(inpt_grid), np.zeros_like(outp_grid) - - sX, sY = self.slices[iPatch] - xPatch_start = random.randint(0, self.nX - sX) - yPatch_start = random.randint(0, self.nY - sY) - - if xPatch_start == 0: - patch_padding[0] = 0 - if xPatch_start == (self.nX-sX): - patch_padding[1] = 0 - if yPatch_start == 0: - patch_padding[2] = 0 - if yPatch_start == (self.nY-sY): - patch_padding[3] = 0 - - # print(f"Input size: {inpt.shape}") - # print(f'For patch {iPatch} of sample {iSample}') - # print(f'(sx,sy): {sX,sY}, (x_start,y_start): {xPatch_start,yPatch_start}') - inpt[:, :(sX + patch_padding[0] + patch_padding[1]), - :(sY + patch_padding[2] + patch_padding[3])] = inpt_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1], - yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] - outp[:,:(sX + patch_padding[0] + patch_padding[1]), - :(sY + patch_padding[2] + patch_padding[3])] = outp_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1], - yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] - else: - inpt, outp = self.sample(idx) - + inpt, outp = self.sample(idx) return torch.tensor(inpt), torch.tensor(outp) def __del__(self): @@ -196,34 +148,6 @@ def outType(self): @property def outScaling(self): return float(self.infos["outScaling"][()]) - - def find_patch_size(self): - """ - List containing random patch sizes - """ - slices = [] - if self.use_min_limit: - nX_min = self.calc_slice_min(self.nX, self.kX) - nY_min = self.calc_slice_min(self.nY, self.kY) - else: - nX_min, nY_min = 0, 0 - for i in range(self.nPatch_per_sample): - sX = random.randint(nX_min, self.nX) - sY = random.randint(nY_min, self.nY) - slices.append((sX,sY)) - return slices - - def calc_slice_min(self, n, modes): - """ - Finding min number of points to satisfy - n/2 +1 >= modes - """ - slice_min = 2*(modes-1) - if slice_min < n: - return slice_min - else: - print("Insufficient number of points to slice") - return 0 def printInfos(self): xGrid, yGrid = self.grid @@ -240,12 +164,6 @@ def printInfos(self): print(f" -- dtInput : {infos['dtInput'][()]:1.2g}") print(f" -- outType : {infos['outType'][()].decode('utf-8')}") print(f" -- outScaling : {infos['outScaling'][()]:1.2g}") - if self.use_domain_sampling: - print(f"-- nPatch (per sample): {self.nPatch_per_sample}") - print(f" --patches (per sample): {self.slices}") - print(f" --padding (per patch): {self.padding}") - if self.use_min_limit: - print(f" Min nX & nY for patch computed using {self.kX, self.kY} modes") def createDataset( dataDir, inSize, outStep, inStep, outType, outScaling, dataFile, @@ -313,19 +231,10 @@ def createDataset( dataset.close() print(" -- done !") -def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, - seed=None, use_domain_sampling=False, - nPatch_per_sample=1,use_min_limit=False, - padding=[0,0,0,0],kX=12, kY=12): - - dataset = HDF5Dataset(dataFile,use_domain_sampling, - nPatch_per_sample,use_min_limit, - padding,kX,kY) - if use_domain_sampling: - nBatches = len(dataset)//nPatch_per_sample - else: - nBatches = len(dataset) - + +def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, seed=None): + dataset = HDF5Dataset(dataFile) + nBatches = len(dataset) trainSize = int(trainRatio*nBatches) valSize = nBatches - trainSize diff --git a/cfno/training/pySDC.py b/cfno/training/pySDC.py index b271ceb..c8b175b 100644 --- a/cfno/training/pySDC.py +++ b/cfno/training/pySDC.py @@ -41,10 +41,9 @@ def __init__(self, # Data setup assert "dataFile" in data, "" - self.data_config = {**data} self.xStep, self.yStep = data.pop("xStep", 1), data.pop("yStep", 1) data.pop("outType", None), data.pop("outScaling", None) # overwritten by dataset - self.trainLoader, self.valLoader, self.dataset = getDataLoaders(**data, kX=model['kX'], kY=model['kY'] ) + self.trainLoader, self.valLoader, self.dataset = getDataLoaders(**data) # sample : [batchSize, 4, nX, nY] self.outType = self.dataset.outType self.outScaling = self.dataset.outScaling @@ -168,9 +167,6 @@ def printInfos(self): print(f"Scheduler: {self.scheduler_name}") for key,val in self.scheduler_config.items(): print(f" -- {key}: {val}") - print("Data settings") - for key, val in self.data_config.items(): - print(f" -- {key}: {val}") # TODO: add more details here ... print("-"*80) diff --git a/scripts/configs/config_dte-3.yaml b/scripts/configs/config_dte-3.yaml index b231f30..a99b440 100644 --- a/scripts/configs/config_dte-3.yaml +++ b/scripts/configs/config_dte-3.yaml @@ -19,10 +19,6 @@ data: batchSize: 20 xStep: 1 yStep: 1 - use_domain_sampling: true - nPatch_per_sample: 100 - use_min_limit: false - padding : [1,1,1,1] model: da: 4 @@ -47,11 +43,11 @@ optim: lr_scheduler: scheduler: CosAnnealingLR - T_max: 50000 + T_max: 10000 train: - trainDir: training_dt1e-3_update_fft_run26 - nEpochs: 20000 + trainDir: training_dt1e-3_update_fft + nEpochs: 10000 checkpoint: model.pt lossesFile: losses.txt saveEvery: 100 From ef6bb4e2c9fde081830942cc9405f11346e5672b Mon Sep 17 00:00:00 2001 From: john2 Date: Wed, 22 Jan 2025 15:49:51 +0100 Subject: [PATCH 05/14] training with patches --- cfno/data/preprocessing.py | 105 ++++++++++++++++++++++++++++-- cfno/training/pySDC.py | 6 +- scripts/configs/config_dte-3.yaml | 10 ++- 3 files changed, 110 insertions(+), 11 deletions(-) diff --git a/cfno/data/preprocessing.py b/cfno/data/preprocessing.py index 15a0ba1..122c5bb 100644 --- a/cfno/data/preprocessing.py +++ b/cfno/data/preprocessing.py @@ -6,7 +6,7 @@ import torch import glob import numpy as np - +import random from torch.utils.data import Dataset, DataLoader, random_split, Subset from cfno.simulation.post import OutputFiles @@ -110,10 +110,27 @@ def get_concat_data(self, task:str, nsamples:int, reader, multistep:bool=True): class HDF5Dataset(Dataset): - def __init__(self, dataFile): + def __init__(self, dataFile, + use_domain_sampling=False, + nPatch_per_sample=1, + use_min_limit=False, + padding=[0,0,0,0], + kX= 12, kY= 12): + self.file = h5py.File(dataFile, 'r') self.inputs = self.file['inputs'] self.outputs = self.file['outputs'] + self.use_domain_sampling = use_domain_sampling + self.nPatch_per_sample = nPatch_per_sample + self.use_min_limit = use_min_limit + self.kX = kX + self.kY = kY + xGrid, yGrid = self.grid + self.nX = xGrid.size + self.nY = yGrid.size + self.slices = self.find_patch_size() + self.padding = padding #[left, right, bottom, top] + assert len(self.inputs) == len(self.outputs), \ f"different sample number for inputs and outputs ({len(self.inputs)},{len(self.outputs)})" @@ -121,7 +138,38 @@ def __len__(self): return len(self.inputs) def __getitem__(self, idx): - inpt, outp = self.sample(idx) + if self.use_domain_sampling: + patch_padding = self.padding.copy() + iSample = idx // self.nPatch_per_sample + iPatch = idx % self.nPatch_per_sample + inpt_grid, outp_grid = self.sample(iSample) + inpt, outp = np.zeros_like(inpt_grid), np.zeros_like(outp_grid) + + sX, sY = self.slices[iPatch] + xPatch_start = random.randint(0, self.nX - sX) + yPatch_start = random.randint(0, self.nY - sY) + + if xPatch_start == 0: + patch_padding[0] = 0 + if xPatch_start == (self.nX-sX): + patch_padding[1] = 0 + if yPatch_start == 0: + patch_padding[2] = 0 + if yPatch_start == (self.nY-sY): + patch_padding[3] = 0 + + # print(f"Input size: {inpt.shape}") + # print(f'For patch {iPatch} of sample {iSample}') + # print(f'(sx,sy): {sX,sY}, (x_start,y_start): {xPatch_start,yPatch_start}') + inpt[:, :(sX + patch_padding[0] + patch_padding[1]), + :(sY + patch_padding[2] + patch_padding[3])] = inpt_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1], + yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + outp[:,:(sX + patch_padding[0] + patch_padding[1]), + :(sY + patch_padding[2] + patch_padding[3])] = outp_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1], + yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + else: + inpt, outp = self.sample(idx) + return torch.tensor(inpt), torch.tensor(outp) def __del__(self): @@ -148,6 +196,34 @@ def outType(self): @property def outScaling(self): return float(self.infos["outScaling"][()]) + + def find_patch_size(self): + """ + List containing random patch sizes + """ + slices = [] + if self.use_min_limit: + nX_min = self.calc_slice_min(self.nX, self.kX) + nY_min = self.calc_slice_min(self.nY, self.kY) + else: + nX_min, nY_min = 0, 0 + for i in range(self.nPatch_per_sample): + sX = random.randint(nX_min, self.nX) + sY = random.randint(nY_min, self.nY) + slices.append((sX,sY)) + return slices + + def calc_slice_min(self, n, modes): + """ + Finding min number of points to satisfy + n/2 +1 >= modes + """ + slice_min = 2*(modes-1) + if slice_min < n: + return slice_min + else: + print("Insufficient number of points to slice") + return 0 def printInfos(self): xGrid, yGrid = self.grid @@ -164,6 +240,12 @@ def printInfos(self): print(f" -- dtInput : {infos['dtInput'][()]:1.2g}") print(f" -- outType : {infos['outType'][()].decode('utf-8')}") print(f" -- outScaling : {infos['outScaling'][()]:1.2g}") + if self.use_domain_sampling: + print(f"-- nPatch (per sample): {self.nPatch_per_sample}") + print(f" --patches (per sample): {self.slices}") + print(f" --padding (per patch): {self.padding}") + if self.use_min_limit: + print(f" Min nX & nY for patch computed using {self.kX, self.kY} modes") def createDataset( dataDir, inSize, outStep, inStep, outType, outScaling, dataFile, @@ -231,10 +313,19 @@ def createDataset( dataset.close() print(" -- done !") - -def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, seed=None): - dataset = HDF5Dataset(dataFile) - nBatches = len(dataset) +def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, + seed=None, use_domain_sampling=False, + nPatch_per_sample=1,use_min_limit=False, + padding=[0,0,0,0],kX=12, kY=12): + + dataset = HDF5Dataset(dataFile,use_domain_sampling, + nPatch_per_sample,use_min_limit, + padding,kX,kY) + if use_domain_sampling: + nBatches = len(dataset)//nPatch_per_sample + else: + nBatches = len(dataset) + trainSize = int(trainRatio*nBatches) valSize = nBatches - trainSize diff --git a/cfno/training/pySDC.py b/cfno/training/pySDC.py index c8b175b..b271ceb 100644 --- a/cfno/training/pySDC.py +++ b/cfno/training/pySDC.py @@ -41,9 +41,10 @@ def __init__(self, # Data setup assert "dataFile" in data, "" + self.data_config = {**data} self.xStep, self.yStep = data.pop("xStep", 1), data.pop("yStep", 1) data.pop("outType", None), data.pop("outScaling", None) # overwritten by dataset - self.trainLoader, self.valLoader, self.dataset = getDataLoaders(**data) + self.trainLoader, self.valLoader, self.dataset = getDataLoaders(**data, kX=model['kX'], kY=model['kY'] ) # sample : [batchSize, 4, nX, nY] self.outType = self.dataset.outType self.outScaling = self.dataset.outScaling @@ -167,6 +168,9 @@ def printInfos(self): print(f"Scheduler: {self.scheduler_name}") for key,val in self.scheduler_config.items(): print(f" -- {key}: {val}") + print("Data settings") + for key, val in self.data_config.items(): + print(f" -- {key}: {val}") # TODO: add more details here ... print("-"*80) diff --git a/scripts/configs/config_dte-3.yaml b/scripts/configs/config_dte-3.yaml index a99b440..b231f30 100644 --- a/scripts/configs/config_dte-3.yaml +++ b/scripts/configs/config_dte-3.yaml @@ -19,6 +19,10 @@ data: batchSize: 20 xStep: 1 yStep: 1 + use_domain_sampling: true + nPatch_per_sample: 100 + use_min_limit: false + padding : [1,1,1,1] model: da: 4 @@ -43,11 +47,11 @@ optim: lr_scheduler: scheduler: CosAnnealingLR - T_max: 10000 + T_max: 50000 train: - trainDir: training_dt1e-3_update_fft - nEpochs: 10000 + trainDir: training_dt1e-3_update_fft_run26 + nEpochs: 20000 checkpoint: model.pt lossesFile: losses.txt saveEvery: 100 From 89cc1315d1f7aee1b644ebe40536a31c413812f3 Mon Sep 17 00:00:00 2001 From: john2 Date: Thu, 23 Jan 2025 15:37:43 +0100 Subject: [PATCH 06/14] evalution: to get domain output --- cfno/data/preprocessing.py | 28 ++++++--- cfno/models/cfno2d.py | 22 ++++++- cfno/training/pySDC.py | 23 +++++++- scripts/11_modelOutput.py | 95 +++++++++++++++++++++++++++---- scripts/configs/config_dte-3.yaml | 5 ++ 5 files changed, 150 insertions(+), 23 deletions(-) diff --git a/cfno/data/preprocessing.py b/cfno/data/preprocessing.py index 122c5bb..6cbe896 100644 --- a/cfno/data/preprocessing.py +++ b/cfno/data/preprocessing.py @@ -115,7 +115,10 @@ def __init__(self, dataFile, nPatch_per_sample=1, use_min_limit=False, padding=[0,0,0,0], - kX= 12, kY= 12): + kX= 12, kY= 12, + slices=[(16,16)], + xPatch_start=0, + yPatch_start=0): self.file = h5py.File(dataFile, 'r') self.inputs = self.file['inputs'] @@ -128,7 +131,14 @@ def __init__(self, dataFile, xGrid, yGrid = self.grid self.nX = xGrid.size self.nY = yGrid.size - self.slices = self.find_patch_size() + self.xPatch_start = xPatch_start + self.yPatch_start = yPatch_start + + if len(slices) != nPatch_per_sample: + self.slices = self.find_patch_size() + else: + self.slices = slices + self.padding = padding #[left, right, bottom, top] assert len(self.inputs) == len(self.outputs), \ @@ -146,8 +156,12 @@ def __getitem__(self, idx): inpt, outp = np.zeros_like(inpt_grid), np.zeros_like(outp_grid) sX, sY = self.slices[iPatch] - xPatch_start = random.randint(0, self.nX - sX) - yPatch_start = random.randint(0, self.nY - sY) + if len(self.slices) == 1: + xPatch_start = self.xPatch_start.copy() + yPatch_start = self.yPatch_start.copy() + else: + xPatch_start = random.randint(0, self.nX - sX) + yPatch_start = random.randint(0, self.nY - sY) if xPatch_start == 0: patch_padding[0] = 0 @@ -321,10 +335,8 @@ def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, dataset = HDF5Dataset(dataFile,use_domain_sampling, nPatch_per_sample,use_min_limit, padding,kX,kY) - if use_domain_sampling: - nBatches = len(dataset)//nPatch_per_sample - else: - nBatches = len(dataset) + + nBatches = len(dataset) trainSize = int(trainRatio*nBatches) valSize = nBatches - trainSize diff --git a/cfno/models/cfno2d.py b/cfno/models/cfno2d.py index bcf872c..120bd3b 100644 --- a/cfno/models/cfno2d.py +++ b/cfno/models/cfno2d.py @@ -311,7 +311,9 @@ def __init__(self, da, dv, du, kX=4, kY=4, fno_skip_type='linear', use_postfnochannel_mlp=False, channel_mlp_skip_type='soft-gating', - channel_mlp_expansion=4 + channel_mlp_expansion=4, + get_subdomain_output=False, + iXBeg=0,iYBeg=0,iXEnd=256,iYEnd=256, ): super().__init__() @@ -357,9 +359,19 @@ def __init__(self, da, dv, du, kX=4, kY=4, for _ in range(nLayers)]) self.memory = CudaMemoryDebugger(print_mem=True) + self.get_subdomain_output = get_subdomain_output + if self.get_subdomain_output: + self.iXBeg = iXBeg + self.iXEnd = iXEnd + self.iYBeg = iYBeg + self.iYEnd = iYEnd def forward(self, x): - """ x[nBatch, nX, nY, da] -> [nBatch, du, nX, nY]""" + """ x[nBatch, nX, nY, da] -> [nBatch, du, nX, nY] + if use_subdomain_output: + x[nBatch, nX, nY, da] -> [nBatch, du, iEndX-iBegX, iEndY-iBegY] + """ + # x = x.permute(0,3,1,2) x = self.P(x) @@ -373,10 +385,16 @@ def forward(self, x): x = self.channel_mlp[index](x) + x_skip_channel_mlp if index < len(self.layers) - 1: x = nn.functional.gelu(x) + + # to get only a subdomain output + if self.get_subdomain_output: + print(f'Filtering to x-subdomain {self.iXBeg,self.iXEnd} & y-subdomain {self.iYBeg,self.iYEnd}') + x = x[:, :, self.iXBeg:self.iXEnd, self.iYBeg:self.iYEnd] x = self.Q(x) # x = self.pos(x) # x = x.permute(0,2,3,1) + print(f'Shape of x: {x.shape}') return x diff --git a/cfno/training/pySDC.py b/cfno/training/pySDC.py index b271ceb..74bc295 100644 --- a/cfno/training/pySDC.py +++ b/cfno/training/pySDC.py @@ -34,8 +34,10 @@ def __init__(self, self.device = th.device('cuda' if th.cuda.is_available() else 'cpu') # Inference-only mode - if data is None and model is None and optim is None: + if data is None and optim is None: assert checkpoint is not None, "need a checkpoint in inference-only evaluation" + if model is not None: + self.modelConfig = model self.load(checkpoint, modelOnly=True) return @@ -362,8 +364,12 @@ def load(self, filePath, modelOnly=False): # for backward compatibility ... checkpoint['model']["non_linearity"] = checkpoint['model'].pop("nonLinearity") if hasattr(self, "modelConfig") and self.modelConfig != checkpoint['model']: + for key, value in self.modelConfig.items(): + if key not in checkpoint['model']: + checkpoint['model'][key] = value print("WARNING : different model settings in config file," " overwriting with config from checkpoint ...") + # print(f"Model: {checkpoint['model']}") self.setupModel(checkpoint['model']) self.model.load_state_dict(checkpoint['model_state_dict']) self.outType = checkpoint["outType"] @@ -402,7 +408,20 @@ def __call__(self, u0, nEval=1): outp = model(inpt) if self.outType == "update": outp /= self.outScaling - outp += inpt + + # Mapping output to input shape to perform addition + if outp.shape == inpt.shape: + outp += inpt + else: + padded_tensor = th.zeros_like(inpt) + padded_tensor[:,:, + self.modelConfig['iXBeg']: self.modelConfig['iXEnd'], + self.modelConfig['iYBeg']: self.modelConfig['iYEnd'] ] = outp[:,:,:,:] + # print(f'Padded tensor: {padded_tensor.shape}') + padded_tensor += inpt + outp = padded_tensor[:,:,self.modelConfig['iXBeg']: self.modelConfig['iXEnd'], + self.modelConfig['iYBeg']: self.modelConfig['iYEnd']] + # print(f'Ouptut shape: {outp.shape}') inpt = outp if not multi: diff --git a/scripts/11_modelOutput.py b/scripts/11_modelOutput.py index 56d7aa5..7866898 100755 --- a/scripts/11_modelOutput.py +++ b/scripts/11_modelOutput.py @@ -3,10 +3,11 @@ import argparse import sys import os -sys.path.insert(2, os.getcwd()) +import numpy as np from cfno.data.preprocessing import HDF5Dataset from cfno.training.pySDC import FourierNeuralOp from cfno.simulation.post import contourPlot +from cfno.utils import readConfig varChoices = ["vx", "vz", "b", "p"] @@ -24,12 +25,28 @@ "--var", default="b", help="variable to view", choices=varChoices) parser.add_argument( "--iSample", default=0, help="sample index", type=int) +parser.add_argument( + "--iXBeg", default=0, help="xPatch start index", type=int) +parser.add_argument( + "--iYBeg", default=0, help="yPatch start index", type=int) +parser.add_argument( + "--iXEnd", default=256, help="xPatch end index", type=int) +parser.add_argument( + "--iYEnd", default=64, help="yPatch end index", type=int) parser.add_argument( "--outType", default="solution", help="type of output", choices=["solution", "update"]) parser.add_argument( "--refScales", action="store_true", help="use the same scales as the reference field") parser.add_argument( "--saveFig", default="modelView.jpg", help="output name for contour figure") +parser.add_argument( + "--config", default="config.yaml", help="configuration file") +parser.add_argument( + "--get_subdomain_output", action="store_true", help="Get subdomain output" +) +parser.add_argument( + "--use_full_input", action="store_true", help="Use full input" +) args = parser.parse_args() dataFile = args.dataFile @@ -40,6 +57,21 @@ refScales = args.refScales saveFig = args.saveFig +# Need model config to get output of shape different to input +# .i.e when using args.get_subdomain_ouput +if args.get_subdomain_output: + if args.config is not None: + config = readConfig(args.config) + args.__dict__.update(**config["model"]) + modelConfig = dict(config.model) + else: + raise ValueError("Model configuration is required for FNO to output subdomain.") + +iXBeg = args.iXBeg +iYBeg = args.iYBeg +iXEnd = args.iXEnd +iYEnd = args.iYEnd +print(f'Args: {args}') # ----------------------------------------------------------------------------- # Script execution # ----------------------------------------------------------------------------- @@ -50,25 +82,66 @@ xGrid, yGrid = dataset.grid u0, uRef = dataset.sample(iSample) +print(f'Shape of u0: {u0.shape}, uRef: {uRef.shape}') + +if args.use_full_input: + uInit = u0[varChoices.index(var)].T + input = u0 +else: + uInit = u0[varChoices.index(var), iXBeg:iXEnd, iYBeg:iYEnd].T + input = u0[:,iXBeg:iXEnd, iYBeg:iYEnd] -uInit = u0[varChoices.index(var)].T -uRef = uRef[varChoices.index(var)].T.copy() +if args.get_subdomain_output: + model = FourierNeuralOp(model=modelConfig, checkpoint=checkpoint) +else: + iXBeg = 0 + iYBeg = 0 + iXEnd = input.shape[1] + iYEnd = input.shape[2] + model = FourierNeuralOp(checkpoint=checkpoint) -model = FourierNeuralOp(checkpoint=checkpoint) -uPred = model(u0)[varChoices.index(var)].T.copy() + +uRef = uRef[varChoices.index(var), iXBeg:iXEnd, iYBeg:iYEnd].T.copy() +print(f'Shape of uInit: {uInit.T.shape}, uRef:{uRef.T.shape}') + +uPred = model(input)[varChoices.index(var)].T.copy() if dataset.outType == "update": uRef /= dataset.outScaling + +if uRef.shape != uInit.shape: + padded_uRef = np.zeros_like(uInit) + padded_uPred = np.zeros_like(uInit) + padded_uRef[iYBeg:iYEnd, iXBeg:iXEnd] = uRef[:,:] + padded_uPred[iYBeg:iYEnd, iXBeg:iXEnd] = uPred[:,:] +else: + padded_uRef = uRef.copy() + padded_uPred = uPred.copy() + + if outType == "solution" and dataset.outType == "update": - uRef += uInit + padded_uRef += uInit if outType == "update" and dataset.outType == "solution": - uRef -= uInit - + padded_uRef -= uInit if outType == "update": - uPred -= uInit + padded_uPred -= uInit + +uPred = padded_uPred[iYBeg:iYEnd, iXBeg:iXEnd] +uRef = padded_uRef[iYBeg:iYEnd, iXBeg:iXEnd] +print(f'Shape of uPred: {uPred.T.shape}, uRef: {uRef.T.shape}') +xExpandedGrid = np.linspace(iXBeg, iXEnd, uPred.shape[0] + 1) +yExpandedGrid = np.linspace(iYBeg, iYEnd, uPred.shape[1] + 1) + +# Create 2D grids for pcolormesh +X, Y = np.meshgrid(xExpandedGrid, yExpandedGrid, indexing='ij') contourPlot( - uPred, xGrid, yGrid, title=f"Model {outType} for {var} using sample {iSample}", + uPred, X, Y, title=f"Model {outType} for {var} using sample {iSample}", refField=uRef, refTitle=f"Dedalus reference (dt={dataset.infos['dtInput'][()]:1.2g}s)", - saveFig=saveFig, closeFig=False, refScales=refScales) + saveFig=f'{saveFig}_{outType}.jpg', closeFig=False, refScales=refScales) print(f" -- saved {var} contour for sample {iSample}") +contourPlot( + uPred-uRef, X, Y, title=f"Model {outType} error for {var} using sample {iSample}\nDedalus reference (dt={dataset.infos['dtInput'][()]:1.2g}s)", + refField=None, refTitle=None, + saveFig=f'{saveFig}_{outType}_error.jpg', closeFig=False, refScales=refScales) +print(f" -- saved {var} contour for sample {iSample}") \ No newline at end of file diff --git a/scripts/configs/config_dte-3.yaml b/scripts/configs/config_dte-3.yaml index c68a2b0..769d679 100644 --- a/scripts/configs/config_dte-3.yaml +++ b/scripts/configs/config_dte-3.yaml @@ -45,6 +45,11 @@ model: fno_skip_type: soft-gating use_postfnochannel_mlp: false channel_mlp_skip_type: 'soft-gating' + get_subdomain_output: true + iXBeg: 16 + iYBeg: 16 + iXEnd: 64 + iYEnd: 64 optim: name: adam From 60b33f3fd60fb2b48e7eec4d42d275b7f693d417 Mon Sep 17 00:00:00 2001 From: john2 Date: Thu, 23 Jan 2025 17:18:15 +0100 Subject: [PATCH 07/14] removing redundant args --- scripts/configs/config_dte-3.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/scripts/configs/config_dte-3.yaml b/scripts/configs/config_dte-3.yaml index 769d679..05a5684 100644 --- a/scripts/configs/config_dte-3.yaml +++ b/scripts/configs/config_dte-3.yaml @@ -39,12 +39,6 @@ model: fno_skip_type: soft-gating use_postfnochannel_mlp: false channel_mlp_skip_type: 'soft-gating' - use_prechannel_mlp: true - channel_mlp_expansion: 4 - use_fno_skip_connection: false - fno_skip_type: soft-gating - use_postfnochannel_mlp: false - channel_mlp_skip_type: 'soft-gating' get_subdomain_output: true iXBeg: 16 iYBeg: 16 From 90a2f9ebe635ec7648b90287a86e452e23b4a70b Mon Sep 17 00:00:00 2001 From: john2 Date: Thu, 23 Jan 2025 18:11:14 +0100 Subject: [PATCH 08/14] cleaning up code --- cfno/data/preprocessing.py | 16 +++++++--------- cfno/models/cfno2d.py | 5 +++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/cfno/data/preprocessing.py b/cfno/data/preprocessing.py index 6cbe896..b660431 100644 --- a/cfno/data/preprocessing.py +++ b/cfno/data/preprocessing.py @@ -163,18 +163,16 @@ def __getitem__(self, idx): xPatch_start = random.randint(0, self.nX - sX) yPatch_start = random.randint(0, self.nY - sY) - if xPatch_start == 0: - patch_padding[0] = 0 - if xPatch_start == (self.nX-sX): - patch_padding[1] = 0 - if yPatch_start == 0: - patch_padding[2] = 0 - if yPatch_start == (self.nY-sY): - patch_padding[3] = 0 - + patch_padding[0] = 0 if xPatch_start == 0 or (xPatch_start - patch_padding[0]) < 0 else patch_padding[0] + patch_padding[1] = 0 if (xPatch_start + sX + patch_padding[1]) >= self.nX else patch_padding[1] + patch_padding[2] = 0 if yPatch_start == 0 or (yPatch_start - patch_padding[2]) < 0 else patch_padding[2] + patch_padding[3] = 0 if (yPatch_start + sY + patch_padding[3]) >= self.nY else patch_padding[3] + # print(f"Input size: {inpt.shape}") # print(f'For patch {iPatch} of sample {iSample}') # print(f'(sx,sy): {sX,sY}, (x_start,y_start): {xPatch_start,yPatch_start}') + # print(f'padding: {patch_padding}') + inpt[:, :(sX + patch_padding[0] + patch_padding[1]), :(sY + patch_padding[2] + patch_padding[3])] = inpt_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1], yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] diff --git a/cfno/models/cfno2d.py b/cfno/models/cfno2d.py index 120bd3b..dc548b2 100644 --- a/cfno/models/cfno2d.py +++ b/cfno/models/cfno2d.py @@ -313,7 +313,8 @@ def __init__(self, da, dv, du, kX=4, kY=4, channel_mlp_skip_type='soft-gating', channel_mlp_expansion=4, get_subdomain_output=False, - iXBeg=0,iYBeg=0,iXEnd=256,iYEnd=256, + iXBeg=0,iYBeg=0, + iXEnd=256,iYEnd=256, ): super().__init__() @@ -394,7 +395,7 @@ def forward(self, x): x = self.Q(x) # x = self.pos(x) # x = x.permute(0,2,3,1) - print(f'Shape of x: {x.shape}') + # print(f'Shape of x: {x.shape}') return x From 1f6ab4e1698ec02f948a133a803372ba9eb99abf Mon Sep 17 00:00:00 2001 From: john2 Date: Thu, 30 Jan 2025 19:37:57 +0100 Subject: [PATCH 09/14] corrected contour scales and errors --- cfno/training/pySDC.py | 2 +- scripts/11_modelOutput.py | 85 ++++++++++++++++++++------------------- 2 files changed, 44 insertions(+), 43 deletions(-) diff --git a/cfno/training/pySDC.py b/cfno/training/pySDC.py index 74bc295..10a840f 100644 --- a/cfno/training/pySDC.py +++ b/cfno/training/pySDC.py @@ -369,7 +369,7 @@ def load(self, filePath, modelOnly=False): checkpoint['model'][key] = value print("WARNING : different model settings in config file," " overwriting with config from checkpoint ...") - # print(f"Model: {checkpoint['model']}") + print(f"Model: {checkpoint['model']}") self.setupModel(checkpoint['model']) self.model.load_state_dict(checkpoint['model_state_dict']) self.outType = checkpoint["outType"] diff --git a/scripts/11_modelOutput.py b/scripts/11_modelOutput.py index 7866898..e48b792 100755 --- a/scripts/11_modelOutput.py +++ b/scripts/11_modelOutput.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse -import sys -import os +import sys, os +sys.path.insert(2, os.getcwd()) import numpy as np from cfno.data.preprocessing import HDF5Dataset from cfno.training.pySDC import FourierNeuralOp @@ -38,15 +38,13 @@ parser.add_argument( "--refScales", action="store_true", help="use the same scales as the reference field") parser.add_argument( - "--saveFig", default="modelView.jpg", help="output name for contour figure") + "--saveFig", default="modelView", help="output name for contour figure") parser.add_argument( "--config", default="config.yaml", help="configuration file") parser.add_argument( - "--get_subdomain_output", action="store_true", help="Get subdomain output" -) + "--get_subdomain_output", action="store_true", help="Get subdomain output") parser.add_argument( - "--use_full_input", action="store_true", help="Use full input" -) + "--use_full_input", action="store_true", help="Use full input") args = parser.parse_args() dataFile = args.dataFile @@ -81,30 +79,37 @@ assert iSample < nSamples, f"iSample={iSample} to big for {nSamples} samples" xGrid, yGrid = dataset.grid -u0, uRef = dataset.sample(iSample) -print(f'Shape of u0: {u0.shape}, uRef: {uRef.shape}') +u0, uRef_full = dataset.sample(iSample) +print(f'Shape of u0: {u0.shape}, uRef: {uRef_full.shape}') if args.use_full_input: - uInit = u0[varChoices.index(var)].T + uInit = u0[varChoices.index(var)] input = u0 + if args.get_subdomain_output: + uRef = uRef_full[varChoices.index(var), iXBeg:iXEnd, iYBeg:iYEnd].copy() + else: + uRef = uRef_full[varChoices.index(var)].copy() + iXBeg = 0 + iYBeg = 0 + iXEnd = input.shape[1] + iYEnd = input.shape[2] else: - uInit = u0[varChoices.index(var), iXBeg:iXEnd, iYBeg:iYEnd].T + uInit = u0[varChoices.index(var), iXBeg:iXEnd, iYBeg:iYEnd] input = u0[:,iXBeg:iXEnd, iYBeg:iYEnd] + uRef = uRef_full[varChoices.index(var), iXBeg:iXEnd, iYBeg:iYEnd].copy() if args.get_subdomain_output: model = FourierNeuralOp(model=modelConfig, checkpoint=checkpoint) else: - iXBeg = 0 - iYBeg = 0 - iXEnd = input.shape[1] - iYEnd = input.shape[2] model = FourierNeuralOp(checkpoint=checkpoint) +uPred = model(input)[varChoices.index(var)].copy() +print(f'Shape of uInit: {uInit.shape}, uRef:{uRef.shape}, uPred: {uPred.shape}') -uRef = uRef[varChoices.index(var), iXBeg:iXEnd, iYBeg:iYEnd].T.copy() -print(f'Shape of uInit: {uInit.T.shape}, uRef:{uRef.T.shape}') - -uPred = model(input)[varChoices.index(var)].T.copy() +xExpandedGrid = np.linspace(iXBeg, iXEnd, uPred.shape[0] + 1) +yExpandedGrid = np.linspace(iYBeg, iYEnd, uPred.shape[1] + 1) +# Create 2D grids for pcolormesh +X, Y = np.meshgrid(xExpandedGrid, yExpandedGrid, indexing='ij') if dataset.outType == "update": uRef /= dataset.outScaling @@ -112,36 +117,32 @@ if uRef.shape != uInit.shape: padded_uRef = np.zeros_like(uInit) padded_uPred = np.zeros_like(uInit) - padded_uRef[iYBeg:iYEnd, iXBeg:iXEnd] = uRef[:,:] - padded_uPred[iYBeg:iYEnd, iXBeg:iXEnd] = uPred[:,:] + padded_uRef[iXBeg:iXEnd, iYBeg:iYEnd] = uRef[:,:].copy() + padded_uPred[iXBeg:iXEnd, iYBeg:iYEnd] = uPred[:,:].copy() + print(f'padded uRef: {padded_uRef.shape}, uPred: {padded_uPred.shape}') + if outType == "solution" and dataset.outType == "update": + padded_uRef += uInit + if outType == "update" and dataset.outType == "solution": + padded_uRef -= uInit + if outType == "update": + padded_uPred -= uInit + uPred[:,:] = padded_uPred[iXBeg:iXEnd, iYBeg:iYEnd] + uRef[:,:] = padded_uRef[iXBeg:iXEnd, iYBeg:iYEnd] else: - padded_uRef = uRef.copy() - padded_uPred = uPred.copy() - - -if outType == "solution" and dataset.outType == "update": - padded_uRef += uInit -if outType == "update" and dataset.outType == "solution": - padded_uRef -= uInit -if outType == "update": - padded_uPred -= uInit + if outType == "solution" and dataset.outType == "update": + uRef += uInit + if outType == "update" and dataset.outType == "solution": + uRef -= uInit + if outType == "update": + uPred -= uInit -uPred = padded_uPred[iYBeg:iYEnd, iXBeg:iXEnd] -uRef = padded_uRef[iYBeg:iYEnd, iXBeg:iXEnd] -print(f'Shape of uPred: {uPred.T.shape}, uRef: {uRef.T.shape}') - -xExpandedGrid = np.linspace(iXBeg, iXEnd, uPred.shape[0] + 1) -yExpandedGrid = np.linspace(iYBeg, iYEnd, uPred.shape[1] + 1) - -# Create 2D grids for pcolormesh -X, Y = np.meshgrid(xExpandedGrid, yExpandedGrid, indexing='ij') contourPlot( uPred, X, Y, title=f"Model {outType} for {var} using sample {iSample}", refField=uRef, refTitle=f"Dedalus reference (dt={dataset.infos['dtInput'][()]:1.2g}s)", saveFig=f'{saveFig}_{outType}.jpg', closeFig=False, refScales=refScales) print(f" -- saved {var} contour for sample {iSample}") contourPlot( - uPred-uRef, X, Y, title=f"Model {outType} error for {var} using sample {iSample}\nDedalus reference (dt={dataset.infos['dtInput'][()]:1.2g}s)", + np.abs(uPred-uRef), X, Y, title=f"Model {outType} error for {var} using sample {iSample}\nDedalus reference (dt={dataset.infos['dtInput'][()]:1.2g}s)", refField=None, refTitle=None, - saveFig=f'{saveFig}_{outType}_error.jpg', closeFig=False, refScales=refScales) + saveFig=f'{saveFig}_{outType}_error.jpg', closeFig=False, refScales=False) print(f" -- saved {var} contour for sample {iSample}") \ No newline at end of file From f7046428ea6d215e69bf25b00a55628687a45a97 Mon Sep 17 00:00:00 2001 From: john2 Date: Mon, 3 Feb 2025 09:25:07 +0100 Subject: [PATCH 10/14] fixed domain size sampling with & w/o overlapping --- cfno/data/preprocessing.py | 119 +++++++++++++++++++++++++++++-------- 1 file changed, 93 insertions(+), 26 deletions(-) diff --git a/cfno/data/preprocessing.py b/cfno/data/preprocessing.py index b660431..08f4daf 100644 --- a/cfno/data/preprocessing.py +++ b/cfno/data/preprocessing.py @@ -112,18 +112,45 @@ class HDF5Dataset(Dataset): def __init__(self, dataFile, use_domain_sampling=False, + use_fixed_domain=False, + use_ordered_sampling=False, nPatch_per_sample=1, use_min_limit=False, padding=[0,0,0,0], - kX= 12, kY= 12, - slices=[(16,16)], xPatch_start=0, - yPatch_start=0): + yPatch_start=0, + slices=[], + kX=12, kY=12): + """ + Dataset reader and getitem for DataLoader + + Args: + dataFile (hdf5): data file + use_domain_sampling (bool, optional): To divide full grid (nX,nY) into nPatch_per_sample + random sized patches of (sX,sY). Defaults to False. + use_fixed_domain (bool, optional): To divide full grid (nX,nY) into nPatch_per_sample of + (sX,sY) patches with overlapping. Defaults to False. + use_ordered_sampling (bool, optional): To divide full grid (nX,nY) into (nX//sX)*(nY//sY) + exactly divisible (sX,sY) size patches w/o overlapping. + Defaults to False. + nPatch_per_sample (int, optional): Number of sub-domains per sample. Defaults to 1. + use_min_limit (bool, optional): Restrict (sX,sY) to be > (2*kX -1, 2*kY-1). Defaults to False. + padding (list, optional): Columns and rows to decode inflow information + in format[left, right, bottom, top]. Defaults to [0,0,0,0] + xPatch_start (int, optional): Starting index of patch in x-axis. Defaults to 0. + yPatch_start (int, optional): Starting index of patch in y-axis. Defaults to 0. + slices (list, optional): Sizes of patch [[sX,sY]]. Defaults to []. + kX (int, optional): Number of fourier modes in x-axis. Defaults to 12. + kY (int, optional): Number of fourier modes in y-axis. Defaults to 12. + + """ self.file = h5py.File(dataFile, 'r') self.inputs = self.file['inputs'] self.outputs = self.file['outputs'] self.use_domain_sampling = use_domain_sampling + self.use_fixed_domain = use_fixed_domain + self.use_ordered_sampling = use_ordered_sampling self.nPatch_per_sample = nPatch_per_sample self.use_min_limit = use_min_limit self.kX = kX @@ -133,16 +160,25 @@ def __init__(self, dataFile, self.nY = yGrid.size self.xPatch_start = xPatch_start self.yPatch_start = yPatch_start - - if len(slices) != nPatch_per_sample: + + if len(slices) == 0: self.slices = self.find_patch_size() + elif len(slices) == 1: + if self.use_fixed_domain: + if self.use_ordered_sampling: + self.nPatch_per_sample = (self.nX // slices[0][0]) * (self.nY // slices[0][1]) + self.slices = slices * self.nPatch_per_sample + else: + self.slices = slices else: self.slices = slices - + self.padding = padding #[left, right, bottom, top] assert len(self.inputs) == len(self.outputs), \ f"different sample number for inputs and outputs ({len(self.inputs)},{len(self.outputs)})" + + assert not self.use_ordered_sampling or self.use_fixed_domain, "If use_ordered_sampling is True, then use_fixed_domain must also be True" def __len__(self): return len(self.inputs) @@ -154,15 +190,18 @@ def __getitem__(self, idx): iPatch = idx % self.nPatch_per_sample inpt_grid, outp_grid = self.sample(iSample) inpt, outp = np.zeros_like(inpt_grid), np.zeros_like(outp_grid) - sX, sY = self.slices[iPatch] if len(self.slices) == 1: xPatch_start = self.xPatch_start.copy() yPatch_start = self.yPatch_start.copy() else: - xPatch_start = random.randint(0, self.nX - sX) - yPatch_start = random.randint(0, self.nY - sY) - + if self.use_ordered_sampling: + xPatch_start = (iPatch // (self.nX//sX)) * sX + yPatch_start = (iPatch % (self.nY//sY)) * sY + else: + xPatch_start = random.randint(0, self.nX - sX) + yPatch_start = random.randint(0, self.nY - sY) + patch_padding[0] = 0 if xPatch_start == 0 or (xPatch_start - patch_padding[0]) < 0 else patch_padding[0] patch_padding[1] = 0 if (xPatch_start + sX + patch_padding[1]) >= self.nX else patch_padding[1] patch_padding[2] = 0 if yPatch_start == 0 or (yPatch_start - patch_padding[2]) < 0 else patch_padding[2] @@ -171,7 +210,7 @@ def __getitem__(self, idx): # print(f"Input size: {inpt.shape}") # print(f'For patch {iPatch} of sample {iSample}') # print(f'(sx,sy): {sX,sY}, (x_start,y_start): {xPatch_start,yPatch_start}') - # print(f'padding: {patch_padding}') + # print(f'padding: {patch_padding}, nPatch_per_sample:{self.nPatch_per_sample}') inpt[:, :(sX + patch_padding[0] + patch_padding[1]), :(sY + patch_padding[2] + patch_padding[3])] = inpt_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1], @@ -211,18 +250,28 @@ def outScaling(self): def find_patch_size(self): """ - List containing random patch sizes + List containing patch sizes """ slices = [] - if self.use_min_limit: - nX_min = self.calc_slice_min(self.nX, self.kX) - nY_min = self.calc_slice_min(self.nY, self.kY) + if self.use_fixed_domain: + self.valid_sX = [sx for sx in range(1,self.nX) if self.nX % sx == 0] + self.valid_sY = [sy for sy in range(1,self.nY) if self.nY % sy == 0] + # select a (sX,sY) randomly + sX = int(random.choice(self.valid_sX)) + sY = int(random.choice(self.valid_sY)) + if self.use_ordered_sampling: + self.nPatch_per_sample = (self.nX // sX) * (self.nY // sY) + slices = [(sX,sY)]*self.nPatch_per_sample else: - nX_min, nY_min = 0, 0 - for i in range(self.nPatch_per_sample): - sX = random.randint(nX_min, self.nX) - sY = random.randint(nY_min, self.nY) - slices.append((sX,sY)) + if self.use_min_limit: + nX_min = self.calc_slice_min(self.nX, self.kX) + nY_min = self.calc_slice_min(self.nY, self.kY) + else: + nX_min, nY_min = 0, 0 + for i in range(self.nPatch_per_sample): + sX = random.randint(nX_min, self.nX) + sY = random.randint(nY_min, self.nY) + slices.append((sX,sY)) return slices def calc_slice_min(self, n, modes): @@ -252,9 +301,13 @@ def printInfos(self): print(f" -- dtInput : {infos['dtInput'][()]:1.2g}") print(f" -- outType : {infos['outType'][()].decode('utf-8')}") print(f" -- outScaling : {infos['outScaling'][()]:1.2g}") + print(f" --use_ordered_sampling: {self.use_ordered_sampling}") if self.use_domain_sampling: print(f"-- nPatch (per sample): {self.nPatch_per_sample}") - print(f" --patches (per sample): {self.slices}") + if self.use_fixed_domain: + print(f" --patches (per sample): {self.slices[0]}") + else: + print(f" --patches (per sample): {self.slices}") print(f" --padding (per patch): {self.padding}") if self.use_min_limit: print(f" Min nX & nY for patch computed using {self.kX, self.kY} modes") @@ -327,12 +380,26 @@ def createDataset( def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, seed=None, use_domain_sampling=False, - nPatch_per_sample=1,use_min_limit=False, - padding=[0,0,0,0],kX=12, kY=12): + use_fixed_domain=False, + use_ordered_sampling=False, + nPatch_per_sample=1, + use_min_limit=False, + padding=[0,0,0,0], + xPatch_start=0, + yPatch_start=0, + kX= 12, kY= 12, **kwargs): + + if 'slices' in kwargs: + slices = kwargs['slices'] + else: + slices = [] - dataset = HDF5Dataset(dataFile,use_domain_sampling, - nPatch_per_sample,use_min_limit, - padding,kX,kY) + dataset = HDF5Dataset(dataFile, use_domain_sampling, + use_fixed_domain, use_ordered_sampling, + nPatch_per_sample, use_min_limit, + padding,xPatch_start, yPatch_start, + slices, kX, kY) + dataset.printInfos() nBatches = len(dataset) From b862e570fa1dca68f51756a18961abcd05133cac Mon Sep 17 00:00:00 2001 From: john2 Date: Mon, 3 Feb 2025 16:25:37 +0100 Subject: [PATCH 11/14] using data grid for contour --- scripts/11_modelOutput.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/11_modelOutput.py b/scripts/11_modelOutput.py index e48b792..eb153e8 100755 --- a/scripts/11_modelOutput.py +++ b/scripts/11_modelOutput.py @@ -106,10 +106,11 @@ uPred = model(input)[varChoices.index(var)].copy() print(f'Shape of uInit: {uInit.shape}, uRef:{uRef.shape}, uPred: {uPred.shape}') -xExpandedGrid = np.linspace(iXBeg, iXEnd, uPred.shape[0] + 1) -yExpandedGrid = np.linspace(iYBeg, iYEnd, uPred.shape[1] + 1) -# Create 2D grids for pcolormesh -X, Y = np.meshgrid(xExpandedGrid, yExpandedGrid, indexing='ij') +# xExpandedGrid = np.linspace(iXBeg, iXEnd, uPred.shape[0] + 1) +# yExpandedGrid = np.linspace(iYBeg, iYEnd, uPred.shape[1] + 1) +# # Create 2D grids for pcolormesh +# X, Y = np.meshgrid(xExpandedGrid, yExpandedGrid, indexing='ij') +Y, X = xGrid[iXBeg:iXEnd], yGrid[iYBeg:iYEnd] if dataset.outType == "update": uRef /= dataset.outScaling From 080f1bfe00f5a90bd5d9459bc7d7ed126c0a4f3d Mon Sep 17 00:00:00 2001 From: john2 Date: Mon, 3 Feb 2025 18:29:16 +0100 Subject: [PATCH 12/14] used slicing instead of padding, change i{X,Y}End deafult --- scripts/11_modelOutput.py | 41 ++++++++++++--------------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/scripts/11_modelOutput.py b/scripts/11_modelOutput.py index eb153e8..2a46bf8 100755 --- a/scripts/11_modelOutput.py +++ b/scripts/11_modelOutput.py @@ -30,9 +30,9 @@ parser.add_argument( "--iYBeg", default=0, help="yPatch start index", type=int) parser.add_argument( - "--iXEnd", default=256, help="xPatch end index", type=int) + "--iXEnd", default=None, help="xPatch end index", type=int) parser.add_argument( - "--iYEnd", default=64, help="yPatch end index", type=int) + "--iYEnd", default=None, help="yPatch end index", type=int) parser.add_argument( "--outType", default="solution", help="type of output", choices=["solution", "update"]) parser.add_argument( @@ -104,39 +104,22 @@ model = FourierNeuralOp(checkpoint=checkpoint) uPred = model(input)[varChoices.index(var)].copy() -print(f'Shape of uInit: {uInit.shape}, uRef:{uRef.shape}, uPred: {uPred.shape}') - -# xExpandedGrid = np.linspace(iXBeg, iXEnd, uPred.shape[0] + 1) -# yExpandedGrid = np.linspace(iYBeg, iYEnd, uPred.shape[1] + 1) -# # Create 2D grids for pcolormesh -# X, Y = np.meshgrid(xExpandedGrid, yExpandedGrid, indexing='ij') -Y, X = xGrid[iXBeg:iXEnd], yGrid[iYBeg:iYEnd] if dataset.outType == "update": uRef /= dataset.outScaling if uRef.shape != uInit.shape: - padded_uRef = np.zeros_like(uInit) - padded_uPred = np.zeros_like(uInit) - padded_uRef[iXBeg:iXEnd, iYBeg:iYEnd] = uRef[:,:].copy() - padded_uPred[iXBeg:iXEnd, iYBeg:iYEnd] = uPred[:,:].copy() - print(f'padded uRef: {padded_uRef.shape}, uPred: {padded_uPred.shape}') - if outType == "solution" and dataset.outType == "update": - padded_uRef += uInit - if outType == "update" and dataset.outType == "solution": - padded_uRef -= uInit - if outType == "update": - padded_uPred -= uInit - uPred[:,:] = padded_uPred[iXBeg:iXEnd, iYBeg:iYEnd] - uRef[:,:] = padded_uRef[iXBeg:iXEnd, iYBeg:iYEnd] -else: - if outType == "solution" and dataset.outType == "update": - uRef += uInit - if outType == "update" and dataset.outType == "solution": - uRef -= uInit - if outType == "update": - uPred -= uInit + uInit = u0[varChoices.index(var), iXBeg:iXEnd, iYBeg:iYEnd] + +print(f'Shape of uInit: {uInit.shape}, uRef:{uRef.shape}, uPred: {uPred.shape}') +if outType == "solution" and dataset.outType == "update": + uRef += uInit +if outType == "update" and dataset.outType == "solution": + uRef -= uInit +if outType == "update": + uPred -= uInit +Y, X = xGrid[iXBeg:iXEnd], yGrid[iYBeg:iYEnd] contourPlot( uPred, X, Y, title=f"Model {outType} for {var} using sample {iSample}", refField=uRef, refTitle=f"Dedalus reference (dt={dataset.infos['dtInput'][()]:1.2g}s)", From 95a66bc46165301d0c5b53a5ba1b0a53fa25d53d Mon Sep 17 00:00:00 2001 From: john2 Date: Mon, 10 Feb 2025 12:03:27 +0100 Subject: [PATCH 13/14] fixing irfft2 dimesnions --- cfno/models/cfno2d.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/cfno/models/cfno2d.py b/cfno/models/cfno2d.py index dc548b2..a7347e3 100644 --- a/cfno/models/cfno2d.py +++ b/cfno/models/cfno2d.py @@ -88,9 +88,9 @@ def _toFourierSpace_FORCE_FFT(self, x): x = th.fft.rfft2(x, norm="ortho") # RFFT on last 2 dimensions return x - def _toRealSpace_FORCE_FFT(self, x): + def _toRealSpace_FORCE_FFT(self, x, org_size): """ x[nBatch, dv, fX = nX/2+1, fY = nY/2+1] -> [nBatch, dv, nX, nY]""" - x = th.fft.irfft2(x, norm="ortho") # IRFFT on last 2 dimensions + x = th.fft.irfft2(x, s=org_size, norm="ortho") # IRFFT on last 2 dimensions return x def _toFourierSpace_FORCE_FFT_REORDER(self, x): @@ -115,6 +115,7 @@ def _toRealSpace_FORCE_FFT_REORDER(self, x): def forward(self, x:th.tensor): """ x[nBatch, dv, nX, nY] -> [nBatch, dv, nX, nY] """ + org_size = x.shape # Transform to Fourier space -> [nBatch, dv, fX, fY] x = self._toFourierSpace(x) # Truncate and keep only first modes -> [nBatch, dv, kX, kY] @@ -130,7 +131,9 @@ def forward(self, x:th.tensor): x = th.einsum("xa,yb,eiab->eixy", Tx.T, Ty.T, x) # Transform back to Real space -> [nBatch, dv, nX, nY] - x = self._toRealSpace(x) + # Need to pass signal orginal shape to round irfft2() + # if last dim is odd + x = self._toRealSpace(x, org_size[-2:]) if self.bias is not None: x = x + self.bias @@ -314,7 +317,7 @@ def __init__(self, da, dv, du, kX=4, kY=4, channel_mlp_expansion=4, get_subdomain_output=False, iXBeg=0,iYBeg=0, - iXEnd=256,iYEnd=256, + iXEnd=None,iYEnd=None, ): super().__init__() @@ -387,7 +390,7 @@ def forward(self, x): if index < len(self.layers) - 1: x = nn.functional.gelu(x) - # to get only a subdomain output + # to get only a subdomain output inference if self.get_subdomain_output: print(f'Filtering to x-subdomain {self.iXBeg,self.iXEnd} & y-subdomain {self.iYBeg,self.iYEnd}') x = x[:, :, self.iXBeg:self.iXEnd, self.iYBeg:self.iYEnd] From 851717c1a0c9d5aaf2c5c44376fe91f786a6c045 Mon Sep 17 00:00:00 2001 From: john2 Date: Mon, 10 Feb 2025 19:17:58 +0100 Subject: [PATCH 14/14] refactoring into sub Dataset classes and allowing varying patch size DataLoader --- cfno/data/preprocessing.py | 488 +++++++++++++++++++++--------- cfno/models/cfno2d.py | 8 +- cfno/training/pySDC.py | 56 +++- scripts/configs/config_dte-3.yaml | 22 +- 4 files changed, 406 insertions(+), 168 deletions(-) diff --git a/cfno/data/preprocessing.py b/cfno/data/preprocessing.py index 08f4daf..81ac90c 100644 --- a/cfno/data/preprocessing.py +++ b/cfno/data/preprocessing.py @@ -8,6 +8,7 @@ import numpy as np import random from torch.utils.data import Dataset, DataLoader, random_split, Subset +from collections import defaultdict from cfno.simulation.post import OutputFiles class FNOData(): @@ -110,117 +111,30 @@ def get_concat_data(self, task:str, nsamples:int, reader, multistep:bool=True): class HDF5Dataset(Dataset): - def __init__(self, dataFile, - use_domain_sampling=False, - use_fixed_domain=False, - use_ordered_sampling=False, - nPatch_per_sample=1, - use_min_limit=False, - padding=[0,0,0,0], - xPatch_start=0, - yPatch_start=0, - slices=[], - kX=12, kY=12): + def __init__(self, dataFile, **kwargs): """ Dataset reader and getitem for DataLoader Args: dataFile (hdf5): data file - use_domain_sampling (bool, optional): To divide full grid (nX,nY) into nPatch_per_sample - random sized patches of (sX,sY). Defaults to False. - use_fixed_domain (bool, optional): To divide full grid (nX,nY) into nPatch_per_sample of - (sX,sY) patches with overlapping. Defaults to False. - use_ordered_sampling (bool, optional): To divide full grid (nX,nY) into (nX//sX)*(nY//sY) - exactly divisible (sX,sY) size patches w/o overlapping. - Defaults to False. - nPatch_per_sample (int, optional): Number of sub-domains per sample. Defaults to 1. - use_min_limit (bool, optional): Restrict (sX,sY) to be > (2*kX -1, 2*kY-1). Defaults to False. - padding (list, optional): Columns and rows to decode inflow information - in format[left, right, bottom, top]. Defaults to [0,0,0,0] - xPatch_start (int, optional): Starting index of patch in x-axis. Defaults to 0. - yPatch_start (int, optional): Starting index of patch in y-axis. Defaults to 0. - slices (list, optional): Sizes of patch [[sX,sY]]. Defaults to []. - kX (int, optional): Number of fourier modes in x-axis. Defaults to 12. - kY (int, optional): Number of fourier modes in y-axis. Defaults to 12. """ self.file = h5py.File(dataFile, 'r') self.inputs = self.file['inputs'] self.outputs = self.file['outputs'] - self.use_domain_sampling = use_domain_sampling - self.use_fixed_domain = use_fixed_domain - self.use_ordered_sampling = use_ordered_sampling - self.nPatch_per_sample = nPatch_per_sample - self.use_min_limit = use_min_limit - self.kX = kX - self.kY = kY xGrid, yGrid = self.grid self.nX = xGrid.size self.nY = yGrid.size - self.xPatch_start = xPatch_start - self.yPatch_start = yPatch_start - - if len(slices) == 0: - self.slices = self.find_patch_size() - elif len(slices) == 1: - if self.use_fixed_domain: - if self.use_ordered_sampling: - self.nPatch_per_sample = (self.nX // slices[0][0]) * (self.nY // slices[0][1]) - self.slices = slices * self.nPatch_per_sample - else: - self.slices = slices - else: - self.slices = slices - - self.padding = padding #[left, right, bottom, top] - + assert len(self.inputs) == len(self.outputs), \ f"different sample number for inputs and outputs ({len(self.inputs)},{len(self.outputs)})" - assert not self.use_ordered_sampling or self.use_fixed_domain, "If use_ordered_sampling is True, then use_fixed_domain must also be True" - def __len__(self): return len(self.inputs) def __getitem__(self, idx): - if self.use_domain_sampling: - patch_padding = self.padding.copy() - iSample = idx // self.nPatch_per_sample - iPatch = idx % self.nPatch_per_sample - inpt_grid, outp_grid = self.sample(iSample) - inpt, outp = np.zeros_like(inpt_grid), np.zeros_like(outp_grid) - sX, sY = self.slices[iPatch] - if len(self.slices) == 1: - xPatch_start = self.xPatch_start.copy() - yPatch_start = self.yPatch_start.copy() - else: - if self.use_ordered_sampling: - xPatch_start = (iPatch // (self.nX//sX)) * sX - yPatch_start = (iPatch % (self.nY//sY)) * sY - else: - xPatch_start = random.randint(0, self.nX - sX) - yPatch_start = random.randint(0, self.nY - sY) - - patch_padding[0] = 0 if xPatch_start == 0 or (xPatch_start - patch_padding[0]) < 0 else patch_padding[0] - patch_padding[1] = 0 if (xPatch_start + sX + patch_padding[1]) >= self.nX else patch_padding[1] - patch_padding[2] = 0 if yPatch_start == 0 or (yPatch_start - patch_padding[2]) < 0 else patch_padding[2] - patch_padding[3] = 0 if (yPatch_start + sY + patch_padding[3]) >= self.nY else patch_padding[3] - - # print(f"Input size: {inpt.shape}") - # print(f'For patch {iPatch} of sample {iSample}') - # print(f'(sx,sy): {sX,sY}, (x_start,y_start): {xPatch_start,yPatch_start}') - # print(f'padding: {patch_padding}, nPatch_per_sample:{self.nPatch_per_sample}') - - inpt[:, :(sX + patch_padding[0] + patch_padding[1]), - :(sY + patch_padding[2] + patch_padding[3])] = inpt_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1], - yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] - outp[:,:(sX + patch_padding[0] + patch_padding[1]), - :(sY + patch_padding[2] + patch_padding[3])] = outp_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1], - yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] - else: - inpt, outp = self.sample(idx) - + inpt, outp = self.sample(idx) return torch.tensor(inpt), torch.tensor(outp) def __del__(self): @@ -247,37 +161,124 @@ def outType(self): @property def outScaling(self): return float(self.infos["outScaling"][()]) + + def printInfos(self): + xGrid, yGrid = self.grid + infos = self.infos + print(f" -- grid shape : ({xGrid.size}, {yGrid.size})") + print(f" -- grid domain : [{xGrid.min():.1f}, {xGrid.max():.1f}] x [{yGrid.min():.1f}, {yGrid.max():.1f}]") + print(f" -- nSimu : {infos['nSimu'][()]}") + print(f" -- dtData : {infos['dtData'][()]:1.2g}") + print(f" -- inSize : {infos['inSize'][()]}") # T_in + print(f" -- outStep : {infos['outStep'][()]}") # T + print(f" -- inStep : {infos['inStep'][()]}") # tStep + print(f" -- nSamples (per simu) : {infos['nSamples'][()]}") + print(f" -- nSamples (total) : {infos['nSamples'][()]*infos['nSimu'][()]}") + print(f" -- dtInput : {infos['dtInput'][()]:1.2g}") + print(f" -- outType : {infos['outType'][()].decode('utf-8')}") + print(f" -- outScaling : {infos['outScaling'][()]:1.2g}") + +class RandomDomainDataset(HDF5Dataset): + """ + Creating dataset by dividing full grid (nX,nY) into nPatch_per_sample different + random sized patches per epoch of (sX,sY). + + Args: + dataFile (hdf5): data file + pad_to_fullGrid (bool, optional): Embeds (sX,sY) into (nX,nY) zero grid + use_fixedPatch_startIdx (bool, optional): To divide full grid (nX,nY) into nPatch_per_sample + (sX,sY) sized patches starting from same index + per epoch . Defaults to False. + nPatch_per_sample (int, optional): Number of patches per sample. Defaults to 1. + use_minLimit (bool, optional): Restrict (sX,sY) to be > (2*kX -1, 2*kY-1). Defaults to False. + padding (list, optional): Columns and rows to decode inflow information + in format[left, right, bottom, top]. Defaults to [0,0,0,0] + slices (list, optional): Sizes of patch [[sX,sY]]. Defaults to []. + patch_startIdx (list, optional): Starting index of patch. Defaults to [[0,0]]. + kX (int, optional): Number of fourier modes in x-axis. Defaults to 12. + kY (int, optional): Number of fourier modes in y-axis. Defaults to 12. + + """ + def __init__(self, dataFile, + pad_to_fullGrid=False, + use_fixedPatch_startIdx=True, + nPatch_per_sample=1, + use_minLimit=True, + padding=[0,0,0,0], + **kwargs): + + super().__init__(dataFile) + self.nPatch_per_sample = nPatch_per_sample + self.pad_to_fullGrid = pad_to_fullGrid + self.use_fixedPatch_startIdx = use_fixedPatch_startIdx + self.use_minLimit = use_minLimit + self.kX = kwargs.get('kX', 12) + self.kY = kwargs.get('kY', 12) + + if not self.pad_to_fullGrid: + self.use_minLimit = True - def find_patch_size(self): + slices = kwargs.get('slices', self.find_patchSize()) + patch_startIdx = kwargs.get('patch_startIdx', [[0,0]]) + if self.use_fixedPatch_startIdx: + if len(patch_startIdx) == len(self.slices): + self.patch_startIdx = patch_startIdx + else: + self.patch_startIdx = self.find_patch_startIdx() + + self.padding = padding #[left, right, bottom, top] + + assert len(self.slices) == self.nPatch_per_sample, "Number of slices doesn't match patches per sample" + + + def __getitem__(self, idx): + patch_padding = self.padding.copy() + iSample = idx // self.nPatch_per_sample + iPatch = idx % self.nPatch_per_sample + inpt_grid, outp_grid = self.sample(iSample) + sX, sY = self.slices[iPatch] + if self.use_fixedPatch_startIdx: + xPatch_startIdx = self.patch_startIdx[iPatch][0] + yPatch_startIdx= self.patch_startIdx[iPatch][1] + else: + xPatch_startIdx = random.randint(0, self.nX - sX) + yPatch_startIdx= random.randint(0, self.nY - sY) + + patch_padding[0] = 0 if xPatch_startIdx == 0 or (xPatch_startIdx - patch_padding[0]) < 0 else patch_padding[0] + patch_padding[1] = 0 if (xPatch_startIdx + sX + patch_padding[1]) >= self.nX else patch_padding[1] + patch_padding[2] = 0 if yPatch_startIdx == 0 or (yPatch_startIdx- patch_padding[2]) < 0 else patch_padding[2] + patch_padding[3] = 0 if (yPatch_startIdx+ sY + patch_padding[3]) >= self.nY else patch_padding[3] + + if self.pad_to_fullGrid: + inpt, outp = np.zeros_like(inpt_grid), np.zeros_like(outp_grid) + inpt[:, :(sX + patch_padding[0] + patch_padding[1]), + :(sY + patch_padding[2] + patch_padding[3])] = inpt_grid[:, xPatch_startIdx - patch_padding[0]: (xPatch_startIdx+sX) + patch_padding[1], + yPatch_startIdx- patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + outp[:,:(sX + patch_padding[0] + patch_padding[1]), + :(sY + patch_padding[2] + patch_padding[3])] = outp_grid[:, xPatch_startIdx - patch_padding[0]: (xPatch_startIdx+sX) + patch_padding[1], + yPatch_startIdx- patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + else: + inpt = inpt_grid[:, xPatch_startIdx - patch_padding[0]: (xPatch_startIdx+sX) + patch_padding[1], yPatch_startIdx- patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + outp = outp_grid[:, xPatch_startIdx - patch_padding[0]: (xPatch_startIdx+sX) + patch_padding[1], yPatch_startIdx- patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + + return torch.tensor(inpt), torch.tensor(outp) + + def find_patchSize(self): """ List containing patch sizes """ slices = [] - if self.use_fixed_domain: - self.valid_sX = [sx for sx in range(1,self.nX) if self.nX % sx == 0] - self.valid_sY = [sy for sy in range(1,self.nY) if self.nY % sy == 0] - # select a (sX,sY) randomly - sX = int(random.choice(self.valid_sX)) - sY = int(random.choice(self.valid_sY)) - if self.use_ordered_sampling: - self.nPatch_per_sample = (self.nX // sX) * (self.nY // sY) - slices = [(sX,sY)]*self.nPatch_per_sample - else: - if self.use_min_limit: - nX_min = self.calc_slice_min(self.nX, self.kX) - nY_min = self.calc_slice_min(self.nY, self.kY) - else: - nX_min, nY_min = 0, 0 - for i in range(self.nPatch_per_sample): - sX = random.randint(nX_min, self.nX) - sY = random.randint(nY_min, self.nY) - slices.append((sX,sY)) + nX_min, nY_min = (self.calc_sliceMin(self.nX, self.kX), self.calc_sliceMin(self.nY, self.kY)) if self.use_minLimit else (0, 0) + for _ in range(self.nPatch_per_sample): + sX = random.randint(nX_min, self.nX) + sY = random.randint(nY_min, self.nY) + slices.append((sX, sY)) return slices - def calc_slice_min(self, n, modes): + def calc_sliceMin(self, n, modes): """ Finding min number of points to satisfy - n/2 +1 >= modes + n/2 +1 >= fourier modes """ slice_min = 2*(modes-1) if slice_min < n: @@ -286,6 +287,17 @@ def calc_slice_min(self, n, modes): print("Insufficient number of points to slice") return 0 + def find_patch_startIdx(self): + """ + List containing patch starting index + """ + patch_start = [] + for i in range(len(self.slices)): + xPatch_startIdx = random.randint(0, self.nX - self.slices[i][0]) + yPatch_startIdx = random.randint(0, self.nY - self.slices[i][1]) + patch_start.append((xPatch_startIdx, yPatch_startIdx)) + return patch_start + def printInfos(self): xGrid, yGrid = self.grid infos = self.infos @@ -301,16 +313,168 @@ def printInfos(self): print(f" -- dtInput : {infos['dtInput'][()]:1.2g}") print(f" -- outType : {infos['outType'][()].decode('utf-8')}") print(f" -- outScaling : {infos['outScaling'][()]:1.2g}") - print(f" --use_ordered_sampling: {self.use_ordered_sampling}") - if self.use_domain_sampling: - print(f"-- nPatch (per sample): {self.nPatch_per_sample}") - if self.use_fixed_domain: - print(f" --patches (per sample): {self.slices[0]}") + print(f" -- pad_to_fullGrid: {self.pad_to_fullGrid}") + print(f" -- nPatch (per sample): {self.nPatch_per_sample}") + print(f" -- patches (per sample): {self.slices}") + print(f" -- padding (per patch): {self.padding}") + if self.use_minLimit: + print(f"Min nX & nY for patch computed using {self.kX, self.kY} modes") + if self.use_fixedPatch_startIdx: + print(f" -- patch start index (per epoch): {self.patch_startIdx}") + +class FixedDomainDataset(HDF5Dataset): + """ + Creating dataset by dividing full grid (nX,nY) into nPatch_per_sample of + (sX,sY) patches with overlapping + + Args: + dataFile (hdf5): data file + use_orderedSampling (bool, optional): To divide full grid (nX,nY) into (nX//sX)*(nY//sY) + exactly divisible (sX,sY) size patches w/o overlapping. + Defaults to False. + pad_to_fullGrid (bool, optional): Embeds (sX,sY) into (nX,nY) zero grid + use_fixedPatch_startIdx (bool, optional): To divide full grid (nX,nY) into nPatch_per_sample + (sX,sY) sized patches starting from same index + per epoch . Defaults to False. + nPatch_per_sample (int, optional): Number of patches per sample. Defaults to 1. + use_minLimit (bool, optional): Restrict (sX,sY) to be > (2*kX -1, 2*kY-1). Defaults to False. + padding (list, optional): Columns and rows to decode inflow information + in format[left, right, bottom, top]. Defaults to [0,0,0,0] + slices (list, optional): Sizes of patch [[sX,sY]]. Defaults to []. + patch_startIdx (list, optional): Starting index of patch. Defaults to [[0,0]]. + kX (int, optional): Number of fourier modes in x-axis. Defaults to 12. + kY (int, optional): Number of fourier modes in y-axis. Defaults to 12. + + """ + def __init__(self, dataFile, + use_orderedSampling=False, + pad_to_fullGrid=False, + use_fixedPatch_startIdx=True, + nPatch_per_sample=1, + use_minLimit=True, + padding=[0,0,0,0], + **kwargs): + + super().__init__(dataFile) + self.nPatch_per_sample = nPatch_per_sample + self.pad_to_fullGrid = pad_to_fullGrid + self.use_fixedPatch_startIdx = use_fixedPatch_startIdx + self.use_minLimit = use_minLimit + self.kX = kwargs.get('kX', 12) + self.kY = kwargs.get('kY', 12) + self.use_orderedSampling = use_orderedSampling + + if not self.pad_to_fullGrid: + self.use_minLimit = True + + slices = kwargs.get('slices', []) + if len(slices) == 0: + single_slice = self.find_patchSize() + else: + single_slice = slices + + assert len(single_slice) == 1, f"{len(single_slice)} patch size given for uniform domain sampling" + + if self.use_orderedSampling: + self.nPatch_per_sample = (self.nX // single_slice[0][0]) * (self.nY // single_slice[0][1]) + self.slices = single_slice * self.nPatch_per_sample + + assert not (self.use_fixedPatch_startIdx and self.use_orderedSampling), \ + "use_fixedPatch_startIdx and use_orderedSampling cannot be True at the same time." + + if self.use_fixedPatch_startIdx: + patch_startIdx = kwargs.get('patch_startIdx', []) + if len(patch_startIdx) == len(self.slices): + self.patch_startIdx = patch_startIdx else: - print(f" --patches (per sample): {self.slices}") - print(f" --padding (per patch): {self.padding}") - if self.use_min_limit: - print(f" Min nX & nY for patch computed using {self.kX, self.kY} modes") + self.patch_startIdx = self.find_patch_startIdx() + + self.padding = padding #[left, right, bottom, top] + + def __getitem__(self, idx): + patch_padding = self.padding.copy() + iSample = idx // self.nPatch_per_sample + iPatch = idx % self.nPatch_per_sample + inpt_grid, outp_grid = self.sample(iSample) + sX, sY = self.slices[iPatch] + + if self.use_fixedPatch_startIdx: + xPatch_startIdx = self.patch_startIdx[iPatch][0] + yPatch_startIdx= self.patch_startIdx[iPatch][1] + elif self.use_orderedSampling: + xPatch_startIdx = (iPatch // (self.nX//sX)) * sX + yPatch_startIdx = (iPatch % (self.nY//sY)) * sY + else: + xPatch_startIdx = random.randint(0, self.nX - sX) + yPatch_startIdx= random.randint(0, self.nY - sY) + + patch_padding[0] = 0 if xPatch_startIdx == 0 or (xPatch_startIdx - patch_padding[0]) < 0 else patch_padding[0] + patch_padding[1] = 0 if (xPatch_startIdx + sX + patch_padding[1]) >= self.nX else patch_padding[1] + patch_padding[2] = 0 if yPatch_startIdx== 0 or (yPatch_startIdx- patch_padding[2]) < 0 else patch_padding[2] + patch_padding[3] = 0 if (yPatch_startIdx+ sY + patch_padding[3]) >= self.nY else patch_padding[3] + + if self.pad_to_fullGrid: + inpt, outp = np.zeros_like(inpt_grid), np.zeros_like(outp_grid) + inpt[:, :(sX + patch_padding[0] + patch_padding[1]), + :(sY + patch_padding[2] + patch_padding[3])] = inpt_grid[:, xPatch_startIdx - patch_padding[0]: (xPatch_startIdx+sX) + patch_padding[1], + yPatch_startIdx- patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + outp[:,:(sX + patch_padding[0] + patch_padding[1]), + :(sY + patch_padding[2] + patch_padding[3])] = outp_grid[:, xPatch_startIdx - patch_padding[0]: (xPatch_startIdx+sX) + patch_padding[1], + yPatch_startIdx- patch_padding[2]: (yPatch_startIdx+sY) + patch_padding[3]] + else: + inpt = inpt_grid[:, xPatch_startIdx - patch_padding[0]: (xPatch_startIdx+sX) + patch_padding[1], yPatch_startIdx- patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + outp = outp_grid[:, xPatch_startIdx - patch_padding[0]: (xPatch_startIdx+sX) + patch_padding[1], yPatch_startIdx- patch_padding[2]: (yPatch_start+sY) + patch_padding[3]] + + return torch.tensor(inpt), torch.tensor(outp) + + def find_patchSize(self): + """ + List containing patch sizes + """ + slices = [] + self.valid_sX = [sx for sx in range(1,self.nX) if self.nX % sx == 0] + self.valid_sY = [sy for sy in range(1,self.nY) if self.nY % sy == 0] + # select a (sX,sY) randomly + sX = int(random.choice(self.valid_sX)) + sY = int(random.choice(self.valid_sY)) + slices.append((sX,sY)) + return slices + + def find_patch_startIdx(self): + """ + List containing patch starting index + """ + patch_start = [] + for i in range(len(self.slices)): + xPatch_startIdx = random.randint(0, self.nX - self.slices[i][0]) + yPatch_startIdx= random.randint(0, self.nY - self.slices[i][1]) + patch_start.append((xPatch_startIdx, yPatch_startIdx)) + return patch_start + + def printInfos(self): + xGrid, yGrid = self.grid + infos = self.infos + print(f" -- grid shape : ({xGrid.size}, {yGrid.size})") + print(f" -- grid domain : [{xGrid.min():.1f}, {xGrid.max():.1f}] x [{yGrid.min():.1f}, {yGrid.max():.1f}]") + print(f" -- nSimu : {infos['nSimu'][()]}") + print(f" -- dtData : {infos['dtData'][()]:1.2g}") + print(f" -- inSize : {infos['inSize'][()]}") # T_in + print(f" -- outStep : {infos['outStep'][()]}") # T + print(f" -- inStep : {infos['inStep'][()]}") # tStep + print(f" -- nSamples (per simu) : {infos['nSamples'][()]}") + print(f" -- nSamples (total) : {infos['nSamples'][()]*infos['nSimu'][()]}") + print(f" -- dtInput : {infos['dtInput'][()]:1.2g}") + print(f" -- outType : {infos['outType'][()].decode('utf-8')}") + print(f" -- outScaling : {infos['outScaling'][()]:1.2g}") + print(f" -- use_orderedSampling: {self.use_orderedSampling}") + print(f" -- pad_to_fullGrid: {self.pad_to_fullGrid}") + print(f" -- nPatch (per sample): {self.nPatch_per_sample}") + print(f" -- patches (per sample): {self.slices[0]}") + print(f" -- padding (per patch): {self.padding}") + if self.use_minLimit: + print(f"Min nX & nY for patch computed using {self.kX, self.kY} modes") + if self.use_fixedPatch_startIdx: + print(f" -- patch start index (per epoch): {self.patch_startIdx}") def createDataset( dataDir, inSize, outStep, inStep, outType, outScaling, dataFile, @@ -378,31 +542,45 @@ def createDataset( dataset.close() print(" -- done !") -def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, - seed=None, use_domain_sampling=False, - use_fixed_domain=False, - use_ordered_sampling=False, +def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, seed=None, + use_domainSampling=False, + use_fixedPatchSize=False, + pad_to_fullGrid=False, + use_orderedSampling=False, + use_fixedPatch_startIdx=False, nPatch_per_sample=1, - use_min_limit=False, + use_minLimit=False, padding=[0,0,0,0], - xPatch_start=0, - yPatch_start=0, - kX= 12, kY= 12, **kwargs): - - if 'slices' in kwargs: - slices = kwargs['slices'] + **kwargs): + + if not use_domainSampling: + dataset = HDF5Dataset(dataFile) else: - slices = [] - - dataset = HDF5Dataset(dataFile, use_domain_sampling, - use_fixed_domain, use_ordered_sampling, - nPatch_per_sample, use_min_limit, - padding,xPatch_start, yPatch_start, - slices, kX, kY) + if use_fixedPatchSize: + dataset = FixedDomainDataset(dataFile, + use_orderedSampling, + pad_to_fullGrid, + use_fixedPatch_startIdx, + nPatch_per_sample, + use_minLimit, + padding, + **kwargs) + else: + dataset = RandomDomainDataset(dataFile, + pad_to_fullGrid, + use_fixedPatch_startIdx, + nPatch_per_sample, + use_minLimit, + padding, + **kwargs) + dataset.printInfos() nBatches = len(dataset) - + collate_fn = None + + train_batchSize = batchSize + valid_batchSize = batchSize trainSize = int(trainRatio*nBatches) valSize = nBatches - trainSize @@ -415,8 +593,32 @@ def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, generator = torch.Generator().manual_seed(seed) trainSet, valSet = random_split( dataset, [trainSize, valSize], generator=generator) - - trainLoader = DataLoader(trainSet, batch_size=batchSize, shuffle=True) - valLoader = DataLoader(valSet, batch_size=batchSize, shuffle=False) + + if use_domainSampling and not pad_to_fullGrid: + train_batchSize = len(trainSet) + valid_batchSize = len(valSet) + collate_fn = variable_tensor_collate_fn + + trainLoader = DataLoader(trainSet, batch_size=train_batchSize, shuffle=True, num_workers=0, collate_fn=collate_fn, pin_memory=True) + valLoader = DataLoader(valSet, batch_size=valid_batchSize, shuffle=False, num_workers=0, collate_fn=collate_fn, pin_memory=True) return trainLoader, valLoader, dataset + +def variable_tensor_collate_fn(batch): + """ + Groups tensors of the same shape together and batches them separately. + """ + grouped_tensors_inp = defaultdict(list) + grouped_tensors_out = defaultdict(list) + + for element in batch: + key = tuple(element[0].shape) # input and output have same shape + # [trainSamples//nPatch_per_sample,4,sx,sy] + grouped_tensors_inp[key].append(element[0]) + grouped_tensors_out[key].append(element[0]) + + # Stack tensors in each group to give [nPatch_per_sample, trainSamples//nPatch_per_sample, 4, sx,sy] + batched_tensors_inp = [torch.stack(tensors) for tensors in grouped_tensors_inp.values()] + batched_tensors_out = [torch.stack(tensors) for tensors in grouped_tensors_out.values()] + + return (batched_tensors_inp, batched_tensors_out) \ No newline at end of file diff --git a/cfno/models/cfno2d.py b/cfno/models/cfno2d.py index a7347e3..25ebd82 100644 --- a/cfno/models/cfno2d.py +++ b/cfno/models/cfno2d.py @@ -77,9 +77,9 @@ def _toFourierSpace(self, x): x = th.fft.rfft(x, dim=-2, norm="ortho") # RFFT on before-last dimension return x - def _toRealSpace(self, x): + def _toRealSpace(self, x, org_size): """ x[nBatch, dv, fX = nX/2+1, fY = nY] -> [nBatch, dv, nX, nY] """ - x = th.fft.irfft(x, dim=-2, norm="ortho") # IRFFT on before-last dimension + x = th.fft.irfft(x, dim=-2, s=org_size[0], norm="ortho") # IRFFT on before-last dimension x = idct(x, norm="ortho") # IDCT on last dimension return x @@ -101,9 +101,9 @@ def _toFourierSpace_FORCE_FFT_REORDER(self, x): x = th.fft.rfft2(x, norm="ortho") # RFFT on last 2 dimensions return x - def _toRealSpace_FORCE_FFT_REORDER(self, x): + def _toRealSpace_FORCE_FFT_REORDER(self, x, org_size): """ x[nBatch, dv, fX = nX/2+1, fY = nY/2+1] -> [nBatch, dv, nX, nY]""" - x = th.fft.irfft2(x, norm="ortho") # IRFFT on last 2 dimensions + x = th.fft.irfft2(x, s=org_size, norm="ortho") # IRFFT on last 2 dimensions nY = x.shape[-1] reorder = np.zeros(nY, dtype=int) reorder[: nY - nY % 2 : 2] = np.arange(nY // 2) diff --git a/cfno/training/pySDC.py b/cfno/training/pySDC.py index 10a840f..c18f788 100644 --- a/cfno/training/pySDC.py +++ b/cfno/training/pySDC.py @@ -185,11 +185,22 @@ def idLoss(self, dataset="valid"): else: ValueError(f"cannot compute id loss on {loader} dataset") nBatches = len(loader) + data_iter = iter(loader) + + if self.data_config['use_domainSampling'] and not self.data_config['pad_to_fullGrid']: + inp_list, out_list = next(data_iter) #[nBatches=nPatch_per_sample, batchSize=nSamples/nBatches, 4, nX, ny] + nBatches = len(inp_list) + avgLoss = 0 outType = self.outType with th.no_grad(): - for inputs, outputs in loader: + # for inputs, outputs in loader: + for iBatch in range(nBatches): + if self.data_config['use_domainSampling'] and not self.data_config['pad_to_fullGrid']: + inputs, outputs = (inp_list[iBatch], out_list[iBatch]) + else: + inputs, outputs = next(data_iter) if outType == "solution": avgLoss += self.lossFunction(inputs, outputs, inputs).item() elif outType == "update": @@ -213,9 +224,17 @@ def fullPath(cls, filePath): # ------------------------------------------------------------------------- def train(self): """Train the model for one epoch""" + nSamples = len(self.trainLoader.dataset) nBatches = len(self.trainLoader) batchSize = self.trainLoader.batch_size + data_iter = iter(self.trainLoader) + + if self.data_config['use_domainSampling'] and not self.data_config['pad_to_fullGrid']: + inp_list, out_list = next(data_iter) # [nBatches=nPatch_per_sample, batchSize=nSamples/nBatches, 4, nX, ny] + nBatches = len(inp_list) + batchSize = len(inp_list[0]) + model = self.model optimizer = self.optimizer scheduler = self.lr_scheduler @@ -224,7 +243,12 @@ def train(self): idLoss = self.losses['id']['train'] model.train() - for iBatch, data in enumerate(self.trainLoader): + for iBatch in range(nBatches): + # for iBatch, data in enumerate(self.trainLoader): + if self.data_config['use_domainSampling'] and not self.data_config['pad_to_fullGrid']: + data = (inp_list[iBatch], out_list[iBatch]) + else: + data = next(data_iter) inp = data[0][..., ::self.xStep, ::self.yStep].to(self.device) ref = data[1][..., ::self.xStep, ::self.yStep].to(self.device) @@ -274,12 +298,22 @@ def valid(self): model = self.model avgLoss = 0 idLoss = self.losses['id']['valid'] + data_iter = iter(self.valLoader) + + if self.data_config['use_domainSampling'] and not self.data_config['pad_to_fullGrid']: + inp_list, out_list = next(data_iter) #[nBatches=nPatch_per_sample, batchSize=nSamples/nBatches, 4, nX, ny] + nBatches = len(inp_list) + batchSize = len(inp_list[0]) model.eval() with th.no_grad(): - for data in self.valLoader: - inp = data[0].to(self.device) - ref = data[1].to(self.device) + for iBatch in range(nBatches): + if self.data_config['use_domainSampling'] and not self.data_config['pad_to_fullGrid']: + data = (inp_list[iBatch], out_list[iBatch]) + else: + data = next(data_iter) + inp = data[0][..., ::self.xStep, ::self.yStep].to(self.device) + ref = data[1][..., ::self.xStep, ::self.yStep].to(self.device) pred = model(inp) avgLoss += self.lossFunction(pred, ref, inp).item() @@ -413,15 +447,11 @@ def __call__(self, u0, nEval=1): if outp.shape == inpt.shape: outp += inpt else: - padded_tensor = th.zeros_like(inpt) - padded_tensor[:,:, + sliced_inpt = inpt[:,:, self.modelConfig['iXBeg']: self.modelConfig['iXEnd'], - self.modelConfig['iYBeg']: self.modelConfig['iYEnd'] ] = outp[:,:,:,:] - # print(f'Padded tensor: {padded_tensor.shape}') - padded_tensor += inpt - outp = padded_tensor[:,:,self.modelConfig['iXBeg']: self.modelConfig['iXEnd'], - self.modelConfig['iYBeg']: self.modelConfig['iYEnd']] - # print(f'Ouptut shape: {outp.shape}') + self.modelConfig['iYBeg']: self.modelConfig['iYEnd']] + # print(f'Sliced Input: {sliced_inpt.shape}') + outp += sliced_inpt inpt = outp if not multi: diff --git a/scripts/configs/config_dte-3.yaml b/scripts/configs/config_dte-3.yaml index 05a5684..875e133 100644 --- a/scripts/configs/config_dte-3.yaml +++ b/scripts/configs/config_dte-3.yaml @@ -19,10 +19,16 @@ data: batchSize: 20 xStep: 1 yStep: 1 - use_domain_sampling: true + use_domainSampling: true + use_fixedPatchSize: false # fixed domain with overlapping + use_orderedSampling: false # fixed domain w/o overlapping + pad_to_fullGrid: false + use_fixedPatch_startIdx: true nPatch_per_sample: 100 - use_min_limit: false + use_minLimit: true padding : [1,1,1,1] + # slices: [] + # patch_startIdx: [] model: da: 4 @@ -39,11 +45,11 @@ model: fno_skip_type: soft-gating use_postfnochannel_mlp: false channel_mlp_skip_type: 'soft-gating' - get_subdomain_output: true - iXBeg: 16 - iYBeg: 16 - iXEnd: 64 - iYEnd: 64 + get_subdomain_output: false # set to true only for inference + iXBeg: 5 + iYBeg: 12 + iXEnd: 30 + iYEnd: 60 optim: name: adam @@ -55,7 +61,7 @@ lr_scheduler: T_max: 50000 train: - trainDir: training_dt1e-3_update_fft_run26 + trainDir: training_dt1e-3_update_fft_test nEpochs: 20000 checkpoint: model.pt lossesFile: losses.txt