Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sampler.py issues creating sequence-score mismatches #15

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
name: evo_env
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- bottleneck=1.3.7=py312ha883a20_0
- brotli=1.0.9=h5eee18b_8
- brotli-bin=1.0.9=h5eee18b_8
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.7.2=h06a4308_0
- et_xmlfile=1.1.0=py312h06a4308_1
- expat=2.6.2=h6a678d5_0
- freetype=2.12.1=h4a9f257_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- joblib=1.4.2=py312h06a4308_0
- jpeg=9e=h5eee18b_1
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libbrotlicommon=1.0.9=h5eee18b_8
- libbrotlidec=1.0.9=h5eee18b_8
- libbrotlienc=1.0.9=h5eee18b_8
- libdeflate=1.17=h5eee18b_1
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtiff=4.5.1=h6a678d5_0
- libuuid=1.41.5=h5eee18b_0
- libwebp-base=1.3.2=h5eee18b_0
- lz4-c=1.9.4=h6a678d5_1
- matplotlib-base=3.8.4=py312h526ad5a_0
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py312h5eee18b_1
- mkl_fft=1.3.8=py312h5eee18b_0
- mkl_random=1.2.4=py312hdb19cb5_0
- ncurses=6.4=h6a678d5_0
- numexpr=2.8.7=py312hf827012_0
- numpy=1.26.4=py312hc5e2394_0
- numpy-base=1.26.4=py312h0da6c21_0
- openjpeg=2.4.0=h9ca470c_2
- openpyxl=3.1.2=py312h5eee18b_0
- openssl=3.0.14=h5eee18b_0
- packaging=24.1=py312h06a4308_0
- pandas=2.2.2=py312h526ad5a_0
- pillow=10.4.0=py312h5eee18b_0
- pip=24.0=py312h06a4308_0
- pybind11-abi=5=hd3eb1b0_0
- python=3.12.4=h5148396_1
- python-dateutil=2.9.0post0=py312h06a4308_2
- python-tzdata=2023.3=pyhd3eb1b0_0
- pytz=2024.1=py312h06a4308_0
- readline=8.2=h5eee18b_0
- scipy=1.13.1=py312hc5e2394_0
- seaborn=0.13.2=py312h06a4308_0
- setuptools=69.5.1=py312h06a4308_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.45.3=h5eee18b_0
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.14=h39e8969_0
- unicodedata2=15.1.0=py312h5eee18b_0
- wheel=0.43.0=py312h06a4308_0
- xz=5.4.6=h5eee18b_1
- zlib=1.2.13=h5eee18b_1
- zstd=1.5.5=hc292b87_2
- pip:
- accelerate==0.32.1
- anyio==4.4.0
- argon2-cffi==23.1.0
- argon2-cffi-bindings==21.2.0
- arrow==1.3.0
- asttokens==2.4.1
- async-lru==2.0.4
- attrs==23.2.0
- babel==2.15.0
- beautifulsoup4==4.12.3
- bleach==6.1.0
- certifi==2024.7.4
- cffi==1.16.0
- charset-normalizer==3.3.2
- comm==0.2.2
- contourpy==1.2.1
- cycler==0.12.1
- debugpy==1.8.2
- decorator==5.1.1
- defusedxml==0.7.1
- executing==2.0.1
- fair-esm==2.0.1
- fastjsonschema==2.20.0
- filelock==3.15.4
- fonttools==4.53.1
- fqdn==1.5.1
- fsspec==2024.6.1
- h11==0.14.0
- httpcore==1.0.5
- httpx==0.27.0
- huggingface-hub==0.23.5
- idna==3.7
- ipykernel==6.29.5
- ipython==8.26.0
- ipywidgets==8.1.3
- isoduration==20.11.0
- jedi==0.19.1
- jinja2==3.1.4
- json5==0.9.25
- jsonpointer==3.0.0
- jsonschema==4.23.0
- jsonschema-specifications==2023.12.1
- jupyter==1.0.0
- jupyter-client==8.6.2
- jupyter-console==6.6.3
- jupyter-core==5.7.2
- jupyter-events==0.10.0
- jupyter-lsp==2.2.5
- jupyter-server==2.14.2
- jupyter-server-terminals==0.5.3
- jupyterlab==4.2.3
- jupyterlab-pygments==0.3.0
- jupyterlab-server==2.27.3
- jupyterlab-widgets==3.0.11
- kiwisolver==1.4.5
- markupsafe==2.1.5
- matplotlib==3.9.1
- matplotlib-inline==0.1.7
- mistune==3.0.2
- mpmath==1.3.0
- nbclient==0.10.0
- nbconvert==7.16.4
- nbformat==5.10.4
- nest-asyncio==1.6.0
- networkx==3.3
- notebook==7.2.1
- notebook-shim==0.2.4
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu12==12.1.105
- nvidia-cuda-nvrtc-cu12==12.1.105
- nvidia-cuda-runtime-cu12==12.1.105
- nvidia-cudnn-cu12==8.9.2.26
- nvidia-cufft-cu12==11.0.2.54
- nvidia-curand-cu12==10.3.2.106
- nvidia-cusolver-cu12==11.4.5.107
- nvidia-cusparse-cu12==12.1.0.106
- nvidia-nccl-cu12==2.20.5
- nvidia-nvjitlink-cu12==12.5.82
- nvidia-nvtx-cu12==12.1.105
- overrides==7.7.0
- pandocfilters==1.5.1
- parso==0.8.4
- pexpect==4.9.0
- platformdirs==4.2.2
- prometheus-client==0.20.0
- prompt-toolkit==3.0.47
- psutil==6.0.0
- ptyprocess==0.7.0
- pure-eval==0.2.2
- pycparser==2.22
- pygments==2.18.0
- pyparsing==3.1.2
- python-json-logger==2.0.7
- pyyaml==6.0.1
- pyzmq==26.0.3
- qtconsole==5.5.2
- qtpy==2.4.1
- referencing==0.35.1
- regex==2024.5.15
- requests==2.32.3
- rfc3339-validator==0.1.4
- rfc3986-validator==0.1.1
- rpds-py==0.19.0
- safetensors==0.4.3
- send2trash==1.8.3
- sniffio==1.3.1
- soupsieve==2.5
- stack-data==0.6.3
- sympy==1.13.0
- terminado==0.18.1
- tinycss2==1.3.0
- tokenizers==0.15.2
- torch==2.3.1
- tornado==6.4.1
- tqdm==4.66.4
- traitlets==5.14.3
- transformers==4.38.0
- types-python-dateutil==2.9.0.20240316
- typing-extensions==4.12.2
- tzdata==2024.1
- uri-template==1.3.0
- urllib3==2.2.2
- wcwidth==0.2.13
- webcolors==24.6.0
- webencodings==0.5.1
- websocket-client==1.8.0
- widgetsnbextension==4.0.11
prefix: /home/carlos/miniconda3/envs/evo_env
87 changes: 70 additions & 17 deletions evo_prot_grad/common/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,34 @@
import pandas as pd
import gc

