diff --git a/Publication/SI/element_similairty_SI.ipynb b/Publication/SI/element_similairty_SI.ipynb new file mode 100644 index 0000000..8c69f4c --- /dev/null +++ b/Publication/SI/element_similairty_SI.ipynb @@ -0,0 +1,597 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Supplementary Information for Element Similarity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "import matplotlib.pyplot as plt\n", + "from typing import List, Optional, Tuple\n", + "from elementembeddings.core import Embedding, data_directory\n", + "from elementembeddings.plotter import dimension_plotter, heatmap_plotter\n", + "import pandas as pd\n", + "import os\n", + "import seaborn as sns\n", + "\n", + "sns.set_context(\"paper\", font_scale=1.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up the 7 embeddings\n", + "cbfvs = [\n", + " \"magpie_sc\",\n", + " \"mat2vec\",\n", + " \"megnet16\",\n", + " \"random_200\",\n", + " \"matscholar\",\n", + " \"oliynyk_sc\",\n", + " \"skipatom\",\n", + "]\n", + "element_embedddings = {cbfv: Embedding.load_data(cbfv) for cbfv in cbfvs}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the ordered symbols file\n", + "symbols_path = os.path.join(data_directory, \"element_data\", \"ordered_periodic.txt\")\n", + "with open(symbols_path) as f:\n", + " symbols = f.read().splitlines()\n", + "\n", + "# Get the first 83 elements\n", + "symbols = symbols[:83]\n", + "\n", + "for cbfv in element_embedddings.keys():\n", + " # Get the keys of the atomic embeddings object\n", + " elements = set(element_embedddings[cbfv].element_list)\n", + " el_symbols_set = set(symbols)\n", + "\n", + " # Get the element symbols we want to remove\n", + " els_to_remove = list(elements - el_symbols_set)\n", + "\n", + " # Iteratively delete the elements with atomic number\n", + " # greater than 83 from our embeddings\n", + " for el in els_to_remove:\n", + " del element_embedddings[cbfv].embeddings[el]\n", + "\n", + " # Verify that we have 83 elements\n", + " print(len(element_embedddings[cbfv].element_list))\n", + "\n", + "# Remove Xe and Kr from SkipAtom\n", + "# del element_embedddings[\"skipatom\"].embeddings[\"Xe\"]\n", + "# del element_embedddings[\"skipatom\"].embeddings[\"Kr\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Similarity measures\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Euclidean distance\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (axes) = plt.subplots(4, 2, figsize=(20, 20))\n", + "\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " heatmap_plotter(\n", + " embedding=cbfv,\n", + " metric=\"euclidean\",\n", + " sortaxisby=\"atomic_number\",\n", + " show_axislabels=False,\n", + " ax=ax,\n", + " )\n", + " # plt.subplots_adjust(wspace=0.001)\n", + "axes[-1][-1].remove()\n", + "\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig(\"SI_euclidean.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Skipatom\n", + "\n", + "From the above plot, we can observe two element vectors causing anomalous behaviour in the skipatom plot. We plot the skipatom map with the axis labelled to determine which elements are causing this behaviour." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(20, 20))\n", + "heatmap_plotter(\n", + " embedding=element_embedddings[\"skipatom\"],\n", + " metric=\"euclidean\",\n", + " sortaxisby=\"atomic_number\",\n", + " show_axislabels=True,\n", + " ax=ax,\n", + ")\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Manhattan distance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (axes) = plt.subplots(4, 2, figsize=(20, 20))\n", + "\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " heatmap_plotter(\n", + " embedding=cbfv,\n", + " metric=\"manhattan\",\n", + " sortaxisby=\"atomic_number\",\n", + " show_axislabels=False,\n", + " ax=ax,\n", + " )\n", + " # plt.subplots_adjust(wspace=0.001)\n", + "axes[-1][-1].remove()\n", + "\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig(\"SI_manhattan.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chebyshev" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (axes) = plt.subplots(4, 2, figsize=(20, 20))\n", + "\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " heatmap_plotter(\n", + " embedding=cbfv,\n", + " metric=\"chebyshev\",\n", + " sortaxisby=\"atomic_number\",\n", + " show_axislabels=False,\n", + " ax=ax,\n", + " )\n", + " # plt.subplots_adjust(wspace=0.001)\n", + "axes[-1][-1].remove()\n", + "\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig(\"SI_chebyshev.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Wasserstein distance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (axes) = plt.subplots(4, 2, figsize=(20, 20))\n", + "\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " heatmap_plotter(\n", + " embedding=cbfv,\n", + " metric=\"wasserstein\",\n", + " sortaxisby=\"atomic_number\",\n", + " show_axislabels=False,\n", + " ax=ax,\n", + " )\n", + " # plt.subplots_adjust(wspace=0.001)\n", + "axes[-1][-1].remove()\n", + "\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig(\"SI_wasserstein.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cosine distance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (axes) = plt.subplots(4, 2, figsize=(20, 20))\n", + "\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " heatmap_plotter(\n", + " embedding=cbfv,\n", + " metric=\"cosine_distance\",\n", + " sortaxisby=\"atomic_number\",\n", + " show_axislabels=False,\n", + " ax=ax,\n", + " )\n", + " # plt.subplots_adjust(wspace=0.001)\n", + "axes[-1][-1].remove()\n", + "\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig(\"SI_cosdistance.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pearson correlation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (axes) = plt.subplots(4, 2, figsize=(20, 20))\n", + "heatmap_params = {\"vmin\": -1, \"vmax\": 1}\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " heatmap_plotter(\n", + " embedding=cbfv,\n", + " metric=\"pearson\",\n", + " cmap=\"Blues_r\",\n", + " sortaxisby=\"atomic_number\",\n", + " show_axislabels=False,\n", + " ax=ax,\n", + " **heatmap_params\n", + " )\n", + " # plt.subplots_adjust(wspace=0.001)\n", + "axes[-1][-1].remove()\n", + "\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig(\"SI_pearson.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Spearman correlation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (axes) = plt.subplots(4, 2, figsize=(20, 20))\n", + "heatmap_params = {\"vmin\": -1, \"vmax\": 1}\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " heatmap_plotter(\n", + " embedding=cbfv,\n", + " metric=\"spearman\",\n", + " cmap=\"Blues_r\",\n", + " sortaxisby=\"atomic_number\",\n", + " show_axislabels=False,\n", + " ax=ax,\n", + " **heatmap_params\n", + " )\n", + " # plt.subplots_adjust(wspace=0.001)\n", + "axes[-1][-1].remove()\n", + "\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig(\"SI_spearman.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cosine similarity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (axes) = plt.subplots(4, 2, figsize=(20, 20))\n", + "heatmap_params = {\"vmin\": -1, \"vmax\": 1}\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " heatmap_plotter(\n", + " embedding=cbfv,\n", + " metric=\"cosine_similarity\",\n", + " cmap=\"Blues_r\",\n", + " sortaxisby=\"atomic_number\",\n", + " show_axislabels=False,\n", + " ax=ax,\n", + " **heatmap_params\n", + " )\n", + " # plt.subplots_adjust(wspace=0.001)\n", + "axes[-1][-1].remove()\n", + "\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig(\"SI_cosinesimilarity.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Two-dimensional projections" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Principal Component Analysis (PCA)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(\n", + " 4,\n", + " 2,\n", + " figsize=(20, 20),\n", + ")\n", + "\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " dimension_plotter(\n", + " embedding=cbfv,\n", + " reducer=\"pca\",\n", + " n_components=2,\n", + " ax=ax,\n", + " adjusttext=True,\n", + " )\n", + " ax.legend().remove()\n", + "axes[-1][-1].remove()\n", + "\n", + "handles, labels = ax.get_legend_handles_labels()\n", + "fig.legend(handles, labels, bbox_to_anchor=(0.54, 1.06), loc=\"upper center\", ncol=5)\n", + "fig.tight_layout()\n", + "plt.savefig(\"SI_pca.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### t-distributed Stochastic Neighbor Embedding (t-SNE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(\n", + " 4,\n", + " 2,\n", + " figsize=(20, 20),\n", + ")\n", + "\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " dimension_plotter(\n", + " embedding=cbfv,\n", + " reducer=\"tsne\",\n", + " n_components=2,\n", + " ax=ax,\n", + " # adjusttext=True,\n", + " )\n", + " ax.legend().remove()\n", + "axes[-1][-1].remove()\n", + "\n", + "handles, labels = ax.get_legend_handles_labels()\n", + "fig.legend(handles, labels, bbox_to_anchor=(0.54, 1.06), loc=\"upper center\", ncol=5)\n", + "fig.tight_layout()\n", + "plt.savefig(\"SI_tsne.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Uniform Manifold Approximation and Projection (UMAP)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(\n", + " 4,\n", + " 2,\n", + " figsize=(20, 20),\n", + ")\n", + "\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.values()):\n", + " dimension_plotter(\n", + " embedding=cbfv,\n", + " reducer=\"umap\",\n", + " n_components=2,\n", + " ax=ax,\n", + " adjusttext=True,\n", + " )\n", + " ax.legend().remove()\n", + "axes[-1][-1].remove()\n", + "\n", + "handles, labels = ax.get_legend_handles_labels()\n", + "fig.legend(handles, labels, bbox_to_anchor=(0.54, 1.06), loc=\"upper center\", ncol=5)\n", + "fig.tight_layout()\n", + "plt.savefig(\"SI_umap.pdf\", bbox_inches=\"tight\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Distribution of similarity measures" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pearson correlation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "correlation_metrics = [\"pearson\", \"cosine_similarity\"]\n", + "correlation_dfs = {}\n", + "for rep in element_embedddings.keys():\n", + " correlation_dfs[rep] = {\n", + " \"pearson\": element_embedddings[rep].correlation_df(),\n", + " \"cosine_similarity\": element_embedddings[rep].correlation_df(\n", + " metric=\"cosine_similarity\"\n", + " ),\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(4, 2, figsize=(20, 20))\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.keys()):\n", + " sns.histplot(correlation_dfs[cbfv][\"pearson\"], x=\"pearson\", ax=ax)\n", + " ax.set_title(cbfv)\n", + " ax.set_xlim(-1, 1)\n", + " ax.set_xlabel(\"Pearson correlation\")\n", + " ax.set_ylabel(\"Count\")\n", + "\n", + "axes[-1][-1].remove()\n", + "plt.tight_layout()\n", + "plt.savefig(\"SI_pearson_distribution.pdf\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cosine similarity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(4, 2, figsize=(20, 20))\n", + "for ax, cbfv in zip(axes.flatten(), element_embedddings.keys()):\n", + " sns.histplot(\n", + " correlation_dfs[cbfv][\"cosine_similarity\"], x=\"cosine_similarity\", ax=ax\n", + " )\n", + " ax.set_title(cbfv)\n", + " ax.set_xlim(-1, 1)\n", + " ax.set_xlabel(\"Cosine similarity\")\n", + " ax.set_ylabel(\"Count\")\n", + "\n", + "axes[-1][-1].remove()\n", + "plt.tight_layout()\n", + "plt.savefig(\"SI_cosine_similarity_distribution.pdf\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "atomic_dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Publication/element_similarity.ipynb b/Publication/element_similarity.ipynb index ab5fa73..fc3f92e 100644 --- a/Publication/element_similarity.ipynb +++ b/Publication/element_similarity.ipynb @@ -152,24 +152,32 @@ "source": [ "distances = [\"euclidean\", \"manhattan\", \"chebyshev\"]\n", "for distance in distances:\n", - " d = element_embedddings[\"magpie\"].compute_distance_metric(\"Li\", \"K\", distance)\n", + " d = element_embedddings[\"magpie_sc\"].compute_distance_metric(\"Li\", \"K\", distance)\n", " print(f\"Distance between Li and K using {distance} is {d:.2f}\")\n", "\n", "# Get the pearson correlation and cosine similarity between the embeddings for Li and K\n", "similarity_metrics = [\"pearson\", \"cosine_similarity\"]\n", "for similarity_metric in similarity_metrics:\n", - " d = element_embedddings[\"magpie\"].compute_correlation_metric(\n", + " magpie_d = element_embedddings[\"magpie_sc\"].compute_correlation_metric(\n", " \"Li\", \"K\", similarity_metric\n", " )\n", - " if similarity_metric == \"pearson\":\n", - " d = d.statistic\n", + "\n", + " magpie_d_Li_Bi = element_embedddings[\"magpie_sc\"].compute_correlation_metric(\n", + " \"Li\", \"Bi\", similarity_metric\n", + " )\n", + "\n", " mvec_d = element_embedddings[\"mat2vec\"].compute_correlation_metric(\n", " \"Li\", \"K\", similarity_metric\n", " )\n", - " if similarity_metric == \"pearson\":\n", - " mvec_d = mvec_d.statistic\n", + " mvec_d_Li_Bi = element_embedddings[\"mat2vec\"].compute_correlation_metric(\n", + " \"Li\", \"Bi\", similarity_metric\n", + " )\n", + "\n", + " print(\n", + " f\"The metric, {similarity_metric}, between Li and K is {magpie_d:.3f} for magpie and {mvec_d:.3f} for mat2vec\"\n", + " )\n", " print(\n", - " f\"The metric, {similarity_metric}, between Li and K is {d:.3f} for magpie and {mvec_d:.3f} for mat2vec\"\n", + " f\"The metric, {similarity_metric}, between Li and Bi is {magpie_d_Li_Bi:.3f} for magpie and {mvec_d_Li_Bi:.3f} for mat2vec\"\n", " )" ] }, @@ -270,7 +278,7 @@ "outputs": [], "source": [ "fig, (axes) = plt.subplots(2, 2, figsize=(10, 10))\n", - "\n", + "heatmap_params = {\"vmin\": -1, \"vmax\": 1}\n", "for ax, cbfv in zip(axes.flatten(), cbfvs_to_keep):\n", " heatmap_plotter(\n", " embedding=element_vectors[cbfv],\n", @@ -279,6 +287,7 @@ " show_axislabels=False,\n", " cmap=\"Blues_r\",\n", " ax=ax,\n", + " **heatmap_params\n", " )\n", "\n", "plt.tight_layout()\n", @@ -305,7 +314,7 @@ "outputs": [], "source": [ "fig, (axes) = plt.subplots(2, 2, figsize=(10, 10))\n", - "\n", + "heatmap_params = {\"vmin\": -1, \"vmax\": 1}\n", "for ax, cbfv in zip(axes.flatten(), cbfvs_to_keep):\n", " heatmap_plotter(\n", " embedding=element_vectors[cbfv],\n", @@ -314,6 +323,7 @@ " show_axislabels=False,\n", " cmap=\"Blues_r\",\n", " ax=ax,\n", + " **heatmap_params\n", " )\n", "\n", "plt.tight_layout()\n", @@ -443,10 +453,77 @@ "handles, labels = ax.get_legend_handles_labels()\n", "fig.legend(handles, labels, bbox_to_anchor=(0.54, 1.06), loc=\"upper center\", ncol=5)\n", "fig.tight_layout()\n", - "plt.savefig(\"7_umap.pdf\")\n", + "plt.savefig(\"7_umap.pdf\", bbox_inches=\"tight\")\n", "fig.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SI\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Distribtion plot\n", + "# Create dictionaries to store the correlation dataframes for the embeddings for each metric\n", + "correlation_metrics = [\"pearson\", \"cosine_similarity\"]\n", + "correlation_dfs = {}\n", + "for cbfv in cbfvs_to_keep:\n", + " correlation_dfs[cbfv] = {\n", + " \"pearson\": element_vectors[cbfv].correlation_df(),\n", + " \"cosine_similarity\": element_vectors[cbfv].correlation_df(\n", + " metric=\"cosine_similarity\"\n", + " ),\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(2, 2, figsize=(10, 10))\n", + "for ax, cbfv in zip(axes.flatten(), cbfvs_to_keep):\n", + " sns.histplot(correlation_dfs[cbfv][\"pearson\"], x=\"pearson\", ax=ax)\n", + " ax.set_title(cbfv)\n", + " ax.set_xlim(-1, 1)\n", + " ax.set_xlabel(\"Pearson correlation\")\n", + " ax.set_ylabel(\"Count\")\n", + "plt.tight_layout()\n", + "plt.savefig(\"SI_pearson_distribution.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(2, 2, figsize=(10, 10))\n", + "for ax, cbfv in zip(axes.flatten(), cbfvs_to_keep):\n", + " sns.histplot(\n", + " correlation_dfs[cbfv][\"cosine_similarity\"], x=\"cosine_similarity\", ax=ax\n", + " )\n", + " ax.set_title(cbfv)\n", + " ax.set_xlim(-1, 1)\n", + " ax.set_xlabel(\"Cosine similarity\")\n", + " ax.set_ylabel(\"Count\")\n", + "plt.tight_layout()\n", + "plt.savefig(\"SI_cosine_distribution.pdf\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/elementembeddings/core.py b/src/elementembeddings/core.py index af65bb4..5d8c132 100644 --- a/src/elementembeddings/core.py +++ b/src/elementembeddings/core.py @@ -510,10 +510,14 @@ def compute_correlation_metric( # Define the allowable metrics scipy_corrs = {"pearson": pearsonr, "spearman": spearmanr} - if metric in scipy_corrs: + if metric == "pearson": return scipy_corrs[metric]( self.embeddings[ele1], self.embeddings[ele2] ).statistic + elif metric == "spearman": + return scipy_corrs[metric]( + self.embeddings[ele1], self.embeddings[ele2] + ).correlation elif metric == "cosine_similarity": return cosine_similarity(self.embeddings[ele1], self.embeddings[ele2])