Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training and Inference on sub-domains #7

Open
wants to merge 15 commits into
base: tibo
Choose a base branch
from
178 changes: 173 additions & 5 deletions cfno/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import glob
import numpy as np

import random
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from cfno.simulation.post import OutputFiles

Expand Down Expand Up @@ -110,18 +110,117 @@ def get_concat_data(self, task:str, nsamples:int, reader, multistep:bool=True):

class HDF5Dataset(Dataset):

def __init__(self, dataFile):
def __init__(self, dataFile,
use_domain_sampling=False,
use_fixed_domain=False,
use_ordered_sampling=False,
nPatch_per_sample=1,
use_min_limit=False,
padding=[0,0,0,0],
xPatch_start=0,
yPatch_start=0,
slices=[],
kX=12, kY=12):
"""
Dataset reader and getitem for DataLoader

Args:
dataFile (hdf5): data file
use_domain_sampling (bool, optional): To divide full grid (nX,nY) into nPatch_per_sample
random sized patches of (sX,sY). Defaults to False.
use_fixed_domain (bool, optional): To divide full grid (nX,nY) into nPatch_per_sample of
(sX,sY) patches with overlapping. Defaults to False.
use_ordered_sampling (bool, optional): To divide full grid (nX,nY) into (nX//sX)*(nY//sY)
exactly divisible (sX,sY) size patches w/o overlapping.
Defaults to False.
nPatch_per_sample (int, optional): Number of sub-domains per sample. Defaults to 1.
use_min_limit (bool, optional): Restrict (sX,sY) to be > (2*kX -1, 2*kY-1). Defaults to False.
padding (list, optional): Columns and rows to decode inflow information
in format[left, right, bottom, top]. Defaults to [0,0,0,0]
xPatch_start (int, optional): Starting index of patch in x-axis. Defaults to 0.
yPatch_start (int, optional): Starting index of patch in y-axis. Defaults to 0.
slices (list, optional): Sizes of patch [[sX,sY]]. Defaults to [].
kX (int, optional): Number of fourier modes in x-axis. Defaults to 12.
kY (int, optional): Number of fourier modes in y-axis. Defaults to 12.

"""

self.file = h5py.File(dataFile, 'r')
self.inputs = self.file['inputs']
self.outputs = self.file['outputs']
self.use_domain_sampling = use_domain_sampling
self.use_fixed_domain = use_fixed_domain
self.use_ordered_sampling = use_ordered_sampling
self.nPatch_per_sample = nPatch_per_sample
self.use_min_limit = use_min_limit
self.kX = kX
self.kY = kY
xGrid, yGrid = self.grid
self.nX = xGrid.size
self.nY = yGrid.size
self.xPatch_start = xPatch_start
self.yPatch_start = yPatch_start

