Skip to content

Commit

Permalink
make pylint happier
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixBenning committed Feb 27, 2024
1 parent fca0d44 commit 5b91618
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 51 deletions.
2 changes: 1 addition & 1 deletion pyrfd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .optimizer import RFD
from .covariance import SquaredExponential
from .covariance import SquaredExponential
13 changes: 6 additions & 7 deletions pyrfd/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
cov_model.auto_fit(
model_factory=ModelM3,
loss=torch.nn.functional.nll_loss,
data= tv.datasets.MNIST(
data=tv.datasets.MNIST(
root="mnistSimpleCNN/data",
train=True,
transform=tv.transforms.ToTensor()
transform=tv.transforms.ToTensor(),
),
cache="cache/CNN3_mnist.csv",
)
rfd = RFD(
ModelM3().parameters(),
covariance_model=cov_model
)
print(cov_model.scale)
rfd = RFD(ModelM3().parameters(), covariance_model=cov_model)


print(cov_model.scale)
64 changes: 51 additions & 13 deletions pyrfd/batchsize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""
Module for optimally sampling batch sizes for better variance estimates.
"""
from time import time
from ctypes import ArgumentError
import numpy as np
from scipy import optimize, stats
import pandas as pd
from time import time
from sklearn.linear_model import LinearRegression
from tqdm import tqdm

from .sampling import _budget

# "arbitrary" design
DEFAULT_VAR_REG = LinearRegression()
Expand All @@ -14,13 +18,16 @@

CUTOFF = 20 # no batch-sizes below


def sq_error_var(var_reg, b):
""" calculate the 4th moment (for Gaussian rvs) i.e. variance of centered
squares, where the variance regression determines their variance for
different batchsizes"""
return 3 * var_reg.predict((1 / np.asarray(b)).reshape(-1, 1)) ** 2


def empirical_intercept_variance(counts, var_reg):
"""
"""
"""TODO: change or compare to a bootstrap version, this assumes theory applicable"""
n = sum(counts)
dist = stats.rv_discrete(
name="epirical batchsize distribution",
Expand All @@ -33,13 +40,28 @@ def empirical_intercept_variance(counts, var_reg):


def limit_intercept_variance(dist: stats.rv_discrete, var_reg):
"""limiting variance as sample budget grows to infinity, this assumes the
number of batch size samples n is related to the buget in the following way:
n E[B] < budget
where B is the random batch size sampled from dist. Thus n = budget/E[B],
and we can replace 1/n by E[B]/budget in the formula. Since we let budget to
infinity, we remove the budget to get an asymptotic/convergent formulation.
"""
theta = dist.expect(func=lambda x: 1 / sq_error_var(var_reg, x))
w_1st_mom = dist.expect(func=lambda x: 1 / (sq_error_var(var_reg, x) * x))
w_2nd_mom = dist.expect(func=lambda x: 1 / (sq_error_var(var_reg, x) * (x**2)))
return (dist.mean() * w_2nd_mom) / (theta * w_2nd_mom - w_1st_mom**2)


def batchsize_dist(var_reg=DEFAULT_VAR_REG, logging=False):
"""Find the optimal batch size distribution (in a Gibbs distribution class)
under the assumption that the variance regression var_reg is true. The Gibbs
distribution class is of the form
p(b) = exp(w[0]/(3*var_reg.predict(1/b)**2) - w[1]b)
and optimized over `w`
"""
beta_0 = var_reg.intercept_
beta_1 = var_reg.coef_[0]
if beta_0 <= 0:
Expand Down Expand Up @@ -68,13 +90,16 @@ def gibbs_dist(w):

if logging:
tqdm.write("Optimizing over batchsize distribution using Nelder-Mead")

def callback(x):
tqdm.write(f"> current parameters: {np.exp(x)}) ", end="\r")
tqdm.write(
f"> current parameters: {np.exp(x)}) ", end="\r"
)

else:
def callback(x):
pass

def callback(_):
pass

start = time()
res = optimize.minimize(
Expand All @@ -86,7 +111,9 @@ def callback(x):
end = time()
weights = np.exp(res.x)
if logging:
tqdm.write(f"> Final batchsize distribution parameters: {weights} ")
tqdm.write(
f"> Final batchsize distribution parameters: {weights} "
)
tqdm.write(f"> {res.message}")
tqdm.write(f"> Time Elapsed: {end-start:.0f} (seconds)")

Expand All @@ -96,7 +123,15 @@ def callback(x):
def batchsize_counts(
budget, var_reg=DEFAULT_VAR_REG, existing_b_size_samples: pd.Series = pd.Series()
):
spent_budget = sum([b * count for b, count in existing_b_size_samples.items()])
"""Determines the optimal batchsize distribution (in a Gibbs distribution class),
then adjusts the distribution for existing batchsize samples. (i.e. sample fewer
from the existing sizes), such that after `budget` is spent, the distribution
should be as close as possible to the optimal distribution. Then samples from
this adjusted distribution.
return a series with counts of batchsizes (where the batchsizes are the index).
"""
spent_budget = _budget(existing_b_size_samples)
total = spent_budget + budget
optimal_dist: stats.rv_discrete = batchsize_dist(var_reg)

Expand All @@ -105,20 +140,23 @@ def batchsize_counts(
df = pd.DataFrame(index=support, data={"desired_dist": optimal_dist.pmf(support)})
df["desired_counts"] = df["desired_dist"] * total / optimal_dist.mean()

pd.set_option("future.no_silent_downcasting", True) # remove warning
pd.set_option("future.no_silent_downcasting", True) # remove warning
df = df.join(existing_b_size_samples.to_frame("existing_counts")).fillna(0)

df["required_counts"] = df["desired_counts"] - df["existing_counts"]
req_cts = df[df["required_counts"] > 0]["required_counts"]
df["required_distribution"] = req_cts / req_cts.sum()
df["required_distribution"] = req_cts / req_cts.sum()
df["required_distribution"] = df["required_distribution"].fillna(0)

required_dist = stats.rv_discrete(
values=(df.index, df["required_distribution"].to_numpy())
)

est_sample_num = np.ceil(budget / required_dist.mean()).astype(int)
b_size_samples = required_dist.rvs(size=est_sample_num + 500, random_state=int(time()))
b_size_samples = required_dist.rvs(
# estimate required samples and add padding
size=np.ceil(budget / required_dist.mean()).astype(int) + 500,
random_state=int(time()),
)
last_idx = np.searchsorted(np.cumsum(b_size_samples), budget) + 1

required_counts = (
Expand Down
97 changes: 80 additions & 17 deletions pyrfd/covariance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
Module providing Covariance models to pass to RFD. I.e. they can be fitted
using loss samples and they provide a learning rate
"""

from abc import abstractmethod
from ctypes import ArgumentError
from logging import warning
Expand All @@ -19,6 +24,7 @@

from .sampling import CachedSamples, IsotropicSampler, _budget


def selection(sorted_list, num_elts):
"""
return a selection of num_elts from the sorted_list (evenly spaced in the index)
Expand All @@ -38,6 +44,13 @@ def fit_mean_var(
var_reg=DEFAULT_VAR_REG,
logging=False,
):
"""Bootstraps weighted least squares regression (WLS) to determine the mean,
and variance of the loss at varying batchsizes. Returns the mean and variance
regression.
An existing regression can be passed to act as a starting point
for the bootstrap.
"""
batch_sizes = np.array(batch_sizes)
batch_losses = np.array(batch_losses)
b_inv = 1 / batch_sizes
Expand All @@ -50,9 +63,11 @@ def fit_mean_var(

# bootstrapping Weighted Least Squares (WLS)
for idx in range(max_bootstrap):
vars = var_reg.predict(b_inv.reshape(-1, 1)) # variance at batchsizes 1/b in X
variances = var_reg.predict(
b_inv.reshape(-1, 1)
) # variance at batchsizes 1/b in X

mu = np.average(batch_losses, weights=(1 / vars))
mu = np.average(batch_losses, weights=1 / variances)
centered_squares = (batch_losses - mu) ** 2

old_intercept = var_reg.intercept_
Expand All @@ -62,7 +77,7 @@ def fit_mean_var(
var_reg.fit(
b_inv.reshape(-1, 1),
centered_squares,
sample_weight=1 / vars**2,
sample_weight=1 / variances**2,
)

if math.isclose(old_intercept, var_reg.intercept_):
Expand All @@ -82,12 +97,19 @@ def isotropic_derivative_var_estimation(
g_var_reg=DEFAULT_VAR_REG,
logging=False,
) -> LinearRegression:
"""Bootstraps weighted least squares regression (WLS) to determine the
expectation of gradient norms of the loss at varying batchsizes.
Returns the regression of the mean against 1/b where b is the batchsize.
An existing regression can be passed to act as a starting point
for the bootstrap.
"""
batch_sizes = np.array(batch_sizes)
b_inv: np.array = 1 / batch_sizes

# bootstrapping WLS
for idx in range(max_bootstrap):
vars: np.array = g_var_reg.predict(
variances: np.array = g_var_reg.predict(
b_inv.reshape(-1, 1)
) # variances at batchsize 1/b

Expand All @@ -96,7 +118,9 @@ def isotropic_derivative_var_estimation(
# out in the weighting we also have a sum of squares (norm), but this
# also only results in a constant which does not matter
old_bias = g_var_reg.intercept_
g_var_reg.fit(b_inv.reshape(-1, 1), sq_grad_norms, sample_weight=(1 / vars**2))
g_var_reg.fit(
b_inv.reshape(-1, 1), sq_grad_norms, sample_weight=1 / variances**2
)

if math.isclose(old_bias, g_var_reg.intercept_):
if logging:
Expand All @@ -108,18 +132,33 @@ def isotropic_derivative_var_estimation(


class IsotropicCovariance:
"""Abstract isotropic covariance class, providing some fallback methods.
Can be subclassed for specific covariance models (see e.g. SquaredExponential)
"""

__slots__ = "mean", "var_reg", "g_var_reg", "dims", "fitted"

def __init__(self) -> None:
self.fitted = False
self.var_reg = DEFAULT_VAR_REG
self.g_var_reg = DEFAULT_VAR_REG
self.dims = None
self.mean = None

@abstractmethod
def learning_rate(self, loss, grad_norm):
"""learning rate of this covariance model from the RFD paper"""
return NotImplemented

def fit(self, df: pd.DataFrame, dims):
""" " Fit the covariance model with loss and gradient norm samples
provided in a pandas dataframe, with columns containing:
- batchsize
- loss
- grad_norm or sq_grad_norm
"""
self.dims = dims
if ("sq_grad_norm" not in df) and ("grad_norm" in df):
df["sq_grad_norm"] = df["grad_norm"] ** 2
Expand Down Expand Up @@ -322,20 +361,28 @@ def auto_fit(
------
Paremeters:
1. A `model_factory` which returns the same randomly initialized [!] model every time it is called
2. A `loss` function e.g. torch.nn.functional.nll_loss which accepts a prediction and a true value
3. data, which can be passed to `torch.utils.DataLoader` with different batch size parameters such
that it returns (x,y) tuples when iterated on
1. A `model_factory` which returns the same randomly initialized [!]
model every time it is called
2. A `loss` function e.g. torch.nn.functional.nll_loss which accepts
a prediction and a true value
3. data, which can be passed to `torch.utils.DataLoader` with
different batch size parameters such that it returns (x,y) tuples when
iterated on
"""
dims = sum(p.numel() for p in model_factory().parameters() if p.requires_grad)
print(f"\n\nAutomatically fitting Covariance Model: {repr(self)}")

sampler = IsotropicSampler(model_factory, loss, data)

if cache:
print("Tip: You can cancel sampling at any time, samples will be saved in the cache.")
print(
"Tip: You can cancel sampling at any time, samples will be saved in the cache."
)
else:
warning("Without a cache it is necessary to re-fit the covariance model every time. Please pass a filepath to the cache parameter")
warning(
"Without a cache it is necessary to re-fit the covariance model"
+ "every time. Please pass a filepath to the cache parameter"
)

cached_samples = CachedSamples(cache)
budget = initial_budget
Expand All @@ -353,7 +400,7 @@ def auto_fit(
self.fit(samples, dims)

var_mean = self.var_reg.intercept_
if var_mean <= 0:
if var_mean <= 0:
# negative variance est -> reset
self.fitted = False
self.var_reg = DEFAULT_VAR_REG
Expand All @@ -362,11 +409,15 @@ def auto_fit(
var_var = empirical_intercept_variance(bsize_counts, self.var_reg)
rel_error = np.sqrt(var_var) / var_mean
tqdm.write(f"\nCheckpoint {idx}:")
tqdm.write("-----------------------------------------------------------")
tqdm.write(
"-----------------------------------------------------------"
)
tqdm.write(f"Estimated relative error: {rel_error}")

if rel_error < tol:
tqdm.write(f"\nSucessfully fitted the Covariance model to a relative error <{tol}")
tqdm.write(
f"\nSucessfully fitted the Covariance model to a relative error <{tol}"
)
break # stop early

dist = stats.rv_discrete(
Expand Down Expand Up @@ -402,7 +453,6 @@ def auto_fit(
outer_pgb.refresh()
## PROGRESS ======================================================


needed_bsize_counts = batchsize_counts(
budget,
self.var_reg,
Expand All @@ -419,17 +469,30 @@ def auto_fit(


class SquaredExponential(IsotropicCovariance):
"""The Squared exponential covariance model. I.e.
C(x) = self.variance * exp(-x^2/(2*self.scale^2))
needs to be fitted using .auto_fit or .fit.
"""

@property
def variance(self):
"""the estimated variance (should only be accessed after fitting)"""
if self.fitted:
return self.var_reg.intercept_
raise ArgumentError("The covariance is not fitted yet, use `auto_fit` or `fit` before use")
raise ArgumentError(
"The covariance is not fitted yet, use `auto_fit` or `fit` before use"
)

@property
def scale(self):
"""the estimated scale (should only be accessed after fitting)"""
if self.fitted:
return np.sqrt(self.variance * self.dims / self.g_var_reg.intercept_)
raise ArgumentError("The covariance is not fitted yet, use `auto_fit` or `fit` before use")
raise ArgumentError(
"The covariance is not fitted yet, use `auto_fit` or `fit` before use"
)

def learning_rate(self, loss, grad_norm):
"""RFD learning rate from Random Function Descent paper"""
Expand Down
Loading

0 comments on commit 5b91618

Please sign in to comment.