Skip to content

Commit

Permalink
GIT: modify code qual workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
anmorgunov committed May 31, 2024
1 parent 4b1002b commit d8a9f0e
Show file tree
Hide file tree
Showing 16 changed files with 81 additions and 107 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Code Quality
on: [push, pull_request]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: "3.11"
- run: pip install --upgrade pip
- run: pip install ruff mypy black pytest
- run: black .
- run: ruff format
- run: ruff check --fix
- run: mypy --strict . --exclude=tests
- run: pytest -v
11 changes: 0 additions & 11 deletions .github/workflows/ruff.yml

This file was deleted.

69 changes: 9 additions & 60 deletions Data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
generate_permutations,
)
from pathlib import Path
from typing import List, Tuple, Dict, Union, Set, Optional, cast
from typing import List, Dict, Union, Set, Optional, cast

data_path = Path(__file__).parent / "PaRoutes"
save_path = Path(__file__).parent / "Processed"
Expand All @@ -43,7 +43,7 @@


class PaRoutesDataset:
def __init__(self, data_path: Path, filename: str, verbose: bool = True):
def __init__(self, data_path: Path, filename: str, verbose: bool = True) -> None:
self.data_path = data_path
self.filename = filename
self.dataset = json.load(open(data_path.joinpath(filename), "r"))
Expand All @@ -52,70 +52,19 @@ def __init__(self, data_path: Path, filename: str, verbose: bool = True):

self.products: List[str] = []
self.filtered_data: FilteredType = []
self.path_strings: List[str] = []
self.max_steps: List[int] = []
self.SMs: List[List[str]] = []
# self.path_strings: List[str] = []
# self.max_steps: List[int] = []
# self.SMs: List[List[str]] = []

self.non_permuted_path_strings: List[str] = []
# self.non_permuted_path_strings: List[str] = []

def filter_dataset(self):
def filter_dataset(self) -> None:
if self.verbose:
print("- Filtering all_routes to remove meta data")
for route in tqdm(self.dataset):
filtered_node = filter_mol_nodes(route)
self.filtered_data.append(filtered_node)
self.products.append(filtered_node["smiles"])

def compress_to_string(self):
if self.verbose:
print(
"- Compressing python dictionaries into python strings and generating permutations"
)

for filtered_route in tqdm(self.filtered_data):
permuted_path_strings = generate_permutations(filtered_route)
# permuted_path_strings = [str(data).replace(" ", "")]
self.path_strings.append(permuted_path_strings)
self.non_permuted_path_strings.append(str(filtered_route).replace(" ", ""))

def find_max_depth(self):
if self.verbose:
print("- Finding the max depth of each route tree")
for filtered_route in tqdm(self.filtered_data):
self.max_steps.append(max_tree_depth(filtered_route))

def find_all_leaves(self):
if self.verbose:
print("- Finding all leaves of each route tree")
for filtered_route in tqdm(self.filtered_data):
self.SMs.append(find_leaves(filtered_route))

def preprocess(self):
self.filter_dataset()
self.compress_to_string()
self.find_max_depth()
self.find_all_leaves()

def prepare_final_datasets(
self, exclude: Optional[Set[int]] = None
) -> Tuple[Dataset, Dataset]:
if exclude is None:
exclude = set()
dataset: Dataset = []
dataset_each_sm: Dataset = []
for i in tqdm(range(len(self.products))):
if i in exclude:
continue
entry: DatasetEntry = {
"train_ID": i,
"product": self.products[i],
"path_strings": self.path_strings[i],
"max_step": self.max_steps[i],
}
dataset.append(entry | {"all_SM": self.SMs[i]})
for sm in self.SMs[i]:
dataset_each_sm.append({**entry, "SM": sm})
return (dataset, dataset_each_sm)
self.products.append(cast(str, filtered_node["smiles"]))

