Skip to content

Latest commit

 

History

History
62 lines (48 loc) · 1.41 KB

README.md

File metadata and controls

62 lines (48 loc) · 1.41 KB

Recursive Feature Machines

There are two notebooks to test out RFM:

  • low_rank.ipynb (an example of low rank polynomials)
  • svhn.ipynb (for the SVHN dataset)

Installation

Can be installed using the command

 pip install git+https://github.com/aradha/recursive_feature_machines.git@pip_install

Requirements:

  • Python 3.8+
  • torchvision==0.14.0
  • hickle==5.0.2
  • tqdm

Stable behavior

Code has been tested using PyTorch 1.13, Python 3.8

Testing installation

import torch
from rfm import LaplaceRFM

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    # find GPU memory in GB, keeping aside 1GB for safety
    DEV_MEM_GB = torch.cuda.get_device_properties(DEVICE).total_memory//1024**3 - 1 
else:
    DEVICE = torch.device("cpu")
    DEV_MEM_GB = 8

def fstar(X):
    return torch.cat([(X[:,0]>0)[:,None], 
	(X[:,1]<0.5)[:,None]], axis=1).float()

model = LaplaceRFM(bandwidth=1., device=DEVICE, mem_gb=DEV_MEM_GB, diag=False)

n = 1000 # samples
d = 100  # dimension
c = 2    # classes

X_train = torch.randn(n, d, device=DEVICE)
X_test = torch.randn(n, d, device=DEVICE)
y_train = fstar(X_train)
y_test = fstar(X_test)

model.fit(
    (X_train, y_train), 
    (X_test, y_test), 
    loader=False, 
    iters=5,
    classif=False
)

Paper

Mechanism of feature learning in deep fully connected networks and kernel machines that recursively learn features