diff --git a/.github/workflows/test_dev.yml b/.github/workflows/test_dev.yml new file mode 100644 index 0000000..c9e51a6 --- /dev/null +++ b/.github/workflows/test_dev.yml @@ -0,0 +1,40 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Tests (dev) + +on: + push: + branches: [ "dev" ] + pull_request: + branches: [ "dev" ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest + python -m pip install -e . + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest diff --git a/.github/workflows/test_master.yml b/.github/workflows/test_master.yml new file mode 100644 index 0000000..adc98b5 --- /dev/null +++ b/.github/workflows/test_master.yml @@ -0,0 +1,40 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Tests (master) + +on: + push: + branches: [ "master"] + pull_request: + branches: [ "master"] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest + python -m pip install -e . + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest diff --git a/README.md b/README.md index 65481fd..291150f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,12 @@ # Fréchet ChemNet Distance +![PyPI](https://img.shields.io/pypi/v/fcd) +![Tests (master)](https://github.com/bioinf-jku/fcd/actions/workflows/test_master.yml/badge.svg?branch=dev) +![Tests (dev)](https://github.com/bioinf-jku/fcd/actions/workflows/test_dev.yml/badge.svg?branch=dev) +![PyPI - Downloads](https://img.shields.io/pypi/dm/fcd) +![GitHub release (latest by date)](https://img.shields.io/github/v/release/bioinf-jku/fcd) +![GitHub release date](https://img.shields.io/github/release-date/bioinf-jku/fcd) +![GitHub](https://img.shields.io/github/license/bioinf-jku/fcd) + Code for the paper "Fréchet ChemNet Distance: A Metric for Generative Models for Molecules in Drug Discovery" [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.8b00234) / @@ -6,10 +14,14 @@ Code for the paper "Fréchet ChemNet Distance: A Metric for Generative Models fo ## Installation -You can install the FCD using +You can install FCD using ``` pip install fcd ``` +or run the example notebook on Google Colab + +. + # Requirements ``` diff --git a/example.ipynb b/example.ipynb index b0d9471..3c08319 100644 --- a/example.ipynb +++ b/example.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install fcd" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -7,15 +17,15 @@ "outputs": [], "source": [ "import os\n", - "from rdkit import RDLogger \n", + "from rdkit import RDLogger\n", "import numpy as np\n", "import pandas as pd\n", - "from fcd import get_fcd, load_ref_model,canonical_smiles, get_predictions, calculate_frechet_distance\n", + "from fcd import get_fcd, load_ref_model, canonical_smiles, get_predictions, calculate_frechet_distance\n", "\n", - "RDLogger.DisableLog('rdApp.*')\n", + "RDLogger.DisableLog(\"rdApp.*\")\n", "\n", "np.random.seed(0)\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]= '0' #set gpu" + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # set gpu" ] }, { @@ -44,8 +54,10 @@ "model = load_ref_model()\n", "\n", "# Load generated molecules\n", - "gen_mol_file = \"generated_smiles/LSTM_Segler.smi\" #input file which contains one generated SMILES per line\n", - "gen_mol = pd.read_csv(gen_mol_file,header=None)[0] #IMPORTANT: take at least 10000 molecules as FCD can vary with sample size \n", + "gen_mol_file = \"generated_smiles/LSTM_Segler.smi\" # input file which contains one generated SMILES per line\n", + "gen_mol = pd.read_csv(gen_mol_file, header=None)[\n", + " 0\n", + "] # IMPORTANT: take at least 10000 molecules as FCD can vary with sample size\n", "sample1 = np.random.choice(gen_mol, 10000, replace=False)\n", "sample2 = np.random.choice(gen_mol, 10000, replace=False)\n", "\n", @@ -82,7 +94,7 @@ } ], "source": [ - "#get CHEBMLNET activations of generated molecules \n", + "# get CHEBMLNET activations of generated molecules\n", "act1 = get_predictions(model, can_sample1)\n", "act2 = get_predictions(model, can_sample2)\n", "\n", @@ -92,13 +104,9 @@ "mu2 = np.mean(act2, axis=0)\n", "sigma2 = np.cov(act2.T)\n", "\n", - "fcd_score = calculate_frechet_distance(\n", - " mu1=mu1,\n", - " mu2=mu2, \n", - " sigma1=sigma1,\n", - " sigma2=sigma2)\n", + "fcd_score = calculate_frechet_distance(mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2)\n", "\n", - "print('FCD: ',fcd_score)" + "print(\"FCD: \", fcd_score)" ] }, { @@ -123,7 +131,7 @@ "\"\"\"if you don't need to store the activations you can also take a shortcut.\"\"\"\n", "fcd_score = get_fcd(can_sample1, can_sample2, model)\n", "\n", - "print('FCD: ',fcd_score)" + "print(\"FCD: \", fcd_score)" ] }, { @@ -147,7 +155,7 @@ "source": [ "\"\"\"This is what happens if you do not canonicalize the smiles\"\"\"\n", "fcd_score = get_fcd(can_sample1, sample2, model)\n", - "print('FCD: ',fcd_score)" + "print(\"FCD: \", fcd_score)" ] } ], diff --git a/fcd/__init__.py b/fcd/__init__.py index 4ee7c79..cee3edd 100644 --- a/fcd/__init__.py +++ b/fcd/__init__.py @@ -1,4 +1,14 @@ -from .fcd import get_fcd, get_predictions, load_ref_model -from .utils import calculate_frechet_distance, canonical_smiles +# ruff: noqa: F401 -__version__ = "1.2" +from fcd.fcd import get_fcd, get_predictions, load_ref_model +from fcd.utils import calculate_frechet_distance, canonical_smiles + +__all__ = [ + "get_fcd", + "get_predictions", + "load_ref_model", + "calculate_frechet_distance", + "canonical_smiles", +] + +__version__ = "1.2.1" diff --git a/fcd/fcd.py b/fcd/fcd.py index c7e528e..b0f14d7 100644 --- a/fcd/fcd.py +++ b/fcd/fcd.py @@ -9,7 +9,7 @@ from torch import nn from torch.utils.data import DataLoader -from .utils import ( +from fcd.utils import ( SmilesDataset, calculate_frechet_distance, load_imported_model, @@ -31,6 +31,8 @@ def load_ref_model(model_path: Optional[str] = None): if model_path is None: chemnet_model_filename = "ChemNet_v0.13_pretrained.pt" model_bytes = pkgutil.get_data("fcd", chemnet_model_filename) + if model_bytes is None: + raise FileNotFoundError(f"Could not find model file {chemnet_model_filename}") tmpdir = tempfile.TemporaryDirectory() model_path = os.path.join(tmpdir.name, chemnet_model_filename) @@ -48,7 +50,7 @@ def get_predictions( smiles_list: List[str], batch_size: int = 128, n_jobs: int = 1, - device: str = "cpu", + device: Optional[str] = None, ) -> np.ndarray: """Calculate Chemnet activations @@ -65,37 +67,48 @@ def get_predictions( if len(smiles_list) == 0: return np.zeros((0, 512)) - dataloader = DataLoader( - SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs - ) + dataloader = DataLoader(SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs) + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + with todevice(model, device), torch.no_grad(): chemnet_activations = [] for batch in dataloader: chemnet_activations.append( - model(batch.transpose(1, 2).float().to(device)) - .to("cpu") - .detach() - .numpy() + model(batch.transpose(1, 2).float().to(device)).to("cpu").detach().numpy().astype(np.float32) ) return np.row_stack(chemnet_activations) -def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module = None) -> float: +def get_fcd(smiles1: List[str], smiles2: List[str], model: Optional[nn.Module] = None, device=None) -> float: """Calculate FCD between two sets of Smiles Args: - smiles1 (List[str]): First set of smiles - smiles2 (List[str]): Second set of smiles + smiles1 (List[str]): First set of SMILES. + smiles2 (List[str]): Second set of SMILES. model (nn.Module, optional): The model to use. Loads default model if None. + device: The device to use for computation. Returns: - float: The FCD score + float: The FCD score. + + Raises: + ValueError: If the input SMILES lists are empty. + + Example: + >>> smiles1 = ['CCO', 'CCN'] + >>> smiles2 = ['CCO', 'CCC'] + >>> fcd_score = get_fcd(smiles1, smiles2) """ + if not smiles1 or not smiles2: + raise ValueError("Input SMILES lists cannot be empty.") + if model is None: model = load_ref_model() - act1 = get_predictions(model, smiles1) - act2 = get_predictions(model, smiles2) + act1 = get_predictions(model, smiles1, device=device) + act2 = get_predictions(model, smiles2, device=device) mu1 = np.mean(act1, axis=0) sigma1 = np.cov(act1.T) @@ -103,8 +116,6 @@ def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module = None) -> mu2 = np.mean(act2, axis=0) sigma2 = np.cov(act2.T) - fcd_score = calculate_frechet_distance( - mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2 - ) + fcd_score = calculate_frechet_distance(mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2) return fcd_score diff --git a/fcd/utils.py b/fcd/utils.py index bb7d57a..ea780e6 100644 --- a/fcd/utils.py +++ b/fcd/utils.py @@ -1,7 +1,8 @@ import re +import warnings from contextlib import contextmanager from multiprocessing import Pool -from typing import List +from typing import List, Optional import numpy as np import torch @@ -10,7 +11,7 @@ from torch import nn from torch.utils.data import Dataset -from .torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose +from fcd.torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose # fmt: off __vocab = ["C","N","O","H","F","Cl","P","B","Br","S","I","Si","#","(",")","+","-","1","2","3","4","5","6","7","8","=","[","]","@","c","n","o","s","X","."] @@ -42,7 +43,7 @@ def tokenize(smiles: str) -> List[str]: return tok_smile -def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray: +def get_one_hot(smiles: str, pad_len: Optional[int] = None) -> np.ndarray: """Generate one-hot representation of a Smiles string. Args: @@ -52,10 +53,13 @@ def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray: Returns: np.ndarray: Array containing the one-hot encoded Smiles """ + # add end token smiles = smiles + "." # initialize array - array_length = len(smiles) if pad_len < 0 else pad_len + array_length = len(smiles) if pad_len is None else pad_len + assert array_length >= len(smiles), "Pad length must be greater than the length of the input SMILES string + 1." + vocab_size = len(__vocab) one_hot = np.zeros((array_length, vocab_size)) @@ -106,22 +110,57 @@ def load_imported_model(keras_config): class SmilesDataset(Dataset): - __PAD_LEN = 350 + """ + A dataset class for handling SMILES data. + + Args: + smiles_list (list): A list of SMILES strings. + pad_len (int, optional): The length to pad the SMILES strings to. If not provided, the default pad length of 350 will be used. + warn (bool, optional): Whether to display a warning message if the specified pad length is different from the default. Defaults to True. + + Attributes: + smiles_list (list): A list of SMILES strings. + pad_len (int): The length to pad the SMILES strings to. + + """ - def __init__(self, smiles_list): + def __init__(self, smiles_list, pad_len=None, warn=True): super().__init__() + DEFAULT_PAD_LEN = 350 + self.smiles_list = smiles_list + max_len = max(len(smiles) for smiles in smiles_list) + 1 # plus one for the end token + + if pad_len is None: + pad_len = max(DEFAULT_PAD_LEN, max_len) + else: + if pad_len < max_len: + raise ValueError(f"Specified pad_len {pad_len} is less than max_len {max_len}") + + if pad_len != DEFAULT_PAD_LEN: + warnings.warn( + """Padding lengths differing from the default of 350 may affect FCD scores. See https://github.com/hogru/GuacaMolEval. + Use warn=False to suppress this warning.""" + ) + + self.pad_len = pad_len def __getitem__(self, idx): smiles = self.smiles_list[idx] - features = get_one_hot(smiles, 350) + features = get_one_hot(smiles, pad_len=self.pad_len) return features / features.shape[1] def __len__(self): return len(self.smiles_list) -def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): +def calculate_frechet_distance( + mu1: np.ndarray, + sigma1: np.ndarray, + mu2: np.ndarray, + sigma2: np.ndarray, + eps: float = 1e-6, +) -> float: """Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is @@ -151,21 +190,20 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): sigma1 = np.atleast_2d(sigma1) sigma2 = np.atleast_2d(sigma2) - assert ( - mu1.shape == mu2.shape - ), "Training and test mean vectors have different lengths" - assert ( - sigma1.shape == sigma2.shape - ), "Training and test covariances have different dimensions" + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" diff = mu1 - mu2 # product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) - if not np.isfinite(covmean).all(): + is_real = np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3) + + if not np.isfinite(covmean).all() or not is_real: offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + assert isinstance(covmean, np.ndarray) # numerical error might give slight imaginary component if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): @@ -175,7 +213,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): tr_covmean = np.trace(covmean) - return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + return float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) @contextmanager @@ -188,11 +226,11 @@ def todevice(model, device): def canonical(smi): try: - return Chem.MolToSmiles(Chem.MolFromSmiles(smi)) - except: + return Chem.MolToSmiles(Chem.MolFromSmiles(smi)) # type: ignore + except Exception: return None -def canonical_smiles(smiles, njobs=32): +def canonical_smiles(smiles, njobs=-1): with Pool(njobs) as pool: return pool.map(canonical, smiles) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ada53c2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,64 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fcd" +version = "1.2.1" +dependencies = ["torch", "numpy", "scipy", "rdkit"] +requires-python = ">=3.8" +authors = [ + { name = "Philipp Renz", email = "renz.ph@gmail.com" }, +] +description = "Fréchet ChEMNet Distance" +readme = "README.md" +license = { file = "LICENSE" } +keywords = ["cheminformatics", "machine learning", "deep learning", "generative models"] +classifiers = [ + "Programming Language :: Python :: >=3.8", + "Development Status :: 5 - Production/Stable", +] + +[project.urls] +Homepage = "https://github.com/bioinf-jku/FCD" +Documentation = "https://github.com/bioinf-jku/FCD" +Repository = "https://github.com/bioinf-jku/FCD" + +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 120 +indent-width = 4 + +# Assume Python 3.8 +target-version = "py38" + diff --git a/setup.py b/setup.py deleted file mode 100644 index e18e275..0000000 --- a/setup.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python - -# from distutils.core import setup -from setuptools import setup - -with open("README.md", "r", encoding="utf8") as fh: - long_description = fh.read() - -setup( - name="FCD", - version="1.2", - author="Philipp Renz", - author_email="renz.ph@gmail.com", - description="Fréchet ChEMNet Distance", - url="https://github.com/bioinf-jku/FCD", - packages=["fcd"], - license="LGPLv3", - long_description=long_description, - long_description_content_type="text/markdown", - install_requires=["torch", "numpy", "scipy", "rdkit"], - include_package_data=True, -) diff --git a/test/test_fcd.py b/test/test_fcd.py index 92ead4e..81395da 100644 --- a/test/test_fcd.py +++ b/test/test_fcd.py @@ -1,21 +1,34 @@ -import unittest - import numpy as np +import pytest +from pytest import approx from fcd import get_fcd -from fcd.utils import get_one_hot +from fcd.utils import SmilesDataset, get_one_hot + + +class TestFCD: + def test_random_smiles_cpu(self): + smiles_list1 = ["CNOHF", "NOHFCl", "OHFClP", "HFClPB"] + smiles_list2 = ["ISi#()", "Si#()+", "#()+-", "()+-1"] + target = 8.8086 + fcd = get_fcd(smiles_list1, smiles_list2, device="cpu") + assert fcd == approx(target, abs=1e-2) + def test_random_smiles_gpu(self): + # Skip test if CUDA is not available + # CUDA comp is less consistent than CPU + import torch + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") -class TestFCD(unittest.TestCase): - def test_random_smiles(self): smiles_list1 = ["CNOHF", "NOHFCl", "OHFClP", "HFClPB"] smiles_list2 = ["ISi#()", "Si#()+", "#()+-", "()+-1"] target = 8.8086 - self.assertAlmostEqual( - get_fcd(smiles_list1, smiles_list2), target, 3, f"Should be {target}" - ) + fcd = get_fcd(smiles_list1, smiles_list2, device="cuda") + assert fcd == approx(target, abs=1e-2) - def test_fcd_torch(self): + def test_random_smiles_cpu_2(self): smiles_list1 = [ "COc1cccc(NC(=O)Cc2coc3ccc(OC)cc23)c1", "Cc1noc(C)c1CN(C)C(=O)Nc1cc(F)cc(F)c1", @@ -25,15 +38,36 @@ def test_fcd_torch(self): "Cc1noc(C)c1CN(C)C(=O)Nc1cc(F)cc(F)c1", ] target = 47.773486382119444 - self.assertAlmostEqual( - get_fcd(smiles_list1, smiles_list2), target, 3, f"Should be {target}" - ) + fcd = get_fcd(smiles_list1, smiles_list2, device="cpu") + + assert fcd == approx(target, abs=1e-3) + + def test_random_smiles_gpu_2(self): + # Skip test if CUDA is not available + # CUDA comp is less consistent than CPU + import torch + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + smiles_list1 = [ + "COc1cccc(NC(=O)Cc2coc3ccc(OC)cc23)c1", + "Cc1noc(C)c1CN(C)C(=O)Nc1cc(F)cc(F)c1", + ] + smiles_list2 = [ + "Oc1ccccc1-c1cccc2cnccc12", + "Cc1noc(C)c1CN(C)C(=O)Nc1cc(F)cc(F)c1", + ] + target = 47.773486382119444 + fcd = get_fcd(smiles_list1, smiles_list2) + + assert fcd == approx(target, abs=1e-3) def test_one_hot(self): inputs = [ "O=C([C@H]1CC[C@@H]2[C@@H]", "CCN(CC)S(=O)(=O)CCC[C@H](", - "COc1ccc(\C=C\2/C[N+](C)(C", + r"COc1ccc(\C=C\2/C[N+](C)(C", "COc1cc(CC(=O)N2CCN(CC2)c3", "CN(CC#Cc1cc(ccc1C)ClC2(O)CC", "CYACCyyCLCl", @@ -42,7 +76,7 @@ def test_one_hot(self): outputs = [ ((26, 35), [2, 25, 0, 13, 26, 0, 28, 3, 27, 17, 0, 0, 26, 0, 28, 28, 3, 27, 18, 26, 0, 28, 28, 3, 27, 34]), ((26, 35), [0, 0, 1, 13, 0, 0, 14, 9, 13, 25, 2, 14, 13, 25, 2, 14, 0, 0, 0, 26, 0, 28, 3, 27, 13, 34]), - ((25, 35), [0, 2, 29, 17, 29, 29, 29, 13, 33, 0, 25, 0, 33, 33, 0, 26, 1, 15, 27, 13, 0, 14, 13, 0, 34]), + ((26, 35), [0, 2, 29, 17, 29, 29, 29, 13, 33, 0, 25, 0, 33, 18, 33, 0, 26, 1, 15, 27, 13, 0, 14, 13, 0, 34]), ((26, 35), [0, 2, 29, 17, 29, 29, 13, 0, 0, 13, 25, 2, 14, 1, 18, 0, 0, 1, 13, 0, 0, 18, 14, 29, 19, 34]), ((28, 35), [0, 1, 13, 0, 0, 12, 0, 29, 17, 29, 29, 13, 29, 29, 29, 17, 0, 14, 5, 0, 18, 13, 2, 14, 0, 0, 34]), ((12, 35), [0, 33, 33, 0, 0, 33, 33, 0, 33, 5, 34]), @@ -53,13 +87,49 @@ def test_one_hot(self): one_hot = get_one_hot(inp) shape = one_hot.shape entries = np.where(one_hot)[1].tolist() - self.assertEqual(shape, correct_shape) - self.assertEqual(entries, correct_entries) + assert shape == correct_shape + assert entries == correct_entries # assert that no duplicate ones and no missing entries. Trailing zero vectors are allowed. non_zero_idx = np.where(one_hot)[0] assert np.all(non_zero_idx == np.arange(len(non_zero_idx))) + def test_one_hot_padding(self): + smiles = "CNOHFCCCCCCCC" + pad_len = 5 + with pytest.raises(AssertionError): + one_hot = get_one_hot(smiles, pad_len=pad_len) + + +class TestSmilesDataset: + def test_dataset_okay(self): + smiles = ["CNOHF", "NOHFCl", "OHFClP", "HFClPB"] + smiles_dataset = SmilesDataset(smiles) + assert len(smiles_dataset) == len(smiles) + assert smiles_dataset.pad_len == 350 + + def test_smiles_too_long(self): + """Check if warning is raised when smiles are too long for default pad_length""" + + smiles = ["CNOHF" * 100, "NOHFCl", "OHFClP", "HFClPB"] + with pytest.warns(UserWarning): + smiles_dataset = SmilesDataset(smiles) + + assert len(smiles_dataset) == len(smiles) + assert smiles_dataset.pad_len == 501 # plus one for the end token + + def test_smiles_one_off(self): + smiles = ["CCCCC"] + with pytest.warns(UserWarning): + smiles_dataset = SmilesDataset(smiles, pad_len=len(smiles[0]) + 1) # plus one for the end token + + assert isinstance(smiles_dataset[0], np.ndarray) + + def test_custom_pad_length(self): + """Check if custom pad_length is used and warning is issued""" + smiles = ["CNOHF", "NOHFCl", "OHFClP", "HFClPB"] + with pytest.warns(UserWarning): + smiles_dataset = SmilesDataset(smiles, pad_len=20) -if __name__ == "__main__": - unittest.main() + assert len(smiles_dataset) == len(smiles) + assert smiles_dataset.pad_len == 20