From d8a9f0ea7296d7893733ca8cae5af173636e2210 Mon Sep 17 00:00:00 2001 From: Anton Morgunov Date: Fri, 31 May 2024 17:20:40 -0400 Subject: [PATCH] GIT: modify code qual workflow --- .github/workflows/quality.yml | 18 +++++++ .github/workflows/ruff.yml | 11 ---- Data/process.py | 69 ++++-------------------- DirectMultiStep/Models/Architecture.py | 3 +- DirectMultiStep/Models/Configure.py | 6 +-- DirectMultiStep/Models/Generation.py | 6 ++- DirectMultiStep/Models/Training.py | 17 +++--- DirectMultiStep/Utils/Dataset.py | 2 +- DirectMultiStep/Utils/PostProcess.py | 4 +- DirectMultiStep/Utils/Visualize.py | 26 ++++++--- DirectMultiStep/helpers.py | 2 +- DirectMultiStep/tests/test_preprocess.py | 6 +-- assess_single.py | 10 ++-- train_nosm.py | 2 +- train_wsm.py | 2 +- visualize_tree.py | 4 +- 16 files changed, 81 insertions(+), 107 deletions(-) create mode 100644 .github/workflows/quality.yml delete mode 100644 .github/workflows/ruff.yml diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml new file mode 100644 index 0000000..5af0703 --- /dev/null +++ b/.github/workflows/quality.yml @@ -0,0 +1,18 @@ +name: Code Quality +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: "3.11" + - run: pip install --upgrade pip + - run: pip install ruff mypy black pytest + - run: black . + - run: ruff format + - run: ruff check --fix + - run: mypy --strict . --exclude=tests + - run: pytest -v \ No newline at end of file diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml deleted file mode 100644 index 0ace30d..0000000 --- a/.github/workflows/ruff.yml +++ /dev/null @@ -1,11 +0,0 @@ -name: Ruff -on: [push, pull_request] -jobs: - clean-python: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: psf/black@stable - - uses: chartboost/ruff-action@v1 - with: - args: 'format --check' \ No newline at end of file diff --git a/Data/process.py b/Data/process.py index 59b2c2c..222a885 100644 --- a/Data/process.py +++ b/Data/process.py @@ -31,7 +31,7 @@ generate_permutations, ) from pathlib import Path -from typing import List, Tuple, Dict, Union, Set, Optional, cast +from typing import List, Dict, Union, Set, Optional, cast data_path = Path(__file__).parent / "PaRoutes" save_path = Path(__file__).parent / "Processed" @@ -43,7 +43,7 @@ class PaRoutesDataset: - def __init__(self, data_path: Path, filename: str, verbose: bool = True): + def __init__(self, data_path: Path, filename: str, verbose: bool = True) -> None: self.data_path = data_path self.filename = filename self.dataset = json.load(open(data_path.joinpath(filename), "r")) @@ -52,70 +52,19 @@ def __init__(self, data_path: Path, filename: str, verbose: bool = True): self.products: List[str] = [] self.filtered_data: FilteredType = [] - self.path_strings: List[str] = [] - self.max_steps: List[int] = [] - self.SMs: List[List[str]] = [] + # self.path_strings: List[str] = [] + # self.max_steps: List[int] = [] + # self.SMs: List[List[str]] = [] - self.non_permuted_path_strings: List[str] = [] + # self.non_permuted_path_strings: List[str] = [] - def filter_dataset(self): + def filter_dataset(self) -> None: if self.verbose: print("- Filtering all_routes to remove meta data") for route in tqdm(self.dataset): filtered_node = filter_mol_nodes(route) self.filtered_data.append(filtered_node) - self.products.append(filtered_node["smiles"]) - - def compress_to_string(self): - if self.verbose: - print( - "- Compressing python dictionaries into python strings and generating permutations" - ) - - for filtered_route in tqdm(self.filtered_data): - permuted_path_strings = generate_permutations(filtered_route) - # permuted_path_strings = [str(data).replace(" ", "")] - self.path_strings.append(permuted_path_strings) - self.non_permuted_path_strings.append(str(filtered_route).replace(" ", "")) - - def find_max_depth(self): - if self.verbose: - print("- Finding the max depth of each route tree") - for filtered_route in tqdm(self.filtered_data): - self.max_steps.append(max_tree_depth(filtered_route)) - - def find_all_leaves(self): - if self.verbose: - print("- Finding all leaves of each route tree") - for filtered_route in tqdm(self.filtered_data): - self.SMs.append(find_leaves(filtered_route)) - - def preprocess(self): - self.filter_dataset() - self.compress_to_string() - self.find_max_depth() - self.find_all_leaves() - - def prepare_final_datasets( - self, exclude: Optional[Set[int]] = None - ) -> Tuple[Dataset, Dataset]: - if exclude is None: - exclude = set() - dataset: Dataset = [] - dataset_each_sm: Dataset = [] - for i in tqdm(range(len(self.products))): - if i in exclude: - continue - entry: DatasetEntry = { - "train_ID": i, - "product": self.products[i], - "path_strings": self.path_strings[i], - "max_step": self.max_steps[i], - } - dataset.append(entry | {"all_SM": self.SMs[i]}) - for sm in self.SMs[i]: - dataset_each_sm.append({**entry, "SM": sm}) - return (dataset, dataset_each_sm) + self.products.append(cast(str, filtered_node["smiles"])) def prepare_final_dataset_v2( self, @@ -237,7 +186,7 @@ def prepare_final_dataset_v2( # ------- Remove SM info from datasets ------- -def remove_sm_from_ds(load_path: Path, save_path: Path): +def remove_sm_from_ds(load_path: Path, save_path: Path) -> None: products, _, path_strings, n_steps_lists = pickle.load(open(load_path, "rb")) pickle.dump((products, path_strings, n_steps_lists), open(save_path, "wb")) diff --git a/DirectMultiStep/Models/Architecture.py b/DirectMultiStep/Models/Architecture.py index f22733d..446b390 100644 --- a/DirectMultiStep/Models/Architecture.py +++ b/DirectMultiStep/Models/Architecture.py @@ -45,7 +45,6 @@ class ModelConfig: - def __init__( self, input_dim: int, @@ -150,7 +149,7 @@ def forward( attn_output_BHLD.permute(0, 2, 1, 3).contiguous().view(B, L, self.hid_dim) ) output_BLD = cast(Tensor, self.projection(attn_output_BLD)) - return output_BLD + return output_BLD class PositionwiseFeedforwardLayer(nn.Module): diff --git a/DirectMultiStep/Models/Configure.py b/DirectMultiStep/Models/Configure.py index d91766d..6a0389e 100644 --- a/DirectMultiStep/Models/Configure.py +++ b/DirectMultiStep/Models/Configure.py @@ -22,10 +22,10 @@ import torch import torch.nn as nn -from .Architecture import Encoder, Decoder, Seq2Seq, ModelConfig +from DirectMultiStep.Models.Architecture import Encoder, Decoder, Seq2Seq, ModelConfig -def count_parameters(model:nn.Module)->int: +def count_parameters(model: nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -40,7 +40,7 @@ def determine_device(allow_mps: bool = False) -> str: return device -def prepare_model(enc_config:ModelConfig, dec_config:ModelConfig)->nn.Module: +def prepare_model(enc_config: ModelConfig, dec_config: ModelConfig) -> nn.Module: device = torch.device(determine_device()) encoder = Encoder(config=enc_config, device=device) decoder = Decoder(config=dec_config, device=device) diff --git a/DirectMultiStep/Models/Generation.py b/DirectMultiStep/Models/Generation.py index 95ddff3..73e3f90 100644 --- a/DirectMultiStep/Models/Generation.py +++ b/DirectMultiStep/Models/Generation.py @@ -164,7 +164,9 @@ def _select_top_k_candidates( return best_k_B_nt def _generate_final_outputs( - self, beam_idxs_BSL_nt: List[List[List[int]]], beam_log_probs_BS_nt: npt.NDArray[np.float64] + self, + beam_idxs_BSL_nt: List[List[List[int]]], + beam_log_probs_BS_nt: npt.NDArray[np.float64], ) -> BeamSearchOutput: """Convert index sequences to final outputs.""" B = len(beam_idxs_BSL_nt) @@ -183,7 +185,7 @@ def _generate_final_outputs( return outputs_B2_nt - def decode(self, src_BC: Tensor, steps_B1: Tensor)->BeamSearchOutput: + def decode(self, src_BC: Tensor, steps_B1: Tensor) -> BeamSearchOutput: """ src_BC: product + one_sm steps_B1: number of steps diff --git a/DirectMultiStep/Models/Training.py b/DirectMultiStep/Models/Training.py index 893d31a..2037f58 100644 --- a/DirectMultiStep/Models/Training.py +++ b/DirectMultiStep/Models/Training.py @@ -25,7 +25,6 @@ from typing import Callable, Optional, Tuple, List, Dict, Any, cast import torch import torch.nn as nn -from .Architecture import Seq2Seq Tensor = torch.Tensor @@ -33,7 +32,7 @@ def _warmup_and_cosine_decay( warmup_steps: int, decay_steps: int, decay_factor: float ) -> Callable[[int], float]: - def _get_new_lr(step:int)->float: + def _get_new_lr(step: int) -> float: if step < warmup_steps: return step / warmup_steps elif step >= warmup_steps and step < warmup_steps + decay_steps: @@ -55,7 +54,7 @@ def __init__( warmup_steps: int = 4000, decay_steps: int = 24000, decay_factor: float = 0.1, - model: Optional[Seq2Seq] = None, + model: Optional[nn.Module] = None, criterion: Optional[nn.Module] = None, ): super().__init__() @@ -81,7 +80,7 @@ def mask_src(self, src_BC: Tensor, masking_prob: float) -> Tensor: masked_src_BC[final_mask_BC] = self.mask_idx return masked_src_BC - def compute_loss(self, batch:Tensor, batch_idx:int)->Tensor: + def compute_loss(self, batch: Tensor, batch_idx: int) -> Tensor: """ enc_item - product_item + one_sm_item dec_item - path_string @@ -99,7 +98,7 @@ def compute_loss(self, batch:Tensor, batch_idx:int)->Tensor: self.processed_tokens += tgt_item_BL.shape[0] * tgt_item_BL.shape[1] return cast(Tensor, loss) - def log_step_info(self, loss:Tensor, mode: str, prog_bar: bool)->None: + def log_step_info(self, loss: Tensor, mode: str, prog_bar: bool) -> None: self.log( f"{mode}_loss", loss, @@ -114,17 +113,19 @@ def log_step_info(self, loss:Tensor, mode: str, prog_bar: bool)->None: f"{mode}_lr", current_lr, batch_size=self.batch_size, sync_dist=True ) - def training_step(self, batch:Tensor, batch_idx:int)->Tensor: + def training_step(self, batch: Tensor, batch_idx: int) -> Tensor: loss = self.compute_loss(batch, batch_idx) self.log_step_info(loss, "train", prog_bar=True) return loss - def validation_step(self, batch:Tensor, batch_idx:int)->Tensor: + def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor: loss = self.compute_loss(batch, batch_idx) self.log_step_info(loss, "val", prog_bar=True) return loss - def configure_optimizers(self)->Tuple[List[torch.optim.Optimizer], List[Dict[str, Any]]]: + def configure_optimizers( + self, + ) -> Tuple[List[torch.optim.Optimizer], List[Dict[str, Any]]]: optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) # return optimizer scheduler = torch.optim.lr_scheduler.LambdaLR( diff --git a/DirectMultiStep/Utils/Dataset.py b/DirectMultiStep/Utils/Dataset.py index 1d4dc18..8ddd0b6 100644 --- a/DirectMultiStep/Utils/Dataset.py +++ b/DirectMultiStep/Utils/Dataset.py @@ -39,7 +39,7 @@ def tokenize_path_string(path_string: str, add_eos: bool = True) -> List[str]: return tokens -class RoutesDataset(Dataset[Tuple[torch.Tensor,...]]): +class RoutesDataset(Dataset[Tuple[torch.Tensor, ...]]): def __init__(self, metadata_path: str) -> None: self.products: List[str] = [] diff --git a/DirectMultiStep/Utils/PostProcess.py b/DirectMultiStep/Utils/PostProcess.py index 4fbb470..524762a 100644 --- a/DirectMultiStep/Utils/PostProcess.py +++ b/DirectMultiStep/Utils/PostProcess.py @@ -77,7 +77,9 @@ def find_matching_paths( if verbose else zip(paths_NS2n, correct_paths) ) - for pathreac_S2n, correct_path in cast(Iterator[Tuple[BeamProcessedType, str]], iterator): + for pathreac_S2n, correct_path in cast( + Iterator[Tuple[BeamProcessedType, str]], iterator + ): path_match = None path_match_perm = None for rank, (path, _) in enumerate(pathreac_S2n): diff --git a/DirectMultiStep/Utils/Visualize.py b/DirectMultiStep/Utils/Visualize.py index 5a9907e..3702d5e 100644 --- a/DirectMultiStep/Utils/Visualize.py +++ b/DirectMultiStep/Utils/Visualize.py @@ -22,9 +22,9 @@ import os import cairo -from rdkit import Chem # type: ignore -from rdkit.Chem import Draw # type: ignore -import cairosvg # type: ignore +from rdkit import Chem # type: ignore +from rdkit.Chem import Draw # type: ignore +import cairosvg # type: ignore from pathlib import Path from typing import Dict, List, Tuple, Union, cast @@ -120,7 +120,12 @@ def check_overlap( def draw_rounded_rectangle( - ctx: cairo.Context, x: int, y: int, width: int, height: int, corner_radius: int # type: ignore + ctx: cairo.Context, + x: int, + y: int, + width: int, + height: int, + corner_radius: int, # type: ignore ) -> None: """Draws a rounded rectangle.""" ctx.new_sub_path() @@ -152,7 +157,7 @@ def draw_molecule_tree( height: int = 400, x_margin: int = 50, y_margin: int = 50, -)->None: +) -> None: canvas_width, canvas_height = compute_canvas_dimensions( tree, width, height, y_margin ) @@ -161,7 +166,7 @@ def draw_molecule_tree( existing_boxes: List[Tuple[int, int]] = [] - def draw_node(node:'RetroSynthesisTree', x:int, y:int)->None: + def draw_node(node: "RetroSynthesisTree", x: int, y: int) -> None: # Check for overlap and adjust position while check_overlap(x, y, existing_boxes, width, height) or check_overlap( x, y - y_margin, existing_boxes, width, height @@ -241,7 +246,14 @@ def draw_node(node:'RetroSynthesisTree', x:int, y:int)->None: os.remove("temp.png") -def draw_tree_from_path_string(path_string: str, save_path: Path, width: int = 400, height: int = 400, x_margin: int = 50, y_margin: int = 100)->None: +def draw_tree_from_path_string( + path_string: str, + save_path: Path, + width: int = 400, + height: int = 400, + x_margin: int = 50, + y_margin: int = 100, +) -> None: assert save_path.suffix == "", "Please provide a path without extension" retro_tree = create_tree_from_path_string(path_string=path_string) diff --git a/DirectMultiStep/helpers.py b/DirectMultiStep/helpers.py index 4638811..ce2b057 100644 --- a/DirectMultiStep/helpers.py +++ b/DirectMultiStep/helpers.py @@ -93,7 +93,7 @@ def parse_version(ckpt: Path) -> int: return sorted(last_checkpoints, key=parse_version, reverse=True)[0] # If no "last" file, find the checkpoint with the largest epoch and step - def parse_epoch_step(filename: str): + def parse_epoch_step(filename: str) -> Tuple[int, int]: # This pattern will match 'epoch=X-step=Y.ckpt' and extract X and Y match = re.search(r"epoch=(\d+)-step=(\d+)\.ckpt", filename) if match: diff --git a/DirectMultiStep/tests/test_preprocess.py b/DirectMultiStep/tests/test_preprocess.py index 5fbafd6..222f6c7 100644 --- a/DirectMultiStep/tests/test_preprocess.py +++ b/DirectMultiStep/tests/test_preprocess.py @@ -21,13 +21,13 @@ # SOFTWARE. import pytest -from ..Utils.PreProcess import ( +from DirectMultiStep.Utils.PreProcess import ( filter_mol_nodes, max_tree_depth, find_leaves, generate_permutations, ) -from .test_data import ( +from DirectMultiStep.tests.test_data import ( test1_leaves, test2_depth1, test3_depth2, @@ -39,7 +39,7 @@ test9_tknz_smiles, test10_tknz_path, ) -from ..Utils.Dataset import tokenize_smile, tokenize_path_string +from DirectMultiStep.Utils.Dataset import tokenize_smile, tokenize_path_string test_filtering_and_depth = [ pytest.param(test1_leaves, 0, id="leaves"), diff --git a/assess_single.py b/assess_single.py index e90ffaf..e86d496 100644 --- a/assess_single.py +++ b/assess_single.py @@ -21,11 +21,10 @@ # SOFTWARE. from DirectMultiStep.Models.TensorGen import BeamSearchOptimized as BeamSearch -from DirectMultiStep.Models.Configure import VanillaTransformerConfig, prepare_model +from DirectMultiStep.Models.Architecture import VanillaTransformerConfig +from DirectMultiStep.Models.Configure import prepare_model from DirectMultiStep.Utils.Dataset import RoutesDataset from DirectMultiStep.Utils.PostProcess import ( - BeamResultType, - BeamSearchOutput, find_valid_paths, process_paths, ) @@ -36,6 +35,7 @@ import lightning as L from rdkit import RDLogger # type: ignore import rdkit.Chem as Chem # type: ignore +from typing import List, Tuple, cast RDLogger.DisableLog("rdApp.*") @@ -51,7 +51,7 @@ def canonicalize(smile: str) -> str: - return Chem.MolToSmiles(Chem.MolFromSmiles(smile), isomericSmiles=True) + return cast(str, Chem.MolToSmiles(Chem.MolFromSmiles(smile), isomericSmiles=True)) product = canonicalize("OC(C1=C(N2N=CC=N2)C=CC(OC)=C1)=O") @@ -122,7 +122,7 @@ def canonicalize(smile: str) -> str: device=device, ) -all_beam_results_NS2: BeamSearchOutput = [] +all_beam_results_NS2: List[List[Tuple[str, float]]] = [] beam_result_BS2 = BSObject.decode( src_BC=encoder_inp, steps_B1=steps_tens, path_start_BL=path_tens ) diff --git a/train_nosm.py b/train_nosm.py index 0e738e2..c4b39bc 100644 --- a/train_nosm.py +++ b/train_nosm.py @@ -26,8 +26,8 @@ from DirectMultiStep.Models.Configure import ( prepare_model, determine_device, - VanillaTransformerConfig, ) +from DirectMultiStep.Models.Architecture import VanillaTransformerConfig from DirectMultiStep.Models.Training import PLTraining from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks import RichModelSummary diff --git a/train_wsm.py b/train_wsm.py index 09ae12b..db7df8a 100644 --- a/train_wsm.py +++ b/train_wsm.py @@ -26,8 +26,8 @@ from DirectMultiStep.Models.Configure import ( prepare_model, determine_device, - VanillaTransformerConfig, ) +from DirectMultiStep.Models.Architecture import VanillaTransformerConfig from DirectMultiStep.Models.Training import PLTraining from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks import RichModelSummary diff --git a/visualize_tree.py b/visualize_tree.py index 1076721..f90a554 100644 --- a/visualize_tree.py +++ b/visualize_tree.py @@ -6,4 +6,6 @@ if __name__ == "__main__": path = "{'smiles':'O=C(c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1)N1CCN(CC2CC2)CC1','children':[{'smiles':'O=C(O)c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1','children':[{'smiles':'CCOC(=O)c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1','children':[{'smiles':'CCOC(=O)c1ccc(N)cc1'},{'smiles':'O=S(=O)(Cl)c1cccc2cccnc12'}]}]},{'smiles':'C1CN(CC2CC2)CCN1'}]}" - draw_tree_from_path_string(path_string=path, save_path=fig_path/"mitapivat", y_margin=150) \ No newline at end of file + draw_tree_from_path_string( + path_string=path, save_path=fig_path / "mitapivat", y_margin=150 + )