Skip to content

Commit

Permalink
Add multiple fixes
Browse files Browse the repository at this point in the history
1. Changes to `get_one_hot`
Problems are given in:
- #14
- #17
- #13

I discarded the changes in the PRs and and added more comprehensive handling of the input data in the
`SmilesDataset` class and the `get_one_hot` function.

2. Imaginary components
Frechet distance calculation fails to work for some cases because of badly conditioned matrices,
as described here #15.

Could not reproduce the error locally, but could do so on colab.

Fixed it in `calculate_frechet_distance` by checking if the first `covmean` computation  is real add a small value to the diagonal.
This made it work for me and I got the same result as the original implementation run locally.

3. Added some more tests and changed to pytest

4. As described in #16
I changed the data type of the activations to float32 in the `get_predictions` function,
which saves memory for larger datasets.

5. Change to pyproject.toml
  • Loading branch information
renzph committed Apr 1, 2024
1 parent b4bcc22 commit bf69345
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 78 deletions.
14 changes: 12 additions & 2 deletions fcd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from .fcd import get_fcd, get_predictions, load_ref_model
from .utils import calculate_frechet_distance, canonical_smiles
# ruff: noqa: F401

from fcd.fcd import get_fcd, get_predictions, load_ref_model
from fcd.utils import calculate_frechet_distance, canonical_smiles

__all__ = [
"get_fcd",
"get_predictions",
"load_ref_model",
"calculate_frechet_distance",
"canonical_smiles",
]

__version__ = "1.2"
47 changes: 29 additions & 18 deletions fcd/fcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import nn
from torch.utils.data import DataLoader

