Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mh-amani committed Nov 25, 2023
1 parent e9b8612 commit 95b86b6
Show file tree
Hide file tree
Showing 12 changed files with 635 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ git clone https://github.com/mh-amani/neural_discrete_reasoning
cd neural_discrete_reasoning

# [OPTIONAL] create conda environment
conda create -n myenv python=3.11
conda create -n ndr python=3.11
conda activate ndr

# install pytorch according to instructions
Expand Down
38 changes: 38 additions & 0 deletions configs/experiment/pvr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: "pvr"
run_name: "${model_key}-${discretizer_key}"

# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: mnist
- override /model: mnist
- override /callbacks: default
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["mnist", "simple_dense_net"]

seed: 12345

trainer:
min_epochs: 10
max_epochs: 10
gradient_clip_val: 0.5

model:
optimizer:
lr: 0.002


data:
batch_size: 64

logger:
wandb:
tags: ${tags}
group: "mnist"
3 changes: 2 additions & 1 deletion configs/logger/wandb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ wandb:
offline: False
id: null # pass correct id to resume experiment!
anonymous: null # enable anonymous logging
project: "lightning-hydra-template"
project: ${name}
name: ${run_name}
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
# entity: "" # set to name of your wandb team
Expand Down
45 changes: 45 additions & 0 deletions configs/model/transformer_dbn_classifier.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
_target_: src.models.transformer_dbn_classifier.TransformerDBNClassifier

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.001
weight_decay: 0.0

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10

####################################################
# compile model for faster training with pytorch 2.0
compile: false
embedding_dim: 256
dbn_after_each_layer: True
num_transformer_layers: 3

discrete_layer:
_target_: src.models.components.discrete_layers.vqvae.VQVAEDiscreteLayer
key: 'vqvae'
temperature: 1.0
label_smoothing_scale: 0.0
dist_ord: 2
dictionary_dim: ${model.params.embedding_dim}
hard: True
projection_method: "layer norm" # "unit-sphere" "scale" "layer norm" or "None"
beta: 0.25

transformer_layer:
_target_: src.models.components.transformer.TransformerLayer
num_heads: 8
dim_feedforward: ${model.params.embedding_dim}
dropout: 0.1
activation: "relu"
dim: ${model.params.embedding_dim}
norm: "layer_norm"
batch_first: True



8 changes: 6 additions & 2 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defaults:
- data: mnist
- model: mnist
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- paths: default
- extras: default
Expand All @@ -27,6 +27,10 @@ defaults:
# debugging config (enable through command line, e.g. `python train.py debug=default)
- debug: null

# determines the log directory's identifier
run_name: ???
name: ???

# task name, determines output directory path
task_name: "train"

Expand All @@ -46,4 +50,4 @@ test: True
ckpt_path: null

# seed for random number generators in pytorch, numpy and python.random
seed: null
seed: 42
6 changes: 3 additions & 3 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# - conda allows for installing packages without requiring certain compilers or
# libraries to be available in the system, since it installs precompiled binaries

name: myenv
name: ndr

channels:
- pytorch
Expand All @@ -21,7 +21,7 @@ channels:
# compatibility is usually guaranteed

dependencies:
- python=3.10
- python=3.11
- pytorch=2.*
- torchvision=0.*
- lightning=2.*
Expand All @@ -32,7 +32,7 @@ dependencies:
- pytest=7.*

# --------- loggers --------- #
# - wandb
- wandb
# - neptune-client
# - mlflow
# - comet-ml
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ hydra-colorlog==1.2.0
hydra-optuna-sweeper==1.2.0

# --------- loggers --------- #
# wandb
wandb
# neptune-client
# mlflow
# comet-ml
Expand Down
60 changes: 60 additions & 0 deletions src/models/components/discrete_layers/abstract_discrete_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
import torch.nn as nn
import torch
from torch.nn import LayerNorm
class AbstractDiscreteLayer(nn.Module):
def __init__(self, dims, **kwargs) -> None:
super().__init__()
self.input_dim = dims['input_dim'] # fed by the model, after x->z and z->x models are instantiated
self.output_dim = dims['output_dim'] # fed by the model, after x->z and z->x models are instantiated
self.vocab_size = dims['vocab_size']
self.dictionary_dim = kwargs['dictionary_dim']

self.temperature = kwargs.get('temperature', 1)
self.label_smoothing_scale = kwargs.get('label_smoothing_scale', 0.001)

self.out_layer_norm = LayerNorm(self.dictionary_dim)

self.dictionary = nn.Embedding(self.vocab_size, self.dictionary_dim)

self.output_embedding = nn.Linear(self.output_dim, self.dictionary_dim)
self.encoder_embedding = nn.Linear(self.dictionary_dim, self.input_dim)
self.decoder_embedding = nn.Linear(self.dictionary_dim, self.output_dim)

def decoder_to_discrete_embedding(self, x):
out_x = self.output_embedding(x)
return out_x

def discrete_embedding_to_decoder(self, x):
return self.decoder_embedding(x)

def discrete_embedding_to_encoder(self, x):
return self.encoder_embedding(x)

def project_matrix(self,x,**kwargs):
return x

def project_embedding_matrix(self):
self.dictionary.weight = torch.nn.Parameter(self.project_matrix(self.dictionary.weight))