def convert_full_to_relative_sequences(full_sequences, ref_seq):
"""
Convert a list of full sequences into relative sequence format based on a reference sequence.

Parameters:
full_sequences (list): List of full sequences (strings) to be converted.
ref_seq (str): Reference sequence (string) to compare against.

Returns:
list: List of relative sequences in the format 'A1B-A3C' where A is the reference amino acid,
1 is the site index (1-based), and B is the mutant amino acid.
"""
relative_sequences = []

for seq in full_sequences:
mutations = []
for i, (ref_aa, mut_aa) in enumerate(zip(ref_seq, seq)):
#print((ref_aa, mut_aa))
if ref_aa != mut_aa:
mutation = f"{ref_aa}{i+1}{mut_aa}"
mutations.append(mutation)
relative_sequences.append('-'.join(mutations))

return relative_sequences

def prep_seqs(seqs, ref_seq):
return [convert_full_to_relative_sequences([seq.replace(" ", "")], ref_seq.replace(" ", "")) for seq in seqs]

class DirectedEvolution:
"""Main class for plug and play directed evolution with gradient-based discrete MCMC.
"""
Expand Down Expand Up @@ -134,7 +162,8 @@ def _product_of_experts(self, inputs: List[str]) -> Tuple[List[torch.Tensor], to
oh, score = expert(inputs)
ohs += [oh]
scores += [expert.temperature * score]
# sum scores over experts
inputs2 = [seq.replace(" ", "") for seq in inputs]
scores2 = [score.detach().cpu().numpy() for score in scores]
return ohs, torch.stack(scores, dim=0).sum(dim=0)


Expand Down Expand Up @@ -184,6 +213,8 @@ def prepare_results(self, variants, scores, n_seqs_to_keep=None):
sequence_scores = {}

# Iterate through the flattened list to count sequences and record first appearance
if len(scores.shape) == 1: # If scores has a single dimension
scores = scores.reshape(-1, 1) # Reshape to have two dimensions
for i, sublist in enumerate(variants):
for j, seq in enumerate(sublist):
flat_seq = ''.join(seq.split())
Expand Down Expand Up @@ -242,6 +273,11 @@ def save_results(self, filename, variants, scores, n_seqs_to_keep=10000):
for key, value in params.items():
f.write(f'{key}: {value}\n')

def _recompute_score(self, seq):
ohs, score = self.experts[0]([seq])
return score.detach().flatten().cpu().numpy()


def __call__(self) -> Tuple[List[str], np.ndarray]:
"""
Run the gradient-based MCMC sampler.
Expand All @@ -261,6 +297,9 @@ def __call__(self) -> Tuple[List[str], np.ndarray]:
pos_mask = pos_mask.reshape(self.parallel_chains,-1)

for i in range(self.n_steps):
print(f"#### Starting iteration {i}.")
start_seqs = self.chains.copy()
self.chains = self.canonical_chain_tokenizer.decode(cur_chains_oh) # to reflect any reset chains that reached multiple mutations
###### sample path length
U = torch.randint(1, 2 * self.max_pas_path_length, size=(self.parallel_chains,1))
max_u = int(torch.max(U).item())
Expand All @@ -274,7 +313,6 @@ def __call__(self) -> Tuple[List[str], np.ndarray]:
# Need to use the string version of the chain to pass to experts
ohs, PoE = self._product_of_experts(self.chains)
grad_x = self._compute_gradients(ohs, PoE)

# do U intermediate steps
with torch.no_grad():
for step in range(max_u):
Expand Down Expand Up @@ -311,12 +349,12 @@ def __call__(self) -> Tuple[List[str], np.ndarray]:
row_select = changes_all.sum(-1).unsqueeze(-1) # [n_chains,seq_len,1]
new_x = cur_chains_oh * (1.0 - row_select) + changes_all
cur_u_mask = u_mask[:, step].unsqueeze(-1).unsqueeze(-1)
cur_chains_oh = cur_u_mask * new_x + (1 - cur_u_mask) * cur_chains_oh
cur_chains_oh2 = cur_u_mask * new_x + (1 - cur_u_mask) * cur_chains_oh

y = cur_chains_oh
y = cur_chains_oh2.clone()

# last step
y_strs = self.canonical_chain_tokenizer.decode(y)
y_strs = self.canonical_chain_tokenizer.decode(y) # Created string version of potentially new seqs. Compare them to old seqs
ohs, proposed_PoE = self._product_of_experts(y_strs)
grad_y = self._compute_gradients(ohs, proposed_PoE)
grad_y = grad_y.detach()
Expand All @@ -336,20 +374,32 @@ def __call__(self) -> Tuple[List[str], np.ndarray]:
#log_acc = log_backwd - log_fwd
m_term = (proposed_PoE.squeeze() - PoE.squeeze())
log_acc = m_term + log_ratio
#print(f"log_acc has shape {log_acc}, m_term has shape {m_term.shape}, and log_ratio has shape {log_ratio.shape}.")
accepted = (log_acc.exp() >= torch.rand_like(log_acc)).float().view(-1, *([1] * x_rank)) # original
#accepted = (log_acc.exp() >= torch.rand_like(log_acc)).float().view(-1, 1, 1)
#print(f"y has shape {y.shape}, and accepted has shape {accepted.shape}")
cur_chains_oh = y * accepted + (1.0 - accepted) * cur_chains_oh

# handle with a for loop
accepted, PoE = accepted.squeeze(), PoE.squeeze().clone()

dec_seqs = self.canonical_chain_tokenizer.decode(cur_chains_oh2)
for i in range(len(accepted)):
if accepted[i] == 1:
cur_chains_oh[i, :, :] = y[i, :, :]
PoE[i] = proposed_PoE[i]
self.chains[i] = self.canonical_chain_tokenizer.decode(cur_chains_oh[i, :, :].unsqueeze(0))[0]
#cur_chains_oh2 = y * accepted + (1.0 - accepted) * cur_chains_oh
#for chain_idx in range(self.parallel_chains):
# if accepted.squeeze()[chain_idx] == 0:
# # Compare cur_chains_oh and cur_chains_oh2 for chains where accepted is 0
# if not torch.equal(cur_chains_oh[chain_idx], cur_chains_oh2[chain_idx]):
#cur_chains_oh = cur_chains_oh2

# Current chain state book-keeping
self.chains_oh = cur_chains_oh
self.chains = self.canonical_chain_tokenizer.decode(cur_chains_oh)
self.chains_oh = cur_chains_oh.clone()
# Check that cur_chains_oh and self.chains are synchronized
decoded_sequences = self.canonical_chain_tokenizer.decode(self.chains_oh)
dec_seqs = self.canonical_chain_tokenizer.decode(self.chains_oh)
# History book-keeping
self.chains_oh_history += [cur_chains_oh.clone()]
PoE = proposed_PoE.squeeze() * accepted.squeeze() + PoE.squeeze() * (1. - accepted.squeeze())
self.PoE_history += [PoE.clone()]

if self.verbose:
x_strs = self.canonical_chain_tokenizer.decode(cur_chains_oh)
for idx in range(log_acc.size(0)):
Expand All @@ -364,7 +414,7 @@ def __call__(self) -> Tuple[List[str], np.ndarray]:
mask_flag = (dist >= self.max_mutations).bool()
mask_flag = mask_flag.reshape(self.parallel_chains)
cur_chains_oh[mask_flag] = self.wt_oh

if i > 10 and i % 100 == 0:
print(f"Finished step {i} out of {self.n_steps}.")
if torch.cuda.is_available():
Expand All @@ -376,8 +426,11 @@ def __call__(self) -> Tuple[List[str], np.ndarray]:
scores_ = self.PoE_history[-1].detach().cpu().numpy()
elif self.output == 'all':
output_ = []
for i in range(len(self.chains_oh_history)):
output_ += [ self.canonical_chain_tokenizer.decode(self.chains_oh_history[i]) ]
for j in range(len(self.chains_oh_history)):
decoded_sequences = self.canonical_chain_tokenizer.decode(self.chains_oh_history[j])
scores = self.PoE_history[j].detach().cpu().numpy() # Convert PoE to numpy for easier handling
seqs = prep_seqs(decoded_sequences, self.wtseq)
output_ += [ decoded_sequences ]
scores_ = torch.stack(self.PoE_history).detach().cpu().numpy()
elif self.output == 'best':
best_idxs = torch.stack(self.PoE_history).argmax(0)
Expand Down
Loading