def prepare_final_dataset_v2(
self,
Expand Down Expand Up @@ -237,7 +186,7 @@ def prepare_final_dataset_v2(
# ------- Remove SM info from datasets -------


def remove_sm_from_ds(load_path: Path, save_path: Path):
def remove_sm_from_ds(load_path: Path, save_path: Path) -> None:
products, _, path_strings, n_steps_lists = pickle.load(open(load_path, "rb"))
pickle.dump((products, path_strings, n_steps_lists), open(save_path, "wb"))

Expand Down
3 changes: 1 addition & 2 deletions DirectMultiStep/Models/Architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@


class ModelConfig:

def __init__(
self,
input_dim: int,
Expand Down Expand Up @@ -150,7 +149,7 @@ def forward(
attn_output_BHLD.permute(0, 2, 1, 3).contiguous().view(B, L, self.hid_dim)
)
output_BLD = cast(Tensor, self.projection(attn_output_BLD))
return output_BLD
return output_BLD


class PositionwiseFeedforwardLayer(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions DirectMultiStep/Models/Configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

import torch
import torch.nn as nn
from .Architecture import Encoder, Decoder, Seq2Seq, ModelConfig
from DirectMultiStep.Models.Architecture import Encoder, Decoder, Seq2Seq, ModelConfig


def count_parameters(model:nn.Module)->int:
def count_parameters(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)


Expand All @@ -40,7 +40,7 @@ def determine_device(allow_mps: bool = False) -> str:
return device


def prepare_model(enc_config:ModelConfig, dec_config:ModelConfig)->nn.Module:
def prepare_model(enc_config: ModelConfig, dec_config: ModelConfig) -> nn.Module:
device = torch.device(determine_device())
encoder = Encoder(config=enc_config, device=device)
decoder = Decoder(config=dec_config, device=device)
Expand Down
6 changes: 4 additions & 2 deletions DirectMultiStep/Models/Generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def _select_top_k_candidates(
return best_k_B_nt

def _generate_final_outputs(
self, beam_idxs_BSL_nt: List[List[List[int]]], beam_log_probs_BS_nt: npt.NDArray[np.float64]
self,
beam_idxs_BSL_nt: List[List[List[int]]],
beam_log_probs_BS_nt: npt.NDArray[np.float64],
) -> BeamSearchOutput:
"""Convert index sequences to final outputs."""
B = len(beam_idxs_BSL_nt)
Expand All @@ -183,7 +185,7 @@ def _generate_final_outputs(

return outputs_B2_nt

def decode(self, src_BC: Tensor, steps_B1: Tensor)->BeamSearchOutput:
def decode(self, src_BC: Tensor, steps_B1: Tensor) -> BeamSearchOutput:
"""
src_BC: product + one_sm
steps_B1: number of steps
Expand Down
17 changes: 9 additions & 8 deletions DirectMultiStep/Models/Training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@
from typing import Callable, Optional, Tuple, List, Dict, Any, cast
import torch
import torch.nn as nn
from .Architecture import Seq2Seq

Tensor = torch.Tensor


def _warmup_and_cosine_decay(
warmup_steps: int, decay_steps: int, decay_factor: float
) -> Callable[[int], float]:
def _get_new_lr(step:int)->float:
def _get_new_lr(step: int) -> float:
if step < warmup_steps:
return step / warmup_steps
elif step >= warmup_steps and step < warmup_steps + decay_steps:
Expand All @@ -55,7 +54,7 @@ def __init__(
warmup_steps: int = 4000,
decay_steps: int = 24000,
decay_factor: float = 0.1,
model: Optional[Seq2Seq] = None,
model: Optional[nn.Module] = None,
criterion: Optional[nn.Module] = None,
):
super().__init__()
Expand All @@ -81,7 +80,7 @@ def mask_src(self, src_BC: Tensor, masking_prob: float) -> Tensor:
masked_src_BC[final_mask_BC] = self.mask_idx
return masked_src_BC

def compute_loss(self, batch:Tensor, batch_idx:int)->Tensor:
def compute_loss(self, batch: Tensor, batch_idx: int) -> Tensor:
"""
enc_item - product_item + one_sm_item
dec_item - path_string
Expand All @@ -99,7 +98,7 @@ def compute_loss(self, batch:Tensor, batch_idx:int)->Tensor:
self.processed_tokens += tgt_item_BL.shape[0] * tgt_item_BL.shape[1]
return cast(Tensor, loss)

def log_step_info(self, loss:Tensor, mode: str, prog_bar: bool)->None:
def log_step_info(self, loss: Tensor, mode: str, prog_bar: bool) -> None:
self.log(
f"{mode}_loss",
loss,
Expand All @@ -114,17 +113,19 @@ def log_step_info(self, loss:Tensor, mode: str, prog_bar: bool)->None:
f"{mode}_lr", current_lr, batch_size=self.batch_size, sync_dist=True
)

def training_step(self, batch:Tensor, batch_idx:int)->Tensor:
def training_step(self, batch: Tensor, batch_idx: int) -> Tensor:
loss = self.compute_loss(batch, batch_idx)
self.log_step_info(loss, "train", prog_bar=True)
return loss

def validation_step(self, batch:Tensor, batch_idx:int)->Tensor:
def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor:
loss = self.compute_loss(batch, batch_idx)
self.log_step_info(loss, "val", prog_bar=True)
return loss

def configure_optimizers(self)->Tuple[List[torch.optim.Optimizer], List[Dict[str, Any]]]:
def configure_optimizers(
self,
) -> Tuple[List[torch.optim.Optimizer], List[Dict[str, Any]]]:
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
# return optimizer
scheduler = torch.optim.lr_scheduler.LambdaLR(
Expand Down
2 changes: 1 addition & 1 deletion DirectMultiStep/Utils/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def tokenize_path_string(path_string: str, add_eos: bool = True) -> List[str]:
return tokens


class RoutesDataset(Dataset[Tuple[torch.Tensor,...]]):
class RoutesDataset(Dataset[Tuple[torch.Tensor, ...]]):
def __init__(self, metadata_path: str) -> None:
self.products: List[str] = []

Expand Down
4 changes: 3 additions & 1 deletion DirectMultiStep/Utils/PostProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def find_matching_paths(
if verbose
else zip(paths_NS2n, correct_paths)
)
for pathreac_S2n, correct_path in cast(Iterator[Tuple[BeamProcessedType, str]], iterator):
for pathreac_S2n, correct_path in cast(
Iterator[Tuple[BeamProcessedType, str]], iterator
):
path_match = None
path_match_perm = None
for rank, (path, _) in enumerate(pathreac_S2n):
Expand Down
26 changes: 19 additions & 7 deletions DirectMultiStep/Utils/Visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

import os
import cairo
from rdkit import Chem # type: ignore
from rdkit.Chem import Draw # type: ignore
import cairosvg # type: ignore
from rdkit import Chem # type: ignore
from rdkit.Chem import Draw # type: ignore
import cairosvg # type: ignore
from pathlib import Path
from typing import Dict, List, Tuple, Union, cast

Expand Down Expand Up @@ -120,7 +120,12 @@ def check_overlap(


def draw_rounded_rectangle(
ctx: cairo.Context, x: int, y: int, width: int, height: int, corner_radius: int # type: ignore
ctx: cairo.Context,
x: int,
y: int,
width: int,
height: int,
corner_radius: int, # type: ignore
) -> None:
"""Draws a rounded rectangle."""
ctx.new_sub_path()
Expand Down Expand Up @@ -152,7 +157,7 @@ def draw_molecule_tree(
height: int = 400,
x_margin: int = 50,
y_margin: int = 50,
)->None:
) -> None:
canvas_width, canvas_height = compute_canvas_dimensions(
tree, width, height, y_margin
)
Expand All @@ -161,7 +166,7 @@ def draw_molecule_tree(

existing_boxes: List[Tuple[int, int]] = []

def draw_node(node:'RetroSynthesisTree', x:int, y:int)->None:
def draw_node(node: "RetroSynthesisTree", x: int, y: int) -> None:
# Check for overlap and adjust position
while check_overlap(x, y, existing_boxes, width, height) or check_overlap(
x, y - y_margin, existing_boxes, width, height
Expand Down Expand Up @@ -241,7 +246,14 @@ def draw_node(node:'RetroSynthesisTree', x:int, y:int)->None:
os.remove("temp.png")


def draw_tree_from_path_string(path_string: str, save_path: Path, width: int = 400, height: int = 400, x_margin: int = 50, y_margin: int = 100)->None:
def draw_tree_from_path_string(
path_string: str,
save_path: Path,
width: int = 400,
height: int = 400,
x_margin: int = 50,
y_margin: int = 100,
) -> None:
assert save_path.suffix == "", "Please provide a path without extension"
retro_tree = create_tree_from_path_string(path_string=path_string)

Expand Down
2 changes: 1 addition & 1 deletion DirectMultiStep/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def parse_version(ckpt: Path) -> int:
return sorted(last_checkpoints, key=parse_version, reverse=True)[0]

# If no "last" file, find the checkpoint with the largest epoch and step
def parse_epoch_step(filename: str):
def parse_epoch_step(filename: str) -> Tuple[int, int]:
# This pattern will match 'epoch=X-step=Y.ckpt' and extract X and Y
match = re.search(r"epoch=(\d+)-step=(\d+)\.ckpt", filename)
if match:
Expand Down
6 changes: 3 additions & 3 deletions DirectMultiStep/tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
# SOFTWARE.

import pytest
from ..Utils.PreProcess import (
from DirectMultiStep.Utils.PreProcess import (
filter_mol_nodes,
max_tree_depth,
find_leaves,
generate_permutations,
)
from .test_data import (
from DirectMultiStep.tests.test_data import (
test1_leaves,
test2_depth1,
test3_depth2,
Expand All @@ -39,7 +39,7 @@
test9_tknz_smiles,
test10_tknz_path,
)
from ..Utils.Dataset import tokenize_smile, tokenize_path_string
from DirectMultiStep.Utils.Dataset import tokenize_smile, tokenize_path_string

test_filtering_and_depth = [
pytest.param(test1_leaves, 0, id="leaves"),
Expand Down
Loading

0 comments on commit d8a9f0e

Please sign in to comment.