Skip to content

Commit

Permalink
add codes
Browse files Browse the repository at this point in the history
  • Loading branch information
losyer committed Jun 4, 2023
1 parent 043908c commit 203c123
Show file tree
Hide file tree
Showing 25 changed files with 2,468 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# whitening_effect

The codes will be released after the paper of this study is accepted.

1 change: 1 addition & 0 deletions recon_with_debias/global_utils/global_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from global_utils.global_utils import *
31 changes: 31 additions & 0 deletions recon_with_debias/global_utils/global_utils/global_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

colors = ['aqua', 'aquamarine', 'azure', 'beige', 'bisque', 'black', 'blanchedalmond', 'blue', 'blueviolet', 'brown', 'burlywood', 'cadetblue', 'chartreuse', 'chocolate', 'coral', 'cornflowerblue', 'cornsilk', 'crimson', 'cyan', 'darkblue', 'darkcyan', 'darkgoldenrod', 'darkgray', 'darkgreen', 'darkgrey', 'darkkhaki', 'darkmagenta', 'darkolivegreen', 'darkorange', 'darkorchid', 'darkred', 'darksalmon', 'darkseagreen', 'darkslateblue', 'darkslategray', 'darkslategrey', 'darkturquoise', 'darkviolet', 'deeppink', 'deepskyblue', 'dimgray', 'dimgrey', 'dodgerblue', 'firebrick', 'floralwhite', 'forestgreen', 'fuchsia', 'gainsboro', 'ghostwhite', 'gold', 'goldenrod', 'gray', 'green', 'greenyellow', 'grey', 'honeydew', 'hotpink', 'indianred', 'indigo', 'ivory', 'khaki', 'lavender', 'lavenderblush', 'lawngreen', 'lemonchiffon', 'lightblue', 'lightcoral', 'lightcyan', 'lightgoldenrodyellow', 'lightgray', 'lightgreen', 'lightgrey', 'lightpink', 'lightsalmon', 'lightseagreen', 'lightskyblue', 'lightslategray', 'lightslategrey', 'lightsteelblue', 'lightyellow', 'lime', 'limegreen', 'linen', 'magenta', 'maroon', 'mediumaquamarine', 'mediumblue', 'mediumorchid', 'mediumpurple', 'mediumseagreen', 'mediumslateblue', 'mediumspringgreen', 'mediumturquoise', 'mediumvioletred', 'midnightblue', 'mintcream', 'mistyrose', 'moccasin', 'navajowhite', 'navy', 'oldlace', 'olive', 'olivedrab', 'orange', 'orangered', 'orchid', 'palegoldenrod', 'palegreen', 'paleturquoise', 'palevioletred', 'papayawhip', 'peachpuff', 'peru', 'pink', 'plum', 'powderblue', 'purple', 'rebeccapurple', 'red', 'rosybrown', 'royalblue', 'saddlebrown', 'salmon', 'sandybrown', 'seagreen', 'seashell', 'sienna', 'silver', 'skyblue', 'slateblue', 'slategray', 'slategrey', 'snow', 'springgreen', 'steelblue', 'tan', 'teal', 'thistle', 'tomato', 'turquoise', 'violet', 'wheat', 'white', 'whitesmoke', 'yellow', 'yellowgreen']


def test():
print('test')


def prepare_directory(path, loop_num=100, sleep_sec=3, test=False):
from time import sleep
import datetime
import os

t_delta = datetime.timedelta(hours=9)
JST = datetime.timezone(t_delta, 'JST')

for _ in range(loop_num):
if test:
raise NotImplementedError
else:
modified_path = path + '_' + \
datetime.datetime.now(JST).strftime("%Y-%m-%d_%H-%M-%S")
try:
os.makedirs(modified_path)
break
except:
print(modified_path)
print(f'Directory name conflict: sleep {sleep_sec} sec.', flush=True)
sleep(sleep_sec)

return modified_path
7 changes: 7 additions & 0 deletions recon_with_debias/global_utils/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from setuptools import setup, find_packages

setup(
name='global_utils',
version='0.1',
packages=find_packages()
)
21 changes: 21 additions & 0 deletions recon_with_debias/src/losses/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from torch import nn
import torch.nn.functional as F


class ReconstructionLoss(nn.Module):

def __init__(self, w2v_model):
super(ReconstructionLoss, self).__init__()
self.w2v_model = w2v_model

def forward(self, word_idx, ref_vector, freq=None, normalize=False):
vec = self.w2v_model(word_idx)
if normalize:
vec = F.normalize(vec)
ref_vector = F.normalize(ref_vector)
loss_fn = nn.MSELoss(reduction='none')
loss = loss_fn(vec, ref_vector).mean(axis=1)
if freq != None:
loss = loss * freq.log()

