Skip to content

Commit

Permalink
Merge pull request #19 from bioinf-jku/dev
Browse files Browse the repository at this point in the history
Fix numerical issues and efficiency issues
  • Loading branch information
renzph committed Apr 1, 2024
2 parents b4bcc22 + 2c45937 commit f806d58
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 80 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/test_dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Tests (dev)

on:
push:
branches: [ "dev" ]
pull_request:
branches: [ "dev" ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
40 changes: 40 additions & 0 deletions .github/workflows/test_master.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Tests (master)

on:
push:
branches: [ "master"]
pull_request:
branches: [ "master"]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Fréchet ChemNet Distance
![PyPI](https://img.shields.io/pypi/v/fcd)
![Tests (master)](https://github.com/bioinf-jku/fcd/actions/workflows/test_master.yml/badge.svg?branch=dev)
![Tests (dev)](https://github.com/bioinf-jku/fcd/actions/workflows/test_dev.yml/badge.svg?branch=dev)
![PyPI - Downloads](https://img.shields.io/pypi/dm/fcd)
![GitHub release (latest by date)](https://img.shields.io/github/v/release/bioinf-jku/fcd)
![GitHub release date](https://img.shields.io/github/release-date/bioinf-jku/fcd)
![GitHub](https://img.shields.io/github/license/bioinf-jku/fcd)


Code for the paper "Fréchet ChemNet Distance: A Metric for Generative Models for Molecules in Drug Discovery"
[JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.8b00234) /
Expand Down
16 changes: 13 additions & 3 deletions fcd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from .fcd import get_fcd, get_predictions, load_ref_model
from .utils import calculate_frechet_distance, canonical_smiles
# ruff: noqa: F401

__version__ = "1.2"
from fcd.fcd import get_fcd, get_predictions, load_ref_model
from fcd.utils import calculate_frechet_distance, canonical_smiles

__all__ = [
"get_fcd",
"get_predictions",
"load_ref_model",
"calculate_frechet_distance",
"canonical_smiles",
]

__version__ = "1.2.1"
47 changes: 29 additions & 18 deletions fcd/fcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import nn
from torch.utils.data import DataLoader

from .utils import (
from fcd.utils import (
SmilesDataset,
calculate_frechet_distance,
load_imported_model,
Expand All @@ -31,6 +31,8 @@ def load_ref_model(model_path: Optional[str] = None):
if model_path is None:
chemnet_model_filename = "ChemNet_v0.13_pretrained.pt"
model_bytes = pkgutil.get_data("fcd", chemnet_model_filename)
if model_bytes is None:
raise FileNotFoundError(f"Could not find model file {chemnet_model_filename}")

tmpdir = tempfile.TemporaryDirectory()
model_path = os.path.join(tmpdir.name, chemnet_model_filename)
Expand All @@ -48,7 +50,7 @@ def get_predictions(
smiles_list: List[str],
batch_size: int = 128,
n_jobs: int = 1,
device: str = "cpu",
device: Optional[str] = None,
) -> np.ndarray:
"""Calculate Chemnet activations
Expand All @@ -65,46 +67,55 @@ def get_predictions(
if len(smiles_list) == 0:
return np.zeros((0, 512))

dataloader = DataLoader(
SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs
)
dataloader = DataLoader(SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs)

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

with todevice(model, device), torch.no_grad():
chemnet_activations = []
for batch in dataloader:
chemnet_activations.append(
model(batch.transpose(1, 2).float().to(device))
.to("cpu")
.detach()
.numpy()
model(batch.transpose(1, 2).float().to(device)).to("cpu").detach().numpy().astype(np.float32)
)
return np.row_stack(chemnet_activations)


def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module = None) -> float:
def get_fcd(smiles1: List[str], smiles2: List[str], model: Optional[nn.Module] = None, device=None) -> float:
"""Calculate FCD between two sets of Smiles
Args:
smiles1 (List[str]): First set of smiles
smiles2 (List[str]): Second set of smiles
smiles1 (List[str]): First set of SMILES.
smiles2 (List[str]): Second set of SMILES.
model (nn.Module, optional): The model to use. Loads default model if None.
device: The device to use for computation.
Returns:
float: The FCD score
float: The FCD score.
Raises:
ValueError: If the input SMILES lists are empty.
Example:
>>> smiles1 = ['CCO', 'CCN']
>>> smiles2 = ['CCO', 'CCC']
>>> fcd_score = get_fcd(smiles1, smiles2)
"""
if not smiles1 or not smiles2:
raise ValueError("Input SMILES lists cannot be empty.")

if model is None:
model = load_ref_model()

act1 = get_predictions(model, smiles1)
act2 = get_predictions(model, smiles2)
act1 = get_predictions(model, smiles1, device=device)
act2 = get_predictions(model, smiles2, device=device)

mu1 = np.mean(act1, axis=0)
sigma1 = np.cov(act1.T)

mu2 = np.mean(act2, axis=0)
sigma2 = np.cov(act2.T)

fcd_score = calculate_frechet_distance(
mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2
)
fcd_score = calculate_frechet_distance(mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2)

return fcd_score
76 changes: 57 additions & 19 deletions fcd/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
import warnings
from contextlib import contextmanager
from multiprocessing import Pool
from typing import List
from typing import List, Optional

import numpy as np
import torch
Expand All @@ -10,7 +11,7 @@
from torch import nn
from torch.utils.data import Dataset

from .torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose
from fcd.torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose

# fmt: off
__vocab = ["C","N","O","H","F","Cl","P","B","Br","S","I","Si","#","(",")","+","-","1","2","3","4","5","6","7","8","=","[","]","@","c","n","o","s","X","."]
Expand Down Expand Up @@ -42,7 +43,7 @@ def tokenize(smiles: str) -> List[str]:
return tok_smile


def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
def get_one_hot(smiles: str, pad_len: Optional[int] = None) -> np.ndarray:
"""Generate one-hot representation of a Smiles string.
Args:
Expand All @@ -52,10 +53,13 @@ def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
Returns:
np.ndarray: Array containing the one-hot encoded Smiles
"""
# add end token
smiles = smiles + "."

# initialize array
array_length = len(smiles) if pad_len < 0 else pad_len
array_length = len(smiles) if pad_len is None else pad_len
assert array_length >= len(smiles), "Pad length must be greater than the length of the input SMILES string + 1."

vocab_size = len(__vocab)
one_hot = np.zeros((array_length, vocab_size))

Expand Down Expand Up @@ -106,22 +110,57 @@ def load_imported_model(keras_config):


class SmilesDataset(Dataset):
__PAD_LEN = 350
"""
A dataset class for handling SMILES data.
Args:
smiles_list (list): A list of SMILES strings.
pad_len (int, optional): The length to pad the SMILES strings to. If not provided, the default pad length of 350 will be used.
warn (bool, optional): Whether to display a warning message if the specified pad length is different from the default. Defaults to True.
Attributes:
smiles_list (list): A list of SMILES strings.
pad_len (int): The length to pad the SMILES strings to.
"""

def __init__(self, smiles_list):
def __init__(self, smiles_list, pad_len=None, warn=True):
super().__init__()
DEFAULT_PAD_LEN = 350

self.smiles_list = smiles_list
max_len = max(len(smiles) for smiles in smiles_list) + 1 # plus one for the end token

if pad_len is None:
pad_len = max(DEFAULT_PAD_LEN, max_len)
else:
if pad_len < max_len:
raise ValueError(f"Specified pad_len {pad_len} is less than max_len {max_len}")

if pad_len != DEFAULT_PAD_LEN:
warnings.warn(
"""Padding lengths differing from the default of 350 may affect FCD scores. See https://github.com/hogru/GuacaMolEval.
Use warn=False to suppress this warning."""
)

self.pad_len = pad_len

def __getitem__(self, idx):
smiles = self.smiles_list[idx]
features = get_one_hot(smiles, 350)
features = get_one_hot(smiles, pad_len=self.pad_len)
return features / features.shape[1]

def __len__(self):
return len(self.smiles_list)


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
def calculate_frechet_distance(
mu1: np.ndarray,
sigma1: np.ndarray,
mu2: np.ndarray,
sigma2: np.ndarray,
eps: float = 1e-6,
) -> float:
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
Expand Down Expand Up @@ -151,21 +190,20 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)

assert (
mu1.shape == mu2.shape
), "Training and test mean vectors have different lengths"
assert (
sigma1.shape == sigma2.shape
), "Training and test covariances have different dimensions"
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

diff = mu1 - mu2

# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
is_real = np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3)

if not np.isfinite(covmean).all() or not is_real:
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

assert isinstance(covmean, np.ndarray)
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
Expand All @@ -175,7 +213,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):

tr_covmean = np.trace(covmean)

return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
return float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)


@contextmanager
Expand All @@ -188,11 +226,11 @@ def todevice(model, device):

def canonical(smi):
try:
return Chem.MolToSmiles(Chem.MolFromSmiles(smi))
except:
return Chem.MolToSmiles(Chem.MolFromSmiles(smi)) # type: ignore
except Exception:
return None


def canonical_smiles(smiles, njobs=32):
def canonical_smiles(smiles, njobs=-1):
with Pool(njobs) as pool:
return pool.map(canonical, smiles)
Loading

0 comments on commit f806d58

Please sign in to comment.