Skip to content

Commit

Permalink
docs: add arxiv link and ack., update abs., replace pipeline.png | st…
Browse files Browse the repository at this point in the history
…yle: black | add: .gitignore
  • Loading branch information
wtyuan96 committed Jul 22, 2022
1 parent 3986066 commit ce3bc9e
Show file tree
Hide file tree
Showing 24 changed files with 1,568 additions and 928 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
**/__pycache__
**/data
**/logs
**/*.swp
33 changes: 16 additions & 17 deletions Experiments/audio_regression/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,31 @@


def get_data(data_root, filename, factor):
rate, wav = wavfile.read(os.path.join(data_root, filename))

rate, wav = wavfile.read(os.path.join(data_root, filename))
print("Rate: %d" % rate)
print("Raw data shape: ", wav.shape)

wav = torch.tensor(wav).reshape(-1, 1)
scale = torch.max(torch.abs(wav))
wav = wav / scale # (N, 1)
wav = wav / scale # (N, 1)

grad = kornia.filters.spatial_gradient(
wav.unsqueeze(0).unsqueeze(0), mode="diff", order=1, normalized=True
).squeeze() # (2, N)
grad = grad[1, :].reshape(-1, 1) # (N, 1)

grad = kornia.filters.spatial_gradient(wav.unsqueeze(0).unsqueeze(0), mode='diff', order=1, normalized=True).squeeze() # (2, N)
grad = grad[1, :].reshape(-1, 1) # (N, 1)
coordinate = torch.linspace(0, len(wav) - 1, len(wav)).reshape(-1, 1) # (N, 1)

coordinate = torch.linspace(0, len(wav) - 1, len(wav)).reshape(-1, 1) # (N, 1)

downsampled_wav = wav[::factor, :]
downsampled_grad = grad[::factor, :]
downsampled_coordinate = coordinate[::factor, :]

return {
'wav': wav,
'grad': grad,
'coordinate': coordinate,

'downsampled_wav': downsampled_wav,
'downsampled_grad': downsampled_grad,
'downsampled_coordinate': downsampled_coordinate,
}


"wav": wav,
"grad": grad,
"coordinate": coordinate,
"downsampled_wav": downsampled_wav,
"downsampled_grad": downsampled_grad,
"downsampled_coordinate": downsampled_coordinate,
}
1 change: 0 additions & 1 deletion Experiments/audio_regression/diff_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,3 @@ def gradient(y, x, grad_outputs=None):
grad_outputs = torch.ones_like(y)
grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
return grad

6 changes: 3 additions & 3 deletions Experiments/audio_regression/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ def mse(x, y):
def val_mse(gt, pred):
val_loss = mse(gt, pred)

return {'val_loss': val_loss}
return {"val_loss": val_loss}


def der_mse(gt_grad, pred_grad):
weights = torch.ones(gt_grad.shape[1]).to(gt_grad.device)
der_loss = torch.mean((weights * (gt_grad - pred_grad).pow(2)).sum(-1))

return {'der_loss': der_loss}
return {"der_loss": der_loss}
199 changes: 130 additions & 69 deletions Experiments/audio_regression/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch.utils.tensorboard import SummaryWriter

from dataset import get_data
from dataset import get_data
from model import MLP
from loss import *
from utils import *
Expand All @@ -15,52 +15,95 @@
def config_parser():

import configargparse

parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True, help="Path of config file.")
parser.add_argument("--config", is_config_file=True, help="Path of config file.")

# logging options
parser.add_argument('--logging_root', type=str, default='./logs/', help="Where to store ckpts and logs.")
parser.add_argument('--epochs_til_ckpt', type=int, default=1000, help="Time interval in epochs until checkpoint is saved.")
parser.add_argument('--epochs_til_summary', type=int, default=100, help="Time interval in epochs until tensorboard summary is saved.")
parser.add_argument(
"--logging_root",
type=str,
default="./logs/",
help="Where to store ckpts and logs.",
)
parser.add_argument(
"--epochs_til_ckpt",
type=int,
default=1000,
help="Time interval in epochs until checkpoint is saved.",
)
parser.add_argument(
"--epochs_til_summary",
type=int,
default=100,
help="Time interval in epochs until tensorboard summary is saved.",
)

# training options
parser.add_argument('--lrate', type=float, default='5e-5')
parser.add_argument('--num_epochs', type=int, default=8000, help="Number of epochs to train for.")
parser.add_argument("--lrate", type=float, default="5e-5")
parser.add_argument(
"--num_epochs", type=int, default=8000, help="Number of epochs to train for."
)

# experiment options
parser.add_argument('--exp_name', type=str, default='supervision_val_der',
help="Name of experiment.")
parser.add_argument('--supervision', type=str, default='val_der', choices=('val', 'der', 'val_der'))
parser.add_argument('--activations', nargs='+', default=['sine', 'sine', 'sine', 'sine'])
parser.add_argument('--w0', type=float, default='30.')
parser.add_argument('--has_pos_encoding', action='store_true')
parser.add_argument('--lambda_der', type=float, default='1.')
parser.add_argument(
"--exp_name",
type=str,
default="supervision_val_der",
help="Name of experiment.",
)
parser.add_argument(
"--supervision", type=str, default="val_der", choices=("val", "der", "val_der")
)
parser.add_argument(
"--activations", nargs="+", default=["sine", "sine", "sine", "sine"]
)
parser.add_argument("--w0", type=float, default="30.")
parser.add_argument("--has_pos_encoding", action="store_true")
parser.add_argument("--lambda_der", type=float, default="1.")

# model options
parser.add_argument('--hidden_features', type=int, default=256)
parser.add_argument('--num_hidden_layers', type=int, default=3)
parser.add_argument("--hidden_features", type=int, default=256)
parser.add_argument("--num_hidden_layers", type=int, default=3)

# dataset options
parser.add_argument('--data_root', type=str, default='../../data/Audio', help="Root path to audio dataset.")
parser.add_argument('--filename', type=str, help="Name of wav file.")
parser.add_argument('--factor', type=int, default=5, help="Factor of downsampling.")

return parser
parser.add_argument(
"--data_root",
type=str,
default="../../data/Audio",
help="Root path to audio dataset.",
)
parser.add_argument("--filename", type=str, help="Name of wav file.")
parser.add_argument("--factor", type=int, default=5, help="Factor of downsampling.")

return parser

def train(args, model, data, epochs, lrate, epochs_til_summary, epochs_til_checkpoint, logging_dir, train_summary_fn, test_summary_fn, log_f):

summaries_dir = os.path.join(logging_dir, 'summaries')
def train(
args,
model,
data,
epochs,
lrate,
epochs_til_summary,
epochs_til_checkpoint,
logging_dir,
train_summary_fn,
test_summary_fn,
log_f,
):

summaries_dir = os.path.join(logging_dir, "summaries")
os.makedirs(summaries_dir)
writer = SummaryWriter(summaries_dir)

checkpoints_dir = os.path.join(logging_dir, 'checkpoints')
checkpoints_dir = os.path.join(logging_dir, "checkpoints")
os.makedirs(checkpoints_dir)

out_train_imgs_dir = os.path.join(logging_dir, 'out_train_imgs')
out_train_imgs_dir = os.path.join(logging_dir, "out_train_imgs")
os.makedirs(out_train_imgs_dir)

out_test_imgs_dir = os.path.join(logging_dir, 'out_test_imgs')
out_test_imgs_dir = os.path.join(logging_dir, "out_test_imgs")
os.makedirs(out_test_imgs_dir)

optim = torch.optim.Adam(lr=lrate, params=model.parameters())
Expand All @@ -71,16 +114,16 @@ def train(args, model, data, epochs, lrate, epochs_til_summary, epochs_til_check
for epoch in range(1, epochs + 1):

# forward and calculate loss
model_output = model(data['downsampled_coordinate'], mode='train')
model_output = model(data["downsampled_coordinate"], mode="train")
losses = {}
losses.update(val_mse(data['downsampled_wav'], model_output['pred']))
losses.update(der_mse(data['downsampled_grad'], model_output['pred_grad']))
if args.supervision == 'val':
train_loss = losses['val_loss']
elif args.supervision == 'der':
train_loss = losses['der_loss']
elif args.supervision == 'val_der':
train_loss = 1. * losses['val_loss'] + args.lambda_der * losses['der_loss']
losses.update(val_mse(data["downsampled_wav"], model_output["pred"]))
losses.update(der_mse(data["downsampled_grad"], model_output["pred_grad"]))
if args.supervision == "val":
train_loss = losses["val_loss"]
elif args.supervision == "der":
train_loss = losses["der_loss"]
elif args.supervision == "val_der":
train_loss = 1.0 * losses["val_loss"] + args.lambda_der * losses["der_loss"]
# tensorboard
for loss_name, loss in losses.items():
writer.add_scalar(loss_name, loss, epoch)
Expand All @@ -94,7 +137,9 @@ def train(args, model, data, epochs, lrate, epochs_til_summary, epochs_til_check
if (not epoch % epochs_til_summary) or (epoch == epochs):

# training summary
psnr = train_summary_fn(data, model_output, writer, epoch, out_train_imgs_dir)
psnr = train_summary_fn(
data, model_output, writer, epoch, out_train_imgs_dir
)
str_print = "[Train] epoch: (%d/%d) " % (epoch, epochs)
for loss_name, loss in losses.items():
str_print += loss_name + ": %0.6f, " % loss
Expand All @@ -104,17 +149,28 @@ def train(args, model, data, epochs, lrate, epochs_til_summary, epochs_til_check

# test summary
with torch.no_grad():
model_output = model(data['coordinate'], mode='test')
psnr = test_summary_fn(data, model_output, writer, epoch, out_test_imgs_dir, args.factor, args.filename)
model_output = model(data["coordinate"], mode="test")
psnr = test_summary_fn(
data,
model_output,
writer,
epoch,
out_test_imgs_dir,
args.factor,
args.filename,
)
str_print = "[Test]: PSNR: %.3f" % (psnr)
print(str_print)
print(str_print, file=log_f)

# save checkpoint
if (not epoch % epochs_til_checkpoint) or (epoch == epochs):
torch.save(model.state_dict(), os.path.join(checkpoints_dir, 'model_epoch_%05d.pth' % epoch))
torch.save(
model.state_dict(),
os.path.join(checkpoints_dir, "model_epoch_%05d.pth" % epoch),
)

torch.save(model.state_dict(), os.path.join(checkpoints_dir, 'model_final.pth'))
torch.save(model.state_dict(), os.path.join(checkpoints_dir, "model_final.pth"))


def main():
Expand All @@ -124,44 +180,49 @@ def main():

logging_dir = os.path.join(args.logging_root, args.exp_name)
if os.path.exists(logging_dir):
if input("The logging directory %s exists. Overwrite? (y/n)" % logging_dir) == 'y':
if (
input("The logging directory %s exists. Overwrite? (y/n)" % logging_dir)
== "y"
):
shutil.rmtree(logging_dir)
os.makedirs(logging_dir)

with open(os.path.join(logging_dir, 'log.txt'), 'w') as log_f:
with open(os.path.join(logging_dir, "log.txt"), "w") as log_f:

print("Args:\n", args)
print("Args:\n", args, file=log_f)

data = get_data(args.data_root, args.filename, args.factor)
print('Shape of original wav:', data['wav'].shape)
print('Shape of downsampled wav:', data['downsampled_wav'].shape)
print("Shape of original wav:", data["wav"].shape)
print("Shape of downsampled wav:", data["downsampled_wav"].shape)

model = MLP(
in_features=1,
out_features=1,
w0=args.w0,
activations=args.activations,
hidden_features=args.hidden_features,
num_hidden_layers=args.num_hidden_layers,
has_pos_encoding=args.has_pos_encoding,
length=len(data['wav']),
fn_samples=len(data['downsampled_wav']))
in_features=1,
out_features=1,
w0=args.w0,
activations=args.activations,
hidden_features=args.hidden_features,
num_hidden_layers=args.num_hidden_layers,
has_pos_encoding=args.has_pos_encoding,
length=len(data["wav"]),
fn_samples=len(data["downsampled_wav"]),
)
model.cuda()

train(
args=args,
model=model,
data=data,
epochs=args.num_epochs,
lrate=args.lrate,
epochs_til_summary=args.epochs_til_summary,
epochs_til_checkpoint=args.epochs_til_ckpt,
logging_dir=logging_dir,
train_summary_fn=write_train_summary,
test_summary_fn=write_test_summary,
log_f=log_f)


if __name__=='__main__':
args=args,
model=model,
data=data,
epochs=args.num_epochs,
lrate=args.lrate,
epochs_til_summary=args.epochs_til_summary,
epochs_til_checkpoint=args.epochs_til_ckpt,
logging_dir=logging_dir,
train_summary_fn=write_train_summary,
test_summary_fn=write_test_summary,
log_f=log_f,
)


if __name__ == "__main__":
main()
Loading

0 comments on commit ce3bc9e

Please sign in to comment.