return loss.mean()
18 changes: 18 additions & 0 deletions recon_with_debias/src/models/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from torch import nn


class ReconModel(nn.Module):

def __init__(self, vocab_size, embed_dim):
super(ReconModel, self).__init__()
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size,
embed_dim)
self.init_weights()

def init_weights(self):
self.embedding.weight.data.normal_(mean=0, std=0.01)

def forward(self, idx):
embeddings = self.embedding(idx)
return embeddings
112 changes: 112 additions & 0 deletions recon_with_debias/src/prepare_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
import codecs
import numpy as np
from utils.utils import get_total_line
import torch
import logging

logger = logging.getLogger()


class DataHandler(object):
def __init__(self, args):
self.args = args
self.ref_vec_path = args.ref_vec_path
self.test = args.test
if not self.args.inference and self.args.freq_path != "":
self.create_word_to_freq_dic()
self.filtering_words = self.load_filtering_words(args.filtering_words_path)
self.loaded_words_set = set()
self.total_line = None

def get_freq(self, word):
try:
return self.word_to_freq_dic[word]
except:
return 2

def create_word_to_freq_dic(self):
logger.info("Create word to frequency dictionary ...")
self.word_to_freq_dic = {}
for line in codecs.open(self.args.freq_path, "r", 'utf-8', errors='replace'):
col = line.strip().split("\t")
assert len(col) == 2
word, freq = col[0], int(col[1])
self.word_to_freq_dic[word] = freq
logger.info("Create word to frequency dictionary ... done")

def load_filtering_words(self, path):
if path == '':
return set()
else:
logger.info('Loading filtering words ...')
words = set()
for line in codecs.open(path, "r", 'utf-8', errors='replace'):
word = line.strip()
words.add(word)
logger.info('Loading filtering words ... done')

return words

def prepare_dataset(self):
logger.info("Create dataset ...")
self.train_data = self.load_dataset()

# Optional Handling
# pass

logger.info("Create dataset ... done")

def load_dataset(self):
dataset = []
self.total_line = get_total_line(path=self.ref_vec_path, test=self.test)
skipped_word_count = 0
with codecs.open(self.ref_vec_path, "r", 'utf-8', errors='replace') as input_data:
for i, line in enumerate(input_data):

if i % int(self.total_line / 10) == 0:
logger.info('{} % done'.format(round(i / (self.total_line / 100))))

if i == 0:
# Get headder info.
col = line.strip('\n').split()
vocab_size, dim = int(col[0]), int(col[1])
continue
col = line.rstrip(' \n').rsplit(' ', dim)
word = col[0]

# if self.args.inference:
# raise NotImplementedError
# else:
# Skip special conditions
if word in self.filtering_words \
or (len(word) > 30 and self.args.discard_long_word) \
or len(col) != dim + 1:
print(line)
skipped_word_count += 1
continue

self.loaded_words_set.add(word)

word_idx = i - 1 - skipped_word_count
word_idx_array = np.array(word_idx, dtype=np.int32)
ref = [None] if self.args.inference else col[1:]
ref_vector = np.array(ref, dtype=np.float32)
assert len(ref_vector) == dim
y = np.array(0, dtype=np.int32)
freq = self.get_freq(word) if self.args.freq_path != "" else 1
freq_array = np.array(freq, dtype=np.float32)
freq_label = torch.tensor(0).long() if word_idx < self.args.adv_freq_thresh else torch.tensor(1).long()

dataset = self.set_dataset(dataset, word_idx_array, ref_vector, freq_array, freq_label, y)

if self.test and len(dataset) == 1000:
logger.info("Prepared small dataset for quick test.")
break
logger.info(f'len(dataset) = {len(dataset)}')
assert len(dataset) == word_idx + 1
return dataset

def set_dataset(self, dataset, word_idx, ref_vector, freq_array, freq_label, y):
dataset.append((word_idx, ref_vector, freq_array, freq_label, y))
return dataset
184 changes: 184 additions & 0 deletions recon_with_debias/src/run_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# coding: utf-8
import argparse
import json
import os
import sys
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from models.model import ReconModel
from losses.loss import ReconstructionLoss
from utils.load_model import load_words
from utils.load_model import output_vector
from prepare_dataset import DataHandler
from trainer import Trainer
from global_utils import prepare_directory

import logging