from .utils import (
from fcd.utils import (
SmilesDataset,
calculate_frechet_distance,
load_imported_model,
Expand All @@ -31,6 +31,8 @@ def load_ref_model(model_path: Optional[str] = None):
if model_path is None:
chemnet_model_filename = "ChemNet_v0.13_pretrained.pt"
model_bytes = pkgutil.get_data("fcd", chemnet_model_filename)
if model_bytes is None:
raise FileNotFoundError(f"Could not find model file {chemnet_model_filename}")

tmpdir = tempfile.TemporaryDirectory()
model_path = os.path.join(tmpdir.name, chemnet_model_filename)
Expand All @@ -48,7 +50,7 @@ def get_predictions(
smiles_list: List[str],
batch_size: int = 128,
n_jobs: int = 1,
device: str = "cpu",
device: str | None = None,
) -> np.ndarray:
"""Calculate Chemnet activations
Expand All @@ -65,46 +67,55 @@ def get_predictions(
if len(smiles_list) == 0:
return np.zeros((0, 512))

dataloader = DataLoader(
SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs
)
dataloader = DataLoader(SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs)

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

with todevice(model, device), torch.no_grad():
chemnet_activations = []
for batch in dataloader:
chemnet_activations.append(
model(batch.transpose(1, 2).float().to(device))
.to("cpu")
.detach()
.numpy()
model(batch.transpose(1, 2).float().to(device)).to("cpu").detach().numpy().astype(np.float32)
)
return np.row_stack(chemnet_activations)


def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module = None) -> float:
def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module | None = None, device=None) -> float:
"""Calculate FCD between two sets of Smiles
Args:
smiles1 (List[str]): First set of smiles
smiles2 (List[str]): Second set of smiles
smiles1 (List[str]): First set of SMILES.
smiles2 (List[str]): Second set of SMILES.
model (nn.Module, optional): The model to use. Loads default model if None.
device: The device to use for computation.
Returns:
float: The FCD score
float: The FCD score.
Raises:
ValueError: If the input SMILES lists are empty.
Example:
>>> smiles1 = ['CCO', 'CCN']
>>> smiles2 = ['CCO', 'CCC']
>>> fcd_score = get_fcd(smiles1, smiles2)
"""
if not smiles1 or not smiles2:
raise ValueError("Input SMILES lists cannot be empty.")

if model is None:
model = load_ref_model()

act1 = get_predictions(model, smiles1)
act2 = get_predictions(model, smiles2)
act1 = get_predictions(model, smiles1, device=device)
act2 = get_predictions(model, smiles2, device=device)

mu1 = np.mean(act1, axis=0)
sigma1 = np.cov(act1.T)

mu2 = np.mean(act2, axis=0)
sigma2 = np.cov(act2.T)

fcd_score = calculate_frechet_distance(
mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2
)
fcd_score = calculate_frechet_distance(mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2)

return fcd_score
74 changes: 56 additions & 18 deletions fcd/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import warnings
from contextlib import contextmanager
from multiprocessing import Pool
from typing import List
Expand All @@ -10,7 +11,7 @@
from torch import nn
from torch.utils.data import Dataset

from .torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose
from fcd.torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose

# fmt: off
__vocab = ["C","N","O","H","F","Cl","P","B","Br","S","I","Si","#","(",")","+","-","1","2","3","4","5","6","7","8","=","[","]","@","c","n","o","s","X","."]
Expand Down Expand Up @@ -42,7 +43,7 @@ def tokenize(smiles: str) -> List[str]:
return tok_smile


def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
def get_one_hot(smiles: str, pad_len: int | None = None) -> np.ndarray:
"""Generate one-hot representation of a Smiles string.
Args:
Expand All @@ -52,10 +53,13 @@ def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
Returns:
np.ndarray: Array containing the one-hot encoded Smiles
"""
# add end token
smiles = smiles + "."

# initialize array
array_length = len(smiles) if pad_len < 0 else pad_len
array_length = len(smiles) if pad_len is None else pad_len
assert array_length >= len(smiles), "Pad length must be greater than the length of the input SMILES string + 1."

vocab_size = len(__vocab)
one_hot = np.zeros((array_length, vocab_size))

Expand Down Expand Up @@ -106,22 +110,57 @@ def load_imported_model(keras_config):


class SmilesDataset(Dataset):
__PAD_LEN = 350
"""
A dataset class for handling SMILES data.
Args:
smiles_list (list): A list of SMILES strings.
pad_len (int, optional): The length to pad the SMILES strings to. If not provided, the default pad length of 350 will be used.
warn (bool, optional): Whether to display a warning message if the specified pad length is different from the default. Defaults to True.
Attributes:
smiles_list (list): A list of SMILES strings.
pad_len (int): The length to pad the SMILES strings to.
"""

def __init__(self, smiles_list):
def __init__(self, smiles_list, pad_len=None, warn=True):
super().__init__()
DEFAULT_PAD_LEN = 350

self.smiles_list = smiles_list
max_len = max(len(smiles) for smiles in smiles_list) + 1 # plus one for the end token

if pad_len is None:
pad_len = max(DEFAULT_PAD_LEN, max_len)
else:
if pad_len < max_len:
raise ValueError(f"Specified pad_len {pad_len} is less than max_len {max_len}")

if pad_len != DEFAULT_PAD_LEN:
warnings.warn(
"""Padding lengths differing from the default of 350 may affect FCD scores. See https://github.com/hogru/GuacaMolEval.
Use warn=False to suppress this warning."""
)

self.pad_len = pad_len

def __getitem__(self, idx):
smiles = self.smiles_list[idx]
features = get_one_hot(smiles, 350)
features = get_one_hot(smiles, pad_len=self.pad_len)
return features / features.shape[1]

def __len__(self):
return len(self.smiles_list)


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
def calculate_frechet_distance(
mu1: np.ndarray,
sigma1: np.ndarray,
mu2: np.ndarray,
sigma2: np.ndarray,
eps: float = 1e-6,
) -> float:
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
Expand Down Expand Up @@ -151,21 +190,20 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)

assert (
mu1.shape == mu2.shape
), "Training and test mean vectors have different lengths"
assert (
sigma1.shape == sigma2.shape
), "Training and test covariances have different dimensions"
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

diff = mu1 - mu2

# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
is_real = np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3)

if not np.isfinite(covmean).all() or not is_real:
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

assert isinstance(covmean, np.ndarray)
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
Expand All @@ -175,7 +213,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):

tr_covmean = np.trace(covmean)

return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
return float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)


@contextmanager
Expand All @@ -188,11 +226,11 @@ def todevice(model, device):

def canonical(smi):
try:
return Chem.MolToSmiles(Chem.MolFromSmiles(smi))
except:
return Chem.MolToSmiles(Chem.MolFromSmiles(smi)) # type: ignore
except Exception:
return None


def canonical_smiles(smiles, njobs=32):
def canonical_smiles(smiles, njobs=-1):
with Pool(njobs) as pool:
return pool.map(canonical, smiles)
64 changes: 64 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "fcd"
version = "1.2.1"
dependencies = ["torch", "numpy", "scipy", "rdkit"]
requires-python = ">=3.8"
authors = [
{ name = "Philipp Renz", email = "[email protected]" },
]
description = "Fréchet ChEMNet Distance"
readme = "README.md"
license = { file = "LICENSE" }
keywords = ["cheminformatics", "machine learning", "deep learning", "generative models"]
classifiers = [
"Programming Language :: Python",
"Development Status :: 5 - Production/Stable",
]

[project.urls]
Homepage = "https://github.com/bioinf-jku/FCD"
Documentation = "https://github.com/bioinf-jku/FCD"
Repository = "https://github.com/bioinf-jku/FCD"

[tool.ruff]
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]

# Same as Black.
line-length = 120
indent-width = 4

# Assume Python 3.8
target-version = "py38"

22 changes: 0 additions & 22 deletions setup.py

This file was deleted.

Loading

0 comments on commit bf69345

Please sign in to comment.