Skip to content

Commit

Permalink
allow uniform and normal xavier_init
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Sep 26, 2023
1 parent e14716b commit b27d8ab
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
24 changes: 17 additions & 7 deletions matgl/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import math
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -444,19 +444,29 @@ def loss_fn(
}


def xavier_init(model: nn.Module) -> None:
def xavier_init(model: nn.Module, gain: float = 1.0, distribution: Literal["uniform", "normal"] = "uniform") -> None:
"""Xavier initialization scheme for the model.
Args:
model (nn.Module): The model to be Xavier-initialized.
gain (float): Gain factor. Defaults to 1.0.
distribution (Literal["uniform", "normal"], optional): Distribution to use. Defaults to "uniform".
"""
if distribution == "uniform":
init_fn = nn.init.xavier_uniform_
elif distribution == "normal":
init_fn = nn.init.xavier_normal_
else:
raise ValueError(f"Invalid distribution: {distribution}")

for name, param in model.named_parameters():
if name.endswith(".bias"):
param.data.fill_(0)
else:
if param.dim() < 2:
bound = math.sqrt(6) / math.sqrt(param.shape[0] + param.shape[0])
elif param.dim() < 2: # torch.nn.xavier only supports >= 2 dim tensors
bound = gain * math.sqrt(6) / math.sqrt(2 * param.shape[0])
if distribution == "uniform":
param.data.uniform_(-bound, bound)
else:
bound = math.sqrt(6) / math.sqrt(param.shape[0] + param.shape[1])
param.data.uniform_(-bound, bound)
param.data.normal_(0, bound**2)
else:
init_fn(param.data, gain=gain)
17 changes: 17 additions & 0 deletions tests/utils/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import partial

import numpy as np
import pytest
import pytorch_lightning as pl
import torch.backends.mps
from dgl.data.utils import split_dataset
Expand Down Expand Up @@ -236,3 +237,19 @@ def teardown_class(cls):
pass

shutil.rmtree("lightning_logs")


@pytest.mark.parametrize("distribution", ["normal", "uniform", "fake"])
def test_xavier_init(distribution):
model = MEGNet()
# get a parameter
w = model.output_proj.layers[0].get_parameter("weight").clone()

if distribution == "fake":
with pytest.raises(ValueError, match=r"^Invalid distribution:."):
xavier_init(model, distribution=distribution)
else:
xavier_init(model, distribution=distribution)
print(w)
assert not torch.allclose(w, model.output_proj.layers[0].get_parameter("weight"))
assert torch.allclose(torch.tensor(0.0), model.output_proj.layers[0].get_parameter("bias"))

0 comments on commit b27d8ab

Please sign in to comment.