Skip to content

Commit

Permalink
v0.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Peotr Zagubisalo committed Jun 27, 2022
1 parent 903260a commit dc0dc84
Show file tree
Hide file tree
Showing 16 changed files with 278 additions and 368 deletions.
Binary file added output_new/epoch-3000-mu-corr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added output_new/epoch-3000-mu-dist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added output_new/epoch-3000-mu-rot-dist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added output_new/epoch-3000-subdec-mu-dist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes.
9 changes: 5 additions & 4 deletions preprocess_db/add_rare_types_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pandas import DataFrame

MISSING_TYPES: Tuple[int, ...] = tuple(range(1, 17))
EXTRA_NO_SELF_TYPES = True
EXTRA_NO_SELF_TYPES = False
MALE_LABEL_SHIFT = 16 # `get_weight` function assumes this


Expand Down Expand Up @@ -139,12 +139,13 @@ def types_tal_good_mask(df: DataFrame,

def smart_coincide_2(
tal_profs: Array, types_self: Array, types_tal: Array, males: Array,
threshold: int = -80, # was 90,
threshold: int = 95, # was 90,
thresholds_males: Tuple[Tuple[int, Tuple[int, ...]], ...] = (
(85, (4,)), (95, (3, 10, 16)), (-90, (12, 13))
# (85, (4,)), (95, (3, 10, 16)), (-90, (12, 13))
(85, (4,)),
),
thresholds_females: Tuple[Tuple[int, Tuple[int, ...]], ...] = (
(95, (16,)),
# (95, (16,)),
)) -> Array:
return smart_coincide_1(tal_profs=tal_profs, types_self=types_self, types_tal=types_tal, males=males,
threshold=threshold,
Expand Down
10 changes: 5 additions & 5 deletions train/betatcvae/kld.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def forward(self, z: Tensor, *qz_params: Tensor) -> Tuple[Tensor, Dict[str, Tens
def extra_repr(self) -> str:
return f'dataset_size={self.dataset_size}'

# def kld(self, z: Tensor, *qz_params: Tensor) -> Tensor:
# n = z.shape[0] # batch_size
# logqz_x = self.q_dist(z, *qz_params).view(n, -1).sum(dim=1)
# logpz = self.prior_dist(z).view(n, -1).sum(dim=1)
# return logqz_x - logpz
def kld(self, z: Tensor, *qz_params: Tensor) -> Tensor:
n = z.shape[0] # batch_size
logqz_x = self.q_dist(z, *qz_params).view(n, -1).sum(dim=1)
logpz = self.prior_dist(z).view(n, -1).sum(dim=1)
return logqz_x - logpz
9 changes: 4 additions & 5 deletions train/jats/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,20 @@


class PlotCallback(pl.Callback):
def __init__(self, df: DataFrame, interesting_ids_ast_file_path: str, plot_every_n_epoch: int = 0, verbose=False):
def __init__(self, df: DataFrame, interesting_ids_ast_file_path: str, plot_every_n_epoch: int = 0):
super(PlotCallback, self).__init__()
self.plot_vae_args_weighted_batch = get_plot_vae_args_weighted_batch(df)
self.plot_vae_args__y__w = get_plot_vae_args__y__w(df)
with open(interesting_ids_ast_file_path, 'r', encoding='utf-8') as f:
self.interesting_ids: Tuple[int, ...] = tuple(ast.literal_eval(f.read()))
self.df = df
self.verbose = verbose
self.plot_every_n_epoch = plot_every_n_epoch

def plot(self, trainer, pl_module):
if trainer.current_epoch == 0:
return
if self.plot_every_n_epoch > 0:
if trainer.current_epoch % self.plot_every_n_epoch != 0:
if (trainer.current_epoch != 10) and (trainer.current_epoch % self.plot_every_n_epoch != 0):
return
prefix = path.join(pl_module.logger.log_dir, f'epoch-{pl_module.current_epoch}-')

Expand All @@ -35,8 +34,8 @@ def plot(self, trainer, pl_module):
plot_dist(mu_wb, z_beta_wb, mu, y, prefix + 'mu-dist', axis_name='μ')

jats = pl_module.jatsregularizer
mu_rot = jats.cat_rot_2d(mu)
plot_dist(jats.cat_rot_2d(mu_wb), jats.cat_rot_2d(z_beta_wb), mu_rot, y,
mu_rot = jats.cat_rot_np(mu)
plot_dist(jats.cat_rot_np(mu_wb), jats.cat_rot_np(z_beta_wb), mu_rot, y,
prefix + 'mu-rot-dist', axis_name='μ_{rot}')
plot_dist(subdec_mu_wb, None, subdec_mu, y, prefix + 'subdec-mu-dist', axis_name='s(μ)')

Expand Down
66 changes: 57 additions & 9 deletions train/jats/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_labeled_mask(df: DataFrame) -> Array:
return df['smart_coincide'].values > 0


def get_loader(df: DataFrame, mode: str, batch_size: int, num_workers: int = 0) -> DataLoader:
def get_loader(df: DataFrame, mode: str, batch_size: int, num_workers: int = 0, verbose=False) -> DataLoader:
"""
Returns a loader of the output format:
Expand All @@ -100,23 +100,26 @@ def get_loader(df: DataFrame, mode: str, batch_size: int, num_workers: int = 0)
"""
if mode not in ('unlbl', 'lbl', 'both', 'plot'): raise ValueError

def get_(x_: Array, verbose_=verbose) -> Tuple[Tensor, ...]:
return (Tensor(x_).to(dtype=tr.long),) if verbose_ else ()

if mode is 'unlbl':
_, passthr, x, e_ext, _, weights, _ = get_data(df)
dataset = TensorDataset(Tensor(x), Tensor(e_ext), Tensor(passthr))
ids, passthr, x, e_ext, _, weights, wtarget = get_data(df)
dataset = TensorDataset(Tensor(x), Tensor(e_ext), Tensor(passthr), *get_(ids), *get_(wtarget))
sampler = WeightedRandomSampler(weights=weights.astype(np.float64),
num_samples=len(x))
return DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)

if mode == 'lbl':
_, passthr, x, e_ext, target, weights, _ = get_data(df[get_labeled_mask(df)])
ids, passthr, x, e_ext, target, weights, wtarget = get_data(df[get_labeled_mask(df)])
dataset = TensorDataset(Tensor(x), Tensor(e_ext),
Tensor(target).to(dtype=tr.long), Tensor(passthr))
Tensor(target).to(dtype=tr.long), Tensor(passthr), *get_(ids), *get_(wtarget))
sampler = WeightedRandomSampler(weights=weights.astype(np.float64),
num_samples=len(x))
return DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)

if mode == 'both':
_, passthr, x, e_ext, target, weights, _ = get_data(df)
ids, passthr, x, e_ext, target, weights, wtarget = get_data(df)
weights_lbl = np.copy(weights)
mask_lbl = get_labeled_mask(df)
weights_lbl[mask_lbl] = get_data(df[mask_lbl])[5]
Expand All @@ -125,12 +128,57 @@ def get_loader(df: DataFrame, mode: str, batch_size: int, num_workers: int = 0)
Tensor(target).to(dtype=tr.long), Tensor(passthr),
Tensor(weights.astype(np.float64)),
Tensor(weights_lbl.astype(np.float64)),
Tensor(mask_lbl).to(dtype=tr.bool))
Tensor(mask_lbl).to(dtype=tr.bool),
*get_(ids), *get_(wtarget))
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)

