Skip to content

Commit

Permalink
add a module 'reconstruct', which combines train and inference into o…
Browse files Browse the repository at this point in the history
…ne step
  • Loading branch information
ShuminBAL committed Dec 1, 2023
1 parent 0028724 commit c50f7a5
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ __pycache__/
*.pyc
results/
.ipynb_checkpoints/
*_test.json
.DS_Store
47 changes: 42 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ A preprint describing CellContrast's algorithms and results is at [bioRxiv](http
- [Latest Updates](#latest-updates)
- [Installations](#installation)
- [Usage](#usage)
- [Quick start](#quick-start)
- [Model training](#model-training)
- [Performance evaluation](#performance-evaluation)
- [Spatial inference](#spatial-inference)
Expand Down Expand Up @@ -65,22 +66,58 @@ conda install pytorch torchvision torchaudio cpuonly -c pytorch
## Usage


CellContrast contains 3 main moduels: train, eval and inference, for training model, benchmarking evaluation and inference of spatial relationships, respectively. To check available modules, run:
CellContrast contains 3 main moduels: `train`, `eval` and `inference`, for training model, benchmarking evaluation and inference of spatial relationships, respectively. In addition, We also provide `reconstruct` module for integrating `train` and `inference`. To check available modules, run:

```bash

python cellContrast.py -h

```

### Quick Start


#### Run with sequencing-based ST

```bash

python cellContrast.py reconstruct \
--train_data_path train_ST.h5ad ## required, use your ST h5ad file here\
--query_data_path query_sc.h5ad ## path of query SC h5ad file\
--parameter_file parameters/parameters_spot.json ## optional. use the our default for spot or single-cell ST, or your customized parameters here\
--save_folder cellContrast_models/ ## optional, model output path\
--enable_denovo ## optional, run MDS to leverage the SC-SC pairwise distance to 2D pseudo space
--save_path spatial_reconstructed_sc.h5ad \ ## path of of the spatial reconstructed SC data
```

#### Run with imaging-based ST

* Adopt the predefined parameters for imaging-based ST data by setting `--single_cell`.

```bash
python cellContrast.py reconstruct \
--train_data_path train_ST.h5ad ## required, use your ST h5ad file here\
--query_data_path query_sc.h5ad ## path of query SC h5ad file\
--single_cell \
--parameter_file parameters/parameters_singleCell.json ## optional. use the our default for spot or single-cell ST, or your customized parameters here\
--save_folder cellContrast_models/ ## optional, model output path\
--enable_denovo ## optional, run MDS to leverage the SC-SC pairwise distance to 2D pseudo space
--save_path spatial_reconstructed_sc.h5ad \ ## path of of the spatial reconstructed SC data

```

### Model training
CellContrast model was trained based on ST data (which should be in [AnnData](https://anndata.readthedocs.io/en/latest/) format, with truth locations in `.obs[['x','y']])`. The model can be trained with the following command:

* :bangbang: Default parameters are defined for sequencing-based ST, adopt the predefined parameters for imaging-based ST data by setting `--single_cell`.

```bash

python cellContrast.py train \
--train_data_path train_ST.h5ad \ ## required, use your ST h5ad file here
--save_folder cellContrast_models/ \ ## optional, model output path
--parameter_file parameters.json ## optional. use the default or your customized parameters here

--single_cell # defaut: not enabled. Set this flag to switch to our prefined parameters for imaging-based ST.
--parameter_file parameters/parameters_singleCell.json ## optional. use the our default for spot or single-cell ST, or your customized parameters here\
## Output file: cellContrast_models/epoch_3000.pt
```

Expand All @@ -92,7 +129,7 @@ python cellContrast.py eval \
--ref_data_path ref_ST.h5ad \ ## path of refernece ST h5ad file
--query_data_path query_ST.h5ad \ ## path of testing h5ad file with truth locations
--model_foldercellContrast_models\ ## folder of trained model
--parameter_file parameters.json \ ## parameters of trained model
--parameter_file parameters/parameters_singleCell.json ## Take the parameter file you used in the training phase.\
--save_path results.csv \ ## evaluation result path

## Output file: result.csv with neighbor hit, JSD, spearman's rank correlation for each testing sample.
Expand All @@ -107,7 +144,7 @@ python cellContrast.py inference \
--ref_data_path train_ST.h5ad \ ## path of refernece ST h5ad file
--query_data_path query_sc.h5ad \ ## path of query SC h5ad file
--model_folder \ ## folder of trained model
--parameter_file parameters.json \ ## uparameters of trained model
--parameter_file parameters/parameters_singleCell.json ## Take the parameter file you used in the training phase.\
--save_path spatial_reconstructed_sc.h5ad \ ## path of of the spatial reconstructed SC data
--enable_denovo \ ## optional, run MDS to leverage the SC-SC pairwise distance to 2D pseudo space

Expand Down
6 changes: 3 additions & 3 deletions cellContrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from importlib import import_module


data_preprocess_modules = ["normalization"]
deep_learning_modules = ["train","eval","inference"]
data_preprocess_modules = [""]
deep_learning_modules = ["reconstruct","train","eval","inference"]


DEEP_LEARNING_FOLDER = "cellContrast"
Expand Down Expand Up @@ -70,4 +70,4 @@ def main():

main()

pass
pass
3 changes: 2 additions & 1 deletion cellContrast/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def main():
parser.add_argument('--model_folder', type=str,default="./cellContrast_models",
help="Save folder of model related files, default:'./cellContrast_models'")
parser.add_argument('--parameter_file_path', type=str,
help="Path of parameter settings, default:'./parameters.json'",default="./parameters.json")
help="Please take the parameter file you used in the training phase,\
default:'./parameters/parameters_spot.json'",default="./parameters/parameters_spot.json")

parser.add_argument('--save_path',type=str,help="Save path of evaluation result",default="./result.csv")
args = parser.parse_args()
Expand Down
3 changes: 2 additions & 1 deletion cellContrast/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def main():
parser.add_argument('--model_folder', type=str,
help="Save folder of model related files, default:'./cellContrast_models'",default="./cellContrast_models")
parser.add_argument('--parameter_file_path', type=str,
help="Path of parameter settings, default:'./parameters.json'",default="./parameters.json")
help="Please take the parameter file you used in the training phase,\
default:'./parameters/parameters_spot.json'",default="./parameters/parameters_spot.json")
parser.add_argument('--ref_data_path',type=str, help="reference ST data, used in generating the coordinates of SC data as the reference, usually should be the training data of the model")

# whether to enable de novo coordinates inference
Expand Down
111 changes: 111 additions & 0 deletions cellContrast/reconstruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from argparse import ArgumentParser, SUPPRESS
import scanpy as sc
import cellContrast.train
import cellContrast.inference
import logging
import sys
import os

logging.getLogger().setLevel(logging.INFO)






def formatInputAnn(args):

'''
'''
train_adata = sc.read_h5ad(args.train_data_path)
query_adata = sc.read_h5ad(args.query_data_path)

# get the overlap genes of reference and query AnnData
train_genes = train_adata.var_names
query_genes = query_adata.var_names
overlapped_genes = list(set(train_genes).intersection(set(query_genes)))
if(len(overlapped_genes)<=0):
sys.exit("[ERROR] 0 overlapped gene found between the training and query data.")

logging.info("%s overlapped genes found between the training and query data" % (str(len(overlapped_genes))))

train_gene_indices = [train_adata.var_names.get_loc(gene) for gene in overlapped_genes]
formatted_train_adata = train_adata[:, train_gene_indices]

query_gene_indices = [query_adata.var_names.get_loc(gene) for gene in overlapped_genes]
formatted_query_adata = query_adata[:, query_gene_indices]

return formatted_train_adata,formatted_query_adata

def main():

# `reconstrcut.py` combines `train.py` and `inference.py` into one step

parser = ArgumentParser(description="Train a cellContrast model")

# training arguments
parser.add_argument('--train_data_path', type=str,
help="The path of training data with h5ad format (annData object)")

parser.add_argument('--save_folder', type=str,
help="Save folder of model related files, default:'./cellContrast_models'",default="./cellContrast_models")

parser.add_argument('--parameter_file_path', type=str,
help="Path of parameter settings, customize it based on reference ST\
default:'./parameters/parameters_spot.json'",default="./parameters/parameters_spot.json")

parser.add_argument('-sc','--single_cell',\
help="default:false, set this flag will swithing to the single-cell resolution ST mode, which uses the predefined './parameters/parameters_singleCell.json'",\
action='store_true')


# inference arugments
parser.add_argument('--query_data_path', type=str,
help="The path of querying data with h5ad format (annData object)")
parser.add_argument('--enable_denovo', action="store_true",help="(Optional) generate the coordinates de novo by MDS algorithm",default=False)
parser.add_argument('--save_path',type=str,help="Save path of the spatial reconstructed SC data",default="./reconstructed_sc.h5ad")



args = parser.parse_args()

if len(sys.argv[1:]) == 0:
parser.print_help()
sys.exit(1)

# check arguments
if(not os.path.exists(args.train_data_path)):
logging.error("train data not exists!")
sys.exit(1)
if(not os.path.exists(args.query_data_path)):
logging.error("query data not exists!")
sys.exit(1)

# check the parameter files
if(args.single_cell):
# change the parameter settings to the single-cell mode unless users have customized it.
args.parameter_file_path = "./parameters/parameters_singleCell.json"


if(not os.path.exists(args.parameter_file_path)):
print("parameter file not exists!")
sys.exit(1)

if(not os.path.exists(args.save_folder)):
os.mkdir(args.save_folder)

# format input training and query data
train_adata, query_adata = formatInputAnn(args)

# training the cellContrast model
model = cellContrast.train.train_model(args,train_adata)

# reconstruct the spatial relationships of query data
logging.info("Performing spatial inference for the query data")
reconstructed_query_adata = cellContrast.inference.perform_inference(query_adata,train_adata,model,args.enable_denovo)
reconstructed_query_adata.write(args.save_path)


35 changes: 29 additions & 6 deletions cellContrast/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import numpy as np
from tqdm import tqdm
import time
import logging

logging.getLogger().setLevel(logging.INFO)

sample_field_name = 'embryo'

Expand Down Expand Up @@ -39,18 +41,23 @@ def save_model(args,model,params,optimizer,LOSS,train_genes):



def train_model(args):
def train_model(args,train_adata=None):

print("train model")
logging.info("Training cellContrast model")


# load parameter settings
with open(args.parameter_file_path,"r") as json_file:
params = json.load(json_file)
print("parameters",params)

# load data
train_adata = sc.read_h5ad(args.train_data_path)


# load data if necessary
if(not train_adata):
logging.info("Load training data")
train_adata = sc.read_h5ad(args.train_data_path)

if(sample_field_name in train_adata.obs):
train_sample_number = train_adata.obs[sample_field_name].unique().shape[0]
else:
Expand All @@ -72,6 +79,8 @@ def train_model(args):
dev = "cuda:0"
else:
dev = "cpu"
logging.info("Using device %s" %(dev))

device = torch.device(dev)


Expand Down Expand Up @@ -131,6 +140,7 @@ def train_model(args):

# Save the last model
save_model(args,model,params,optimizer,LOSS,train_genes)
return model



Expand All @@ -143,11 +153,18 @@ def main():

parser.add_argument('--train_data_path', type=str,
help="The path of training data with h5ad format (annData object)")

parser.add_argument('--save_folder', type=str,
help="Save folder of model related files, default:'./cellContrast_models'",default="./cellContrast_models")

parser.add_argument('--parameter_file_path', type=str,
help="Path of parameter settings, default:'./parameters.json'",default="./parameters.json")
help="Path of parameter settings, customize it based on reference ST\
default:'./parameters/parameters_spot.json'",default="./parameters/parameters_spot.json")

parser.add_argument('-sc','--single_cell',\
help="default:false, set this flag will swithing to the single-cell resolution ST mode, which uses the predefined './parameters/parameters_singleCell.json'",\
action='store_true')



args = parser.parse_args()
Expand All @@ -156,11 +173,17 @@ def main():
parser.print_help()
sys.exit(1)

# check parameters
# check arguments
if(not os.path.exists(args.train_data_path)):
print("train data not exists!")
sys.exit(1)

# check the parameter files
if(args.single_cell):
# change the parameter settings to the single-cell mode unless users have customized it.
args.parameter_file_path = "./parameters/parameters_singleCell.json"


if(not os.path.exists(args.parameter_file_path)):
print("parameter file not exists!")
sys.exit(1)
Expand Down
File renamed without changes.
13 changes: 13 additions & 0 deletions parameters/parameters_spot.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"n_encoder_hidden": 1024,
"n_encoder_latent": 512,
"n_encoder_layers": 2,
"n_projection_hidden": 256,
"n_projection_output": 128,
"dropout_rate": 0,
"training_epoch": 3000,
"inital_learning_rate": 0.1,
"k_nearest_positives": 20,
"batch_size": 64,
"temperature": 0.05
}

0 comments on commit c50f7a5

Please sign in to comment.