Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix decoding fragments containing square brackets #25

Merged
merged 6 commits into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
[![Conda](https://img.shields.io/conda/v/conda-forge/safe-mol?label=conda&color=success)](https://anaconda.org/conda-forge/safe-mol)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/safe-mol)](https://pypi.org/project/safe-mol/)
[![Conda](https://img.shields.io/conda/dn/conda-forge/safe-mol)](https://anaconda.org/conda-forge/safe-mol)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/safe-mol)](https://pypi.org/project/safe-mol/)
[![Code license](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/datamol-io/safe/blob/main/LICENSE)
[![Data License](https://img.shields.io/badge/Data%20License-CC%20BY%204.0-red.svg)](https://github.com/datamol-io/safe/blob/main/DATA_LICENSE)
[![GitHub Repo stars](https://img.shields.io/github/stars/datamol-io/safe)](https://github.com/datamol-io/safe/stargazers)
Expand Down Expand Up @@ -60,6 +59,13 @@ The construction of a SAFE strings requires defining a molecular fragmentation a
<img src="docs/assets/safe-construction.svg" width="100%">
</div>

## News

#### 2024/01/15
1. We have updated the model with the version used for the paper. The revision number is ``
2. @IanAWatson has a C++ implementation of SAFE in [LillyMol](https://github.com/IanAWatson/LillyMol/tree/bazel_version_float) that is quite fast and use a custom fragmentation algorithm. Follow the installation instruction on the repo and checkout the docs of the CLI here: [docs/Molecule_Tools/SAFE.md](https://github.com/IanAWatson/LillyMol/blob/bazel_version_float/docs/Molecule_Tools/SAFE.md)


### Installation

You can install `safe` using pip:
Expand Down
16 changes: 5 additions & 11 deletions safe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
from .converter import encode
from .converter import decode
from .converter import SAFEConverter
from .viz import to_image
from .tokenizer import SAFETokenizer
from .tokenizer import split
from . import trainer, utils
from ._exception import SAFEDecodeError, SAFEEncodeError, SAFEFragmentationError
from .converter import SAFEConverter, decode, encode
from .sample import SAFEDesign
from ._exception import SAFEDecodeError
from ._exception import SAFEEncodeError
from ._exception import SAFEFragmentationError
from . import trainer
from . import utils
from .tokenizer import SAFETokenizer, split
from .viz import to_image
29 changes: 13 additions & 16 deletions safe/converter.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
from typing import Union
from typing import Optional
from typing import List
from typing import Callable

import re
import datamol as dm
import itertools
import numpy as np

from contextlib import suppress
import re
from collections import Counter
from contextlib import suppress
from typing import Callable, List, Optional, Union

import datamol as dm
import numpy as np
from rdkit import Chem
from rdkit.Chem import BRICS
from ._exception import SAFEDecodeError
from ._exception import SAFEEncodeError
from ._exception import SAFEFragmentationError

from ._exception import SAFEDecodeError, SAFEEncodeError, SAFEFragmentationError
from .utils import standardize_attach


Expand Down Expand Up @@ -110,8 +104,8 @@ def _find_branch_number(cls, inp: str):
Args:
inp: input smiles
"""

matching_groups = re.findall(r"((?<=%)\d{2})|((?<!%)\d+)", inp)
inp = re.sub("[\[].*?[\]]", "", inp) # noqa
matching_groups = re.findall(r"((?<=%)\d{2})|((?<!%)\d+)(?![^\[]*\])", inp)
# first match is for multiple connection as multiple digits
# second match is for single connections requiring 2 digits
# SMILES does not support triple digits
Expand Down Expand Up @@ -262,6 +256,7 @@ def encoder(
# TODO(maclandrol): RDKit supports some extended form of ring closure, up to 5 digits
# https://www.rdkit.org/docs/RDKit_Book.html#ring-closures and I should try to include them
branch_numbers = self._find_branch_number(inp)

mol = dm.to_mol(inp, remove_hs=False)

bond_map_id = 1
Expand Down Expand Up @@ -327,7 +322,9 @@ def encoder(
)

scaffold_str = ".".join(frags_str)
attach_pos = set(re.findall(r"(\[\d+\*\]|\[[^:]*:\d+\])", scaffold_str))
# don't capture atom mapping in the scaffold
attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you double check this ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem here is that if we don't handle the atom mapping, the attach_pos set will contain the atom mapping pieces and will give us wrong molecules in some cases.

As an example:

smiles="c1cc2c(cc1[C@@H]1CC[C:3][NH2+]1)O[13C]CO2"
scaffold_str = ".".join(frags_str) # c1cc2c(cc1[1*])O[13C]CO2.[C@@H]1([1*])CC[C:3][NH2+]1
attach_pos = set(re.findall(r"(\[\d+\*\]|\[[^:]*:\d+\])", scaffold_str)) #  #{'[13C]CO2.[C@@H]1([1*])CC[C:3]', '[1*]'}

this will give us the safe string 'c1cc2c(cc13)O[13C]CO2.[C@@H]13CC[C:3][NH2+]1'

With my regex:

smiles="c1cc2c(cc1[C@@H]1CC[C:3][NH2+]1)O[13C]CO2"
scaffold_str = ".".join(frags_str) # c1cc2c(cc1[1*])O[13C]CO2.[C@@H]1([1*])CC[C:3][NH2+]1
attach_pos = set(re.findall(r"(\[\d+\*\]|\[[^:]*:\d+\])", scaffold_str)) #  {'[1*]'}

that will give us in this case the exact safe string 'c1cc2c(cc13)O[13C]CO2.[C@@H]13CC[C:3][NH2+]1'

For more difficult cases of atom mapping like [C+:1]#[C:2]CC(C[C:20][O:21]CN)C[C:3]1=[C:7]([H:10])[N-:6][O:5][C:4]1([H:8])[H:9]"
image

We will have the following attach_pos and final safe string:

modified_regex = {'[1*]'}
old_regex = {'[C:3]', '[N-:6]', '[C:4]', '[O:21]', '[H:9]', '[H:8]', '[C:7]', '[C+:1]', '[C:2]', '[1*]', '[O:5]', '[H:10]', '[C:20]'}
# modified regex
'[C+:1]#[C:2]CC(C[C:3]1=[C:7]([H:10])[N-:6][O:5][C:4]1([H:8])[H:9])C[C:20]2.[O:21]2CN'
# old regex
"[C+:1]#[C:2]CC(C[C:3]1=[C:7]([H:10])[N-:6][O:5][C:4]1([H:8])[H:9])C[C:20][1*].[O:21]([1*])CN"

That will lead at the following mols:
old regex
image
new regex
image


if canonical:
attach_pos = sorted(attach_pos)
starting_num = 1 if len(branch_numbers) == 0 else max(branch_numbers) + 1
Expand Down
32 changes: 14 additions & 18 deletions safe/sample.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
from typing import Union
from typing import List
from typing import Optional


from contextlib import suppress
from collections.abc import Mapping
from collections import Counter
from transformers.generation import PhrasalConstraint
from transformers.generation import DisjunctiveConstraint
from transformers import GenerationConfig
from safe.trainer.model import SAFEDoubleHeadsModel
from safe.tokenizer import SAFETokenizer
from loguru import logger
from tqdm.auto import tqdm

import itertools
import os
import re
import torch
import random
import re
from collections import Counter
from collections.abc import Mapping
from contextlib import suppress
from typing import List, Optional, Union

import datamol as dm
import torch
from loguru import logger
from tqdm.auto import tqdm
from transformers import GenerationConfig
from transformers.generation import DisjunctiveConstraint, PhrasalConstraint

import safe as sf
from safe.tokenizer import SAFETokenizer
from safe.trainer.model import SAFEDoubleHeadsModel


class SAFEDesign:
Expand Down
44 changes: 19 additions & 25 deletions safe/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,35 @@
from typing import Optional
from typing import List
from typing import Any
from typing import Iterator
from typing import Union
from typing import Dict

import re
import os
import contextlib
import fsspec
import copy
import torch
import numpy as np
import json
import os
import re
import warnings
import packaging.version
from typing import Any, Dict, Iterator, List, Optional, Union

import fsspec
import numpy as np
import packaging.version
import torch
from loguru import logger
from tokenizers import decoders
from tokenizers import Tokenizer
from tokenizers import Tokenizer, decoders
from tokenizers.models import BPE, WordLevel
from tokenizers.trainers import BpeTrainer, WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace, PreTokenizer
from tokenizers.pre_tokenizers import PreTokenizer, Whitespace
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import BpeTrainer, WordLevelTrainer
from transformers import PreTrainedTokenizerFast
from transformers import __version__ as transformers_version
from transformers.utils import PushToHubMixin
from transformers.utils import is_offline_mode
from transformers.utils import is_remote_url
from transformers.utils import cached_file
from transformers.utils import download_url
from transformers.utils import extract_commit_hash
from transformers.utils import working_or_temp_dir
from transformers.utils import (
PushToHubMixin,
cached_file,
download_url,
extract_commit_hash,
is_offline_mode,
is_remote_url,
working_or_temp_dir,
)

from .utils import attr_as


SPECIAL_TOKENS = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
UNK_TOKEN = "[UNK]"
PADDING_TOKEN = "[PAD]"
Expand Down
26 changes: 11 additions & 15 deletions safe/trainer/cli.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
from typing import Optional
from typing import Literal

import math
import os
import sys
import uuid
import safe
from dataclasses import dataclass, field
from typing import Literal, Optional

import datasets
import evaluate
import torch
import transformers
import evaluate
import datasets
from dataclasses import dataclass, field
from loguru import logger
from transformers import AutoConfig
from transformers import AutoTokenizer
from transformers import set_seed
from transformers.utils.logging import log_levels as LOG_LEVELS
from transformers import AutoConfig, AutoTokenizer, TrainingArguments, set_seed
from transformers.trainer_utils import get_last_checkpoint
from transformers import TrainingArguments
from safe.trainer.model import SAFEDoubleHeadsModel
from transformers.utils.logging import log_levels as LOG_LEVELS

import safe
from safe.tokenizer import SAFETokenizer
from safe.trainer.data_utils import get_dataset
from safe.trainer.collator import SAFECollator
from safe.trainer.data_utils import get_dataset
from safe.trainer.model import SAFEDoubleHeadsModel
from safe.trainer.trainer_utils import SAFETrainer


CURRENT_DIR = os.path.join(safe.__path__[0], "trainer")


Expand Down
13 changes: 4 additions & 9 deletions safe/trainer/collator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from typing import Union
from typing import Optional
from typing import List
from typing import Dict
from typing import Any

import copy
import functools
import torch
from collections.abc import Mapping
from typing import Any, Dict, List, Optional, Union

import torch
from tokenizers import Tokenizer
from transformers.data.data_collator import _torch_collate_batch
from safe.tokenizer import SAFETokenizer

from tokenizers import Tokenizer
from safe.tokenizer import SAFETokenizer


class SAFECollator:
Expand Down
13 changes: 4 additions & 9 deletions safe/trainer/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from typing import Optional
from typing import Callable
from typing import Any
from typing import Union
from typing import Dict

import itertools
from collections.abc import Mapping
from tqdm.auto import tqdm
from functools import partial
from typing import Any, Callable, Dict, Optional, Union

import itertools
import upath
import datasets
import upath
from tqdm.auto import tqdm

from safe.tokenizer import SAFETokenizer

Expand Down
24 changes: 10 additions & 14 deletions safe/trainer/model.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from typing import Optional
from typing import Union
from typing import Tuple
from typing import Callable
from typing import Any
from typing import Any, Callable, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import MSELoss
from transformers import GPT2DoubleHeadsModel
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import GPT2DoubleHeadsModel, PretrainedConfig
from transformers.activations import get_activation
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
from transformers.models.gpt2.modeling_gpt2 import add_start_docstrings_to_model_forward
from transformers.models.gpt2.modeling_gpt2 import replace_return_docstrings
from transformers.models.gpt2.modeling_gpt2 import GPT2_INPUTS_DOCSTRING
from transformers.models.gpt2.modeling_gpt2 import _CONFIG_FOR_DOC
from transformers import PretrainedConfig
from transformers.models.gpt2.modeling_gpt2 import (
_CONFIG_FOR_DOC,
GPT2_INPUTS_DOCSTRING,
GPT2DoubleHeadsModelOutput,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)


class PropertyHead(torch.nn.Module):
Expand Down
4 changes: 1 addition & 3 deletions safe/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from transformers import Trainer
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES


class SAFETrainer(Trainer):
Expand Down
30 changes: 11 additions & 19 deletions safe/utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@
from typing import Optional
from typing import Any
from typing import Union
from typing import List
from typing import Tuple

import itertools
import random
import re
from collections import deque
from contextlib import contextmanager, suppress
from functools import partial
from itertools import combinations
from itertools import compress
from itertools import combinations, compress
from typing import Any, List, Optional, Tuple, Union

import datamol as dm
import networkx as nx
import numpy as np
from loguru import logger
from networkx.utils import py_random_state

from rdkit import Chem
from rdkit.Chem import EditableMol, Atom
from rdkit.Chem.rdmolops import ReplaceCore
from rdkit.Chem.rdmolops import AdjustQueryParameters
from rdkit.Chem.rdmolops import AdjustQueryProperties
from rdkit.Chem import Atom, EditableMol
from rdkit.Chem.rdChemReactions import ReactionFromSmarts
from rdkit.Chem.rdmolops import AdjustQueryParameters, AdjustQueryProperties, ReplaceCore

import itertools
import random
import re
import numpy as np
import networkx as nx
import datamol as dm
import safe as sf

__implicit_carbon_query = dm.from_smarts("[#6;h]")
Expand Down
10 changes: 4 additions & 6 deletions safe/viz.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Optional
from typing import Any
from typing import Tuple
from typing import Union

import itertools
import matplotlib.pyplot as plt
from typing import Any, Optional, Tuple, Union

import datamol as dm
import matplotlib.pyplot as plt

import safe as sf


Expand Down
2 changes: 1 addition & 1 deletion tests/test_hgf_load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from safe.sample import SAFEDesign
from safe.tokenizer import SAFETokenizer
from safe.trainer.model import SAFEDoubleHeadsModel
from safe.sample import SAFEDesign


def test_load_default_safe_model():
Expand Down
Loading