From b4bcc222a7eeb0885560847d483de80ba1a0a709 Mon Sep 17 00:00:00 2001 From: Philipp Renz Date: Wed, 23 Aug 2023 11:38:27 +0200 Subject: [PATCH] Rewrite get_one_hot and add test --- fcd/utils.py | 115 +++++++++++++++++++++-------------------------- test/test_fcd.py | 34 ++++++++++++++ 2 files changed, 86 insertions(+), 63 deletions(-) diff --git a/fcd/utils.py b/fcd/utils.py index 7350da9..bb7d57a 100644 --- a/fcd/utils.py +++ b/fcd/utils.py @@ -1,5 +1,7 @@ +import re from contextlib import contextmanager from multiprocessing import Pool +from typing import List import numpy as np import torch @@ -8,74 +10,61 @@ from torch import nn from torch.utils.data import Dataset -from .torch_layers import (IndexTensor, IndexTuple, Reverse, SamePadding1d, - Transpose) - -__vocab = [ - "C", - "N", - "O", - "H", - "F", - "Cl", - "P", - "B", - "Br", - "S", - "I", - "Si", - "#", - "(", - ")", - "+", - "-", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "=", - "[", - "]", - "@", - "c", - "n", - "o", - "s", - "X", - ".", -] -__vocab_i2c = {i: k for i, k in enumerate(__vocab)} +from .torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose + +# fmt: off +__vocab = ["C","N","O","H","F","Cl","P","B","Br","S","I","Si","#","(",")","+","-","1","2","3","4","5","6","7","8","=","[","]","@","c","n","o","s","X","."] +# fmt: on __vocab_c2i = {k: i for i, k in enumerate(__vocab)} __unk = __vocab_c2i["X"] -__two_letters = {"r", "i", "l"} -def get_one_hot(smiles, pad_len=-1): +sorted_vocab = __vocab[:] +sorted_vocab.sort() # sort alphabetically +sorted_vocab.sort(key=len, reverse=True) # sort by length for regex +FULL_REGEX = "|".join( + "(%s)" % re.escape(base_symbol) for base_symbol in sorted_vocab +) # Tries to match longer tokens first. +FULL_REGEX += "|." # Handle unkown characters + + +def tokenize(smiles: str) -> List[str]: + """Tokenizes the given smiles string. Needed for multi-character tokens like 'Cl' + + Args: + smiles (str): Input molecule as Smiles + + Returns: + List[str]: List of tokens + """ + tok_smile = [mo.group() for mo in re.finditer(FULL_REGEX, smiles)] + assert "".join(tok_smile) == smiles + return tok_smile + + +def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray: + """Generate one-hot representation of a Smiles string. + + Args: + smiles (str): Input molecule as Smiles + pad_len (int, optional): Whether or not to pad to a given size. Defaults to -1. + + Returns: + np.ndarray: Array containing the one-hot encoded Smiles + """ smiles = smiles + "." - one_hot = np.zeros((len(smiles) if pad_len < 0 else pad_len, len(__vocab))) - - if len(smiles) == 1: - one_hot[0, __vocab_c2i.get(".")] = 1 - return one_hot - - src = 0 - dst = 0 - while True: - if smiles[src + 1] in __two_letters: - sym = smiles[src : src + 2] - src += 2 - else: - sym = smiles[src] - src += 1 - one_hot[dst, __vocab_c2i.get(sym, __unk)] = 1 - dst += 1 - if smiles[src] == "." or dst == one_hot.shape[0] - 1: - one_hot[dst, __vocab_c2i.get(".")] = 1 - break + + # initialize array + array_length = len(smiles) if pad_len < 0 else pad_len + vocab_size = len(__vocab) + one_hot = np.zeros((array_length, vocab_size)) + + tokens = tokenize(smiles) + numeric = [__vocab_c2i.get(token, __unk) for token in tokens] + + for pos, num in enumerate(numeric): + one_hot[pos, num] = 1 + return one_hot diff --git a/test/test_fcd.py b/test/test_fcd.py index d5992df..92ead4e 100644 --- a/test/test_fcd.py +++ b/test/test_fcd.py @@ -1,6 +1,9 @@ import unittest +import numpy as np + from fcd import get_fcd +from fcd.utils import get_one_hot class TestFCD(unittest.TestCase): @@ -26,6 +29,37 @@ def test_fcd_torch(self): get_fcd(smiles_list1, smiles_list2), target, 3, f"Should be {target}" ) + def test_one_hot(self): + inputs = [ + "O=C([C@H]1CC[C@@H]2[C@@H]", + "CCN(CC)S(=O)(=O)CCC[C@H](", + "COc1ccc(\C=C\2/C[N+](C)(C", + "COc1cc(CC(=O)N2CCN(CC2)c3", + "CN(CC#Cc1cc(ccc1C)ClC2(O)CC", + "CYACCyyCLCl", + ] + # fmt: off + outputs = [ + ((26, 35), [2, 25, 0, 13, 26, 0, 28, 3, 27, 17, 0, 0, 26, 0, 28, 28, 3, 27, 18, 26, 0, 28, 28, 3, 27, 34]), + ((26, 35), [0, 0, 1, 13, 0, 0, 14, 9, 13, 25, 2, 14, 13, 25, 2, 14, 0, 0, 0, 26, 0, 28, 3, 27, 13, 34]), + ((25, 35), [0, 2, 29, 17, 29, 29, 29, 13, 33, 0, 25, 0, 33, 33, 0, 26, 1, 15, 27, 13, 0, 14, 13, 0, 34]), + ((26, 35), [0, 2, 29, 17, 29, 29, 13, 0, 0, 13, 25, 2, 14, 1, 18, 0, 0, 1, 13, 0, 0, 18, 14, 29, 19, 34]), + ((28, 35), [0, 1, 13, 0, 0, 12, 0, 29, 17, 29, 29, 13, 29, 29, 29, 17, 0, 14, 5, 0, 18, 13, 2, 14, 0, 0, 34]), + ((12, 35), [0, 33, 33, 0, 0, 33, 33, 0, 33, 5, 34]), + ] + # fmt: on + + for inp, (correct_shape, correct_entries) in zip(inputs, outputs): + one_hot = get_one_hot(inp) + shape = one_hot.shape + entries = np.where(one_hot)[1].tolist() + self.assertEqual(shape, correct_shape) + self.assertEqual(entries, correct_entries) + + # assert that no duplicate ones and no missing entries. Trailing zero vectors are allowed. + non_zero_idx = np.where(one_hot)[0] + assert np.all(non_zero_idx == np.arange(len(non_zero_idx))) + if __name__ == "__main__": unittest.main()