def init_logger(name):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
log_format = '%(asctime)s [%(levelname)-8s] [%(module)s#%(funcName)s %(lineno)d] %(message)s'
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
handler.flush = sys.stdout.flush
if not logger.hasHandlers():
logger.addHandler(handler)
return logger
logger = init_logger(None)


def save_vector(args, result_dest, selected_epochs):
words = load_words(args, args.test)
vocab_size = len(words)

for epoch_num in selected_epochs:
epoch_num = epoch_num + 1
w2v_model = model_setup(args, vocab_size)
model_path = result_dest + f'/selected_epoch_{epoch_num}/word_embeddings.bin'
w2v_model.load_state_dict(torch.load(model_path))
embeddings = w2v_model.embedding.weight.detach().numpy()
output_path = result_dest + f'/selected_epoch_{epoch_num}/'
name = output_path + f'{args.save_vector_name}'
output_vector(words, embeddings, args.embed_dim, name=name)
print('Done')


def get_sampler(split, args, shuffle=False, distributed=False, rank=0):
if distributed:
return DistributedSampler(split, num_replicas=args.n_gpu, rank=rank, shuffle=shuffle)
else:
return None


def initial_setup(args):
# Setup result directory
result_dest_name = args.result_dir
result_dest = prepare_directory(result_dest_name+'/')

with open(os.path.join(result_dest, "settings.json"), "w") as fo:
fo.write(json.dumps(vars(args), sort_keys=True, indent=4))
print()
print('###########')
print('#Arguments#')
print('###########')
print(json.dumps(vars(args), sort_keys=True, indent=4), flush=True)

logger.info("result dest: " + result_dest)
return result_dest


def model_setup(args, vocab_size=1000):

model = ReconModel(vocab_size=vocab_size,
embed_dim=args.embed_dim)
return model


def main(args):
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
result_dest = initial_setup(args)
distributed = False
rank = args.rank

# Data setup
data_handler = DataHandler(args)
data_handler.prepare_dataset()
train_dataset = data_handler.train_data
vocab_size = len(train_dataset)

# Model setup
w2v_model = model_setup(args, vocab_size)
loss_model = ReconstructionLoss(w2v_model)

sampler = get_sampler(train_dataset,
args,
shuffle=False,
distributed=distributed,
rank=rank)
train_dataloader = DataLoader(train_dataset,
shuffle=False,
batch_size=args.batch_size,
sampler=sampler)

if args.save_selected_epoch:
selected_epochs = args.save_selected_epoch.split('_')
selected_epochs = [int(epoch)-1 for epoch in selected_epochs]

trainer = Trainer(loss_model,
w2v_model,
train_dataloader,
device='cuda',
args=args,
epochs=args.epoch,
optimizer_class=torch.optim.Adam,
optimizer_params={'lr': args.lr},
weight_decay=0.0,
output_path=result_dest,
save_best_model=True,
max_grad_norm=999999.0,
use_amp=True,
rank=rank,
save_per_epoch=False,
save_selected_epoch=selected_epochs,
)

trainer.run()

if args.save_vector:
save_vector(args, result_dest, selected_epochs)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--rank', type=int, default=0, help='')

# Training parameter
parser.add_argument('--epoch', dest='epoch', type=int, default=300, help='number of epochs to learn')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=200, help='minibatch size')
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--seed', type=int, default=0)

# Model parameter
parser.add_argument('--embed_dim', dest='embed_dim', type=int, default=300)
parser.add_argument('--adv_freq_thresh', type=int, default=200000)
parser.add_argument('--adv_lr', type=float, default=0.02)
parser.add_argument('--adv_wdecay', type=float, default=1.2e-6)
parser.add_argument('--adv_lambda', type=float, default=0.02)

# Training flag
parser.add_argument('--test', action='store_true', help='use tiny dataset')
parser.add_argument('--inference', action='store_true')
parser.add_argument('--normalize', action='store_true')

# Other flag
parser.add_argument('--discard_long_word', action='store_true')

# Data path
parser.add_argument('--ref_vec_path', type=str, default="")
parser.add_argument('--result_dir', type=str, default="")

parser.add_argument('--filtering_words_path', type=str, default="")
parser.add_argument('--freq_path', type=str, default="")
parser.add_argument('--out_prefix', type=str, default="")

# For saving
parser.add_argument('--save_vector', action='store_true')
parser.add_argument('--save_selected_epoch', type=str, default=None)
parser.add_argument('--save_vector_name', type=str, default='vector.txt')

args = parser.parse_args()
return args


if __name__ == '__main__':
arguments = parse_args()
main(arguments)
Loading

0 comments on commit 203c123

Please sign in to comment.