diff --git a/README.md b/README.md index 01ca10c..86d6b28 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,16 @@ # CITRUS🍊: A phenotype simulation tool with the flexibility to model complex interactions -CITRUS, the CIs and Trans inteRaction nUmerical Simulator, is a tool for simulating phenotypes with complex genetic archetectures that go beyond simple models that assume linear, additive contributions of individual SNPs. The goal of this tool is to provide better simulations for benchmarking GWAS/PRS models. +CITRUS, the CIs and Trans inteRaction nUmerical Simulator, is a collection of tools for simulating phenotypes with complex genetic architectures that go beyond simple models that assume linear, additive contributions of individual SNPs. The goal of CITRUS is to provide better simulations for benchmarking GWAS/PRS models. -## Installation +The key component of CITRUS is the ability to specify custom models relating genotypes to phenotypes. See the [designing simulations](doc/designing_simulations.md) for details on specifying models. Example models are provided in `example-files/`. + +CITRUS provides multiple command line utilities for performing and analyzing simulations: + +* [citrus simulate](doc/cli.md#simulate): Perform a simulation using a given model +* [citrus plot](doc/cli.md#plot): Visualize a phenotype model +* [citrus shap](doc/cli.md#shap): Generate SHAP values for a model -For plotting models, you will need to have [graphviz](https://graphviz.org/) installed. +## Installation ### With conda @@ -31,11 +37,21 @@ conda run -n citrus poetry install conda activate citrus ``` +Note, for plotting models, you will need to have [graphviz](https://graphviz.org/) installed. + +## Quickstart + +``` +# Visualize a model +citrus plot -c example-files/linear_additive.json +``` + ## Full documentation +[Command Line Interface](doc/cli.md) + [User Guide](doc/user_guide.md) [Designing Simulations](doc/designing_simulations.md) -[Command Line Interface](doc/cli.md) diff --git a/citrus/cli.py b/citrus/cli.py index 2b2b62f..51cbbf3 100644 --- a/citrus/cli.py +++ b/citrus/cli.py @@ -1,59 +1,19 @@ """CITRUS command line interface. -See CITRUS/doc/CLI.md for more information. - -This tool can be used to run the simulation based on either: - - 1. A single configuration JSON file that specifies paths to genotype - data files. - - CITRUS_sim -c - - 2. A single configuration JSON file and a list of paths to genotype - data files. The list of paths must be the same length as the - number of input source files in the configuration file (i.e. - the length of the list under the 'input' key in the JSON). Any - paths in the configuration file will be ignored. The -g or - --genotype_files flag can be used to specify the paths to the - genotype files. - - CITRUS_sim -c -g \\ - ... - - CITRUS_sim -c -g - -Output: - - If no additional flags are provided, output will be written to the - current working directory. The output will be a CSV file (output.csv) - containing sample IDs and all corresponding values from the simulation, - and a JSON file (sim_config.json) containing the simulation configuration - (including any random selections made by nodes). - - If the -o or --output_dir flag is provided, if the directory does not - exist it will be created, and the output files will be saved to it. By - default the output files will be named output.csv and sim_config.json, - but these can be changed with the -f or --output_file_name and -j or - --output_config_json flags, respectively. - - Output file will by default be a comma seperated CSV file. Use -t or - --tsv flag to instead save as a tab seperated TSV file. - -Example Usage: - - CITRUS_sim -c config.json -o sim_results/output_dir - - CITRUS_sim -c config.json -g genotype_file_1 genotype_file_2 \\ - -o sim_results/output_dir -t -f my_output.tsv -j my_sim_config.json +See CITRUS/doc/CLI.md and individual tools for more information. """ import click +import sys @click.group() @click.version_option(package_name="citrus", message="%(version)s") def citrus(): pass +""" +citrus simulate +""" @citrus.command(no_args_is_help=True) @click.option( '-c', '--config_file', @@ -147,6 +107,9 @@ def simulate( sep="\t" if tsv else "," ) +""" +citrus plot +""" @citrus.command(no_args_is_help=True) @click.option( '-c', '--config_file', @@ -161,28 +124,39 @@ def simulate( help="Output filename (without extension) for saving plot." ) @click.option( - '-f', '--format', + '-f', '--fmt', type=click.Choice(['jpg', 'png', 'svg']), default='png', show_default=True, help="File format and extension for the output plot." ) -def plot(config_file: str, out: str, format: str): +@click.option( + '--verbose', + is_flag=True, + help="Print extra output to the terminal", + default=False +) +def plot(config_file: str, out: str, fmt: str, verbose: str): """ Save a plot of the network defined by the simulation config file. Note: Colors correspond to cis, inheritance, and trans effects """ - from pheno_sim import plot + from . import plot from json import load with open(config_file, "r") as f: config = load(f) # Create a plot of the model - plot.visualize(input_spec=config, filename=out, img_format=format) + retcode = plot.visualize(input_spec=config, filename=out, + img_format=fmt, verbose=verbose) + sys.exit(retcode) +""" +citrus shap +""" @citrus.command(no_args_is_help=True) @click.option( '-c', '--config_file', diff --git a/citrus/plot.py b/citrus/plot.py new file mode 100644 index 0000000..38a62a6 --- /dev/null +++ b/citrus/plot.py @@ -0,0 +1,234 @@ +"""Plot simulation as directed graph.""" + +import pydot +from PIL import Image, ImageDraw + +from pheno_sim.pheno_simulation import PhenoSimulation +from pheno_sim.base_nodes import AbstractBaseCombineFunctionNode +from .utils import MSG + +def visualize(input_spec: dict, filename: str, + img_format: str, verbose: bool=False) -> int: + """ + Visualize a phenotype model + + Parameters + ---------- + input_spec : dict + Model configuration + filename : str + Prefix of output filename + img_format : str + Format of output file (jpg, png, or svg}) + verbose : bool + If true, print extra output to terminal + + Returns + ------- + retcode : int + Return code (0 for success) + """ + + # Generate sim object + sim = PhenoSimulation(input_spec) + + # Add input nodes + sim_nodes = dict() + + for input_source in sim.input_runner.input_sources: + for input_node in input_source.input_nodes: + sim_nodes[input_node.alias] = CITRUSNode( + alias=input_node.alias, + node_type='input', + class_name=type(input_node).__name__ + ) + + # Add operator nodes + for sim_step in sim.simulation_steps: + step_alias = sim_step.alias + step_inputs = sim_step.inputs + + if isinstance(step_inputs, str): + step_inputs = [step_inputs] + elif isinstance(step_inputs, dict): + step_inputs = list(step_inputs.values()) + + if isinstance(sim_step, AbstractBaseCombineFunctionNode): + step_type = 'combine' + else: + step_type = 'trans' + + sim_nodes[step_alias] = CITRUSNode( + alias=step_alias, + inputs=step_inputs, + node_type=step_type, + class_name=type(sim_step).__name__ + ) + + # Identify nodes that are of type 'combine'. + combine_nodes = [node for node in sim_nodes.values() if node.node_type == 'combine'] + + # For each combine node, identify its ancestors. + all_ancestor_nodes = set() + for combine_node in combine_nodes: + all_ancestor_nodes.update(get_ancestor_nodes(combine_node, sim_nodes)) + + # Change the node_type of ancestor nodes which are of type 'trans' to 'cis'. + for ancestor_node_alias in all_ancestor_nodes: + if sim_nodes[ancestor_node_alias].node_type == 'trans': + sim_nodes[ancestor_node_alias].node_type = 'cis' + + # Print nodes + if verbose: + for k,v in sim_nodes.items(): + MSG(f"{k}:\n\t{str(v)}") + + # Plot the graph + plot_graph_with_legend(sim_nodes, filename, img_format) + MSG(f"Plot output to {filename}.{img_format}") + + # Return + return 0 + + +class CITRUSNode: + """ + Helper class for plotting. + Represents a single CITRUS model node + + Attributes + ---------- + alias : str + Node alias + inputs : list + Simulation steps input to this node + node_type : str + One of: input, cis, trans, combine + class_name : str + Name of the class of the node + """ + def __init__( + self, + alias, + inputs=[], + node_type=None, + class_name=None + ): + self.alias = alias + self.inputs = inputs + self.node_type = node_type + self.class_name = class_name + + def __str__(self): + return f"alias: {self.alias}\tnode_type: {self.node_type}\tclass_name: {self.class_name}\tinputs: {self.inputs}" + + +def get_ancestor_nodes(node: CITRUSNode, + sim_nodes: dict[str, CITRUSNode]) -> list[CITRUSNode]: + """ + Iteratively get ancestor nodes of a given node. + + Parameters + ---------- + node : CITRUSNode + node to get ancestors of + sim_nodes : dict[str]->CITRUSNode + Dictionary of all nodes + + Returns + ------- + ancestors : list of CITRUSNode + List of nodes that are ancestors of node + """ + ancestors = set() + to_visit = [node] + + while to_visit: + current_node = to_visit.pop() + + # Add the direct parents to the ancestors set + for input_node_alias in current_node.inputs: + # Check if the ancestor is already in the set to avoid cycles + if input_node_alias not in ancestors: + ancestors.add(input_node_alias) + to_visit.append(sim_nodes[input_node_alias]) + + return list(ancestors) + + + +def plot_graph_with_legend(sim_nodes: dict[str, CITRUSNode], + filename: str, img_format: str): + """ + Make the plot from a list of nodes + + Parameters + ---------- + sim_nodes : dict[str]->CITRUSNode + Dictionary of all nodes + filename : str + Prefix of output filename + img_format : str + Format of output file (jpg, png, or svg}) + """ + + # Define a dictionary for colors based on node_type + node_colors = { + 'input': '#FF8C00', # dark orange + 'cis': '#90EE90', # light green + 'trans': '#DEB887', # burly wood (light brown) + 'combine': '#FFF44F' # lemon yellow + } + + # Create a new graph + graph = pydot.Dot(graph_type='digraph', rankdir='UD', ranksep='0.5') + + # Add nodes with their respective colors + for node in sim_nodes.values(): + label = f"{node.alias}\n<{node.class_name}>" + graph.add_node(pydot.Node(node.alias, label=label, style="filled", fillcolor=node_colors[node.node_type])) + + # Add edges based on the inputs of each node + for node in sim_nodes.values(): + for input_node in node.inputs: + graph.add_edge(pydot.Edge(input_node, node.alias)) + + # Save graph to file + graph.write(filename + '.' + img_format, format=img_format) + + # Adding a legend using PIL + img = Image.open(filename + '.' + img_format) + legend = Image.new('RGB', (250, 100), (255, 255, 255)) + + # Draw the legend + for index, (label, color) in enumerate(node_colors.items()): + d = ImageDraw.Draw(legend) + d.rectangle([10, 10 + index * 25, 30, 30 + index * 25], fill=color) + d.text((40, 10 + index * 25), label, fill=(0, 0, 0)) + + # Scale up the legend by any desired factor + scale_factor = 2.0 + scaled_width = int(legend.width * scale_factor) + scaled_height = int(legend.height * scale_factor) + legend = legend.resize( + (scaled_width, scaled_height), + resample=Image.BICUBIC + ) + + # Trim whitespace: Here we'll crop out 40% from the right, adjust as needed + crop_percentage = 0.5 + cropped_width = int(scaled_width * (1 - crop_percentage)) + legend = legend.crop((0, 0, cropped_width, scaled_height)) + + # Combine original image and legend + total_width = img.width + legend.width + max_height = max(img.height, legend.height) + + combined = Image.new('RGB', (total_width, max_height), 'white') + combined.paste(img, (0, 0)) + + # Center the legend vertically + y_offset = (img.height - legend.height) // 2 + combined.paste(legend, (img.width, y_offset)) + + combined.save(filename + '.' + img_format) \ No newline at end of file diff --git a/citrus/utils.py b/citrus/utils.py new file mode 100644 index 0000000..3a52884 --- /dev/null +++ b/citrus/utils.py @@ -0,0 +1,8 @@ +""" +Utility functions +""" + +import sys + +def MSG(msg: str): + sys.stderr.write("[CITRUS]: " + msg.strip() +"\n") \ No newline at end of file diff --git a/example-files/linear_additive_nogenotypes.json b/example-files/linear_additive_nogenotypes.json new file mode 100644 index 0000000..d43a80d --- /dev/null +++ b/example-files/linear_additive_nogenotypes.json @@ -0,0 +1,66 @@ +{ + "input": [ + { + "input_nodes": [ + { + "alias": "chr19_280540_G_A", + "type": "SNP", + "chr": "19", + "pos": 280540 + }, + { + "alias": "chr19_523746_C_T", + "type": "SNP", + "chr": "19", + "pos": [523746] + } + ] + } + ], + "simulation_steps": [ + { + "type": "Constant", + "alias": "chr19_280540_G_A_beta", + "input_match_size": "chr19_280540_G_A", + "constant": 0.1 + }, + { + "type": "Constant", + "alias": "chr19_523746_C_T_beta", + "input_match_size": "chr19_523746_C_T", + "constant": 0.3 + }, + { + "type": "Product", + "alias": "chr19_280540_G_A_effect", + "input_aliases": [ + "chr19_280540_G_A_beta", "chr19_280540_G_A" + ] + }, + { + "type": "Product", + "alias": "chr19_523746_C_T_effect", + "input_aliases": [ + "chr19_523746_C_T_beta", "chr19_523746_C_T" + ] + }, + { + "type": "Concatenate", + "alias": "effects_by_haplotype", + "input_aliases": [ + "chr19_280540_G_A_effect", + "chr19_523746_C_T_effect" + ] + }, + { + "type": "AdditiveCombine", + "alias": "effects", + "input_alias": "effects_by_haplotype" + }, + { + "type": "SumReduce", + "alias": "phenotype", + "input_alias": "effects" + } + ] +} \ No newline at end of file diff --git a/pheno_sim/plot.py b/pheno_sim/plot.py deleted file mode 100644 index 68664f1..0000000 --- a/pheno_sim/plot.py +++ /dev/null @@ -1,229 +0,0 @@ -"""Plot simulation as directed graph.""" - -import pydot -from PIL import Image, ImageDraw - -from pheno_sim.pheno_simulation import PhenoSimulation -from pheno_sim.base_nodes import AbstractBaseCombineFunctionNode - - -def visualize(input_spec: dict, filename: str, img_format: str): - # Generate sim object - sim = PhenoSimulation(input_spec) - - # Add input nodes - sim_nodes = dict() - - for input_source in sim.input_runner.input_sources: - for input_node in input_source.input_nodes: - sim_nodes[input_node.alias] = CITRUSNode( - alias=input_node.alias, - node_type='input', - class_name=type(input_node).__name__ - ) - - # Add operator nodes - for sim_step in sim.simulation_steps: - step_alias = sim_step.alias - step_inputs = sim_step.inputs - - if isinstance(step_inputs, str): - step_inputs = [step_inputs] - elif isinstance(step_inputs, dict): - step_inputs = list(step_inputs.values()) - - if isinstance(sim_step, AbstractBaseCombineFunctionNode): - step_type = 'combine' - else: - step_type = 'trans' - - sim_nodes[step_alias] = CITRUSNode( - alias=step_alias, - inputs=step_inputs, - node_type=step_type, - class_name=type(sim_step).__name__ - ) - - # Identify nodes that are of type 'combine'. - combine_nodes = [node for node in sim_nodes.values() if node.node_type == 'combine'] - - # For each combine node, identify its ancestors. - all_ancestor_nodes = set() - for combine_node in combine_nodes: - all_ancestor_nodes.update(get_ancestor_nodes(combine_node, sim_nodes)) - - # Change the node_type of ancestor nodes which are of type 'trans' to 'cis'. - for ancestor_node_alias in all_ancestor_nodes: - if sim_nodes[ancestor_node_alias].node_type == 'trans': - sim_nodes[ancestor_node_alias].node_type = 'cis' - - # Print nodes - for k,v in sim_nodes.items(): - print(f"{k}:\n\t{str(v)}") - - # Plot the graph - plot_graph_with_legend(sim_nodes, filename, img_format) - - -class CITRUSNode: - - def __init__( - self, - alias, - inputs=[], - node_type=None, - class_name=None - ): - self.alias = alias - self.inputs = inputs - self.node_type = node_type - self.class_name = class_name - - def __str__(self): - return f"alias: {self.alias}\tnode_type: {self.node_type}\tclass_name: {self.class_name}\tinputs: {self.inputs}" - - -def get_ancestor_nodes(node, sim_nodes): - """Iteratively get ancestor nodes of a given node.""" - ancestors = set() - to_visit = [node] - - while to_visit: - current_node = to_visit.pop() - - # Add the direct parents to the ancestors set - for input_node_alias in current_node.inputs: - # Check if the ancestor is already in the set to avoid cycles - if input_node_alias not in ancestors: - ancestors.add(input_node_alias) - to_visit.append(sim_nodes[input_node_alias]) - - return list(ancestors) - - - -def plot_graph_with_legend(sim_nodes, filename, img_format): - # Define a dictionary for colors based on node_type - node_colors = { - 'input': '#FF8C00', # dark orange - 'cis': '#90EE90', # light green - 'trans': '#DEB887', # burly wood (light brown) - 'combine': '#FFF44F' # lemon yellow - } - - # Create a new graph - graph = pydot.Dot(graph_type='digraph', rankdir='UD', ranksep='0.5') - - # Add nodes with their respective colors - for node in sim_nodes.values(): - label = f"{node.alias}\n<{node.class_name}>" - graph.add_node(pydot.Node(node.alias, label=label, style="filled", fillcolor=node_colors[node.node_type])) - - # Add edges based on the inputs of each node - for node in sim_nodes.values(): - for input_node in node.inputs: - graph.add_edge(pydot.Edge(input_node, node.alias)) - - # Save graph to file - graph.write(filename + '.' + img_format, format=img_format) - - # Adding a legend using PIL - img = Image.open(filename + '.' + img_format) - legend = Image.new('RGB', (250, 100), (255, 255, 255)) - - # Draw the legend - for index, (label, color) in enumerate(node_colors.items()): - d = ImageDraw.Draw(legend) - d.rectangle([10, 10 + index * 25, 30, 30 + index * 25], fill=color) - d.text((40, 10 + index * 25), label, fill=(0, 0, 0)) - - # Scale up the legend by any desired factor - scale_factor = 2.0 - scaled_width = int(legend.width * scale_factor) - scaled_height = int(legend.height * scale_factor) - legend = legend.resize( - (scaled_width, scaled_height), - resample=Image.BICUBIC - ) - - # Trim whitespace: Here we'll crop out 40% from the right, adjust as needed - crop_percentage = 0.5 - cropped_width = int(scaled_width * (1 - crop_percentage)) - legend = legend.crop((0, 0, cropped_width, scaled_height)) - - # Combine original image and legend - total_width = img.width + legend.width - max_height = max(img.height, legend.height) - - combined = Image.new('RGB', (total_width, max_height), 'white') - combined.paste(img, (0, 0)) - - # Center the legend vertically - y_offset = (img.height - legend.height) // 2 - combined.paste(legend, (img.width, y_offset)) - - combined.save(filename + '.' + img_format) - - -# if __name__ == '__main__': - -# # dev -# input_spec = '../benchmarking_mlcb/pheno_sim/sim_configs/xor_pheno_1.json' -# filename = '../benchmarking_mlcb/pheno_sim/sim_configs/xor_pheno_1' -# img_format = 'png' - -# # Generate sim object -# sim = PhenoSimulation.from_JSON_file(input_spec) - -# # Add input nodes -# sim_nodes = dict() - -# for input_source in sim.input_runner.input_sources: -# for input_node in input_source.input_nodes: -# sim_nodes[input_node.alias] = CITRUSNode( -# alias=input_node.alias, -# node_type='input', -# class_name=type(input_node).__name__ -# ) - -# # Add operator nodes -# for sim_step in sim.simulation_steps: -# step_alias = sim_step.alias -# step_inputs = sim_step.inputs - -# if isinstance(step_inputs, str): -# step_inputs = [step_inputs] -# elif isinstance(step_inputs, dict): -# step_inputs = list(step_inputs.values()) - -# if isinstance(sim_step, AbstractBaseCombineFunctionNode): -# step_type = 'combine' -# else: -# step_type = 'trans' - -# sim_nodes[step_alias] = CITRUSNode( -# alias=step_alias, -# inputs=step_inputs, -# node_type=step_type, -# class_name=type(sim_step).__name__ -# ) - -# # Identify nodes that are of type 'combine'. -# combine_nodes = [node for node in sim_nodes.values() if node.node_type == 'combine'] - -# # For each combine node, identify its ancestors. -# all_ancestor_nodes = set() -# for combine_node in combine_nodes: -# all_ancestor_nodes.update(get_ancestor_nodes(combine_node, sim_nodes)) - -# # Change the node_type of ancestor nodes which are of type 'trans' to 'cis'. -# for ancestor_node_alias in all_ancestor_nodes: -# if sim_nodes[ancestor_node_alias].node_type == 'trans': -# sim_nodes[ancestor_node_alias].node_type = 'cis' - -# # Print nodes -# for k,v in sim_nodes.items(): -# print(f"{k}:\n\t{str(v)}") - -# # Plot the graph -# plot_graph_with_legend(sim_nodes, filename, img_format)