if len(slices) == 0:
self.slices = self.find_patch_size()
elif len(slices) == 1:
if self.use_fixed_domain:
if self.use_ordered_sampling:
self.nPatch_per_sample = (self.nX // slices[0][0]) * (self.nY // slices[0][1])
self.slices = slices * self.nPatch_per_sample
else:
self.slices = slices
else:
self.slices = slices

self.padding = padding #[left, right, bottom, top]

assert len(self.inputs) == len(self.outputs), \
f"different sample number for inputs and outputs ({len(self.inputs)},{len(self.outputs)})"

assert not self.use_ordered_sampling or self.use_fixed_domain, "If use_ordered_sampling is True, then use_fixed_domain must also be True"

def __len__(self):
return len(self.inputs)

def __getitem__(self, idx):
inpt, outp = self.sample(idx)
if self.use_domain_sampling:
patch_padding = self.padding.copy()
iSample = idx // self.nPatch_per_sample
iPatch = idx % self.nPatch_per_sample
inpt_grid, outp_grid = self.sample(iSample)
inpt, outp = np.zeros_like(inpt_grid), np.zeros_like(outp_grid)
sX, sY = self.slices[iPatch]
if len(self.slices) == 1:
xPatch_start = self.xPatch_start.copy()
yPatch_start = self.yPatch_start.copy()
else:
if self.use_ordered_sampling:
xPatch_start = (iPatch // (self.nX//sX)) * sX
yPatch_start = (iPatch % (self.nY//sY)) * sY
else:
xPatch_start = random.randint(0, self.nX - sX)
yPatch_start = random.randint(0, self.nY - sY)

patch_padding[0] = 0 if xPatch_start == 0 or (xPatch_start - patch_padding[0]) < 0 else patch_padding[0]
patch_padding[1] = 0 if (xPatch_start + sX + patch_padding[1]) >= self.nX else patch_padding[1]
patch_padding[2] = 0 if yPatch_start == 0 or (yPatch_start - patch_padding[2]) < 0 else patch_padding[2]
patch_padding[3] = 0 if (yPatch_start + sY + patch_padding[3]) >= self.nY else patch_padding[3]

# print(f"Input size: {inpt.shape}")
# print(f'For patch {iPatch} of sample {iSample}')
# print(f'(sx,sy): {sX,sY}, (x_start,y_start): {xPatch_start,yPatch_start}')
# print(f'padding: {patch_padding}, nPatch_per_sample:{self.nPatch_per_sample}')

inpt[:, :(sX + patch_padding[0] + patch_padding[1]),
:(sY + patch_padding[2] + patch_padding[3])] = inpt_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1],
yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]]
outp[:,:(sX + patch_padding[0] + patch_padding[1]),
:(sY + patch_padding[2] + patch_padding[3])] = outp_grid[:, xPatch_start - patch_padding[0]: (xPatch_start+sX) + patch_padding[1],
yPatch_start - patch_padding[2]: (yPatch_start+sY) + patch_padding[3]]
else:
inpt, outp = self.sample(idx)

return torch.tensor(inpt), torch.tensor(outp)

def __del__(self):
Expand All @@ -148,6 +247,44 @@ def outType(self):
@property
def outScaling(self):
return float(self.infos["outScaling"][()])

def find_patch_size(self):
"""
List containing patch sizes
"""
slices = []
if self.use_fixed_domain:
self.valid_sX = [sx for sx in range(1,self.nX) if self.nX % sx == 0]
self.valid_sY = [sy for sy in range(1,self.nY) if self.nY % sy == 0]
# select a (sX,sY) randomly
sX = int(random.choice(self.valid_sX))
sY = int(random.choice(self.valid_sY))
if self.use_ordered_sampling:
self.nPatch_per_sample = (self.nX // sX) * (self.nY // sY)
slices = [(sX,sY)]*self.nPatch_per_sample
else:
if self.use_min_limit:
nX_min = self.calc_slice_min(self.nX, self.kX)
nY_min = self.calc_slice_min(self.nY, self.kY)
else:
nX_min, nY_min = 0, 0
for i in range(self.nPatch_per_sample):
sX = random.randint(nX_min, self.nX)
sY = random.randint(nY_min, self.nY)
slices.append((sX,sY))
return slices

def calc_slice_min(self, n, modes):
"""
Finding min number of points to satisfy
n/2 +1 >= modes
"""
slice_min = 2*(modes-1)
if slice_min < n:
return slice_min
else:
print("Insufficient number of points to slice")
return 0

def printInfos(self):
xGrid, yGrid = self.grid
Expand All @@ -164,6 +301,16 @@ def printInfos(self):
print(f" -- dtInput : {infos['dtInput'][()]:1.2g}")
print(f" -- outType : {infos['outType'][()].decode('utf-8')}")
print(f" -- outScaling : {infos['outScaling'][()]:1.2g}")
print(f" --use_ordered_sampling: {self.use_ordered_sampling}")
if self.use_domain_sampling:
print(f"-- nPatch (per sample): {self.nPatch_per_sample}")
if self.use_fixed_domain:
print(f" --patches (per sample): {self.slices[0]}")
else:
print(f" --patches (per sample): {self.slices}")
print(f" --padding (per patch): {self.padding}")
if self.use_min_limit:
print(f" Min nX & nY for patch computed using {self.kX, self.kY} modes")

def createDataset(
dataDir, inSize, outStep, inStep, outType, outScaling, dataFile,
Expand Down Expand Up @@ -231,10 +378,31 @@ def createDataset(
dataset.close()
print(" -- done !")

def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20,
seed=None, use_domain_sampling=False,
use_fixed_domain=False,
use_ordered_sampling=False,
nPatch_per_sample=1,
use_min_limit=False,
padding=[0,0,0,0],
xPatch_start=0,
yPatch_start=0,
kX= 12, kY= 12, **kwargs):

if 'slices' in kwargs:
slices = kwargs['slices']
else:
slices = []

dataset = HDF5Dataset(dataFile, use_domain_sampling,
use_fixed_domain, use_ordered_sampling,
nPatch_per_sample, use_min_limit,
padding,xPatch_start, yPatch_start,
slices, kX, kY)
dataset.printInfos()

def getDataLoaders(dataFile, trainRatio=0.8, batchSize=20, seed=None):
dataset = HDF5Dataset(dataFile)
nBatches = len(dataset)

trainSize = int(trainRatio*nBatches)
valSize = nBatches - trainSize

Expand Down
23 changes: 21 additions & 2 deletions cfno/models/cfno2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,10 @@ def __init__(self, da, dv, du, kX=4, kY=4,
fno_skip_type='linear',
use_postfnochannel_mlp=False,
channel_mlp_skip_type='soft-gating',
channel_mlp_expansion=4
channel_mlp_expansion=4,
get_subdomain_output=False,
iXBeg=0,iYBeg=0,
iXEnd=256,iYEnd=256,
):

super().__init__()
Expand Down Expand Up @@ -357,9 +360,19 @@ def __init__(self, da, dv, du, kX=4, kY=4,
for _ in range(nLayers)])

self.memory = CudaMemoryDebugger(print_mem=True)
self.get_subdomain_output = get_subdomain_output
if self.get_subdomain_output:
self.iXBeg = iXBeg
self.iXEnd = iXEnd
self.iYBeg = iYBeg
self.iYEnd = iYEnd

def forward(self, x):
""" x[nBatch, nX, nY, da] -> [nBatch, du, nX, nY]"""
""" x[nBatch, nX, nY, da] -> [nBatch, du, nX, nY]
if use_subdomain_output:
x[nBatch, nX, nY, da] -> [nBatch, du, iEndX-iBegX, iEndY-iBegY]
"""

# x = x.permute(0,3,1,2)
x = self.P(x)

Expand All @@ -373,10 +386,16 @@ def forward(self, x):
x = self.channel_mlp[index](x) + x_skip_channel_mlp
if index < len(self.layers) - 1:
x = nn.functional.gelu(x)

# to get only a subdomain output
if self.get_subdomain_output:
print(f'Filtering to x-subdomain {self.iXBeg,self.iXEnd} & y-subdomain {self.iYBeg,self.iYEnd}')
x = x[:, :, self.iXBeg:self.iXEnd, self.iYBeg:self.iYEnd]

x = self.Q(x)
# x = self.pos(x)
# x = x.permute(0,2,3,1)
# print(f'Shape of x: {x.shape}')

return x

Expand Down
29 changes: 26 additions & 3 deletions cfno/training/pySDC.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,19 @@ def __init__(self,
self.device = th.device('cuda' if th.cuda.is_available() else 'cpu')

# Inference-only mode
if data is None and model is None and optim is None:
if data is None and optim is None:
assert checkpoint is not None, "need a checkpoint in inference-only evaluation"
if model is not None:
self.modelConfig = model
self.load(checkpoint, modelOnly=True)
return

# Data setup
assert "dataFile" in data, ""
self.data_config = {**data}
self.xStep, self.yStep = data.pop("xStep", 1), data.pop("yStep", 1)
data.pop("outType", None), data.pop("outScaling", None) # overwritten by dataset
self.trainLoader, self.valLoader, self.dataset = getDataLoaders(**data)
self.trainLoader, self.valLoader, self.dataset = getDataLoaders(**data, kX=model['kX'], kY=model['kY'] )
# sample : [batchSize, 4, nX, nY]
self.outType = self.dataset.outType
self.outScaling = self.dataset.outScaling
Expand Down Expand Up @@ -167,6 +170,9 @@ def printInfos(self):
print(f"Scheduler: {self.scheduler_name}")
for key,val in self.scheduler_config.items():
print(f" -- {key}: {val}")
print("Data settings")
for key, val in self.data_config.items():
print(f" -- {key}: {val}")
# TODO: add more details here ...
print("-"*80)

Expand Down Expand Up @@ -358,8 +364,12 @@ def load(self, filePath, modelOnly=False):
# for backward compatibility ...
checkpoint['model']["non_linearity"] = checkpoint['model'].pop("nonLinearity")
if hasattr(self, "modelConfig") and self.modelConfig != checkpoint['model']:
for key, value in self.modelConfig.items():
if key not in checkpoint['model']:
checkpoint['model'][key] = value
print("WARNING : different model settings in config file,"
" overwriting with config from checkpoint ...")
print(f"Model: {checkpoint['model']}")
self.setupModel(checkpoint['model'])
self.model.load_state_dict(checkpoint['model_state_dict'])
self.outType = checkpoint["outType"]
Expand Down Expand Up @@ -398,7 +408,20 @@ def __call__(self, u0, nEval=1):
outp = model(inpt)
if self.outType == "update":
outp /= self.outScaling
outp += inpt

# Mapping output to input shape to perform addition
if outp.shape == inpt.shape:
outp += inpt
else:
padded_tensor = th.zeros_like(inpt)
padded_tensor[:,:,
self.modelConfig['iXBeg']: self.modelConfig['iXEnd'],
self.modelConfig['iYBeg']: self.modelConfig['iYEnd'] ] = outp[:,:,:,:]
# print(f'Padded tensor: {padded_tensor.shape}')
padded_tensor += inpt
outp = padded_tensor[:,:,self.modelConfig['iXBeg']: self.modelConfig['iXEnd'],
self.modelConfig['iYBeg']: self.modelConfig['iYEnd']]
# print(f'Ouptut shape: {outp.shape}')
inpt = outp

if not multi:
Expand Down
Loading