-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
losyer
committed
Jun 4, 2023
1 parent
043908c
commit 203c123
Showing
25 changed files
with
2,468 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
# whitening_effect | ||
|
||
The codes will be released after the paper of this study is accepted. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from global_utils.global_utils import * |
31 changes: 31 additions & 0 deletions
31
recon_with_debias/global_utils/global_utils/global_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
|
||
colors = ['aqua', 'aquamarine', 'azure', 'beige', 'bisque', 'black', 'blanchedalmond', 'blue', 'blueviolet', 'brown', 'burlywood', 'cadetblue', 'chartreuse', 'chocolate', 'coral', 'cornflowerblue', 'cornsilk', 'crimson', 'cyan', 'darkblue', 'darkcyan', 'darkgoldenrod', 'darkgray', 'darkgreen', 'darkgrey', 'darkkhaki', 'darkmagenta', 'darkolivegreen', 'darkorange', 'darkorchid', 'darkred', 'darksalmon', 'darkseagreen', 'darkslateblue', 'darkslategray', 'darkslategrey', 'darkturquoise', 'darkviolet', 'deeppink', 'deepskyblue', 'dimgray', 'dimgrey', 'dodgerblue', 'firebrick', 'floralwhite', 'forestgreen', 'fuchsia', 'gainsboro', 'ghostwhite', 'gold', 'goldenrod', 'gray', 'green', 'greenyellow', 'grey', 'honeydew', 'hotpink', 'indianred', 'indigo', 'ivory', 'khaki', 'lavender', 'lavenderblush', 'lawngreen', 'lemonchiffon', 'lightblue', 'lightcoral', 'lightcyan', 'lightgoldenrodyellow', 'lightgray', 'lightgreen', 'lightgrey', 'lightpink', 'lightsalmon', 'lightseagreen', 'lightskyblue', 'lightslategray', 'lightslategrey', 'lightsteelblue', 'lightyellow', 'lime', 'limegreen', 'linen', 'magenta', 'maroon', 'mediumaquamarine', 'mediumblue', 'mediumorchid', 'mediumpurple', 'mediumseagreen', 'mediumslateblue', 'mediumspringgreen', 'mediumturquoise', 'mediumvioletred', 'midnightblue', 'mintcream', 'mistyrose', 'moccasin', 'navajowhite', 'navy', 'oldlace', 'olive', 'olivedrab', 'orange', 'orangered', 'orchid', 'palegoldenrod', 'palegreen', 'paleturquoise', 'palevioletred', 'papayawhip', 'peachpuff', 'peru', 'pink', 'plum', 'powderblue', 'purple', 'rebeccapurple', 'red', 'rosybrown', 'royalblue', 'saddlebrown', 'salmon', 'sandybrown', 'seagreen', 'seashell', 'sienna', 'silver', 'skyblue', 'slateblue', 'slategray', 'slategrey', 'snow', 'springgreen', 'steelblue', 'tan', 'teal', 'thistle', 'tomato', 'turquoise', 'violet', 'wheat', 'white', 'whitesmoke', 'yellow', 'yellowgreen'] | ||
|
||
|
||
def test(): | ||
print('test') | ||
|
||
|
||
def prepare_directory(path, loop_num=100, sleep_sec=3, test=False): | ||
from time import sleep | ||
import datetime | ||
import os | ||
|
||
t_delta = datetime.timedelta(hours=9) | ||
JST = datetime.timezone(t_delta, 'JST') | ||
|
||
for _ in range(loop_num): | ||
if test: | ||
raise NotImplementedError | ||
else: | ||
modified_path = path + '_' + \ | ||
datetime.datetime.now(JST).strftime("%Y-%m-%d_%H-%M-%S") | ||
try: | ||
os.makedirs(modified_path) | ||
break | ||
except: | ||
print(modified_path) | ||
print(f'Directory name conflict: sleep {sleep_sec} sec.', flush=True) | ||
sleep(sleep_sec) | ||
|
||
return modified_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from setuptools import setup, find_packages | ||
|
||
setup( | ||
name='global_utils', | ||
version='0.1', | ||
packages=find_packages() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class ReconstructionLoss(nn.Module): | ||
|
||
def __init__(self, w2v_model): | ||
super(ReconstructionLoss, self).__init__() | ||
self.w2v_model = w2v_model | ||
|
||
def forward(self, word_idx, ref_vector, freq=None, normalize=False): | ||
vec = self.w2v_model(word_idx) | ||
if normalize: | ||
vec = F.normalize(vec) | ||
ref_vector = F.normalize(ref_vector) | ||
loss_fn = nn.MSELoss(reduction='none') | ||
loss = loss_fn(vec, ref_vector).mean(axis=1) | ||
if freq != None: | ||
loss = loss * freq.log() | ||
|
||
return loss.mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from torch import nn | ||
|
||
|
||
class ReconModel(nn.Module): | ||
|
||
def __init__(self, vocab_size, embed_dim): | ||
super(ReconModel, self).__init__() | ||
self.vocab_size = vocab_size | ||
self.embedding = nn.Embedding(vocab_size, | ||
embed_dim) | ||
self.init_weights() | ||
|
||
def init_weights(self): | ||
self.embedding.weight.data.normal_(mean=0, std=0.01) | ||
|
||
def forward(self, idx): | ||
embeddings = self.embedding(idx) | ||
return embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# -*- coding: utf-8 -*- | ||
import codecs | ||
import numpy as np | ||
from utils.utils import get_total_line | ||
import torch | ||
import logging | ||
|
||
logger = logging.getLogger() | ||
|
||
|
||
class DataHandler(object): | ||
def __init__(self, args): | ||
self.args = args | ||
self.ref_vec_path = args.ref_vec_path | ||
self.test = args.test | ||
if not self.args.inference and self.args.freq_path != "": | ||
self.create_word_to_freq_dic() | ||
self.filtering_words = self.load_filtering_words(args.filtering_words_path) | ||
self.loaded_words_set = set() | ||
self.total_line = None | ||
|
||
def get_freq(self, word): | ||
try: | ||
return self.word_to_freq_dic[word] | ||
except: | ||
return 2 | ||
|
||
def create_word_to_freq_dic(self): | ||
logger.info("Create word to frequency dictionary ...") | ||
self.word_to_freq_dic = {} | ||
for line in codecs.open(self.args.freq_path, "r", 'utf-8', errors='replace'): | ||
col = line.strip().split("\t") | ||
assert len(col) == 2 | ||
word, freq = col[0], int(col[1]) | ||
self.word_to_freq_dic[word] = freq | ||
logger.info("Create word to frequency dictionary ... done") | ||
|
||
def load_filtering_words(self, path): | ||
if path == '': | ||
return set() | ||
else: | ||
logger.info('Loading filtering words ...') | ||
words = set() | ||
for line in codecs.open(path, "r", 'utf-8', errors='replace'): | ||
word = line.strip() | ||
words.add(word) | ||
logger.info('Loading filtering words ... done') | ||
|
||
return words | ||
|
||
def prepare_dataset(self): | ||
logger.info("Create dataset ...") | ||
self.train_data = self.load_dataset() | ||
|
||
# Optional Handling | ||
# pass | ||
|
||
logger.info("Create dataset ... done") | ||
|
||
def load_dataset(self): | ||
dataset = [] | ||
self.total_line = get_total_line(path=self.ref_vec_path, test=self.test) | ||
skipped_word_count = 0 | ||
with codecs.open(self.ref_vec_path, "r", 'utf-8', errors='replace') as input_data: | ||
for i, line in enumerate(input_data): | ||
|
||
if i % int(self.total_line / 10) == 0: | ||
logger.info('{} % done'.format(round(i / (self.total_line / 100)))) | ||
|
||
if i == 0: | ||
# Get headder info. | ||
col = line.strip('\n').split() | ||
vocab_size, dim = int(col[0]), int(col[1]) | ||
continue | ||
col = line.rstrip(' \n').rsplit(' ', dim) | ||
word = col[0] | ||
|
||
# if self.args.inference: | ||
# raise NotImplementedError | ||
# else: | ||
# Skip special conditions | ||
if word in self.filtering_words \ | ||
or (len(word) > 30 and self.args.discard_long_word) \ | ||
or len(col) != dim + 1: | ||
print(line) | ||
skipped_word_count += 1 | ||
continue | ||
|
||
self.loaded_words_set.add(word) | ||
|
||
word_idx = i - 1 - skipped_word_count | ||
word_idx_array = np.array(word_idx, dtype=np.int32) | ||
ref = [None] if self.args.inference else col[1:] | ||
ref_vector = np.array(ref, dtype=np.float32) | ||
assert len(ref_vector) == dim | ||
y = np.array(0, dtype=np.int32) | ||
freq = self.get_freq(word) if self.args.freq_path != "" else 1 | ||
freq_array = np.array(freq, dtype=np.float32) | ||
freq_label = torch.tensor(0).long() if word_idx < self.args.adv_freq_thresh else torch.tensor(1).long() | ||
|
||
dataset = self.set_dataset(dataset, word_idx_array, ref_vector, freq_array, freq_label, y) | ||
|
||
if self.test and len(dataset) == 1000: | ||
logger.info("Prepared small dataset for quick test.") | ||
break | ||
logger.info(f'len(dataset) = {len(dataset)}') | ||
assert len(dataset) == word_idx + 1 | ||
return dataset | ||
|
||
def set_dataset(self, dataset, word_idx, ref_vector, freq_array, freq_label, y): | ||
dataset.append((word_idx, ref_vector, freq_array, freq_label, y)) | ||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# coding: utf-8 | ||
import argparse | ||
import json | ||
import os | ||
import sys | ||
import numpy as np | ||
import random | ||
import torch | ||
from torch.utils.data import Dataset, DataLoader, DistributedSampler | ||
from models.model import ReconModel | ||
from losses.loss import ReconstructionLoss | ||
from utils.load_model import load_words | ||
from utils.load_model import output_vector | ||
from prepare_dataset import DataHandler | ||
from trainer import Trainer | ||
from global_utils import prepare_directory | ||
|
||
import logging | ||
|
||
|
||
def init_logger(name): | ||
logger = logging.getLogger(name) | ||
logger.setLevel(logging.INFO) | ||
|
||
handler = logging.StreamHandler(sys.stdout) | ||
handler.setLevel(logging.DEBUG) | ||
log_format = '%(asctime)s [%(levelname)-8s] [%(module)s#%(funcName)s %(lineno)d] %(message)s' | ||
formatter = logging.Formatter(log_format) | ||
handler.setFormatter(formatter) | ||
handler.flush = sys.stdout.flush | ||
if not logger.hasHandlers(): | ||
logger.addHandler(handler) | ||
return logger | ||
logger = init_logger(None) | ||
|
||
|
||
def save_vector(args, result_dest, selected_epochs): | ||
words = load_words(args, args.test) | ||
vocab_size = len(words) | ||
|
||
for epoch_num in selected_epochs: | ||
epoch_num = epoch_num + 1 | ||
w2v_model = model_setup(args, vocab_size) | ||
model_path = result_dest + f'/selected_epoch_{epoch_num}/word_embeddings.bin' | ||
w2v_model.load_state_dict(torch.load(model_path)) | ||
embeddings = w2v_model.embedding.weight.detach().numpy() | ||
output_path = result_dest + f'/selected_epoch_{epoch_num}/' | ||
name = output_path + f'{args.save_vector_name}' | ||
output_vector(words, embeddings, args.embed_dim, name=name) | ||
print('Done') | ||
|
||
|
||
def get_sampler(split, args, shuffle=False, distributed=False, rank=0): | ||
if distributed: | ||
return DistributedSampler(split, num_replicas=args.n_gpu, rank=rank, shuffle=shuffle) | ||
else: | ||
return None | ||
|
||
|
||
def initial_setup(args): | ||
# Setup result directory | ||
result_dest_name = args.result_dir | ||
result_dest = prepare_directory(result_dest_name+'/') | ||
|
||
with open(os.path.join(result_dest, "settings.json"), "w") as fo: | ||
fo.write(json.dumps(vars(args), sort_keys=True, indent=4)) | ||
print() | ||
print('###########') | ||
print('#Arguments#') | ||
print('###########') | ||
print(json.dumps(vars(args), sort_keys=True, indent=4), flush=True) | ||
|
||
logger.info("result dest: " + result_dest) | ||
return result_dest | ||
|
||
|
||
def model_setup(args, vocab_size=1000): | ||
|
||
model = ReconModel(vocab_size=vocab_size, | ||
embed_dim=args.embed_dim) | ||
return model | ||
|
||
|
||
def main(args): | ||
np.random.seed(args.seed) | ||
random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
result_dest = initial_setup(args) | ||
distributed = False | ||
rank = args.rank | ||
|
||
# Data setup | ||
data_handler = DataHandler(args) | ||
data_handler.prepare_dataset() | ||
train_dataset = data_handler.train_data | ||
vocab_size = len(train_dataset) | ||
|
||
# Model setup | ||
w2v_model = model_setup(args, vocab_size) | ||
loss_model = ReconstructionLoss(w2v_model) | ||
|
||
sampler = get_sampler(train_dataset, | ||
args, | ||
shuffle=False, | ||
distributed=distributed, | ||
rank=rank) | ||
train_dataloader = DataLoader(train_dataset, | ||
shuffle=False, | ||
batch_size=args.batch_size, | ||
sampler=sampler) | ||
|
||
if args.save_selected_epoch: | ||
selected_epochs = args.save_selected_epoch.split('_') | ||
selected_epochs = [int(epoch)-1 for epoch in selected_epochs] | ||
|
||
trainer = Trainer(loss_model, | ||
w2v_model, | ||
train_dataloader, | ||
device='cuda', | ||
args=args, | ||
epochs=args.epoch, | ||
optimizer_class=torch.optim.Adam, | ||
optimizer_params={'lr': args.lr}, | ||
weight_decay=0.0, | ||
output_path=result_dest, | ||
save_best_model=True, | ||
max_grad_norm=999999.0, | ||
use_amp=True, | ||
rank=rank, | ||
save_per_epoch=False, | ||
save_selected_epoch=selected_epochs, | ||
) | ||
|
||
trainer.run() | ||
|
||
if args.save_vector: | ||
save_vector(args, result_dest, selected_epochs) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--rank', type=int, default=0, help='') | ||
|
||
# Training parameter | ||
parser.add_argument('--epoch', dest='epoch', type=int, default=300, help='number of epochs to learn') | ||
parser.add_argument('--batch_size', dest='batch_size', type=int, default=200, help='minibatch size') | ||
parser.add_argument('--lr', type=float, default=0.0001) | ||
parser.add_argument('--seed', type=int, default=0) | ||
|
||
# Model parameter | ||
parser.add_argument('--embed_dim', dest='embed_dim', type=int, default=300) | ||
parser.add_argument('--adv_freq_thresh', type=int, default=200000) | ||
parser.add_argument('--adv_lr', type=float, default=0.02) | ||
parser.add_argument('--adv_wdecay', type=float, default=1.2e-6) | ||
parser.add_argument('--adv_lambda', type=float, default=0.02) | ||
|
||
# Training flag | ||
parser.add_argument('--test', action='store_true', help='use tiny dataset') | ||
parser.add_argument('--inference', action='store_true') | ||
parser.add_argument('--normalize', action='store_true') | ||
|
||
# Other flag | ||
parser.add_argument('--discard_long_word', action='store_true') | ||
|
||
# Data path | ||
parser.add_argument('--ref_vec_path', type=str, default="") | ||
parser.add_argument('--result_dir', type=str, default="") | ||
|
||
parser.add_argument('--filtering_words_path', type=str, default="") | ||
parser.add_argument('--freq_path', type=str, default="") | ||
parser.add_argument('--out_prefix', type=str, default="") | ||
|
||
# For saving | ||
parser.add_argument('--save_vector', action='store_true') | ||
parser.add_argument('--save_selected_epoch', type=str, default=None) | ||
parser.add_argument('--save_vector_name', type=str, default='vector.txt') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
if __name__ == '__main__': | ||
arguments = parse_args() | ||
main(arguments) |
Oops, something went wrong.