if mode == 'plot':
_, passthr, x, e_ext, _, weights, _ = get_data(df[get_labeled_mask(df)])
dataset = TensorDataset(Tensor(x), Tensor(e_ext), Tensor(passthr))
ids, passthr, x, e_ext, _, weights, wtarget = get_data(df[get_labeled_mask(df)])
dataset = TensorDataset(Tensor(x), Tensor(e_ext), Tensor(passthr), *get_(ids), *get_(wtarget))
sampler = WeightedRandomSampler(weights=weights.astype(np.float64),
num_samples=len(x))
return DataLoader(dataset, batch_size=batch_size, sampler=sampler)


def test_loader(df_len: int, loader_verbose: DataLoader):
"""
:param df_len: len(df) expected.
:param loader_verbose: get_loader(df, 'unlbl', BATCH_SIZE, num_workers=NUM_WORKERS, verbose=True) expected.
:return:
"""
ids_, types_ = [], []
for smpl in loader_verbose:
id_i, t_i = smpl[-2], smpl[-1]
ids_.append(id_i.view(-1))
types_.append(t_i.view(-1))

ids = tr.cat(ids_).numpy()
uni, counts = np.unique(ids, return_counts=True)
print(df_len, len(ids), len(uni), list(reversed(sorted(counts)))[:100])
types = tr.cat(types_).numpy()
uni, counts = np.unique(types, return_counts=True)
print(df_len, counts)

for smpl in loader_verbose:
id_i, t_i = smpl[-2], smpl[-1]
ids_.append(id_i.view(-1))
types_.append(t_i.view(-1))

ids = tr.cat(ids_).numpy()
uni, counts = np.unique(ids, return_counts=True)
print(df_len, len(ids), len(uni), list(reversed(sorted(counts)))[:100])
types = tr.cat(types_).numpy()
uni, counts = np.unique(types, return_counts=True)
print(df_len, counts)

for smpl in loader_verbose:
id_i, t_i = smpl[-2], smpl[-1]
ids_.append(id_i.view(-1))
types_.append(t_i.view(-1))

ids = tr.cat(ids_).numpy()
uni, counts = np.unique(ids, return_counts=True)
print(df_len, len(ids), len(uni), list(reversed(sorted(counts)))[:100])
types = tr.cat(types_).numpy()
uni, counts = np.unique(types, return_counts=True)
print(df_len, counts)
4 changes: 4 additions & 0 deletions train/jats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def probs_temper(probs: Tensor) -> Tensor:
probs_[masks[4 - 1] | masks[8 - 1] | masks[12 - 1] | masks[16 - 1]] = 3
return probs_


def probs_quadraclub(probs: Tensor) -> Tensor:
""" (NTL, SFL, STC, NFC, SFC, NTC, NFL, STL) """
probs_ = probs.clone()
Expand All @@ -26,16 +27,19 @@ def probs_quadraclub(probs: Tensor) -> Tensor:
probs_[masks[15 - 1] | masks[16 - 1]] = 7
return probs_


def expand_quadraclub(probs: Tensor) -> Tensor:
""" (NTL, SFL, STC, NFC, SFC, NTC, NFL, STL) """
n, m = probs.shape
return probs.view(n, m, 1).expand(n, m, 2).reshape(n, m * 2)


def expand_temper(probs: Tensor) -> Tensor:
""" (EP/-IR, IJ/+IR, IP/-ER, EJ/+ER) """
n, m = probs.shape
return probs.view(n, 1, m).expand(n, 4, m).reshape(n, m * 4)


def expand_temper_to_stat_dyn(probs: Tensor) -> Tensor:
""" (EP/-IR, IJ/+IR, IP/-ER, EJ/+ER) """
n, m = probs.shape[0], probs.shape[1] // 2
Expand Down
Loading

0 comments on commit dc0dc84

Please sign in to comment.