def forward(self, x,**kwargs):
continous_vector = self.decoder_to_discrete_embedding(x)

# scores are between 0 and 1, and sum to 1 over the vocab dimension.
id, score, quantized_vector, quantization_loss = self.discretize(continous_vector,**kwargs)
return id, score, quantized_vector, quantization_loss

def embed_enc_from_id(self, x):
embeds = self.dictionary(x)
return self.discrete_embedding_to_encoder(embeds)

def embed_dec_from_id(self, x):
embeds = self.dictionary(x)
return self.discrete_embedding_to_decoder(embeds)

@abstractmethod
def discretize(self, x,**kwargs) -> dict:
pass



18 changes: 18 additions & 0 deletions src/models/components/discrete_layers/gumbel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .abstract_discrete_layer import AbstractDiscreteLayer
import torch
from torch.nn.functional import gumbel_softmax


class GumbelDiscreteLayer(AbstractDiscreteLayer):
def __init__(self, dims, **kwargs) -> None:
super().__init__(dims, **kwargs)
self.hard = kwargs['hard'] # if True, use argmax in forward pass, else use gumbel softmax. the backwardpass is the same in both cases
self.output_embedding = torch.nn.Linear(self.output_dim, self.vocab_size)

def discretize(self, x,**kwargs) -> dict:
score = gumbel_softmax(x, tau=self.temperature, hard=self.hard, dim=-1)
x_quantized = torch.matmul(score, self.dictionary.weight)
id = torch.argmax(score, dim=-1)
quantization_loss = 0
return id, score, x_quantized, quantization_loss

86 changes: 86 additions & 0 deletions src/models/components/discrete_layers/vqvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from .abstract_discrete_layer import AbstractDiscreteLayer
import torch
from torch import nn
# from vector_quantize_pytorch import VectorQuantize
from entmax import sparsemax

class VQVAEDiscreteLayer(AbstractDiscreteLayer):
def __init__(self, dims, **kwargs) -> None:
super().__init__(dims, **kwargs)

self.projection_method = kwargs.get("projection_method",None)

self.dictionary = nn.Embedding(self.vocab_size, self.dictionary_dim)
self.dictionary.weight = torch.nn.Parameter(self.project_matrix(self.dictionary.weight))

self.dist_ord = kwargs.get('dist_ord', 2)
self.embedding_loss = torch.nn.functional.mse_loss # torch.nn.CosineSimilarity(dim=-1)
self.hard = kwargs['hard']
self.kernel = nn.Softmax(dim=-1)
self.beta = kwargs.get("beta",0.25) #0.25 is the beta used in the vq-vae paper

###################
#Probably can remove these as we are using th matrix projection now
# def fetch_embeddings_by_index(self,indices):
# if self.normalize_embeddings:
# return nn.functional.normalize(self.dictionary(indices),dim=-1)
# #~else
# return self.dictionary(indices)

# def fetch_embedding_matrix(self):
# if self.normalize_embeddings:
# return nn.functional.normalize(self.dictionary.weight,dim=-1)
# #~else
# return self.dictionary.weight
###################

def project_matrix(self,x):
if self.projection_method == "unit-sphere":
return torch.nn.functional.normalize(x,dim=-1)
if self.projection_method == "scale":
# devide the vector by the square root of the dimension
return x / torch.sqrt(self.dictionary_dim)
if self.projection_method == "layer norm":
return self.out_layer_norm(x)
return x

def discretize(self, x, **kwargs) -> dict:
probs = self.kernel( - self.codebook_distances(x) / self.temperature)
x = self.project_matrix(x)
indices = torch.argmax(probs, dim=-1)

if self.hard:
# Apply STE for hard quantization
quantized = self.dictionary(indices)#self.fetch_embeddings_by_index(indices)
quantized = quantized + x - (x).detach()
else:
quantized = torch.matmul(probs, self.dictionary.weight)

if kwargs.get("supervision",False):
true_quantized = self.dictionary(kwargs.get("true_ids",None))
commitment_loss = self.embedding_loss(true_quantized.detach(),x)
embedding_loss = self.embedding_loss(true_quantized,x.detach())

else:
commitment_loss = self.embedding_loss(quantized.detach(),x)
embedding_loss = self.embedding_loss(quantized,x.detach())

vq_loss = self.beta * commitment_loss + embedding_loss

return indices, probs, quantized, vq_loss

def codebook_distances(self, x):

#dictionary_expanded = self.fetch_embedding_matrix().unsqueeze(0).unsqueeze(1) # Shape: (batch, 1, vocab, dim)
dictionary_expanded = self.dictionary.weight.unsqueeze(0).unsqueeze(1)
x_expanded = x.unsqueeze(2)
# if self.normalize_embeddings:
# x_expanded = nn.functional.normalize(x,dim=-1).unsqueeze(2) # Shape: (batch, length, 1, dim)
# else:
# x_expanded = x.unsqueeze(2) # Shape: (batch, length, 1, dim)

# Compute the squared differences
dist = torch.linalg.vector_norm(x_expanded - dictionary_expanded, ord=self.dist_ord, dim=-1)
return dist


Loading

0 comments on commit 95b86b6

Please sign in to comment.