diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..f3aa2dc --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,102 @@ +# seq2squiggle contributing guide + +This guide provides an overview of the contribution workflow from setting up a development environment, testing your changes, submitting a pull request and performing a release. + + +It's based on the contribution guide of [breakfast written by Matthew Huska](https://github.com/rki-mf1/breakfast), which follows the packaging guidelines ["Hypermodern Python" by Claudio Jolowicz](https://cjolowicz.github.io/posts/hypermodern-python-01-setup/). + +## New contributor guide +To get an overview of the project itself, read the [README](README.md). + +### Setting up your development tools + +Some tooling needs to be set up before you can work on seq2squiggle. To install this we use mamba, a faster replacement for the conda package manager, and place them in their own environment: + +```sh +mamba create -n seq2squiggle-dev python=3 poetry fortran-compiler nox pre-commit +``` + +Then when you want to work on the project, or at the very least if you want to use poetry commands or run tests, you need to switch to this environment: + +```sh +mamba activate seq2squiggle-dev +``` + +The rest of this document assumes that you have the seq2squiggle-dev environment active. + +### Installing the package + +As you're developing, you can install what you have developed using poetry install into your seq2squiggle-dev conda environment: + +```sh +poetry install +seq2squiggle version +``` + +### Testing + +**Not implemented yet** + +### Adding dependencies, updating dependency versions + +You can add dependencies using poetry: + +```sh +poetry add scikit-learn +poetry add pandas +``` + +You can automatically update the dependency to the newest minor or patch release like this: + +```sh +poetry update pandas +``` + +and for major releases you have to be more explicit, assuming you're coming from 1.x to 2.x: + +```sh +poetry update pandas^2.0 +``` + +### Releasing a new version + +First update the version in pyproject.toml using poetry: + +```sh +poetry version patch +# +git commit -am "Bump version" +git push +``` + +Then tag the commit with the same version number (note the "v" prefix), push the code and push the tag: + +```sh +git tag v0.3.1 +git push origin v0.3.1 +``` + +Now go to github.com and do a release, selecting the version number tag you just pushed. This will automatically trigger the new version being tested and pushed to PyPI if the tests pass. + +### Updating the python version dependency + +Aside from updating package dependencies, it is also sometimes useful to update the dependency on python itself. One way to do this is to edit the pyproject.toml file and change the python version description. Versions can be specified using constraints that are documented in the [poetry docs](https://python-poetry.org/docs/dependency-specification/): + +``` +[tool.poetry.dependencies] +python = "^3.10" # <-- this +``` + +Afterwards, you need to use poetry to update the poetry.lock file to reflect the change that you just made to the pyproject.toml file. Be sure to use the `--no-update` flag to not update the locked versions of all dependency packages. + +```sh +poetry lock --no-update +``` + +Then you need to run your tests to make sure everything is working, commit and push the changes. + +You might also need to update/change the version of python in your conda environment, but I'm not certain about that. + +### Updating the bioconda package when dependencies, dependency versions, or the python version has been changed + +For package updates that don't lead to added/removed dependencies, changes to dependency versions, or changes to the allowed python version, a normal release (as above) is sufficient to automatically update both the PyPI and bioconda packages. However, for changes that do result in changes to dependencies it is necessary to update the bioconda meta.yml file. This is explained in [bioconda docs](https://bioconda.github.io/contributor/updating.html), and they also provide tools to help you with this. diff --git a/README.md b/README.md index 40f5d84..548babe 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ `seq2squiggle` is a deep learning-based tool for generating artifical nanopore signals from DNA sequence data. - + Please cite the following publication if you use `seq2squiggle` in your work: -- Beslic, D., Kucklick, M., Engelmann, S., Fuchs, S., Renards, B.Y., Körber, N. End-to-end simulation of nanopore sequencing signals with feed-forward transformers. bioRxiv (2024). +- Beslic, D., Kucklick, M., Engelmann, S., Fuchs, S., Renard, B. Y., & Körber, N. (2024). End-to-end simulation of nanopore sequencing signals with feed-forward transformers. bioRxiv. doi:10.1101/2024.08.12.607296 ## Installation @@ -50,6 +50,10 @@ Generate 10,000 reads from a fasta file: ``` seq2squiggle predict example.fasta -o example.blow5 -n 10000 ``` +Generate 10,000 reads using R9.4.1 chemistry on a MinION: +``` +seq2squiggle predict example.fasta -o example.blow5 -n 10000 --profile dna_r9_min +``` Generate reads with a coverage of 30: ``` seq2squiggle predict example.fasta -o example.blow5 -c 30 @@ -67,41 +71,38 @@ Export as pod5: seq2squiggle predict example.fastq -o example.pod5 --read-input ``` +## Noise options +`seq2squiggle` provides flexible options for generating signal data with various noise configurations. By default, it uses its duration sampler and noise sampler modules to predict event durations and amplitude noise levels specific to each input k-mer. Alternatively, you can deactivate these modules (`--noise-sampler False --duration-sampler False`) and use static distributions to sample event durations and amplitude noise. The static distributions can be configured using the options `--noise-std`, `--dwell-std`, and `--dwell-mean`. +### Examples using different noise options -## Different noise options -`seq2squiggle` supports different options for generating the signal data. -Per default, the noise sampler and duration sampler are used. - -### Examples - -Generate reads using both the noise sampler and duration sampler: +Default configuration (noise sampler and duration sampler enabled): ``` seq2squiggle predict example.fasta -o example.blow5 ``` -Generate reads using the noise sampler with an increased factor and duration sampler: +Using the noise sampler with increased noise standard deviation and the duration sampler: ``` seq2squiggle predict example.fasta -o example.blow5 --noise-std 1.5 ``` -Generate reads using a static normal distribution for the noise and duration sampler: +Using a static normal distribution for the amplitude noise and the duration sampler: ``` -seq2squiggle predict example.fasta -o example.blow5 --noise-std 1.5 --noise-sampling False +seq2squiggle predict example.fasta -o example.blow5 --noise-std 1.0 --noise-sampling False ``` -Generate reads using only the noise sampler and a static normal distribution for the event length: +Using the noise sampler and a static normal distribution for event durations: ``` -seq2squiggle predict example.fasta -o example.blow5 --duration-sampling False --ideal-event-length -1 +seq2squiggle predict example.fasta -o example.blow5 --duration-sampling False --dwell-std 4.0 ``` -Generate reads using only the noise sampler and ideal event lengths: +Using the noise sampler with ideal event lengths (each k-mer event will have a length of 10): ``` -seq2squiggle predict example.fasta -o example.blow5 --duration-sampling False --ideal-event-length 10.0 +seq2squiggle predict example.fasta -o example.blow5 --duration-sampling False --dwell-mean 10.0 --dwell-std 0.0 ``` -Generate reads using a static normal distribution for the amplitude noise and ideal event lengths: +Using a static normal distribution for amplitude noise and ideal event lengths: ``` -seq2squiggle predict example.fasta -o example.blow5 --duration-sampling False --ideal-event-length 10.0 --noise-sampling False --noise-std 1.0 +seq2squiggle predict example.fasta -o example.blow5 --duration-sampling False --dwell-mean 10.0 --dwell-std 0.0 --noise-sampling False --noise-std 1.0 ``` -Generate reads using no amplitude noise and ideal event lengths: +Generating reads with no amplitude noise and ideal event lengths: ``` -seq2squiggle predict example.fasta -o example.blow5 --duration-sampling False --ideal-event-length 10.0 --noise-sampling False --noise-std -1 +seq2squiggle predict example.fasta -o example.blow5 --duration-sampling False --dwell-mean 10.0 --dwell-std 0.0 --noise-sampling False --noise-std 0.0 ``` ## Train a new model @@ -125,4 +126,5 @@ seq2squiggle train train_dir valid_dir --config my_config.yml --model last.ckpt ``` ## Acknowledgement -The model is based on [xcmyz's implementation of FastSpeech](https://github.com/xcmyz/FastSpeech). Some code snippets for preprocessing DNA-signal chunks have been taken from [bonito](https://github.com/nanoporetech/bonito). +The model is based on [xcmyz's implementation of FastSpeech](https://github.com/xcmyz/FastSpeech). Some code snippets for preprocessing DNA-signal chunks have been taken from [bonito](https://github.com/nanoporetech/bonito). We also incorporated code snippets from [Casanovo](https://github.com/Noble-Lab/casanovo) for different functionalities, including downloading weights, logging, and the design of the main function. +Additionally, we used parameter profiles from squigulator for various chemistries to set digitisation, sample-rate, range, median_before, and other signal parameters. These profiles are detailed in [squigulator's documentation](https://hasindu2008.github.io/squigulator/docs/profile.html). diff --git a/img/seq2squiggle.png b/img/seq2squiggle.png new file mode 100644 index 0000000..f7b9b7a Binary files /dev/null and b/img/seq2squiggle.png differ diff --git a/img/seq2squiggle_architecture.png b/img/seq2squiggle_architecture.png deleted file mode 100644 index 25b2d0c..0000000 Binary files a/img/seq2squiggle_architecture.png and /dev/null differ diff --git a/src/seq2squiggle/__init.py__ b/src/seq2squiggle/__init.py__ new file mode 100644 index 0000000..e69de29 diff --git a/src/seq2squiggle/config.yaml b/src/seq2squiggle/config.yaml index 5d3be3f..48a8cf6 100644 --- a/src/seq2squiggle/config.yaml +++ b/src/seq2squiggle/config.yaml @@ -9,7 +9,7 @@ log_name: "Human-R1041-4khz" wandb_logger_state: disabled # disabled, online, offline ### Preprocessing parameters -max_chunks_train: 170000000 +max_chunks_train: 210000000 max_chunks_valid: 100000 scaling_max_value: 165.0 # If valid_dir is not provided, validation data will be generated from the training dataset. @@ -28,9 +28,9 @@ encoder_layers: 2 encoder_heads: 8 decoder_layers: 2 decoder_heads: 8 -encoder_dropout: 0.1 -decoder_dropout: 0.1 -duration_dropout: 0.1 +encoder_dropout: 0.2 +decoder_dropout: 0.2 +duration_dropout: 0.2 ### Learning rate parameters train_batch_size: 512 @@ -39,7 +39,7 @@ save_model: True # Optimizer. Allowed options: Adam, AdamW, SGD, RMSProp, optimizer: "Adam" warmup_ratio: 0.01 # Percentage of total steps used for warmup -lr: 0.00025 +lr: 0.0005 weight_decay: 0.0 # Schedule for learning rate. Allowed options: warmup_cosine, warmup_constant, constant, warmup_cosine_restarts, one_cycle lr_schedule: "warmup_cosine" diff --git a/src/seq2squiggle/dataloader.py b/src/seq2squiggle/dataloader.py index 5909a36..81689f0 100644 --- a/src/seq2squiggle/dataloader.py +++ b/src/seq2squiggle/dataloader.py @@ -8,10 +8,12 @@ import numpy as np import logging import pytorch_lightning as pl -from torch.utils.data import DataLoader, IterableDataset, Dataset +from torch.utils.data import DataLoader, IterableDataset, Dataset, get_worker_info, DistributedSampler +from torch.distributed import init_process_group, get_rank, get_world_size from multiprocessing.pool import ThreadPool as Pool import itertools import multiprocessing +import torch.distributed as tdi from typing import Tuple, List, Optional, Dict, Generator from bisect import bisect @@ -82,6 +84,8 @@ def __init__( valid_dir: str = "path/to/dir", batch_size: int = 128, n_workers: int = 1, + rank: int = 0, + world_size: int = 1, ): super().__init__() self.data_dir = data_dir @@ -90,6 +94,8 @@ def __init__( self.config = config self.n_workers = n_workers self.total_l = total_l + self.rank = rank + self.world_size = world_size def setup(self, stage: str): if stage in ("fit", "validate"): @@ -108,7 +114,7 @@ def setup(self, stage: str): if stage in (None, "predict"): logger.debug("Loading fasta started") self.predict_loader_kwargs = load_fasta( - self.data_dir, self.config, self.total_l + self.data_dir, self.config, self.total_l, self.rank, self.world_size ) logger.debug("Loading fasta ended") @@ -134,7 +140,7 @@ def val_dataloader(self): def predict_dataloader(self): predict_loader = DataLoader( - batch_size=self.batch_size, # self.batch_size + batch_size=self.batch_size, num_workers=self.n_workers, pin_memory=True, **self.predict_loader_kwargs, @@ -253,6 +259,64 @@ def __len__(self): return self.data_count +class DataParallelIterableDataSet(IterableDataset): + """ + A PyTorch `IterableDataset` that wraps an iterable to provide data for prediction. + Multi-Threading not implemented yet. + + Parameters + ---------- + iterable : iterable + An iterable object that yields data samples. + length : int + The total length of the dataset. + + Attributes + ---------- + iterable : iterable + The iterable object used to provide data samples. + length : int + The length of the dataset, representing the number of samples. + + Methods + ------- + __iter__() + Returns the iterable object itself. + __len__() + Returns the length of the dataset. + """ + def __init__(self, iterable, length, rank, world_size): + self.iterable = iterable + self.length = length + self.rank = rank + self.world_size = world_size + + def __iter__(self): + # devices split + device_rank, num_devices = (tdi.get_rank(), tdi.get_world_size()) if tdi.is_initialized() else (0, 1) + # workers split + worker_info = get_worker_info() + worker_rank, num_workers = (worker_info.id, worker_info.num_workers) if worker_info else (0, 1) + + # total (devices + workers) split by device, then by worker + num_replicas = num_workers * num_devices + replica_rank = worker_rank * num_devices + device_rank + # by worker, then device would be: + # rank = device_rank * num_workers + worker_rank + + for i, data in enumerate(self.iterable): + if i % num_replicas == replica_rank: + #print(f"Device: {device_rank}, worker {worker_rank} fetches sample {i}") + yield data + else: + continue + + # return self.iterable + + + def __len__(self): + return self.length + class IterableFastaDataSet(IterableDataset): """ A PyTorch `IterableDataset` that wraps an iterable to provide data for prediction. @@ -287,6 +351,7 @@ def __init__(self, iterable, length): def __iter__(self): return self.iterable + def __len__(self): return self.length @@ -335,7 +400,7 @@ def process_read( def load_fasta( - fasta: List[Tuple[str, str]], config: Dict, total_l: int + fasta: List[Tuple[str, str]], config: Dict, total_l: int, rank:int, world_size:int ) -> Dict[str, "DataLoader"]: """ Loads and processes FASTA files into a dataset for prediction, using parallel processing. @@ -379,12 +444,13 @@ def load_fasta( combined_generator = itertools.chain(*results) logger.debug("Splitting the reads to chunks finished.") - + predict_loader_kwargs = { + #"dataset": DataParallelIterableDataSet(combined_generator, total_l, rank, world_size), "dataset": IterableFastaDataSet(combined_generator, total_l), "shuffle": False, } - + return predict_loader_kwargs @@ -424,59 +490,24 @@ def load_numpy( - Training DataLoader configuration with dataset and shuffle setting. - Validation DataLoader configuration with dataset and shuffle setting. """ - - chunks_train = [ - os.path.join(npy_train, filename) - for filename in os.listdir(npy_train) - if filename.startswith("chunks-") - ] - targets_train = [ - os.path.join(npy_train, filename) - for filename in os.listdir(npy_train) - if filename.startswith("targets-") - ] - c_lengths_train = [ - os.path.join(npy_train, filename) - for filename in os.listdir(npy_train) - if filename.startswith("chunks_lengths-") - ] - t_lengths_train = [ - os.path.join(npy_train, filename) - for filename in os.listdir(npy_train) - if filename.startswith("targets_lengths-") - ] - stdevs_train = [ - os.path.join(npy_train, filename) - for filename in os.listdir(npy_train) - if filename.startswith("stdevs-") - ] - - if npy_valid is not None and os.path.exists(npy_valid): - chunks_valid = [ - os.path.join(npy_valid, filename) - for filename in os.listdir(npy_valid) - if filename.startswith("chunks-") - ] - targets_valid = [ - os.path.join(npy_valid, filename) - for filename in os.listdir(npy_valid) - if filename.startswith("targets-") - ] - c_lengths_valid = [ - os.path.join(npy_valid, filename) - for filename in os.listdir(npy_valid) - if filename.startswith("chunks_lengths-") - ] - t_lengths_valid = [ - os.path.join(npy_valid, filename) - for filename in os.listdir(npy_valid) - if filename.startswith("targets_lengths-") - ] - stdevs_valid = [ - os.path.join(npy_train, filename) - for filename in os.listdir(npy_train) - if filename.startswith("stdevs-") - ] + def load_paths(directory: str, prefix: str) -> List[str]: + return sorted( + os.path.join(directory, f) for f in os.listdir(directory) if f.startswith(prefix) + ) + + + chunks_train = load_paths(npy_train, "chunks-") + targets_train = load_paths(npy_train, "targets-") + c_lengths_train = load_paths(npy_train, "chunks_lengths-") + t_lengths_train = load_paths(npy_train, "targets_lengths-") + stdevs_train = load_paths(npy_train, "stdevs-") + + if npy_valid and os.path.exists(npy_valid): + chunks_valid = load_paths(npy_valid, "chunks-") + targets_valid = load_paths(npy_valid, "targets-") + c_lengths_valid = load_paths(npy_valid, "chunks_lengths-") + t_lengths_valid = load_paths(npy_valid, "targets_lengths-") + stdevs_valid = load_paths(npy_valid, "stdevs-") else: # Lazy split for testing chunks_train, chunks_valid = train_test_split( @@ -611,32 +642,10 @@ def sort_files( - Sorted target lengths file paths. - Sorted standard deviations file paths. """ - # Extract the filename from each path - chunk_filenames = [path.split("/")[-1] for path in chunks_path] - target_filenames = [path.split("/")[-1] for path in targets_path] - clengths_filenames = [path.split("/")[-1] for path in c_lengths_path] - tlengths_filesnames = [path.split("/")[-1] for path in t_lengths_path] - stdevs_filesnames = [path.split("/")[-1] for path in stdevs_path] - - # Sort both lists based on filenames - sorted_chunk_paths = [path for _, path in sorted(zip(chunk_filenames, chunks_path))] - sorted_target_paths = [ - path for _, path in sorted(zip(target_filenames, targets_path)) - ] - sorted_clengths_paths = [ - path for _, path in sorted(zip(clengths_filenames, c_lengths_path)) - ] - sorted_tlengths_paths = [ - path for _, path in sorted(zip(tlengths_filesnames, t_lengths_path)) - ] - sorted_stdevs_paths = [ - path for _, path in sorted(zip(stdevs_filesnames, stdevs_path)) - ] - - return ( - sorted_chunk_paths, - sorted_target_paths, - sorted_clengths_paths, - sorted_tlengths_paths, - sorted_stdevs_paths, + def sort_by_filename(paths): + return sorted(paths, key=lambda path: path.split("/")[-1]) + + return tuple( + sort_by_filename(path_list) for path_list in + [chunks_path, targets_path, c_lengths_path, t_lengths_path, stdevs_path] ) diff --git a/src/seq2squiggle/inference.py b/src/seq2squiggle/inference.py index cf56931..83d6984 100644 --- a/src/seq2squiggle/inference.py +++ b/src/seq2squiggle/inference.py @@ -18,7 +18,7 @@ from .signal_io import BLOW5Writer, POD5Writer from .model import seq2squiggle -from .utils import get_reads +from .utils import get_reads, get_profile, update_profile, update_config from .train import DDPStrategy from .dataloader import PoreDataModule from . import __version__ @@ -28,7 +28,7 @@ def get_writer( - out: str, profile: object, ideal_event_length: int, export_every_n_samples: int + out: str, profile: object, ideal_mode: bool, export_every_n_samples: int ) -> tuple: """ Returns an appropriate file writer object based on the output file extension. @@ -58,7 +58,7 @@ def get_writer( os.remove(out) if any(out_base.endswith(ext) for ext in slow5_ext): - return BLOW5Writer(out, profile, ideal_event_length), export_every_n_samples + return BLOW5Writer(out, profile, ideal_mode), export_every_n_samples elif out_base.endswith(pod5_ext): logger.warning("POD5 Writer does not support appending to an existing file.") logger.warning( @@ -67,13 +67,13 @@ def get_writer( logger.warning( "This might lead to Out of Memory errors for large-scale simulations. Consider exporting to BLOW5/SLOW5 and using the blue_crab tool for conversion to pod5." ) - return POD5Writer(out, profile, ideal_event_length), float("inf") + return POD5Writer(out, profile, ideal_mode), float("inf") else: logger.error("Output file must have .pod5, .slow5, or .blow5 extension.") raise ValueError("Output file must have .pod5, .slow5, or .blow5 extension.") -def get_saved_weights() -> str: +def get_saved_weights(profile_name) -> str: """ Checks for the existence of the saved weights file and returns the appropriate file path. @@ -92,15 +92,31 @@ def get_saved_weights() -> str: FileNotFoundError If neither the specified saved weights file nor the default file in the logging directory is found. """ - logger.info("Weights file path is not provided.") cache_dir = appdirs.user_cache_dir("seq2squiggle", False, opinion=False) os.makedirs(cache_dir, exist_ok=True) + + # Log profile name details + if profile_name.startswith("dna-r10"): + logger.info("Detected R10.4.1 chemistry profile.") + logger.info("Profile can be changed with the --profile parameter") + profile_keyword = "R10" + elif profile_name.startswith("dna-r9"): + logger.info("Detected R9.4.1 chemistry profile.") + logger.info("Profile can be changed with the --profile parameter") + profile_keyword = "R9" + else: + logger.warning( + "Profile name '%s' does not match known patterns (R10- or R9-). Proceeding with latest weights.", + profile_name, + ) + profile_keyword = None + version = __version__ version_match = None, None, 0 - # Search in local cache + # Search local cache for version- and profile-matching weights for filename in os.listdir(cache_dir): root, ext = os.path.splitext(filename) if ext == ".ckpt": @@ -112,11 +128,13 @@ def get_saved_weights() -> str: if (m := [i == j for i, j in zip(version, file_version)])[0] else 0 ) - if match > version_match[2]: + if match > version_match[2] and profile_keyword and profile_keyword in root: version_match = os.path.join(cache_dir, filename), None, match + + # Return best-matching local weights if version_match[2] > 0: logger.info( - "Model weights file %s retrieved from local cache", + "Found matching weights in local cache: %s", version_match[0], ) return version_match[0] @@ -134,15 +152,37 @@ def get_saved_weights() -> str: for release_asset in release.get_assets(): fn, ext = os.path.splitext(release_asset.name) if ext == ".ckpt": - version_match = ( - os.path.join( - cache_dir, - f"{fn}@v{'.'.join(map(str, rel_version))}{ext}", - ), - release_asset.browser_download_url, - match, - ) - break + if profile_keyword and profile_keyword in release_asset.name: + logger.info( + "Found matching release for %s profile: %s", + profile_keyword, + release_asset.name, + ) + version_match = ( + os.path.join( + cache_dir, + f"{fn}@v{'.'.join(map(str, rel_version))}{ext}", + ), + release_asset.browser_download_url, + match, + ) + break + elif not (profile_keyword): + logger.info( + "Found no matching release for %s profile: %s", + profile_keyword, + release_asset.name, + ) + # Save the latest available release for fallback + version_match = ( + os.path.join( + cache_dir, + f"{fn}@v{'.'.join(map(str, rel_version))}{ext}", + ), + release_asset.browser_download_url, + match, + ) + break # Download the model weights if a matching release was found. if version_match[2] > 0: filename, url, _ = version_match @@ -159,13 +199,14 @@ def get_saved_weights() -> str: return filename else: logger.error( - "No matching model weights for release v%s found, please " + "No matching model weights for release v%s and profile %s found, please " "specify your model weights explicitly using the `--model` " "parameter", version, + profile_name, ) raise ValueError( - f"No matching model weights for release v{version} found, " + f"No matching model weights for release v{version} and profile {profile_name} found, " f"please specify your model weights explicitly using the " f"`--model` parameter" ) @@ -204,9 +245,16 @@ def check_model(model: object, config: dict) -> None: for param, value in architecture_params.items(): if param not in exclude_params: if model_params.get(param) != value: + if param == "seq_kmer": + raise ValueError( + f"Parameter 'seq_kmer' mismatch: Model checkpoint value is " + f"{model_params.get(param)}, while config value is {value}. " + f"The model was trained on {model_params.get(param)}-mers, while the config file expects {value}-mers. " + "Choose a different model or change the config value or the --profile option. " + ) logger.warning( - f"Mismatching {param} parameter in model checkpoint" - f" ({model_params.get(param)}) and in config file ({value})" + f"Mismatching {param} parameter in model checkpoint " + f"({model_params.get(param)}) and in config file ({value})" ) @@ -220,13 +268,21 @@ def inference_run( c: int, out: str, profile: dict, - ideal_event_length: int, + dwell_mean: int, + dwell_std: float, noise_std: float, noise_sampling: bool, duration_sampling: bool, distr: str, predict_batch_size: int, export_every_n_samples: int, + sample_rate: int, + digitisation: int, + range_val: float, + offset_mean: float, + offset_std: float, + median_before_mean: float, + median_before_std: float, seed: int, ): """ @@ -266,6 +322,8 @@ def inference_run( Batch size for predictions. export_every_n_samples : int Number of samples after which to export data. + sampling_rate : int + sampling rate seed : int Random seed for reproducibility. @@ -273,13 +331,27 @@ def inference_run( ------- None """ + profile_dict = get_profile(profile) + profile_dict = update_profile(profile_dict, sample_rate=sample_rate, + digitisation=digitisation, + range=range_val, + offset_mean=offset_mean, + offset_std=offset_std, + median_before_mean=median_before_mean, + median_before_std=median_before_std) + + # Update config based on profile_dict + config = update_config(profile, config) + + ideal_mode = not(duration_sampling or dwell_std > 0) + writer, export_every_n_samples = get_writer( - out, profile, ideal_event_length, export_every_n_samples + out, profile_dict, ideal_mode, export_every_n_samples ) if saved_weights is None: try: - saved_weights = get_saved_weights() + saved_weights = get_saved_weights(profile) except github.RateLimitExceededException: logger.error( "GitHub API rate limit exceeded while trying to download the " @@ -296,7 +368,8 @@ def inference_run( load_model = seq2squiggle.load_from_checkpoint( checkpoint_path=saved_weights, out_writer=writer, - ideal_event_length=ideal_event_length, + dwell_mean=dwell_mean, + dwell_std=dwell_std, noise_std=noise_std, noise_sampling=noise_sampling, duration_sampling=duration_sampling, @@ -307,18 +380,9 @@ def inference_run( reads, total_l = get_reads(fasta, read_input, n, r, c, config, distr, seed) - fasta_data = PoreDataModule( - config=config, - data_dir=reads, - total_l=total_l, - batch_size=predict_batch_size, - n_workers=1, # n_workers > 1 causes incorrect order of IterableDataset + slower than single process - ) # "gamma_cpu" not implemented for 'BFloat16' - precision = "64" - if torch.cuda.device_count() >= 1: - precision = "16-mixed" + precision = "16-mixed" if torch.cuda.device_count() >= 1 else "32" trainer = pl.Trainer( accelerator="auto", @@ -326,6 +390,20 @@ def inference_run( devices="auto", logger=False, strategy=_get_strategy(), + # use_distributed_sampler=False + ) + + rank = trainer.global_rank + world_size = trainer.world_size + + fasta_data = PoreDataModule( + config=config, + data_dir=reads, + total_l=total_l, + batch_size=predict_batch_size, + n_workers=1, # n_workers > 1 causes incorrect order of IterableDataset + slower than single process + rank=rank, + world_size=world_size, ) trainer.predict(model=load_model, datamodule=fasta_data, return_predictions=False) @@ -347,3 +425,5 @@ def _get_strategy(): if torch.cuda.device_count() > 1: return DDPStrategy(find_unused_parameters=False, static_graph=True) return "auto" + + diff --git a/src/seq2squiggle/layers.py b/src/seq2squiggle/layers.py index e8a7f53..1c84605 100644 --- a/src/seq2squiggle/layers.py +++ b/src/seq2squiggle/layers.py @@ -5,7 +5,7 @@ """ from torch import nn, bmm, FloatTensor -import numpy as np +import torch class ScaledDotProductAttention(nn.Module): @@ -21,11 +21,10 @@ def forward(self, q, k, v, mask=None): attn = attn / self.temperature if mask is not None: - attn = attn.masked_fill(mask, -np.inf) + attn = attn.masked_fill(mask, -torch.inf) attn = self.softmax(attn) output = bmm(attn, v) - return output, attn @@ -43,7 +42,7 @@ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): self.w_ks = nn.Linear(d_model, n_head * d_k) self.w_vs = nn.Linear(d_model, n_head * d_v) - self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) + self.attention = ScaledDotProductAttention(temperature=d_k**0.5) self.layer_norm = nn.LayerNorm(d_model) self.fc = nn.Linear(n_head * d_v, d_model) @@ -134,20 +133,22 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): """Sinusoid position encoding table""" def cal_angle(position, hid_idx): - return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) + return position / 10000**(2 * (hid_idx // 2) / d_hid) def get_posi_angle_vec(position): return [cal_angle(position, hid_j) for hid_j in range(d_hid)] - - sinusoid_table = np.array( + + sinusoid_table = torch.tensor( [get_posi_angle_vec(pos_i) for pos_i in range(n_position)] ) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) # dim 2i+1 if padding_idx is not None: # zero vector for padding dimension sinusoid_table[padding_idx] = 0.0 - return FloatTensor(sinusoid_table) + return torch.FloatTensor(sinusoid_table) + + diff --git a/src/seq2squiggle/model.py b/src/seq2squiggle/model.py index 3bb8ed9..7c866df 100644 --- a/src/seq2squiggle/model.py +++ b/src/seq2squiggle/model.py @@ -15,6 +15,7 @@ from .modules import Encoder, LengthRegulator, Decoder, NoiseSampler from .utils import generate_validation_plots +from .signal_io import BLOW5Writer logger = logging.getLogger("seq2squiggle") @@ -32,7 +33,8 @@ def __init__( config: dict, save_valid_plots: bool = True, out_writer: None = None, - ideal_event_length: int = 0, + dwell_mean: float = 9.0, + dwell_std: float = 0.0, noise_std: float = -1, noise_sampling: bool = False, duration_sampling: bool = False, @@ -48,7 +50,8 @@ def __init__( self.save_valid_plots = save_valid_plots self.results = [] self.out_writer = out_writer - self.ideal_event_length = ideal_event_length + self.dwell_mean = dwell_mean + self.dwell_std = dwell_std self.noise_std = noise_std self.noise_sampling = noise_sampling self.duration_sampling = duration_sampling @@ -195,7 +198,6 @@ def validation_step(self, batch, batch_idx): def predict_step(self, batch): read_id, data, *args = batch - bs, seq_l = data.shape[:2] data = data.reshape(bs, seq_l, -1) @@ -210,50 +212,42 @@ def predict_step(self, batch): target=None, noise_std_prediction=noise_std_prediction, max_length=self.config["max_signal_len"], - ideal_length=self.ideal_event_length, + dwell_mean=self.dwell_mean, + dwell_std=self.dwell_std, duration_sampling=self.duration_sampling, ) - prediction = self.decoders(length_predict_out) - prediction = prediction.cpu().squeeze(-1) + prediction = self.decoders(length_predict_out) prediction = prediction * self.config["scaling_max_value"] - - non_zero_mask = prediction != 0 - + prediction = prediction.squeeze(-1) + if self.noise_std > 0: + non_zero_mask = prediction != 0 if self.noise_sampling: - noise_std = noise_std_prediction_ext.detach().cpu().squeeze().numpy() - noise_std = ( - noise_std * self.noise_std * self.config["scaling_max_value"] - ) - gen_noise = np.random.normal(loc=0, scale=noise_std) - gen_noise = torch.tensor(gen_noise, dtype=prediction.dtype) + noise_std = noise_std_prediction_ext.squeeze() * self.noise_std * self.config["scaling_max_value"] + + gen_noise = torch.normal(mean=0, std=noise_std) + prediction[non_zero_mask] += gen_noise[non_zero_mask] else: - noise = np.random.normal( - loc=0, scale=self.noise_std, size=prediction.shape - ) - noise = torch.tensor( - noise, dtype=prediction.dtype, device=prediction.device - ) - prediction[non_zero_mask] += noise[non_zero_mask] + gen_noise = torch.normal(mean=0, std=self.noise_std, size=prediction.shape, device=prediction.device) - # Clamp the tensor to ensure no negative values + prediction[non_zero_mask] += gen_noise[non_zero_mask] + prediction = torch.clamp(prediction, min=0) - # Create dict of read_id and predictions d = {} for read, pred in zip(read_id, prediction): d.setdefault(read, []).append(pred) self.results.append(d) - # Increment the sample count self.total_samples += data.shape[0] - if self.total_samples >= self.export_every_n_samples: + if isinstance(self.out_writer, BLOW5Writer) and self.total_samples >= self.export_every_n_samples: self.export_and_clear_results(keep_last=True) self.total_samples = 0 # Reset sample count + def export_and_clear_results(self, keep_last: bool = True): """ diff --git a/src/seq2squiggle/modules.py b/src/seq2squiggle/modules.py index 61a97bd..fd46da0 100644 --- a/src/seq2squiggle/modules.py +++ b/src/seq2squiggle/modules.py @@ -67,6 +67,8 @@ def forward(self, src_seq, return_attns=False): batch_size, max_len = src_seq.shape[0], src_seq.shape[1] enc_slf_attn_list = [] + src_seq = src_seq.float() + src_seq = self.src_emb(src_seq) src_seq = self.relu(src_seq) if self.pre_layers > 0: @@ -74,11 +76,8 @@ def forward(self, src_seq, return_attns=False): src_seq = pre_layer(src_seq) src_seq = self.relu(src_seq) - enc_output = src_seq + get_sinusoid_encoding_table( - src_seq.shape[1], src_seq.shape[2] - )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( - src_seq.device - ) + + enc_output = src_seq + self.position_enc[:src_seq.shape[1]] for enc_layer in self.layer_stack: enc_output, enc_slf_attn = enc_layer( @@ -134,18 +133,7 @@ def __init__(self, config): def forward(self, enc_seq): batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1] - # -- Forward - if not self.training and enc_seq.shape[1] > self.max_seq_len: - dec_output = enc_seq + get_sinusoid_encoding_table( - enc_seq.shape[1], self.d_model - )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( - enc_seq.device - ) - else: - max_len = min(max_len, self.max_seq_len) - dec_output = enc_seq[:, :max_len, :] + self.position_enc[ - :, :max_len, : - ].expand(batch_size, -1, -1) + dec_output = enc_seq + self.position_enc[:enc_seq.shape[1]] for dec_layer in self.layer_stack_FFT: dec_output, _ = dec_layer(dec_output, mask=None, slf_attn_mask=None) @@ -290,40 +278,32 @@ def forward(self, x): return stdv -@jit(nopython=True) -def create_alignment(base_mat, duration_predictor_output): +def get_padding_mask(lengths, max_len=None): """ - Create an alignment matrix based on the duration predictions. - - This function fills a base matrix with alignment values based on the - durations predicted for each sequence. For each predicted duration, - the function updates the corresponding positions in the base matrix. + Creates a padding mask for a batch of sequences based on their lengths. Parameters ---------- - base_mat : np.ndarray - The base matrix to be filled with alignment values. It should have - shape (N, max_duration, L) where N is the number of sequences, - max_duration is the maximum duration predicted, and L is the length of each sequence. - - duration_predictor_output : np.ndarray - The predicted durations for each sequence. It should have shape (N, L) where - N is the number of sequences and L is the length of each sequence. Each element - represents the predicted duration for the corresponding position in the sequence. + lengths + A tensor containing the lengths of each sequence in the batch. + max_len + The maximum length of the sequences. If not provided, it will be set to the maximum value in `lengths`. Returns ------- - np.ndarray - The updated base matrix with alignment values filled in based on the predicted durations. + torch.Tensor + A boolean tensor of shape (batch_size, max_len) where each element is True if it corresponds + to a padding position and False otherwise. """ - N, L = duration_predictor_output.shape - for i in range(N): - count = 0 - for j in range(L): - for k in range(duration_predictor_output[i][j]): - base_mat[i][count + k][j] = 1 - count = count + duration_predictor_output[i][j] - return base_mat + + if max_len is None: + max_len = lengths.max().item() + + # Create a mask by comparing each position index with the sequence lengths + ids = torch.arange(max_len, device=lengths.device) + mask = ids.unsqueeze(0) < lengths.unsqueeze(1) + + return mask class LengthRegulator(nn.Module): @@ -361,7 +341,7 @@ def __init__(self, config): self.config = config self.duration_sampler = DurationSampler(self.config) - def LR(self, x, duration_pred_out, max_length=None): + def LR(self, x, x_noise, duration_pred_out, max_length=None): """ Regulates the length of the input tensor x based on the duration predictions. @@ -369,6 +349,8 @@ def LR(self, x, duration_pred_out, max_length=None): ---------- x : torch.Tensor The input tensor with shape (batch_size, max_duration, dna_length). + x_noise : torch.Tensor + The noise std input tensor with shape (batch_size, max_duration, dna_length). duration_pred_out : torch.Tensor The tensor with predicted durations, shape (batch_size, dna_length). max_length : Optional[int] @@ -376,24 +358,35 @@ def LR(self, x, duration_pred_out, max_length=None): Returns ------- - torch.Tensor + output: torch.Tensor + The length-regulated tensor with the same shape as x. + output_noise: torch.Tensor The length-regulated tensor with the same shape as x. """ - # largest value - expand_max_len = torch.max(torch.sum(duration_pred_out, -1), -1)[0] - # intialize array filled with 0s, shape (bs, expand_max_len, dna_length) - alignment = torch.zeros( - duration_pred_out.size(0), expand_max_len, duration_pred_out.size(1) - ).numpy() + batch_size, input_max_seq_len = duration_pred_out.shape + # determine largest value + cum_duration = torch.cumsum(duration_pred_out, dim=1) + output_max_seq_len = torch.max(cum_duration) # create alignment matrix - alignment = create_alignment(alignment, duration_pred_out.cpu().numpy()) - alignment = torch.from_numpy(alignment).to(device) - # matrix multp - output = alignment @ x + cum_duration_reshaped = cum_duration.reshape(batch_size * input_max_seq_len) + M = get_padding_mask(cum_duration_reshaped,output_max_seq_len).reshape( + batch_size, input_max_seq_len,output_max_seq_len).float() + # adjust the matrix so that it captures the differences between cumulative durations + M = torch.diff(M, dim=1, prepend=torch.zeros_like(M[:, :1])) + # matrix multip + output = torch.bmm(M.permute(0, 2, 1), x) + + if x_noise is not None: + x_noise = torch.bmm(M.permute(0, 2, 1), x_noise) + # pad to max length if max_length: output = F.pad(output, (0, 0, 0, max_length - output.size(1), 0, 0)) - return output + if x_noise is not None: + x_noise = F.pad(x_noise, (0, 0, 0, max_length - x_noise.size(1), 0, 0)) + return output, x_noise + + def forward( self, @@ -403,7 +396,8 @@ def forward( alpha=1.0, target=None, max_length=None, - ideal_length=0.0, + dwell_mean=9.0, + dwell_std=0.0, duration_sampling=True, ): min_value = 3 @@ -418,13 +412,13 @@ def forward( else: bs, seq, _ = emb_out.shape dist = None - if ideal_length > 0: - duration_predictor_output = torch.full((bs, seq), ideal_length).to( + if dwell_std <= 0: + duration_predictor_output = torch.full((bs, seq), dwell_mean).to( device ) else: - mean = torch.full((bs, seq), 9.0).to(device) - std = torch.full((bs, seq), 4.0).to(device) + mean = torch.full((bs, seq), dwell_mean).to(device) + std = torch.full((bs, seq), dwell_std).to(device) duration_predictor_output = torch.normal(mean=mean, std=std) # Ensure all values are positive by clipping to a minimum value # a small positive value to ensure strictly positive values @@ -433,17 +427,10 @@ def forward( ) if target is not None: - output = self.LR(x, target, max_length=max_length) - if noise_std_prediction is not None: - noise_std_prediction = self.LR( - noise_std_prediction, target, max_length=max_length - ) + output, noise_std_prediction = self.LR(x, noise_std_prediction, target, max_length=max_length) else: duration_prediction = duration_predictor_output.detach().clone() duration_prediction = torch.round(duration_prediction).int() - output = self.LR(x, duration_prediction, max_length=max_length) - if noise_std_prediction is not None: - noise_std_prediction = self.LR( - noise_std_prediction, duration_prediction, max_length=max_length - ) + output, noise_std_prediction = self.LR(x, noise_std_prediction, duration_prediction, max_length=max_length) return output, duration_predictor_output, dist, noise_std_prediction + diff --git a/src/seq2squiggle/preprocess.py b/src/seq2squiggle/preprocess.py index a823df1..0a27203 100644 --- a/src/seq2squiggle/preprocess.py +++ b/src/seq2squiggle/preprocess.py @@ -177,28 +177,15 @@ def load_numpy_datasets( A tuple containing numpy arrays for chunks, targets, chunks lengths, targets lengths, and standard deviations. """ - data = np.load(path) - chunks = data["chunks"] - c_lengths = data["chunks_lengths"] - targets = data["targets"] - t_lengths = data["targets_lengths"] - stdevs = data["stdevs"] - data.close() - - if limit: - chunks = chunks[:limit] - targets = targets[:limit] - c_lengths = c_lengths[:limit] - t_lengths = t_lengths[:limit] - stdevs = stdevs[:limit] - - return ( - np.array(chunks), - np.array(targets), - np.array(c_lengths), - np.array(t_lengths), - np.array(stdevs), - ) + with np.load(path, allow_pickle=False) as data: + slices = slice(None, limit) + return ( + np.array(data.get("chunks", [])[slices]), + np.array(data.get("targets", [])[slices]), + np.array(data.get("chunks_lengths", [])[slices]), + np.array(data.get("targets_lengths", [])[slices]), + np.array(data.get("stdevs", [])[slices]), + ) def save_chunks(chunks, output_directory: str) -> None: @@ -218,23 +205,21 @@ def save_chunks(chunks, output_directory: str) -> None: """ os.makedirs(output_directory, exist_ok=True) - np.save(os.path.join(output_directory, "chunks.npy"), chunks.chunks.squeeze(1)) - np.save( - os.path.join(output_directory, "chunks_lengths.npy"), - chunks.c_lengths.squeeze(1), - ) - np.save(os.path.join(output_directory, "targets.npy"), chunks.targets) - np.save(os.path.join(output_directory, "targets_lengths.npy"), chunks.t_lengths) - np.save(os.path.join(output_directory, "stdevs.npy"), chunks.stdevs) + # Define file names and their corresponding data attributes + data_map = { + "chunks": chunks.chunks.squeeze(1), + "chunks_lengths": chunks.c_lengths.squeeze(1), + "targets": chunks.targets, + "targets_lengths": chunks.t_lengths, + "stdevs": chunks.stdevs, + } + + # Save and log each attribute + for name, data in data_map.items(): + np.save(os.path.join(output_directory, f"{name}.npy"), data) + logger.debug(f" - {name}.npy with shape {data.shape}") logger.debug(f"> data written to: {output_directory}") - logger.debug(f" - chunks.npy with shape {chunks.chunks.squeeze(1).shape}") - logger.debug( - f" - chunks_lengths.npy with shape {chunks.c_lengths.squeeze(1).shape}" - ) - logger.debug(f" - targets.npy with shape {chunks.targets.shape}") - logger.debug(f" - targets_lengths.npy shape {chunks.t_lengths.shape}") - logger.debug(f" - stdevs.npy shape {chunks.stdevs.shape}") def save_chunks_in_batches(chunks, output_directory: str, counter: int = 0) -> None: @@ -255,34 +240,20 @@ def save_chunks_in_batches(chunks, output_directory: str, counter: int = 0) -> N None """ os.makedirs(output_directory, exist_ok=True) - np.save( - os.path.join(output_directory, f"chunks-{counter:04d}.npy"), - chunks.chunks.squeeze(1), - ) - np.save( - os.path.join(output_directory, f"chunks_lengths-{counter:04d}.npy"), - chunks.c_lengths.squeeze(1), - ) - np.save( - os.path.join(output_directory, f"targets-{counter:04d}.npy"), chunks.targets - ) - np.save( - os.path.join(output_directory, f"targets_lengths-{counter:04d}.npy"), - chunks.t_lengths, - ) - np.save( - os.path.join(output_directory, f"stdevs-{counter:04d}.npy"), - chunks.stdevs, - ) + + data_map = { + "chunks": chunks.chunks.squeeze(1), + "chunks_lengths": chunks.c_lengths.squeeze(1), + "targets": chunks.targets, + "targets_lengths": chunks.t_lengths, + "stdevs": chunks.stdevs, + } + + for name, data in data_map.items(): + np.save(os.path.join(output_directory, f"{name}-{counter:04d}.npy"), data) + logger.debug(f" - {name}.npy with shape {data.shape}") logger.debug(f"> data written to: {output_directory}") - logger.debug(f" - chunks.npy with shape {chunks.chunks.squeeze(1).shape}") - logger.debug( - f" - chunks_lengths.npy with shape {chunks.c_lengths.squeeze(1).shape}" - ) - logger.debug(f" - targets.npy with shape {chunks.targets.shape}") - logger.debug(f" - targets_lengths.npy shape {chunks.t_lengths.shape}") - logger.debug(f" - stdevs.npy shape {chunks.t_lengths.shape}") def get_chunks( @@ -365,23 +336,30 @@ def get_kmer(dna_seq: List[str], kmer_size: int) -> List[str]: list of str List of k-mer sequences of the specified size. """ - if kmer_size == 9: - return dna_seq - elif kmer_size == 8: - return [i[1:] for i in dna_seq] - elif kmer_size == 7: - return [i[1:-1] for i in dna_seq] - elif kmer_size == 6: - return [i[2:-1] for i in dna_seq] - elif kmer_size == 5: - return [i[3:-1] for i in dna_seq] - elif kmer_size == 4: - return [i[4:-1] for i in dna_seq] - elif kmer_size == 3: - return [i[5:-1] for i in dna_seq] + if not (3 <= kmer_size <= 9): + logger.error(f"Choose a kmer value between 3 and 9. You chose {kmer_size}") + raise ValueError(f"Choose a kmer value between 3 and 9. You chose {kmer_size}") + + # Check the length of the first sequence + seq_length = len(dna_seq[0]) + + # for R9 + if seq_length == 6: + slice_map = {6: slice(None), 5: slice(0, -1), 4: slice(1, -1), 3: slice(1, 4)} + # for R10 + elif seq_length == 9: + slice_map = {9: slice(None), 8: slice(1, None), 7: slice(1, -1), 6: slice(2, -1), + 5: slice(3, -1), 4: slice(4, -1), 3: slice(5, -1)} else: - logger.error(f"Choose a kmer value between 3 and 9. You choose {kmer_size}") - raise ValueError(f"Choose a kmer value between 3 and 9. You choose {kmer_size}") + logger.error("Sequence length should be 6 (R9.4) or 9 (R10.4).") + raise ValueError("Sequence length should be 6 (R9.4) or 9 (R10.4).") + + if kmer_size > seq_length: + logger.error(f"kmer_size {kmer_size} is larger than the sequence length {seq_length}.") + raise ValueError(f"kmer_size {kmer_size} is larger than the sequence length {seq_length}.") + + return [seq[slice_map[kmer_size]] for seq in dna_seq] + def process_df( @@ -407,7 +385,7 @@ def process_df( - Array of standard deviations for each event. """ # Filter out artifacts of uncalled4 signal processing - df = df.sort(["position"]).filter(pl.col("model_kmer") != ("N" * 9)) + df = df.sort(["position"]).filter(pl.col("model_kmer") != ("N" * config["seq_kmer"])) # add 0s so that remainder is 0 df = df.with_columns(pl.col("end_idx").sub(pl.col("start_idx")).alias("signal_len")) @@ -418,7 +396,7 @@ def process_df( signal_len = df.select(pl.col(["signal_len"])).to_numpy().squeeze() # process DNA - dna_seq = df["model_kmer"].to_list() + dna_seq = df["model_kmer"].to_list() # Add remainder remain = config["max_dna_len"] - (len(dna_seq) % config["max_dna_len"]) @@ -426,7 +404,7 @@ def process_df( dna_seq = dna_seq + zero_array # One hot encode sequence - dna_seq = one_hot_encode(dna_seq) + dna_seq = one_hot_encode(dna_seq, len(dna_seq[0])) # Process the signal signal = df["samples"].to_list() diff --git a/src/seq2squiggle/seq2squiggle.py b/src/seq2squiggle/seq2squiggle.py index 1fbc8b0..0eb791a 100755 --- a/src/seq2squiggle/seq2squiggle.py +++ b/src/seq2squiggle/seq2squiggle.py @@ -135,9 +135,9 @@ def preprocess( verbosity, ): """ - Preprocess f5c's events.tsv for training the model + Preprocess uncalled4's events.tsv for training the model - EVENTS_PATH must be a events.tsv from f5c. + EVENTS_PATH must be a events.tsv from uncalled4 or f5c. OUTDIR must be path to output directory """ setup_logging(verbosity) @@ -205,12 +205,140 @@ def train( logger.info("Training done.") -@main.command(cls=_SharedParams) +# Function to conditionally show advanced options +def conditional_option(f): + f = click.option( + "--noise-sampler", + default=True, + type=bool, + help="Enable or disable the noise sampler.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--duration-sampler", + default=True, + type=bool, + help="Enable or disable the duration sampler.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--dwell-mean", + default=9.0, + type=float, + help="Specify the mean dwell time (=number of signal points per k-mer). This will only be used if the duration sampler is deactivated", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--dwell-std", + default=0.0, + type=float, + help="Specify the standard deviation of the dwell time (=number of signal points per k-mer). This will only be used if the duration sampler is deactivated", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--noise-std", + default=1.0, + type=float, + help="Set the standard deviation for noise.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--distr", + default="expon", + type=click.Choice(["expon", "beta", "gamma"]), + help="Choose a distribution for read sampling.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--predict-batch-size", + default=1024, + type=int, + help="Specify the batch size for prediction.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--export-every-n-samples", + default=1000000, + type=int, + help="Specify how often the predicted samples should be saved.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--sample-rate", + default=5000, + type=int, + help="Specify the sampling rate.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--digitisation", + default=None, + type=int, + help="Specify the digitisation.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--range_val", + default=None, + type=float, + help="Specify the range value.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--offset_mean", + default=None, + type=float, + help="Specify the digitisation.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--offset_std", + default=None, + type=float, + help="Specify the digitisation.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--median_before_mean", + default=None, + type=float, + help="Specify the digitisation.", + show_default=True, + hidden=True # Hidden by default + )(f) + f = click.option( + "--median_before_std", + default=None, + type=float, + help="Specify the digitisation.", + show_default=True, + hidden=True # Hidden by default + )(f) + return f + + + + + +@main.command(cls=_SharedParams, context_settings={"ignore_unknown_options": True}) @click.argument( "fasta", - required=True, + required=False, type=click.Path( - exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path + exists=False, file_okay=True, dir_okay=False, path_type=pathlib.Path ), ) @click.option( @@ -245,67 +373,27 @@ def train( @click.option( "-o", "--out", - required=True, + required=False, type=click.Path(file_okay=True, dir_okay=False, path_type=pathlib.Path), help="Specify the path to the output POD5/SLOW5/BLOW5 file.", ) @click.option( "--profile", - default="prom_r10_dna", + default="dna-r10-prom", show_default=True, - type=click.Choice(["minion_r10_dna", "prom_r10_dna"]), + type=click.Choice(["dna-r10-prom", "dna-r10-min", "dna-r9-prom", "dna-r9-min"]), help="Select a profile for data simulation. The profile determines values for digitization, sample rate, range, offset mean, offset standard deviation, median before mean, and median before standard deviation.", ) @click.option( - "--noise-sampler", - show_default=True, - default=True, - type=bool, - help="Enable or disable the noise sampler. If disabled, no noise will be added to the signal.", -) -@click.option( - "--duration-sampler", - show_default=True, - default=True, - type=bool, - help="Enable or disable the duration sampler. If disabled, the ideal event length will be used.", -) -@click.option( - "--ideal-event-length", - default=-1.0, - show_default=True, - type=float, - help="Specify the ideal event length to use. This option is only effective if the duration sampler is disabled. If set to -1, a static normal distribution will be used.", -) -@click.option( - "--noise-std", - default=1.0, - show_default=True, - type=float, - help="Set the standard deviation for noise. When the noise sampler is enabled, the noise generated will be scaled by this value. If the noise sampler is disabled, a static normal distribution will be used. No additional noise will be added if noise-std is less than or equal to 0.", -) -@click.option( - "--distr", - default="expon", - show_default=True, - type=click.Choice(["expon", "beta", "gamma"]), - help="Choose a distribution for read sampling. This option is only required in genome mode.", -) -@click.option( - "--predict-batch-size", - default=1024, - show_default=True, - type=int, - help="Specify the batch size for prediction.", -) -@click.option( - "--export-every-n-samples", - default=500000, - show_default=True, - type=int, - help="Specify how often the predicted samples (chunk) should be saved to output file. Increasing it will reduce runtime and increase memory consumption.", + "--show-advanced-options", + is_flag=True, + default=False, + help="Show advanced options for signal prediction." ) +@conditional_option +@click.pass_context def predict( + ctx, fasta, read_input, num_reads, @@ -313,13 +401,22 @@ def predict( coverage, out, profile, + show_advanced_options, noise_sampler, duration_sampler, - ideal_event_length, + dwell_mean, + dwell_std, noise_std, distr, predict_batch_size, export_every_n_samples, + sample_rate, + digitisation, + range_val, + offset_mean, + offset_std, + median_before_mean, + median_before_std, seed, model, config, @@ -330,6 +427,26 @@ def predict( FASTA must be .fasta file with desired genome or reads for simulation """ + if show_advanced_options: + # Dynamically re-generate the command's help message with hidden=False + for param in ctx.command.params: + param.hidden = False + + # Re-run help message to show advanced options + click.echo(ctx.get_help()) + ctx.exit() # Exit after showing help with advanced options + + # Check for help flag + if ctx.invoked_subcommand is None and ctx.args and "-h" in ctx.args: + # Print the normal help and exit + click.echo(ctx.get_help()) + ctx.exit() + + if not fasta or not out: + logger.error("FASTA file and Output file are required for prediction.") + ctx.exit(1) + + setup_logging(verbosity) logger.info("seq2squiggle version %s", str(__version__)) @@ -344,11 +461,19 @@ def predict( "profile": profile, "noise_sampler": noise_sampler, "duration_sampler": duration_sampler, - "ideal_event_length": ideal_event_length, + "dwell_mean": dwell_mean, + "dwell_std": dwell_std, "noise_std": noise_std, "distr": distr, "predict_batch_size": predict_batch_size, "export_every_n_samples": export_every_n_samples, + "sample_rate": sample_rate, + "digitisation": digitisation, + "range": range_val, + "offset_mean": offset_mean, + "offset_std": offset_std, + "median_before_mean": median_before_mean, + "median_before_std": median_before_std, "seed": seed, "model": model, "config": config, @@ -376,13 +501,21 @@ def predict( c=coverage, out=out, profile=profile, - ideal_event_length=ideal_event_length, + dwell_mean=dwell_mean, + dwell_std=dwell_std, noise_std=noise_std, noise_sampling=noise_sampler, duration_sampling=duration_sampler, distr=distr, predict_batch_size=predict_batch_size, export_every_n_samples=export_every_n_samples, + sample_rate=sample_rate, + digitisation=digitisation, + range_val=range_val, + offset_mean=offset_mean, + offset_std=offset_std, + median_before_mean=median_before_mean, + median_before_std=median_before_std, seed=seed, ) logger.info("Prediction done.") @@ -428,7 +561,6 @@ def version(): def set_config(config_path : dict) -> dict: default_config_path = pathlib.Path(__file__).parent / "config.yaml" - path_to_use = default_config_path if config_path is None else config_path try: diff --git a/src/seq2squiggle/signal_io.py b/src/seq2squiggle/signal_io.py index 8586e86..566defe 100644 --- a/src/seq2squiggle/signal_io.py +++ b/src/seq2squiggle/signal_io.py @@ -12,8 +12,6 @@ import os from uuid import uuid4 -from .utils import get_profile - logger = logging.getLogger("seq2squiggle") @@ -27,19 +25,18 @@ class BLOW5Writer: The name of the slow5 file. """ - def __init__(self, filename, profile, ideal_event_length): + def __init__(self, filename, profile, ideal_mode): self.filename = filename self.profile: dict = profile - self.ideal_event_length = ideal_event_length + self.ideal_mode = ideal_mode self.signals = None - self.profile_d = get_profile(self.profile) - self.median_before = float(self.profile_d["median_before_mean"]) - self.median_before_std = float(self.profile_d["median_before_std"]) - self.offset = float(self.profile_d["offset_mean"]) - self.offset_std = float(self.profile_d["offset_std"]) - self.digitisation = float(self.profile_d["digitisation"]) - self.signal_range = float(self.profile_d["range"]) - self.sample_rate = float(self.profile_d["sample_rate"]) + self.median_before = float(self.profile["median_before_mean"]) + self.median_before_std = float(self.profile["median_before_std"]) + self.offset = float(self.profile["offset_mean"]) + self.offset_std = float(self.profile["offset_std"]) + self.digitisation = float(self.profile["digitisation"]) + self.signal_range = float(self.profile["range"]) + self.sample_rate = float(self.profile["sample_rate"]) self.start_time = 0 def save(self): @@ -57,11 +54,24 @@ def save(self): # To write a file, mode in Open() must be set to 'w' and when appending, 'a' s5 = pyslow5.Open(str(self.filename), f_mode) + if f_mode == 'w': + header, end_reason_labels = s5.get_empty_header(aux=True) + header['asic_id'] = 'asic_id_0' + header['exp_start_time'] = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") + header['experiment_type'] = 'genomic_dna' + header['run_id'] = 'run_id_0' + header['sample_frequency'] = int(self.sample_rate) + + # Remove any fields with None values from header + header = {key: value for key, value in header.items() if value is not None} + + ret = s5.write_header(header,end_reason_labels=end_reason_labels) + records = {} auxs = {} for idx, (read_id, signal) in enumerate(self.signals.items()): - if self.ideal_event_length <= 0: + if self.ideal_mode: median_before_value = np.random.normal( self.median_before, self.median_before_std ) @@ -69,7 +79,7 @@ def save(self): else: median_before_value = self.median_before offset_value = self.offset - signal = signal.numpy().astype(np.float32) + signal = signal.cpu().numpy().astype(np.float32) signal_raw = np.round( signal * self.digitisation / self.signal_range - self.offset ) @@ -111,19 +121,18 @@ class POD5Writer: The name of the pod5 file. """ - def __init__(self, filename, profile, ideal_event_length): + def __init__(self, filename, profile, ideal_mode): self.filename = filename self.profile: dict = profile - self.ideal_event_length = ideal_event_length + self.ideal_mode = ideal_mode self.signals = None - self.profile_d = get_profile(self.profile) - self.median_before = float(self.profile_d["median_before_mean"]) - self.median_before_std = float(self.profile_d["median_before_std"]) - self.offset = float(self.profile_d["offset_mean"]) - self.offset_std = float(self.profile_d["offset_std"]) - self.digitisation = float(self.profile_d["digitisation"]) - self.signal_range = float(self.profile_d["range"]) - self.sample_rate = float(self.profile_d["sample_rate"]) + self.median_before = float(self.profile["median_before_mean"]) + self.median_before_std = float(self.profile["median_before_std"]) + self.offset = float(self.profile["offset_mean"]) + self.offset_std = float(self.profile["offset_std"]) + self.digitisation = float(self.profile["digitisation"]) + self.signal_range = float(self.profile["range"]) + self.sample_rate = float(self.profile["sample_rate"]) self.start_time = 0 def save(self): @@ -134,6 +143,8 @@ def save(self): logger.warning("POD5 was not exported. No signals were found") raise ValueError("POD5 was not exported. No signals were found") + + run_info = pod5.RunInfo( acquisition_id="", # f5d5051ec9f7983c76e78543f720289d2988ce48 acquisition_start_time=datetime.now(), @@ -157,10 +168,12 @@ def save(self): tracking_id={}, ) + + pod5_reads = [] for idx, (read_id, signal) in enumerate(self.signals.items()): - if self.ideal_event_length <= 0: + if self.ideal_mode: median_before_value = np.random.normal( self.median_before, self.median_before_std ) @@ -168,13 +181,13 @@ def save(self): else: median_before_value = self.median_before offset_value = self.offset - signal = signal.numpy().astype(np.float32) + signal = signal.cpu().numpy().astype(np.float32) signal_raw = np.round( signal * self.digitisation / self.signal_range - self.offset ) signal_raw = signal_raw.astype(np.int16) - pore = pod5.Pore(channel=123, well=3, pore_type="pore_type") + pore = pod5.Pore(channel=123, well=3, pore_type="not_set") calibration = pod5.Calibration( offset=offset_value, scale=(self.signal_range / self.digitisation) ) diff --git a/src/seq2squiggle/train.py b/src/seq2squiggle/train.py index 5cb489e..efe6adf 100644 --- a/src/seq2squiggle/train.py +++ b/src/seq2squiggle/train.py @@ -91,9 +91,7 @@ def train_run( ] # "gamma_cpu" not implemented for 'BFloat16' - precision = "64" - if torch.cuda.device_count() >= 1: - precision = "16-mixed" + precision = "16-mixed" if torch.cuda.device_count() >= 1 else "64" trainer = pl.Trainer( accelerator="auto", @@ -106,7 +104,6 @@ def train_run( logger=wandb_logger, gradient_clip_val=config["gradient_clip_val"], strategy=_get_strategy(), - #num_nodes=2, ) trainer.fit(fft_model, poredata) diff --git a/src/seq2squiggle/utils.py b/src/seq2squiggle/utils.py index 7405e00..d14468c 100644 --- a/src/seq2squiggle/utils.py +++ b/src/seq2squiggle/utils.py @@ -19,6 +19,7 @@ import psutil import multiprocessing from typing import List, Generator, Tuple, Union +from uuid import uuid4 logger = logging.getLogger("seq2squiggle") @@ -50,7 +51,7 @@ def n_workers() -> int: return n_cpu // n_gpu if (n_gpu := torch.cuda.device_count()) > 1 else n_cpu -def one_hot_encode(sequences: List[str]) -> np.ndarray: +def one_hot_encode(sequences: List[str], seq_len: int) -> np.ndarray: """ One-hot encodes a list of DNA sequences. @@ -58,6 +59,8 @@ def one_hot_encode(sequences: List[str]) -> np.ndarray: ---------- sequences : list of str A list where each string is a DNA sequence containing characters from {"_", "A", "C", "G", "T"}. + seq_len: int + Length of the input k-mer sequences Returns ------- @@ -73,7 +76,7 @@ def one_hot_encode(sequences: List[str]) -> np.ndarray: # Initialize an empty array to store the one-hot encoded sequences n_outer_sequences = len(sequences) - one_hot_encoded = np.zeros((n_outer_sequences, 9, n_letters), dtype=np.float16) + one_hot_encoded = np.zeros((n_outer_sequences, seq_len, n_letters), dtype=np.float16) # Iterate through each outer sequence and its inner sequences, and one-hot encode them for i, outer_sequence in enumerate(sequences): @@ -142,24 +145,42 @@ def get_profile(profile): ------- """ profiles = { - "minion_r10_dna": { + "dna-r10-min": { "digitisation": 8192, - "sample_rate": 4000, + "sample_rate": 5000, "range": 1536.598389, "offset_mean": 13.380569389019, "offset_std": 16.311471649012, "median_before_mean": 202.15407438804, "median_before_std": 13.406139241768, }, - "prom_r10_dna": { + "dna-r10-prom": { "digitisation": 2048, - "sample_rate": 4000, + "sample_rate": 5000, "range": 281.345551, "offset_mean": -127.5655735, "offset_std": 19.377283387665, "median_before_mean": 189.87607393756, "median_before_std": 15.788097978713, }, + "dna-r9-min": { + "digitisation": 8192, + "sample_rate": 4000, + "range": 1443.030273, + "offset_mean": 13.7222605, + "offset_std": 10.25279688, + "median_before_mean": 200.815801, + "median_before_std": 20.48933762, + }, + "dna-r9-prom": { + "digitisation": 2048, + "sample_rate": 4000, + "range": 748.5801, + "offset_mean": -237.4102, + "offset_std": 14.1575, + "median_before_mean": 214.2890337, + "median_before_std": 18.0127916, + }, } if profile in profiles: @@ -168,6 +189,54 @@ def get_profile(profile): logger.error(f"Incorrect value for profile: {profile}") +def update_profile(profile_dict, **kwargs): + """ + Update the profile dictionary with the provided parameters. + + Any parameter in kwargs that is not None will replace the corresponding + value in the profile_dict. + + ------- + Arguments + dict + The current profile dictionary to update + kwargs + The parameters to update in the profile dictionary + + ------- + Returns + dict + The updated profile dictionary + """ + for key, value in kwargs.items(): + if value is not None and key in profile_dict: + profile_dict[key] = value + elif key not in profile_dict: + logger.warning(f"Warning: {key} is not a valid key in the profile") + + return profile_dict + +def update_config(profile_name, config): + """ + Updates the configuration dictionary with the appropriate sequence k-mer size + based on the profile name. + + Parameters: + profile_name (str): The profile name, typically indicating sequencing chemistry. + config (dict): The configuration dictionary to update. + + Returns: + dict: The updated configuration dictionary. + """ + if profile_name.startswith("dna-r10"): + config["seq_kmer"] = 9 + elif profile_name.startswith("dna-r9"): + config["seq_kmer"] = 6 + else: + raise ValueError(f"Unsupported profile name: {profile_name}. Expected 'dna-r10' or 'dna-r9' prefix.") + return config + + def regular_break_points(n, chunk_len, overlap=0, align="left"): """ Returns breakpoints of a signal given a chunk_len. @@ -255,7 +324,7 @@ def add_remainder(x, max_dna, k): def split_sequence(x, config): x = extract_kmers(x, config["seq_kmer"]) x = add_remainder(x, config["max_dna_len"], config["seq_kmer"]) - x = one_hot_encode(x) + x = one_hot_encode(x, config["seq_kmer"]) breakpoints = regular_break_points(len(x), config["max_dna_len"], align="left") x_breaks = np.array([x[i:j] for (i, j) in breakpoints]) return x_breaks @@ -370,12 +439,12 @@ def export_fasta(read_l, fasta): out_file = f"{file_name}_reads.fasta" with open(out_file, "w") as f: for i, read in enumerate(read_l): - f.write(f">Read_{i}\n{''.join(read)}\n") + f.write(f"{str(uuid4())}\n{''.join(read)}\n") return out_file def yield_reads(reads): - return ((read, f">Read_{i}") for i, read in enumerate(reads)) + return ((read, str(uuid4())) for i, read in enumerate(reads)) def sample_reads_from_genome(