Skip to content

Commit

Permalink
Rewrite get_one_hot and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
renzph committed Aug 23, 2023
1 parent 09cb85e commit b4bcc22
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 63 deletions.
115 changes: 52 additions & 63 deletions fcd/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
from contextlib import contextmanager
from multiprocessing import Pool
from typing import List

import numpy as np
import torch
Expand All @@ -8,74 +10,61 @@
from torch import nn
from torch.utils.data import Dataset

from .torch_layers import (IndexTensor, IndexTuple, Reverse, SamePadding1d,
Transpose)

__vocab = [
"C",
"N",
"O",
"H",
"F",
"Cl",
"P",
"B",
"Br",
"S",
"I",
"Si",
"#",
"(",
")",
"+",
"-",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"=",
"[",
"]",
"@",
"c",
"n",
"o",
"s",
"X",
".",
]
__vocab_i2c = {i: k for i, k in enumerate(__vocab)}
from .torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose

# fmt: off
__vocab = ["C","N","O","H","F","Cl","P","B","Br","S","I","Si","#","(",")","+","-","1","2","3","4","5","6","7","8","=","[","]","@","c","n","o","s","X","."]
# fmt: on
__vocab_c2i = {k: i for i, k in enumerate(__vocab)}
__unk = __vocab_c2i["X"]
__two_letters = {"r", "i", "l"}


def get_one_hot(smiles, pad_len=-1):
sorted_vocab = __vocab[:]
sorted_vocab.sort() # sort alphabetically
sorted_vocab.sort(key=len, reverse=True) # sort by length for regex
FULL_REGEX = "|".join(
"(%s)" % re.escape(base_symbol) for base_symbol in sorted_vocab
) # Tries to match longer tokens first.
FULL_REGEX += "|." # Handle unkown characters


def tokenize(smiles: str) -> List[str]:
"""Tokenizes the given smiles string. Needed for multi-character tokens like 'Cl'
Args:
smiles (str): Input molecule as Smiles
Returns:
List[str]: List of tokens
"""
tok_smile = [mo.group() for mo in re.finditer(FULL_REGEX, smiles)]
assert "".join(tok_smile) == smiles
return tok_smile


def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
"""Generate one-hot representation of a Smiles string.
Args:
smiles (str): Input molecule as Smiles
pad_len (int, optional): Whether or not to pad to a given size. Defaults to -1.
Returns:
np.ndarray: Array containing the one-hot encoded Smiles
"""
smiles = smiles + "."
one_hot = np.zeros((len(smiles) if pad_len < 0 else pad_len, len(__vocab)))

if len(smiles) == 1:
one_hot[0, __vocab_c2i.get(".")] = 1
return one_hot

src = 0
dst = 0
while True:
if smiles[src + 1] in __two_letters:
sym = smiles[src : src + 2]
src += 2
else:
sym = smiles[src]
src += 1
one_hot[dst, __vocab_c2i.get(sym, __unk)] = 1
dst += 1
if smiles[src] == "." or dst == one_hot.shape[0] - 1:
one_hot[dst, __vocab_c2i.get(".")] = 1
break

# initialize array
array_length = len(smiles) if pad_len < 0 else pad_len
vocab_size = len(__vocab)
one_hot = np.zeros((array_length, vocab_size))

tokens = tokenize(smiles)
numeric = [__vocab_c2i.get(token, __unk) for token in tokens]

for pos, num in enumerate(numeric):
one_hot[pos, num] = 1

return one_hot


Expand Down
34 changes: 34 additions & 0 deletions test/test_fcd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import unittest

import numpy as np

from fcd import get_fcd
from fcd.utils import get_one_hot


class TestFCD(unittest.TestCase):
Expand All @@ -26,6 +29,37 @@ def test_fcd_torch(self):
get_fcd(smiles_list1, smiles_list2), target, 3, f"Should be {target}"
)

def test_one_hot(self):
inputs = [
"O=C([C@H]1CC[C@@H]2[C@@H]",
"CCN(CC)S(=O)(=O)CCC[C@H](",
"COc1ccc(\C=C\2/C[N+](C)(C",
"COc1cc(CC(=O)N2CCN(CC2)c3",
"CN(CC#Cc1cc(ccc1C)ClC2(O)CC",
"CYACCyyCLCl",
]
# fmt: off
outputs = [
((26, 35), [2, 25, 0, 13, 26, 0, 28, 3, 27, 17, 0, 0, 26, 0, 28, 28, 3, 27, 18, 26, 0, 28, 28, 3, 27, 34]),
((26, 35), [0, 0, 1, 13, 0, 0, 14, 9, 13, 25, 2, 14, 13, 25, 2, 14, 0, 0, 0, 26, 0, 28, 3, 27, 13, 34]),
((25, 35), [0, 2, 29, 17, 29, 29, 29, 13, 33, 0, 25, 0, 33, 33, 0, 26, 1, 15, 27, 13, 0, 14, 13, 0, 34]),
((26, 35), [0, 2, 29, 17, 29, 29, 13, 0, 0, 13, 25, 2, 14, 1, 18, 0, 0, 1, 13, 0, 0, 18, 14, 29, 19, 34]),
((28, 35), [0, 1, 13, 0, 0, 12, 0, 29, 17, 29, 29, 13, 29, 29, 29, 17, 0, 14, 5, 0, 18, 13, 2, 14, 0, 0, 34]),
((12, 35), [0, 33, 33, 0, 0, 33, 33, 0, 33, 5, 34]),
]
# fmt: on

for inp, (correct_shape, correct_entries) in zip(inputs, outputs):
one_hot = get_one_hot(inp)
shape = one_hot.shape
entries = np.where(one_hot)[1].tolist()
self.assertEqual(shape, correct_shape)
self.assertEqual(entries, correct_entries)

# assert that no duplicate ones and no missing entries. Trailing zero vectors are allowed.
non_zero_idx = np.where(one_hot)[0]
assert np.all(non_zero_idx == np.arange(len(non_zero_idx)))


if __name__ == "__main__":
unittest.main()

0 comments on commit b4bcc22

Please sign in to comment.