Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix numerical issues and efficiency issues #19

Merged
merged 2 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/test_dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Tests (dev)

on:
push:
branches: [ "dev" ]
pull_request:
branches: [ "dev" ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
40 changes: 40 additions & 0 deletions .github/workflows/test_master.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Tests (master)

on:
push:
branches: [ "master"]
pull_request:
branches: [ "master"]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# Fréchet ChemNet Distance
![PyPI](https://img.shields.io/pypi/v/fcd)
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/fcd)
![Tests (master)](https://github.com/bioinf-jku/fcd/actions/workflows/test_master.yml/badge.svg?branch=dev)
![Tests (dev)](https://github.com/bioinf-jku/fcd/actions/workflows/test_dev.yml/badge.svg?branch=dev)
![PyPI - Downloads](https://img.shields.io/pypi/dm/fcd)
![GitHub release (latest by date)](https://img.shields.io/github/v/release/bioinf-jku/fcd)
![GitHub release date](https://img.shields.io/github/release-date/bioinf-jku/fcd)
![GitHub](https://img.shields.io/github/license/bioinf-jku/fcd)


Code for the paper "Fréchet ChemNet Distance: A Metric for Generative Models for Molecules in Drug Discovery"
[JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.8b00234) /
Expand Down
16 changes: 13 additions & 3 deletions fcd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from .fcd import get_fcd, get_predictions, load_ref_model
from .utils import calculate_frechet_distance, canonical_smiles
# ruff: noqa: F401

__version__ = "1.2"
from fcd.fcd import get_fcd, get_predictions, load_ref_model
from fcd.utils import calculate_frechet_distance, canonical_smiles

__all__ = [
"get_fcd",
"get_predictions",
"load_ref_model",
"calculate_frechet_distance",
"canonical_smiles",
]

__version__ = "1.2.1"
47 changes: 29 additions & 18 deletions fcd/fcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import nn
from torch.utils.data import DataLoader

from .utils import (
from fcd.utils import (
SmilesDataset,
calculate_frechet_distance,
load_imported_model,
Expand All @@ -31,6 +31,8 @@ def load_ref_model(model_path: Optional[str] = None):
if model_path is None:
chemnet_model_filename = "ChemNet_v0.13_pretrained.pt"
model_bytes = pkgutil.get_data("fcd", chemnet_model_filename)
if model_bytes is None:
raise FileNotFoundError(f"Could not find model file {chemnet_model_filename}")

tmpdir = tempfile.TemporaryDirectory()
model_path = os.path.join(tmpdir.name, chemnet_model_filename)
Expand All @@ -48,7 +50,7 @@ def get_predictions(
smiles_list: List[str],
batch_size: int = 128,
n_jobs: int = 1,
device: str = "cpu",
device: Optional[str] = None,
) -> np.ndarray:
"""Calculate Chemnet activations
Expand All @@ -65,46 +67,55 @@ def get_predictions(
if len(smiles_list) == 0:
return np.zeros((0, 512))

dataloader = DataLoader(
SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs
)
dataloader = DataLoader(SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs)

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

with todevice(model, device), torch.no_grad():
chemnet_activations = []
for batch in dataloader:
chemnet_activations.append(
model(batch.transpose(1, 2).float().to(device))
.to("cpu")
.detach()
.numpy()
model(batch.transpose(1, 2).float().to(device)).to("cpu").detach().numpy().astype(np.float32)
)
return np.row_stack(chemnet_activations)


def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module = None) -> float:
def get_fcd(smiles1: List[str], smiles2: List[str], model: Optional[nn.Module] = None, device=None) -> float:
"""Calculate FCD between two sets of Smiles
Args:
smiles1 (List[str]): First set of smiles
smiles2 (List[str]): Second set of smiles
smiles1 (List[str]): First set of SMILES.
smiles2 (List[str]): Second set of SMILES.
model (nn.Module, optional): The model to use. Loads default model if None.
device: The device to use for computation.
Returns:
float: The FCD score
float: The FCD score.
Raises:
ValueError: If the input SMILES lists are empty.
Example:
>>> smiles1 = ['CCO', 'CCN']
>>> smiles2 = ['CCO', 'CCC']
>>> fcd_score = get_fcd(smiles1, smiles2)
"""
if not smiles1 or not smiles2:
raise ValueError("Input SMILES lists cannot be empty.")

if model is None:
model = load_ref_model()

act1 = get_predictions(model, smiles1)
act2 = get_predictions(model, smiles2)
act1 = get_predictions(model, smiles1, device=device)
act2 = get_predictions(model, smiles2, device=device)

mu1 = np.mean(act1, axis=0)
sigma1 = np.cov(act1.T)

mu2 = np.mean(act2, axis=0)
sigma2 = np.cov(act2.T)

fcd_score = calculate_frechet_distance(
mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2
)
fcd_score = calculate_frechet_distance(mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2)

return fcd_score
76 changes: 57 additions & 19 deletions fcd/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
import warnings
from contextlib import contextmanager
from multiprocessing import Pool
from typing import List
from typing import List, Optional

import numpy as np
import torch
Expand All @@ -10,7 +11,7 @@
from torch import nn
from torch.utils.data import Dataset

from .torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose
from fcd.torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose

# fmt: off
__vocab = ["C","N","O","H","F","Cl","P","B","Br","S","I","Si","#","(",")","+","-","1","2","3","4","5","6","7","8","=","[","]","@","c","n","o","s","X","."]
Expand Down Expand Up @@ -42,7 +43,7 @@ def tokenize(smiles: str) -> List[str]:
return tok_smile


def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
def get_one_hot(smiles: str, pad_len: Optional[int] = None) -> np.ndarray:
"""Generate one-hot representation of a Smiles string.

Args:
Expand All @@ -52,10 +53,13 @@ def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
Returns:
np.ndarray: Array containing the one-hot encoded Smiles
"""
# add end token
smiles = smiles + "."

# initialize array
array_length = len(smiles) if pad_len < 0 else pad_len
array_length = len(smiles) if pad_len is None else pad_len
assert array_length >= len(smiles), "Pad length must be greater than the length of the input SMILES string + 1."

vocab_size = len(__vocab)
one_hot = np.zeros((array_length, vocab_size))

Expand Down Expand Up @@ -106,22 +110,57 @@ def load_imported_model(keras_config):


class SmilesDataset(Dataset):
__PAD_LEN = 350
"""
A dataset class for handling SMILES data.

Args:
smiles_list (list): A list of SMILES strings.
pad_len (int, optional): The length to pad the SMILES strings to. If not provided, the default pad length of 350 will be used.
warn (bool, optional): Whether to display a warning message if the specified pad length is different from the default. Defaults to True.

Attributes:
smiles_list (list): A list of SMILES strings.
pad_len (int): The length to pad the SMILES strings to.

"""

def __init__(self, smiles_list):
def __init__(self, smiles_list, pad_len=None, warn=True):
super().__init__()
DEFAULT_PAD_LEN = 350

self.smiles_list = smiles_list
max_len = max(len(smiles) for smiles in smiles_list) + 1 # plus one for the end token

if pad_len is None:
pad_len = max(DEFAULT_PAD_LEN, max_len)
else:
if pad_len < max_len:
raise ValueError(f"Specified pad_len {pad_len} is less than max_len {max_len}")

if pad_len != DEFAULT_PAD_LEN:
warnings.warn(
"""Padding lengths differing from the default of 350 may affect FCD scores. See https://github.com/hogru/GuacaMolEval.
Use warn=False to suppress this warning."""
)

self.pad_len = pad_len

def __getitem__(self, idx):
smiles = self.smiles_list[idx]
features = get_one_hot(smiles, 350)
features = get_one_hot(smiles, pad_len=self.pad_len)
return features / features.shape[1]

def __len__(self):
return len(self.smiles_list)


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
def calculate_frechet_distance(
mu1: np.ndarray,
sigma1: np.ndarray,
mu2: np.ndarray,
sigma2: np.ndarray,
eps: float = 1e-6,
) -> float:
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
Expand Down Expand Up @@ -151,21 +190,20 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)

assert (
mu1.shape == mu2.shape
), "Training and test mean vectors have different lengths"
assert (
sigma1.shape == sigma2.shape
), "Training and test covariances have different dimensions"
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

diff = mu1 - mu2

# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
is_real = np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3)

if not np.isfinite(covmean).all() or not is_real:
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

assert isinstance(covmean, np.ndarray)
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
Expand All @@ -175,7 +213,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):

tr_covmean = np.trace(covmean)

return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
return float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)


@contextmanager
Expand All @@ -188,11 +226,11 @@ def todevice(model, device):

def canonical(smi):
try:
return Chem.MolToSmiles(Chem.MolFromSmiles(smi))
except:
return Chem.MolToSmiles(Chem.MolFromSmiles(smi)) # type: ignore
except Exception:
return None


def canonical_smiles(smiles, njobs=32):
def canonical_smiles(smiles, njobs=-1):
with Pool(njobs) as pool:
return pool.map(canonical, smiles)
Loading