From b75619f8bcaece6420c7000dfab966b05ae59c7d Mon Sep 17 00:00:00 2001 From: Ayuei Date: Wed, 25 Jan 2023 17:36:21 +1100 Subject: [PATCH] Directory refactoring of source code, no major breaking changes. debeir.interfaces -> debeir.core debeir.data_sets -> debeir.datasets - Documentation updated with refactoring - Test cases updated with refactoring --- .idea/deployment.xml | 2 +- .idea/misc.xml | 2 +- docs/debeir.html | 14 +- docs/debeir/core.html | 259 ++ docs/debeir/core/callbacks.html | 840 +++++++ docs/debeir/core/config.html | 1477 ++++++++++++ docs/debeir/core/converters.html | 446 ++++ docs/debeir/core/document.html | 999 ++++++++ docs/debeir/core/executor.html | 951 ++++++++ docs/debeir/core/indexer.html | 723 ++++++ docs/debeir/core/parser.html | 1095 +++++++++ docs/debeir/core/pipeline.html | 910 +++++++ docs/debeir/core/query.html | 934 ++++++++ docs/debeir/core/results.html | 455 ++++ docs/debeir/datasets.html | 264 +++ docs/debeir/datasets/bioreddit.html | 548 +++++ docs/debeir/datasets/clinical_trials.html | 2094 +++++++++++++++++ docs/debeir/datasets/factory.html | 676 ++++++ docs/debeir/datasets/marco.html | 770 ++++++ .../debeir/datasets/trec_clinical_trials.html | 610 +++++ docs/debeir/datasets/trec_covid.html | 453 ++++ docs/debeir/datasets/types.html | 731 ++++++ docs/debeir/datasets/utils.html | 544 +++++ docs/debeir/engines.html | 16 +- docs/debeir/engines/client.html | 269 +-- docs/debeir/engines/dummyindex.html | 6 +- docs/debeir/engines/dummyindex/index.html | 6 +- docs/debeir/engines/elasticsearch.html | 18 +- .../engines/elasticsearch/change_bm25.html | 233 +- .../engines/elasticsearch/executor.html | 619 ++--- .../elasticsearch/generate_script_score.html | 730 +++--- docs/debeir/engines/solr.html | 6 +- docs/debeir/evaluation.html | 6 +- docs/debeir/evaluation/cross_validation.html | 175 +- docs/debeir/evaluation/evaluator.html | 382 +-- docs/debeir/evaluation/residual_scoring.html | 278 ++- docs/debeir/models.html | 6 +- docs/debeir/models/colbert.html | 1453 ++++++------ docs/debeir/rankers.html | 6 +- docs/debeir/rankers/reranking.html | 6 +- docs/debeir/rankers/reranking/nir.html | 278 ++- docs/debeir/rankers/reranking/reranker.html | 270 ++- docs/debeir/rankers/reranking/use.html | 17 +- .../rankers/transformer_sent_encoder.html | 386 +-- docs/debeir/training.html | 6 +- docs/debeir/training/evaluate_reranker.html | 393 ++-- docs/debeir/training/hparm_tuning.html | 23 +- docs/debeir/training/hparm_tuning/config.html | 26 +- .../training/hparm_tuning/optuna_rank.html | 251 +- .../debeir/training/hparm_tuning/trainer.html | 791 ++++--- docs/debeir/training/hparm_tuning/types.html | 323 ++- docs/debeir/training/losses.html | 6 +- docs/debeir/training/losses/contrastive.html | 810 +++---- docs/debeir/training/losses/ranking.html | 10 +- docs/debeir/training/train_reranker.html | 136 +- .../training/train_sentence_encoder.html | 525 +++-- docs/debeir/training/utils.html | 693 +++--- docs/debeir/utils.html | 6 +- docs/debeir/utils/scaler.html | 61 +- docs/debeir/utils/utils.html | 20 +- docs/search.js | 2 +- .../hparam_tuning_from_config.py | 4 +- examples/hparam_tuning/tune_bm25.py | 2 +- examples/hparam_tuning/tune_z_param.py | 2 +- examples/trec2022/training.py | 4 +- src/debeir/__init__.py | 9 +- src/debeir/core/__init__.py | 5 + src/debeir/{interfaces => core}/callbacks.py | 10 +- src/debeir/{interfaces => core}/config.py | 4 +- src/debeir/{interfaces => core}/converters.py | 4 +- src/debeir/{interfaces => core}/document.py | 2 +- src/debeir/{interfaces => core}/executor.py | 59 +- src/debeir/{interfaces => core}/indexer.py | 3 +- src/debeir/{interfaces => core}/parser.py | 9 +- src/debeir/{interfaces => core}/pipeline.py | 17 +- src/debeir/{interfaces => core}/query.py | 8 +- src/debeir/{interfaces => core}/results.py | 3 +- .../{data_sets => datasets}/__init__.py | 4 +- .../{data_sets => datasets}/bioreddit.py | 5 +- .../clinical_trials.py | 64 +- src/debeir/{data_sets => datasets}/factory.py | 31 +- src/debeir/{data_sets => datasets}/marco.py | 49 +- .../trec_clinical_trials.py | 7 +- .../{data_sets => datasets}/trec_covid.py | 4 +- src/debeir/{data_sets => datasets}/types.py | 4 +- src/debeir/{data_sets => datasets}/utils.py | 7 +- src/debeir/engines/__init__.py | 4 +- src/debeir/engines/client.py | 1 + src/debeir/engines/elasticsearch/__init__.py | 4 + .../engines/elasticsearch/change_bm25.py | 25 +- src/debeir/engines/elasticsearch/executor.py | 37 +- .../elasticsearch/generate_script_score.py | 9 +- src/debeir/evaluation/__init__.py | 2 +- src/debeir/evaluation/cross_validation.py | 8 +- src/debeir/evaluation/evaluator.py | 7 +- src/debeir/evaluation/residual_scoring.py | 20 +- src/debeir/interfaces/__init__.py | 3 - src/debeir/models/colbert.py | 31 +- src/debeir/rankers/__init__.py | 2 +- src/debeir/rankers/reranking/nir.py | 11 +- src/debeir/rankers/reranking/reranker.py | 4 +- .../rankers/transformer_sent_encoder.py | 5 +- src/debeir/training/__init__.py | 2 +- src/debeir/training/evaluate_reranker.py | 15 +- src/debeir/training/hparm_tuning/__init__.py | 3 + src/debeir/training/hparm_tuning/config.py | 4 +- .../training/hparm_tuning/optuna_rank.py | 9 +- src/debeir/training/hparm_tuning/trainer.py | 17 +- src/debeir/training/hparm_tuning/types.py | 3 +- src/debeir/training/losses/contrastive.py | 10 +- src/debeir/training/losses/ranking.py | 3 - src/debeir/training/train_reranker.py | 7 +- src/debeir/training/train_sentence_encoder.py | 11 +- src/debeir/training/utils.py | 11 +- src/debeir/utils/__init__.py | 2 +- src/debeir/utils/scaler.py | 2 - src/debeir/utils/utils.py | 6 +- tests/test_callbacks.py | 10 +- tests/test_pipeline.py | 8 +- tests/test_reranking.py | 8 +- tests/test_results.py | 8 +- 121 files changed, 20874 insertions(+), 4803 deletions(-) create mode 100644 docs/debeir/core.html create mode 100644 docs/debeir/core/callbacks.html create mode 100644 docs/debeir/core/config.html create mode 100644 docs/debeir/core/converters.html create mode 100644 docs/debeir/core/document.html create mode 100644 docs/debeir/core/executor.html create mode 100644 docs/debeir/core/indexer.html create mode 100644 docs/debeir/core/parser.html create mode 100644 docs/debeir/core/pipeline.html create mode 100644 docs/debeir/core/query.html create mode 100644 docs/debeir/core/results.html create mode 100644 docs/debeir/datasets.html create mode 100644 docs/debeir/datasets/bioreddit.html create mode 100644 docs/debeir/datasets/clinical_trials.html create mode 100644 docs/debeir/datasets/factory.html create mode 100644 docs/debeir/datasets/marco.html create mode 100644 docs/debeir/datasets/trec_clinical_trials.html create mode 100644 docs/debeir/datasets/trec_covid.html create mode 100644 docs/debeir/datasets/types.html create mode 100644 docs/debeir/datasets/utils.html create mode 100644 src/debeir/core/__init__.py rename src/debeir/{interfaces => core}/callbacks.py (96%) rename src/debeir/{interfaces => core}/config.py (99%) rename src/debeir/{interfaces => core}/converters.py (96%) rename src/debeir/{interfaces => core}/document.py (99%) rename src/debeir/{interfaces => core}/executor.py (81%) rename src/debeir/{interfaces => core}/indexer.py (99%) rename src/debeir/{interfaces => core}/parser.py (99%) rename src/debeir/{interfaces => core}/pipeline.py (90%) rename src/debeir/{interfaces => core}/query.py (96%) rename src/debeir/{interfaces => core}/results.py (95%) rename src/debeir/{data_sets => datasets}/__init__.py (78%) rename src/debeir/{data_sets => datasets}/bioreddit.py (92%) rename src/debeir/{data_sets => datasets}/clinical_trials.py (93%) rename src/debeir/{data_sets => datasets}/factory.py (84%) rename src/debeir/{data_sets => datasets}/marco.py (68%) rename src/debeir/{data_sets => datasets}/trec_clinical_trials.py (94%) rename src/debeir/{data_sets => datasets}/trec_covid.py (89%) rename src/debeir/{data_sets => datasets}/types.py (96%) rename src/debeir/{data_sets => datasets}/utils.py (96%) delete mode 100644 src/debeir/interfaces/__init__.py diff --git a/.idea/deployment.xml b/.idea/deployment.xml index 7a3f1ac..6c45b01 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,6 +1,6 @@ - + diff --git a/.idea/misc.xml b/.idea/misc.xml index 79bd307..4772e2e 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/docs/debeir.html b/docs/debeir.html index a6f7705..27beff9 100644 --- a/docs/debeir.html +++ b/docs/debeir.html @@ -3,7 +3,7 @@ - + debeir API documentation @@ -24,10 +24,10 @@

Submodules

    -
  • data_sets
  • +
  • core
  • +
  • datasets
  • engines
  • evaluation
  • -
  • interfaces
  • models
  • rankers
  • training
  • @@ -48,7 +48,7 @@

    Submodules

    debeir

    -

    The NIR (Neural Index Ranker) source code library.

    +

    The DeBEIR (Dense Bi-Encoder Information Retrieval) source code library.

    See ./main.py in the parent directory for an out-of-the-box runnable code.

    @@ -60,7 +60,7 @@

    1"""
    -2The NIR (Neural Index Ranker) source code library.
    +2The DeBEIR (Dense Bi-Encoder Information Retrieval) source code library.
     3
     4See ./main.py in the parent directory for an out-of-the-box runnable code.
     5
    @@ -171,7 +171,7 @@ 

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -198,7 +198,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/core.html b/docs/debeir/core.html new file mode 100644 index 0000000..b05bd8c --- /dev/null +++ b/docs/debeir/core.html @@ -0,0 +1,259 @@ + + + + + + + debeir.core API documentation + + + + + + + + + +
    +
    +

    +debeir.core

    + +

    Core library interfaces that must be implemented for custom datasets

    + +

    Interfaces to implement custom data_sets in nir.data_sets.

    +
    + + + + + +
    1"""
    +2Core library interfaces that must be implemented for custom datasets
    +3
    +4Interfaces to implement custom data_sets in nir.data_sets.
    +5"""
    +
    + + +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/core/callbacks.html b/docs/debeir/core/callbacks.html new file mode 100644 index 0000000..fc27af8 --- /dev/null +++ b/docs/debeir/core/callbacks.html @@ -0,0 +1,840 @@ + + + + + + + debeir.core.callbacks API documentation + + + + + + + + + +
    +
    +

    +debeir.core.callbacks

    + +

    Callbacks for before after running. +E.g. before is for setup +after is for evaluation/serialization etc

    +
    + + + + + +
      1"""
    +  2Callbacks for before after running.
    +  3E.g. before is for setup
    +  4after is for evaluation/serialization etc
    +  5"""
    +  6
    +  7import abc
    +  8import os
    +  9import tempfile
    + 10import uuid
    + 11from typing import List
    + 12
    + 13import loguru
    + 14from debeir.datasets.factory import query_factory
    + 15from debeir.evaluation.evaluator import Evaluator
    + 16from debeir.core.config import GenericConfig, NIRConfig
    + 17from debeir.core.pipeline import Pipeline
    + 18
    + 19
    + 20class Callback:
    + 21    def __init__(self):
    + 22        self.pipeline = None
    + 23
    + 24    @abc.abstractmethod
    + 25    def before(self, pipeline: Pipeline):
    + 26        pass
    + 27
    + 28    @abc.abstractmethod
    + 29    def after(self, results: List):
    + 30        pass
    + 31
    + 32
    + 33class SerializationCallback(Callback):
    + 34    def __init__(self, config: GenericConfig, nir_config: NIRConfig):
    + 35        super().__init__()
    + 36        self.config = config
    + 37        self.nir_config = nir_config
    + 38        self.output_file = None
    + 39        self.query_cls = query_factory[self.config.query_fn]
    + 40
    + 41    def before(self, pipeline: Pipeline):
    + 42        """
    + 43        Check if output file exists
    + 44
    + 45        :return:
    + 46            Output file path
    + 47        """
    + 48
    + 49        self.pipeline = Pipeline
    + 50
    + 51        output_file = self.config.output_file
    + 52        output_dir = os.path.join(self.nir_config.output_directory, self.config.index)
    + 53
    + 54        if output_file is None:
    + 55            os.makedirs(name=output_dir, exist_ok=True)
    + 56            output_file = os.path.join(output_dir, str(uuid.uuid4()))
    + 57
    + 58            loguru.logger.info(f"Output file not specified, writing to: {output_file}")
    + 59
    + 60        else:
    + 61            output_file = os.path.join(output_dir, output_file)
    + 62
    + 63        if os.path.exists(output_file):
    + 64            if not self.config.overwrite_output_if_exists:
    + 65                raise RuntimeError("Directory exists and isn't explicitly overwritten "
    + 66                                   "in config with overwrite_output_if_exists=True")
    + 67
    + 68            loguru.logger.info(f"Output file exists: {output_file}. Overwriting...")
    + 69            open(output_file, "w+").close()
    + 70
    + 71        pipeline.output_file = output_file
    + 72        self.output_file = output_file
    + 73
    + 74    def after(self, results: List):
    + 75        """
    + 76        Serialize results to self.output_file in a TREC-style format
    + 77        :param topic_num: Topic number to serialize
    + 78        :param res: Raw elasticsearch result
    + 79        :param run_name: The run name for TREC-style runs (default: NO_RUN_NAME)
    + 80        """
    + 81
    + 82        self._after(results,
    + 83                    output_file=self.output_file,
    + 84                    run_name=self.config.run_name)
    + 85
    + 86    @classmethod
    + 87    def _after(self, results: List, output_file, run_name=None):
    + 88        if run_name is None:
    + 89            run_name = "NO_RUN_NAME"
    + 90
    + 91        with open(output_file, "a+t") as writer:
    + 92            for doc in results:
    + 93                line = f"{doc.topic_num}\t" \
    + 94                       f"Q0\t" \
    + 95                       f"{doc.doc_id}\t" \
    + 96                       f"{doc.scores['rank']}\t" \
    + 97                       f"{doc.score}\t" \
    + 98                       f"{run_name}\n"
    + 99
    +100                writer.write(line)
    +101
    +102
    +103class EvaluationCallback(Callback):
    +104    def __init__(self, evaluator: Evaluator, config):
    +105        super().__init__()
    +106        self.evaluator = evaluator
    +107        self.config = config
    +108        self.parsed_run = None
    +109
    +110    def before(self, pipeline: Pipeline):
    +111        self.pipeline = Pipeline
    +112
    +113    def after(self, results: List, id_field="id"):
    +114        if self.pipeline.output_file is None:
    +115            directory_name = tempfile.mkdtemp()
    +116            fn = str(uuid.uuid4())
    +117
    +118            fp = os.path.join(directory_name, fn)
    +119
    +120            query = query_factory[self.config.query_fn]
    +121            query.id_field = id_field
    +122
    +123            SerializationCallback._after(results,
    +124                                         output_file=fp,
    +125                                         run_name=self.config.run_name)
    +126
    +127            self.pipeline.output_file = fp
    +128
    +129        parsed_run = self.evaluator.evaluate_runs(self.pipeline.output_file,
    +130                                                  disable_cache=True)
    +131        self.parsed_run = parsed_run
    +132
    +133        return self.parsed_run
    +
    + + +
    +
    + +
    + + class + Callback: + + + +
    + +
    21class Callback:
    +22    def __init__(self):
    +23        self.pipeline = None
    +24
    +25    @abc.abstractmethod
    +26    def before(self, pipeline: Pipeline):
    +27        pass
    +28
    +29    @abc.abstractmethod
    +30    def after(self, results: List):
    +31        pass
    +
    + + + + +
    + +
    + + Callback() + + + +
    + +
    22    def __init__(self):
    +23        self.pipeline = None
    +
    + + + + +
    +
    + +
    +
    @abc.abstractmethod
    + + def + before(self, pipeline: debeir.core.pipeline.Pipeline): + + + +
    + +
    25    @abc.abstractmethod
    +26    def before(self, pipeline: Pipeline):
    +27        pass
    +
    + + + + +
    +
    + +
    +
    @abc.abstractmethod
    + + def + after(self, results: List): + + + +
    + +
    29    @abc.abstractmethod
    +30    def after(self, results: List):
    +31        pass
    +
    + + + + +
    +
    +
    + +
    + + class + SerializationCallback(Callback): + + + +
    + +
     34class SerializationCallback(Callback):
    + 35    def __init__(self, config: GenericConfig, nir_config: NIRConfig):
    + 36        super().__init__()
    + 37        self.config = config
    + 38        self.nir_config = nir_config
    + 39        self.output_file = None
    + 40        self.query_cls = query_factory[self.config.query_fn]
    + 41
    + 42    def before(self, pipeline: Pipeline):
    + 43        """
    + 44        Check if output file exists
    + 45
    + 46        :return:
    + 47            Output file path
    + 48        """
    + 49
    + 50        self.pipeline = Pipeline
    + 51
    + 52        output_file = self.config.output_file
    + 53        output_dir = os.path.join(self.nir_config.output_directory, self.config.index)
    + 54
    + 55        if output_file is None:
    + 56            os.makedirs(name=output_dir, exist_ok=True)
    + 57            output_file = os.path.join(output_dir, str(uuid.uuid4()))
    + 58
    + 59            loguru.logger.info(f"Output file not specified, writing to: {output_file}")
    + 60
    + 61        else:
    + 62            output_file = os.path.join(output_dir, output_file)
    + 63
    + 64        if os.path.exists(output_file):
    + 65            if not self.config.overwrite_output_if_exists:
    + 66                raise RuntimeError("Directory exists and isn't explicitly overwritten "
    + 67                                   "in config with overwrite_output_if_exists=True")
    + 68
    + 69            loguru.logger.info(f"Output file exists: {output_file}. Overwriting...")
    + 70            open(output_file, "w+").close()
    + 71
    + 72        pipeline.output_file = output_file
    + 73        self.output_file = output_file
    + 74
    + 75    def after(self, results: List):
    + 76        """
    + 77        Serialize results to self.output_file in a TREC-style format
    + 78        :param topic_num: Topic number to serialize
    + 79        :param res: Raw elasticsearch result
    + 80        :param run_name: The run name for TREC-style runs (default: NO_RUN_NAME)
    + 81        """
    + 82
    + 83        self._after(results,
    + 84                    output_file=self.output_file,
    + 85                    run_name=self.config.run_name)
    + 86
    + 87    @classmethod
    + 88    def _after(self, results: List, output_file, run_name=None):
    + 89        if run_name is None:
    + 90            run_name = "NO_RUN_NAME"
    + 91
    + 92        with open(output_file, "a+t") as writer:
    + 93            for doc in results:
    + 94                line = f"{doc.topic_num}\t" \
    + 95                       f"Q0\t" \
    + 96                       f"{doc.doc_id}\t" \
    + 97                       f"{doc.scores['rank']}\t" \
    + 98                       f"{doc.score}\t" \
    + 99                       f"{run_name}\n"
    +100
    +101                writer.write(line)
    +
    + + + + +
    + +
    + + SerializationCallback( config: debeir.core.config.GenericConfig, nir_config: debeir.core.config.NIRConfig) + + + +
    + +
    35    def __init__(self, config: GenericConfig, nir_config: NIRConfig):
    +36        super().__init__()
    +37        self.config = config
    +38        self.nir_config = nir_config
    +39        self.output_file = None
    +40        self.query_cls = query_factory[self.config.query_fn]
    +
    + + + + +
    +
    + +
    + + def + before(self, pipeline: debeir.core.pipeline.Pipeline): + + + +
    + +
    42    def before(self, pipeline: Pipeline):
    +43        """
    +44        Check if output file exists
    +45
    +46        :return:
    +47            Output file path
    +48        """
    +49
    +50        self.pipeline = Pipeline
    +51
    +52        output_file = self.config.output_file
    +53        output_dir = os.path.join(self.nir_config.output_directory, self.config.index)
    +54
    +55        if output_file is None:
    +56            os.makedirs(name=output_dir, exist_ok=True)
    +57            output_file = os.path.join(output_dir, str(uuid.uuid4()))
    +58
    +59            loguru.logger.info(f"Output file not specified, writing to: {output_file}")
    +60
    +61        else:
    +62            output_file = os.path.join(output_dir, output_file)
    +63
    +64        if os.path.exists(output_file):
    +65            if not self.config.overwrite_output_if_exists:
    +66                raise RuntimeError("Directory exists and isn't explicitly overwritten "
    +67                                   "in config with overwrite_output_if_exists=True")
    +68
    +69            loguru.logger.info(f"Output file exists: {output_file}. Overwriting...")
    +70            open(output_file, "w+").close()
    +71
    +72        pipeline.output_file = output_file
    +73        self.output_file = output_file
    +
    + + +

    Check if output file exists

    + +
    Returns
    + +
    +
    Output file path
    +
    +
    +
    + + +
    +
    + +
    + + def + after(self, results: List): + + + +
    + +
    75    def after(self, results: List):
    +76        """
    +77        Serialize results to self.output_file in a TREC-style format
    +78        :param topic_num: Topic number to serialize
    +79        :param res: Raw elasticsearch result
    +80        :param run_name: The run name for TREC-style runs (default: NO_RUN_NAME)
    +81        """
    +82
    +83        self._after(results,
    +84                    output_file=self.output_file,
    +85                    run_name=self.config.run_name)
    +
    + + +

    Serialize results to self.output_file in a TREC-style format

    + +
    Parameters
    + +
      +
    • topic_num: Topic number to serialize
    • +
    • res: Raw elasticsearch result
    • +
    • run_name: The run name for TREC-style runs (default: NO_RUN_NAME)
    • +
    +
    + + +
    +
    +
    + +
    + + class + EvaluationCallback(Callback): + + + +
    + +
    104class EvaluationCallback(Callback):
    +105    def __init__(self, evaluator: Evaluator, config):
    +106        super().__init__()
    +107        self.evaluator = evaluator
    +108        self.config = config
    +109        self.parsed_run = None
    +110
    +111    def before(self, pipeline: Pipeline):
    +112        self.pipeline = Pipeline
    +113
    +114    def after(self, results: List, id_field="id"):
    +115        if self.pipeline.output_file is None:
    +116            directory_name = tempfile.mkdtemp()
    +117            fn = str(uuid.uuid4())
    +118
    +119            fp = os.path.join(directory_name, fn)
    +120
    +121            query = query_factory[self.config.query_fn]
    +122            query.id_field = id_field
    +123
    +124            SerializationCallback._after(results,
    +125                                         output_file=fp,
    +126                                         run_name=self.config.run_name)
    +127
    +128            self.pipeline.output_file = fp
    +129
    +130        parsed_run = self.evaluator.evaluate_runs(self.pipeline.output_file,
    +131                                                  disable_cache=True)
    +132        self.parsed_run = parsed_run
    +133
    +134        return self.parsed_run
    +
    + + + + +
    + +
    + + EvaluationCallback(evaluator: debeir.evaluation.evaluator.Evaluator, config) + + + +
    + +
    105    def __init__(self, evaluator: Evaluator, config):
    +106        super().__init__()
    +107        self.evaluator = evaluator
    +108        self.config = config
    +109        self.parsed_run = None
    +
    + + + + +
    +
    + +
    + + def + before(self, pipeline: debeir.core.pipeline.Pipeline): + + + +
    + +
    111    def before(self, pipeline: Pipeline):
    +112        self.pipeline = Pipeline
    +
    + + + + +
    +
    + +
    + + def + after(self, results: List, id_field='id'): + + + +
    + +
    114    def after(self, results: List, id_field="id"):
    +115        if self.pipeline.output_file is None:
    +116            directory_name = tempfile.mkdtemp()
    +117            fn = str(uuid.uuid4())
    +118
    +119            fp = os.path.join(directory_name, fn)
    +120
    +121            query = query_factory[self.config.query_fn]
    +122            query.id_field = id_field
    +123
    +124            SerializationCallback._after(results,
    +125                                         output_file=fp,
    +126                                         run_name=self.config.run_name)
    +127
    +128            self.pipeline.output_file = fp
    +129
    +130        parsed_run = self.evaluator.evaluate_runs(self.pipeline.output_file,
    +131                                                  disable_cache=True)
    +132        self.parsed_run = parsed_run
    +133
    +134        return self.parsed_run
    +
    + + + + +
    +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/core/config.html b/docs/debeir/core/config.html new file mode 100644 index 0000000..71c2b89 --- /dev/null +++ b/docs/debeir/core/config.html @@ -0,0 +1,1477 @@ + + + + + + + debeir.core.config API documentation + + + + + + + + + +
    +
    +

    +debeir.core.config

    + + + + + + +
      1import abc
    +  2import dataclasses
    +  3import os
    +  4from abc import ABC
    +  5from dataclasses import dataclass
    +  6from pathlib import Path
    +  7from typing import Dict, List, MutableMapping, Union
    +  8
    +  9import loguru
    + 10import toml
    + 11
    + 12
    + 13class Config:
    + 14    """
    + 15    Config Interface with creation class methods
    + 16    """
    + 17
    + 18    def __update__(self, **kwargs):
    + 19        attrs = vars(self)
    + 20        kwargs.update(attrs)
    + 21
    + 22        return kwargs
    + 23
    + 24    @classmethod
    + 25    def from_toml(cls, fp: Union[str, Path], field_class, *args, **kwargs) -> 'Config':
    + 26        """
    + 27        Instantiates a Config object from a toml file
    + 28
    + 29        :param fp: File path of the Config TOML file
    + 30        :param field_class: Class of the Config object to be instantiated
    + 31        :param args: Arguments to be passed to Config
    + 32        :param kwargs: Keyword arguments to be passed
    + 33        :return:
    + 34            A instantiated and validated Config object.
    + 35        """
    + 36        args_dict = toml.load(fp)
    + 37
    + 38        return cls.from_args(args_dict, field_class, *args, **kwargs)
    + 39
    + 40    @classmethod
    + 41    def from_args(cls, args_dict: MutableMapping, field_class, *args, **kwargs):
    + 42        """
    + 43        Instantiates a Config object from arguments
    + 44
    + 45
    + 46        :param args_dict:
    + 47        :param field_class:
    + 48        :param args:
    + 49        :param kwargs:
    + 50        :return:
    + 51        """
    + 52        from debeir.rankers.transformer_sent_encoder import Encoder
    + 53
    + 54        field_names = set(f.name for f in dataclasses.fields(field_class))
    + 55        obj = field_class(**{k: v for k, v in args_dict.items() if k in field_names})
    + 56        if hasattr(obj, 'encoder_fp') and obj.encoder_fp:
    + 57            obj.encoder = Encoder(obj.encoder_fp, obj.encoder_normalize)
    + 58
    + 59        obj.validate()
    + 60
    + 61        return obj
    + 62
    + 63    @classmethod
    + 64    def from_dict(cls, data_class, **kwargs):
    + 65        """
    + 66        Instantiates a Config object from a dictionary
    + 67
    + 68        :param data_class:
    + 69        :param kwargs:
    + 70        :return:
    + 71        """
    + 72        from debeir.rankers.transformer_sent_encoder import Encoder
    + 73
    + 74        if "encoder_fp" in kwargs and kwargs["encoder_fp"]:
    + 75            kwargs["encoder"] = Encoder(kwargs["encoder_fp"])
    + 76
    + 77        field_names = set(f.name for f in dataclasses.fields(data_class))
    + 78        obj = data_class(**{k: v for k, v in kwargs.items() if k in field_names})
    + 79        obj.validate(0)
    + 80
    + 81        return obj
    + 82
    + 83    @abc.abstractmethod
    + 84    def validate(self):
    + 85        """
    + 86        Validates if the config is correct.
    + 87        Must be implemented by inherited classes.
    + 88        """
    + 89        pass
    + 90
    + 91
    + 92@dataclass(init=True, unsafe_hash=True)
    + 93class GenericConfig(Config, ABC):
    + 94    """
    + 95    Generic NIR Configuration file for which all configs will inherit
    + 96    """
    + 97    query_type: str
    + 98    index: str = None
    + 99    encoder_normalize: bool = True
    +100    ablations: bool = False
    +101    norm_weight: float = None
    +102    automatic: bool = None
    +103    encoder: object = None
    +104    encoder_fp: str = None
    +105    query_weights: List[float] = None
    +106    cosine_weights: List[float] = None
    +107    evaluate: bool = False
    +108    qrels: str = None
    +109    config_fn: str = None
    +110    query_fn: str = None
    +111    parser_fn: str = None
    +112    executor_fn: str = None
    +113    cosine_ceiling: float = None
    +114    topics_path: str = None
    +115    return_id_only: bool = False
    +116    overwrite_output_if_exists: bool = False
    +117    output_file: str = None
    +118    run_name: str = None
    +119
    +120    @classmethod
    +121    def from_toml(cls, fp: Union[str, Path], *args, **kwargs) -> 'GenericConfig':
    +122        return Config.from_toml(fp, cls, *args, **kwargs)
    +123
    +124
    +125@dataclass(init=True)
    +126class _NIRMasterConfig(Config):
    +127    """
    +128    Base NIR Master config: nir.toml
    +129    """
    +130    metrics: Dict
    +131    search: Dict
    +132    nir: Dict
    +133
    +134    def get_metrics(self, key='common', return_as_instance=False):
    +135        metrics = self.metrics[key]
    +136        if return_as_instance:
    +137            return MetricsConfig.from_args(metrics, MetricsConfig)
    +138
    +139        return metrics
    +140
    +141    def get_search_engine_settings(self, key='elasticsearch', return_as_instance=False):
    +142        engine_settings = self.search['engines'][key]
    +143        if return_as_instance:
    +144            return ElasticsearchConfig.from_args(engine_settings, ElasticsearchConfig)
    +145
    +146        return engine_settings
    +147
    +148    def get_nir_settings(self, key='default_settings', return_as_instance=False):
    +149        nir_settings = self.nir[key]
    +150
    +151        if return_as_instance:
    +152            return NIRConfig.from_args(nir_settings, NIRConfig)
    +153
    +154        return nir_settings
    +155
    +156    def validate(self):
    +157        return True
    +158
    +159
    +160@dataclass(init=True)
    +161class ElasticsearchConfig(Config):
    +162    """
    +163    Basic Elasticsearch configuration file settings from the master nir.toml file
    +164    """
    +165    protocol: str
    +166    ip: str
    +167    port: str
    +168    timeout: int
    +169
    +170    def validate(self):
    +171        """
    +172        Checks if Elasticsearch URL is correct
    +173        """
    +174        assert self.protocol in ['http', 'https']
    +175        assert self.port.isdigit()
    +176
    +177
    +178@dataclass(init=True)
    +179class SolrConfig(ElasticsearchConfig):
    +180    """
    +181    Basic Solr configuration file settings from the master nir.toml file
    +182    """
    +183    pass
    +184
    +185
    +186@dataclass(init=True)
    +187class MetricsConfig(Config):
    +188    """
    +189    Basic Metrics configuration file settings from the master nir.toml file
    +190    """
    +191    metrics: List[str]
    +192
    +193    def validate(self):
    +194        """
    +195        Checks if each Metrics is usable by evaluator classes
    +196        """
    +197        for metric in self.metrics:
    +198            assert "@" in metric
    +199
    +200            metric, depth = metric.split("@")
    +201
    +202            assert metric.isalpha()
    +203            assert depth.isdigit()
    +204
    +205
    +206@dataclass(init=True)
    +207class NIRConfig(Config):
    +208    """
    +209    Basic NIR configuration file settings from the master nir.toml file
    +210    """
    +211    norm_weight: str
    +212    evaluate: bool
    +213    return_size: int
    +214    output_directory: str
    +215
    +216    def validate(self):
    +217        return True
    +218
    +219
    +220def apply_config(func):
    +221    """
    +222    Configuration decorator.
    +223
    +224    :param func: Decorated function
    +225    :return:
    +226    """
    +227
    +228    def use_config(self, *args, **kwargs):
    +229        """
    +230        Replaces keywords and args passed to the function with ones from self.config.
    +231
    +232        :param self:
    +233        :param args: To be updated
    +234        :param kwargs: To be updated
    +235        :return:
    +236        """
    +237        if self.config is not None:
    +238            kwargs = self.config.__update__(**kwargs)
    +239
    +240        return func(self, *args, **kwargs)
    +241
    +242    return use_config
    +243
    +244
    +245def override_with_toml_config(func):
    +246    """
    +247    Configuration decorator. Overwrite a functions kwargs and args with a specified toml config file.
    +248    Pass override_with_config=path/to/config
    +249
    +250    :param func: Decorated function
    +251    :return:
    +252    """
    +253
    +254    def override_with(override_with_config_: str = None, *args, **kwargs):
    +255        """
    +256        Replaces keywords and args passed to the function with ones from self.config.
    +257
    +258        :param override_with_config_: Path to config else None
    +259        :param args: To be updated
    +260        :param kwargs: To be updated
    +261        :return:
    +262        """
    +263
    +264        if f"override_{func.__name__}_with_config_" in kwargs:
    +265            override_with_config_ = f"override_{func.__name__}_with_config_"
    +266
    +267        if override_with_config_ is not None:
    +268            if os.path.exists(override_with_config_):
    +269                toml_kwargs = toml.load(override_with_config_)
    +270                kwargs = kwargs.update(**toml_kwargs)
    +271
    +272        return func(*args, **kwargs)
    +273
    +274    return override_with
    +275
    +276
    +277def save_kwargs_to_file(func):
    +278    def save_kwargs(save_kwargs_to_: str = None, *args, **kwargs):
    +279        """
    +280        Save kwargs passed to the function output_file = f"{save_kwargs_to_}_{func.__name__}.toml"
    +281
    +282        :param save_kwargs_to_: Path to save location for config else None. This should be a DIRECTORY.
    +283        :param args: To be updated
    +284        :param kwargs: To be updated
    +285        :return:
    +286        """
    +287        if save_kwargs_to_ is not None:
    +288            os.makedirs(save_kwargs_to_, exist_ok=True)
    +289
    +290            if os.path.exists(save_kwargs_to_):
    +291                output_file = f"{save_kwargs_to_}/{func.__name__}.toml"
    +292                loguru.logger.info(f"Saving kwargs to {output_file}")
    +293                toml.dump(kwargs, open(output_file, "w+"))
    +294
    +295        return func(*args, **kwargs)
    +296
    +297    return save_kwargs
    +
    + + +
    +
    + +
    + + class + Config: + + + +
    + +
    14class Config:
    +15    """
    +16    Config Interface with creation class methods
    +17    """
    +18
    +19    def __update__(self, **kwargs):
    +20        attrs = vars(self)
    +21        kwargs.update(attrs)
    +22
    +23        return kwargs
    +24
    +25    @classmethod
    +26    def from_toml(cls, fp: Union[str, Path], field_class, *args, **kwargs) -> 'Config':
    +27        """
    +28        Instantiates a Config object from a toml file
    +29
    +30        :param fp: File path of the Config TOML file
    +31        :param field_class: Class of the Config object to be instantiated
    +32        :param args: Arguments to be passed to Config
    +33        :param kwargs: Keyword arguments to be passed
    +34        :return:
    +35            A instantiated and validated Config object.
    +36        """
    +37        args_dict = toml.load(fp)
    +38
    +39        return cls.from_args(args_dict, field_class, *args, **kwargs)
    +40
    +41    @classmethod
    +42    def from_args(cls, args_dict: MutableMapping, field_class, *args, **kwargs):
    +43        """
    +44        Instantiates a Config object from arguments
    +45
    +46
    +47        :param args_dict:
    +48        :param field_class:
    +49        :param args:
    +50        :param kwargs:
    +51        :return:
    +52        """
    +53        from debeir.rankers.transformer_sent_encoder import Encoder
    +54
    +55        field_names = set(f.name for f in dataclasses.fields(field_class))
    +56        obj = field_class(**{k: v for k, v in args_dict.items() if k in field_names})
    +57        if hasattr(obj, 'encoder_fp') and obj.encoder_fp:
    +58            obj.encoder = Encoder(obj.encoder_fp, obj.encoder_normalize)
    +59
    +60        obj.validate()
    +61
    +62        return obj
    +63
    +64    @classmethod
    +65    def from_dict(cls, data_class, **kwargs):
    +66        """
    +67        Instantiates a Config object from a dictionary
    +68
    +69        :param data_class:
    +70        :param kwargs:
    +71        :return:
    +72        """
    +73        from debeir.rankers.transformer_sent_encoder import Encoder
    +74
    +75        if "encoder_fp" in kwargs and kwargs["encoder_fp"]:
    +76            kwargs["encoder"] = Encoder(kwargs["encoder_fp"])
    +77
    +78        field_names = set(f.name for f in dataclasses.fields(data_class))
    +79        obj = data_class(**{k: v for k, v in kwargs.items() if k in field_names})
    +80        obj.validate(0)
    +81
    +82        return obj
    +83
    +84    @abc.abstractmethod
    +85    def validate(self):
    +86        """
    +87        Validates if the config is correct.
    +88        Must be implemented by inherited classes.
    +89        """
    +90        pass
    +
    + + +

    Config Interface with creation class methods

    +
    + + +
    +
    + + Config() + + +
    + + + + +
    +
    + +
    +
    @classmethod
    + + def + from_toml( cls, fp: Union[str, pathlib.Path], field_class, *args, **kwargs) -> debeir.core.config.Config: + + + +
    + +
    25    @classmethod
    +26    def from_toml(cls, fp: Union[str, Path], field_class, *args, **kwargs) -> 'Config':
    +27        """
    +28        Instantiates a Config object from a toml file
    +29
    +30        :param fp: File path of the Config TOML file
    +31        :param field_class: Class of the Config object to be instantiated
    +32        :param args: Arguments to be passed to Config
    +33        :param kwargs: Keyword arguments to be passed
    +34        :return:
    +35            A instantiated and validated Config object.
    +36        """
    +37        args_dict = toml.load(fp)
    +38
    +39        return cls.from_args(args_dict, field_class, *args, **kwargs)
    +
    + + +

    Instantiates a Config object from a toml file

    + +
    Parameters
    + +
      +
    • fp: File path of the Config TOML file
    • +
    • field_class: Class of the Config object to be instantiated
    • +
    • args: Arguments to be passed to Config
    • +
    • kwargs: Keyword arguments to be passed
    • +
    + +
    Returns
    + +
    +
    A instantiated and validated Config object.
    +
    +
    +
    + + +
    +
    + +
    +
    @classmethod
    + + def + from_args(cls, args_dict: MutableMapping, field_class, *args, **kwargs): + + + +
    + +
    41    @classmethod
    +42    def from_args(cls, args_dict: MutableMapping, field_class, *args, **kwargs):
    +43        """
    +44        Instantiates a Config object from arguments
    +45
    +46
    +47        :param args_dict:
    +48        :param field_class:
    +49        :param args:
    +50        :param kwargs:
    +51        :return:
    +52        """
    +53        from debeir.rankers.transformer_sent_encoder import Encoder
    +54
    +55        field_names = set(f.name for f in dataclasses.fields(field_class))
    +56        obj = field_class(**{k: v for k, v in args_dict.items() if k in field_names})
    +57        if hasattr(obj, 'encoder_fp') and obj.encoder_fp:
    +58            obj.encoder = Encoder(obj.encoder_fp, obj.encoder_normalize)
    +59
    +60        obj.validate()
    +61
    +62        return obj
    +
    + + +

    Instantiates a Config object from arguments

    + +
    Parameters
    + +
      +
    • args_dict:
    • +
    • field_class:
    • +
    • args:
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    +
    @classmethod
    + + def + from_dict(cls, data_class, **kwargs): + + + +
    + +
    64    @classmethod
    +65    def from_dict(cls, data_class, **kwargs):
    +66        """
    +67        Instantiates a Config object from a dictionary
    +68
    +69        :param data_class:
    +70        :param kwargs:
    +71        :return:
    +72        """
    +73        from debeir.rankers.transformer_sent_encoder import Encoder
    +74
    +75        if "encoder_fp" in kwargs and kwargs["encoder_fp"]:
    +76            kwargs["encoder"] = Encoder(kwargs["encoder_fp"])
    +77
    +78        field_names = set(f.name for f in dataclasses.fields(data_class))
    +79        obj = data_class(**{k: v for k, v in kwargs.items() if k in field_names})
    +80        obj.validate(0)
    +81
    +82        return obj
    +
    + + +

    Instantiates a Config object from a dictionary

    + +
    Parameters
    + +
      +
    • data_class:
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    +
    @abc.abstractmethod
    + + def + validate(self): + + + +
    + +
    84    @abc.abstractmethod
    +85    def validate(self):
    +86        """
    +87        Validates if the config is correct.
    +88        Must be implemented by inherited classes.
    +89        """
    +90        pass
    +
    + + +

    Validates if the config is correct. +Must be implemented by inherited classes.

    +
    + + +
    +
    +
    + +
    +
    @dataclass(init=True, unsafe_hash=True)
    + + class + GenericConfig(Config, abc.ABC): + + + +
    + +
     93@dataclass(init=True, unsafe_hash=True)
    + 94class GenericConfig(Config, ABC):
    + 95    """
    + 96    Generic NIR Configuration file for which all configs will inherit
    + 97    """
    + 98    query_type: str
    + 99    index: str = None
    +100    encoder_normalize: bool = True
    +101    ablations: bool = False
    +102    norm_weight: float = None
    +103    automatic: bool = None
    +104    encoder: object = None
    +105    encoder_fp: str = None
    +106    query_weights: List[float] = None
    +107    cosine_weights: List[float] = None
    +108    evaluate: bool = False
    +109    qrels: str = None
    +110    config_fn: str = None
    +111    query_fn: str = None
    +112    parser_fn: str = None
    +113    executor_fn: str = None
    +114    cosine_ceiling: float = None
    +115    topics_path: str = None
    +116    return_id_only: bool = False
    +117    overwrite_output_if_exists: bool = False
    +118    output_file: str = None
    +119    run_name: str = None
    +120
    +121    @classmethod
    +122    def from_toml(cls, fp: Union[str, Path], *args, **kwargs) -> 'GenericConfig':
    +123        return Config.from_toml(fp, cls, *args, **kwargs)
    +
    + + +

    Generic NIR Configuration file for which all configs will inherit

    +
    + + +
    +
    + + GenericConfig( query_type: str, index: str = None, encoder_normalize: bool = True, ablations: bool = False, norm_weight: float = None, automatic: bool = None, encoder: object = None, encoder_fp: str = None, query_weights: List[float] = None, cosine_weights: List[float] = None, evaluate: bool = False, qrels: str = None, config_fn: str = None, query_fn: str = None, parser_fn: str = None, executor_fn: str = None, cosine_ceiling: float = None, topics_path: str = None, return_id_only: bool = False, overwrite_output_if_exists: bool = False, output_file: str = None, run_name: str = None) + + +
    + + + + +
    +
    + +
    +
    @classmethod
    + + def + from_toml( cls, fp: Union[str, pathlib.Path], *args, **kwargs) -> debeir.core.config.GenericConfig: + + + +
    + +
    121    @classmethod
    +122    def from_toml(cls, fp: Union[str, Path], *args, **kwargs) -> 'GenericConfig':
    +123        return Config.from_toml(fp, cls, *args, **kwargs)
    +
    + + +

    Instantiates a Config object from a toml file

    + +
    Parameters
    + +
      +
    • fp: File path of the Config TOML file
    • +
    • field_class: Class of the Config object to be instantiated
    • +
    • args: Arguments to be passed to Config
    • +
    • kwargs: Keyword arguments to be passed
    • +
    + +
    Returns
    + +
    +
    A instantiated and validated Config object.
    +
    +
    +
    + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    +
    @dataclass(init=True)
    + + class + ElasticsearchConfig(Config): + + + +
    + +
    161@dataclass(init=True)
    +162class ElasticsearchConfig(Config):
    +163    """
    +164    Basic Elasticsearch configuration file settings from the master nir.toml file
    +165    """
    +166    protocol: str
    +167    ip: str
    +168    port: str
    +169    timeout: int
    +170
    +171    def validate(self):
    +172        """
    +173        Checks if Elasticsearch URL is correct
    +174        """
    +175        assert self.protocol in ['http', 'https']
    +176        assert self.port.isdigit()
    +
    + + +

    Basic Elasticsearch configuration file settings from the master nir.toml file

    +
    + + +
    +
    + + ElasticsearchConfig(protocol: str, ip: str, port: str, timeout: int) + + +
    + + + + +
    +
    + +
    + + def + validate(self): + + + +
    + +
    171    def validate(self):
    +172        """
    +173        Checks if Elasticsearch URL is correct
    +174        """
    +175        assert self.protocol in ['http', 'https']
    +176        assert self.port.isdigit()
    +
    + + +

    Checks if Elasticsearch URL is correct

    +
    + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    +
    @dataclass(init=True)
    + + class + SolrConfig(ElasticsearchConfig): + + + +
    + +
    179@dataclass(init=True)
    +180class SolrConfig(ElasticsearchConfig):
    +181    """
    +182    Basic Solr configuration file settings from the master nir.toml file
    +183    """
    +184    pass
    +
    + + +

    Basic Solr configuration file settings from the master nir.toml file

    +
    + + +
    +
    + + SolrConfig(protocol: str, ip: str, port: str, timeout: int) + + +
    + + + + +
    +
    +
    Inherited Members
    +
    + + +
    +
    +
    +
    + +
    +
    @dataclass(init=True)
    + + class + MetricsConfig(Config): + + + +
    + +
    187@dataclass(init=True)
    +188class MetricsConfig(Config):
    +189    """
    +190    Basic Metrics configuration file settings from the master nir.toml file
    +191    """
    +192    metrics: List[str]
    +193
    +194    def validate(self):
    +195        """
    +196        Checks if each Metrics is usable by evaluator classes
    +197        """
    +198        for metric in self.metrics:
    +199            assert "@" in metric
    +200
    +201            metric, depth = metric.split("@")
    +202
    +203            assert metric.isalpha()
    +204            assert depth.isdigit()
    +
    + + +

    Basic Metrics configuration file settings from the master nir.toml file

    +
    + + +
    +
    + + MetricsConfig(metrics: List[str]) + + +
    + + + + +
    +
    + +
    + + def + validate(self): + + + +
    + +
    194    def validate(self):
    +195        """
    +196        Checks if each Metrics is usable by evaluator classes
    +197        """
    +198        for metric in self.metrics:
    +199            assert "@" in metric
    +200
    +201            metric, depth = metric.split("@")
    +202
    +203            assert metric.isalpha()
    +204            assert depth.isdigit()
    +
    + + +

    Checks if each Metrics is usable by evaluator classes

    +
    + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    +
    @dataclass(init=True)
    + + class + NIRConfig(Config): + + + +
    + +
    207@dataclass(init=True)
    +208class NIRConfig(Config):
    +209    """
    +210    Basic NIR configuration file settings from the master nir.toml file
    +211    """
    +212    norm_weight: str
    +213    evaluate: bool
    +214    return_size: int
    +215    output_directory: str
    +216
    +217    def validate(self):
    +218        return True
    +
    + + +

    Basic NIR configuration file settings from the master nir.toml file

    +
    + + +
    +
    + + NIRConfig( norm_weight: str, evaluate: bool, return_size: int, output_directory: str) + + +
    + + + + +
    +
    + +
    + + def + validate(self): + + + +
    + +
    217    def validate(self):
    +218        return True
    +
    + + +

    Validates if the config is correct. +Must be implemented by inherited classes.

    +
    + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    + + def + apply_config(func): + + + +
    + +
    221def apply_config(func):
    +222    """
    +223    Configuration decorator.
    +224
    +225    :param func: Decorated function
    +226    :return:
    +227    """
    +228
    +229    def use_config(self, *args, **kwargs):
    +230        """
    +231        Replaces keywords and args passed to the function with ones from self.config.
    +232
    +233        :param self:
    +234        :param args: To be updated
    +235        :param kwargs: To be updated
    +236        :return:
    +237        """
    +238        if self.config is not None:
    +239            kwargs = self.config.__update__(**kwargs)
    +240
    +241        return func(self, *args, **kwargs)
    +242
    +243    return use_config
    +
    + + +

    Configuration decorator.

    + +
    Parameters
    + +
      +
    • func: Decorated function
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + def + override_with_toml_config(func): + + + +
    + +
    246def override_with_toml_config(func):
    +247    """
    +248    Configuration decorator. Overwrite a functions kwargs and args with a specified toml config file.
    +249    Pass override_with_config=path/to/config
    +250
    +251    :param func: Decorated function
    +252    :return:
    +253    """
    +254
    +255    def override_with(override_with_config_: str = None, *args, **kwargs):
    +256        """
    +257        Replaces keywords and args passed to the function with ones from self.config.
    +258
    +259        :param override_with_config_: Path to config else None
    +260        :param args: To be updated
    +261        :param kwargs: To be updated
    +262        :return:
    +263        """
    +264
    +265        if f"override_{func.__name__}_with_config_" in kwargs:
    +266            override_with_config_ = f"override_{func.__name__}_with_config_"
    +267
    +268        if override_with_config_ is not None:
    +269            if os.path.exists(override_with_config_):
    +270                toml_kwargs = toml.load(override_with_config_)
    +271                kwargs = kwargs.update(**toml_kwargs)
    +272
    +273        return func(*args, **kwargs)
    +274
    +275    return override_with
    +
    + + +

    Configuration decorator. Overwrite a functions kwargs and args with a specified toml config file. +Pass override_with_config=path/to/config

    + +
    Parameters
    + +
      +
    • func: Decorated function
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + def + save_kwargs_to_file(func): + + + +
    + +
    278def save_kwargs_to_file(func):
    +279    def save_kwargs(save_kwargs_to_: str = None, *args, **kwargs):
    +280        """
    +281        Save kwargs passed to the function output_file = f"{save_kwargs_to_}_{func.__name__}.toml"
    +282
    +283        :param save_kwargs_to_: Path to save location for config else None. This should be a DIRECTORY.
    +284        :param args: To be updated
    +285        :param kwargs: To be updated
    +286        :return:
    +287        """
    +288        if save_kwargs_to_ is not None:
    +289            os.makedirs(save_kwargs_to_, exist_ok=True)
    +290
    +291            if os.path.exists(save_kwargs_to_):
    +292                output_file = f"{save_kwargs_to_}/{func.__name__}.toml"
    +293                loguru.logger.info(f"Saving kwargs to {output_file}")
    +294                toml.dump(kwargs, open(output_file, "w+"))
    +295
    +296        return func(*args, **kwargs)
    +297
    +298    return save_kwargs
    +
    + + + + +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/core/converters.html b/docs/debeir/core/converters.html new file mode 100644 index 0000000..29ccec4 --- /dev/null +++ b/docs/debeir/core/converters.html @@ -0,0 +1,446 @@ + + + + + + + debeir.core.converters API documentation + + + + + + + + + +
    +
    +

    +debeir.core.converters

    + + + + + + +
     1from collections import defaultdict
    + 2from typing import Dict, Union
    + 3
    + 4from debeir.core.parser import Parser
    + 5
    + 6import datasets
    + 7
    + 8
    + 9class ParsedTopicsToDataset:
    +10    """
    +11    Converts a parser's output to a huggingface dataset object.
    +12    """
    +13
    +14    @classmethod
    +15    def convert(cls, parser: Parser, output: Dict[Union[str, int], Dict]):
    +16        """
    +17        Flatten a Dict of shape (traditional parser output)
    +18        {topic_id: {
    +19                "Facet_1": ...
    +20                "Facet_2": ...
    +21            }
    +22        }
    +23
    +24        ->
    +25
    +26        To a flattened arrow-like dataset.
    +27        {
    +28        topic_ids: [],
    +29        Facet_1s: [],
    +30        Facet_2s: [],
    +31        }
    +32
    +33        :param output: Topics output from the parser object
    +34        :return:
    +35        """
    +36        flattened_topics = defaultdict(lambda: [])
    +37
    +38        for topic_id, topic in output.items():
    +39            flattened_topics["topic_id"].append(topic_id)
    +40
    +41            for field in parser.parse_fields:
    +42                if field in topic:
    +43                    flattened_topics[field].append(topic[field])
    +44                else:
    +45                    flattened_topics[field].append(None)
    +46
    +47        return datasets.Dataset.from_dict(flattened_topics)
    +
    + + +
    +
    + +
    + + class + ParsedTopicsToDataset: + + + +
    + +
    10class ParsedTopicsToDataset:
    +11    """
    +12    Converts a parser's output to a huggingface dataset object.
    +13    """
    +14
    +15    @classmethod
    +16    def convert(cls, parser: Parser, output: Dict[Union[str, int], Dict]):
    +17        """
    +18        Flatten a Dict of shape (traditional parser output)
    +19        {topic_id: {
    +20                "Facet_1": ...
    +21                "Facet_2": ...
    +22            }
    +23        }
    +24
    +25        ->
    +26
    +27        To a flattened arrow-like dataset.
    +28        {
    +29        topic_ids: [],
    +30        Facet_1s: [],
    +31        Facet_2s: [],
    +32        }
    +33
    +34        :param output: Topics output from the parser object
    +35        :return:
    +36        """
    +37        flattened_topics = defaultdict(lambda: [])
    +38
    +39        for topic_id, topic in output.items():
    +40            flattened_topics["topic_id"].append(topic_id)
    +41
    +42            for field in parser.parse_fields:
    +43                if field in topic:
    +44                    flattened_topics[field].append(topic[field])
    +45                else:
    +46                    flattened_topics[field].append(None)
    +47
    +48        return datasets.Dataset.from_dict(flattened_topics)
    +
    + + +

    Converts a parser's output to a huggingface dataset object.

    +
    + + +
    +
    + + ParsedTopicsToDataset() + + +
    + + + + +
    +
    + +
    +
    @classmethod
    + + def + convert( cls, parser: debeir.core.parser.Parser, output: Dict[Union[str, int], Dict]): + + + +
    + +
    15    @classmethod
    +16    def convert(cls, parser: Parser, output: Dict[Union[str, int], Dict]):
    +17        """
    +18        Flatten a Dict of shape (traditional parser output)
    +19        {topic_id: {
    +20                "Facet_1": ...
    +21                "Facet_2": ...
    +22            }
    +23        }
    +24
    +25        ->
    +26
    +27        To a flattened arrow-like dataset.
    +28        {
    +29        topic_ids: [],
    +30        Facet_1s: [],
    +31        Facet_2s: [],
    +32        }
    +33
    +34        :param output: Topics output from the parser object
    +35        :return:
    +36        """
    +37        flattened_topics = defaultdict(lambda: [])
    +38
    +39        for topic_id, topic in output.items():
    +40            flattened_topics["topic_id"].append(topic_id)
    +41
    +42            for field in parser.parse_fields:
    +43                if field in topic:
    +44                    flattened_topics[field].append(topic[field])
    +45                else:
    +46                    flattened_topics[field].append(None)
    +47
    +48        return datasets.Dataset.from_dict(flattened_topics)
    +
    + + +

    Flatten a Dict of shape (traditional parser output) +{topic_id: { + "Facet_1": ... + "Facet_2": ... + } +}

    + +

    ->

    + +

    To a flattened arrow-like dataset. +{ +topic_ids: [], +Facet_1s: [], +Facet_2s: [], +}

    + +
    Parameters
    + +
      +
    • output: Topics output from the parser object
    • +
    + +
    Returns
    +
    + + +
    +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/core/document.html b/docs/debeir/core/document.html new file mode 100644 index 0000000..6362663 --- /dev/null +++ b/docs/debeir/core/document.html @@ -0,0 +1,999 @@ + + + + + + + debeir.core.document API documentation + + + + + + + + + +
    +
    +

    +debeir.core.document

    + + + + + + +
      1import abc
    +  2import dataclasses
    +  3from collections import defaultdict
    +  4from typing import Dict, List, Union
    +  5
    +  6from debeir.utils.utils import flatten
    +  7
    +  8
    +  9@dataclasses.dataclass
    + 10class Document:
    + 11    """
    + 12    Generic Document class.
    + 13    Used as an interface for interacting across multiple indexes with different mappings.
    + 14    """
    + 15    doc_id: Union[int, float, str]
    + 16    topic_num: Union[int, str, float] = None
    + 17    facets: Dict = None
    + 18    score: Union[float, int] = 0.0  # Primay Score
    + 19    scores: Dict[str, Union[float, int]] = dataclasses.field(
    + 20        default_factory=lambda: {})  # Include other scores if needed
    + 21
    + 22    @classmethod
    + 23    @abc.abstractmethod
    + 24    def from_results(cls, results, *args, **kwargs) -> Dict[Union[int, float], 'Document']:
    + 25        """
    + 26        Produces a list of Document objects from raw results returned from the index
    + 27
    + 28        In the format {topic_num: [Document, ..., Document]}
    + 29        """
    + 30        pass
    + 31
    + 32    def get_document_id(self):
    + 33        """
    + 34        :return:
    + 35            self.doc_id
    + 36        """
    + 37        return self.doc_id
    + 38
    + 39    def flatten_facets(self, *args, **kwargs):
    + 40        """
    + 41        Flattens multi-level internal document facets into a single level
    + 42            e.g. Doc['Upper']['Lower'] -> Doc['Upper_Lower']
    + 43        :param args:
    + 44        :param kwargs:
    + 45        """
    + 46        self.facets = flatten(self.facets, *args, **kwargs)
    + 47
    + 48    @classmethod
    + 49    def _get_document_facet(cls, intermediate_repr, key):
    + 50        return intermediate_repr[key]
    + 51
    + 52    def get_document_facet(self, key, sep="_"):
    + 53        """
    + 54        Retrieve a document facet
    + 55        Works for multidimensional keys or single
    + 56        :param key: Facet to retrieve
    + 57        :param sep: The seperator for multidimensional key
    + 58        :return:
    + 59            Returns the document facet given the key (field)
    + 60        """
    + 61        if sep in key:
    + 62            keys = key.split(sep)
    + 63
    + 64            intermediate_repr = self.facets
    + 65            for k in keys:
    + 66                intermediate_repr = self._get_document_facet(intermediate_repr, k)
    + 67
    + 68            return intermediate_repr
    + 69
    + 70        return self.facets[key]
    + 71
    + 72    def set(self, doc_id=None, facets=None, score=None, facet=None, facet_value=None) -> 'Document':
    + 73        """
    + 74        Set attributes of the object. Use keyword arguments to do so. Works as a builder class.
    + 75        doc.set(doc_id="123").set(facets={"title": "my title"})
    + 76        :param doc_id:
    + 77        :param facets:
    + 78        :param score:
    + 79        :param facet:
    + 80        :param facet_value:
    + 81
    + 82        :return:
    + 83            Returns document object
    + 84        """
    + 85        if doc_id is not None:
    + 86            self.doc_id = doc_id
    + 87
    + 88        if facets is not None:
    + 89            self.facets = facets
    + 90
    + 91        if score is not None:
    + 92            self.score = score
    + 93
    + 94        if facet is not None and facet_value is not None:
    + 95            self.facets[facet] = facet_value
    + 96
    + 97        return self
    + 98
    + 99    def to_trec_format(self, rank, run_name) -> str:
    +100        """
    +101        Returns TREC format for the document
    +102        :return:
    +103            A trec formatted string
    +104        """
    +105
    +106        return f"{self.topic_num}\t" \
    +107               f"Q0\t" \
    +108               f"{self.doc_id}\t" \
    +109               f"{rank}\t" \
    +110               f"{self.score}\t" \
    +111               f"{run_name}\n"
    +112
    +113    @classmethod
    +114    def get_trec_format(cls, ranked_list: List['Document'], run_name="NO_RUN_NAME", sort=True, sorting_func=None):
    +115        """
    +116        Get the trec format of a list of ranked documents. This function is a generator.
    +117
    +118        :param ranked_list: A list of Document-type objects
    +119        :param run_name: Run name to print in the TREC formatted string
    +120        :param sort: Whether to sort the input list in descending order of score.
    +121        :param sorting_func: Custom sorting function will be used if provided
    +122        """
    +123
    +124        if sort:
    +125            if sorting_func:
    +126                ranked_list = sorting_func(ranked_list)
    +127            else:
    +128                ranked_list.sort(key=lambda doc: doc.score, reverse=True)
    +129
    +130        for rank, document in enumerate(ranked_list, start=1):
    +131            yield document.to_trec_format(rank, run_name)
    +132
    +133
    +134class ElasticsearchDocument(Document):
    +135    @classmethod
    +136    def from_results(cls, results, query_cls, ignore_facets=True,
    +137                     *args, **kwargs) -> Dict[Union[int, float], 'Document']:
    +138
    +139        documents = defaultdict(lambda: [])
    +140
    +141        for (topic_num, res) in results:
    +142            for rank, result in enumerate(res["hits"]["hits"], start=1):
    +143                doc_id = query_cls.get_id_mapping(result["_source"])
    +144                facets = {}
    +145
    +146                if not ignore_facets:
    +147                    facets = {k: v for (k, v) in result['_source'].items() if not k.startswith("_")}
    +148
    +149                documents[topic_num].append(ElasticsearchDocument(doc_id,
    +150                                                                  topic_num,
    +151                                                                  facets=facets,
    +152                                                                  score=float(result['_score'])))
    +153
    +154                documents[topic_num][-1].scores['rank'] = rank
    +155
    +156        return dict(documents)
    +157
    +158
    +159document_factory = {
    +160    "elasticsearch": ElasticsearchDocument
    +161}
    +
    + + +
    +
    + +
    +
    @dataclasses.dataclass
    + + class + Document: + + + +
    + +
     10@dataclasses.dataclass
    + 11class Document:
    + 12    """
    + 13    Generic Document class.
    + 14    Used as an interface for interacting across multiple indexes with different mappings.
    + 15    """
    + 16    doc_id: Union[int, float, str]
    + 17    topic_num: Union[int, str, float] = None
    + 18    facets: Dict = None
    + 19    score: Union[float, int] = 0.0  # Primay Score
    + 20    scores: Dict[str, Union[float, int]] = dataclasses.field(
    + 21        default_factory=lambda: {})  # Include other scores if needed
    + 22
    + 23    @classmethod
    + 24    @abc.abstractmethod
    + 25    def from_results(cls, results, *args, **kwargs) -> Dict[Union[int, float], 'Document']:
    + 26        """
    + 27        Produces a list of Document objects from raw results returned from the index
    + 28
    + 29        In the format {topic_num: [Document, ..., Document]}
    + 30        """
    + 31        pass
    + 32
    + 33    def get_document_id(self):
    + 34        """
    + 35        :return:
    + 36            self.doc_id
    + 37        """
    + 38        return self.doc_id
    + 39
    + 40    def flatten_facets(self, *args, **kwargs):
    + 41        """
    + 42        Flattens multi-level internal document facets into a single level
    + 43            e.g. Doc['Upper']['Lower'] -> Doc['Upper_Lower']
    + 44        :param args:
    + 45        :param kwargs:
    + 46        """
    + 47        self.facets = flatten(self.facets, *args, **kwargs)
    + 48
    + 49    @classmethod
    + 50    def _get_document_facet(cls, intermediate_repr, key):
    + 51        return intermediate_repr[key]
    + 52
    + 53    def get_document_facet(self, key, sep="_"):
    + 54        """
    + 55        Retrieve a document facet
    + 56        Works for multidimensional keys or single
    + 57        :param key: Facet to retrieve
    + 58        :param sep: The seperator for multidimensional key
    + 59        :return:
    + 60            Returns the document facet given the key (field)
    + 61        """
    + 62        if sep in key:
    + 63            keys = key.split(sep)
    + 64
    + 65            intermediate_repr = self.facets
    + 66            for k in keys:
    + 67                intermediate_repr = self._get_document_facet(intermediate_repr, k)
    + 68
    + 69            return intermediate_repr
    + 70
    + 71        return self.facets[key]
    + 72
    + 73    def set(self, doc_id=None, facets=None, score=None, facet=None, facet_value=None) -> 'Document':
    + 74        """
    + 75        Set attributes of the object. Use keyword arguments to do so. Works as a builder class.
    + 76        doc.set(doc_id="123").set(facets={"title": "my title"})
    + 77        :param doc_id:
    + 78        :param facets:
    + 79        :param score:
    + 80        :param facet:
    + 81        :param facet_value:
    + 82
    + 83        :return:
    + 84            Returns document object
    + 85        """
    + 86        if doc_id is not None:
    + 87            self.doc_id = doc_id
    + 88
    + 89        if facets is not None:
    + 90            self.facets = facets
    + 91
    + 92        if score is not None:
    + 93            self.score = score
    + 94
    + 95        if facet is not None and facet_value is not None:
    + 96            self.facets[facet] = facet_value
    + 97
    + 98        return self
    + 99
    +100    def to_trec_format(self, rank, run_name) -> str:
    +101        """
    +102        Returns TREC format for the document
    +103        :return:
    +104            A trec formatted string
    +105        """
    +106
    +107        return f"{self.topic_num}\t" \
    +108               f"Q0\t" \
    +109               f"{self.doc_id}\t" \
    +110               f"{rank}\t" \
    +111               f"{self.score}\t" \
    +112               f"{run_name}\n"
    +113
    +114    @classmethod
    +115    def get_trec_format(cls, ranked_list: List['Document'], run_name="NO_RUN_NAME", sort=True, sorting_func=None):
    +116        """
    +117        Get the trec format of a list of ranked documents. This function is a generator.
    +118
    +119        :param ranked_list: A list of Document-type objects
    +120        :param run_name: Run name to print in the TREC formatted string
    +121        :param sort: Whether to sort the input list in descending order of score.
    +122        :param sorting_func: Custom sorting function will be used if provided
    +123        """
    +124
    +125        if sort:
    +126            if sorting_func:
    +127                ranked_list = sorting_func(ranked_list)
    +128            else:
    +129                ranked_list.sort(key=lambda doc: doc.score, reverse=True)
    +130
    +131        for rank, document in enumerate(ranked_list, start=1):
    +132            yield document.to_trec_format(rank, run_name)
    +
    + + +

    Generic Document class. +Used as an interface for interacting across multiple indexes with different mappings.

    +
    + + +
    +
    + + Document( doc_id: Union[int, float, str], topic_num: Union[int, str, float] = None, facets: Dict = None, score: Union[float, int] = 0.0, scores: Dict[str, Union[float, int]] = <factory>) + + +
    + + + + +
    +
    + +
    +
    @classmethod
    +
    @abc.abstractmethod
    + + def + from_results( cls, results, *args, **kwargs) -> Dict[Union[int, float], debeir.core.document.Document]: + + + +
    + +
    23    @classmethod
    +24    @abc.abstractmethod
    +25    def from_results(cls, results, *args, **kwargs) -> Dict[Union[int, float], 'Document']:
    +26        """
    +27        Produces a list of Document objects from raw results returned from the index
    +28
    +29        In the format {topic_num: [Document, ..., Document]}
    +30        """
    +31        pass
    +
    + + +

    Produces a list of Document objects from raw results returned from the index

    + +

    In the format {topic_num: [Document, ..., Document]}

    +
    + + +
    +
    + +
    + + def + get_document_id(self): + + + +
    + +
    33    def get_document_id(self):
    +34        """
    +35        :return:
    +36            self.doc_id
    +37        """
    +38        return self.doc_id
    +
    + + +
    Returns
    + +
    +
    self.doc_id
    +
    +
    +
    + + +
    +
    + +
    + + def + flatten_facets(self, *args, **kwargs): + + + +
    + +
    40    def flatten_facets(self, *args, **kwargs):
    +41        """
    +42        Flattens multi-level internal document facets into a single level
    +43            e.g. Doc['Upper']['Lower'] -> Doc['Upper_Lower']
    +44        :param args:
    +45        :param kwargs:
    +46        """
    +47        self.facets = flatten(self.facets, *args, **kwargs)
    +
    + + +

    Flattens multi-level internal document facets into a single level + e.g. Doc['Upper']['Lower'] -> Doc['Upper_Lower']

    + +
    Parameters
    + +
      +
    • args:
    • +
    • kwargs:
    • +
    +
    + + +
    +
    + +
    + + def + get_document_facet(self, key, sep='_'): + + + +
    + +
    53    def get_document_facet(self, key, sep="_"):
    +54        """
    +55        Retrieve a document facet
    +56        Works for multidimensional keys or single
    +57        :param key: Facet to retrieve
    +58        :param sep: The seperator for multidimensional key
    +59        :return:
    +60            Returns the document facet given the key (field)
    +61        """
    +62        if sep in key:
    +63            keys = key.split(sep)
    +64
    +65            intermediate_repr = self.facets
    +66            for k in keys:
    +67                intermediate_repr = self._get_document_facet(intermediate_repr, k)
    +68
    +69            return intermediate_repr
    +70
    +71        return self.facets[key]
    +
    + + +

    Retrieve a document facet +Works for multidimensional keys or single

    + +
    Parameters
    + +
      +
    • key: Facet to retrieve
    • +
    • sep: The seperator for multidimensional key
    • +
    + +
    Returns
    + +
    +
    Returns the document facet given the key (field)
    +
    +
    +
    + + +
    +
    + +
    + + def + set( self, doc_id=None, facets=None, score=None, facet=None, facet_value=None) -> debeir.core.document.Document: + + + +
    + +
    73    def set(self, doc_id=None, facets=None, score=None, facet=None, facet_value=None) -> 'Document':
    +74        """
    +75        Set attributes of the object. Use keyword arguments to do so. Works as a builder class.
    +76        doc.set(doc_id="123").set(facets={"title": "my title"})
    +77        :param doc_id:
    +78        :param facets:
    +79        :param score:
    +80        :param facet:
    +81        :param facet_value:
    +82
    +83        :return:
    +84            Returns document object
    +85        """
    +86        if doc_id is not None:
    +87            self.doc_id = doc_id
    +88
    +89        if facets is not None:
    +90            self.facets = facets
    +91
    +92        if score is not None:
    +93            self.score = score
    +94
    +95        if facet is not None and facet_value is not None:
    +96            self.facets[facet] = facet_value
    +97
    +98        return self
    +
    + + +

    Set attributes of the object. Use keyword arguments to do so. Works as a builder class. +doc.set(doc_id="123").set(facets={"title": "my title"})

    + +
    Parameters
    + +
      +
    • doc_id:
    • +
    • facets:
    • +
    • score:
    • +
    • facet:
    • +
    • facet_value:
    • +
    + +
    Returns
    + +
    +
    Returns document object
    +
    +
    +
    + + +
    +
    + +
    + + def + to_trec_format(self, rank, run_name) -> str: + + + +
    + +
    100    def to_trec_format(self, rank, run_name) -> str:
    +101        """
    +102        Returns TREC format for the document
    +103        :return:
    +104            A trec formatted string
    +105        """
    +106
    +107        return f"{self.topic_num}\t" \
    +108               f"Q0\t" \
    +109               f"{self.doc_id}\t" \
    +110               f"{rank}\t" \
    +111               f"{self.score}\t" \
    +112               f"{run_name}\n"
    +
    + + +

    Returns TREC format for the document

    + +
    Returns
    + +
    +
    A trec formatted string
    +
    +
    +
    + + +
    +
    + +
    +
    @classmethod
    + + def + get_trec_format( cls, ranked_list: List[debeir.core.document.Document], run_name='NO_RUN_NAME', sort=True, sorting_func=None): + + + +
    + +
    114    @classmethod
    +115    def get_trec_format(cls, ranked_list: List['Document'], run_name="NO_RUN_NAME", sort=True, sorting_func=None):
    +116        """
    +117        Get the trec format of a list of ranked documents. This function is a generator.
    +118
    +119        :param ranked_list: A list of Document-type objects
    +120        :param run_name: Run name to print in the TREC formatted string
    +121        :param sort: Whether to sort the input list in descending order of score.
    +122        :param sorting_func: Custom sorting function will be used if provided
    +123        """
    +124
    +125        if sort:
    +126            if sorting_func:
    +127                ranked_list = sorting_func(ranked_list)
    +128            else:
    +129                ranked_list.sort(key=lambda doc: doc.score, reverse=True)
    +130
    +131        for rank, document in enumerate(ranked_list, start=1):
    +132            yield document.to_trec_format(rank, run_name)
    +
    + + +

    Get the trec format of a list of ranked documents. This function is a generator.

    + +
    Parameters
    + +
      +
    • ranked_list: A list of Document-type objects
    • +
    • run_name: Run name to print in the TREC formatted string
    • +
    • sort: Whether to sort the input list in descending order of score.
    • +
    • sorting_func: Custom sorting function will be used if provided
    • +
    +
    + + +
    +
    +
    + +
    + + class + ElasticsearchDocument(Document): + + + +
    + +
    135class ElasticsearchDocument(Document):
    +136    @classmethod
    +137    def from_results(cls, results, query_cls, ignore_facets=True,
    +138                     *args, **kwargs) -> Dict[Union[int, float], 'Document']:
    +139
    +140        documents = defaultdict(lambda: [])
    +141
    +142        for (topic_num, res) in results:
    +143            for rank, result in enumerate(res["hits"]["hits"], start=1):
    +144                doc_id = query_cls.get_id_mapping(result["_source"])
    +145                facets = {}
    +146
    +147                if not ignore_facets:
    +148                    facets = {k: v for (k, v) in result['_source'].items() if not k.startswith("_")}
    +149
    +150                documents[topic_num].append(ElasticsearchDocument(doc_id,
    +151                                                                  topic_num,
    +152                                                                  facets=facets,
    +153                                                                  score=float(result['_score'])))
    +154
    +155                documents[topic_num][-1].scores['rank'] = rank
    +156
    +157        return dict(documents)
    +
    + + +

    Generic Document class. +Used as an interface for interacting across multiple indexes with different mappings.

    +
    + + +
    + +
    +
    @classmethod
    + + def + from_results( cls, results, query_cls, ignore_facets=True, *args, **kwargs) -> Dict[Union[int, float], debeir.core.document.Document]: + + + +
    + +
    136    @classmethod
    +137    def from_results(cls, results, query_cls, ignore_facets=True,
    +138                     *args, **kwargs) -> Dict[Union[int, float], 'Document']:
    +139
    +140        documents = defaultdict(lambda: [])
    +141
    +142        for (topic_num, res) in results:
    +143            for rank, result in enumerate(res["hits"]["hits"], start=1):
    +144                doc_id = query_cls.get_id_mapping(result["_source"])
    +145                facets = {}
    +146
    +147                if not ignore_facets:
    +148                    facets = {k: v for (k, v) in result['_source'].items() if not k.startswith("_")}
    +149
    +150                documents[topic_num].append(ElasticsearchDocument(doc_id,
    +151                                                                  topic_num,
    +152                                                                  facets=facets,
    +153                                                                  score=float(result['_score'])))
    +154
    +155                documents[topic_num][-1].scores['rank'] = rank
    +156
    +157        return dict(documents)
    +
    + + +

    Produces a list of Document objects from raw results returned from the index

    + +

    In the format {topic_num: [Document, ..., Document]}

    +
    + + +
    + +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/core/executor.html b/docs/debeir/core/executor.html new file mode 100644 index 0000000..6ffc0a6 --- /dev/null +++ b/docs/debeir/core/executor.html @@ -0,0 +1,951 @@ + + + + + + + debeir.core.executor API documentation + + + + + + + + + +
    +
    +

    +debeir.core.executor

    + + + + + + +
      1from typing import Dict, Optional, Union
    +  2
    +  3import loguru
    +  4from debeir.engines.elasticsearch.executor import ElasticsearchExecutor
    +  5from debeir.core.config import GenericConfig, NIRConfig
    +  6from debeir.core.query import GenericElasticsearchQuery
    +  7from debeir.rankers.transformer_sent_encoder import Encoder
    +  8from debeir.utils.scaler import unpack_elasticsearch_scores
    +  9from elasticsearch import AsyncElasticsearch as Elasticsearch
    + 10
    + 11
    + 12class GenericElasticsearchExecutor(ElasticsearchExecutor):
    + 13    """
    + 14    Generic Executor class for Elasticsearch
    + 15    """
    + 16    query: GenericElasticsearchQuery
    + 17
    + 18    def __init__(
    + 19            self,
    + 20            topics: Dict[Union[str, int], Dict[str, str]],
    + 21            client: Elasticsearch,
    + 22            index_name: str,
    + 23            output_file: str,
    + 24            query: GenericElasticsearchQuery,
    + 25            encoder: Optional[Encoder] = None,
    + 26            config=None,
    + 27            *args,
    + 28            **kwargs,
    + 29    ):
    + 30        super().__init__(
    + 31            topics,
    + 32            client,
    + 33            index_name,
    + 34            output_file,
    + 35            query,
    + 36            encoder,
    + 37            config=config,
    + 38            *args,
    + 39            **kwargs,
    + 40        )
    + 41
    + 42        self.query_fns = {
    + 43            "query": self.generate_query,
    + 44            "embedding": self.generate_embedding_query,
    + 45        }
    + 46
    + 47    def generate_query(self, topic_num, best_fields=True, **kwargs):
    + 48        """
    + 49        Generates a standard BM25 query given the topic number
    + 50
    + 51        :param topic_num: Query topic number to generate
    + 52        :param best_fields: Whether to use a curated list of fields
    + 53        :param kwargs:
    + 54        :return:
    + 55        """
    + 56        return self.query.generate_query(topic_num, **kwargs)
    + 57
    + 58    # def generate_query_ablation(self, topic_num, **kwargs):
    + 59    #    return self.query.generate_query_ablation(topic_num)
    + 60
    + 61    def generate_embedding_query(
    + 62            self,
    + 63            topic_num,
    + 64            cosine_weights=None,
    + 65            query_weights=None,
    + 66            norm_weight=2.15,
    + 67            automatic_scores=None,
    + 68            **kwargs,
    + 69    ):
    + 70        """
    + 71        Executes an NIR-style query with combined scoring.
    + 72
    + 73        :param topic_num:
    + 74        :param cosine_weights:
    + 75        :param query_weights:
    + 76        :param norm_weight:
    + 77        :param automatic_scores:
    + 78        :param kwargs:
    + 79        :return:
    + 80        """
    + 81        assert self.encoder is not None or self.config.encoder is not None
    + 82
    + 83        if "encoder" not in kwargs:
    + 84            kwargs["encoder"] = self.encoder
    + 85
    + 86        return self.query.generate_query_embedding(
    + 87            topic_num,
    + 88            cosine_weights=cosine_weights,
    + 89            query_weight=query_weights,
    + 90            norm_weight=norm_weight,
    + 91            automatic_scores=automatic_scores,
    + 92            **kwargs,
    + 93        )
    + 94
    + 95    # @apply_config
    + 96    async def execute_query(
    + 97            self, query=None, return_size: int = None, return_id_only: bool = None,
    + 98            topic_num=None, ablation=False, query_type=None,
    + 99            **kwargs
    +100    ):
    +101        """
    +102        Executes a query using the underlying elasticsearch client.
    +103
    +104        :param query:
    +105        :param topic_num:
    +106        :param ablation:
    +107        :param query_type:
    +108        :param return_size:
    +109        :param return_id_only:
    +110        :param kwargs:
    +111        :return:
    +112        """
    +113
    +114        if ablation:
    +115            query_type = "ablation"
    +116
    +117        assert query is not None or topic_num is not None
    +118
    +119        if query:
    +120            if return_id_only:
    +121                # query["fields"] = [self.query.id_mapping]
    +122                # query["_source"] = False
    +123                query["_source"] = [self.query.id_mapping]
    +124            res = await self.client.search(
    +125                index=self.index_name, body=query, size=return_size
    +126            )
    +127
    +128            return [query, res]
    +129
    +130        if topic_num:
    +131            loguru.logger.debug(query_type)
    +132            body = self.query_fns[query_type](topic_num=topic_num, **kwargs)
    +133            if return_id_only:
    +134                loguru.logger.debug("Skip")
    +135                body["_source"] = [self.query.id_mapping]
    +136
    +137            loguru.logger.debug(body)
    +138            res = await self.client.search(
    +139                index=self.index_name, body=body, size=return_size
    +140            )
    +141
    +142            return [topic_num, res]
    +143
    +144    async def run_automatic_adjustment(self, return_results=False):
    +145        """
    +146        Get the normalization constant to be used in NIR-style queries for all topics given an initial
    +147        run of BM25 results.
    +148        """
    +149        loguru.logger.info("Running automatic BM25 weight adjustment")
    +150
    +151        # Backup variables temporarily
    +152        # size = self.return_size
    +153        # self.return_size = 1
    +154        # self.return_id_only = True
    +155        # prev_qt = self.config.query_type
    +156        # self.config.query_type = "query"
    +157
    +158        results = await self.run_all_queries(query_type="query",
    +159                                             return_results=True,
    +160                                             return_size=1,
    +161                                             return_id_only=True)
    +162
    +163        res = unpack_elasticsearch_scores(results)
    +164        self.query.set_bm25_scores(res)
    +165
    +166        if return_results:
    +167            return results
    +168
    +169    @classmethod
    +170    def build_from_config(cls, topics: Dict, query_obj: GenericElasticsearchQuery, client,
    +171                          config: GenericConfig, nir_config: NIRConfig):
    +172        """
    +173        Build an query executor engine from a config file.
    +174        """
    +175
    +176        return cls(
    +177            topics=topics,
    +178            client=client,
    +179            config=config,
    +180            index_name=config.index,
    +181            output_file="",
    +182            return_size=nir_config.return_size,
    +183            query=query_obj
    +184        )
    +
    + + +
    +
    + +
    + + class + GenericElasticsearchExecutor(debeir.engines.elasticsearch.executor.ElasticsearchExecutor): + + + +
    + +
     13class GenericElasticsearchExecutor(ElasticsearchExecutor):
    + 14    """
    + 15    Generic Executor class for Elasticsearch
    + 16    """
    + 17    query: GenericElasticsearchQuery
    + 18
    + 19    def __init__(
    + 20            self,
    + 21            topics: Dict[Union[str, int], Dict[str, str]],
    + 22            client: Elasticsearch,
    + 23            index_name: str,
    + 24            output_file: str,
    + 25            query: GenericElasticsearchQuery,
    + 26            encoder: Optional[Encoder] = None,
    + 27            config=None,
    + 28            *args,
    + 29            **kwargs,
    + 30    ):
    + 31        super().__init__(
    + 32            topics,
    + 33            client,
    + 34            index_name,
    + 35            output_file,
    + 36            query,
    + 37            encoder,
    + 38            config=config,
    + 39            *args,
    + 40            **kwargs,
    + 41        )
    + 42
    + 43        self.query_fns = {
    + 44            "query": self.generate_query,
    + 45            "embedding": self.generate_embedding_query,
    + 46        }
    + 47
    + 48    def generate_query(self, topic_num, best_fields=True, **kwargs):
    + 49        """
    + 50        Generates a standard BM25 query given the topic number
    + 51
    + 52        :param topic_num: Query topic number to generate
    + 53        :param best_fields: Whether to use a curated list of fields
    + 54        :param kwargs:
    + 55        :return:
    + 56        """
    + 57        return self.query.generate_query(topic_num, **kwargs)
    + 58
    + 59    # def generate_query_ablation(self, topic_num, **kwargs):
    + 60    #    return self.query.generate_query_ablation(topic_num)
    + 61
    + 62    def generate_embedding_query(
    + 63            self,
    + 64            topic_num,
    + 65            cosine_weights=None,
    + 66            query_weights=None,
    + 67            norm_weight=2.15,
    + 68            automatic_scores=None,
    + 69            **kwargs,
    + 70    ):
    + 71        """
    + 72        Executes an NIR-style query with combined scoring.
    + 73
    + 74        :param topic_num:
    + 75        :param cosine_weights:
    + 76        :param query_weights:
    + 77        :param norm_weight:
    + 78        :param automatic_scores:
    + 79        :param kwargs:
    + 80        :return:
    + 81        """
    + 82        assert self.encoder is not None or self.config.encoder is not None
    + 83
    + 84        if "encoder" not in kwargs:
    + 85            kwargs["encoder"] = self.encoder
    + 86
    + 87        return self.query.generate_query_embedding(
    + 88            topic_num,
    + 89            cosine_weights=cosine_weights,
    + 90            query_weight=query_weights,
    + 91            norm_weight=norm_weight,
    + 92            automatic_scores=automatic_scores,
    + 93            **kwargs,
    + 94        )
    + 95
    + 96    # @apply_config
    + 97    async def execute_query(
    + 98            self, query=None, return_size: int = None, return_id_only: bool = None,
    + 99            topic_num=None, ablation=False, query_type=None,
    +100            **kwargs
    +101    ):
    +102        """
    +103        Executes a query using the underlying elasticsearch client.
    +104
    +105        :param query:
    +106        :param topic_num:
    +107        :param ablation:
    +108        :param query_type:
    +109        :param return_size:
    +110        :param return_id_only:
    +111        :param kwargs:
    +112        :return:
    +113        """
    +114
    +115        if ablation:
    +116            query_type = "ablation"
    +117
    +118        assert query is not None or topic_num is not None
    +119
    +120        if query:
    +121            if return_id_only:
    +122                # query["fields"] = [self.query.id_mapping]
    +123                # query["_source"] = False
    +124                query["_source"] = [self.query.id_mapping]
    +125            res = await self.client.search(
    +126                index=self.index_name, body=query, size=return_size
    +127            )
    +128
    +129            return [query, res]
    +130
    +131        if topic_num:
    +132            loguru.logger.debug(query_type)
    +133            body = self.query_fns[query_type](topic_num=topic_num, **kwargs)
    +134            if return_id_only:
    +135                loguru.logger.debug("Skip")
    +136                body["_source"] = [self.query.id_mapping]
    +137
    +138            loguru.logger.debug(body)
    +139            res = await self.client.search(
    +140                index=self.index_name, body=body, size=return_size
    +141            )
    +142
    +143            return [topic_num, res]
    +144
    +145    async def run_automatic_adjustment(self, return_results=False):
    +146        """
    +147        Get the normalization constant to be used in NIR-style queries for all topics given an initial
    +148        run of BM25 results.
    +149        """
    +150        loguru.logger.info("Running automatic BM25 weight adjustment")
    +151
    +152        # Backup variables temporarily
    +153        # size = self.return_size
    +154        # self.return_size = 1
    +155        # self.return_id_only = True
    +156        # prev_qt = self.config.query_type
    +157        # self.config.query_type = "query"
    +158
    +159        results = await self.run_all_queries(query_type="query",
    +160                                             return_results=True,
    +161                                             return_size=1,
    +162                                             return_id_only=True)
    +163
    +164        res = unpack_elasticsearch_scores(results)
    +165        self.query.set_bm25_scores(res)
    +166
    +167        if return_results:
    +168            return results
    +169
    +170    @classmethod
    +171    def build_from_config(cls, topics: Dict, query_obj: GenericElasticsearchQuery, client,
    +172                          config: GenericConfig, nir_config: NIRConfig):
    +173        """
    +174        Build an query executor engine from a config file.
    +175        """
    +176
    +177        return cls(
    +178            topics=topics,
    +179            client=client,
    +180            config=config,
    +181            index_name=config.index,
    +182            output_file="",
    +183            return_size=nir_config.return_size,
    +184            query=query_obj
    +185        )
    +
    + + +

    Generic Executor class for Elasticsearch

    +
    + + +
    + +
    + + GenericElasticsearchExecutor( topics: Dict[Union[str, int], Dict[str, str]], client: elasticsearch.AsyncElasticsearch, index_name: str, output_file: str, query: debeir.core.query.GenericElasticsearchQuery, encoder: Optional[debeir.rankers.transformer_sent_encoder.Encoder] = None, config=None, *args, **kwargs) + + + +
    + +
    19    def __init__(
    +20            self,
    +21            topics: Dict[Union[str, int], Dict[str, str]],
    +22            client: Elasticsearch,
    +23            index_name: str,
    +24            output_file: str,
    +25            query: GenericElasticsearchQuery,
    +26            encoder: Optional[Encoder] = None,
    +27            config=None,
    +28            *args,
    +29            **kwargs,
    +30    ):
    +31        super().__init__(
    +32            topics,
    +33            client,
    +34            index_name,
    +35            output_file,
    +36            query,
    +37            encoder,
    +38            config=config,
    +39            *args,
    +40            **kwargs,
    +41        )
    +42
    +43        self.query_fns = {
    +44            "query": self.generate_query,
    +45            "embedding": self.generate_embedding_query,
    +46        }
    +
    + + + + +
    +
    + +
    + + def + generate_query(self, topic_num, best_fields=True, **kwargs): + + + +
    + +
    48    def generate_query(self, topic_num, best_fields=True, **kwargs):
    +49        """
    +50        Generates a standard BM25 query given the topic number
    +51
    +52        :param topic_num: Query topic number to generate
    +53        :param best_fields: Whether to use a curated list of fields
    +54        :param kwargs:
    +55        :return:
    +56        """
    +57        return self.query.generate_query(topic_num, **kwargs)
    +
    + + +

    Generates a standard BM25 query given the topic number

    + +
    Parameters
    + +
      +
    • topic_num: Query topic number to generate
    • +
    • best_fields: Whether to use a curated list of fields
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + def + generate_embedding_query( self, topic_num, cosine_weights=None, query_weights=None, norm_weight=2.15, automatic_scores=None, **kwargs): + + + +
    + +
    62    def generate_embedding_query(
    +63            self,
    +64            topic_num,
    +65            cosine_weights=None,
    +66            query_weights=None,
    +67            norm_weight=2.15,
    +68            automatic_scores=None,
    +69            **kwargs,
    +70    ):
    +71        """
    +72        Executes an NIR-style query with combined scoring.
    +73
    +74        :param topic_num:
    +75        :param cosine_weights:
    +76        :param query_weights:
    +77        :param norm_weight:
    +78        :param automatic_scores:
    +79        :param kwargs:
    +80        :return:
    +81        """
    +82        assert self.encoder is not None or self.config.encoder is not None
    +83
    +84        if "encoder" not in kwargs:
    +85            kwargs["encoder"] = self.encoder
    +86
    +87        return self.query.generate_query_embedding(
    +88            topic_num,
    +89            cosine_weights=cosine_weights,
    +90            query_weight=query_weights,
    +91            norm_weight=norm_weight,
    +92            automatic_scores=automatic_scores,
    +93            **kwargs,
    +94        )
    +
    + + +

    Executes an NIR-style query with combined scoring.

    + +
    Parameters
    + +
      +
    • topic_num:
    • +
    • cosine_weights:
    • +
    • query_weights:
    • +
    • norm_weight:
    • +
    • automatic_scores:
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + async def + execute_query( self, query=None, return_size: int = None, return_id_only: bool = None, topic_num=None, ablation=False, query_type=None, **kwargs): + + + +
    + +
     97    async def execute_query(
    + 98            self, query=None, return_size: int = None, return_id_only: bool = None,
    + 99            topic_num=None, ablation=False, query_type=None,
    +100            **kwargs
    +101    ):
    +102        """
    +103        Executes a query using the underlying elasticsearch client.
    +104
    +105        :param query:
    +106        :param topic_num:
    +107        :param ablation:
    +108        :param query_type:
    +109        :param return_size:
    +110        :param return_id_only:
    +111        :param kwargs:
    +112        :return:
    +113        """
    +114
    +115        if ablation:
    +116            query_type = "ablation"
    +117
    +118        assert query is not None or topic_num is not None
    +119
    +120        if query:
    +121            if return_id_only:
    +122                # query["fields"] = [self.query.id_mapping]
    +123                # query["_source"] = False
    +124                query["_source"] = [self.query.id_mapping]
    +125            res = await self.client.search(
    +126                index=self.index_name, body=query, size=return_size
    +127            )
    +128
    +129            return [query, res]
    +130
    +131        if topic_num:
    +132            loguru.logger.debug(query_type)
    +133            body = self.query_fns[query_type](topic_num=topic_num, **kwargs)
    +134            if return_id_only:
    +135                loguru.logger.debug("Skip")
    +136                body["_source"] = [self.query.id_mapping]
    +137
    +138            loguru.logger.debug(body)
    +139            res = await self.client.search(
    +140                index=self.index_name, body=body, size=return_size
    +141            )
    +142
    +143            return [topic_num, res]
    +
    + + +

    Execute a query given parameters

    + +
    Parameters
    + +
      +
    • args:
    • +
    • kwargs:
    • +
    +
    + + +
    +
    + +
    + + async def + run_automatic_adjustment(self, return_results=False): + + + +
    + +
    145    async def run_automatic_adjustment(self, return_results=False):
    +146        """
    +147        Get the normalization constant to be used in NIR-style queries for all topics given an initial
    +148        run of BM25 results.
    +149        """
    +150        loguru.logger.info("Running automatic BM25 weight adjustment")
    +151
    +152        # Backup variables temporarily
    +153        # size = self.return_size
    +154        # self.return_size = 1
    +155        # self.return_id_only = True
    +156        # prev_qt = self.config.query_type
    +157        # self.config.query_type = "query"
    +158
    +159        results = await self.run_all_queries(query_type="query",
    +160                                             return_results=True,
    +161                                             return_size=1,
    +162                                             return_id_only=True)
    +163
    +164        res = unpack_elasticsearch_scores(results)
    +165        self.query.set_bm25_scores(res)
    +166
    +167        if return_results:
    +168            return results
    +
    + + +

    Get the normalization constant to be used in NIR-style queries for all topics given an initial +run of BM25 results.

    +
    + + +
    +
    + +
    +
    @classmethod
    + + def + build_from_config( cls, topics: Dict, query_obj: debeir.core.query.GenericElasticsearchQuery, client, config: debeir.core.config.GenericConfig, nir_config: debeir.core.config.NIRConfig): + + + +
    + +
    170    @classmethod
    +171    def build_from_config(cls, topics: Dict, query_obj: GenericElasticsearchQuery, client,
    +172                          config: GenericConfig, nir_config: NIRConfig):
    +173        """
    +174        Build an query executor engine from a config file.
    +175        """
    +176
    +177        return cls(
    +178            topics=topics,
    +179            client=client,
    +180            config=config,
    +181            index_name=config.index,
    +182            output_file="",
    +183            return_size=nir_config.return_size,
    +184            query=query_obj
    +185        )
    +
    + + +

    Build an query executor engine from a config file.

    +
    + + +
    + +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/core/indexer.html b/docs/debeir/core/indexer.html new file mode 100644 index 0000000..b0ec356 --- /dev/null +++ b/docs/debeir/core/indexer.html @@ -0,0 +1,723 @@ + + + + + + + debeir.core.indexer API documentation + + + + + + + + + +
    +
    +

    +debeir.core.indexer

    + + + + + + +
     1import abc
    + 2import threading
    + 3from queue import Queue
    + 4from typing import List
    + 5
    + 6from debeir.rankers.transformer_sent_encoder import Encoder
    + 7from debeir.utils.utils import remove_excess_whitespace
    + 8from elasticsearch import Elasticsearch
    + 9
    +10
    +11class Indexer:
    +12    def __init__(self, client):
    +13        super().__init__()
    +14        self.client = client
    +15
    +16    @abc.abstractmethod
    +17    def get_field(self, document, field):
    +18        pass
    +19
    +20
    +21class SemanticElasticsearchIndexer(Indexer, threading.Thread):
    +22    """
    +23    Create a NIR-style index, with dense field representations with provided sentence encoder
    +24    Assumes you've already indexed to start with.
    +25    """
    +26
    +27    def __init__(self, es_client: Elasticsearch, encoder: Encoder, index: str,
    +28                 fields_to_encode: List[str], queue: Queue):
    +29        super().__init__(es_client)
    +30        self.encoder = encoder
    +31        self.index = index
    +32        self.fields = fields_to_encode
    +33        self.q = queue
    +34        self.update_mappings(self.index, self.fields, self.client)
    +35
    +36    @classmethod
    +37    def update_mappings(self, index, fields, client: Elasticsearch):
    +38        mapping = {}
    +39        value = {
    +40            "type": "dense_vector",
    +41            "dims": 768
    +42        }
    +43
    +44        for field in fields:
    +45            mapping[field + "_Embedding"] = value
    +46            mapping[field + "_Text"] = {"type": "text"}
    +47
    +48        client.indices.put_mapping(
    +49            body={
    +50                "properties": mapping
    +51            }, index=index)
    +52
    +53    # async def create_index(self, document_itr=None):
    +54    #    await self._update_mappings()
    +55
    +56    #    if document_itr is None:
    +57    #        document_itr = helpers.async_scan(self.es_client, index=self.index)
    +58
    +59    #    bar = tqdm(desc="Indexing", total=35_000)
    +60
    +61    #    async for document in document_itr:
    +62    #        doc = document["_source"]
    +63    #        await self.index_document(doc)
    +64
    +65    #        bar.update(1)
    +66
    +67    def get_field(self, document, field):
    +68        if field not in document:
    +69            return False
    +70
    +71        if "f{field}_Text" in document and document["f{field}_Text"] != 0:
    +72            return False
    +73
    +74        if 'Textblock' in document[field]:
    +75            return remove_excess_whitespace(document[field]['Textblock'])
    +76
    +77        return remove_excess_whitespace(document[field])
    +78
    +79    def index_document(self, document):
    +80        update_doc = {}
    +81        doc = document["_source"]
    +82
    +83        for field in self.fields:
    +84            text_field = self.get_field(doc, field)
    +85
    +86            if text_field:
    +87                embedding = self.encoder.encode(topic=text_field, disable_cache=True)
    +88                update_doc[f"{field}_Embedding"] = embedding
    +89                update_doc[f"{field}_Text"] = text_field
    +90
    +91        if update_doc:
    +92            self.client.update(index=self.index,
    +93                               id=document['_id'],
    +94                               doc=update_doc)
    +95
    +96    def run(self):
    +97        while not self.q.empty():
    +98            document = self.q.get()
    +99            self.index_document(document)
    +
    + + +
    +
    + +
    + + class + Indexer: + + + +
    + +
    12class Indexer:
    +13    def __init__(self, client):
    +14        super().__init__()
    +15        self.client = client
    +16
    +17    @abc.abstractmethod
    +18    def get_field(self, document, field):
    +19        pass
    +
    + + + + +
    + +
    + + Indexer(client) + + + +
    + +
    13    def __init__(self, client):
    +14        super().__init__()
    +15        self.client = client
    +
    + + + + +
    +
    + +
    +
    @abc.abstractmethod
    + + def + get_field(self, document, field): + + + +
    + +
    17    @abc.abstractmethod
    +18    def get_field(self, document, field):
    +19        pass
    +
    + + + + +
    +
    +
    + +
    + + class + SemanticElasticsearchIndexer(Indexer, threading.Thread): + + + +
    + +
     22class SemanticElasticsearchIndexer(Indexer, threading.Thread):
    + 23    """
    + 24    Create a NIR-style index, with dense field representations with provided sentence encoder
    + 25    Assumes you've already indexed to start with.
    + 26    """
    + 27
    + 28    def __init__(self, es_client: Elasticsearch, encoder: Encoder, index: str,
    + 29                 fields_to_encode: List[str], queue: Queue):
    + 30        super().__init__(es_client)
    + 31        self.encoder = encoder
    + 32        self.index = index
    + 33        self.fields = fields_to_encode
    + 34        self.q = queue
    + 35        self.update_mappings(self.index, self.fields, self.client)
    + 36
    + 37    @classmethod
    + 38    def update_mappings(self, index, fields, client: Elasticsearch):
    + 39        mapping = {}
    + 40        value = {
    + 41            "type": "dense_vector",
    + 42            "dims": 768
    + 43        }
    + 44
    + 45        for field in fields:
    + 46            mapping[field + "_Embedding"] = value
    + 47            mapping[field + "_Text"] = {"type": "text"}
    + 48
    + 49        client.indices.put_mapping(
    + 50            body={
    + 51                "properties": mapping
    + 52            }, index=index)
    + 53
    + 54    # async def create_index(self, document_itr=None):
    + 55    #    await self._update_mappings()
    + 56
    + 57    #    if document_itr is None:
    + 58    #        document_itr = helpers.async_scan(self.es_client, index=self.index)
    + 59
    + 60    #    bar = tqdm(desc="Indexing", total=35_000)
    + 61
    + 62    #    async for document in document_itr:
    + 63    #        doc = document["_source"]
    + 64    #        await self.index_document(doc)
    + 65
    + 66    #        bar.update(1)
    + 67
    + 68    def get_field(self, document, field):
    + 69        if field not in document:
    + 70            return False
    + 71
    + 72        if "f{field}_Text" in document and document["f{field}_Text"] != 0:
    + 73            return False
    + 74
    + 75        if 'Textblock' in document[field]:
    + 76            return remove_excess_whitespace(document[field]['Textblock'])
    + 77
    + 78        return remove_excess_whitespace(document[field])
    + 79
    + 80    def index_document(self, document):
    + 81        update_doc = {}
    + 82        doc = document["_source"]
    + 83
    + 84        for field in self.fields:
    + 85            text_field = self.get_field(doc, field)
    + 86
    + 87            if text_field:
    + 88                embedding = self.encoder.encode(topic=text_field, disable_cache=True)
    + 89                update_doc[f"{field}_Embedding"] = embedding
    + 90                update_doc[f"{field}_Text"] = text_field
    + 91
    + 92        if update_doc:
    + 93            self.client.update(index=self.index,
    + 94                               id=document['_id'],
    + 95                               doc=update_doc)
    + 96
    + 97    def run(self):
    + 98        while not self.q.empty():
    + 99            document = self.q.get()
    +100            self.index_document(document)
    +
    + + +

    Create a NIR-style index, with dense field representations with provided sentence encoder +Assumes you've already indexed to start with.

    +
    + + +
    + +
    + + SemanticElasticsearchIndexer( es_client: elasticsearch.Elasticsearch, encoder: debeir.rankers.transformer_sent_encoder.Encoder, index: str, fields_to_encode: List[str], queue: queue.Queue) + + + +
    + +
    28    def __init__(self, es_client: Elasticsearch, encoder: Encoder, index: str,
    +29                 fields_to_encode: List[str], queue: Queue):
    +30        super().__init__(es_client)
    +31        self.encoder = encoder
    +32        self.index = index
    +33        self.fields = fields_to_encode
    +34        self.q = queue
    +35        self.update_mappings(self.index, self.fields, self.client)
    +
    + + +

    This constructor should always be called with keyword arguments. Arguments are:

    + +

    group should be None; reserved for future extension when a ThreadGroup +class is implemented.

    + +

    target is the callable object to be invoked by the run() +method. Defaults to None, meaning nothing is called.

    + +

    name is the thread name. By default, a unique name is constructed of +the form "Thread-N" where N is a small decimal number.

    + +

    args is the argument tuple for the target invocation. Defaults to ().

    + +

    kwargs is a dictionary of keyword arguments for the target +invocation. Defaults to {}.

    + +

    If a subclass overrides the constructor, it must make sure to invoke +the base class constructor (Thread.__init__()) before doing anything +else to the thread.

    +
    + + +
    +
    + +
    +
    @classmethod
    + + def + update_mappings(self, index, fields, client: elasticsearch.Elasticsearch): + + + +
    + +
    37    @classmethod
    +38    def update_mappings(self, index, fields, client: Elasticsearch):
    +39        mapping = {}
    +40        value = {
    +41            "type": "dense_vector",
    +42            "dims": 768
    +43        }
    +44
    +45        for field in fields:
    +46            mapping[field + "_Embedding"] = value
    +47            mapping[field + "_Text"] = {"type": "text"}
    +48
    +49        client.indices.put_mapping(
    +50            body={
    +51                "properties": mapping
    +52            }, index=index)
    +
    + + + + +
    +
    + +
    + + def + get_field(self, document, field): + + + +
    + +
    68    def get_field(self, document, field):
    +69        if field not in document:
    +70            return False
    +71
    +72        if "f{field}_Text" in document and document["f{field}_Text"] != 0:
    +73            return False
    +74
    +75        if 'Textblock' in document[field]:
    +76            return remove_excess_whitespace(document[field]['Textblock'])
    +77
    +78        return remove_excess_whitespace(document[field])
    +
    + + + + +
    +
    + +
    + + def + index_document(self, document): + + + +
    + +
    80    def index_document(self, document):
    +81        update_doc = {}
    +82        doc = document["_source"]
    +83
    +84        for field in self.fields:
    +85            text_field = self.get_field(doc, field)
    +86
    +87            if text_field:
    +88                embedding = self.encoder.encode(topic=text_field, disable_cache=True)
    +89                update_doc[f"{field}_Embedding"] = embedding
    +90                update_doc[f"{field}_Text"] = text_field
    +91
    +92        if update_doc:
    +93            self.client.update(index=self.index,
    +94                               id=document['_id'],
    +95                               doc=update_doc)
    +
    + + + + +
    +
    + +
    + + def + run(self): + + + +
    + +
     97    def run(self):
    + 98        while not self.q.empty():
    + 99            document = self.q.get()
    +100            self.index_document(document)
    +
    + + +

    Method representing the thread's activity.

    + +

    You may override this method in a subclass. The standard run() method +invokes the callable object passed to the object's constructor as the +target argument, if any, with sequential and keyword arguments taken +from the args and kwargs arguments, respectively.

    +
    + + +
    +
    +
    Inherited Members
    +
    +
    threading.Thread
    +
    start
    +
    join
    +
    name
    +
    ident
    +
    is_alive
    +
    daemon
    +
    isDaemon
    +
    setDaemon
    +
    getName
    +
    setName
    +
    native_id
    + +
    +
    +
    +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/core/parser.html b/docs/debeir/core/parser.html new file mode 100644 index 0000000..30ed940 --- /dev/null +++ b/docs/debeir/core/parser.html @@ -0,0 +1,1095 @@ + + + + + + + debeir.core.parser API documentation + + + + + + + + + +
    +
    +

    +debeir.core.parser

    + + + + + + +
      1import abc
    +  2import csv
    +  3import dataclasses
    +  4import json
    +  5from collections import defaultdict
    +  6from dataclasses import dataclass
    +  7from typing import Dict, List
    +  8from xml.etree import ElementTree as ET
    +  9
    + 10import dill
    + 11import pandas as pd
    + 12
    + 13
    + 14# TODO: Parse fields can come from a config or ID_fields
    + 15# TODO: move _get_topics to private cls method with arguments, and expose get_topics as an instance method.
    + 16
    + 17
    + 18@dataclass(init=True)
    + 19class Parser:
    + 20    """
    + 21    Parser interface
    + 22    """
    + 23
    + 24    id_field: object
    + 25    parse_fields: List[str]
    + 26
    + 27    @classmethod
    + 28    def normalize(cls, input_dict) -> Dict:
    + 29        """
    + 30        Flatten the dictionary, i.e. from Dict[int, Dict] -> Dict[str, str_or_int]
    + 31
    + 32        :param input_dict:
    + 33        :return:
    + 34        """
    + 35        return pd.io.json.json_normalize(input_dict,
    + 36                                         sep=".").to_dict(orient='records')[0]
    + 37
    + 38    def get_topics(self, path, *args, **kwargs):
    + 39        """
    + 40        Instance method for getting topics, forwards instance self parameters to the _get_topics class method.
    + 41        """
    + 42
    + 43        self_kwargs = vars(self)
    + 44        kwargs.update(self_kwargs)
    + 45
    + 46        return self._get_topics(path, *args, **kwargs)
    + 47
    + 48    @classmethod
    + 49    @abc.abstractmethod
    + 50    def _get_topics(cls, path, *args, **kwargs) -> Dict[int, Dict[str, str]]:
    + 51        raise NotImplementedError
    + 52
    + 53
    + 54@dataclasses.dataclass(init=True)
    + 55class PickleParser(Parser):
    + 56    """
    + 57    Load topics from a pickle file
    + 58    """
    + 59
    + 60    @classmethod
    + 61    def _get_topics(cls, path, *args, **kwargs) -> Dict[int, Dict[str, str]]:
    + 62        return dill.load(path)
    + 63
    + 64
    + 65@dataclasses.dataclass(init=True)
    + 66class XMLParser(Parser):
    + 67    """
    + 68    Load topics from an XML file
    + 69    """
    + 70    topic_field_name: str
    + 71    id_field: str
    + 72    parse_fields: List[str]
    + 73
    + 74    @classmethod
    + 75    def _recurse_to_child_node(cls, node: ET.Element, track: List):
    + 76        """
    + 77        Helper method to get all children nodes for text extraction in an xml.
    + 78
    + 79        :param node: Current node
    + 80        :param track: List to track nodes
    + 81        :return:
    + 82        """
    + 83        if len(node.getchildren()) > 0:
    + 84            for child in node.getchildren():
    + 85                track.append(cls._recurse_to_child_node(child, track))
    + 86
    + 87        return node
    + 88
    + 89    @classmethod
    + 90    def unwrap(cls, doc_dict, key):
    + 91        """
    + 92        Converts defaultdict to dict and list of size 1 to just the element
    + 93
    + 94        :param doc_dict:
    + 95        :param key:
    + 96        """
    + 97        if isinstance(doc_dict[key], defaultdict):
    + 98            doc_dict[key] = dict(doc_dict[key])
    + 99
    +100            for e_key in doc_dict[key]:
    +101                cls.unwrap(doc_dict[key], e_key)
    +102
    +103        if isinstance(doc_dict[key], list):
    +104            if len(doc_dict[key]) == 1:
    +105                doc_dict[key] = doc_dict[key][0]
    +106
    +107    def _get_topics(self, path, *args, **kwargs) -> Dict[int, Dict[str, str]]:
    +108        all_topics = ET.parse(path).getroot()
    +109        qtopics = {}
    +110
    +111        for topic in all_topics.findall(self.topic_field_name):
    +112            _id = topic.attrib[self.id_field]
    +113            if _id.isnumeric():
    +114                _id = int(_id)
    +115
    +116            if self.parse_fields:
    +117                temp = {}
    +118                for field in self.parse_fields:
    +119                    try:
    +120                        temp[field] = topic.find(field).text.strip()
    +121                    except:
    +122                        continue
    +123
    +124                qtopics[_id] = temp
    +125            else:
    +126                #  The topic contains the text
    +127                qtopics[_id] = {"query": topic.text.strip()}
    +128
    +129        return qtopics
    +130
    +131
    +132@dataclasses.dataclass
    +133class CSVParser(Parser):
    +134    """
    +135    Loads topics from a CSV file
    +136    """
    +137    id_field = "id"
    +138    parse_fields = ["Text"]
    +139
    +140    def __init__(self, id_field=None, parse_fields=None):
    +141        if parse_fields is None:
    +142            parse_fields = ["id", "text"]
    +143
    +144        if id_field is None:
    +145            id_field = "id"
    +146
    +147        super().__init__(id_field, parse_fields)
    +148
    +149    @classmethod
    +150    def _get_topics(cls, csvfile, dialect="excel",
    +151                    id_field: str = None,
    +152                    parse_fields: List[str] = None,
    +153                    *args, **kwargs) -> Dict[int, Dict[str, str]]:
    +154        topics = {}
    +155
    +156        if isinstance(csvfile, str):
    +157            csvfile = open(csvfile, 'rt')
    +158
    +159        if id_field is None:
    +160            id_field = cls.id_field
    +161
    +162        if parse_fields is None:
    +163            parse_fields = cls.parse_fields
    +164
    +165        reader = csv.DictReader(csvfile, dialect=dialect)
    +166        for row in reader:
    +167            temp = {}
    +168
    +169            for field in parse_fields:
    +170                temp[field] = row[field]
    +171
    +172            topics[row[id_field]] = temp
    +173
    +174        return topics
    +175
    +176
    +177@dataclasses.dataclass(init=True)
    +178class TSVParser(CSVParser):
    +179
    +180    @classmethod
    +181    def _get_topics(cls, tsvfile, *args, **kwargs) -> Dict[int, Dict[str, str]]:
    +182        return CSVParser._get_topics(tsvfile, *args, dialect='excel-tab', **kwargs)
    +183
    +184
    +185@dataclasses.dataclass(init=True)
    +186class JsonLinesParser(Parser):
    +187    """
    +188    Loads topics from a jsonl file,
    +189    a JSON per line
    +190
    +191    Provide parse_fields, id_field and whether to ignore full matches on json keys
    +192    secondary_id appends to the primary id as jsonlines are flattened structure and may contain duplicate ids.
    +193    """
    +194    parse_fields: List[str]
    +195    id_field: str
    +196    ignore_full_match: bool = True
    +197    secondary_id: str = None
    +198
    +199    @classmethod
    +200    def _get_topics(cls, jsonlfile, id_field, parse_fields,
    +201                    ignore_full_match=True, secondary_id=None, *args, **kwargs) -> Dict[str, Dict]:
    +202        with open(jsonlfile, "r") as jsonl_f:
    +203            topics = {}
    +204
    +205            for jsonl in jsonl_f:
    +206                json_dict = json.loads(jsonl)
    +207                _id = json_dict.pop(id_field)
    +208
    +209                if secondary_id:
    +210                    _id = str(_id) + "_" + str(json_dict[secondary_id])
    +211
    +212                for key in list(json_dict.keys()):
    +213                    found = False
    +214                    for _key in parse_fields:
    +215                        if ignore_full_match:
    +216                            if key in _key or key == _key or _key in key:
    +217                                found = True
    +218                        else:
    +219                            if _key == key:
    +220                                found = True
    +221                    if not found:
    +222                        json_dict.pop(key)
    +223
    +224                topics[_id] = json_dict
    +225
    +226        return topics
    +
    + + +
    +
    + +
    +
    @dataclass(init=True)
    + + class + Parser: + + + +
    + +
    19@dataclass(init=True)
    +20class Parser:
    +21    """
    +22    Parser interface
    +23    """
    +24
    +25    id_field: object
    +26    parse_fields: List[str]
    +27
    +28    @classmethod
    +29    def normalize(cls, input_dict) -> Dict:
    +30        """
    +31        Flatten the dictionary, i.e. from Dict[int, Dict] -> Dict[str, str_or_int]
    +32
    +33        :param input_dict:
    +34        :return:
    +35        """
    +36        return pd.io.json.json_normalize(input_dict,
    +37                                         sep=".").to_dict(orient='records')[0]
    +38
    +39    def get_topics(self, path, *args, **kwargs):
    +40        """
    +41        Instance method for getting topics, forwards instance self parameters to the _get_topics class method.
    +42        """
    +43
    +44        self_kwargs = vars(self)
    +45        kwargs.update(self_kwargs)
    +46
    +47        return self._get_topics(path, *args, **kwargs)
    +48
    +49    @classmethod
    +50    @abc.abstractmethod
    +51    def _get_topics(cls, path, *args, **kwargs) -> Dict[int, Dict[str, str]]:
    +52        raise NotImplementedError
    +
    + + +

    Parser interface

    +
    + + +
    +
    + + Parser(id_field: object, parse_fields: List[str]) + + +
    + + + + +
    +
    + +
    +
    @classmethod
    + + def + normalize(cls, input_dict) -> Dict: + + + +
    + +
    28    @classmethod
    +29    def normalize(cls, input_dict) -> Dict:
    +30        """
    +31        Flatten the dictionary, i.e. from Dict[int, Dict] -> Dict[str, str_or_int]
    +32
    +33        :param input_dict:
    +34        :return:
    +35        """
    +36        return pd.io.json.json_normalize(input_dict,
    +37                                         sep=".").to_dict(orient='records')[0]
    +
    + + +

    Flatten the dictionary, i.e. from Dict[int, Dict] -> Dict[str, str_or_int]

    + +
    Parameters
    + +
      +
    • input_dict:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + def + get_topics(self, path, *args, **kwargs): + + + +
    + +
    39    def get_topics(self, path, *args, **kwargs):
    +40        """
    +41        Instance method for getting topics, forwards instance self parameters to the _get_topics class method.
    +42        """
    +43
    +44        self_kwargs = vars(self)
    +45        kwargs.update(self_kwargs)
    +46
    +47        return self._get_topics(path, *args, **kwargs)
    +
    + + +

    Instance method for getting topics, forwards instance self parameters to the _get_topics class method.

    +
    + + +
    +
    +
    + +
    +
    @dataclasses.dataclass(init=True)
    + + class + PickleParser(Parser): + + + +
    + +
    55@dataclasses.dataclass(init=True)
    +56class PickleParser(Parser):
    +57    """
    +58    Load topics from a pickle file
    +59    """
    +60
    +61    @classmethod
    +62    def _get_topics(cls, path, *args, **kwargs) -> Dict[int, Dict[str, str]]:
    +63        return dill.load(path)
    +
    + + +

    Load topics from a pickle file

    +
    + + +
    +
    + + PickleParser(id_field: object, parse_fields: List[str]) + + +
    + + + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    +
    @dataclasses.dataclass(init=True)
    + + class + XMLParser(Parser): + + + +
    + +
     66@dataclasses.dataclass(init=True)
    + 67class XMLParser(Parser):
    + 68    """
    + 69    Load topics from an XML file
    + 70    """
    + 71    topic_field_name: str
    + 72    id_field: str
    + 73    parse_fields: List[str]
    + 74
    + 75    @classmethod
    + 76    def _recurse_to_child_node(cls, node: ET.Element, track: List):
    + 77        """
    + 78        Helper method to get all children nodes for text extraction in an xml.
    + 79
    + 80        :param node: Current node
    + 81        :param track: List to track nodes
    + 82        :return:
    + 83        """
    + 84        if len(node.getchildren()) > 0:
    + 85            for child in node.getchildren():
    + 86                track.append(cls._recurse_to_child_node(child, track))
    + 87
    + 88        return node
    + 89
    + 90    @classmethod
    + 91    def unwrap(cls, doc_dict, key):
    + 92        """
    + 93        Converts defaultdict to dict and list of size 1 to just the element
    + 94
    + 95        :param doc_dict:
    + 96        :param key:
    + 97        """
    + 98        if isinstance(doc_dict[key], defaultdict):
    + 99            doc_dict[key] = dict(doc_dict[key])
    +100
    +101            for e_key in doc_dict[key]:
    +102                cls.unwrap(doc_dict[key], e_key)
    +103
    +104        if isinstance(doc_dict[key], list):
    +105            if len(doc_dict[key]) == 1:
    +106                doc_dict[key] = doc_dict[key][0]
    +107
    +108    def _get_topics(self, path, *args, **kwargs) -> Dict[int, Dict[str, str]]:
    +109        all_topics = ET.parse(path).getroot()
    +110        qtopics = {}
    +111
    +112        for topic in all_topics.findall(self.topic_field_name):
    +113            _id = topic.attrib[self.id_field]
    +114            if _id.isnumeric():
    +115                _id = int(_id)
    +116
    +117            if self.parse_fields:
    +118                temp = {}
    +119                for field in self.parse_fields:
    +120                    try:
    +121                        temp[field] = topic.find(field).text.strip()
    +122                    except:
    +123                        continue
    +124
    +125                qtopics[_id] = temp
    +126            else:
    +127                #  The topic contains the text
    +128                qtopics[_id] = {"query": topic.text.strip()}
    +129
    +130        return qtopics
    +
    + + +

    Load topics from an XML file

    +
    + + +
    +
    + + XMLParser(id_field: str, parse_fields: List[str], topic_field_name: str) + + +
    + + + + +
    +
    + +
    +
    @classmethod
    + + def + unwrap(cls, doc_dict, key): + + + +
    + +
     90    @classmethod
    + 91    def unwrap(cls, doc_dict, key):
    + 92        """
    + 93        Converts defaultdict to dict and list of size 1 to just the element
    + 94
    + 95        :param doc_dict:
    + 96        :param key:
    + 97        """
    + 98        if isinstance(doc_dict[key], defaultdict):
    + 99            doc_dict[key] = dict(doc_dict[key])
    +100
    +101            for e_key in doc_dict[key]:
    +102                cls.unwrap(doc_dict[key], e_key)
    +103
    +104        if isinstance(doc_dict[key], list):
    +105            if len(doc_dict[key]) == 1:
    +106                doc_dict[key] = doc_dict[key][0]
    +
    + + +

    Converts defaultdict to dict and list of size 1 to just the element

    + +
    Parameters
    + +
      +
    • doc_dict:
    • +
    • key:
    • +
    +
    + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    +
    @dataclasses.dataclass
    + + class + CSVParser(Parser): + + + +
    + +
    133@dataclasses.dataclass
    +134class CSVParser(Parser):
    +135    """
    +136    Loads topics from a CSV file
    +137    """
    +138    id_field = "id"
    +139    parse_fields = ["Text"]
    +140
    +141    def __init__(self, id_field=None, parse_fields=None):
    +142        if parse_fields is None:
    +143            parse_fields = ["id", "text"]
    +144
    +145        if id_field is None:
    +146            id_field = "id"
    +147
    +148        super().__init__(id_field, parse_fields)
    +149
    +150    @classmethod
    +151    def _get_topics(cls, csvfile, dialect="excel",
    +152                    id_field: str = None,
    +153                    parse_fields: List[str] = None,
    +154                    *args, **kwargs) -> Dict[int, Dict[str, str]]:
    +155        topics = {}
    +156
    +157        if isinstance(csvfile, str):
    +158            csvfile = open(csvfile, 'rt')
    +159
    +160        if id_field is None:
    +161            id_field = cls.id_field
    +162
    +163        if parse_fields is None:
    +164            parse_fields = cls.parse_fields
    +165
    +166        reader = csv.DictReader(csvfile, dialect=dialect)
    +167        for row in reader:
    +168            temp = {}
    +169
    +170            for field in parse_fields:
    +171                temp[field] = row[field]
    +172
    +173            topics[row[id_field]] = temp
    +174
    +175        return topics
    +
    + + +

    Loads topics from a CSV file

    +
    + + +
    + +
    + + CSVParser(id_field=None, parse_fields=None) + + + +
    + +
    141    def __init__(self, id_field=None, parse_fields=None):
    +142        if parse_fields is None:
    +143            parse_fields = ["id", "text"]
    +144
    +145        if id_field is None:
    +146            id_field = "id"
    +147
    +148        super().__init__(id_field, parse_fields)
    +
    + + + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    +
    @dataclasses.dataclass(init=True)
    + + class + TSVParser(CSVParser): + + + +
    + +
    178@dataclasses.dataclass(init=True)
    +179class TSVParser(CSVParser):
    +180
    +181    @classmethod
    +182    def _get_topics(cls, tsvfile, *args, **kwargs) -> Dict[int, Dict[str, str]]:
    +183        return CSVParser._get_topics(tsvfile, *args, dialect='excel-tab', **kwargs)
    +
    + + + + +
    +
    + + TSVParser(id_field: object, parse_fields: List[str]) + + +
    + + + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    +
    @dataclasses.dataclass(init=True)
    + + class + JsonLinesParser(Parser): + + + +
    + +
    186@dataclasses.dataclass(init=True)
    +187class JsonLinesParser(Parser):
    +188    """
    +189    Loads topics from a jsonl file,
    +190    a JSON per line
    +191
    +192    Provide parse_fields, id_field and whether to ignore full matches on json keys
    +193    secondary_id appends to the primary id as jsonlines are flattened structure and may contain duplicate ids.
    +194    """
    +195    parse_fields: List[str]
    +196    id_field: str
    +197    ignore_full_match: bool = True
    +198    secondary_id: str = None
    +199
    +200    @classmethod
    +201    def _get_topics(cls, jsonlfile, id_field, parse_fields,
    +202                    ignore_full_match=True, secondary_id=None, *args, **kwargs) -> Dict[str, Dict]:
    +203        with open(jsonlfile, "r") as jsonl_f:
    +204            topics = {}
    +205
    +206            for jsonl in jsonl_f:
    +207                json_dict = json.loads(jsonl)
    +208                _id = json_dict.pop(id_field)
    +209
    +210                if secondary_id:
    +211                    _id = str(_id) + "_" + str(json_dict[secondary_id])
    +212
    +213                for key in list(json_dict.keys()):
    +214                    found = False
    +215                    for _key in parse_fields:
    +216                        if ignore_full_match:
    +217                            if key in _key or key == _key or _key in key:
    +218                                found = True
    +219                        else:
    +220                            if _key == key:
    +221                                found = True
    +222                    if not found:
    +223                        json_dict.pop(key)
    +224
    +225                topics[_id] = json_dict
    +226
    +227        return topics
    +
    + + +

    Loads topics from a jsonl file, +a JSON per line

    + +

    Provide parse_fields, id_field and whether to ignore full matches on json keys +secondary_id appends to the primary id as jsonlines are flattened structure and may contain duplicate ids.

    +
    + + +
    +
    + + JsonLinesParser( id_field: str, parse_fields: List[str], ignore_full_match: bool = True, secondary_id: str = None) + + +
    + + + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/core/pipeline.html b/docs/debeir/core/pipeline.html new file mode 100644 index 0000000..c15dabe --- /dev/null +++ b/docs/debeir/core/pipeline.html @@ -0,0 +1,910 @@ + + + + + + + debeir.core.pipeline API documentation + + + + + + + + + +
    +
    +

    +debeir.core.pipeline

    + + + + + + +
      1import abc
    +  2from typing import List
    +  3
    +  4import debeir
    +  5from debeir.core.config import Config, GenericConfig
    +  6from debeir.core.executor import GenericElasticsearchExecutor
    +  7from debeir.core.results import Results
    +  8from debeir.datasets.factory import factory_fn, get_nir_config
    +  9from debeir.engines.client import Client
    + 10from loguru import logger
    + 11
    + 12
    + 13class Pipeline:
    + 14    pipeline_structure = ["parser", "query", "engine", "evaluator"]
    + 15    cannot_disable = ["parser", "query", "engine"]
    + 16    callbacks: List['debeir.core.callbacks.Callback']
    + 17    output_file = None
    + 18
    + 19    def __init__(self, engine: GenericElasticsearchExecutor,
    + 20                 engine_name: str,
    + 21                 metrics_config,
    + 22                 engine_config,
    + 23                 nir_config,
    + 24                 run_config: Config,
    + 25                 callbacks=None):
    + 26
    + 27        self.engine = engine
    + 28        self.engine_name = engine_name
    + 29        self.run_config = run_config
    + 30        self.metrics_config = metrics_config
    + 31        self.engine_config = engine_config
    + 32        self.nir_config = nir_config
    + 33        self.output_file = None
    + 34        self.disable = {}
    + 35
    + 36        if callbacks is None:
    + 37            self.callbacks = []
    + 38        else:
    + 39            self.callbacks = callbacks
    + 40
    + 41    @classmethod
    + 42    def build_from_config(cls, nir_config_fp, engine, config_fp) -> 'Pipeline':
    + 43        query_cls, config, parser, executor_cls = factory_fn(config_fp)
    + 44
    + 45        nir_config, search_engine_config, metrics_config = get_nir_config(nir_config_fp,
    + 46                                                                          engine=engine,
    + 47                                                                          ignore_errors=False)
    + 48
    + 49        client = Client.build_from_config(engine, search_engine_config)
    + 50        topics = parser._get_topics(config.topics_path)
    + 51
    + 52        query = query_cls(topics=topics, query_type=config.query_type, config=config)
    + 53
    + 54        executor = executor_cls.build_from_config(
    + 55            topics,
    + 56            query,
    + 57            client.get_client(engine),
    + 58            config,
    + 59            nir_config
    + 60        )
    + 61
    + 62        return cls(
    + 63            executor,
    + 64            engine,
    + 65            metrics_config,
    + 66            search_engine_config,
    + 67            nir_config,
    + 68            config
    + 69        )
    + 70
    + 71    def disable(self, parts: list):
    + 72        for part in parts:
    + 73            if part in self.pipeline_structure and part not in self.cannot_disable:
    + 74                self.disable[part] = True
    + 75            else:
    + 76                logger.warning(f"Cannot disable {part} because it doesn't exist or is integral to the pipeline")
    + 77
    + 78    @abc.abstractmethod
    + 79    async def run_pipeline(self, *args,
    + 80                           **kwargs):
    + 81        raise NotImplementedError()
    + 82
    + 83
    + 84class NIRPipeline(Pipeline):
    + 85    run_config: GenericConfig
    + 86
    + 87    def __init__(self, *args, **kwargs):
    + 88        super().__init__(*args, **kwargs)
    + 89
    + 90    async def prehook(self):
    + 91        if self.run_config.automatic or self.run_config.norm_weight == "automatic":
    + 92            logger.info(f"Running initial BM25 for query adjustment")
    + 93            await self.engine.run_automatic_adjustment()
    + 94
    + 95    async def run_engine(self, *args, **kwargs):
    + 96        # Run bm25 nir adjustment
    + 97        logger.info(f"Running {self.run_config.query_type} queries")
    + 98
    + 99        return await self.engine.run_all_queries(*args, return_results=True, **kwargs)
    +100
    +101    async def posthook(self, *args, **kwargs):
    +102        pass
    +103
    +104    async def run_pipeline(self, *args, return_results=False, **kwargs):
    +105        for cb in self.callbacks:
    +106            cb.before(self)
    +107
    +108        await self.prehook()
    +109        results = await self.run_engine(*args, **kwargs)
    +110        results = Results(results, self.engine.query, self.engine_name)
    +111
    +112        for cb in self.callbacks:
    +113            cb.after(results)
    +114
    +115        return results
    +116
    +117    def register_callback(self, cb):
    +118        self.callbacks.append(cb)
    +119
    +120
    +121class BM25Pipeline(NIRPipeline):
    +122    async def run_pipeline(self, *args, return_results=False, **kwargs):
    +123        for cb in self.callbacks:
    +124            cb.before(self)
    +125
    +126        results = await self.engine.run_all_queries(query_type="query",
    +127                                                    return_results=True)
    +128
    +129        results = Results(results, self.engine.query, self.engine_name)
    +130
    +131        for cb in self.callbacks:
    +132            cb.after(results)
    +133
    +134        return results
    +
    + + +
    +
    + +
    + + class + Pipeline: + + + +
    + +
    14class Pipeline:
    +15    pipeline_structure = ["parser", "query", "engine", "evaluator"]
    +16    cannot_disable = ["parser", "query", "engine"]
    +17    callbacks: List['debeir.core.callbacks.Callback']
    +18    output_file = None
    +19
    +20    def __init__(self, engine: GenericElasticsearchExecutor,
    +21                 engine_name: str,
    +22                 metrics_config,
    +23                 engine_config,
    +24                 nir_config,
    +25                 run_config: Config,
    +26                 callbacks=None):
    +27
    +28        self.engine = engine
    +29        self.engine_name = engine_name
    +30        self.run_config = run_config
    +31        self.metrics_config = metrics_config
    +32        self.engine_config = engine_config
    +33        self.nir_config = nir_config
    +34        self.output_file = None
    +35        self.disable = {}
    +36
    +37        if callbacks is None:
    +38            self.callbacks = []
    +39        else:
    +40            self.callbacks = callbacks
    +41
    +42    @classmethod
    +43    def build_from_config(cls, nir_config_fp, engine, config_fp) -> 'Pipeline':
    +44        query_cls, config, parser, executor_cls = factory_fn(config_fp)
    +45
    +46        nir_config, search_engine_config, metrics_config = get_nir_config(nir_config_fp,
    +47                                                                          engine=engine,
    +48                                                                          ignore_errors=False)
    +49
    +50        client = Client.build_from_config(engine, search_engine_config)
    +51        topics = parser._get_topics(config.topics_path)
    +52
    +53        query = query_cls(topics=topics, query_type=config.query_type, config=config)
    +54
    +55        executor = executor_cls.build_from_config(
    +56            topics,
    +57            query,
    +58            client.get_client(engine),
    +59            config,
    +60            nir_config
    +61        )
    +62
    +63        return cls(
    +64            executor,
    +65            engine,
    +66            metrics_config,
    +67            search_engine_config,
    +68            nir_config,
    +69            config
    +70        )
    +71
    +72    def disable(self, parts: list):
    +73        for part in parts:
    +74            if part in self.pipeline_structure and part not in self.cannot_disable:
    +75                self.disable[part] = True
    +76            else:
    +77                logger.warning(f"Cannot disable {part} because it doesn't exist or is integral to the pipeline")
    +78
    +79    @abc.abstractmethod
    +80    async def run_pipeline(self, *args,
    +81                           **kwargs):
    +82        raise NotImplementedError()
    +
    + + + + +
    + +
    + + Pipeline( engine: debeir.core.executor.GenericElasticsearchExecutor, engine_name: str, metrics_config, engine_config, nir_config, run_config: debeir.core.config.Config, callbacks=None) + + + +
    + +
    20    def __init__(self, engine: GenericElasticsearchExecutor,
    +21                 engine_name: str,
    +22                 metrics_config,
    +23                 engine_config,
    +24                 nir_config,
    +25                 run_config: Config,
    +26                 callbacks=None):
    +27
    +28        self.engine = engine
    +29        self.engine_name = engine_name
    +30        self.run_config = run_config
    +31        self.metrics_config = metrics_config
    +32        self.engine_config = engine_config
    +33        self.nir_config = nir_config
    +34        self.output_file = None
    +35        self.disable = {}
    +36
    +37        if callbacks is None:
    +38            self.callbacks = []
    +39        else:
    +40            self.callbacks = callbacks
    +
    + + + + +
    +
    + +
    + + def + disable(self, parts: list): + + + +
    + +
    72    def disable(self, parts: list):
    +73        for part in parts:
    +74            if part in self.pipeline_structure and part not in self.cannot_disable:
    +75                self.disable[part] = True
    +76            else:
    +77                logger.warning(f"Cannot disable {part} because it doesn't exist or is integral to the pipeline")
    +
    + + + + +
    +
    + +
    +
    @classmethod
    + + def + build_from_config(cls, nir_config_fp, engine, config_fp) -> debeir.core.pipeline.Pipeline: + + + +
    + +
    42    @classmethod
    +43    def build_from_config(cls, nir_config_fp, engine, config_fp) -> 'Pipeline':
    +44        query_cls, config, parser, executor_cls = factory_fn(config_fp)
    +45
    +46        nir_config, search_engine_config, metrics_config = get_nir_config(nir_config_fp,
    +47                                                                          engine=engine,
    +48                                                                          ignore_errors=False)
    +49
    +50        client = Client.build_from_config(engine, search_engine_config)
    +51        topics = parser._get_topics(config.topics_path)
    +52
    +53        query = query_cls(topics=topics, query_type=config.query_type, config=config)
    +54
    +55        executor = executor_cls.build_from_config(
    +56            topics,
    +57            query,
    +58            client.get_client(engine),
    +59            config,
    +60            nir_config
    +61        )
    +62
    +63        return cls(
    +64            executor,
    +65            engine,
    +66            metrics_config,
    +67            search_engine_config,
    +68            nir_config,
    +69            config
    +70        )
    +
    + + + + +
    +
    + +
    +
    @abc.abstractmethod
    + + async def + run_pipeline(self, *args, **kwargs): + + + +
    + +
    79    @abc.abstractmethod
    +80    async def run_pipeline(self, *args,
    +81                           **kwargs):
    +82        raise NotImplementedError()
    +
    + + + + +
    +
    +
    + +
    + + class + NIRPipeline(Pipeline): + + + +
    + +
     85class NIRPipeline(Pipeline):
    + 86    run_config: GenericConfig
    + 87
    + 88    def __init__(self, *args, **kwargs):
    + 89        super().__init__(*args, **kwargs)
    + 90
    + 91    async def prehook(self):
    + 92        if self.run_config.automatic or self.run_config.norm_weight == "automatic":
    + 93            logger.info(f"Running initial BM25 for query adjustment")
    + 94            await self.engine.run_automatic_adjustment()
    + 95
    + 96    async def run_engine(self, *args, **kwargs):
    + 97        # Run bm25 nir adjustment
    + 98        logger.info(f"Running {self.run_config.query_type} queries")
    + 99
    +100        return await self.engine.run_all_queries(*args, return_results=True, **kwargs)
    +101
    +102    async def posthook(self, *args, **kwargs):
    +103        pass
    +104
    +105    async def run_pipeline(self, *args, return_results=False, **kwargs):
    +106        for cb in self.callbacks:
    +107            cb.before(self)
    +108
    +109        await self.prehook()
    +110        results = await self.run_engine(*args, **kwargs)
    +111        results = Results(results, self.engine.query, self.engine_name)
    +112
    +113        for cb in self.callbacks:
    +114            cb.after(results)
    +115
    +116        return results
    +117
    +118    def register_callback(self, cb):
    +119        self.callbacks.append(cb)
    +
    + + + + +
    + +
    + + NIRPipeline(*args, **kwargs) + + + +
    + +
    88    def __init__(self, *args, **kwargs):
    +89        super().__init__(*args, **kwargs)
    +
    + + + + +
    +
    + +
    + + async def + prehook(self): + + + +
    + +
    91    async def prehook(self):
    +92        if self.run_config.automatic or self.run_config.norm_weight == "automatic":
    +93            logger.info(f"Running initial BM25 for query adjustment")
    +94            await self.engine.run_automatic_adjustment()
    +
    + + + + +
    +
    + +
    + + async def + run_engine(self, *args, **kwargs): + + + +
    + +
     96    async def run_engine(self, *args, **kwargs):
    + 97        # Run bm25 nir adjustment
    + 98        logger.info(f"Running {self.run_config.query_type} queries")
    + 99
    +100        return await self.engine.run_all_queries(*args, return_results=True, **kwargs)
    +
    + + + + +
    +
    + +
    + + async def + posthook(self, *args, **kwargs): + + + +
    + +
    102    async def posthook(self, *args, **kwargs):
    +103        pass
    +
    + + + + +
    +
    + +
    + + async def + run_pipeline(self, *args, return_results=False, **kwargs): + + + +
    + +
    105    async def run_pipeline(self, *args, return_results=False, **kwargs):
    +106        for cb in self.callbacks:
    +107            cb.before(self)
    +108
    +109        await self.prehook()
    +110        results = await self.run_engine(*args, **kwargs)
    +111        results = Results(results, self.engine.query, self.engine_name)
    +112
    +113        for cb in self.callbacks:
    +114            cb.after(results)
    +115
    +116        return results
    +
    + + + + +
    +
    + +
    + + def + register_callback(self, cb): + + + +
    + +
    118    def register_callback(self, cb):
    +119        self.callbacks.append(cb)
    +
    + + + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    + + class + BM25Pipeline(NIRPipeline): + + + +
    + +
    122class BM25Pipeline(NIRPipeline):
    +123    async def run_pipeline(self, *args, return_results=False, **kwargs):
    +124        for cb in self.callbacks:
    +125            cb.before(self)
    +126
    +127        results = await self.engine.run_all_queries(query_type="query",
    +128                                                    return_results=True)
    +129
    +130        results = Results(results, self.engine.query, self.engine_name)
    +131
    +132        for cb in self.callbacks:
    +133            cb.after(results)
    +134
    +135        return results
    +
    + + + + +
    + +
    + + async def + run_pipeline(self, *args, return_results=False, **kwargs): + + + +
    + +
    123    async def run_pipeline(self, *args, return_results=False, **kwargs):
    +124        for cb in self.callbacks:
    +125            cb.before(self)
    +126
    +127        results = await self.engine.run_all_queries(query_type="query",
    +128                                                    return_results=True)
    +129
    +130        results = Results(results, self.engine.query, self.engine_name)
    +131
    +132        for cb in self.callbacks:
    +133            cb.after(results)
    +134
    +135        return results
    +
    + + + + +
    + +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/core/query.html b/docs/debeir/core/query.html new file mode 100644 index 0000000..99ba7aa --- /dev/null +++ b/docs/debeir/core/query.html @@ -0,0 +1,934 @@ + + + + + + + debeir.core.query API documentation + + + + + + + + + +
    +
    +

    +debeir.core.query

    + + + + + + +
      1import dataclasses
    +  2from typing import Dict, Optional, Union
    +  3
    +  4import loguru
    +  5from debeir.engines.elasticsearch.generate_script_score import generate_script
    +  6from debeir.core.config import GenericConfig, apply_config
    +  7from debeir.utils.scaler import get_z_value
    +  8
    +  9
    + 10@dataclasses.dataclass(init=True)
    + 11class Query:
    + 12    """
    + 13    A query interface class
    + 14    :param topics: Topics that the query will be composed of
    + 15    :param config: Config object that contains the settings for querying
    + 16    """
    + 17    topics: Dict[int, Dict[str, str]]
    + 18    config: GenericConfig
    + 19
    + 20
    + 21class GenericElasticsearchQuery(Query):
    + 22    """
    + 23    A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries.
    + 24    Requires topics, configs to be included
    + 25    """
    + 26    id_mapping: str = "Id"
    + 27
    + 28    def __init__(self, topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs):
    + 29        super().__init__(topics, config)
    + 30
    + 31        if id_mapping is None:
    + 32            self.id_mapping = "id"
    + 33
    + 34        if mappings is None:
    + 35            self.mappings = ["Text"]
    + 36        else:
    + 37            self.mappings = mappings
    + 38
    + 39        self.topics = topics
    + 40        self.config = config
    + 41        self.query_type = self.config.query_type
    + 42
    + 43        self.embed_mappings = ["Text_Embedding"]
    + 44
    + 45        self.query_funcs = {
    + 46            "query": self.generate_query,
    + 47            "embedding": self.generate_query_embedding,
    + 48        }
    + 49
    + 50        self.top_bm25_scores = top_bm25_scores
    + 51
    + 52    def _generate_base_query(self, topic_num):
    + 53        qfield = list(self.topics[topic_num].keys())[0]
    + 54        query = self.topics[topic_num][qfield]
    + 55        should = {"should": []}
    + 56
    + 57        for i, field in enumerate(self.mappings):
    + 58            should["should"].append(
    + 59                {
    + 60                    "match": {
    + 61                        f"{field}": {
    + 62                            "query": query,
    + 63                        }
    + 64                    }
    + 65                }
    + 66            )
    + 67
    + 68        return qfield, query, should
    + 69
    + 70    def generate_query(self, topic_num, *args, **kwargs):
    + 71        """
    + 72        Generates a simple BM25 query based off the query facets. Searches over all the document facets.
    + 73        :param topic_num:
    + 74        :param args:
    + 75        :param kwargs:
    + 76        :return:
    + 77        """
    + 78        _, _, should = self._generate_base_query(topic_num)
    + 79
    + 80        query = {
    + 81            "query": {
    + 82                "bool": should,
    + 83            }
    + 84        }
    + 85
    + 86        return query
    + 87
    + 88    def set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]):
    + 89        """
    + 90        Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used
    + 91        for log normalization.
    + 92
    + 93        Score = log(bm25)/log(z) + embed_score
    + 94        :param scores: Top BM25 Scores of the form {topic_num: top_bm25_score}
    + 95        """
    + 96        self.top_bm25_scores = scores
    + 97
    + 98    def has_bm25_scores(self):
    + 99        """
    +100        Checks if BM25 scores have been set
    +101        :return:
    +102        """
    +103        return self.top_bm25_scores is not None
    +104
    +105    @apply_config
    +106    def generate_query_embedding(
    +107            self, topic_num, encoder, *args, norm_weight=2.15, ablations=False, cosine_ceiling=Optional[float],
    +108            cosine_offset: float = 1.0, **kwargs):
    +109        """
    +110        Generates an embedding script score query for Elasticsearch as part of the NIR scoring function.
    +111
    +112        :param topic_num: The topic number to search for
    +113        :param encoder: The encoder that will be used for encoding the topics
    +114        :param norm_weight: The BM25 log normalization constant
    +115        :param ablations: Whether to execute ablation style queries (i.e. one query facet
    +116                          or one document facet at a time)
    +117        :param cosine_ceiling: Cosine ceiling used for automatic z-log normalization parameter calculation
    +118        :param args:
    +119        :param kwargs: Pass disable_cache to disable encoder caching
    +120        :return:
    +121            An elasticsearch script_score query
    +122        """
    +123
    +124        qfields = list(self.topics[topic_num].keys())
    +125        should = {"should": []}
    +126
    +127        if self.has_bm25_scores():
    +128            cosine_ceiling = len(self.embed_mappings) * len(qfields) if cosine_ceiling is None else cosine_ceiling
    +129            norm_weight = get_z_value(
    +130                cosine_ceiling=cosine_ceiling,
    +131                bm25_ceiling=self.top_bm25_scores[topic_num],
    +132            )
    +133            loguru.logger.debug(f"Automatic norm_weight: {norm_weight}")
    +134
    +135        params = {
    +136            "weights": [1] * (len(self.embed_mappings) * len(self.mappings)),
    +137            "offset": cosine_offset,
    +138            "norm_weight": norm_weight,
    +139            "disable_bm25": ablations,
    +140        }
    +141
    +142        embed_fields = []
    +143
    +144        for qfield in qfields:
    +145            for field in self.mappings:
    +146                should["should"].append(
    +147                    {
    +148                        "match": {
    +149                            f"{field}": {
    +150                                "query": self.topics[topic_num][qfield],
    +151                            }
    +152                        }
    +153                    }
    +154                )
    +155
    +156            params[f"{qfield}_eb"] = encoder.encode(topic=self.topics[topic_num][qfield])
    +157            embed_fields.append(f"{qfield}_eb")
    +158
    +159        query = {
    +160            "query": {
    +161                "script_score": {
    +162                    "query": {
    +163                        "bool": should,
    +164                    },
    +165                    "script": generate_script(
    +166                        self.embed_mappings, params, qfields=embed_fields
    +167                    ),
    +168                }
    +169            }
    +170        }
    +171
    +172        loguru.logger.debug(query)
    +173        return query
    +174
    +175    @classmethod
    +176    def get_id_mapping(cls, hit):
    +177        """
    +178        Get the document ID
    +179
    +180        :param hit: The raw document result
    +181        :return:
    +182            The document's ID
    +183        """
    +184        return hit[cls.id_mapping]
    +
    + + +
    +
    + +
    +
    @dataclasses.dataclass(init=True)
    + + class + Query: + + + +
    + +
    11@dataclasses.dataclass(init=True)
    +12class Query:
    +13    """
    +14    A query interface class
    +15    :param topics: Topics that the query will be composed of
    +16    :param config: Config object that contains the settings for querying
    +17    """
    +18    topics: Dict[int, Dict[str, str]]
    +19    config: GenericConfig
    +
    + + +

    A query interface class

    + +
    Parameters
    + +
      +
    • topics: Topics that the query will be composed of
    • +
    • config: Config object that contains the settings for querying
    • +
    +
    + + +
    +
    + + Query( topics: Dict[int, Dict[str, str]], config: debeir.core.config.GenericConfig) + + +
    + + + + +
    +
    +
    + +
    + + class + GenericElasticsearchQuery(Query): + + + +
    + +
     22class GenericElasticsearchQuery(Query):
    + 23    """
    + 24    A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries.
    + 25    Requires topics, configs to be included
    + 26    """
    + 27    id_mapping: str = "Id"
    + 28
    + 29    def __init__(self, topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs):
    + 30        super().__init__(topics, config)
    + 31
    + 32        if id_mapping is None:
    + 33            self.id_mapping = "id"
    + 34
    + 35        if mappings is None:
    + 36            self.mappings = ["Text"]
    + 37        else:
    + 38            self.mappings = mappings
    + 39
    + 40        self.topics = topics
    + 41        self.config = config
    + 42        self.query_type = self.config.query_type
    + 43
    + 44        self.embed_mappings = ["Text_Embedding"]
    + 45
    + 46        self.query_funcs = {
    + 47            "query": self.generate_query,
    + 48            "embedding": self.generate_query_embedding,
    + 49        }
    + 50
    + 51        self.top_bm25_scores = top_bm25_scores
    + 52
    + 53    def _generate_base_query(self, topic_num):
    + 54        qfield = list(self.topics[topic_num].keys())[0]
    + 55        query = self.topics[topic_num][qfield]
    + 56        should = {"should": []}
    + 57
    + 58        for i, field in enumerate(self.mappings):
    + 59            should["should"].append(
    + 60                {
    + 61                    "match": {
    + 62                        f"{field}": {
    + 63                            "query": query,
    + 64                        }
    + 65                    }
    + 66                }
    + 67            )
    + 68
    + 69        return qfield, query, should
    + 70
    + 71    def generate_query(self, topic_num, *args, **kwargs):
    + 72        """
    + 73        Generates a simple BM25 query based off the query facets. Searches over all the document facets.
    + 74        :param topic_num:
    + 75        :param args:
    + 76        :param kwargs:
    + 77        :return:
    + 78        """
    + 79        _, _, should = self._generate_base_query(topic_num)
    + 80
    + 81        query = {
    + 82            "query": {
    + 83                "bool": should,
    + 84            }
    + 85        }
    + 86
    + 87        return query
    + 88
    + 89    def set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]):
    + 90        """
    + 91        Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used
    + 92        for log normalization.
    + 93
    + 94        Score = log(bm25)/log(z) + embed_score
    + 95        :param scores: Top BM25 Scores of the form {topic_num: top_bm25_score}
    + 96        """
    + 97        self.top_bm25_scores = scores
    + 98
    + 99    def has_bm25_scores(self):
    +100        """
    +101        Checks if BM25 scores have been set
    +102        :return:
    +103        """
    +104        return self.top_bm25_scores is not None
    +105
    +106    @apply_config
    +107    def generate_query_embedding(
    +108            self, topic_num, encoder, *args, norm_weight=2.15, ablations=False, cosine_ceiling=Optional[float],
    +109            cosine_offset: float = 1.0, **kwargs):
    +110        """
    +111        Generates an embedding script score query for Elasticsearch as part of the NIR scoring function.
    +112
    +113        :param topic_num: The topic number to search for
    +114        :param encoder: The encoder that will be used for encoding the topics
    +115        :param norm_weight: The BM25 log normalization constant
    +116        :param ablations: Whether to execute ablation style queries (i.e. one query facet
    +117                          or one document facet at a time)
    +118        :param cosine_ceiling: Cosine ceiling used for automatic z-log normalization parameter calculation
    +119        :param args:
    +120        :param kwargs: Pass disable_cache to disable encoder caching
    +121        :return:
    +122            An elasticsearch script_score query
    +123        """
    +124
    +125        qfields = list(self.topics[topic_num].keys())
    +126        should = {"should": []}
    +127
    +128        if self.has_bm25_scores():
    +129            cosine_ceiling = len(self.embed_mappings) * len(qfields) if cosine_ceiling is None else cosine_ceiling
    +130            norm_weight = get_z_value(
    +131                cosine_ceiling=cosine_ceiling,
    +132                bm25_ceiling=self.top_bm25_scores[topic_num],
    +133            )
    +134            loguru.logger.debug(f"Automatic norm_weight: {norm_weight}")
    +135
    +136        params = {
    +137            "weights": [1] * (len(self.embed_mappings) * len(self.mappings)),
    +138            "offset": cosine_offset,
    +139            "norm_weight": norm_weight,
    +140            "disable_bm25": ablations,
    +141        }
    +142
    +143        embed_fields = []
    +144
    +145        for qfield in qfields:
    +146            for field in self.mappings:
    +147                should["should"].append(
    +148                    {
    +149                        "match": {
    +150                            f"{field}": {
    +151                                "query": self.topics[topic_num][qfield],
    +152                            }
    +153                        }
    +154                    }
    +155                )
    +156
    +157            params[f"{qfield}_eb"] = encoder.encode(topic=self.topics[topic_num][qfield])
    +158            embed_fields.append(f"{qfield}_eb")
    +159
    +160        query = {
    +161            "query": {
    +162                "script_score": {
    +163                    "query": {
    +164                        "bool": should,
    +165                    },
    +166                    "script": generate_script(
    +167                        self.embed_mappings, params, qfields=embed_fields
    +168                    ),
    +169                }
    +170            }
    +171        }
    +172
    +173        loguru.logger.debug(query)
    +174        return query
    +175
    +176    @classmethod
    +177    def get_id_mapping(cls, hit):
    +178        """
    +179        Get the document ID
    +180
    +181        :param hit: The raw document result
    +182        :return:
    +183            The document's ID
    +184        """
    +185        return hit[cls.id_mapping]
    +
    + + +

    A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries. +Requires topics, configs to be included

    +
    + + +
    + +
    + + GenericElasticsearchQuery( topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs) + + + +
    + +
    29    def __init__(self, topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs):
    +30        super().__init__(topics, config)
    +31
    +32        if id_mapping is None:
    +33            self.id_mapping = "id"
    +34
    +35        if mappings is None:
    +36            self.mappings = ["Text"]
    +37        else:
    +38            self.mappings = mappings
    +39
    +40        self.topics = topics
    +41        self.config = config
    +42        self.query_type = self.config.query_type
    +43
    +44        self.embed_mappings = ["Text_Embedding"]
    +45
    +46        self.query_funcs = {
    +47            "query": self.generate_query,
    +48            "embedding": self.generate_query_embedding,
    +49        }
    +50
    +51        self.top_bm25_scores = top_bm25_scores
    +
    + + + + +
    +
    + +
    + + def + generate_query(self, topic_num, *args, **kwargs): + + + +
    + +
    71    def generate_query(self, topic_num, *args, **kwargs):
    +72        """
    +73        Generates a simple BM25 query based off the query facets. Searches over all the document facets.
    +74        :param topic_num:
    +75        :param args:
    +76        :param kwargs:
    +77        :return:
    +78        """
    +79        _, _, should = self._generate_base_query(topic_num)
    +80
    +81        query = {
    +82            "query": {
    +83                "bool": should,
    +84            }
    +85        }
    +86
    +87        return query
    +
    + + +

    Generates a simple BM25 query based off the query facets. Searches over all the document facets.

    + +
    Parameters
    + +
      +
    • topic_num:
    • +
    • args:
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + def + set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]): + + + +
    + +
    89    def set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]):
    +90        """
    +91        Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used
    +92        for log normalization.
    +93
    +94        Score = log(bm25)/log(z) + embed_score
    +95        :param scores: Top BM25 Scores of the form {topic_num: top_bm25_score}
    +96        """
    +97        self.top_bm25_scores = scores
    +
    + + +

    Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used +for log normalization.

    + +

    Score = log(bm25)/log(z) + embed_score

    + +
    Parameters
    + +
      +
    • scores: Top BM25 Scores of the form {topic_num: top_bm25_score}
    • +
    +
    + + +
    +
    + +
    + + def + has_bm25_scores(self): + + + +
    + +
     99    def has_bm25_scores(self):
    +100        """
    +101        Checks if BM25 scores have been set
    +102        :return:
    +103        """
    +104        return self.top_bm25_scores is not None
    +
    + + +

    Checks if BM25 scores have been set

    + +
    Returns
    +
    + + +
    +
    + +
    + + def + generate_query_embedding(self, *args, **kwargs): + + + +
    + +
    229    def use_config(self, *args, **kwargs):
    +230        """
    +231        Replaces keywords and args passed to the function with ones from self.config.
    +232
    +233        :param self:
    +234        :param args: To be updated
    +235        :param kwargs: To be updated
    +236        :return:
    +237        """
    +238        if self.config is not None:
    +239            kwargs = self.config.__update__(**kwargs)
    +240
    +241        return func(self, *args, **kwargs)
    +
    + + +

    Generates an embedding script score query for Elasticsearch as part of the NIR scoring function.

    + +
    Parameters
    + +
      +
    • topic_num: The topic number to search for
    • +
    • encoder: The encoder that will be used for encoding the topics
    • +
    • norm_weight: The BM25 log normalization constant
    • +
    • ablations: Whether to execute ablation style queries (i.e. one query facet +or one document facet at a time)
    • +
    • cosine_ceiling: Cosine ceiling used for automatic z-log normalization parameter calculation
    • +
    • args:
    • +
    • kwargs: Pass disable_cache to disable encoder caching
    • +
    + +
    Returns
    + +
    +
    An elasticsearch script_score query
    +
    +
    +
    + + +
    +
    + +
    +
    @classmethod
    + + def + get_id_mapping(cls, hit): + + + +
    + +
    176    @classmethod
    +177    def get_id_mapping(cls, hit):
    +178        """
    +179        Get the document ID
    +180
    +181        :param hit: The raw document result
    +182        :return:
    +183            The document's ID
    +184        """
    +185        return hit[cls.id_mapping]
    +
    + + +

    Get the document ID

    + +
    Parameters
    + +
      +
    • hit: The raw document result
    • +
    + +
    Returns
    + +
    +
    The document's ID
    +
    +
    +
    + + +
    +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/core/results.html b/docs/debeir/core/results.html new file mode 100644 index 0000000..69d016f --- /dev/null +++ b/docs/debeir/core/results.html @@ -0,0 +1,455 @@ + + + + + + + debeir.core.results API documentation + + + + + + + + + +
    +
    +

    +debeir.core.results

    + + + + + + +
     1from typing import List
    + 2
    + 3from debeir.core.document import Document, document_factory
    + 4
    + 5LAZY_STATIC_DOCUMENT_KEY = "document_objects"
    + 6LAZY_STATIC_DOCUMENT_TOPICS = "document_topics"
    + 7LAZY_STATIC_DOCUMENT_HASHMAP = "document_topics"
    + 8
    + 9
    +10class Results:
    +11    document_cls: Document
    +12
    +13    def __init__(self, results: List, query_cls, engine_name):
    +14        self.results = results
    +15        self.document_cls: Document = document_factory[engine_name]
    +16        self.__doc_cur = 0
    +17        self.__topic_num = None
    +18        self.lazy_static = {}
    +19        self.query_cls = query_cls
    +20        self.topic_flag = False
    +21
    +22    def _as_documents(self, recompile=False):
    +23        if recompile or 'document_objects' not in self.lazy_static:
    +24            self.lazy_static[LAZY_STATIC_DOCUMENT_KEY] = self.document_cls.from_results(self.results,
    +25                                                                                        self.query_cls,
    +26                                                                                        ignore_facets=False)
    +27            self.lazy_static[LAZY_STATIC_DOCUMENT_TOPICS] = list(self.lazy_static[LAZY_STATIC_DOCUMENT_KEY].keys())
    +28
    +29        return self.lazy_static[LAZY_STATIC_DOCUMENT_KEY]
    +30
    +31    def get_topic_ids(self):
    +32        if LAZY_STATIC_DOCUMENT_KEY not in self.lazy_static:
    +33            self._as_documents()
    +34
    +35        return self.lazy_static[LAZY_STATIC_DOCUMENT_TOPICS]
    +36
    +37    def __iter__(self):
    +38        self._as_documents()
    +39        self.__doc_cur = 0
    +40
    +41        if not self.__topic_num:
    +42            self.__topic_num = 0
    +43
    +44        return self
    +45
    +46    def __next__(self):
    +47        if self.topic_flag:
    +48            topic_num = self.__topic_num
    +49        else:
    +50            topic_num = self.get_topic_ids()[self.__topic_num]
    +51
    +52        if self.__doc_cur >= len(self._as_documents()[topic_num]):
    +53            self.__doc_cur = 0
    +54            self.__topic_num += 1
    +55
    +56            if self.topic_flag or self.__topic_num >= len(self.get_topic_ids()):
    +57                raise StopIteration
    +58
    +59            topic_num = self.get_topic_ids()[self.__topic_num]
    +60
    +61        item = self._as_documents()[topic_num][self.__doc_cur]
    +62        self.__doc_cur += 1
    +63
    +64        return item
    +65
    +66    def __call__(self, topic_num=None):
    +67        self.__topic_num = topic_num
    +68        if topic_num:
    +69            self.topic_flag = True
    +70
    +71        return self
    +72
    +73    def __getitem__(self, item):
    +74        return self._as_documents()[item]
    +
    + + +
    +
    + +
    + + class + Results: + + + +
    + +
    11class Results:
    +12    document_cls: Document
    +13
    +14    def __init__(self, results: List, query_cls, engine_name):
    +15        self.results = results
    +16        self.document_cls: Document = document_factory[engine_name]
    +17        self.__doc_cur = 0
    +18        self.__topic_num = None
    +19        self.lazy_static = {}
    +20        self.query_cls = query_cls
    +21        self.topic_flag = False
    +22
    +23    def _as_documents(self, recompile=False):
    +24        if recompile or 'document_objects' not in self.lazy_static:
    +25            self.lazy_static[LAZY_STATIC_DOCUMENT_KEY] = self.document_cls.from_results(self.results,
    +26                                                                                        self.query_cls,
    +27                                                                                        ignore_facets=False)
    +28            self.lazy_static[LAZY_STATIC_DOCUMENT_TOPICS] = list(self.lazy_static[LAZY_STATIC_DOCUMENT_KEY].keys())
    +29
    +30        return self.lazy_static[LAZY_STATIC_DOCUMENT_KEY]
    +31
    +32    def get_topic_ids(self):
    +33        if LAZY_STATIC_DOCUMENT_KEY not in self.lazy_static:
    +34            self._as_documents()
    +35
    +36        return self.lazy_static[LAZY_STATIC_DOCUMENT_TOPICS]
    +37
    +38    def __iter__(self):
    +39        self._as_documents()
    +40        self.__doc_cur = 0
    +41
    +42        if not self.__topic_num:
    +43            self.__topic_num = 0
    +44
    +45        return self
    +46
    +47    def __next__(self):
    +48        if self.topic_flag:
    +49            topic_num = self.__topic_num
    +50        else:
    +51            topic_num = self.get_topic_ids()[self.__topic_num]
    +52
    +53        if self.__doc_cur >= len(self._as_documents()[topic_num]):
    +54            self.__doc_cur = 0
    +55            self.__topic_num += 1
    +56
    +57            if self.topic_flag or self.__topic_num >= len(self.get_topic_ids()):
    +58                raise StopIteration
    +59
    +60            topic_num = self.get_topic_ids()[self.__topic_num]
    +61
    +62        item = self._as_documents()[topic_num][self.__doc_cur]
    +63        self.__doc_cur += 1
    +64
    +65        return item
    +66
    +67    def __call__(self, topic_num=None):
    +68        self.__topic_num = topic_num
    +69        if topic_num:
    +70            self.topic_flag = True
    +71
    +72        return self
    +73
    +74    def __getitem__(self, item):
    +75        return self._as_documents()[item]
    +
    + + + + +
    + +
    + + Results(results: List, query_cls, engine_name) + + + +
    + +
    14    def __init__(self, results: List, query_cls, engine_name):
    +15        self.results = results
    +16        self.document_cls: Document = document_factory[engine_name]
    +17        self.__doc_cur = 0
    +18        self.__topic_num = None
    +19        self.lazy_static = {}
    +20        self.query_cls = query_cls
    +21        self.topic_flag = False
    +
    + + + + +
    +
    + +
    + + def + get_topic_ids(self): + + + +
    + +
    32    def get_topic_ids(self):
    +33        if LAZY_STATIC_DOCUMENT_KEY not in self.lazy_static:
    +34            self._as_documents()
    +35
    +36        return self.lazy_static[LAZY_STATIC_DOCUMENT_TOPICS]
    +
    + + + + +
    +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/datasets.html b/docs/debeir/datasets.html new file mode 100644 index 0000000..7c0bfb4 --- /dev/null +++ b/docs/debeir/datasets.html @@ -0,0 +1,264 @@ + + + + + + + debeir.datasets API documentation + + + + + + + + + +
    +
    +

    +debeir.datasets

    + +

    Contains data_sets implemented from nir.interfaces

    + +
      +
    1. Parser (For reading data from files into a Dict object)
    2. +
    3. Query object (Generating queries) +
        +
      • These query objects can be very lightweight containing only the mappings of the index.
      • +
    4. +
    +
    + + + + + +
    1"""
    +2Contains data_sets implemented from nir.interfaces
    +31. Parser (For reading data from files into a Dict object)
    +42. Query object (Generating queries)
    +5    - These query objects can be very lightweight containing only the mappings of the index.
    +6"""
    +
    + + +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/datasets/bioreddit.html b/docs/debeir/datasets/bioreddit.html new file mode 100644 index 0000000..dae67d8 --- /dev/null +++ b/docs/debeir/datasets/bioreddit.html @@ -0,0 +1,548 @@ + + + + + + + debeir.datasets.bioreddit API documentation + + + + + + + + + +
    +
    +

    +debeir.datasets.bioreddit

    + + + + + + +
     1from typing import Dict
    + 2
    + 3from debeir.core.parser import CSVParser
    + 4from debeir.core.query import GenericElasticsearchQuery
    + 5
    + 6
    + 7class BioRedditSubmissionParser(CSVParser):
    + 8    """
    + 9    Parser for the BioReddit Submission Dataset
    +10    """
    +11    parse_fields = ["id", "body"]
    +12
    +13    @classmethod
    +14    def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
    +15        return super().get_topics(csvfile)
    +16
    +17
    +18class BioRedditCommentParser(CSVParser):
    +19    """
    +20    Parser for the BioReddit Comment Dataset
    +21    """
    +22    parse_fields = ["id", "parent_id", "selftext", "title"]
    +23
    +24    @classmethod
    +25    def get_topics(cls, csvfile) -> Dict[str, Dict[str, str]]:
    +26        topics = super().get_topics(csvfile)
    +27        temp = {}
    +28
    +29        for _, topic in topics.items():
    +30            topic["text"] = topic.pop("selftext")
    +31            topic["text2"] = topic.pop("title")
    +32            temp[topic["id"]] = topic
    +33
    +34        return temp
    +35
    +36
    +37class BioRedditElasticsearchQuery(GenericElasticsearchQuery):
    +38    """
    +39    Elasticsearch Query object for the BioReddit
    +40    """
    +41
    +42    def __init__(self, topics, config, *args, **kwargs):
    +43        super().__init__(topics, config, *args, **kwargs)
    +44        self.mappings = ["Text"]
    +45
    +46        self.topics = topics
    +47        self.config = config
    +48        self.query_type = self.config.query_type
    +49
    +50        self.embed_mappings = ["Text_Embedding"]
    +51
    +52        self.query_funcs = {
    +53            "query": self.generate_query,
    +54            "embedding": self.generate_query_embedding,
    +55        }
    +
    + + +
    +
    + +
    + + class + BioRedditSubmissionParser(debeir.core.parser.CSVParser): + + + +
    + +
     8class BioRedditSubmissionParser(CSVParser):
    + 9    """
    +10    Parser for the BioReddit Submission Dataset
    +11    """
    +12    parse_fields = ["id", "body"]
    +13
    +14    @classmethod
    +15    def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
    +16        return super().get_topics(csvfile)
    +
    + + +

    Parser for the BioReddit Submission Dataset

    +
    + + +
    + +
    +
    @classmethod
    + + def + get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]: + + + +
    + +
    14    @classmethod
    +15    def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
    +16        return super().get_topics(csvfile)
    +
    + + +

    Instance method for getting topics, forwards instance self parameters to the _get_topics class method.

    +
    + + +
    + +
    +
    + +
    + + class + BioRedditCommentParser(debeir.core.parser.CSVParser): + + + +
    + +
    19class BioRedditCommentParser(CSVParser):
    +20    """
    +21    Parser for the BioReddit Comment Dataset
    +22    """
    +23    parse_fields = ["id", "parent_id", "selftext", "title"]
    +24
    +25    @classmethod
    +26    def get_topics(cls, csvfile) -> Dict[str, Dict[str, str]]:
    +27        topics = super().get_topics(csvfile)
    +28        temp = {}
    +29
    +30        for _, topic in topics.items():
    +31            topic["text"] = topic.pop("selftext")
    +32            topic["text2"] = topic.pop("title")
    +33            temp[topic["id"]] = topic
    +34
    +35        return temp
    +
    + + +

    Parser for the BioReddit Comment Dataset

    +
    + + +
    + +
    +
    @classmethod
    + + def + get_topics(cls, csvfile) -> Dict[str, Dict[str, str]]: + + + +
    + +
    25    @classmethod
    +26    def get_topics(cls, csvfile) -> Dict[str, Dict[str, str]]:
    +27        topics = super().get_topics(csvfile)
    +28        temp = {}
    +29
    +30        for _, topic in topics.items():
    +31            topic["text"] = topic.pop("selftext")
    +32            topic["text2"] = topic.pop("title")
    +33            temp[topic["id"]] = topic
    +34
    +35        return temp
    +
    + + +

    Instance method for getting topics, forwards instance self parameters to the _get_topics class method.

    +
    + + +
    + +
    +
    + +
    + + class + BioRedditElasticsearchQuery(debeir.core.query.GenericElasticsearchQuery): + + + +
    + +
    38class BioRedditElasticsearchQuery(GenericElasticsearchQuery):
    +39    """
    +40    Elasticsearch Query object for the BioReddit
    +41    """
    +42
    +43    def __init__(self, topics, config, *args, **kwargs):
    +44        super().__init__(topics, config, *args, **kwargs)
    +45        self.mappings = ["Text"]
    +46
    +47        self.topics = topics
    +48        self.config = config
    +49        self.query_type = self.config.query_type
    +50
    +51        self.embed_mappings = ["Text_Embedding"]
    +52
    +53        self.query_funcs = {
    +54            "query": self.generate_query,
    +55            "embedding": self.generate_query_embedding,
    +56        }
    +
    + + +

    Elasticsearch Query object for the BioReddit

    +
    + + +
    + +
    + + BioRedditElasticsearchQuery(topics, config, *args, **kwargs) + + + +
    + +
    43    def __init__(self, topics, config, *args, **kwargs):
    +44        super().__init__(topics, config, *args, **kwargs)
    +45        self.mappings = ["Text"]
    +46
    +47        self.topics = topics
    +48        self.config = config
    +49        self.query_type = self.config.query_type
    +50
    +51        self.embed_mappings = ["Text_Embedding"]
    +52
    +53        self.query_funcs = {
    +54            "query": self.generate_query,
    +55            "embedding": self.generate_query_embedding,
    +56        }
    +
    + + + + +
    + +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/datasets/clinical_trials.html b/docs/debeir/datasets/clinical_trials.html new file mode 100644 index 0000000..07189b0 --- /dev/null +++ b/docs/debeir/datasets/clinical_trials.html @@ -0,0 +1,2094 @@ + + + + + + + debeir.datasets.clinical_trials API documentation + + + + + + + + + +
    +
    +

    +debeir.datasets.clinical_trials

    + + + + + + +
      1import csv
    +  2from dataclasses import dataclass
    +  3from typing import Dict, List, Optional, Union
    +  4
    +  5import loguru
    +  6from debeir.engines.elasticsearch.generate_script_score import generate_script
    +  7from debeir.core.config import GenericConfig, apply_config
    +  8from debeir.core.executor import GenericElasticsearchExecutor
    +  9from debeir.core.parser import Parser
    + 10from debeir.core.query import GenericElasticsearchQuery
    + 11from debeir.rankers.transformer_sent_encoder import Encoder
    + 12from debeir.utils.scaler import get_z_value
    + 13from elasticsearch import AsyncElasticsearch as Elasticsearch
    + 14
    + 15
    + 16@dataclass(init=True, unsafe_hash=True)
    + 17class TrialsQueryConfig(GenericConfig):
    + 18    query_field_usage: str = None
    + 19    embed_field_usage: str = None
    + 20    fields: List[str] = None
    + 21
    + 22    def validate(self):
    + 23        """
    + 24        Checks if query type is included, and checks if an encoder is included for embedding queries
    + 25        """
    + 26        if self.query_type == "embedding":
    + 27            assert self.query_field_usage and self.embed_field_usage, (
    + 28                "Must have both field usages" " if embedding query"
    + 29            )
    + 30            assert (
    + 31                    self.encoder_fp and self.encoder
    + 32            ), "Must provide encoder path for embedding model"
    + 33            assert self.norm_weight is not None or self.automatic is not None, (
    + 34                "Norm weight be specified or be " "automatic "
    + 35            )
    + 36
    + 37        assert (
    + 38                self.query_field_usage is not None or self.fields is not None
    + 39        ), "Must have a query field"
    + 40        assert self.query_type in [
    + 41            "ablation",
    + 42            "query",
    + 43            "query_best",
    + 44            "embedding",
    + 45        ], "Check your query type"
    + 46
    + 47    @classmethod
    + 48    def from_toml(cls, fp: str, *args, **kwargs) -> "GenericConfig":
    + 49        return super().from_toml(fp, cls, *args, **kwargs)
    + 50
    + 51    @classmethod
    + 52    def from_dict(cls, **kwargs) -> "GenericConfig":
    + 53        return super().from_dict(cls, **kwargs)
    + 54
    + 55
    + 56class TrialsElasticsearchQuery(GenericElasticsearchQuery):
    + 57    """
    + 58    Elasticsearch Query object for the Clinical Trials Index
    + 59    """
    + 60    topics: Dict[int, Dict[str, str]]
    + 61    query_type: str
    + 62    fields: List[int]
    + 63    query_funcs: Dict
    + 64    config: GenericConfig
    + 65    id_mapping: str = "_id"
    + 66    mappings: List[str]
    + 67    config: TrialsQueryConfig
    + 68
    + 69    def __init__(self, topics, query_type, config=None, *args, **kwargs):
    + 70        super().__init__(topics, config, *args, **kwargs)
    + 71        self.query_type = query_type
    + 72        self.config = config
    + 73        self.topics = topics
    + 74        self.fields = []
    + 75        self.mappings = [
    + 76            "HasExpandedAccess",
    + 77            "BriefSummary.Textblock",
    + 78            "CompletionDate.Type",
    + 79            "OversightInfo.Text",
    + 80            "OverallContactBackup.PhoneExt",
    + 81            "RemovedCountries.Text",
    + 82            "SecondaryOutcome",
    + 83            "Sponsors.LeadSponsor.Text",
    + 84            "BriefTitle",
    + 85            "IDInfo.NctID",
    + 86            "IDInfo.SecondaryID",
    + 87            "OverallContactBackup.Phone",
    + 88            "Eligibility.StudyPop.Textblock",
    + 89            "DetailedDescription.Textblock",
    + 90            "Eligibility.MinimumAge",
    + 91            "Sponsors.Collaborator",
    + 92            "Reference",
    + 93            "Eligibility.Criteria.Textblock",
    + 94            "XMLName.Space",
    + 95            "Rank",
    + 96            "OverallStatus",
    + 97            "InterventionBrowse.Text",
    + 98            "Eligibility.Text",
    + 99            "Intervention",
    +100            "BiospecDescr.Textblock",
    +101            "ResponsibleParty.NameTitle",
    +102            "NumberOfArms",
    +103            "ResponsibleParty.ResponsiblePartyType",
    +104            "IsSection801",
    +105            "Acronym",
    +106            "Eligibility.MaximumAge",
    +107            "DetailedDescription.Text",
    +108            "StudyDesign",
    +109            "OtherOutcome",
    +110            "VerificationDate",
    +111            "ConditionBrowse.MeshTerm",
    +112            "Enrollment.Text",
    +113            "IDInfo.Text",
    +114            "ConditionBrowse.Text",
    +115            "FirstreceivedDate",
    +116            "NumberOfGroups",
    +117            "OversightInfo.HasDmc",
    +118            "PrimaryCompletionDate.Text",
    +119            "ResultsReference",
    +120            "Eligibility.StudyPop.Text",
    +121            "IsFdaRegulated",
    +122            "WhyStopped",
    +123            "ArmGroup",
    +124            "OverallContact.LastName",
    +125            "Phase",
    +126            "RemovedCountries.Country",
    +127            "InterventionBrowse.MeshTerm",
    +128            "Eligibility.HealthyVolunteers",
    +129            "Location",
    +130            "OfficialTitle",
    +131            "OverallContact.Email",
    +132            "RequiredHeader.Text",
    +133            "RequiredHeader.URL",
    +134            "LocationCountries.Country",
    +135            "OverallContact.PhoneExt",
    +136            "Condition",
    +137            "PrimaryOutcome",
    +138            "LocationCountries.Text",
    +139            "BiospecDescr.Text",
    +140            "IDInfo.OrgStudyID",
    +141            "Link",
    +142            "OverallContact.Phone",
    +143            "Source",
    +144            "ResponsibleParty.InvestigatorAffiliation",
    +145            "StudyType",
    +146            "FirstreceivedResultsDate",
    +147            "Enrollment.Type",
    +148            "Eligibility.Gender",
    +149            "OverallContactBackup.LastName",
    +150            "Keyword",
    +151            "BiospecRetention",
    +152            "CompletionDate.Text",
    +153            "OverallContact.Text",
    +154            "RequiredHeader.DownloadDate",
    +155            "Sponsors.Text",
    +156            "Text",
    +157            "Eligibility.SamplingMethod",
    +158            "LastchangedDate",
    +159            "ResponsibleParty.InvestigatorFullName",
    +160            "StartDate",
    +161            "RequiredHeader.LinkText",
    +162            "OverallOfficial",
    +163            "Sponsors.LeadSponsor.AgencyClass",
    +164            "OverallContactBackup.Text",
    +165            "Eligibility.Criteria.Text",
    +166            "XMLName.Local",
    +167            "OversightInfo.Authority",
    +168            "PrimaryCompletionDate.Type",
    +169            "ResponsibleParty.Organization",
    +170            "IDInfo.NctAlias",
    +171            "ResponsibleParty.Text",
    +172            "TargetDuration",
    +173            "Sponsors.LeadSponsor.Agency",
    +174            "BriefSummary.Text",
    +175            "OverallContactBackup.Email",
    +176            "ResponsibleParty.InvestigatorTitle",
    +177        ]
    +178
    +179        self.best_recall_fields = [
    +180            "LocationCountries.Country",
    +181            "BiospecRetention",
    +182            "DetailedDescription.Textblock",
    +183            "HasExpandedAccess",
    +184            "ConditionBrowse.MeshTerm",
    +185            "RequiredHeader.LinkText",
    +186            "WhyStopped",
    +187            "BriefSummary.Textblock",
    +188            "Eligibility.Criteria.Textblock",
    +189            "OfficialTitle",
    +190            "Eligibility.MaximumAge",
    +191            "Eligibility.StudyPop.Textblock",
    +192            "BiospecDescr.Textblock",
    +193            "BriefTitle",
    +194            "Eligibility.MinimumAge",
    +195            "ResponsibleParty.Organization",
    +196            "TargetDuration",
    +197            "Condition",
    +198            "IDInfo.OrgStudyID",
    +199            "Keyword",
    +200            "Source",
    +201            "Sponsors.LeadSponsor.Agency",
    +202            "ResponsibleParty.InvestigatorAffiliation",
    +203            "OversightInfo.Authority",
    +204            "OversightInfo.HasDmc",
    +205            "OverallContact.Phone",
    +206            "Phase",
    +207            "OverallContactBackup.LastName",
    +208            "Acronym",
    +209            "InterventionBrowse.MeshTerm",
    +210            "RemovedCountries.Country",
    +211        ]
    +212        self.best_map_fields = [
    +213            "Eligibility.Gender",
    +214            "LocationCountries.Country",
    +215            "DetailedDescription.Textblock",
    +216            "BriefSummary.Textblock",
    +217            "ConditionBrowse.MeshTerm",
    +218            "Eligibility.Criteria.Textblock",
    +219            "InterventionBrowse.MeshTerm",
    +220            "StudyType",
    +221            "IsFdaRegulated",
    +222            "HasExpandedAccess",
    +223            "RequiredHeader.LinkText",
    +224            "BiospecRetention",
    +225            "OfficialTitle",
    +226            "Eligibility.SamplingMethod",
    +227            "Eligibility.StudyPop.Textblock",
    +228            "Condition",
    +229            "Eligibility.MinimumAge",
    +230            "Keyword",
    +231            "Eligibility.MaximumAge",
    +232            "BriefTitle",
    +233        ]
    +234        self.best_embed_fields = [
    +235            "WhyStopped",
    +236            "HasExpandedAccess",
    +237            "BiospecRetention",
    +238            "BriefSummary.Textblock",
    +239            "LocationCountries.Country",
    +240            "ConditionBrowse.MeshTerm",
    +241            "DetailedDescription.Textblock",
    +242            "RequiredHeader.LinkText",
    +243            "Eligibility.Criteria.Textblock",
    +244        ]
    +245
    +246        self.sensible = [
    +247            "BriefSummary.Textblock" "BriefTitle",
    +248            "Eligibility.StudyPop.Textblock",
    +249            "DetailedDescription.Textblock",
    +250            "Eligibility.MinimumAge",
    +251            "Eligibility.Criteria.Textblock",
    +252            "InterventionBrowse.Text",
    +253            "Eligibility.Text",
    +254            "BiospecDescr.Textblock",
    +255            "Eligibility.MaximumAge",
    +256            "DetailedDescription.Text",
    +257            "ConditionBrowse.MeshTerm",
    +258            "ConditionBrowse.Text",
    +259            "Eligibility.StudyPop.Text",
    +260            "InterventionBrowse.MeshTerm",
    +261            "OfficialTitle",
    +262            "Condition",
    +263            "PrimaryOutcome",
    +264            "BiospecDescr.Text",
    +265            "Eligibility.Gender",
    +266            "Keyword",
    +267            "BiospecRetention",
    +268            "Eligibility.Criteria.Text",
    +269            "BriefSummary.Text",
    +270        ]
    +271
    +272        self.sensible_embed = [
    +273            "BriefSummary.Textblock" "BriefTitle",
    +274            "Eligibility.StudyPop.Textblock",
    +275            "DetailedDescription.Textblock",
    +276            "Eligibility.Criteria.Textblock",
    +277            "InterventionBrowse.Text",
    +278            "Eligibility.Text",
    +279            "BiospecDescr.Textblock",
    +280            "DetailedDescription.Text",
    +281            "ConditionBrowse.MeshTerm",
    +282            "ConditionBrowse.Text",
    +283            "Eligibility.StudyPop.Text",
    +284            "InterventionBrowse.MeshTerm",
    +285            "OfficialTitle",
    +286            "Condition",
    +287            "PrimaryOutcome",
    +288            "BiospecDescr.Text",
    +289            "Keyword",
    +290            "BiospecRetention",
    +291            "Eligibility.Criteria.Text",
    +292            "BriefSummary.Text",
    +293        ]
    +294
    +295        self.sensible_embed_safe = list(
    +296            set(self.best_recall_fields).intersection(set(self.sensible_embed))
    +297        )
    +298
    +299        self.query_funcs = {
    +300            "query": self.generate_query,
    +301            "ablation": self.generate_query_ablation,
    +302            "embedding": self.generate_query_embedding,
    +303        }
    +304
    +305        loguru.logger.debug(self.sensible_embed_safe)
    +306
    +307        self.field_usage = {
    +308            "best_recall_fields": self.best_recall_fields,
    +309            "all": self.mappings,
    +310            "best_map_fields": self.best_map_fields,
    +311            "best_embed_fields": self.best_embed_fields,
    +312            "sensible": self.sensible,
    +313            "sensible_embed": self.sensible_embed,
    +314            "sensible_embed_safe": self.sensible_embed_safe,
    +315        }
    +316
    +317    @apply_config
    +318    def generate_query(self, topic_num, query_field_usage, **kwargs) -> Dict:
    +319        """
    +320        Generates a query for the clinical trials index
    +321
    +322        :param topic_num: Topic number to search
    +323        :param query_field_usage: Which document facets to search over
    +324        :param kwargs:
    +325        :return:
    +326            A basic elasticsearch query for clinical trials
    +327        """
    +328        fields = self.field_usage[query_field_usage]
    +329        should = {"should": []}
    +330
    +331        qfield = list(self.topics[topic_num].keys())[0]
    +332        query = self.topics[topic_num][qfield]
    +333
    +334        for i, field in enumerate(fields):
    +335            should["should"].append(
    +336                {
    +337                    "match": {
    +338                        f"{field}": {
    +339                            "query": query,
    +340                        }
    +341                    }
    +342                }
    +343            )
    +344
    +345        query = {
    +346            "query": {
    +347                "bool": should,
    +348            }
    +349        }
    +350
    +351        return query
    +352
    +353    def generate_query_ablation(self, topic_num, **kwargs):
    +354        """
    +355        Only search one document facet at a time
    +356        :param topic_num:
    +357        :param kwargs:
    +358        :return:
    +359        """
    +360        query = {"query": {"match": {}}}
    +361
    +362        for field in self.fields:
    +363            query["query"]["match"][self.mappings[field]] = ""
    +364
    +365        for qfield in self.fields:
    +366            qfield = self.mappings[qfield]
    +367            for field in self.topics[topic_num]:
    +368                query["query"]["match"][qfield] += self.topics[topic_num][field]
    +369
    +370        return query
    +371
    +372    @apply_config
    +373    def generate_query_embedding(
    +374            self,
    +375            topic_num,
    +376            encoder,
    +377            query_field_usage,
    +378            embed_field_usage,
    +379            cosine_weights: List[float] = None,
    +380            query_weight: List[float] = None,
    +381            norm_weight=2.15,
    +382            ablations=False,
    +383            automatic_scores=None,
    +384            **kwargs,
    +385    ):
    +386        """
    +387        Computes the NIR score for a given topic
    +388
    +389        Score = log(BM25)/log(norm_weight) + embedding_score
    +390
    +391        :param topic_num:
    +392        :param encoder:
    +393        :param query_field_usage:
    +394        :param embed_field_usage:
    +395        :param cosine_weights:
    +396        :param query_weight:
    +397        :param norm_weight:
    +398        :param ablations:
    +399        :param automatic_scores:
    +400        :param kwargs:
    +401        :return:
    +402        """
    +403        should = {"should": []}
    +404
    +405        assert norm_weight or automatic_scores
    +406
    +407        query_fields = self.field_usage[query_field_usage]
    +408        embed_fields = self.field_usage[embed_field_usage]
    +409
    +410        qfield = list(self.topics[topic_num].keys())[0]
    +411        query = self.topics[topic_num][qfield]
    +412
    +413        for i, field in enumerate(query_fields):
    +414            should["should"].append(
    +415                {
    +416                    "match": {
    +417                        f"{field}": {
    +418                            "query": query,
    +419                            "boost": query_weight[i] if query_weight else 1,
    +420                        }
    +421                    }
    +422                }
    +423            )
    +424
    +425        if automatic_scores is not None:
    +426            norm_weight = get_z_value(
    +427                cosine_ceiling=len(embed_fields) * len(query_fields),
    +428                bm25_ceiling=automatic_scores[topic_num],
    +429            )
    +430
    +431        params = {
    +432            "weights": cosine_weights if cosine_weights else [1] * len(embed_fields),
    +433            "q_eb": encoder.encode(self.topics[topic_num][qfield]),
    +434            "offset": 1.0,
    +435            "norm_weight": norm_weight,
    +436            "disable_bm25": ablations,
    +437        }
    +438
    +439        query = {
    +440            "query": {
    +441                "script_score": {
    +442                    "query": {
    +443                        "bool": should,
    +444                    },
    +445                    "script": generate_script(self.best_embed_fields, params=params),
    +446                },
    +447            }
    +448        }
    +449
    +450        return query
    +451
    +452    def get_query_type(self, *args, **kwargs):
    +453        return self.query_funcs[self.query_type](*args, **kwargs)
    +454
    +455    def get_id_mapping(self, hit):
    +456        return hit[self.id_mapping]
    +457
    +458
    +459class ClinicalTrialsElasticsearchExecutor(GenericElasticsearchExecutor):
    +460    """
    +461    Executes queries given a query object.
    +462    """
    +463    query: TrialsElasticsearchQuery
    +464
    +465    def __init__(
    +466            self,
    +467            topics: Dict[Union[str, int], Dict[str, str]],
    +468            client: Elasticsearch,
    +469            index_name: str,
    +470            output_file: str,
    +471            query: TrialsElasticsearchQuery,
    +472            encoder: Optional[Encoder] = None,
    +473            config=None,
    +474            *args,
    +475            **kwargs,
    +476    ):
    +477        super().__init__(
    +478            topics,
    +479            client,
    +480            index_name,
    +481            output_file,
    +482            query,
    +483            encoder,
    +484            config=config,
    +485            *args,
    +486            **kwargs,
    +487        )
    +488
    +489        self.query_fns = {
    +490            "query": self.generate_query,
    +491            "ablation": self.generate_query_ablation,
    +492            "embedding": self.generate_embedding_query,
    +493        }
    +494
    +495
    +496class ClinicalTrialParser(Parser):
    +497    """
    +498    Parser for Clinical Trials topics
    +499    """
    +500
    +501    @classmethod
    +502    def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
    +503        topics = {}
    +504        reader = csv.reader(csvfile)
    +505        for i, row in enumerate(reader):
    +506            if i == 0:
    +507                continue
    +508
    +509            _id = row[0]
    +510            text = row[1]
    +511
    +512            topics[_id] = {"text": text}
    +513
    +514        return topics
    +
    + + +
    +
    + +
    +
    @dataclass(init=True, unsafe_hash=True)
    + + class + TrialsQueryConfig(debeir.core.config.GenericConfig): + + + +
    + +
    17@dataclass(init=True, unsafe_hash=True)
    +18class TrialsQueryConfig(GenericConfig):
    +19    query_field_usage: str = None
    +20    embed_field_usage: str = None
    +21    fields: List[str] = None
    +22
    +23    def validate(self):
    +24        """
    +25        Checks if query type is included, and checks if an encoder is included for embedding queries
    +26        """
    +27        if self.query_type == "embedding":
    +28            assert self.query_field_usage and self.embed_field_usage, (
    +29                "Must have both field usages" " if embedding query"
    +30            )
    +31            assert (
    +32                    self.encoder_fp and self.encoder
    +33            ), "Must provide encoder path for embedding model"
    +34            assert self.norm_weight is not None or self.automatic is not None, (
    +35                "Norm weight be specified or be " "automatic "
    +36            )
    +37
    +38        assert (
    +39                self.query_field_usage is not None or self.fields is not None
    +40        ), "Must have a query field"
    +41        assert self.query_type in [
    +42            "ablation",
    +43            "query",
    +44            "query_best",
    +45            "embedding",
    +46        ], "Check your query type"
    +47
    +48    @classmethod
    +49    def from_toml(cls, fp: str, *args, **kwargs) -> "GenericConfig":
    +50        return super().from_toml(fp, cls, *args, **kwargs)
    +51
    +52    @classmethod
    +53    def from_dict(cls, **kwargs) -> "GenericConfig":
    +54        return super().from_dict(cls, **kwargs)
    +
    + + + + +
    +
    + + TrialsQueryConfig( query_type: str, index: str = None, encoder_normalize: bool = True, ablations: bool = False, norm_weight: float = None, automatic: bool = None, encoder: object = None, encoder_fp: str = None, query_weights: List[float] = None, cosine_weights: List[float] = None, evaluate: bool = False, qrels: str = None, config_fn: str = None, query_fn: str = None, parser_fn: str = None, executor_fn: str = None, cosine_ceiling: float = None, topics_path: str = None, return_id_only: bool = False, overwrite_output_if_exists: bool = False, output_file: str = None, run_name: str = None, query_field_usage: str = None, embed_field_usage: str = None, fields: List[str] = None) + + +
    + + + + +
    +
    + +
    + + def + validate(self): + + + +
    + +
    23    def validate(self):
    +24        """
    +25        Checks if query type is included, and checks if an encoder is included for embedding queries
    +26        """
    +27        if self.query_type == "embedding":
    +28            assert self.query_field_usage and self.embed_field_usage, (
    +29                "Must have both field usages" " if embedding query"
    +30            )
    +31            assert (
    +32                    self.encoder_fp and self.encoder
    +33            ), "Must provide encoder path for embedding model"
    +34            assert self.norm_weight is not None or self.automatic is not None, (
    +35                "Norm weight be specified or be " "automatic "
    +36            )
    +37
    +38        assert (
    +39                self.query_field_usage is not None or self.fields is not None
    +40        ), "Must have a query field"
    +41        assert self.query_type in [
    +42            "ablation",
    +43            "query",
    +44            "query_best",
    +45            "embedding",
    +46        ], "Check your query type"
    +
    + + +

    Checks if query type is included, and checks if an encoder is included for embedding queries

    +
    + + +
    +
    + +
    +
    @classmethod
    + + def + from_toml(cls, fp: str, *args, **kwargs) -> debeir.core.config.GenericConfig: + + + +
    + +
    48    @classmethod
    +49    def from_toml(cls, fp: str, *args, **kwargs) -> "GenericConfig":
    +50        return super().from_toml(fp, cls, *args, **kwargs)
    +
    + + +

    Instantiates a Config object from a toml file

    + +
    Parameters
    + +
      +
    • fp: File path of the Config TOML file
    • +
    • field_class: Class of the Config object to be instantiated
    • +
    • args: Arguments to be passed to Config
    • +
    • kwargs: Keyword arguments to be passed
    • +
    + +
    Returns
    + +
    +
    A instantiated and validated Config object.
    +
    +
    +
    + + +
    +
    + +
    +
    @classmethod
    + + def + from_dict(cls, **kwargs) -> debeir.core.config.GenericConfig: + + + +
    + +
    52    @classmethod
    +53    def from_dict(cls, **kwargs) -> "GenericConfig":
    +54        return super().from_dict(cls, **kwargs)
    +
    + + +

    Instantiates a Config object from a dictionary

    + +
    Parameters
    + +
      +
    • data_class:
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    + + class + TrialsElasticsearchQuery(debeir.core.query.GenericElasticsearchQuery): + + + +
    + +
     57class TrialsElasticsearchQuery(GenericElasticsearchQuery):
    + 58    """
    + 59    Elasticsearch Query object for the Clinical Trials Index
    + 60    """
    + 61    topics: Dict[int, Dict[str, str]]
    + 62    query_type: str
    + 63    fields: List[int]
    + 64    query_funcs: Dict
    + 65    config: GenericConfig
    + 66    id_mapping: str = "_id"
    + 67    mappings: List[str]
    + 68    config: TrialsQueryConfig
    + 69
    + 70    def __init__(self, topics, query_type, config=None, *args, **kwargs):
    + 71        super().__init__(topics, config, *args, **kwargs)
    + 72        self.query_type = query_type
    + 73        self.config = config
    + 74        self.topics = topics
    + 75        self.fields = []
    + 76        self.mappings = [
    + 77            "HasExpandedAccess",
    + 78            "BriefSummary.Textblock",
    + 79            "CompletionDate.Type",
    + 80            "OversightInfo.Text",
    + 81            "OverallContactBackup.PhoneExt",
    + 82            "RemovedCountries.Text",
    + 83            "SecondaryOutcome",
    + 84            "Sponsors.LeadSponsor.Text",
    + 85            "BriefTitle",
    + 86            "IDInfo.NctID",
    + 87            "IDInfo.SecondaryID",
    + 88            "OverallContactBackup.Phone",
    + 89            "Eligibility.StudyPop.Textblock",
    + 90            "DetailedDescription.Textblock",
    + 91            "Eligibility.MinimumAge",
    + 92            "Sponsors.Collaborator",
    + 93            "Reference",
    + 94            "Eligibility.Criteria.Textblock",
    + 95            "XMLName.Space",
    + 96            "Rank",
    + 97            "OverallStatus",
    + 98            "InterventionBrowse.Text",
    + 99            "Eligibility.Text",
    +100            "Intervention",
    +101            "BiospecDescr.Textblock",
    +102            "ResponsibleParty.NameTitle",
    +103            "NumberOfArms",
    +104            "ResponsibleParty.ResponsiblePartyType",
    +105            "IsSection801",
    +106            "Acronym",
    +107            "Eligibility.MaximumAge",
    +108            "DetailedDescription.Text",
    +109            "StudyDesign",
    +110            "OtherOutcome",
    +111            "VerificationDate",
    +112            "ConditionBrowse.MeshTerm",
    +113            "Enrollment.Text",
    +114            "IDInfo.Text",
    +115            "ConditionBrowse.Text",
    +116            "FirstreceivedDate",
    +117            "NumberOfGroups",
    +118            "OversightInfo.HasDmc",
    +119            "PrimaryCompletionDate.Text",
    +120            "ResultsReference",
    +121            "Eligibility.StudyPop.Text",
    +122            "IsFdaRegulated",
    +123            "WhyStopped",
    +124            "ArmGroup",
    +125            "OverallContact.LastName",
    +126            "Phase",
    +127            "RemovedCountries.Country",
    +128            "InterventionBrowse.MeshTerm",
    +129            "Eligibility.HealthyVolunteers",
    +130            "Location",
    +131            "OfficialTitle",
    +132            "OverallContact.Email",
    +133            "RequiredHeader.Text",
    +134            "RequiredHeader.URL",
    +135            "LocationCountries.Country",
    +136            "OverallContact.PhoneExt",
    +137            "Condition",
    +138            "PrimaryOutcome",
    +139            "LocationCountries.Text",
    +140            "BiospecDescr.Text",
    +141            "IDInfo.OrgStudyID",
    +142            "Link",
    +143            "OverallContact.Phone",
    +144            "Source",
    +145            "ResponsibleParty.InvestigatorAffiliation",
    +146            "StudyType",
    +147            "FirstreceivedResultsDate",
    +148            "Enrollment.Type",
    +149            "Eligibility.Gender",
    +150            "OverallContactBackup.LastName",
    +151            "Keyword",
    +152            "BiospecRetention",
    +153            "CompletionDate.Text",
    +154            "OverallContact.Text",
    +155            "RequiredHeader.DownloadDate",
    +156            "Sponsors.Text",
    +157            "Text",
    +158            "Eligibility.SamplingMethod",
    +159            "LastchangedDate",
    +160            "ResponsibleParty.InvestigatorFullName",
    +161            "StartDate",
    +162            "RequiredHeader.LinkText",
    +163            "OverallOfficial",
    +164            "Sponsors.LeadSponsor.AgencyClass",
    +165            "OverallContactBackup.Text",
    +166            "Eligibility.Criteria.Text",
    +167            "XMLName.Local",
    +168            "OversightInfo.Authority",
    +169            "PrimaryCompletionDate.Type",
    +170            "ResponsibleParty.Organization",
    +171            "IDInfo.NctAlias",
    +172            "ResponsibleParty.Text",
    +173            "TargetDuration",
    +174            "Sponsors.LeadSponsor.Agency",
    +175            "BriefSummary.Text",
    +176            "OverallContactBackup.Email",
    +177            "ResponsibleParty.InvestigatorTitle",
    +178        ]
    +179
    +180        self.best_recall_fields = [
    +181            "LocationCountries.Country",
    +182            "BiospecRetention",
    +183            "DetailedDescription.Textblock",
    +184            "HasExpandedAccess",
    +185            "ConditionBrowse.MeshTerm",
    +186            "RequiredHeader.LinkText",
    +187            "WhyStopped",
    +188            "BriefSummary.Textblock",
    +189            "Eligibility.Criteria.Textblock",
    +190            "OfficialTitle",
    +191            "Eligibility.MaximumAge",
    +192            "Eligibility.StudyPop.Textblock",
    +193            "BiospecDescr.Textblock",
    +194            "BriefTitle",
    +195            "Eligibility.MinimumAge",
    +196            "ResponsibleParty.Organization",
    +197            "TargetDuration",
    +198            "Condition",
    +199            "IDInfo.OrgStudyID",
    +200            "Keyword",
    +201            "Source",
    +202            "Sponsors.LeadSponsor.Agency",
    +203            "ResponsibleParty.InvestigatorAffiliation",
    +204            "OversightInfo.Authority",
    +205            "OversightInfo.HasDmc",
    +206            "OverallContact.Phone",
    +207            "Phase",
    +208            "OverallContactBackup.LastName",
    +209            "Acronym",
    +210            "InterventionBrowse.MeshTerm",
    +211            "RemovedCountries.Country",
    +212        ]
    +213        self.best_map_fields = [
    +214            "Eligibility.Gender",
    +215            "LocationCountries.Country",
    +216            "DetailedDescription.Textblock",
    +217            "BriefSummary.Textblock",
    +218            "ConditionBrowse.MeshTerm",
    +219            "Eligibility.Criteria.Textblock",
    +220            "InterventionBrowse.MeshTerm",
    +221            "StudyType",
    +222            "IsFdaRegulated",
    +223            "HasExpandedAccess",
    +224            "RequiredHeader.LinkText",
    +225            "BiospecRetention",
    +226            "OfficialTitle",
    +227            "Eligibility.SamplingMethod",
    +228            "Eligibility.StudyPop.Textblock",
    +229            "Condition",
    +230            "Eligibility.MinimumAge",
    +231            "Keyword",
    +232            "Eligibility.MaximumAge",
    +233            "BriefTitle",
    +234        ]
    +235        self.best_embed_fields = [
    +236            "WhyStopped",
    +237            "HasExpandedAccess",
    +238            "BiospecRetention",
    +239            "BriefSummary.Textblock",
    +240            "LocationCountries.Country",
    +241            "ConditionBrowse.MeshTerm",
    +242            "DetailedDescription.Textblock",
    +243            "RequiredHeader.LinkText",
    +244            "Eligibility.Criteria.Textblock",
    +245        ]
    +246
    +247        self.sensible = [
    +248            "BriefSummary.Textblock" "BriefTitle",
    +249            "Eligibility.StudyPop.Textblock",
    +250            "DetailedDescription.Textblock",
    +251            "Eligibility.MinimumAge",
    +252            "Eligibility.Criteria.Textblock",
    +253            "InterventionBrowse.Text",
    +254            "Eligibility.Text",
    +255            "BiospecDescr.Textblock",
    +256            "Eligibility.MaximumAge",
    +257            "DetailedDescription.Text",
    +258            "ConditionBrowse.MeshTerm",
    +259            "ConditionBrowse.Text",
    +260            "Eligibility.StudyPop.Text",
    +261            "InterventionBrowse.MeshTerm",
    +262            "OfficialTitle",
    +263            "Condition",
    +264            "PrimaryOutcome",
    +265            "BiospecDescr.Text",
    +266            "Eligibility.Gender",
    +267            "Keyword",
    +268            "BiospecRetention",
    +269            "Eligibility.Criteria.Text",
    +270            "BriefSummary.Text",
    +271        ]
    +272
    +273        self.sensible_embed = [
    +274            "BriefSummary.Textblock" "BriefTitle",
    +275            "Eligibility.StudyPop.Textblock",
    +276            "DetailedDescription.Textblock",
    +277            "Eligibility.Criteria.Textblock",
    +278            "InterventionBrowse.Text",
    +279            "Eligibility.Text",
    +280            "BiospecDescr.Textblock",
    +281            "DetailedDescription.Text",
    +282            "ConditionBrowse.MeshTerm",
    +283            "ConditionBrowse.Text",
    +284            "Eligibility.StudyPop.Text",
    +285            "InterventionBrowse.MeshTerm",
    +286            "OfficialTitle",
    +287            "Condition",
    +288            "PrimaryOutcome",
    +289            "BiospecDescr.Text",
    +290            "Keyword",
    +291            "BiospecRetention",
    +292            "Eligibility.Criteria.Text",
    +293            "BriefSummary.Text",
    +294        ]
    +295
    +296        self.sensible_embed_safe = list(
    +297            set(self.best_recall_fields).intersection(set(self.sensible_embed))
    +298        )
    +299
    +300        self.query_funcs = {
    +301            "query": self.generate_query,
    +302            "ablation": self.generate_query_ablation,
    +303            "embedding": self.generate_query_embedding,
    +304        }
    +305
    +306        loguru.logger.debug(self.sensible_embed_safe)
    +307
    +308        self.field_usage = {
    +309            "best_recall_fields": self.best_recall_fields,
    +310            "all": self.mappings,
    +311            "best_map_fields": self.best_map_fields,
    +312            "best_embed_fields": self.best_embed_fields,
    +313            "sensible": self.sensible,
    +314            "sensible_embed": self.sensible_embed,
    +315            "sensible_embed_safe": self.sensible_embed_safe,
    +316        }
    +317
    +318    @apply_config
    +319    def generate_query(self, topic_num, query_field_usage, **kwargs) -> Dict:
    +320        """
    +321        Generates a query for the clinical trials index
    +322
    +323        :param topic_num: Topic number to search
    +324        :param query_field_usage: Which document facets to search over
    +325        :param kwargs:
    +326        :return:
    +327            A basic elasticsearch query for clinical trials
    +328        """
    +329        fields = self.field_usage[query_field_usage]
    +330        should = {"should": []}
    +331
    +332        qfield = list(self.topics[topic_num].keys())[0]
    +333        query = self.topics[topic_num][qfield]
    +334
    +335        for i, field in enumerate(fields):
    +336            should["should"].append(
    +337                {
    +338                    "match": {
    +339                        f"{field}": {
    +340                            "query": query,
    +341                        }
    +342                    }
    +343                }
    +344            )
    +345
    +346        query = {
    +347            "query": {
    +348                "bool": should,
    +349            }
    +350        }
    +351
    +352        return query
    +353
    +354    def generate_query_ablation(self, topic_num, **kwargs):
    +355        """
    +356        Only search one document facet at a time
    +357        :param topic_num:
    +358        :param kwargs:
    +359        :return:
    +360        """
    +361        query = {"query": {"match": {}}}
    +362
    +363        for field in self.fields:
    +364            query["query"]["match"][self.mappings[field]] = ""
    +365
    +366        for qfield in self.fields:
    +367            qfield = self.mappings[qfield]
    +368            for field in self.topics[topic_num]:
    +369                query["query"]["match"][qfield] += self.topics[topic_num][field]
    +370
    +371        return query
    +372
    +373    @apply_config
    +374    def generate_query_embedding(
    +375            self,
    +376            topic_num,
    +377            encoder,
    +378            query_field_usage,
    +379            embed_field_usage,
    +380            cosine_weights: List[float] = None,
    +381            query_weight: List[float] = None,
    +382            norm_weight=2.15,
    +383            ablations=False,
    +384            automatic_scores=None,
    +385            **kwargs,
    +386    ):
    +387        """
    +388        Computes the NIR score for a given topic
    +389
    +390        Score = log(BM25)/log(norm_weight) + embedding_score
    +391
    +392        :param topic_num:
    +393        :param encoder:
    +394        :param query_field_usage:
    +395        :param embed_field_usage:
    +396        :param cosine_weights:
    +397        :param query_weight:
    +398        :param norm_weight:
    +399        :param ablations:
    +400        :param automatic_scores:
    +401        :param kwargs:
    +402        :return:
    +403        """
    +404        should = {"should": []}
    +405
    +406        assert norm_weight or automatic_scores
    +407
    +408        query_fields = self.field_usage[query_field_usage]
    +409        embed_fields = self.field_usage[embed_field_usage]
    +410
    +411        qfield = list(self.topics[topic_num].keys())[0]
    +412        query = self.topics[topic_num][qfield]
    +413
    +414        for i, field in enumerate(query_fields):
    +415            should["should"].append(
    +416                {
    +417                    "match": {
    +418                        f"{field}": {
    +419                            "query": query,
    +420                            "boost": query_weight[i] if query_weight else 1,
    +421                        }
    +422                    }
    +423                }
    +424            )
    +425
    +426        if automatic_scores is not None:
    +427            norm_weight = get_z_value(
    +428                cosine_ceiling=len(embed_fields) * len(query_fields),
    +429                bm25_ceiling=automatic_scores[topic_num],
    +430            )
    +431
    +432        params = {
    +433            "weights": cosine_weights if cosine_weights else [1] * len(embed_fields),
    +434            "q_eb": encoder.encode(self.topics[topic_num][qfield]),
    +435            "offset": 1.0,
    +436            "norm_weight": norm_weight,
    +437            "disable_bm25": ablations,
    +438        }
    +439
    +440        query = {
    +441            "query": {
    +442                "script_score": {
    +443                    "query": {
    +444                        "bool": should,
    +445                    },
    +446                    "script": generate_script(self.best_embed_fields, params=params),
    +447                },
    +448            }
    +449        }
    +450
    +451        return query
    +452
    +453    def get_query_type(self, *args, **kwargs):
    +454        return self.query_funcs[self.query_type](*args, **kwargs)
    +455
    +456    def get_id_mapping(self, hit):
    +457        return hit[self.id_mapping]
    +
    + + +

    Elasticsearch Query object for the Clinical Trials Index

    +
    + + +
    + +
    + + TrialsElasticsearchQuery(topics, query_type, config=None, *args, **kwargs) + + + +
    + +
     70    def __init__(self, topics, query_type, config=None, *args, **kwargs):
    + 71        super().__init__(topics, config, *args, **kwargs)
    + 72        self.query_type = query_type
    + 73        self.config = config
    + 74        self.topics = topics
    + 75        self.fields = []
    + 76        self.mappings = [
    + 77            "HasExpandedAccess",
    + 78            "BriefSummary.Textblock",
    + 79            "CompletionDate.Type",
    + 80            "OversightInfo.Text",
    + 81            "OverallContactBackup.PhoneExt",
    + 82            "RemovedCountries.Text",
    + 83            "SecondaryOutcome",
    + 84            "Sponsors.LeadSponsor.Text",
    + 85            "BriefTitle",
    + 86            "IDInfo.NctID",
    + 87            "IDInfo.SecondaryID",
    + 88            "OverallContactBackup.Phone",
    + 89            "Eligibility.StudyPop.Textblock",
    + 90            "DetailedDescription.Textblock",
    + 91            "Eligibility.MinimumAge",
    + 92            "Sponsors.Collaborator",
    + 93            "Reference",
    + 94            "Eligibility.Criteria.Textblock",
    + 95            "XMLName.Space",
    + 96            "Rank",
    + 97            "OverallStatus",
    + 98            "InterventionBrowse.Text",
    + 99            "Eligibility.Text",
    +100            "Intervention",
    +101            "BiospecDescr.Textblock",
    +102            "ResponsibleParty.NameTitle",
    +103            "NumberOfArms",
    +104            "ResponsibleParty.ResponsiblePartyType",
    +105            "IsSection801",
    +106            "Acronym",
    +107            "Eligibility.MaximumAge",
    +108            "DetailedDescription.Text",
    +109            "StudyDesign",
    +110            "OtherOutcome",
    +111            "VerificationDate",
    +112            "ConditionBrowse.MeshTerm",
    +113            "Enrollment.Text",
    +114            "IDInfo.Text",
    +115            "ConditionBrowse.Text",
    +116            "FirstreceivedDate",
    +117            "NumberOfGroups",
    +118            "OversightInfo.HasDmc",
    +119            "PrimaryCompletionDate.Text",
    +120            "ResultsReference",
    +121            "Eligibility.StudyPop.Text",
    +122            "IsFdaRegulated",
    +123            "WhyStopped",
    +124            "ArmGroup",
    +125            "OverallContact.LastName",
    +126            "Phase",
    +127            "RemovedCountries.Country",
    +128            "InterventionBrowse.MeshTerm",
    +129            "Eligibility.HealthyVolunteers",
    +130            "Location",
    +131            "OfficialTitle",
    +132            "OverallContact.Email",
    +133            "RequiredHeader.Text",
    +134            "RequiredHeader.URL",
    +135            "LocationCountries.Country",
    +136            "OverallContact.PhoneExt",
    +137            "Condition",
    +138            "PrimaryOutcome",
    +139            "LocationCountries.Text",
    +140            "BiospecDescr.Text",
    +141            "IDInfo.OrgStudyID",
    +142            "Link",
    +143            "OverallContact.Phone",
    +144            "Source",
    +145            "ResponsibleParty.InvestigatorAffiliation",
    +146            "StudyType",
    +147            "FirstreceivedResultsDate",
    +148            "Enrollment.Type",
    +149            "Eligibility.Gender",
    +150            "OverallContactBackup.LastName",
    +151            "Keyword",
    +152            "BiospecRetention",
    +153            "CompletionDate.Text",
    +154            "OverallContact.Text",
    +155            "RequiredHeader.DownloadDate",
    +156            "Sponsors.Text",
    +157            "Text",
    +158            "Eligibility.SamplingMethod",
    +159            "LastchangedDate",
    +160            "ResponsibleParty.InvestigatorFullName",
    +161            "StartDate",
    +162            "RequiredHeader.LinkText",
    +163            "OverallOfficial",
    +164            "Sponsors.LeadSponsor.AgencyClass",
    +165            "OverallContactBackup.Text",
    +166            "Eligibility.Criteria.Text",
    +167            "XMLName.Local",
    +168            "OversightInfo.Authority",
    +169            "PrimaryCompletionDate.Type",
    +170            "ResponsibleParty.Organization",
    +171            "IDInfo.NctAlias",
    +172            "ResponsibleParty.Text",
    +173            "TargetDuration",
    +174            "Sponsors.LeadSponsor.Agency",
    +175            "BriefSummary.Text",
    +176            "OverallContactBackup.Email",
    +177            "ResponsibleParty.InvestigatorTitle",
    +178        ]
    +179
    +180        self.best_recall_fields = [
    +181            "LocationCountries.Country",
    +182            "BiospecRetention",
    +183            "DetailedDescription.Textblock",
    +184            "HasExpandedAccess",
    +185            "ConditionBrowse.MeshTerm",
    +186            "RequiredHeader.LinkText",
    +187            "WhyStopped",
    +188            "BriefSummary.Textblock",
    +189            "Eligibility.Criteria.Textblock",
    +190            "OfficialTitle",
    +191            "Eligibility.MaximumAge",
    +192            "Eligibility.StudyPop.Textblock",
    +193            "BiospecDescr.Textblock",
    +194            "BriefTitle",
    +195            "Eligibility.MinimumAge",
    +196            "ResponsibleParty.Organization",
    +197            "TargetDuration",
    +198            "Condition",
    +199            "IDInfo.OrgStudyID",
    +200            "Keyword",
    +201            "Source",
    +202            "Sponsors.LeadSponsor.Agency",
    +203            "ResponsibleParty.InvestigatorAffiliation",
    +204            "OversightInfo.Authority",
    +205            "OversightInfo.HasDmc",
    +206            "OverallContact.Phone",
    +207            "Phase",
    +208            "OverallContactBackup.LastName",
    +209            "Acronym",
    +210            "InterventionBrowse.MeshTerm",
    +211            "RemovedCountries.Country",
    +212        ]
    +213        self.best_map_fields = [
    +214            "Eligibility.Gender",
    +215            "LocationCountries.Country",
    +216            "DetailedDescription.Textblock",
    +217            "BriefSummary.Textblock",
    +218            "ConditionBrowse.MeshTerm",
    +219            "Eligibility.Criteria.Textblock",
    +220            "InterventionBrowse.MeshTerm",
    +221            "StudyType",
    +222            "IsFdaRegulated",
    +223            "HasExpandedAccess",
    +224            "RequiredHeader.LinkText",
    +225            "BiospecRetention",
    +226            "OfficialTitle",
    +227            "Eligibility.SamplingMethod",
    +228            "Eligibility.StudyPop.Textblock",
    +229            "Condition",
    +230            "Eligibility.MinimumAge",
    +231            "Keyword",
    +232            "Eligibility.MaximumAge",
    +233            "BriefTitle",
    +234        ]
    +235        self.best_embed_fields = [
    +236            "WhyStopped",
    +237            "HasExpandedAccess",
    +238            "BiospecRetention",
    +239            "BriefSummary.Textblock",
    +240            "LocationCountries.Country",
    +241            "ConditionBrowse.MeshTerm",
    +242            "DetailedDescription.Textblock",
    +243            "RequiredHeader.LinkText",
    +244            "Eligibility.Criteria.Textblock",
    +245        ]
    +246
    +247        self.sensible = [
    +248            "BriefSummary.Textblock" "BriefTitle",
    +249            "Eligibility.StudyPop.Textblock",
    +250            "DetailedDescription.Textblock",
    +251            "Eligibility.MinimumAge",
    +252            "Eligibility.Criteria.Textblock",
    +253            "InterventionBrowse.Text",
    +254            "Eligibility.Text",
    +255            "BiospecDescr.Textblock",
    +256            "Eligibility.MaximumAge",
    +257            "DetailedDescription.Text",
    +258            "ConditionBrowse.MeshTerm",
    +259            "ConditionBrowse.Text",
    +260            "Eligibility.StudyPop.Text",
    +261            "InterventionBrowse.MeshTerm",
    +262            "OfficialTitle",
    +263            "Condition",
    +264            "PrimaryOutcome",
    +265            "BiospecDescr.Text",
    +266            "Eligibility.Gender",
    +267            "Keyword",
    +268            "BiospecRetention",
    +269            "Eligibility.Criteria.Text",
    +270            "BriefSummary.Text",
    +271        ]
    +272
    +273        self.sensible_embed = [
    +274            "BriefSummary.Textblock" "BriefTitle",
    +275            "Eligibility.StudyPop.Textblock",
    +276            "DetailedDescription.Textblock",
    +277            "Eligibility.Criteria.Textblock",
    +278            "InterventionBrowse.Text",
    +279            "Eligibility.Text",
    +280            "BiospecDescr.Textblock",
    +281            "DetailedDescription.Text",
    +282            "ConditionBrowse.MeshTerm",
    +283            "ConditionBrowse.Text",
    +284            "Eligibility.StudyPop.Text",
    +285            "InterventionBrowse.MeshTerm",
    +286            "OfficialTitle",
    +287            "Condition",
    +288            "PrimaryOutcome",
    +289            "BiospecDescr.Text",
    +290            "Keyword",
    +291            "BiospecRetention",
    +292            "Eligibility.Criteria.Text",
    +293            "BriefSummary.Text",
    +294        ]
    +295
    +296        self.sensible_embed_safe = list(
    +297            set(self.best_recall_fields).intersection(set(self.sensible_embed))
    +298        )
    +299
    +300        self.query_funcs = {
    +301            "query": self.generate_query,
    +302            "ablation": self.generate_query_ablation,
    +303            "embedding": self.generate_query_embedding,
    +304        }
    +305
    +306        loguru.logger.debug(self.sensible_embed_safe)
    +307
    +308        self.field_usage = {
    +309            "best_recall_fields": self.best_recall_fields,
    +310            "all": self.mappings,
    +311            "best_map_fields": self.best_map_fields,
    +312            "best_embed_fields": self.best_embed_fields,
    +313            "sensible": self.sensible,
    +314            "sensible_embed": self.sensible_embed,
    +315            "sensible_embed_safe": self.sensible_embed_safe,
    +316        }
    +
    + + + + +
    +
    + +
    + + def + generate_query(self, *args, **kwargs): + + + +
    + +
    229    def use_config(self, *args, **kwargs):
    +230        """
    +231        Replaces keywords and args passed to the function with ones from self.config.
    +232
    +233        :param self:
    +234        :param args: To be updated
    +235        :param kwargs: To be updated
    +236        :return:
    +237        """
    +238        if self.config is not None:
    +239            kwargs = self.config.__update__(**kwargs)
    +240
    +241        return func(self, *args, **kwargs)
    +
    + + +

    Generates a query for the clinical trials index

    + +
    Parameters
    + +
      +
    • topic_num: Topic number to search
    • +
    • query_field_usage: Which document facets to search over
    • +
    • kwargs:
    • +
    + +
    Returns
    + +
    +
    A basic elasticsearch query for clinical trials
    +
    +
    +
    + + +
    +
    + +
    + + def + generate_query_ablation(self, topic_num, **kwargs): + + + +
    + +
    354    def generate_query_ablation(self, topic_num, **kwargs):
    +355        """
    +356        Only search one document facet at a time
    +357        :param topic_num:
    +358        :param kwargs:
    +359        :return:
    +360        """
    +361        query = {"query": {"match": {}}}
    +362
    +363        for field in self.fields:
    +364            query["query"]["match"][self.mappings[field]] = ""
    +365
    +366        for qfield in self.fields:
    +367            qfield = self.mappings[qfield]
    +368            for field in self.topics[topic_num]:
    +369                query["query"]["match"][qfield] += self.topics[topic_num][field]
    +370
    +371        return query
    +
    + + +

    Only search one document facet at a time

    + +
    Parameters
    + +
      +
    • topic_num:
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + def + generate_query_embedding(self, *args, **kwargs): + + + +
    + +
    229    def use_config(self, *args, **kwargs):
    +230        """
    +231        Replaces keywords and args passed to the function with ones from self.config.
    +232
    +233        :param self:
    +234        :param args: To be updated
    +235        :param kwargs: To be updated
    +236        :return:
    +237        """
    +238        if self.config is not None:
    +239            kwargs = self.config.__update__(**kwargs)
    +240
    +241        return func(self, *args, **kwargs)
    +
    + + +

    Computes the NIR score for a given topic

    + +

    Score = log(BM25)/log(norm_weight) + embedding_score

    + +
    Parameters
    + +
      +
    • topic_num:
    • +
    • encoder:
    • +
    • query_field_usage:
    • +
    • embed_field_usage:
    • +
    • cosine_weights:
    • +
    • query_weight:
    • +
    • norm_weight:
    • +
    • ablations:
    • +
    • automatic_scores:
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + def + get_query_type(self, *args, **kwargs): + + + +
    + +
    453    def get_query_type(self, *args, **kwargs):
    +454        return self.query_funcs[self.query_type](*args, **kwargs)
    +
    + + + + +
    +
    + +
    + + def + get_id_mapping(self, hit): + + + +
    + +
    456    def get_id_mapping(self, hit):
    +457        return hit[self.id_mapping]
    +
    + + +

    Get the document ID

    + +
    Parameters
    + +
      +
    • hit: The raw document result
    • +
    + +
    Returns
    + +
    +
    The document's ID
    +
    +
    +
    + + +
    + +
    +
    + +
    + + class + ClinicalTrialsElasticsearchExecutor(debeir.core.executor.GenericElasticsearchExecutor): + + + +
    + +
    460class ClinicalTrialsElasticsearchExecutor(GenericElasticsearchExecutor):
    +461    """
    +462    Executes queries given a query object.
    +463    """
    +464    query: TrialsElasticsearchQuery
    +465
    +466    def __init__(
    +467            self,
    +468            topics: Dict[Union[str, int], Dict[str, str]],
    +469            client: Elasticsearch,
    +470            index_name: str,
    +471            output_file: str,
    +472            query: TrialsElasticsearchQuery,
    +473            encoder: Optional[Encoder] = None,
    +474            config=None,
    +475            *args,
    +476            **kwargs,
    +477    ):
    +478        super().__init__(
    +479            topics,
    +480            client,
    +481            index_name,
    +482            output_file,
    +483            query,
    +484            encoder,
    +485            config=config,
    +486            *args,
    +487            **kwargs,
    +488        )
    +489
    +490        self.query_fns = {
    +491            "query": self.generate_query,
    +492            "ablation": self.generate_query_ablation,
    +493            "embedding": self.generate_embedding_query,
    +494        }
    +
    + + +

    Executes queries given a query object.

    +
    + + +
    + +
    + + ClinicalTrialsElasticsearchExecutor( topics: Dict[Union[str, int], Dict[str, str]], client: elasticsearch.AsyncElasticsearch, index_name: str, output_file: str, query: debeir.datasets.clinical_trials.TrialsElasticsearchQuery, encoder: Optional[debeir.rankers.transformer_sent_encoder.Encoder] = None, config=None, *args, **kwargs) + + + +
    + +
    466    def __init__(
    +467            self,
    +468            topics: Dict[Union[str, int], Dict[str, str]],
    +469            client: Elasticsearch,
    +470            index_name: str,
    +471            output_file: str,
    +472            query: TrialsElasticsearchQuery,
    +473            encoder: Optional[Encoder] = None,
    +474            config=None,
    +475            *args,
    +476            **kwargs,
    +477    ):
    +478        super().__init__(
    +479            topics,
    +480            client,
    +481            index_name,
    +482            output_file,
    +483            query,
    +484            encoder,
    +485            config=config,
    +486            *args,
    +487            **kwargs,
    +488        )
    +489
    +490        self.query_fns = {
    +491            "query": self.generate_query,
    +492            "ablation": self.generate_query_ablation,
    +493            "embedding": self.generate_embedding_query,
    +494        }
    +
    + + + + +
    + +
    +
    + +
    + + class + ClinicalTrialParser(debeir.core.parser.Parser): + + + +
    + +
    497class ClinicalTrialParser(Parser):
    +498    """
    +499    Parser for Clinical Trials topics
    +500    """
    +501
    +502    @classmethod
    +503    def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
    +504        topics = {}
    +505        reader = csv.reader(csvfile)
    +506        for i, row in enumerate(reader):
    +507            if i == 0:
    +508                continue
    +509
    +510            _id = row[0]
    +511            text = row[1]
    +512
    +513            topics[_id] = {"text": text}
    +514
    +515        return topics
    +
    + + +

    Parser for Clinical Trials topics

    +
    + + +
    + +
    +
    @classmethod
    + + def + get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]: + + + +
    + +
    502    @classmethod
    +503    def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
    +504        topics = {}
    +505        reader = csv.reader(csvfile)
    +506        for i, row in enumerate(reader):
    +507            if i == 0:
    +508                continue
    +509
    +510            _id = row[0]
    +511            text = row[1]
    +512
    +513            topics[_id] = {"text": text}
    +514
    +515        return topics
    +
    + + +

    Instance method for getting topics, forwards instance self parameters to the _get_topics class method.

    +
    + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/datasets/factory.html b/docs/debeir/datasets/factory.html new file mode 100644 index 0000000..b2357fa --- /dev/null +++ b/docs/debeir/datasets/factory.html @@ -0,0 +1,676 @@ + + + + + + + debeir.datasets.factory API documentation + + + + + + + + + +
    +
    +

    +debeir.datasets.factory

    + + + + + + +
      1from pathlib import Path
    +  2from typing import Dict, Type, Union
    +  3
    +  4import toml
    +  5from debeir.datasets.bioreddit import BioRedditCommentParser, BioRedditSubmissionParser
    +  6from debeir.datasets.clinical_trials import ClinicalTrialParser, ClinicalTrialsElasticsearchExecutor, \
    +  7    TrialsElasticsearchQuery, TrialsQueryConfig
    +  8from debeir.datasets.marco import MarcoElasticsearchExecutor, MarcoQueryConfig
    +  9from debeir.datasets.trec_clinical_trials import TrecClincialElasticsearchQuery, TrecClinicalTrialsParser
    + 10from debeir.datasets.trec_covid import TrecCovidParser, TrecElasticsearchQuery
    + 11from debeir.evaluation.evaluator import Evaluator
    + 12from debeir.evaluation.residual_scoring import ResidualEvaluator
    + 13from debeir.core.config import Config, ElasticsearchConfig, GenericConfig, MetricsConfig, NIRConfig, SolrConfig, \
    + 14    _NIRMasterConfig
    + 15from debeir.core.executor import GenericElasticsearchExecutor
    + 16from debeir.core.parser import (
    + 17    CSVParser, Parser, TSVParser,
    + 18)
    + 19from debeir.core.query import GenericElasticsearchQuery, Query
    + 20
    + 21str_to_config_cls = {
    + 22    "clinical_trials": TrialsQueryConfig,
    + 23    "test_trials": TrialsQueryConfig,
    + 24    "med-marco": MarcoQueryConfig,
    + 25    "generic": MarcoQueryConfig,
    + 26}
    + 27
    + 28query_factory = {
    + 29    "clinical_trials": TrialsElasticsearchQuery,
    + 30    "test_trials": TrialsElasticsearchQuery,
    + 31    "generic": GenericElasticsearchQuery,
    + 32    "trec_covid": TrecElasticsearchQuery,
    + 33    "trec_clinical": TrecClincialElasticsearchQuery,
    + 34}
    + 35
    + 36parser_factory = {
    + 37    "trec_covid": TrecCovidParser,
    + 38    "bioreddit-comment": BioRedditCommentParser,
    + 39    "bioreddit-submission": BioRedditSubmissionParser,
    + 40    "test_trials": ClinicalTrialParser,
    + 41    "med-marco": CSVParser,
    + 42    "tsv": TSVParser,
    + 43    "trec_clinical": TrecClinicalTrialsParser
    + 44}
    + 45
    + 46executor_factory = {
    + 47    "clinical": ClinicalTrialsElasticsearchExecutor,
    + 48    "med-marco": MarcoElasticsearchExecutor,
    + 49    "generic": GenericElasticsearchExecutor,
    + 50}
    + 51
    + 52evaluator_factory = {
    + 53    "residual": ResidualEvaluator,
    + 54    "trec": Evaluator,
    + 55}
    + 56
    + 57
    + 58def get_index_name(config_fp):
    + 59    """
    + 60    Get the index name from the config without parsing as a TOML
    + 61
    + 62    :param config_fp:
    + 63    :return:
    + 64    """
    + 65    with open(config_fp, "r") as reader:
    + 66        for line in reader:
    + 67            if line.startswith("index"):
    + 68                line = line.replace('"', "")
    + 69                return line.split("=")[-1].strip()
    + 70    return None
    + 71
    + 72
    + 73def factory_fn(config_fp, index=None) -> (Query, GenericConfig,
    + 74                                          Parser, GenericElasticsearchExecutor, Evaluator):
    + 75    """
    + 76    Factory method for creating the parsed topics, config object, query object and query executor object
    + 77
    + 78    :param config_fp: Config file path
    + 79    :param index: Index to search
    + 80    :return:
    + 81        Query, Config, Parser, Executor, Evaluator
    + 82    """
    + 83    config = config_factory(config_fp)
    + 84    assert config.index is not None
    + 85    query_cls = query_factory[config.query_fn]
    + 86    parser = parser_factory[config.parser_fn]
    + 87    executor = executor_factory[config.executor_fn]
    + 88
    + 89    return query_cls, config, parser, executor
    + 90
    + 91
    + 92def config_factory(path: Union[str, Path] = None, config_cls: Type[Config] = None, args_dict: Dict = None):
    + 93    """
    + 94    Factory method for creating configs
    + 95
    + 96    :param path: Config path
    + 97    :param config_cls: Config class to instantiate
    + 98    :param args_dict: Arguments to consider
    + 99    :return:
    +100        A config object
    +101    """
    +102    if path:
    +103        args_dict = toml.load(path)
    +104
    +105    if not config_cls:
    +106        if "config_fn" in args_dict:
    +107            config_cls = str_to_config_cls[args_dict["config_fn"]]
    +108        else:
    +109            raise NotImplementedError()
    +110
    +111    return config_cls.from_args(args_dict, config_cls)
    +112
    +113
    +114def get_nir_config(nir_config, *args, ignore_errors=False, **kwargs):
    +115    main_config = config_factory(nir_config, config_cls=_NIRMasterConfig)
    +116    search_engine_config = None
    +117
    +118    supported_search_engines = {"solr": SolrConfig,
    +119                                "elasticsearch": ElasticsearchConfig}
    +120
    +121    search_engine_config = None
    +122
    +123    if 'engine' in kwargs and kwargs['engine'] in supported_search_engines:
    +124        search_engine = kwargs['engine']
    +125        search_engine_config = config_factory(args_dict=main_config.get_search_engine_settings(search_engine),
    +126                                              config_cls=supported_search_engines[search_engine])
    +127
    +128    # for search_engine in supported_search_engines:
    +129    #    if search_engine in kwargs and kwargs[search_engine] and kwargs['engine'] == search_engine:
    +130    #        search_engine_config = config_factory(args_dict=main_config.get_search_engine_settings(search_engine),
    +131    #                                              config_cls=supported_search_engines[search_engine])
    +132
    +133    if not ignore_errors and search_engine_config is None:
    +134        raise RuntimeError("Unable to get a search engine configuration.")
    +135
    +136    metrics_config = config_factory(args_dict=main_config.get_metrics(), config_cls=MetricsConfig)
    +137    nir_config = config_factory(args_dict=main_config.get_nir_settings(), config_cls=NIRConfig)
    +138
    +139    return nir_config, search_engine_config, metrics_config
    +140
    +141
    +142def apply_nir_config(func):
    +143    """
    +144    Decorator that applies the NIR config settings to the current function
    +145    Replaces arguments and keywords arguments with those found in the config
    +146
    +147    :param func:
    +148    :return:
    +149    """
    +150
    +151    def parse_nir_config(*args, ignore_errors=False, **kwargs):
    +152        """
    +153        Parses the NIR config for the different setting groups: Search Engine, Metrics and NIR settings
    +154        Applies these settings to the current function
    +155        :param ignore_errors:
    +156        :param args:
    +157        :param kwargs:
    +158        :return:
    +159        """
    +160
    +161        nir_config, search_engine_config, metrics_config = get_nir_config(*args,
    +162                                                                          ignore_errors,
    +163                                                                          **kwargs)
    +164
    +165        kwargs = nir_config.__update__(
    +166            **search_engine_config.__update__(
    +167                **metrics_config.__update__(**kwargs)
    +168            )
    +169        )
    +170
    +171        return func(*args, **kwargs)
    +172
    +173    return parse_nir_config
    +
    + + +
    +
    + +
    + + def + get_index_name(config_fp): + + + +
    + +
    59def get_index_name(config_fp):
    +60    """
    +61    Get the index name from the config without parsing as a TOML
    +62
    +63    :param config_fp:
    +64    :return:
    +65    """
    +66    with open(config_fp, "r") as reader:
    +67        for line in reader:
    +68            if line.startswith("index"):
    +69                line = line.replace('"', "")
    +70                return line.split("=")[-1].strip()
    +71    return None
    +
    + + +

    Get the index name from the config without parsing as a TOML

    + +
    Parameters
    + +
      +
    • config_fp:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + def + factory_fn( config_fp, index=None) -> (<class 'debeir.core.query.Query'>, <class 'debeir.core.config.GenericConfig'>, <class 'debeir.core.parser.Parser'>, <class 'debeir.core.executor.GenericElasticsearchExecutor'>, <class 'debeir.evaluation.evaluator.Evaluator'>): + + + +
    + +
    74def factory_fn(config_fp, index=None) -> (Query, GenericConfig,
    +75                                          Parser, GenericElasticsearchExecutor, Evaluator):
    +76    """
    +77    Factory method for creating the parsed topics, config object, query object and query executor object
    +78
    +79    :param config_fp: Config file path
    +80    :param index: Index to search
    +81    :return:
    +82        Query, Config, Parser, Executor, Evaluator
    +83    """
    +84    config = config_factory(config_fp)
    +85    assert config.index is not None
    +86    query_cls = query_factory[config.query_fn]
    +87    parser = parser_factory[config.parser_fn]
    +88    executor = executor_factory[config.executor_fn]
    +89
    +90    return query_cls, config, parser, executor
    +
    + + +

    Factory method for creating the parsed topics, config object, query object and query executor object

    + +
    Parameters
    + +
      +
    • config_fp: Config file path
    • +
    • index: Index to search
    • +
    + +
    Returns
    + +
    +
    Query, Config, Parser, Executor, Evaluator
    +
    +
    +
    + + +
    +
    + +
    + + def + config_factory( path: Union[str, pathlib.Path] = None, config_cls: Type[debeir.core.config.Config] = None, args_dict: Dict = None): + + + +
    + +
     93def config_factory(path: Union[str, Path] = None, config_cls: Type[Config] = None, args_dict: Dict = None):
    + 94    """
    + 95    Factory method for creating configs
    + 96
    + 97    :param path: Config path
    + 98    :param config_cls: Config class to instantiate
    + 99    :param args_dict: Arguments to consider
    +100    :return:
    +101        A config object
    +102    """
    +103    if path:
    +104        args_dict = toml.load(path)
    +105
    +106    if not config_cls:
    +107        if "config_fn" in args_dict:
    +108            config_cls = str_to_config_cls[args_dict["config_fn"]]
    +109        else:
    +110            raise NotImplementedError()
    +111
    +112    return config_cls.from_args(args_dict, config_cls)
    +
    + + +

    Factory method for creating configs

    + +
    Parameters
    + +
      +
    • path: Config path
    • +
    • config_cls: Config class to instantiate
    • +
    • args_dict: Arguments to consider
    • +
    + +
    Returns
    + +
    +
    A config object
    +
    +
    +
    + + +
    +
    + +
    + + def + get_nir_config(nir_config, *args, ignore_errors=False, **kwargs): + + + +
    + +
    115def get_nir_config(nir_config, *args, ignore_errors=False, **kwargs):
    +116    main_config = config_factory(nir_config, config_cls=_NIRMasterConfig)
    +117    search_engine_config = None
    +118
    +119    supported_search_engines = {"solr": SolrConfig,
    +120                                "elasticsearch": ElasticsearchConfig}
    +121
    +122    search_engine_config = None
    +123
    +124    if 'engine' in kwargs and kwargs['engine'] in supported_search_engines:
    +125        search_engine = kwargs['engine']
    +126        search_engine_config = config_factory(args_dict=main_config.get_search_engine_settings(search_engine),
    +127                                              config_cls=supported_search_engines[search_engine])
    +128
    +129    # for search_engine in supported_search_engines:
    +130    #    if search_engine in kwargs and kwargs[search_engine] and kwargs['engine'] == search_engine:
    +131    #        search_engine_config = config_factory(args_dict=main_config.get_search_engine_settings(search_engine),
    +132    #                                              config_cls=supported_search_engines[search_engine])
    +133
    +134    if not ignore_errors and search_engine_config is None:
    +135        raise RuntimeError("Unable to get a search engine configuration.")
    +136
    +137    metrics_config = config_factory(args_dict=main_config.get_metrics(), config_cls=MetricsConfig)
    +138    nir_config = config_factory(args_dict=main_config.get_nir_settings(), config_cls=NIRConfig)
    +139
    +140    return nir_config, search_engine_config, metrics_config
    +
    + + + + +
    +
    + +
    + + def + apply_nir_config(func): + + + +
    + +
    143def apply_nir_config(func):
    +144    """
    +145    Decorator that applies the NIR config settings to the current function
    +146    Replaces arguments and keywords arguments with those found in the config
    +147
    +148    :param func:
    +149    :return:
    +150    """
    +151
    +152    def parse_nir_config(*args, ignore_errors=False, **kwargs):
    +153        """
    +154        Parses the NIR config for the different setting groups: Search Engine, Metrics and NIR settings
    +155        Applies these settings to the current function
    +156        :param ignore_errors:
    +157        :param args:
    +158        :param kwargs:
    +159        :return:
    +160        """
    +161
    +162        nir_config, search_engine_config, metrics_config = get_nir_config(*args,
    +163                                                                          ignore_errors,
    +164                                                                          **kwargs)
    +165
    +166        kwargs = nir_config.__update__(
    +167            **search_engine_config.__update__(
    +168                **metrics_config.__update__(**kwargs)
    +169            )
    +170        )
    +171
    +172        return func(*args, **kwargs)
    +173
    +174    return parse_nir_config
    +
    + + +

    Decorator that applies the NIR config settings to the current function +Replaces arguments and keywords arguments with those found in the config

    + +
    Parameters
    + +
      +
    • func:
    • +
    + +
    Returns
    +
    + + +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/datasets/marco.html b/docs/debeir/datasets/marco.html new file mode 100644 index 0000000..e55cc41 --- /dev/null +++ b/docs/debeir/datasets/marco.html @@ -0,0 +1,770 @@ + + + + + + + debeir.datasets.marco API documentation + + + + + + + + + +
    +
    +

    +debeir.datasets.marco

    + + + + + + +
     1from dataclasses import dataclass
    + 2from typing import Dict, Optional, Union
    + 3
    + 4from debeir.core.config import GenericConfig
    + 5from debeir.core.executor import GenericElasticsearchExecutor
    + 6from debeir.core.query import GenericElasticsearchQuery
    + 7from debeir.rankers.transformer_sent_encoder import Encoder
    + 8from elasticsearch import AsyncElasticsearch as Elasticsearch
    + 9
    +10
    +11class MarcoElasticsearchExecutor(GenericElasticsearchExecutor):
    +12    query: GenericElasticsearchQuery
    +13
    +14    def __init__(
    +15            self,
    +16            topics: Dict[Union[str, int], Dict[str, str]],
    +17            client: Elasticsearch,
    +18            index_name: str,
    +19            output_file: str,
    +20            query: GenericElasticsearchQuery,
    +21            encoder: Optional[Encoder] = None,
    +22            config=None,
    +23            *args,
    +24            **kwargs,
    +25    ):
    +26        super().__init__(
    +27            topics,
    +28            client,
    +29            index_name,
    +30            output_file,
    +31            query,
    +32            encoder,
    +33            config=config,
    +34            *args,
    +35            **kwargs,
    +36        )
    +37
    +38        self.query_fns = {
    +39            "query": self.generate_query,
    +40            "embedding": self.generate_embedding_query,
    +41        }
    +42
    +43    def generate_query(self, topic_num, best_fields=True, **kwargs):
    +44        return self.query.generate_query(topic_num)
    +45
    +46    def generate_embedding_query(
    +47            self,
    +48            topic_num,
    +49            cosine_weights=None,
    +50            query_weights=None,
    +51            norm_weight=2.15,
    +52            automatic_scores=None,
    +53            **kwargs,
    +54    ):
    +55        return super().generate_embedding_query(
    +56            topic_num,
    +57            cosine_weights=cosine_weights,
    +58            query_weights=query_weights,
    +59            norm_weight=2.15,
    +60            automatic_scores=None,
    +61            **kwargs,
    +62        )
    +63
    +64    async def execute_query(
    +65            self, query=None, topic_num=None, ablation=False, query_type="query", **kwargs
    +66    ):
    +67        return super().execute_query(
    +68            query, topic_num, ablation, query_type=query_type, **kwargs
    +69        )
    +70
    +71
    +72@dataclass(init=True, unsafe_hash=True)
    +73class MarcoQueryConfig(GenericConfig):
    +74    def validate(self):
    +75        if self.query_type == "embedding":
    +76            assert (
    +77                    self.encoder_fp and self.encoder
    +78            ), "Must provide encoder path for embedding model"
    +79            assert self.norm_weight is not None or self.automatic is not None, (
    +80                "Norm weight be " "specified or be automatic"
    +81            )
    +82
    +83    @classmethod
    +84    def from_toml(cls, fp: str, *args, **kwargs) -> "MarcoQueryConfig":
    +85        return super().from_toml(fp, cls, *args, **kwargs)
    +86
    +87    @classmethod
    +88    def from_dict(cls, **kwargs) -> "MarcoQueryConfig":
    +89        return super().from_dict(cls, **kwargs)
    +
    + + +
    +
    + +
    + + class + MarcoElasticsearchExecutor(debeir.core.executor.GenericElasticsearchExecutor): + + + +
    + +
    12class MarcoElasticsearchExecutor(GenericElasticsearchExecutor):
    +13    query: GenericElasticsearchQuery
    +14
    +15    def __init__(
    +16            self,
    +17            topics: Dict[Union[str, int], Dict[str, str]],
    +18            client: Elasticsearch,
    +19            index_name: str,
    +20            output_file: str,
    +21            query: GenericElasticsearchQuery,
    +22            encoder: Optional[Encoder] = None,
    +23            config=None,
    +24            *args,
    +25            **kwargs,
    +26    ):
    +27        super().__init__(
    +28            topics,
    +29            client,
    +30            index_name,
    +31            output_file,
    +32            query,
    +33            encoder,
    +34            config=config,
    +35            *args,
    +36            **kwargs,
    +37        )
    +38
    +39        self.query_fns = {
    +40            "query": self.generate_query,
    +41            "embedding": self.generate_embedding_query,
    +42        }
    +43
    +44    def generate_query(self, topic_num, best_fields=True, **kwargs):
    +45        return self.query.generate_query(topic_num)
    +46
    +47    def generate_embedding_query(
    +48            self,
    +49            topic_num,
    +50            cosine_weights=None,
    +51            query_weights=None,
    +52            norm_weight=2.15,
    +53            automatic_scores=None,
    +54            **kwargs,
    +55    ):
    +56        return super().generate_embedding_query(
    +57            topic_num,
    +58            cosine_weights=cosine_weights,
    +59            query_weights=query_weights,
    +60            norm_weight=2.15,
    +61            automatic_scores=None,
    +62            **kwargs,
    +63        )
    +64
    +65    async def execute_query(
    +66            self, query=None, topic_num=None, ablation=False, query_type="query", **kwargs
    +67    ):
    +68        return super().execute_query(
    +69            query, topic_num, ablation, query_type=query_type, **kwargs
    +70        )
    +
    + + +

    Generic Executor class for Elasticsearch

    +
    + + +
    + +
    + + MarcoElasticsearchExecutor( topics: Dict[Union[str, int], Dict[str, str]], client: elasticsearch.AsyncElasticsearch, index_name: str, output_file: str, query: debeir.core.query.GenericElasticsearchQuery, encoder: Optional[debeir.rankers.transformer_sent_encoder.Encoder] = None, config=None, *args, **kwargs) + + + +
    + +
    15    def __init__(
    +16            self,
    +17            topics: Dict[Union[str, int], Dict[str, str]],
    +18            client: Elasticsearch,
    +19            index_name: str,
    +20            output_file: str,
    +21            query: GenericElasticsearchQuery,
    +22            encoder: Optional[Encoder] = None,
    +23            config=None,
    +24            *args,
    +25            **kwargs,
    +26    ):
    +27        super().__init__(
    +28            topics,
    +29            client,
    +30            index_name,
    +31            output_file,
    +32            query,
    +33            encoder,
    +34            config=config,
    +35            *args,
    +36            **kwargs,
    +37        )
    +38
    +39        self.query_fns = {
    +40            "query": self.generate_query,
    +41            "embedding": self.generate_embedding_query,
    +42        }
    +
    + + + + +
    +
    + +
    + + def + generate_query(self, topic_num, best_fields=True, **kwargs): + + + +
    + +
    44    def generate_query(self, topic_num, best_fields=True, **kwargs):
    +45        return self.query.generate_query(topic_num)
    +
    + + +

    Generates a standard BM25 query given the topic number

    + +
    Parameters
    + +
      +
    • topic_num: Query topic number to generate
    • +
    • best_fields: Whether to use a curated list of fields
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + def + generate_embedding_query( self, topic_num, cosine_weights=None, query_weights=None, norm_weight=2.15, automatic_scores=None, **kwargs): + + + +
    + +
    47    def generate_embedding_query(
    +48            self,
    +49            topic_num,
    +50            cosine_weights=None,
    +51            query_weights=None,
    +52            norm_weight=2.15,
    +53            automatic_scores=None,
    +54            **kwargs,
    +55    ):
    +56        return super().generate_embedding_query(
    +57            topic_num,
    +58            cosine_weights=cosine_weights,
    +59            query_weights=query_weights,
    +60            norm_weight=2.15,
    +61            automatic_scores=None,
    +62            **kwargs,
    +63        )
    +
    + + +

    Executes an NIR-style query with combined scoring.

    + +
    Parameters
    + +
      +
    • topic_num:
    • +
    • cosine_weights:
    • +
    • query_weights:
    • +
    • norm_weight:
    • +
    • automatic_scores:
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    + +
    + + async def + execute_query( self, query=None, topic_num=None, ablation=False, query_type='query', **kwargs): + + + +
    + +
    65    async def execute_query(
    +66            self, query=None, topic_num=None, ablation=False, query_type="query", **kwargs
    +67    ):
    +68        return super().execute_query(
    +69            query, topic_num, ablation, query_type=query_type, **kwargs
    +70        )
    +
    + + +

    Execute a query given parameters

    + +
    Parameters
    + +
      +
    • args:
    • +
    • kwargs:
    • +
    +
    + + +
    + +
    +
    + +
    +
    @dataclass(init=True, unsafe_hash=True)
    + + class + MarcoQueryConfig(debeir.core.config.GenericConfig): + + + +
    + +
    73@dataclass(init=True, unsafe_hash=True)
    +74class MarcoQueryConfig(GenericConfig):
    +75    def validate(self):
    +76        if self.query_type == "embedding":
    +77            assert (
    +78                    self.encoder_fp and self.encoder
    +79            ), "Must provide encoder path for embedding model"
    +80            assert self.norm_weight is not None or self.automatic is not None, (
    +81                "Norm weight be " "specified or be automatic"
    +82            )
    +83
    +84    @classmethod
    +85    def from_toml(cls, fp: str, *args, **kwargs) -> "MarcoQueryConfig":
    +86        return super().from_toml(fp, cls, *args, **kwargs)
    +87
    +88    @classmethod
    +89    def from_dict(cls, **kwargs) -> "MarcoQueryConfig":
    +90        return super().from_dict(cls, **kwargs)
    +
    + + + + +
    +
    + + MarcoQueryConfig( query_type: str, index: str = None, encoder_normalize: bool = True, ablations: bool = False, norm_weight: float = None, automatic: bool = None, encoder: object = None, encoder_fp: str = None, query_weights: List[float] = None, cosine_weights: List[float] = None, evaluate: bool = False, qrels: str = None, config_fn: str = None, query_fn: str = None, parser_fn: str = None, executor_fn: str = None, cosine_ceiling: float = None, topics_path: str = None, return_id_only: bool = False, overwrite_output_if_exists: bool = False, output_file: str = None, run_name: str = None) + + +
    + + + + +
    +
    + +
    + + def + validate(self): + + + +
    + +
    75    def validate(self):
    +76        if self.query_type == "embedding":
    +77            assert (
    +78                    self.encoder_fp and self.encoder
    +79            ), "Must provide encoder path for embedding model"
    +80            assert self.norm_weight is not None or self.automatic is not None, (
    +81                "Norm weight be " "specified or be automatic"
    +82            )
    +
    + + +

    Validates if the config is correct. +Must be implemented by inherited classes.

    +
    + + +
    +
    + +
    +
    @classmethod
    + + def + from_toml(cls, fp: str, *args, **kwargs) -> debeir.datasets.marco.MarcoQueryConfig: + + + +
    + +
    84    @classmethod
    +85    def from_toml(cls, fp: str, *args, **kwargs) -> "MarcoQueryConfig":
    +86        return super().from_toml(fp, cls, *args, **kwargs)
    +
    + + +

    Instantiates a Config object from a toml file

    + +
    Parameters
    + +
      +
    • fp: File path of the Config TOML file
    • +
    • field_class: Class of the Config object to be instantiated
    • +
    • args: Arguments to be passed to Config
    • +
    • kwargs: Keyword arguments to be passed
    • +
    + +
    Returns
    + +
    +
    A instantiated and validated Config object.
    +
    +
    +
    + + +
    +
    + +
    +
    @classmethod
    + + def + from_dict(cls, **kwargs) -> debeir.datasets.marco.MarcoQueryConfig: + + + +
    + +
    88    @classmethod
    +89    def from_dict(cls, **kwargs) -> "MarcoQueryConfig":
    +90        return super().from_dict(cls, **kwargs)
    +
    + + +

    Instantiates a Config object from a dictionary

    + +
    Parameters
    + +
      +
    • data_class:
    • +
    • kwargs:
    • +
    + +
    Returns
    +
    + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/datasets/trec_clinical_trials.html b/docs/debeir/datasets/trec_clinical_trials.html new file mode 100644 index 0000000..493f44b --- /dev/null +++ b/docs/debeir/datasets/trec_clinical_trials.html @@ -0,0 +1,610 @@ + + + + + + + debeir.datasets.trec_clinical_trials API documentation + + + + + + + + + +
    +
    +

    +debeir.datasets.trec_clinical_trials

    + + + + + + +
      1import pathlib
    +  2import xml.etree.ElementTree as ET
    +  3from collections import defaultdict
    +  4from typing import Dict, List
    +  5
    +  6import pandas as pd
    +  7from debeir.core.parser import JsonLinesParser, XMLParser
    +  8from debeir.core.query import GenericElasticsearchQuery
    +  9
    + 10
    + 11class TREClinicalTrialDocumentParser(XMLParser):
    + 12    """
    + 13    Parser for Clinical Trials topics
    + 14    """
    + 15
    + 16    parse_fields: List[str] = ["brief_title", "official_title",
    + 17                               "brief_summary", "detailed_description",
    + 18                               "eligibility", "condition_browse",
    + 19                               "intervention_browse"]
    + 20    topic_field_name: str
    + 21    id_field: str
    + 22
    + 23    @classmethod
    + 24    def extract(cls, path) -> Dict:
    + 25        document = ET.parse(path).getroot()
    + 26        document_dict = defaultdict(lambda: defaultdict(lambda: []))
    + 27        document_dict['doc_id'] = pathlib.Path(path).parts[-1].strip(".xml")
    + 28
    + 29        for parse_field in cls.parse_fields:
    + 30            node = document.find(parse_field)
    + 31            nodes: List[ET.Element] = []
    + 32
    + 33            if node is not None:
    + 34                cls._recurse_to_child_node(node, nodes)
    + 35
    + 36            if len(nodes) == 0 and node is not None:
    + 37                document_dict[parse_field] = node.text
    + 38
    + 39            for node in nodes:
    + 40                text = node.text.strip()
    + 41
    + 42                if not text:
    + 43                    continue
    + 44
    + 45                if document_dict[parse_field][node.tag]:
    + 46                    document_dict[parse_field][node.tag].append(text)
    + 47                else:
    + 48                    document_dict[parse_field][node.tag] = [text]
    + 49
    + 50            cls.unwrap(document_dict, parse_field)
    + 51
    + 52        document_dict = pd.io.json.json_normalize(document_dict,
    + 53                                                  sep=".").to_dict(orient='records')[0]
    + 54
    + 55        return document_dict
    + 56
    + 57
    + 58TrecClinicalTrialTripletParser = JsonLinesParser(
    + 59    parse_fields=["q_text", "brief_title", "official_title",
    + 60                  "brief_summary", "detailed_description", "rel"],
    + 61    id_field="qid",
    + 62    secondary_id="doc_id",
    + 63    ignore_full_match=True
    + 64)
    + 65
    + 66TrecClinicalTrialsParser = XMLParser(
    + 67    parse_fields=None,
    + 68    id_field="number",
    + 69    topic_field_name="topic")
    + 70
    + 71
    + 72class TrecClincialElasticsearchQuery(GenericElasticsearchQuery):
    + 73    def __init__(self, topics, config, *args, **kwargs):
    + 74        super().__init__(topics, config, *args, **kwargs)
    + 75
    + 76        # self.mappings = ['BriefTitle_Text',
    + 77        #                 'BriefSummary_Text',
    + 78        #                 'DetailedDescription_Text']
    + 79
    + 80        self.mappings = [
    + 81            "BriefSummary_Text",
    + 82            "BriefTitle_Text",
    + 83            'DetailedDescription_Text',
    + 84            'Eligibility.Criteria.Textblock'
    + 85            'Eligibility.StudyPop.Textblock',
    + 86            'ConditionBrowse.MeshTerm',
    + 87            'InterventionBrowse.MeshTerm',
    + 88            'Condition',
    + 89            'Eligibility.Gender',
    + 90            "OfficialTitle"]
    + 91
    + 92        self.topics = topics
    + 93        self.config = config
    + 94        self.query_type = self.config.query_type
    + 95
    + 96        self.embed_mappings = ['BriefTitle_Embedding',
    + 97                               'BriefSummary_Embedding',
    + 98                               'DetailedDescription_Embedding']
    + 99
    +100        self.id_mapping = "docid"
    +101
    +102        self.query_funcs = {
    +103            "query": self.generate_query,
    +104            "embedding": self.generate_query_embedding,
    +105        }
    +
    + + +
    +
    + +
    + + class + TREClinicalTrialDocumentParser(debeir.core.parser.XMLParser): + + + +
    + +
    12class TREClinicalTrialDocumentParser(XMLParser):
    +13    """
    +14    Parser for Clinical Trials topics
    +15    """
    +16
    +17    parse_fields: List[str] = ["brief_title", "official_title",
    +18                               "brief_summary", "detailed_description",
    +19                               "eligibility", "condition_browse",
    +20                               "intervention_browse"]
    +21    topic_field_name: str
    +22    id_field: str
    +23
    +24    @classmethod
    +25    def extract(cls, path) -> Dict:
    +26        document = ET.parse(path).getroot()
    +27        document_dict = defaultdict(lambda: defaultdict(lambda: []))
    +28        document_dict['doc_id'] = pathlib.Path(path).parts[-1].strip(".xml")
    +29
    +30        for parse_field in cls.parse_fields:
    +31            node = document.find(parse_field)
    +32            nodes: List[ET.Element] = []
    +33
    +34            if node is not None:
    +35                cls._recurse_to_child_node(node, nodes)
    +36
    +37            if len(nodes) == 0 and node is not None:
    +38                document_dict[parse_field] = node.text
    +39
    +40            for node in nodes:
    +41                text = node.text.strip()
    +42
    +43                if not text:
    +44                    continue
    +45
    +46                if document_dict[parse_field][node.tag]:
    +47                    document_dict[parse_field][node.tag].append(text)
    +48                else:
    +49                    document_dict[parse_field][node.tag] = [text]
    +50
    +51            cls.unwrap(document_dict, parse_field)
    +52
    +53        document_dict = pd.io.json.json_normalize(document_dict,
    +54                                                  sep=".").to_dict(orient='records')[0]
    +55
    +56        return document_dict
    +
    + + +

    Parser for Clinical Trials topics

    +
    + + +
    + +
    +
    @classmethod
    + + def + extract(cls, path) -> Dict: + + + +
    + +
    24    @classmethod
    +25    def extract(cls, path) -> Dict:
    +26        document = ET.parse(path).getroot()
    +27        document_dict = defaultdict(lambda: defaultdict(lambda: []))
    +28        document_dict['doc_id'] = pathlib.Path(path).parts[-1].strip(".xml")
    +29
    +30        for parse_field in cls.parse_fields:
    +31            node = document.find(parse_field)
    +32            nodes: List[ET.Element] = []
    +33
    +34            if node is not None:
    +35                cls._recurse_to_child_node(node, nodes)
    +36
    +37            if len(nodes) == 0 and node is not None:
    +38                document_dict[parse_field] = node.text
    +39
    +40            for node in nodes:
    +41                text = node.text.strip()
    +42
    +43                if not text:
    +44                    continue
    +45
    +46                if document_dict[parse_field][node.tag]:
    +47                    document_dict[parse_field][node.tag].append(text)
    +48                else:
    +49                    document_dict[parse_field][node.tag] = [text]
    +50
    +51            cls.unwrap(document_dict, parse_field)
    +52
    +53        document_dict = pd.io.json.json_normalize(document_dict,
    +54                                                  sep=".").to_dict(orient='records')[0]
    +55
    +56        return document_dict
    +
    + + + + +
    + +
    +
    + +
    + + class + TrecClincialElasticsearchQuery(debeir.core.query.GenericElasticsearchQuery): + + + +
    + +
     73class TrecClincialElasticsearchQuery(GenericElasticsearchQuery):
    + 74    def __init__(self, topics, config, *args, **kwargs):
    + 75        super().__init__(topics, config, *args, **kwargs)
    + 76
    + 77        # self.mappings = ['BriefTitle_Text',
    + 78        #                 'BriefSummary_Text',
    + 79        #                 'DetailedDescription_Text']
    + 80
    + 81        self.mappings = [
    + 82            "BriefSummary_Text",
    + 83            "BriefTitle_Text",
    + 84            'DetailedDescription_Text',
    + 85            'Eligibility.Criteria.Textblock'
    + 86            'Eligibility.StudyPop.Textblock',
    + 87            'ConditionBrowse.MeshTerm',
    + 88            'InterventionBrowse.MeshTerm',
    + 89            'Condition',
    + 90            'Eligibility.Gender',
    + 91            "OfficialTitle"]
    + 92
    + 93        self.topics = topics
    + 94        self.config = config
    + 95        self.query_type = self.config.query_type
    + 96
    + 97        self.embed_mappings = ['BriefTitle_Embedding',
    + 98                               'BriefSummary_Embedding',
    + 99                               'DetailedDescription_Embedding']
    +100
    +101        self.id_mapping = "docid"
    +102
    +103        self.query_funcs = {
    +104            "query": self.generate_query,
    +105            "embedding": self.generate_query_embedding,
    +106        }
    +
    + + +

    A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries. +Requires topics, configs to be included

    +
    + + +
    + +
    + + TrecClincialElasticsearchQuery(topics, config, *args, **kwargs) + + + +
    + +
     74    def __init__(self, topics, config, *args, **kwargs):
    + 75        super().__init__(topics, config, *args, **kwargs)
    + 76
    + 77        # self.mappings = ['BriefTitle_Text',
    + 78        #                 'BriefSummary_Text',
    + 79        #                 'DetailedDescription_Text']
    + 80
    + 81        self.mappings = [
    + 82            "BriefSummary_Text",
    + 83            "BriefTitle_Text",
    + 84            'DetailedDescription_Text',
    + 85            'Eligibility.Criteria.Textblock'
    + 86            'Eligibility.StudyPop.Textblock',
    + 87            'ConditionBrowse.MeshTerm',
    + 88            'InterventionBrowse.MeshTerm',
    + 89            'Condition',
    + 90            'Eligibility.Gender',
    + 91            "OfficialTitle"]
    + 92
    + 93        self.topics = topics
    + 94        self.config = config
    + 95        self.query_type = self.config.query_type
    + 96
    + 97        self.embed_mappings = ['BriefTitle_Embedding',
    + 98                               'BriefSummary_Embedding',
    + 99                               'DetailedDescription_Embedding']
    +100
    +101        self.id_mapping = "docid"
    +102
    +103        self.query_funcs = {
    +104            "query": self.generate_query,
    +105            "embedding": self.generate_query_embedding,
    +106        }
    +
    + + + + +
    + +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/datasets/trec_covid.html b/docs/debeir/datasets/trec_covid.html new file mode 100644 index 0000000..8be718e --- /dev/null +++ b/docs/debeir/datasets/trec_covid.html @@ -0,0 +1,453 @@ + + + + + + + debeir.datasets.trec_covid API documentation + + + + + + + + + +
    +
    +

    +debeir.datasets.trec_covid

    + + + + + + +
     1from typing import Dict
    + 2
    + 3from debeir.core.parser import XMLParser
    + 4from debeir.core.query import GenericElasticsearchQuery
    + 5
    + 6
    + 7class TrecCovidParser(XMLParser):
    + 8    parse_fields = ["query", "question", "narrative"]
    + 9    topic_field_name = "topic"
    +10    id_field = "number"
    +11
    +12    @classmethod
    +13    def get_topics(cls, xmlfile) -> Dict[int, Dict[str, str]]:
    +14        return super().get_topics(xmlfile)
    +15
    +16
    +17class TrecElasticsearchQuery(GenericElasticsearchQuery):
    +18    def __init__(self, topics, config, *args, **kwargs):
    +19        super().__init__(topics, config, *args, **kwargs)
    +20
    +21        self.mappings = ["title", "abstract", "fulltext"]
    +22
    +23        self.topics = topics
    +24        self.config = config
    +25        self.query_type = self.config.query_type
    +26
    +27        self.embed_mappings = [
    +28            "title_embedding",
    +29            "abstract_embedding",
    +30            "fulltext_embedding",
    +31        ]
    +32
    +33        self.id_mapping = "id"
    +34
    +35        self.query_funcs = {
    +36            "query": self.generate_query,
    +37            "embedding": self.generate_query_embedding,
    +38        }
    +
    + + +
    +
    + +
    + + class + TrecCovidParser(debeir.core.parser.XMLParser): + + + +
    + +
     8class TrecCovidParser(XMLParser):
    + 9    parse_fields = ["query", "question", "narrative"]
    +10    topic_field_name = "topic"
    +11    id_field = "number"
    +12
    +13    @classmethod
    +14    def get_topics(cls, xmlfile) -> Dict[int, Dict[str, str]]:
    +15        return super().get_topics(xmlfile)
    +
    + + +

    Load topics from an XML file

    +
    + + +
    + +
    +
    @classmethod
    + + def + get_topics(cls, xmlfile) -> Dict[int, Dict[str, str]]: + + + +
    + +
    13    @classmethod
    +14    def get_topics(cls, xmlfile) -> Dict[int, Dict[str, str]]:
    +15        return super().get_topics(xmlfile)
    +
    + + +

    Instance method for getting topics, forwards instance self parameters to the _get_topics class method.

    +
    + + +
    + +
    +
    + +
    + + class + TrecElasticsearchQuery(debeir.core.query.GenericElasticsearchQuery): + + + +
    + +
    18class TrecElasticsearchQuery(GenericElasticsearchQuery):
    +19    def __init__(self, topics, config, *args, **kwargs):
    +20        super().__init__(topics, config, *args, **kwargs)
    +21
    +22        self.mappings = ["title", "abstract", "fulltext"]
    +23
    +24        self.topics = topics
    +25        self.config = config
    +26        self.query_type = self.config.query_type
    +27
    +28        self.embed_mappings = [
    +29            "title_embedding",
    +30            "abstract_embedding",
    +31            "fulltext_embedding",
    +32        ]
    +33
    +34        self.id_mapping = "id"
    +35
    +36        self.query_funcs = {
    +37            "query": self.generate_query,
    +38            "embedding": self.generate_query_embedding,
    +39        }
    +
    + + +

    A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries. +Requires topics, configs to be included

    +
    + + +
    + +
    + + TrecElasticsearchQuery(topics, config, *args, **kwargs) + + + +
    + +
    19    def __init__(self, topics, config, *args, **kwargs):
    +20        super().__init__(topics, config, *args, **kwargs)
    +21
    +22        self.mappings = ["title", "abstract", "fulltext"]
    +23
    +24        self.topics = topics
    +25        self.config = config
    +26        self.query_type = self.config.query_type
    +27
    +28        self.embed_mappings = [
    +29            "title_embedding",
    +30            "abstract_embedding",
    +31            "fulltext_embedding",
    +32        ]
    +33
    +34        self.id_mapping = "id"
    +35
    +36        self.query_funcs = {
    +37            "query": self.generate_query,
    +38            "embedding": self.generate_query_embedding,
    +39        }
    +
    + + + + +
    + +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/datasets/types.html b/docs/debeir/datasets/types.html new file mode 100644 index 0000000..9034c17 --- /dev/null +++ b/docs/debeir/datasets/types.html @@ -0,0 +1,731 @@ + + + + + + + debeir.datasets.types API documentation + + + + + + + + + +
    +
    +

    +debeir.datasets.types

    + + + + + + +
     1import string
    + 2from collections import defaultdict
    + 3from enum import Enum
    + 4from typing import List, Union
    + 5
    + 6
    + 7class InputExample:
    + 8    """
    + 9    Copied from Sentence Transformer Library
    +10    Structure for one input example with texts, the label and a unique id
    +11    """
    +12
    +13    def __init__(self, guid: str = '', texts: List[str] = None, label: Union[int, float] = 0):
    +14        """
    +15        Creates one InputExample with the given texts, guid and label
    +16
    +17        :param guid
    +18            id for the example
    +19        :param texts
    +20            the texts for the example. Note, str.strip() is called on the texts
    +21        :param label
    +22            the label for the example
    +23        """
    +24        self.guid = guid
    +25        self.texts = [text.strip() for text in texts]
    +26        self.label = label
    +27
    +28    def __str__(self):
    +29        return "<InputExample> label: {}, texts: {}".format(str(self.label), "; ".join(self.texts))
    +30
    +31    def get_label(self):
    +32        return self.label
    +33
    +34    # def __getattr__(self, key):
    +35    #    if key == "label":
    +36    #        return self.get_label()
    +37
    +38    #    if key == "texts":
    +39    #        return self.texts
    +40
    +41    #    if key in ["guid", "id"]:
    +42    #        return self.guid
    +43
    +44    #    raise KeyError()
    +45
    +46    @classmethod
    +47    def to_dict(cls, data: List['InputExample']):
    +48        text_len = len(data[0].texts)
    +49        processed_data = defaultdict(lambda: [])
    +50
    +51        for datum in data:
    +52            # string.ascii_lowercase
    +53
    +54            processed_data["id"].append(datum.guid)
    +55            processed_data["label"].append(datum.get_label())
    +56
    +57            for i in range(text_len):
    +58                letter = string.ascii_lowercase[i]  # abcdefghi
    +59                # processed_data[text_a] = ...
    +60                processed_data[f"text_{letter}"].append(datum.texts[i])
    +61
    +62        return processed_data
    +63
    +64    @classmethod
    +65    def from_parser_output(cls, data):
    +66        pass
    +67
    +68
    +69class RelevanceExample(InputExample):
    +70    """
    +71    Converts Relevance Labels to 0 - 1
    +72    """
    +73
    +74    def __init__(self, max_score=2, *args, **kwargs):
    +75        super().__init__(*args, **kwargs)
    +76        self.max_score = max_score
    +77
    +78    def get_label(self):
    +79        return self.relevance()
    +80
    +81    def relevance(self):
    +82        """
    +83        :return:
    +84            Returns a normalised score for relevance between 0 - 1
    +85        """
    +86        return self.label / self.max_score
    +87
    +88
    +89class DatasetTypes(Enum):
    +90    """
    +91    A collection of common dataset types that is usable in the library.
    +92    """
    +93    List: "List"
    +94    ListInputExample: "ListInputExample"
    +95    ListDict: "ListDict"
    +96    HuggingfaceDataset: "HuggingfaceDataset"
    +
    + + +
    +
    + +
    + + class + InputExample: + + + +
    + +
     8class InputExample:
    + 9    """
    +10    Copied from Sentence Transformer Library
    +11    Structure for one input example with texts, the label and a unique id
    +12    """
    +13
    +14    def __init__(self, guid: str = '', texts: List[str] = None, label: Union[int, float] = 0):
    +15        """
    +16        Creates one InputExample with the given texts, guid and label
    +17
    +18        :param guid
    +19            id for the example
    +20        :param texts
    +21            the texts for the example. Note, str.strip() is called on the texts
    +22        :param label
    +23            the label for the example
    +24        """
    +25        self.guid = guid
    +26        self.texts = [text.strip() for text in texts]
    +27        self.label = label
    +28
    +29    def __str__(self):
    +30        return "<InputExample> label: {}, texts: {}".format(str(self.label), "; ".join(self.texts))
    +31
    +32    def get_label(self):
    +33        return self.label
    +34
    +35    # def __getattr__(self, key):
    +36    #    if key == "label":
    +37    #        return self.get_label()
    +38
    +39    #    if key == "texts":
    +40    #        return self.texts
    +41
    +42    #    if key in ["guid", "id"]:
    +43    #        return self.guid
    +44
    +45    #    raise KeyError()
    +46
    +47    @classmethod
    +48    def to_dict(cls, data: List['InputExample']):
    +49        text_len = len(data[0].texts)
    +50        processed_data = defaultdict(lambda: [])
    +51
    +52        for datum in data:
    +53            # string.ascii_lowercase
    +54
    +55            processed_data["id"].append(datum.guid)
    +56            processed_data["label"].append(datum.get_label())
    +57
    +58            for i in range(text_len):
    +59                letter = string.ascii_lowercase[i]  # abcdefghi
    +60                # processed_data[text_a] = ...
    +61                processed_data[f"text_{letter}"].append(datum.texts[i])
    +62
    +63        return processed_data
    +64
    +65    @classmethod
    +66    def from_parser_output(cls, data):
    +67        pass
    +
    + + +

    Copied from Sentence Transformer Library +Structure for one input example with texts, the label and a unique id

    +
    + + +
    + +
    + + InputExample( guid: str = '', texts: List[str] = None, label: Union[int, float] = 0) + + + +
    + +
    14    def __init__(self, guid: str = '', texts: List[str] = None, label: Union[int, float] = 0):
    +15        """
    +16        Creates one InputExample with the given texts, guid and label
    +17
    +18        :param guid
    +19            id for the example
    +20        :param texts
    +21            the texts for the example. Note, str.strip() is called on the texts
    +22        :param label
    +23            the label for the example
    +24        """
    +25        self.guid = guid
    +26        self.texts = [text.strip() for text in texts]
    +27        self.label = label
    +
    + + +

    Creates one InputExample with the given texts, guid and label

    + +

    :param guid + id for the example +:param texts + the texts for the example. Note, str.strip() is called on the texts +:param label + the label for the example

    +
    + + +
    +
    + +
    + + def + get_label(self): + + + +
    + +
    32    def get_label(self):
    +33        return self.label
    +
    + + + + +
    +
    + +
    +
    @classmethod
    + + def + to_dict(cls, data: List[debeir.datasets.types.InputExample]): + + + +
    + +
    47    @classmethod
    +48    def to_dict(cls, data: List['InputExample']):
    +49        text_len = len(data[0].texts)
    +50        processed_data = defaultdict(lambda: [])
    +51
    +52        for datum in data:
    +53            # string.ascii_lowercase
    +54
    +55            processed_data["id"].append(datum.guid)
    +56            processed_data["label"].append(datum.get_label())
    +57
    +58            for i in range(text_len):
    +59                letter = string.ascii_lowercase[i]  # abcdefghi
    +60                # processed_data[text_a] = ...
    +61                processed_data[f"text_{letter}"].append(datum.texts[i])
    +62
    +63        return processed_data
    +
    + + + + +
    +
    + +
    +
    @classmethod
    + + def + from_parser_output(cls, data): + + + +
    + +
    65    @classmethod
    +66    def from_parser_output(cls, data):
    +67        pass
    +
    + + + + +
    +
    +
    + +
    + + class + RelevanceExample(InputExample): + + + +
    + +
    70class RelevanceExample(InputExample):
    +71    """
    +72    Converts Relevance Labels to 0 - 1
    +73    """
    +74
    +75    def __init__(self, max_score=2, *args, **kwargs):
    +76        super().__init__(*args, **kwargs)
    +77        self.max_score = max_score
    +78
    +79    def get_label(self):
    +80        return self.relevance()
    +81
    +82    def relevance(self):
    +83        """
    +84        :return:
    +85            Returns a normalised score for relevance between 0 - 1
    +86        """
    +87        return self.label / self.max_score
    +
    + + +

    Converts Relevance Labels to 0 - 1

    +
    + + +
    + +
    + + RelevanceExample(max_score=2, *args, **kwargs) + + + +
    + +
    75    def __init__(self, max_score=2, *args, **kwargs):
    +76        super().__init__(*args, **kwargs)
    +77        self.max_score = max_score
    +
    + + +

    Creates one InputExample with the given texts, guid and label

    + +

    :param guid + id for the example +:param texts + the texts for the example. Note, str.strip() is called on the texts +:param label + the label for the example

    +
    + + +
    +
    + +
    + + def + get_label(self): + + + +
    + +
    79    def get_label(self):
    +80        return self.relevance()
    +
    + + + + +
    +
    + +
    + + def + relevance(self): + + + +
    + +
    82    def relevance(self):
    +83        """
    +84        :return:
    +85            Returns a normalised score for relevance between 0 - 1
    +86        """
    +87        return self.label / self.max_score
    +
    + + +
    Returns
    + +
    +
    Returns a normalised score for relevance between 0 - 1
    +
    +
    +
    + + +
    +
    +
    Inherited Members
    +
    + +
    +
    +
    +
    + +
    + + class + DatasetTypes(enum.Enum): + + + +
    + +
    90class DatasetTypes(Enum):
    +91    """
    +92    A collection of common dataset types that is usable in the library.
    +93    """
    +94    List: "List"
    +95    ListInputExample: "ListInputExample"
    +96    ListDict: "ListDict"
    +97    HuggingfaceDataset: "HuggingfaceDataset"
    +
    + + +

    A collection of common dataset types that is usable in the library.

    +
    + + +
    +
    Inherited Members
    +
    +
    enum.Enum
    +
    name
    +
    value
    + +
    +
    +
    +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/datasets/utils.html b/docs/debeir/datasets/utils.html new file mode 100644 index 0000000..9a5595e --- /dev/null +++ b/docs/debeir/datasets/utils.html @@ -0,0 +1,544 @@ + + + + + + + debeir.datasets.utils API documentation + + + + + + + + + +
    +
    +

    +debeir.datasets.utils

    + + + + + + +
     1# TODO: Convert a Parser Return Dict (Dict[int, Dict[str, ...])
    + 2
    + 3from debeir.datasets.types import DatasetTypes, InputExample
    + 4from debeir.evaluation.cross_validation import CrossValidator
    + 5from debeir.evaluation.evaluator import Evaluator
    + 6
    + 7import datasets
    + 8
    + 9
    +10class CrossValidatorDataset:
    +11    """
    +12    Cross Validator Dataset
    +13    """
    +14    cross_val_cls: CrossValidator
    +15
    +16    def __init__(self, dataset, cross_validator, n_folds, x_attr='text', y_attr='label'):
    +17        self.cross_val_cls = cross_validator
    +18        self.dataset = dataset
    +19        self.fold = 0
    +20        self.n_folds = n_folds
    +21        self.x_attr = x_attr
    +22        self.y_attr = y_attr
    +23        self.folds = []
    +24
    +25    @classmethod
    +26    def prepare_cross_validator(cls, data, evaluator: Evaluator,
    +27                                n_splits: int, x_attr, y_attr, seed=42) -> 'CrossValidatorDataset':
    +28        """
    +29        Prepare the cross validator dataset object that will internally produce the folds.
    +30
    +31        :param data: Dataset to be used. Should be a list of dicts, or list of [x,y] or a Dataset object from data_sets
    +32        :param evaluator: Evaluator to use for checking results
    +33        :param n_splits: Number of cross validation splits, k-fold (stratified)
    +34        :param seed: Seed to use (default 42)
    +35        :param y_attr: Label, or idx of the y label
    +36        :param x_attr: Label or idx of the x label (not directly used)
    +37        """
    +38
    +39        return cls(data, CrossValidator(evaluator, data, x_attr, y_attr,
    +40                                        n_splits=n_splits, seed=seed),
    +41                   x_attr=x_attr, y_attr=y_attr,
    +42                   n_folds=n_splits)
    +43
    +44    def get_fold(self, idx) -> datasets.DatasetDict:
    +45        """
    +46
    +47        Get the fold and returns a dataset.DataDict object with
    +48        DataDict{'train': ..., 'val': ...}
    +49
    +50        :param idx:
    +51        """
    +52
    +53        train_idxs, val_idxs = self.cross_val_cls.get_fold(idx)
    +54        dataset_dict = DatasetDict()
    +55
    +56        if self.cross_val_cls.dataset_type in [DatasetTypes.List, DatasetTypes.ListDict]:
    +57            # TODO: figure out how to make this into a huggingface dataset object generically
    +58            train_subset = [self.dataset[i] for i in train_idxs]
    +59            val_subset = [self.dataset[i] for i in val_idxs]
    +60        elif self.cross_val_cls.dataset_type == DatasetTypes.ListInputExample:
    +61            train_subset = InputExample.to_dict([self.dataset[i] for i in train_idxs])
    +62            val_subset = InputExample.to_dict([self.dataset[i] for i in val_idxs])
    +63
    +64            dataset_dict['train'] = datasets.Dataset.from_dict(train_subset)
    +65            dataset_dict['val'] = datasets.Dataset.from_dict(val_subset)
    +66
    +67        elif self.cross_val_cls.dataset_type == DatasetTypes.HuggingfaceDataset:
    +68            train_subset = self.dataset.select(train_idxs)
    +69            val_subset = self.dataset.select(val_idxs)
    +70
    +71            dataset_dict['train'] = datasets.Dataset.from_dict(train_subset)
    +72            dataset_dict['val'] = datasets.Dataset.from_dict(val_subset)
    +73
    +74        return dataset_dict
    +
    + + +
    +
    + +
    + + class + CrossValidatorDataset: + + + +
    + +
    11class CrossValidatorDataset:
    +12    """
    +13    Cross Validator Dataset
    +14    """
    +15    cross_val_cls: CrossValidator
    +16
    +17    def __init__(self, dataset, cross_validator, n_folds, x_attr='text', y_attr='label'):
    +18        self.cross_val_cls = cross_validator
    +19        self.dataset = dataset
    +20        self.fold = 0
    +21        self.n_folds = n_folds
    +22        self.x_attr = x_attr
    +23        self.y_attr = y_attr
    +24        self.folds = []
    +25
    +26    @classmethod
    +27    def prepare_cross_validator(cls, data, evaluator: Evaluator,
    +28                                n_splits: int, x_attr, y_attr, seed=42) -> 'CrossValidatorDataset':
    +29        """
    +30        Prepare the cross validator dataset object that will internally produce the folds.
    +31
    +32        :param data: Dataset to be used. Should be a list of dicts, or list of [x,y] or a Dataset object from data_sets
    +33        :param evaluator: Evaluator to use for checking results
    +34        :param n_splits: Number of cross validation splits, k-fold (stratified)
    +35        :param seed: Seed to use (default 42)
    +36        :param y_attr: Label, or idx of the y label
    +37        :param x_attr: Label or idx of the x label (not directly used)
    +38        """
    +39
    +40        return cls(data, CrossValidator(evaluator, data, x_attr, y_attr,
    +41                                        n_splits=n_splits, seed=seed),
    +42                   x_attr=x_attr, y_attr=y_attr,
    +43                   n_folds=n_splits)
    +44
    +45    def get_fold(self, idx) -> datasets.DatasetDict:
    +46        """
    +47
    +48        Get the fold and returns a dataset.DataDict object with
    +49        DataDict{'train': ..., 'val': ...}
    +50
    +51        :param idx:
    +52        """
    +53
    +54        train_idxs, val_idxs = self.cross_val_cls.get_fold(idx)
    +55        dataset_dict = DatasetDict()
    +56
    +57        if self.cross_val_cls.dataset_type in [DatasetTypes.List, DatasetTypes.ListDict]:
    +58            # TODO: figure out how to make this into a huggingface dataset object generically
    +59            train_subset = [self.dataset[i] for i in train_idxs]
    +60            val_subset = [self.dataset[i] for i in val_idxs]
    +61        elif self.cross_val_cls.dataset_type == DatasetTypes.ListInputExample:
    +62            train_subset = InputExample.to_dict([self.dataset[i] for i in train_idxs])
    +63            val_subset = InputExample.to_dict([self.dataset[i] for i in val_idxs])
    +64
    +65            dataset_dict['train'] = datasets.Dataset.from_dict(train_subset)
    +66            dataset_dict['val'] = datasets.Dataset.from_dict(val_subset)
    +67
    +68        elif self.cross_val_cls.dataset_type == DatasetTypes.HuggingfaceDataset:
    +69            train_subset = self.dataset.select(train_idxs)
    +70            val_subset = self.dataset.select(val_idxs)
    +71
    +72            dataset_dict['train'] = datasets.Dataset.from_dict(train_subset)
    +73            dataset_dict['val'] = datasets.Dataset.from_dict(val_subset)
    +74
    +75        return dataset_dict
    +
    + + +

    Cross Validator Dataset

    +
    + + +
    + +
    + + CrossValidatorDataset(dataset, cross_validator, n_folds, x_attr='text', y_attr='label') + + + +
    + +
    17    def __init__(self, dataset, cross_validator, n_folds, x_attr='text', y_attr='label'):
    +18        self.cross_val_cls = cross_validator
    +19        self.dataset = dataset
    +20        self.fold = 0
    +21        self.n_folds = n_folds
    +22        self.x_attr = x_attr
    +23        self.y_attr = y_attr
    +24        self.folds = []
    +
    + + + + +
    +
    + +
    +
    @classmethod
    + + def + prepare_cross_validator( cls, data, evaluator: debeir.evaluation.evaluator.Evaluator, n_splits: int, x_attr, y_attr, seed=42) -> debeir.datasets.utils.CrossValidatorDataset: + + + +
    + +
    26    @classmethod
    +27    def prepare_cross_validator(cls, data, evaluator: Evaluator,
    +28                                n_splits: int, x_attr, y_attr, seed=42) -> 'CrossValidatorDataset':
    +29        """
    +30        Prepare the cross validator dataset object that will internally produce the folds.
    +31
    +32        :param data: Dataset to be used. Should be a list of dicts, or list of [x,y] or a Dataset object from data_sets
    +33        :param evaluator: Evaluator to use for checking results
    +34        :param n_splits: Number of cross validation splits, k-fold (stratified)
    +35        :param seed: Seed to use (default 42)
    +36        :param y_attr: Label, or idx of the y label
    +37        :param x_attr: Label or idx of the x label (not directly used)
    +38        """
    +39
    +40        return cls(data, CrossValidator(evaluator, data, x_attr, y_attr,
    +41                                        n_splits=n_splits, seed=seed),
    +42                   x_attr=x_attr, y_attr=y_attr,
    +43                   n_folds=n_splits)
    +
    + + +

    Prepare the cross validator dataset object that will internally produce the folds.

    + +
    Parameters
    + +
      +
    • data: Dataset to be used. Should be a list of dicts, or list of [x,y] or a Dataset object from data_sets
    • +
    • evaluator: Evaluator to use for checking results
    • +
    • n_splits: Number of cross validation splits, k-fold (stratified)
    • +
    • seed: Seed to use (default 42)
    • +
    • y_attr: Label, or idx of the y label
    • +
    • x_attr: Label or idx of the x label (not directly used)
    • +
    +
    + + +
    +
    + +
    + + def + get_fold(self, idx) -> datasets.dataset_dict.DatasetDict: + + + +
    + +
    45    def get_fold(self, idx) -> datasets.DatasetDict:
    +46        """
    +47
    +48        Get the fold and returns a dataset.DataDict object with
    +49        DataDict{'train': ..., 'val': ...}
    +50
    +51        :param idx:
    +52        """
    +53
    +54        train_idxs, val_idxs = self.cross_val_cls.get_fold(idx)
    +55        dataset_dict = DatasetDict()
    +56
    +57        if self.cross_val_cls.dataset_type in [DatasetTypes.List, DatasetTypes.ListDict]:
    +58            # TODO: figure out how to make this into a huggingface dataset object generically
    +59            train_subset = [self.dataset[i] for i in train_idxs]
    +60            val_subset = [self.dataset[i] for i in val_idxs]
    +61        elif self.cross_val_cls.dataset_type == DatasetTypes.ListInputExample:
    +62            train_subset = InputExample.to_dict([self.dataset[i] for i in train_idxs])
    +63            val_subset = InputExample.to_dict([self.dataset[i] for i in val_idxs])
    +64
    +65            dataset_dict['train'] = datasets.Dataset.from_dict(train_subset)
    +66            dataset_dict['val'] = datasets.Dataset.from_dict(val_subset)
    +67
    +68        elif self.cross_val_cls.dataset_type == DatasetTypes.HuggingfaceDataset:
    +69            train_subset = self.dataset.select(train_idxs)
    +70            val_subset = self.dataset.select(val_idxs)
    +71
    +72            dataset_dict['train'] = datasets.Dataset.from_dict(train_subset)
    +73            dataset_dict['val'] = datasets.Dataset.from_dict(val_subset)
    +74
    +75        return dataset_dict
    +
    + + +

    Get the fold and returns a dataset.DataDict object with +DataDict{'train': ..., 'val': ...}

    + +
    Parameters
    + +
      +
    • idx:
    • +
    +
    + + +
    +
    +
    + + \ No newline at end of file diff --git a/docs/debeir/engines.html b/docs/debeir/engines.html index ae8ad74..58a049b 100644 --- a/docs/debeir/engines.html +++ b/docs/debeir/engines.html @@ -3,7 +3,7 @@ - + debeir.engines API documentation @@ -49,7 +49,9 @@

    Submodules

    debeir.engines

    -

    Implemented Search Engines to run queries against.

    +

    WIP

    + +

    Implemented Search Engines to run queries against.

    @@ -57,8 +59,10 @@

    1"""
    -2Implemented Search Engines to run queries against.
    -3"""
    +2WIP 
    +3
    +4Implemented Search Engines to run queries against.
    +5"""
     
    @@ -164,7 +168,7 @@

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -191,7 +195,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/engines/client.html b/docs/debeir/engines/client.html index f829d51..25a753f 100644 --- a/docs/debeir/engines/client.html +++ b/docs/debeir/engines/client.html @@ -3,7 +3,7 @@ - + debeir.engines.client API documentation @@ -70,56 +70,57 @@

     1import dataclasses
    - 2from elasticsearch import AsyncElasticsearch
    - 3
    + 2
    + 3from elasticsearch import AsyncElasticsearch
      4
    - 5@dataclasses.dataclass(init=True)
    - 6class Client:
    - 7    """
    - 8    Overarching client interface object that contains references to different clients for search
    - 9    Allows sharing between function calls
    -10    """
    -11    es_client: AsyncElasticsearch = None
    -12    solr_client: object = None
    -13    generic_client: object = None
    -14
    -15    @classmethod
    -16    def build_from_config(cls, engine_type, engine_config) -> 'Client':
    -17        """
    -18        Build client from engine config
    -19        :param engine_type:
    -20        :param engine_config:
    -21        :return:
    -22        """
    -23
    -24        client = Client()
    -25
    -26        if engine_type == "elasticsearch":
    -27            es_client = AsyncElasticsearch(
    -28                f"{engine_config.protocol}://{engine_config.ip}:{engine_config.port}",
    -29                timeout=engine_config.timeout
    -30            )
    -31
    -32            client.es_client = es_client
    -33
    -34        return client
    -35
    -36    def get_client(self, engine):
    -37        if engine == "elasticsearch":
    -38            return self.es_client
    -39
    -40    async def close(self):
    -41        """
    -42        Generically close all contained client objects
    -43        """
    -44        if self.es_client:
    -45            await self.es_client.close()
    -46
    -47        if self.solr_client:
    -48            await self.solr_client.close()
    -49
    -50        if self.generic_client:
    -51            await self.generic_client.close()
    + 5
    + 6@dataclasses.dataclass(init=True)
    + 7class Client:
    + 8    """
    + 9    Overarching client interface object that contains references to different clients for search
    +10    Allows sharing between function calls
    +11    """
    +12    es_client: AsyncElasticsearch = None
    +13    solr_client: object = None
    +14    generic_client: object = None
    +15
    +16    @classmethod
    +17    def build_from_config(cls, engine_type, engine_config) -> 'Client':
    +18        """
    +19        Build client from engine config
    +20        :param engine_type:
    +21        :param engine_config:
    +22        :return:
    +23        """
    +24
    +25        client = Client()
    +26
    +27        if engine_type == "elasticsearch":
    +28            es_client = AsyncElasticsearch(
    +29                f"{engine_config.protocol}://{engine_config.ip}:{engine_config.port}",
    +30                timeout=engine_config.timeout
    +31            )
    +32
    +33            client.es_client = es_client
    +34
    +35        return client
    +36
    +37    def get_client(self, engine):
    +38        if engine == "elasticsearch":
    +39            return self.es_client
    +40
    +41    async def close(self):
    +42        """
    +43        Generically close all contained client objects
    +44        """
    +45        if self.es_client:
    +46            await self.es_client.close()
    +47
    +48        if self.solr_client:
    +49            await self.solr_client.close()
    +50
    +51        if self.generic_client:
    +52            await self.generic_client.close()
     
    @@ -136,53 +137,53 @@

    -
     6@dataclasses.dataclass(init=True)
    - 7class Client:
    - 8    """
    - 9    Overarching client interface object that contains references to different clients for search
    -10    Allows sharing between function calls
    -11    """
    -12    es_client: AsyncElasticsearch = None
    -13    solr_client: object = None
    -14    generic_client: object = None
    -15
    -16    @classmethod
    -17    def build_from_config(cls, engine_type, engine_config) -> 'Client':
    -18        """
    -19        Build client from engine config
    -20        :param engine_type:
    -21        :param engine_config:
    -22        :return:
    -23        """
    -24
    -25        client = Client()
    -26
    -27        if engine_type == "elasticsearch":
    -28            es_client = AsyncElasticsearch(
    -29                f"{engine_config.protocol}://{engine_config.ip}:{engine_config.port}",
    -30                timeout=engine_config.timeout
    -31            )
    -32
    -33            client.es_client = es_client
    -34
    -35        return client
    -36
    -37    def get_client(self, engine):
    -38        if engine == "elasticsearch":
    -39            return self.es_client
    -40
    -41    async def close(self):
    -42        """
    -43        Generically close all contained client objects
    -44        """
    -45        if self.es_client:
    -46            await self.es_client.close()
    -47
    -48        if self.solr_client:
    -49            await self.solr_client.close()
    -50
    -51        if self.generic_client:
    -52            await self.generic_client.close()
    +            
     7@dataclasses.dataclass(init=True)
    + 8class Client:
    + 9    """
    +10    Overarching client interface object that contains references to different clients for search
    +11    Allows sharing between function calls
    +12    """
    +13    es_client: AsyncElasticsearch = None
    +14    solr_client: object = None
    +15    generic_client: object = None
    +16
    +17    @classmethod
    +18    def build_from_config(cls, engine_type, engine_config) -> 'Client':
    +19        """
    +20        Build client from engine config
    +21        :param engine_type:
    +22        :param engine_config:
    +23        :return:
    +24        """
    +25
    +26        client = Client()
    +27
    +28        if engine_type == "elasticsearch":
    +29            es_client = AsyncElasticsearch(
    +30                f"{engine_config.protocol}://{engine_config.ip}:{engine_config.port}",
    +31                timeout=engine_config.timeout
    +32            )
    +33
    +34            client.es_client = es_client
    +35
    +36        return client
    +37
    +38    def get_client(self, engine):
    +39        if engine == "elasticsearch":
    +40            return self.es_client
    +41
    +42    async def close(self):
    +43        """
    +44        Generically close all contained client objects
    +45        """
    +46        if self.es_client:
    +47            await self.es_client.close()
    +48
    +49        if self.solr_client:
    +50            await self.solr_client.close()
    +51
    +52        if self.generic_client:
    +53            await self.generic_client.close()
     
    @@ -215,26 +216,26 @@

    -
    16    @classmethod
    -17    def build_from_config(cls, engine_type, engine_config) -> 'Client':
    -18        """
    -19        Build client from engine config
    -20        :param engine_type:
    -21        :param engine_config:
    -22        :return:
    -23        """
    -24
    -25        client = Client()
    -26
    -27        if engine_type == "elasticsearch":
    -28            es_client = AsyncElasticsearch(
    -29                f"{engine_config.protocol}://{engine_config.ip}:{engine_config.port}",
    -30                timeout=engine_config.timeout
    -31            )
    -32
    -33            client.es_client = es_client
    -34
    -35        return client
    +            
    17    @classmethod
    +18    def build_from_config(cls, engine_type, engine_config) -> 'Client':
    +19        """
    +20        Build client from engine config
    +21        :param engine_type:
    +22        :param engine_config:
    +23        :return:
    +24        """
    +25
    +26        client = Client()
    +27
    +28        if engine_type == "elasticsearch":
    +29            es_client = AsyncElasticsearch(
    +30                f"{engine_config.protocol}://{engine_config.ip}:{engine_config.port}",
    +31                timeout=engine_config.timeout
    +32            )
    +33
    +34            client.es_client = es_client
    +35
    +36        return client
     
    @@ -263,9 +264,9 @@
    Returns
    -
    37    def get_client(self, engine):
    -38        if engine == "elasticsearch":
    -39            return self.es_client
    +            
    38    def get_client(self, engine):
    +39        if engine == "elasticsearch":
    +40            return self.es_client
     
    @@ -283,18 +284,18 @@
    Returns
    -
    41    async def close(self):
    -42        """
    -43        Generically close all contained client objects
    -44        """
    -45        if self.es_client:
    -46            await self.es_client.close()
    -47
    -48        if self.solr_client:
    -49            await self.solr_client.close()
    -50
    -51        if self.generic_client:
    -52            await self.generic_client.close()
    +            
    42    async def close(self):
    +43        """
    +44        Generically close all contained client objects
    +45        """
    +46        if self.es_client:
    +47            await self.es_client.close()
    +48
    +49        if self.solr_client:
    +50            await self.solr_client.close()
    +51
    +52        if self.generic_client:
    +53            await self.generic_client.close()
     
    @@ -405,7 +406,7 @@
    Returns
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -432,7 +433,7 @@
    Returns
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/engines/dummyindex.html b/docs/debeir/engines/dummyindex.html index 3bd7896..02e3761 100644 --- a/docs/debeir/engines/dummyindex.html +++ b/docs/debeir/engines/dummyindex.html @@ -3,7 +3,7 @@ - + debeir.engines.dummyindex API documentation @@ -152,7 +152,7 @@

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -179,7 +179,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/engines/dummyindex/index.html b/docs/debeir/engines/dummyindex/index.html index 3241ba3..0345846 100644 --- a/docs/debeir/engines/dummyindex/index.html +++ b/docs/debeir/engines/dummyindex/index.html @@ -3,7 +3,7 @@ - + debeir.engines.dummyindex.index API documentation @@ -307,7 +307,7 @@

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -334,7 +334,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/engines/elasticsearch.html b/docs/debeir/engines/elasticsearch.html index ee12a23..841169b 100644 --- a/docs/debeir/engines/elasticsearch.html +++ b/docs/debeir/engines/elasticsearch.html @@ -3,7 +3,7 @@ - + debeir.engines.elasticsearch API documentation @@ -48,12 +48,20 @@

    Submodules

    debeir.engines.elasticsearch

    - +

    Library code for interacting with the elasticsearch engine

    + +

    Contains many helper functions for asynchronous and fast querying, with optional caching available

    +
    + -
    1
    +                        
    1"""
    +2Library code for interacting with the elasticsearch engine
    +3
    +4Contains many helper functions for asynchronous and fast querying, with optional caching available
    +5"""
     
    @@ -159,7 +167,7 @@

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -186,7 +194,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/engines/elasticsearch/change_bm25.html b/docs/debeir/engines/elasticsearch/change_bm25.html index e792756..087f54f 100644 --- a/docs/debeir/engines/elasticsearch/change_bm25.html +++ b/docs/debeir/engines/elasticsearch/change_bm25.html @@ -3,7 +3,7 @@ - + debeir.engines.elasticsearch.change_bm25 API documentation @@ -56,77 +56,76 @@

     1import json
      2
    - 3import elasticsearch
    - 4import requests
    - 5from loguru import logger
    - 6
    - 7# echo "k = $k b = $b"
    - 8#
    - 9# curl -X POST "localhost:9200/${INDEX}/_close?pretty"
    -10#
    -11# curl -X PUT "localhost:9200/${INDEX}/_settings?pretty" -H 'Content-Type: application/json' -d"
    -12# {
    -13#  \"index\": {
    -14#    \"similarity\": {
    -15#      \"default\": {
    -16#        \"type\": \"BM25\",
    -17#        \"b\": ${b},
    -18#        \"k1\": ${k}
    -19#      }
    -20#    }
    -21#  }
    -22# }"
    -23# curl -X POST "localhost:9200/${INDEX}/_open?pretty"
    -24#
    -25# sleep 10
    + 3import requests
    + 4
    + 5
    + 6# echo "k = $k b = $b"
    + 7#
    + 8# curl -X POST "localhost:9200/${INDEX}/_close?pretty"
    + 9#
    +10# curl -X PUT "localhost:9200/${INDEX}/_settings?pretty" -H 'Content-Type: application/json' -d"
    +11# {
    +12#  \"index\": {
    +13#    \"similarity\": {
    +14#      \"default\": {
    +15#        \"type\": \"BM25\",
    +16#        \"b\": ${b},
    +17#        \"k1\": ${k}
    +18#      }
    +19#    }
    +20#  }
    +21# }"
    +22# curl -X POST "localhost:9200/${INDEX}/_open?pretty"
    +23#
    +24# sleep 10
    +25
     26
    -27
    -28def change_bm25_params(index, k1: float, b: float, base_url: str="http://localhost:9200"):
    -29    """
    -30    Change the BM25 parameters of the elasticsearch BM25 ranker.
    -31
    -32    :param index: The elasticsearch index name
    -33    :param k1: The k parameter for BM25 (default 1.2) [Usually 0-3] [Term saturation constant] ->
    -34               The higher the k value, the more weight given to document that repeat terms.
    -35    :param b: The b parameter for BM25 (default 0.75) [Usually 0-1] [Document length constant] ->
    -36              The higher the b value, the higher it penalises longer documents.
    -37    :param base_url: The elasticsearch base URL for API requests (without index suffix)
    -38    """
    -39    base_url = f"{base_url}/{index}"
    -40
    -41    resp = requests.post(base_url + "/_open?pretty", timeout=60)
    -42
    -43    if not resp.ok:
    -44        raise RuntimeError("Response code:", resp.status_code, resp.text)
    -45
    -46    resp = requests.post(base_url + "/_close?pretty", timeout=60)
    -47
    -48    if not resp.ok:
    -49        raise RuntimeError("Response code:", resp.status_code, resp.text)
    -50
    -51    headers = {"Content-type": "application/json"}
    -52
    -53    data = {
    -54      "index": {
    -55        "similarity": {
    -56          "default": {
    -57            "type": "BM25",
    -58            "b": b,
    -59            "k1": k1,
    -60          }
    +27def change_bm25_params(index, k1: float, b: float, base_url: str = "http://localhost:9200"):
    +28    """
    +29    Change the BM25 parameters of the elasticsearch BM25 ranker.
    +30
    +31    :param index: The elasticsearch index name
    +32    :param k1: The k parameter for BM25 (default 1.2) [Usually 0-3] [Term saturation constant] ->
    +33               The higher the k value, the more weight given to document that repeat terms.
    +34    :param b: The b parameter for BM25 (default 0.75) [Usually 0-1] [Document length constant] ->
    +35              The higher the b value, the higher it penalises longer documents.
    +36    :param base_url: The elasticsearch base URL for API requests (without index suffix)
    +37    """
    +38    base_url = f"{base_url}/{index}"
    +39
    +40    resp = requests.post(base_url + "/_open?pretty", timeout=60)
    +41
    +42    if not resp.ok:
    +43        raise RuntimeError("Response code:", resp.status_code, resp.text)
    +44
    +45    resp = requests.post(base_url + "/_close?pretty", timeout=60)
    +46
    +47    if not resp.ok:
    +48        raise RuntimeError("Response code:", resp.status_code, resp.text)
    +49
    +50    headers = {"Content-type": "application/json"}
    +51
    +52    data = {
    +53        "index": {
    +54            "similarity": {
    +55                "default": {
    +56                    "type": "BM25",
    +57                    "b": b,
    +58                    "k1": k1,
    +59                }
    +60            }
     61        }
    -62      }
    -63     }
    -64
    -65    resp = requests.put(base_url+"/_settings", headers=headers, data=json.dumps(data), timeout=60)
    -66
    -67    if not resp.ok:
    -68        raise RuntimeError("Response code:", resp.status_code, resp.text)
    -69
    -70    resp = requests.post(base_url + "/_open?pretty", timeout=60)
    -71
    -72    if not resp.ok:
    -73        raise RuntimeError("Response code:", resp.status_code, resp.text)
    +62    }
    +63
    +64    resp = requests.put(base_url + "/_settings", headers=headers, data=json.dumps(data), timeout=60)
    +65
    +66    if not resp.ok:
    +67        raise RuntimeError("Response code:", resp.status_code, resp.text)
    +68
    +69    resp = requests.post(base_url + "/_open?pretty", timeout=60)
    +70
    +71    if not resp.ok:
    +72        raise RuntimeError("Response code:", resp.status_code, resp.text)
     
    @@ -142,52 +141,52 @@

    -
    29def change_bm25_params(index, k1: float, b: float, base_url: str="http://localhost:9200"):
    -30    """
    -31    Change the BM25 parameters of the elasticsearch BM25 ranker.
    -32
    -33    :param index: The elasticsearch index name
    -34    :param k1: The k parameter for BM25 (default 1.2) [Usually 0-3] [Term saturation constant] ->
    -35               The higher the k value, the more weight given to document that repeat terms.
    -36    :param b: The b parameter for BM25 (default 0.75) [Usually 0-1] [Document length constant] ->
    -37              The higher the b value, the higher it penalises longer documents.
    -38    :param base_url: The elasticsearch base URL for API requests (without index suffix)
    -39    """
    -40    base_url = f"{base_url}/{index}"
    -41
    -42    resp = requests.post(base_url + "/_open?pretty", timeout=60)
    -43
    -44    if not resp.ok:
    -45        raise RuntimeError("Response code:", resp.status_code, resp.text)
    -46
    -47    resp = requests.post(base_url + "/_close?pretty", timeout=60)
    -48
    -49    if not resp.ok:
    -50        raise RuntimeError("Response code:", resp.status_code, resp.text)
    -51
    -52    headers = {"Content-type": "application/json"}
    -53
    -54    data = {
    -55      "index": {
    -56        "similarity": {
    -57          "default": {
    -58            "type": "BM25",
    -59            "b": b,
    -60            "k1": k1,
    -61          }
    +            
    28def change_bm25_params(index, k1: float, b: float, base_url: str = "http://localhost:9200"):
    +29    """
    +30    Change the BM25 parameters of the elasticsearch BM25 ranker.
    +31
    +32    :param index: The elasticsearch index name
    +33    :param k1: The k parameter for BM25 (default 1.2) [Usually 0-3] [Term saturation constant] ->
    +34               The higher the k value, the more weight given to document that repeat terms.
    +35    :param b: The b parameter for BM25 (default 0.75) [Usually 0-1] [Document length constant] ->
    +36              The higher the b value, the higher it penalises longer documents.
    +37    :param base_url: The elasticsearch base URL for API requests (without index suffix)
    +38    """
    +39    base_url = f"{base_url}/{index}"
    +40
    +41    resp = requests.post(base_url + "/_open?pretty", timeout=60)
    +42
    +43    if not resp.ok:
    +44        raise RuntimeError("Response code:", resp.status_code, resp.text)
    +45
    +46    resp = requests.post(base_url + "/_close?pretty", timeout=60)
    +47
    +48    if not resp.ok:
    +49        raise RuntimeError("Response code:", resp.status_code, resp.text)
    +50
    +51    headers = {"Content-type": "application/json"}
    +52
    +53    data = {
    +54        "index": {
    +55            "similarity": {
    +56                "default": {
    +57                    "type": "BM25",
    +58                    "b": b,
    +59                    "k1": k1,
    +60                }
    +61            }
     62        }
    -63      }
    -64     }
    -65
    -66    resp = requests.put(base_url+"/_settings", headers=headers, data=json.dumps(data), timeout=60)
    -67
    -68    if not resp.ok:
    -69        raise RuntimeError("Response code:", resp.status_code, resp.text)
    -70
    -71    resp = requests.post(base_url + "/_open?pretty", timeout=60)
    -72
    -73    if not resp.ok:
    -74        raise RuntimeError("Response code:", resp.status_code, resp.text)
    +63    }
    +64
    +65    resp = requests.put(base_url + "/_settings", headers=headers, data=json.dumps(data), timeout=60)
    +66
    +67    if not resp.ok:
    +68        raise RuntimeError("Response code:", resp.status_code, resp.text)
    +69
    +70    resp = requests.post(base_url + "/_open?pretty", timeout=60)
    +71
    +72    if not resp.ok:
    +73        raise RuntimeError("Response code:", resp.status_code, resp.text)
     
    @@ -308,7 +307,7 @@
    Parameters
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -335,7 +334,7 @@
    Parameters
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/engines/elasticsearch/executor.html b/docs/debeir/engines/elasticsearch/executor.html index 6594d03..d3635c1 100644 --- a/docs/debeir/engines/elasticsearch/executor.html +++ b/docs/debeir/engines/elasticsearch/executor.html @@ -3,7 +3,7 @@ - + debeir.engines.elasticsearch.executor API documentation @@ -69,15 +69,15 @@

    -
      1from typing import Dict, Optional, Union, List
    +                        
      1from typing import Dict, List, Optional, Union
       2
       3import tqdm.asyncio
    -  4from elasticsearch import AsyncElasticsearch as Elasticsearch
    -  5
    -  6from debeir.rankers.transformer_sent_encoder import Encoder
    -  7from debeir.interfaces.query import GenericElasticsearchQuery
    -  8from debeir.interfaces.config import apply_config
    -  9from debeir.utils.utils import unpack_coroutine
    +  4from debeir.core.config import apply_config
    +  5from debeir.core.document import document_factory
    +  6from debeir.core.query import GenericElasticsearchQuery
    +  7from debeir.rankers.transformer_sent_encoder import Encoder
    +  8from debeir.utils.utils import unpack_coroutine
    +  9from elasticsearch import AsyncElasticsearch as Elasticsearch
      10
      11
      12class ElasticsearchExecutor:
    @@ -90,109 +90,111 @@ 

    19 2. End-to-End Neural IR 20 3. Statistical keyword matching 21 """ - 22 def __init__( - 23 self, - 24 topics: Dict[Union[str, int], Dict[str, str]], - 25 client: Elasticsearch, - 26 index_name: str, - 27 output_file: str, - 28 query: GenericElasticsearchQuery, - 29 encoder: Optional[Encoder], - 30 return_size: int = 1000, - 31 test=False, - 32 return_id_only=True, - 33 config=None, - 34 ): - 35 self.topics = {"1": topics["1"]} if test else topics - 36 self.client = client - 37 self.index_name = index_name - 38 self.output_file = output_file - 39 self.return_size = return_size - 40 self.query = query - 41 self.encoder = encoder - 42 self.return_id_only = return_id_only - 43 self.config = config - 44 - 45 def generate_query(self, topic_num): - 46 """ - 47 Generates a query given a topic number from the list of topics - 48 - 49 :param topic_num: - 50 """ - 51 raise NotImplementedError - 52 - 53 def execute_query(self, *args, **kwargs): - 54 """ - 55 Execute a query given parameters - 56 - 57 :param args: - 58 :param kwargs: - 59 """ - 60 raise NotImplementedError - 61 - 62 @apply_config - 63 def _update_kwargs(self, **kwargs): - 64 return kwargs - 65 - 66 async def run_all_queries( - 67 self, query_type=None, return_results=False, - 68 return_size: int = None, return_id_only: bool = False, **kwargs - 69 ) -> List: - 70 """ - 71 A generic function that will asynchronously run all topics using the execute_query() method - 72 - 73 :param query_type: Which query to execute. Query_type determines which method is used to generate the queries - 74 from self.query.query_funcs: Dict[str, func] - 75 :param return_results: Whether to return raw results from the client. Useful for analysing results directly or - 76 for computing the BM25 scores for log normalization in NIR-style scoring - 77 :param return_size: Number of documents to return. Overrides the config value if exists. - 78 :param return_id_only: Return the ID of the document only, rather than the full source document. - 79 :param args: Arguments to pass to the execute_query method - 80 :param kwargs: Keyword arguments to pass to the execute_query method - 81 :return: - 82 A list of results if return_results = True else an empty list is returned. - 83 """ - 84 if not await self.client.ping(): - 85 await self.client.close() - 86 raise RuntimeError( - 87 f"Elasticsearch instance cannot be reached at {self.client}" - 88 ) - 89 - 90 kwargs = self._update_kwargs(**kwargs) + 22 + 23 def __init__( + 24 self, + 25 topics: Dict[Union[str, int], Dict[str, str]], + 26 client: Elasticsearch, + 27 index_name: str, + 28 output_file: str, + 29 query: GenericElasticsearchQuery, + 30 encoder: Optional[Encoder], + 31 return_size: int = 1000, + 32 test=False, + 33 return_id_only=True, + 34 config=None, + 35 ): + 36 self.topics = {"1": topics["1"]} if test else topics + 37 self.client = client + 38 self.index_name = index_name + 39 self.output_file = output_file + 40 self.return_size = return_size + 41 self.query = query + 42 self.encoder = encoder + 43 self.return_id_only = return_id_only + 44 self.config = config + 45 self.document_cls = document_factory['elasticsearch'] + 46 + 47 def generate_query(self, topic_num): + 48 """ + 49 Generates a query given a topic number from the list of topics + 50 + 51 :param topic_num: + 52 """ + 53 raise NotImplementedError + 54 + 55 def execute_query(self, *args, **kwargs): + 56 """ + 57 Execute a query given parameters + 58 + 59 :param args: + 60 :param kwargs: + 61 """ + 62 raise NotImplementedError + 63 + 64 @apply_config + 65 def _update_kwargs(self, **kwargs): + 66 return kwargs + 67 + 68 async def run_all_queries( + 69 self, query_type=None, return_results=False, + 70 return_size: int = None, return_id_only: bool = False, **kwargs + 71 ) -> List: + 72 """ + 73 A generic function that will asynchronously run all topics using the execute_query() method + 74 + 75 :param query_type: Which query to execute. Query_type determines which method is used to generate the queries + 76 from self.query.query_funcs: Dict[str, func] + 77 :param return_results: Whether to return raw results from the client. Useful for analysing results directly or + 78 for computing the BM25 scores for log normalization in NIR-style scoring + 79 :param return_size: Number of documents to return. Overrides the config value if exists. + 80 :param return_id_only: Return the ID of the document only, rather than the full source document. + 81 :param args: Arguments to pass to the execute_query method + 82 :param kwargs: Keyword arguments to pass to the execute_query method + 83 :return: + 84 A list of results if return_results = True else an empty list is returned. + 85 """ + 86 if not await self.client.ping(): + 87 await self.client.close() + 88 raise RuntimeError( + 89 f"Elasticsearch instance cannot be reached at {self.client}" + 90 ) 91 - 92 if return_size is None: - 93 return_size = self.return_size - 94 - 95 if return_id_only is None: - 96 return_id_only = self.return_id_only - 97 - 98 if query_type is None: - 99 query_type = self.config.query_type -100 -101 kwargs.pop('return_size', None) -102 kwargs.pop('return_id_only', None) -103 kwargs.pop('query_type', None) -104 -105 tasks = [ -106 self.execute_query( -107 topic_num=topic_num, -108 query_type=query_type, -109 return_size=return_size, -110 return_id_only=return_id_only, -111 **kwargs -112 ) -113 for topic_num in self.topics -114 ] -115 -116 results = [] + 92 kwargs = self._update_kwargs(**kwargs) + 93 + 94 if return_size is None: + 95 return_size = self.return_size + 96 + 97 if return_id_only is None: + 98 return_id_only = self.return_id_only + 99 +100 if query_type is None: +101 query_type = self.config.query_type +102 +103 kwargs.pop('return_size', None) +104 kwargs.pop('return_id_only', None) +105 kwargs.pop('query_type', None) +106 +107 tasks = [ +108 self.execute_query( +109 topic_num=topic_num, +110 query_type=query_type, +111 return_size=return_size, +112 return_id_only=return_id_only, +113 **kwargs +114 ) +115 for topic_num in self.topics +116 ] 117 -118 for f in tqdm.asyncio.tqdm.as_completed(tasks, desc="Running Queries"): -119 res = await unpack_coroutine(f) -120 -121 if return_results: -122 results.append(res) -123 -124 return results +118 results = [] +119 +120 for f in tqdm.asyncio.tqdm.as_completed(tasks, desc="Running Queries"): +121 res = await unpack_coroutine(f) +122 +123 if return_results: +124 results.append(res) +125 +126 return results

    @@ -218,109 +220,111 @@

    20 2. End-to-End Neural IR 21 3. Statistical keyword matching 22 """ - 23 def __init__( - 24 self, - 25 topics: Dict[Union[str, int], Dict[str, str]], - 26 client: Elasticsearch, - 27 index_name: str, - 28 output_file: str, - 29 query: GenericElasticsearchQuery, - 30 encoder: Optional[Encoder], - 31 return_size: int = 1000, - 32 test=False, - 33 return_id_only=True, - 34 config=None, - 35 ): - 36 self.topics = {"1": topics["1"]} if test else topics - 37 self.client = client - 38 self.index_name = index_name - 39 self.output_file = output_file - 40 self.return_size = return_size - 41 self.query = query - 42 self.encoder = encoder - 43 self.return_id_only = return_id_only - 44 self.config = config - 45 - 46 def generate_query(self, topic_num): - 47 """ - 48 Generates a query given a topic number from the list of topics - 49 - 50 :param topic_num: - 51 """ - 52 raise NotImplementedError - 53 - 54 def execute_query(self, *args, **kwargs): - 55 """ - 56 Execute a query given parameters - 57 - 58 :param args: - 59 :param kwargs: - 60 """ - 61 raise NotImplementedError - 62 - 63 @apply_config - 64 def _update_kwargs(self, **kwargs): - 65 return kwargs - 66 - 67 async def run_all_queries( - 68 self, query_type=None, return_results=False, - 69 return_size: int = None, return_id_only: bool = False, **kwargs - 70 ) -> List: - 71 """ - 72 A generic function that will asynchronously run all topics using the execute_query() method - 73 - 74 :param query_type: Which query to execute. Query_type determines which method is used to generate the queries - 75 from self.query.query_funcs: Dict[str, func] - 76 :param return_results: Whether to return raw results from the client. Useful for analysing results directly or - 77 for computing the BM25 scores for log normalization in NIR-style scoring - 78 :param return_size: Number of documents to return. Overrides the config value if exists. - 79 :param return_id_only: Return the ID of the document only, rather than the full source document. - 80 :param args: Arguments to pass to the execute_query method - 81 :param kwargs: Keyword arguments to pass to the execute_query method - 82 :return: - 83 A list of results if return_results = True else an empty list is returned. - 84 """ - 85 if not await self.client.ping(): - 86 await self.client.close() - 87 raise RuntimeError( - 88 f"Elasticsearch instance cannot be reached at {self.client}" - 89 ) - 90 - 91 kwargs = self._update_kwargs(**kwargs) + 23 + 24 def __init__( + 25 self, + 26 topics: Dict[Union[str, int], Dict[str, str]], + 27 client: Elasticsearch, + 28 index_name: str, + 29 output_file: str, + 30 query: GenericElasticsearchQuery, + 31 encoder: Optional[Encoder], + 32 return_size: int = 1000, + 33 test=False, + 34 return_id_only=True, + 35 config=None, + 36 ): + 37 self.topics = {"1": topics["1"]} if test else topics + 38 self.client = client + 39 self.index_name = index_name + 40 self.output_file = output_file + 41 self.return_size = return_size + 42 self.query = query + 43 self.encoder = encoder + 44 self.return_id_only = return_id_only + 45 self.config = config + 46 self.document_cls = document_factory['elasticsearch'] + 47 + 48 def generate_query(self, topic_num): + 49 """ + 50 Generates a query given a topic number from the list of topics + 51 + 52 :param topic_num: + 53 """ + 54 raise NotImplementedError + 55 + 56 def execute_query(self, *args, **kwargs): + 57 """ + 58 Execute a query given parameters + 59 + 60 :param args: + 61 :param kwargs: + 62 """ + 63 raise NotImplementedError + 64 + 65 @apply_config + 66 def _update_kwargs(self, **kwargs): + 67 return kwargs + 68 + 69 async def run_all_queries( + 70 self, query_type=None, return_results=False, + 71 return_size: int = None, return_id_only: bool = False, **kwargs + 72 ) -> List: + 73 """ + 74 A generic function that will asynchronously run all topics using the execute_query() method + 75 + 76 :param query_type: Which query to execute. Query_type determines which method is used to generate the queries + 77 from self.query.query_funcs: Dict[str, func] + 78 :param return_results: Whether to return raw results from the client. Useful for analysing results directly or + 79 for computing the BM25 scores for log normalization in NIR-style scoring + 80 :param return_size: Number of documents to return. Overrides the config value if exists. + 81 :param return_id_only: Return the ID of the document only, rather than the full source document. + 82 :param args: Arguments to pass to the execute_query method + 83 :param kwargs: Keyword arguments to pass to the execute_query method + 84 :return: + 85 A list of results if return_results = True else an empty list is returned. + 86 """ + 87 if not await self.client.ping(): + 88 await self.client.close() + 89 raise RuntimeError( + 90 f"Elasticsearch instance cannot be reached at {self.client}" + 91 ) 92 - 93 if return_size is None: - 94 return_size = self.return_size - 95 - 96 if return_id_only is None: - 97 return_id_only = self.return_id_only - 98 - 99 if query_type is None: -100 query_type = self.config.query_type -101 -102 kwargs.pop('return_size', None) -103 kwargs.pop('return_id_only', None) -104 kwargs.pop('query_type', None) -105 -106 tasks = [ -107 self.execute_query( -108 topic_num=topic_num, -109 query_type=query_type, -110 return_size=return_size, -111 return_id_only=return_id_only, -112 **kwargs -113 ) -114 for topic_num in self.topics -115 ] -116 -117 results = [] + 93 kwargs = self._update_kwargs(**kwargs) + 94 + 95 if return_size is None: + 96 return_size = self.return_size + 97 + 98 if return_id_only is None: + 99 return_id_only = self.return_id_only +100 +101 if query_type is None: +102 query_type = self.config.query_type +103 +104 kwargs.pop('return_size', None) +105 kwargs.pop('return_id_only', None) +106 kwargs.pop('query_type', None) +107 +108 tasks = [ +109 self.execute_query( +110 topic_num=topic_num, +111 query_type=query_type, +112 return_size=return_size, +113 return_id_only=return_id_only, +114 **kwargs +115 ) +116 for topic_num in self.topics +117 ] 118 -119 for f in tqdm.asyncio.tqdm.as_completed(tasks, desc="Running Queries"): -120 res = await unpack_coroutine(f) -121 -122 if return_results: -123 results.append(res) -124 -125 return results +119 results = [] +120 +121 for f in tqdm.asyncio.tqdm.as_completed(tasks, desc="Running Queries"): +122 res = await unpack_coroutine(f) +123 +124 if return_results: +125 results.append(res) +126 +127 return results

    @@ -338,34 +342,35 @@

    - ElasticsearchExecutor( topics: Dict[Union[str, int], Dict[str, str]], client: elasticsearch.AsyncElasticsearch, index_name: str, output_file: str, query: debeir.interfaces.query.GenericElasticsearchQuery, encoder: Optional[debeir.rankers.transformer_sent_encoder.Encoder], return_size: int = 1000, test=False, return_id_only=True, config=None) + ElasticsearchExecutor( topics: Dict[Union[str, int], Dict[str, str]], client: elasticsearch.AsyncElasticsearch, index_name: str, output_file: str, query: debeir.core.query.GenericElasticsearchQuery, encoder: Optional[debeir.rankers.transformer_sent_encoder.Encoder], return_size: int = 1000, test=False, return_id_only=True, config=None)
    -
    23    def __init__(
    -24        self,
    -25        topics: Dict[Union[str, int], Dict[str, str]],
    -26        client: Elasticsearch,
    -27        index_name: str,
    -28        output_file: str,
    -29        query: GenericElasticsearchQuery,
    -30        encoder: Optional[Encoder],
    -31        return_size: int = 1000,
    -32        test=False,
    -33        return_id_only=True,
    -34        config=None,
    -35    ):
    -36        self.topics = {"1": topics["1"]} if test else topics
    -37        self.client = client
    -38        self.index_name = index_name
    -39        self.output_file = output_file
    -40        self.return_size = return_size
    -41        self.query = query
    -42        self.encoder = encoder
    -43        self.return_id_only = return_id_only
    -44        self.config = config
    +            
    24    def __init__(
    +25            self,
    +26            topics: Dict[Union[str, int], Dict[str, str]],
    +27            client: Elasticsearch,
    +28            index_name: str,
    +29            output_file: str,
    +30            query: GenericElasticsearchQuery,
    +31            encoder: Optional[Encoder],
    +32            return_size: int = 1000,
    +33            test=False,
    +34            return_id_only=True,
    +35            config=None,
    +36    ):
    +37        self.topics = {"1": topics["1"]} if test else topics
    +38        self.client = client
    +39        self.index_name = index_name
    +40        self.output_file = output_file
    +41        self.return_size = return_size
    +42        self.query = query
    +43        self.encoder = encoder
    +44        self.return_id_only = return_id_only
    +45        self.config = config
    +46        self.document_cls = document_factory['elasticsearch']
     
    @@ -383,13 +388,13 @@

    -
    46    def generate_query(self, topic_num):
    -47        """
    -48        Generates a query given a topic number from the list of topics
    -49
    -50        :param topic_num:
    -51        """
    -52        raise NotImplementedError
    +            
    48    def generate_query(self, topic_num):
    +49        """
    +50        Generates a query given a topic number from the list of topics
    +51
    +52        :param topic_num:
    +53        """
    +54        raise NotImplementedError
     
    @@ -415,14 +420,14 @@
    Parameters
    -
    54    def execute_query(self, *args, **kwargs):
    -55        """
    -56        Execute a query given parameters
    -57
    -58        :param args:
    -59        :param kwargs:
    -60        """
    -61        raise NotImplementedError
    +            
    56    def execute_query(self, *args, **kwargs):
    +57        """
    +58        Execute a query given parameters
    +59
    +60        :param args:
    +61        :param kwargs:
    +62        """
    +63        raise NotImplementedError
     
    @@ -449,65 +454,65 @@
    Parameters
    -
     67    async def run_all_queries(
    - 68        self, query_type=None, return_results=False,
    - 69            return_size: int = None, return_id_only: bool = False, **kwargs
    - 70    ) -> List:
    - 71        """
    - 72        A generic function that will asynchronously run all topics using the execute_query() method
    - 73
    - 74        :param query_type: Which query to execute. Query_type determines which method is used to generate the queries
    - 75               from self.query.query_funcs: Dict[str, func]
    - 76        :param return_results: Whether to return raw results from the client. Useful for analysing results directly or
    - 77               for computing the BM25 scores for log normalization in NIR-style scoring
    - 78        :param return_size: Number of documents to return. Overrides the config value if exists.
    - 79        :param return_id_only: Return the ID of the document only, rather than the full source document.
    - 80        :param args: Arguments to pass to the execute_query method
    - 81        :param kwargs: Keyword arguments to pass to the execute_query method
    - 82        :return:
    - 83            A list of results if return_results = True else an empty list is returned.
    - 84        """
    - 85        if not await self.client.ping():
    - 86            await self.client.close()
    - 87            raise RuntimeError(
    - 88                f"Elasticsearch instance cannot be reached at {self.client}"
    - 89            )
    - 90
    - 91        kwargs = self._update_kwargs(**kwargs)
    +            
     69    async def run_all_queries(
    + 70            self, query_type=None, return_results=False,
    + 71            return_size: int = None, return_id_only: bool = False, **kwargs
    + 72    ) -> List:
    + 73        """
    + 74        A generic function that will asynchronously run all topics using the execute_query() method
    + 75
    + 76        :param query_type: Which query to execute. Query_type determines which method is used to generate the queries
    + 77               from self.query.query_funcs: Dict[str, func]
    + 78        :param return_results: Whether to return raw results from the client. Useful for analysing results directly or
    + 79               for computing the BM25 scores for log normalization in NIR-style scoring
    + 80        :param return_size: Number of documents to return. Overrides the config value if exists.
    + 81        :param return_id_only: Return the ID of the document only, rather than the full source document.
    + 82        :param args: Arguments to pass to the execute_query method
    + 83        :param kwargs: Keyword arguments to pass to the execute_query method
    + 84        :return:
    + 85            A list of results if return_results = True else an empty list is returned.
    + 86        """
    + 87        if not await self.client.ping():
    + 88            await self.client.close()
    + 89            raise RuntimeError(
    + 90                f"Elasticsearch instance cannot be reached at {self.client}"
    + 91            )
      92
    - 93        if return_size is None:
    - 94            return_size = self.return_size
    - 95
    - 96        if return_id_only is None:
    - 97            return_id_only = self.return_id_only
    - 98
    - 99        if query_type is None:
    -100            query_type = self.config.query_type
    -101
    -102        kwargs.pop('return_size', None)
    -103        kwargs.pop('return_id_only', None)
    -104        kwargs.pop('query_type', None)
    -105
    -106        tasks = [
    -107            self.execute_query(
    -108                topic_num=topic_num,
    -109                query_type=query_type,
    -110                return_size=return_size,
    -111                return_id_only=return_id_only,
    -112                **kwargs
    -113            )
    -114            for topic_num in self.topics
    -115        ]
    -116
    -117        results = []
    + 93        kwargs = self._update_kwargs(**kwargs)
    + 94
    + 95        if return_size is None:
    + 96            return_size = self.return_size
    + 97
    + 98        if return_id_only is None:
    + 99            return_id_only = self.return_id_only
    +100
    +101        if query_type is None:
    +102            query_type = self.config.query_type
    +103
    +104        kwargs.pop('return_size', None)
    +105        kwargs.pop('return_id_only', None)
    +106        kwargs.pop('query_type', None)
    +107
    +108        tasks = [
    +109            self.execute_query(
    +110                topic_num=topic_num,
    +111                query_type=query_type,
    +112                return_size=return_size,
    +113                return_id_only=return_id_only,
    +114                **kwargs
    +115            )
    +116            for topic_num in self.topics
    +117        ]
     118
    -119        for f in tqdm.asyncio.tqdm.as_completed(tasks, desc="Running Queries"):
    -120            res = await unpack_coroutine(f)
    -121
    -122            if return_results:
    -123                results.append(res)
    -124
    -125        return results
    +119        results = []
    +120
    +121        for f in tqdm.asyncio.tqdm.as_completed(tasks, desc="Running Queries"):
    +122            res = await unpack_coroutine(f)
    +123
    +124            if return_results:
    +125                results.append(res)
    +126
    +127        return results
     
    @@ -638,7 +643,7 @@
    Returns
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -665,7 +670,7 @@
    Returns
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/engines/elasticsearch/generate_script_score.html b/docs/debeir/engines/elasticsearch/generate_script_score.html index 72e7d5a..4fcff58 100644 --- a/docs/debeir/engines/elasticsearch/generate_script_score.html +++ b/docs/debeir/engines/elasticsearch/generate_script_score.html @@ -3,7 +3,7 @@ - + debeir.engines.elasticsearch.generate_script_score API documentation @@ -82,7 +82,7 @@

      1import copy
    -  2from typing import Union, Dict
    +  2from typing import Dict, Union
       3
       4base_script = {
       5    "lang": "painless",
    @@ -99,168 +99,169 @@ 

    16 17 This is a string builder class 18 """ - 19 def __init__(self): - 20 self.s = "" - 21 self.i = 0 - 22 self.variables = [] - 23 - 24 def _add_line(self, line): - 25 self.s = self.s + line.strip() + "\n" - 26 - 27 def add_preamble(self): - 28 """ - 29 Adds preamble to the internal string - 30 This will return the bm25 score if the normalization constant is below 0 - 31 """ - 32 self._add_line( - 33 """ - 34 if (params.norm_weight < 0.0) { - 35 return _score; - 36 } - 37 """ - 38 ) - 39 - 40 def add_log_score(self, ignore_below_one=False) -> "SourceBuilder": - 41 """ - 42 Adds the BM25 log score line - 43 :param ignore_below_one: Ignore all scores below 1.0 as Log(1) = 0. Otherwise, just ignore Log(0 and under). - 44 :return: - 45 SourceBuilder - 46 """ - 47 if ignore_below_one: - 48 self._add_line( - 49 #"def log_score = _score < 1.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" - 50 "def log_score = params.disable_bm25 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" - 51 # "def log_score = Math.log(_score)/Math.log(params.norm_weight);" - 52 ) - 53 else: - 54 self._add_line( - 55 "def log_score = _score <= 0.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" - 56 # "def log_score = Math.log(_score)/Math.log(params.norm_weight);" - 57 ) - 58 - 59 return self - 60 - 61 def add_embed_field(self, qfield, field) -> "SourceBuilder": - 62 """ - 63 Adds a cosine score line. - 64 :param qfield: Query field - 65 :param field: Document facet field - 66 :return: - 67 """ - 68 if "embedding" not in field.lower(): - 69 field = field.replace(".", "_") + "_Embedding" - 70 - 71 variable_name = f"{field}_{qfield}_score" - 72 - 73 self._add_line( - 74 f"double {variable_name} = doc['{field}'].isEmpty() ? 0.0 : params.weights[{self.i}]*cosineSimilarity(params.{qfield}" - 75 f", '{field}') + params.offset; " - 76 # f"double {variable_name} = cosineSimilarity(params.{qfield}, '{field}') + 1.0; " - 77 ) - 78 self.variables.append(variable_name) - 79 - 80 self.i += 1 - 81 - 82 return self - 83 - 84 def finish(self): - 85 """ - 86 Finalises the script score and returns the internal string - 87 :return: - 88 A string containing the script score query - 89 """ - 90 self._add_line("double embed_score = " + " + ".join(self.variables) + ";") - 91 self._add_line( - 92 #"return params.disable_bm25 == true ? embed_score : embed_score + log_score;" - 93 "return embed_score + log_score;" - 94 ) - 95 - 96 return self.s - 97 + 19 + 20 def __init__(self): + 21 self.s = "" + 22 self.i = 0 + 23 self.variables = [] + 24 + 25 def _add_line(self, line): + 26 self.s = self.s + line.strip() + "\n" + 27 + 28 def add_preamble(self): + 29 """ + 30 Adds preamble to the internal string + 31 This will return the bm25 score if the normalization constant is below 0 + 32 """ + 33 self._add_line( + 34 """ + 35 if (params.norm_weight < 0.0) { + 36 return _score; + 37 } + 38 """ + 39 ) + 40 + 41 def add_log_score(self, ignore_below_one=False) -> "SourceBuilder": + 42 """ + 43 Adds the BM25 log score line + 44 :param ignore_below_one: Ignore all scores below 1.0 as Log(1) = 0. Otherwise, just ignore Log(0 and under). + 45 :return: + 46 SourceBuilder + 47 """ + 48 if ignore_below_one: + 49 self._add_line( + 50 # "def log_score = _score < 1.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" + 51 "def log_score = params.disable_bm25 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" + 52 # "def log_score = Math.log(_score)/Math.log(params.norm_weight);" + 53 ) + 54 else: + 55 self._add_line( + 56 "def log_score = _score <= 0.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" + 57 # "def log_score = Math.log(_score)/Math.log(params.norm_weight);" + 58 ) + 59 + 60 return self + 61 + 62 def add_embed_field(self, qfield, field) -> "SourceBuilder": + 63 """ + 64 Adds a cosine score line. + 65 :param qfield: Query field + 66 :param field: Document facet field + 67 :return: + 68 """ + 69 if "embedding" not in field.lower(): + 70 field = field.replace(".", "_") + "_Embedding" + 71 + 72 variable_name = f"{field}_{qfield}_score" + 73 + 74 self._add_line( + 75 f"double {variable_name} = doc['{field}'].isEmpty() ? 0.0 : params.weights[{self.i}]*cosineSimilarity(params.{qfield}" + 76 f", '{field}') + params.offset; " + 77 # f"double {variable_name} = cosineSimilarity(params.{qfield}, '{field}') + 1.0; " + 78 ) + 79 self.variables.append(variable_name) + 80 + 81 self.i += 1 + 82 + 83 return self + 84 + 85 def finish(self): + 86 """ + 87 Finalises the script score and returns the internal string + 88 :return: + 89 A string containing the script score query + 90 """ + 91 self._add_line("double embed_score = " + " + ".join(self.variables) + ";") + 92 self._add_line( + 93 # "return params.disable_bm25 == true ? embed_score : embed_score + log_score;" + 94 "return embed_score + log_score;" + 95 ) + 96 + 97 return self.s 98 - 99def generate_source(qfields: Union[list, str], fields) -> str: -100 """ -101 Generates the script source based off a set of input fields and facets -102 -103 :param qfields: Query fields (or topic fields) -104 :param fields: Document facets to compute cosine similarity on -105 :return: -106 """ -107 sb = SourceBuilder() -108 sb.add_log_score(ignore_below_one=True) -109 -110 if isinstance(qfields, str): -111 qfields = [qfields] -112 -113 for qfield in qfields: -114 for field in fields: -115 sb.add_embed_field(qfield, field) -116 -117 s = sb.finish() -118 -119 return s -120 + 99 +100def generate_source(qfields: Union[list, str], fields) -> str: +101 """ +102 Generates the script source based off a set of input fields and facets +103 +104 :param qfields: Query fields (or topic fields) +105 :param fields: Document facets to compute cosine similarity on +106 :return: +107 """ +108 sb = SourceBuilder() +109 sb.add_log_score(ignore_below_one=True) +110 +111 if isinstance(qfields, str): +112 qfields = [qfields] +113 +114 for qfield in qfields: +115 for field in fields: +116 sb.add_embed_field(qfield, field) +117 +118 s = sb.finish() +119 +120 return s 121 -122# def generate_source(fields, log_ignore=False): -123# s = "" -124# -125# if log_ignore: -126# -127# s = """ -128# def log_score = _score < 1.0 ? _score : Math.log(_score)/Math.log(params.norm_weight); -129# def weights = params.weights;""".strip()+"\n" -130# -131# variables = [] -132# -133# for i, field in enumerate(fields): -134# field = field.replace(".", '_') + '_Embedding' -135# s += f"double {field}_score = doc['{field}'].size() == 0 ? 0 : weights[{i}]*cosineSimilarity(params.q_eb, '{field}') + params.offset;\n" -136# -137# variables.append(f"{field}_score") -138# -139# s = s.strip() -140# -141# s = s + "\n double embed_score = " + " + ".join(variables) + ";" -142# s = s + " \n return params.disable_bm25 == true ? embed_score : embed_score + Math.log(_score)/Math.log(params.norm_weight);" -143# -144# return s -145 +122 +123# def generate_source(fields, log_ignore=False): +124# s = "" +125# +126# if log_ignore: +127# +128# s = """ +129# def log_score = _score < 1.0 ? _score : Math.log(_score)/Math.log(params.norm_weight); +130# def weights = params.weights;""".strip()+"\n" +131# +132# variables = [] +133# +134# for i, field in enumerate(fields): +135# field = field.replace(".", '_') + '_Embedding' +136# s += f"double {field}_score = doc['{field}'].size() == 0 ? 0 : weights[{i}]*cosineSimilarity(params.q_eb, '{field}') + params.offset;\n" +137# +138# variables.append(f"{field}_score") +139# +140# s = s.strip() +141# +142# s = s + "\n double embed_score = " + " + ".join(variables) + ";" +143# s = s + " \n return params.disable_bm25 == true ? embed_score : embed_score + Math.log(_score)/Math.log(params.norm_weight);" +144# +145# return s 146 -147def check_params_is_valid(params, qfields): -148 """ -149 Validate if the parameters for the script score passes a simple sanity check. -150 -151 :param params: -152 :param qfields: -153 """ -154 for qfield in qfields: -155 assert qfield in params -156 -157 assert "weights" in params -158 assert "offset" in params -159 +147 +148def check_params_is_valid(params, qfields): +149 """ +150 Validate if the parameters for the script score passes a simple sanity check. +151 +152 :param params: +153 :param qfields: +154 """ +155 for qfield in qfields: +156 assert qfield in params +157 +158 assert "weights" in params +159 assert "offset" in params 160 -161def generate_script( -162 fields, params, source_generator=generate_source, qfields="q_eb" -163) -> Dict: -164 """ -165 Parameters for creating the script -166 -167 :param fields: Document fields to search -168 :param params: Parameters for the script -169 :param source_generator: Function that will generate the script -170 :param qfields: Query fields to search from (topic facets) -171 :return: -172 """ -173 script = copy.deepcopy(base_script) -174 check_params_is_valid(params, qfields) -175 -176 script["lang"] = "painless" -177 script["source"] = source_generator(qfields, fields) -178 script["params"] = params -179 -180 return script +161 +162def generate_script( +163 fields, params, source_generator=generate_source, qfields="q_eb" +164) -> Dict: +165 """ +166 Parameters for creating the script +167 +168 :param fields: Document fields to search +169 :param params: Parameters for the script +170 :param source_generator: Function that will generate the script +171 :param qfields: Query fields to search from (topic facets) +172 :return: +173 """ +174 script = copy.deepcopy(base_script) +175 check_params_is_valid(params, qfields) +176 +177 script["lang"] = "painless" +178 script["source"] = source_generator(qfields, fields) +179 script["params"] = params +180 +181 return script

    @@ -283,84 +284,85 @@

    17 18 This is a string builder class 19 """ -20 def __init__(self): -21 self.s = "" -22 self.i = 0 -23 self.variables = [] -24 -25 def _add_line(self, line): -26 self.s = self.s + line.strip() + "\n" -27 -28 def add_preamble(self): -29 """ -30 Adds preamble to the internal string -31 This will return the bm25 score if the normalization constant is below 0 -32 """ -33 self._add_line( -34 """ -35 if (params.norm_weight < 0.0) { -36 return _score; -37 } -38 """ -39 ) -40 -41 def add_log_score(self, ignore_below_one=False) -> "SourceBuilder": -42 """ -43 Adds the BM25 log score line -44 :param ignore_below_one: Ignore all scores below 1.0 as Log(1) = 0. Otherwise, just ignore Log(0 and under). -45 :return: -46 SourceBuilder -47 """ -48 if ignore_below_one: -49 self._add_line( -50 #"def log_score = _score < 1.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" -51 "def log_score = params.disable_bm25 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" -52 # "def log_score = Math.log(_score)/Math.log(params.norm_weight);" -53 ) -54 else: -55 self._add_line( -56 "def log_score = _score <= 0.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" -57 # "def log_score = Math.log(_score)/Math.log(params.norm_weight);" -58 ) -59 -60 return self -61 -62 def add_embed_field(self, qfield, field) -> "SourceBuilder": -63 """ -64 Adds a cosine score line. -65 :param qfield: Query field -66 :param field: Document facet field -67 :return: -68 """ -69 if "embedding" not in field.lower(): -70 field = field.replace(".", "_") + "_Embedding" -71 -72 variable_name = f"{field}_{qfield}_score" -73 -74 self._add_line( -75 f"double {variable_name} = doc['{field}'].isEmpty() ? 0.0 : params.weights[{self.i}]*cosineSimilarity(params.{qfield}" -76 f", '{field}') + params.offset; " -77 # f"double {variable_name} = cosineSimilarity(params.{qfield}, '{field}') + 1.0; " -78 ) -79 self.variables.append(variable_name) -80 -81 self.i += 1 -82 -83 return self -84 -85 def finish(self): -86 """ -87 Finalises the script score and returns the internal string -88 :return: -89 A string containing the script score query -90 """ -91 self._add_line("double embed_score = " + " + ".join(self.variables) + ";") -92 self._add_line( -93 #"return params.disable_bm25 == true ? embed_score : embed_score + log_score;" -94 "return embed_score + log_score;" -95 ) -96 -97 return self.s +20 +21 def __init__(self): +22 self.s = "" +23 self.i = 0 +24 self.variables = [] +25 +26 def _add_line(self, line): +27 self.s = self.s + line.strip() + "\n" +28 +29 def add_preamble(self): +30 """ +31 Adds preamble to the internal string +32 This will return the bm25 score if the normalization constant is below 0 +33 """ +34 self._add_line( +35 """ +36 if (params.norm_weight < 0.0) { +37 return _score; +38 } +39 """ +40 ) +41 +42 def add_log_score(self, ignore_below_one=False) -> "SourceBuilder": +43 """ +44 Adds the BM25 log score line +45 :param ignore_below_one: Ignore all scores below 1.0 as Log(1) = 0. Otherwise, just ignore Log(0 and under). +46 :return: +47 SourceBuilder +48 """ +49 if ignore_below_one: +50 self._add_line( +51 # "def log_score = _score < 1.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" +52 "def log_score = params.disable_bm25 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" +53 # "def log_score = Math.log(_score)/Math.log(params.norm_weight);" +54 ) +55 else: +56 self._add_line( +57 "def log_score = _score <= 0.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);" +58 # "def log_score = Math.log(_score)/Math.log(params.norm_weight);" +59 ) +60 +61 return self +62 +63 def add_embed_field(self, qfield, field) -> "SourceBuilder": +64 """ +65 Adds a cosine score line. +66 :param qfield: Query field +67 :param field: Document facet field +68 :return: +69 """ +70 if "embedding" not in field.lower(): +71 field = field.replace(".", "_") + "_Embedding" +72 +73 variable_name = f"{field}_{qfield}_score" +74 +75 self._add_line( +76 f"double {variable_name} = doc['{field}'].isEmpty() ? 0.0 : params.weights[{self.i}]*cosineSimilarity(params.{qfield}" +77 f", '{field}') + params.offset; " +78 # f"double {variable_name} = cosineSimilarity(params.{qfield}, '{field}') + 1.0; " +79 ) +80 self.variables.append(variable_name) +81 +82 self.i += 1 +83 +84 return self +85 +86 def finish(self): +87 """ +88 Finalises the script score and returns the internal string +89 :return: +90 A string containing the script score query +91 """ +92 self._add_line("double embed_score = " + " + ".join(self.variables) + ";") +93 self._add_line( +94 # "return params.disable_bm25 == true ? embed_score : embed_score + log_score;" +95 "return embed_score + log_score;" +96 ) +97 +98 return self.s

    @@ -381,10 +383,10 @@

    -
    20    def __init__(self):
    -21        self.s = ""
    -22        self.i = 0
    -23        self.variables = []
    +            
    21    def __init__(self):
    +22        self.s = ""
    +23        self.i = 0
    +24        self.variables = []
     
    @@ -402,18 +404,18 @@

    -
    28    def add_preamble(self):
    -29        """
    -30        Adds preamble to the internal string
    -31        This will return the bm25 score if the normalization constant is below 0
    -32        """
    -33        self._add_line(
    -34            """
    -35            if (params.norm_weight < 0.0) {
    -36                return _score;
    -37            }
    -38        """
    -39        )
    +            
    29    def add_preamble(self):
    +30        """
    +31        Adds preamble to the internal string
    +32        This will return the bm25 score if the normalization constant is below 0
    +33        """
    +34        self._add_line(
    +35            """
    +36            if (params.norm_weight < 0.0) {
    +37                return _score;
    +38            }
    +39        """
    +40        )
     
    @@ -434,26 +436,26 @@

    -
    41    def add_log_score(self, ignore_below_one=False) -> "SourceBuilder":
    -42        """
    -43        Adds the BM25 log score line
    -44        :param ignore_below_one: Ignore all scores below 1.0 as Log(1) = 0. Otherwise, just ignore Log(0 and under).
    -45        :return:
    -46            SourceBuilder
    -47        """
    -48        if ignore_below_one:
    -49            self._add_line(
    -50                #"def log_score = _score < 1.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);"
    -51                "def log_score = params.disable_bm25 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);"
    -52                # "def log_score = Math.log(_score)/Math.log(params.norm_weight);"
    -53            )
    -54        else:
    -55            self._add_line(
    -56                "def log_score = _score <= 0.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);"
    -57                # "def log_score = Math.log(_score)/Math.log(params.norm_weight);"
    -58            )
    -59
    -60        return self
    +            
    42    def add_log_score(self, ignore_below_one=False) -> "SourceBuilder":
    +43        """
    +44        Adds the BM25 log score line
    +45        :param ignore_below_one: Ignore all scores below 1.0 as Log(1) = 0. Otherwise, just ignore Log(0 and under).
    +46        :return:
    +47            SourceBuilder
    +48        """
    +49        if ignore_below_one:
    +50            self._add_line(
    +51                # "def log_score = _score < 1.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);"
    +52                "def log_score = params.disable_bm25 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);"
    +53                # "def log_score = Math.log(_score)/Math.log(params.norm_weight);"
    +54            )
    +55        else:
    +56            self._add_line(
    +57                "def log_score = _score <= 0.0 ? 0.0 : Math.log(_score)/Math.log(params.norm_weight);"
    +58                # "def log_score = Math.log(_score)/Math.log(params.norm_weight);"
    +59            )
    +60
    +61        return self
     
    @@ -486,28 +488,28 @@
    Returns
    -
    62    def add_embed_field(self, qfield, field) -> "SourceBuilder":
    -63        """
    -64        Adds a cosine score line.
    -65        :param qfield: Query field
    -66        :param field: Document facet field
    -67        :return:
    -68        """
    -69        if "embedding" not in field.lower():
    -70            field = field.replace(".", "_") + "_Embedding"
    -71
    -72        variable_name = f"{field}_{qfield}_score"
    -73
    -74        self._add_line(
    -75            f"double {variable_name} = doc['{field}'].isEmpty() ? 0.0 : params.weights[{self.i}]*cosineSimilarity(params.{qfield}"
    -76            f", '{field}') + params.offset; "
    -77            # f"double {variable_name} = cosineSimilarity(params.{qfield}, '{field}') + 1.0; "
    -78        )
    -79        self.variables.append(variable_name)
    -80
    -81        self.i += 1
    -82
    -83        return self
    +            
    63    def add_embed_field(self, qfield, field) -> "SourceBuilder":
    +64        """
    +65        Adds a cosine score line.
    +66        :param qfield: Query field
    +67        :param field: Document facet field
    +68        :return:
    +69        """
    +70        if "embedding" not in field.lower():
    +71            field = field.replace(".", "_") + "_Embedding"
    +72
    +73        variable_name = f"{field}_{qfield}_score"
    +74
    +75        self._add_line(
    +76            f"double {variable_name} = doc['{field}'].isEmpty() ? 0.0 : params.weights[{self.i}]*cosineSimilarity(params.{qfield}"
    +77            f", '{field}') + params.offset; "
    +78            # f"double {variable_name} = cosineSimilarity(params.{qfield}, '{field}') + 1.0; "
    +79        )
    +80        self.variables.append(variable_name)
    +81
    +82        self.i += 1
    +83
    +84        return self
     
    @@ -536,19 +538,19 @@
    Returns
    -
    85    def finish(self):
    -86        """
    -87        Finalises the script score and returns the internal string
    -88        :return:
    -89            A string containing the script score query
    -90        """
    -91        self._add_line("double embed_score = " + " + ".join(self.variables) + ";")
    -92        self._add_line(
    -93            #"return params.disable_bm25 == true ? embed_score : embed_score + log_score;"
    -94            "return embed_score + log_score;"
    -95        )
    -96
    -97        return self.s
    +            
    86    def finish(self):
    +87        """
    +88        Finalises the script score and returns the internal string
    +89        :return:
    +90            A string containing the script score query
    +91        """
    +92        self._add_line("double embed_score = " + " + ".join(self.variables) + ";")
    +93        self._add_line(
    +94            # "return params.disable_bm25 == true ? embed_score : embed_score + log_score;"
    +95            "return embed_score + log_score;"
    +96        )
    +97
    +98        return self.s
     
    @@ -576,27 +578,27 @@
    Returns
    -
    100def generate_source(qfields: Union[list, str], fields) -> str:
    -101    """
    -102    Generates the script source based off a set of input fields and facets
    -103
    -104    :param qfields: Query fields (or topic fields)
    -105    :param fields: Document facets to compute cosine similarity on
    -106    :return:
    -107    """
    -108    sb = SourceBuilder()
    -109    sb.add_log_score(ignore_below_one=True)
    -110
    -111    if isinstance(qfields, str):
    -112        qfields = [qfields]
    -113
    -114    for qfield in qfields:
    -115        for field in fields:
    -116            sb.add_embed_field(qfield, field)
    -117
    -118    s = sb.finish()
    -119
    -120    return s
    +            
    101def generate_source(qfields: Union[list, str], fields) -> str:
    +102    """
    +103    Generates the script source based off a set of input fields and facets
    +104
    +105    :param qfields: Query fields (or topic fields)
    +106    :param fields: Document facets to compute cosine similarity on
    +107    :return:
    +108    """
    +109    sb = SourceBuilder()
    +110    sb.add_log_score(ignore_below_one=True)
    +111
    +112    if isinstance(qfields, str):
    +113        qfields = [qfields]
    +114
    +115    for qfield in qfields:
    +116        for field in fields:
    +117            sb.add_embed_field(qfield, field)
    +118
    +119    s = sb.finish()
    +120
    +121    return s
     
    @@ -625,18 +627,18 @@
    Returns
    -
    148def check_params_is_valid(params, qfields):
    -149    """
    -150    Validate if the parameters for the script score passes a simple sanity check.
    -151
    -152    :param params:
    -153    :param qfields:
    -154    """
    -155    for qfield in qfields:
    -156        assert qfield in params
    -157
    -158    assert "weights" in params
    -159    assert "offset" in params
    +            
    149def check_params_is_valid(params, qfields):
    +150    """
    +151    Validate if the parameters for the script score passes a simple sanity check.
    +152
    +153    :param params:
    +154    :param qfields:
    +155    """
    +156    for qfield in qfields:
    +157        assert qfield in params
    +158
    +159    assert "weights" in params
    +160    assert "offset" in params
     
    @@ -663,26 +665,26 @@
    Parameters
    -
    162def generate_script(
    -163    fields, params, source_generator=generate_source, qfields="q_eb"
    -164) -> Dict:
    -165    """
    -166    Parameters for creating the script
    -167
    -168    :param fields: Document fields to search
    -169    :param params: Parameters for the script
    -170    :param source_generator:  Function that will generate the script
    -171    :param qfields: Query fields to search from (topic facets)
    -172    :return:
    -173    """
    -174    script = copy.deepcopy(base_script)
    -175    check_params_is_valid(params, qfields)
    -176
    -177    script["lang"] = "painless"
    -178    script["source"] = source_generator(qfields, fields)
    -179    script["params"] = params
    -180
    -181    return script
    +            
    163def generate_script(
    +164        fields, params, source_generator=generate_source, qfields="q_eb"
    +165) -> Dict:
    +166    """
    +167    Parameters for creating the script
    +168
    +169    :param fields: Document fields to search
    +170    :param params: Parameters for the script
    +171    :param source_generator:  Function that will generate the script
    +172    :param qfields: Query fields to search from (topic facets)
    +173    :return:
    +174    """
    +175    script = copy.deepcopy(base_script)
    +176    check_params_is_valid(params, qfields)
    +177
    +178    script["lang"] = "painless"
    +179    script["source"] = source_generator(qfields, fields)
    +180    script["params"] = params
    +181
    +182    return script
     
    @@ -803,7 +805,7 @@
    Returns
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -830,7 +832,7 @@
    Returns
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/engines/solr.html b/docs/debeir/engines/solr.html index 2f937f1..4ade05d 100644 --- a/docs/debeir/engines/solr.html +++ b/docs/debeir/engines/solr.html @@ -3,7 +3,7 @@ - + debeir.engines.solr API documentation @@ -148,7 +148,7 @@

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -175,7 +175,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/evaluation.html b/docs/debeir/evaluation.html index 467906b..4e7f031 100644 --- a/docs/debeir/evaluation.html +++ b/docs/debeir/evaluation.html @@ -3,7 +3,7 @@ - + debeir.evaluation API documentation @@ -167,7 +167,7 @@

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -194,7 +194,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/evaluation/cross_validation.html b/docs/debeir/evaluation/cross_validation.html index 1743571..5a23d21 100644 --- a/docs/debeir/evaluation/cross_validation.html +++ b/docs/debeir/evaluation/cross_validation.html @@ -3,7 +3,7 @@ - + debeir.evaluation.cross_validation API documentation @@ -36,6 +36,12 @@

    API Documentation

  • CrossValidatorTypes
  • @@ -73,57 +79,57 @@

      1from enum import Enum
    -  2from typing import List, Union, Dict
    +  2from typing import Dict, List, Union
       3
       4import numpy as np
    -  5from sklearn.model_selection import KFold, StratifiedKFold
    -  6from debeir.data_sets.types import DatasetTypes
    -  7from debeir.data_sets.types import InputExample
    -  8
    -  9import datasets
    +  5from debeir.datasets.types import DatasetTypes, InputExample
    +  6from sklearn.model_selection import KFold, StratifiedKFold
    +  7
    +  8import datasets
    +  9
      10
    - 11
    - 12def split_k_fold(n_fold, data_files):
    - 13    percentage = 100 // n_fold
    - 14
    - 15    vals_ds = datasets.load_dataset('csv', split=[
    - 16        f'train[{k}%:{k + percentage}%]' for k in range(0, 100, percentage)
    - 17    ], data_files=data_files)
    - 18
    - 19    trains_ds = datasets.load_dataset('csv', split=[
    - 20        f'train[:{k}%]+train[{k + percentage}%:]' for k in range(0, 100, percentage)
    - 21    ], data_files=data_files)
    - 22
    - 23    return trains_ds, vals_ds
    + 11def split_k_fold(n_fold, data_files):
    + 12    percentage = 100 // n_fold
    + 13
    + 14    vals_ds = datasets.load_dataset('csv', split=[
    + 15        f'train[{k}%:{k + percentage}%]' for k in range(0, 100, percentage)
    + 16    ], data_files=data_files)
    + 17
    + 18    trains_ds = datasets.load_dataset('csv', split=[
    + 19        f'train[:{k}%]+train[{k + percentage}%:]' for k in range(0, 100, percentage)
    + 20    ], data_files=data_files)
    + 21
    + 22    return trains_ds, vals_ds
    + 23
      24
    - 25
    - 26class CrossValidatorTypes(Enum):
    - 27    """
    - 28    Cross Validator Strategies for separating the dataset
    - 29    """
    - 30    Stratified = "StratifiedKFold"
    - 31    KFold = "KFold"
    + 25class CrossValidatorTypes(Enum):
    + 26    """
    + 27    Cross Validator Strategies for separating the dataset
    + 28    """
    + 29    Stratified = "StratifiedKFold"
    + 30    KFold = "KFold"
    + 31
      32
    - 33
    - 34str_to_fn = {
    - 35    "StratifiedKFold": StratifiedKFold,
    - 36    "KFold": KFold
    - 37}
    + 33str_to_fn = {
    + 34    "StratifiedKFold": StratifiedKFold,
    + 35    "KFold": KFold
    + 36}
    + 37
      38
    - 39
    - 40class CrossValidator:
    - 41    """
    - 42    Cross Validator Class for different types of data_sets
    - 43
    - 44    E.g. List -> [[Data], label]
    - 45         List[Dict] -> {"data": Data, "label": label}
    - 46         Huggingface Dataset Object -> Data(set="train", label = "label").select(idx)
    - 47    """
    + 39class CrossValidator:
    + 40    """
    + 41    Cross Validator Class for different types of data_sets
    + 42
    + 43    E.g. List -> [[Data], label]
    + 44         List[Dict] -> {"data": Data, "label": label}
    + 45         Huggingface Dataset Object -> Data(set="train", label = "label").select(idx)
    + 46    """
    + 47
      48    def __init__(self, dataset: Union[List, List[Dict], datasets.Dataset],
      49                 x_idx_label_or_attr: Union[str, int], y_idx_label_or_attr: Union[str, int],
      50                 cross_validator_type: [str, CrossValidatorTypes] = CrossValidatorTypes.Stratified,
      51                 seed=42, n_splits=5):
    - 52        #self.evaluator = evaluator
    + 52        # self.evaluator = evaluator
      53        self.cross_vali_fn = str_to_fn[cross_validator_type](n_splits=n_splits,
      54                                                             shuffle=True,
      55                                                             random_state=seed)
    @@ -202,18 +208,18 @@ 

    -
    13def split_k_fold(n_fold, data_files):
    -14    percentage = 100 // n_fold
    -15
    -16    vals_ds = datasets.load_dataset('csv', split=[
    -17        f'train[{k}%:{k + percentage}%]' for k in range(0, 100, percentage)
    -18    ], data_files=data_files)
    -19
    -20    trains_ds = datasets.load_dataset('csv', split=[
    -21        f'train[:{k}%]+train[{k + percentage}%:]' for k in range(0, 100, percentage)
    -22    ], data_files=data_files)
    -23
    -24    return trains_ds, vals_ds
    +            
    12def split_k_fold(n_fold, data_files):
    +13    percentage = 100 // n_fold
    +14
    +15    vals_ds = datasets.load_dataset('csv', split=[
    +16        f'train[{k}%:{k + percentage}%]' for k in range(0, 100, percentage)
    +17    ], data_files=data_files)
    +18
    +19    trains_ds = datasets.load_dataset('csv', split=[
    +20        f'train[:{k}%]+train[{k + percentage}%:]' for k in range(0, 100, percentage)
    +21    ], data_files=data_files)
    +22
    +23    return trains_ds, vals_ds
     
    @@ -231,12 +237,12 @@

    -
    27class CrossValidatorTypes(Enum):
    -28    """
    -29    Cross Validator Strategies for separating the dataset
    -30    """
    -31    Stratified = "StratifiedKFold"
    -32    KFold = "KFold"
    +            
    26class CrossValidatorTypes(Enum):
    +27    """
    +28    Cross Validator Strategies for separating the dataset
    +29    """
    +30    Stratified = "StratifiedKFold"
    +31    KFold = "KFold"
     
    @@ -244,6 +250,28 @@

    +
    +
    + Stratified = <CrossValidatorTypes.Stratified: 'StratifiedKFold'> + + +
    + + + + +
    +
    +
    + KFold = <CrossValidatorTypes.KFold: 'KFold'> + + +
    + + + + +
    Inherited Members
    @@ -266,19 +294,20 @@
    Inherited Members
    -
     41class CrossValidator:
    - 42    """
    - 43    Cross Validator Class for different types of data_sets
    - 44
    - 45    E.g. List -> [[Data], label]
    - 46         List[Dict] -> {"data": Data, "label": label}
    - 47         Huggingface Dataset Object -> Data(set="train", label = "label").select(idx)
    - 48    """
    +            
     40class CrossValidator:
    + 41    """
    + 42    Cross Validator Class for different types of data_sets
    + 43
    + 44    E.g. List -> [[Data], label]
    + 45         List[Dict] -> {"data": Data, "label": label}
    + 46         Huggingface Dataset Object -> Data(set="train", label = "label").select(idx)
    + 47    """
    + 48
      49    def __init__(self, dataset: Union[List, List[Dict], datasets.Dataset],
      50                 x_idx_label_or_attr: Union[str, int], y_idx_label_or_attr: Union[str, int],
      51                 cross_validator_type: [str, CrossValidatorTypes] = CrossValidatorTypes.Stratified,
      52                 seed=42, n_splits=5):
    - 53        #self.evaluator = evaluator
    + 53        # self.evaluator = evaluator
      54        self.cross_vali_fn = str_to_fn[cross_validator_type](n_splits=n_splits,
      55                                                             shuffle=True,
      56                                                             random_state=seed)
    @@ -357,7 +386,7 @@ 
    Inherited Members
    - CrossValidator( dataset: Union[List, List[Dict], datasets.arrow_dataset.Dataset], x_idx_label_or_attr: Union[str, int], y_idx_label_or_attr: Union[str, int], cross_validator_type: [<class 'str'>, <enum 'CrossValidatorTypes'>] = <CrossValidatorTypes.Stratified: 'StratifiedKFold'>, seed=42, n_splits=5) + CrossValidator( dataset: Union[List, List[Dict], datasets.arrow_dataset.Dataset], x_idx_label_or_attr: Union[str, int], y_idx_label_or_attr: Union[str, int], cross_validator_type: [<class 'str'>, <enum 'CrossValidatorTypes'>] = <CrossValidatorTypes.Stratified: 'StratifiedKFold'>, seed=42, n_splits=5) @@ -367,7 +396,7 @@
    Inherited Members
    50 x_idx_label_or_attr: Union[str, int], y_idx_label_or_attr: Union[str, int], 51 cross_validator_type: [str, CrossValidatorTypes] = CrossValidatorTypes.Stratified, 52 seed=42, n_splits=5): -53 #self.evaluator = evaluator +53 # self.evaluator = evaluator 54 self.cross_vali_fn = str_to_fn[cross_validator_type](n_splits=n_splits, 55 shuffle=True, 56 random_state=seed) @@ -527,7 +556,7 @@
    Returns
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -554,7 +583,7 @@
    Returns
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/evaluation/evaluator.html b/docs/debeir/evaluation/evaluator.html index cb6f072..849f94c 100644 --- a/docs/debeir/evaluation/evaluator.html +++ b/docs/debeir/evaluation/evaluator.html @@ -3,7 +3,7 @@ - + debeir.evaluation.evaluator API documentation @@ -72,84 +72,85 @@

    -
     1import loguru
    - 2from typing import Union, List, Dict
    - 3from collections import defaultdict
    - 4
    +                        
     1from collections import defaultdict
    + 2from typing import Dict, List, Union
    + 3
    + 4import loguru
      5from analysis_tools_ir import evaluate, sigtests
    - 6from debeir.interfaces.config import MetricsConfig, GenericConfig
    + 6from debeir.core.config import GenericConfig, MetricsConfig
      7
      8
      9class Evaluator:
     10    """
     11    Evaluation class for computing metrics from TREC-style files
     12    """
    -13    def __init__(self, qrels: str, metrics: List[str]):
    -14        self.qrels = qrels
    -15        self.metrics = []
    -16        self.depths = []
    -17
    -18        try:
    -19            self._validate_and_setup_metrics(metrics)
    -20        except AssertionError:
    -21            raise ValueError("Metrics must be of the form metric@depth")
    -22
    -23    def _validate_and_setup_metrics(self, metrics):
    -24        for metric in metrics:
    -25            assert "@" in metric
    -26            try:
    -27                metric, depth = metric.split("@")
    -28            except:
    -29                raise RuntimeError(f"Unable to parse metric {metric}")
    -30
    -31            assert metric.isalpha()
    -32            assert depth.isdigit()
    -33
    -34            self.metrics.append(metric)
    -35            self.depths.append(int(depth))
    -36
    -37    def evaluate_runs(self, res: Union[str, List[str]], **kwargs):
    -38        """
    -39        Evaluates the TREC-style results from an input result list or file
    -40
    -41        :param res: Results file path or raw results list
    -42        :param kwargs: Keyword arguments to pass to the underlying analysis_tools_ir.parse_run library
    -43        :return:
    -44        """
    -45        results = defaultdict(lambda: {})
    -46        for metric, depth in zip(self.metrics, self.depths):
    -47            results[metric][depth] = evaluate.parse_run(
    -48                res, self.qrels,
    -49                metric=metric, depth=depth,
    -50                **kwargs
    -51            )
    -52
    -53        return results
    -54
    -55    def average_all_metrics(self, runs: Dict, logger: loguru.logger):
    -56        """
    -57        Averages the metric per topic scores into a single averaged score.
    -58
    -59        :param runs: Parsed run dictionary: {metric_name@depth: Run object}
    -60        :param logger: Logger to print metrics
    -61        """
    -62        for metric, depth in zip(self.metrics, self.depths):
    -63            run = runs[metric][depth].run
    -64            logger.info(f"{metric}@{depth} Average: {sum(run.values()) / len(run):.4}")
    -65
    -66    def sigtests(self, results_a, results_b):
    -67        """
    -68        Run a paired significance test on two result files
    -69
    -70        :param results_a:
    -71        :param results_b:
    -72        :return:
    -73        """
    -74        return sigtests.paired.paired_t_test(results_a, results_b, self.qrels)
    -75
    -76    @classmethod
    -77    def build_from_config(cls, config: GenericConfig, metrics_config: MetricsConfig):
    -78        return cls(config.qrels, metrics_config.metrics)
    +13
    +14    def __init__(self, qrels: str, metrics: List[str]):
    +15        self.qrels = qrels
    +16        self.metrics = []
    +17        self.depths = []
    +18
    +19        try:
    +20            self._validate_and_setup_metrics(metrics)
    +21        except AssertionError:
    +22            raise ValueError("Metrics must be of the form metric@depth")
    +23
    +24    def _validate_and_setup_metrics(self, metrics):
    +25        for metric in metrics:
    +26            assert "@" in metric
    +27            try:
    +28                metric, depth = metric.split("@")
    +29            except:
    +30                raise RuntimeError(f"Unable to parse metric {metric}")
    +31
    +32            assert metric.isalpha()
    +33            assert depth.isdigit()
    +34
    +35            self.metrics.append(metric)
    +36            self.depths.append(int(depth))
    +37
    +38    def evaluate_runs(self, res: Union[str, List[str]], **kwargs):
    +39        """
    +40        Evaluates the TREC-style results from an input result list or file
    +41
    +42        :param res: Results file path or raw results list
    +43        :param kwargs: Keyword arguments to pass to the underlying analysis_tools_ir.parse_run library
    +44        :return:
    +45        """
    +46        results = defaultdict(lambda: {})
    +47        for metric, depth in zip(self.metrics, self.depths):
    +48            results[metric][depth] = evaluate.parse_run(
    +49                res, self.qrels,
    +50                metric=metric, depth=depth,
    +51                **kwargs
    +52            )
    +53
    +54        return results
    +55
    +56    def average_all_metrics(self, runs: Dict, logger: loguru.logger):
    +57        """
    +58        Averages the metric per topic scores into a single averaged score.
    +59
    +60        :param runs: Parsed run dictionary: {metric_name@depth: Run object}
    +61        :param logger: Logger to print metrics
    +62        """
    +63        for metric, depth in zip(self.metrics, self.depths):
    +64            run = runs[metric][depth].run
    +65            logger.info(f"{metric}@{depth} Average: {sum(run.values()) / len(run):.4}")
    +66
    +67    def sigtests(self, results_a, results_b):
    +68        """
    +69        Run a paired significance test on two result files
    +70
    +71        :param results_a:
    +72        :param results_b:
    +73        :return:
    +74        """
    +75        return sigtests.paired.paired_t_test(results_a, results_b, self.qrels)
    +76
    +77    @classmethod
    +78    def build_from_config(cls, config: GenericConfig, metrics_config: MetricsConfig):
    +79        return cls(config.qrels, metrics_config.metrics)
     
    @@ -169,72 +170,73 @@

    11 """ 12 Evaluation class for computing metrics from TREC-style files 13 """ -14 def __init__(self, qrels: str, metrics: List[str]): -15 self.qrels = qrels -16 self.metrics = [] -17 self.depths = [] -18 -19 try: -20 self._validate_and_setup_metrics(metrics) -21 except AssertionError: -22 raise ValueError("Metrics must be of the form metric@depth") -23 -24 def _validate_and_setup_metrics(self, metrics): -25 for metric in metrics: -26 assert "@" in metric -27 try: -28 metric, depth = metric.split("@") -29 except: -30 raise RuntimeError(f"Unable to parse metric {metric}") -31 -32 assert metric.isalpha() -33 assert depth.isdigit() -34 -35 self.metrics.append(metric) -36 self.depths.append(int(depth)) -37 -38 def evaluate_runs(self, res: Union[str, List[str]], **kwargs): -39 """ -40 Evaluates the TREC-style results from an input result list or file -41 -42 :param res: Results file path or raw results list -43 :param kwargs: Keyword arguments to pass to the underlying analysis_tools_ir.parse_run library -44 :return: -45 """ -46 results = defaultdict(lambda: {}) -47 for metric, depth in zip(self.metrics, self.depths): -48 results[metric][depth] = evaluate.parse_run( -49 res, self.qrels, -50 metric=metric, depth=depth, -51 **kwargs -52 ) -53 -54 return results -55 -56 def average_all_metrics(self, runs: Dict, logger: loguru.logger): -57 """ -58 Averages the metric per topic scores into a single averaged score. -59 -60 :param runs: Parsed run dictionary: {metric_name@depth: Run object} -61 :param logger: Logger to print metrics -62 """ -63 for metric, depth in zip(self.metrics, self.depths): -64 run = runs[metric][depth].run -65 logger.info(f"{metric}@{depth} Average: {sum(run.values()) / len(run):.4}") -66 -67 def sigtests(self, results_a, results_b): -68 """ -69 Run a paired significance test on two result files -70 -71 :param results_a: -72 :param results_b: -73 :return: -74 """ -75 return sigtests.paired.paired_t_test(results_a, results_b, self.qrels) -76 -77 @classmethod -78 def build_from_config(cls, config: GenericConfig, metrics_config: MetricsConfig): -79 return cls(config.qrels, metrics_config.metrics) +14 +15 def __init__(self, qrels: str, metrics: List[str]): +16 self.qrels = qrels +17 self.metrics = [] +18 self.depths = [] +19 +20 try: +21 self._validate_and_setup_metrics(metrics) +22 except AssertionError: +23 raise ValueError("Metrics must be of the form metric@depth") +24 +25 def _validate_and_setup_metrics(self, metrics): +26 for metric in metrics: +27 assert "@" in metric +28 try: +29 metric, depth = metric.split("@") +30 except: +31 raise RuntimeError(f"Unable to parse metric {metric}") +32 +33 assert metric.isalpha() +34 assert depth.isdigit() +35 +36 self.metrics.append(metric) +37 self.depths.append(int(depth)) +38 +39 def evaluate_runs(self, res: Union[str, List[str]], **kwargs): +40 """ +41 Evaluates the TREC-style results from an input result list or file +42 +43 :param res: Results file path or raw results list +44 :param kwargs: Keyword arguments to pass to the underlying analysis_tools_ir.parse_run library +45 :return: +46 """ +47 results = defaultdict(lambda: {}) +48 for metric, depth in zip(self.metrics, self.depths): +49 results[metric][depth] = evaluate.parse_run( +50 res, self.qrels, +51 metric=metric, depth=depth, +52 **kwargs +53 ) +54 +55 return results +56 +57 def average_all_metrics(self, runs: Dict, logger: loguru.logger): +58 """ +59 Averages the metric per topic scores into a single averaged score. +60 +61 :param runs: Parsed run dictionary: {metric_name@depth: Run object} +62 :param logger: Logger to print metrics +63 """ +64 for metric, depth in zip(self.metrics, self.depths): +65 run = runs[metric][depth].run +66 logger.info(f"{metric}@{depth} Average: {sum(run.values()) / len(run):.4}") +67 +68 def sigtests(self, results_a, results_b): +69 """ +70 Run a paired significance test on two result files +71 +72 :param results_a: +73 :param results_b: +74 :return: +75 """ +76 return sigtests.paired.paired_t_test(results_a, results_b, self.qrels) +77 +78 @classmethod +79 def build_from_config(cls, config: GenericConfig, metrics_config: MetricsConfig): +80 return cls(config.qrels, metrics_config.metrics)

    @@ -252,15 +254,15 @@

    -
    14    def __init__(self, qrels: str, metrics: List[str]):
    -15        self.qrels = qrels
    -16        self.metrics = []
    -17        self.depths = []
    -18
    -19        try:
    -20            self._validate_and_setup_metrics(metrics)
    -21        except AssertionError:
    -22            raise ValueError("Metrics must be of the form metric@depth")
    +            
    15    def __init__(self, qrels: str, metrics: List[str]):
    +16        self.qrels = qrels
    +17        self.metrics = []
    +18        self.depths = []
    +19
    +20        try:
    +21            self._validate_and_setup_metrics(metrics)
    +22        except AssertionError:
    +23            raise ValueError("Metrics must be of the form metric@depth")
     
    @@ -278,23 +280,23 @@

    -
    38    def evaluate_runs(self, res: Union[str, List[str]], **kwargs):
    -39        """
    -40        Evaluates the TREC-style results from an input result list or file
    -41
    -42        :param res: Results file path or raw results list
    -43        :param kwargs: Keyword arguments to pass to the underlying analysis_tools_ir.parse_run library
    -44        :return:
    -45        """
    -46        results = defaultdict(lambda: {})
    -47        for metric, depth in zip(self.metrics, self.depths):
    -48            results[metric][depth] = evaluate.parse_run(
    -49                res, self.qrels,
    -50                metric=metric, depth=depth,
    -51                **kwargs
    -52            )
    -53
    -54        return results
    +            
    39    def evaluate_runs(self, res: Union[str, List[str]], **kwargs):
    +40        """
    +41        Evaluates the TREC-style results from an input result list or file
    +42
    +43        :param res: Results file path or raw results list
    +44        :param kwargs: Keyword arguments to pass to the underlying analysis_tools_ir.parse_run library
    +45        :return:
    +46        """
    +47        results = defaultdict(lambda: {})
    +48        for metric, depth in zip(self.metrics, self.depths):
    +49            results[metric][depth] = evaluate.parse_run(
    +50                res, self.qrels,
    +51                metric=metric, depth=depth,
    +52                **kwargs
    +53            )
    +54
    +55        return results
     
    @@ -317,22 +319,22 @@
    Returns
    def - average_all_metrics( self, runs: Dict, logger: <loguru.logger handlers=[(id=0, level=10, sink=<_io.StringIO object at 0x7f966d8cdea0>)]>): + average_all_metrics( self, runs: Dict, logger: <loguru.logger handlers=[(id=0, level=10, sink=<_io.StringIO object at 0x105cfa710>)]>):
    -
    56    def average_all_metrics(self, runs: Dict, logger: loguru.logger):
    -57        """
    -58        Averages the metric per topic scores into a single averaged score.
    -59
    -60        :param runs: Parsed run dictionary: {metric_name@depth: Run object}
    -61        :param logger: Logger to print metrics
    -62        """
    -63        for metric, depth in zip(self.metrics, self.depths):
    -64            run = runs[metric][depth].run
    -65            logger.info(f"{metric}@{depth} Average: {sum(run.values()) / len(run):.4}")
    +            
    57    def average_all_metrics(self, runs: Dict, logger: loguru.logger):
    +58        """
    +59        Averages the metric per topic scores into a single averaged score.
    +60
    +61        :param runs: Parsed run dictionary: {metric_name@depth: Run object}
    +62        :param logger: Logger to print metrics
    +63        """
    +64        for metric, depth in zip(self.metrics, self.depths):
    +65            run = runs[metric][depth].run
    +66            logger.info(f"{metric}@{depth} Average: {sum(run.values()) / len(run):.4}")
     
    @@ -359,15 +361,15 @@
    Parameters
    -
    67    def sigtests(self, results_a, results_b):
    -68        """
    -69        Run a paired significance test on two result files
    -70
    -71        :param results_a:
    -72        :param results_b:
    -73        :return:
    -74        """
    -75        return sigtests.paired.paired_t_test(results_a, results_b, self.qrels)
    +            
    68    def sigtests(self, results_a, results_b):
    +69        """
    +70        Run a paired significance test on two result files
    +71
    +72        :param results_a:
    +73        :param results_b:
    +74        :return:
    +75        """
    +76        return sigtests.paired.paired_t_test(results_a, results_b, self.qrels)
     
    @@ -391,15 +393,15 @@
    Returns
    @classmethod
    def - build_from_config( cls, config: debeir.interfaces.config.GenericConfig, metrics_config: debeir.interfaces.config.MetricsConfig): + build_from_config( cls, config: debeir.core.config.GenericConfig, metrics_config: debeir.core.config.MetricsConfig):
    -
    77    @classmethod
    -78    def build_from_config(cls, config: GenericConfig, metrics_config: MetricsConfig):
    -79        return cls(config.qrels, metrics_config.metrics)
    +            
    78    @classmethod
    +79    def build_from_config(cls, config: GenericConfig, metrics_config: MetricsConfig):
    +80        return cls(config.qrels, metrics_config.metrics)
     
    @@ -508,7 +510,7 @@
    Returns
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -535,7 +537,7 @@
    Returns
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/evaluation/residual_scoring.html b/docs/debeir/evaluation/residual_scoring.html index addf46d..c1d3328 100644 --- a/docs/debeir/evaluation/residual_scoring.html +++ b/docs/debeir/evaluation/residual_scoring.html @@ -3,7 +3,7 @@ - + debeir.evaluation.residual_scoring API documentation @@ -66,66 +66,82 @@

     1import os
      2import subprocess
      3import tempfile
    - 4from typing import List, Union, Dict
    - 5
    - 6from debeir.evaluation.evaluator import Evaluator
    - 7import uuid
    + 4import uuid
    + 5from typing import Dict, List, Union
    + 6
    + 7from debeir.evaluation.evaluator import Evaluator
      8
    - 9# Remove all documents that exist in the training set
    -10# Evaluate on remaining
    -11# Normalize for result set length, cut off at ????
    -12
    + 9
    +10# Remove all documents that exist in the training set
    +11# Evaluate on remaining
    +12# Normalize for result set length, cut off at ????
     13
    -14class ResidualEvaluator(Evaluator):
    -15    def __init__(self, qrels: str, metrics: List[str], filter_ids: Dict[str, List[str]]):
    -16        super().__init__(qrels, metrics)
    -17        self.qrels_fp = qrels
    -18        self.filter_ids = filter_ids
    -19
    -20    def _filter_run(self, res: str):
    -21        if self.filter_ids is None:
    -22            return res
    -23
    -24        tmpdir = tempfile.mkdtemp()
    -25        tmpfp = os.path.join(tmpdir, str(uuid.uuid4()))
    -26
    -27        writer = open(tmpfp, 'w+')
    -28
    -29        with open(res) as out_file:
    -30            for line in out_file:
    -31                topic_num, _, doc_id, _, _, _ = line.split()
    -32                if doc_id in self.filter_ids[topic_num]:
    -33                    continue
    -34
    -35                writer.write(line)
    +14
    +15class ResidualEvaluator(Evaluator):
    +16    """ Residual Scoring is the scoring of a subset of documents or the residiaul. The residual is created by removing documents from the collection and qrels.
    +17    """
    +18
    +19    def __init__(self, qrels: str, metrics: List[str], filter_ids: Dict[str, List[str]]):
    +20        """ 
    +21        Args:
    +22            qrels (str): Path to qrels 
    +23            metrics (List[str]): A list of metrics with depth e.g. NDCG@1000
    +24            filter_ids (Dict[str, List[str]]): A list of IDs to remove from the collection given by Dict[Topic_num, [Docids]]
    +25        """
    +26        super().__init__(qrels, metrics)
    +27        self.qrels_fp = qrels
    +28        self.filter_ids = filter_ids
    +29
    +30    def _filter_run(self, res: str):
    +31        if self.filter_ids is None:
    +32            return res
    +33
    +34        tmpdir = tempfile.mkdtemp()
    +35        tmpfp = os.path.join(tmpdir, str(uuid.uuid4()))
     36
    -37        writer.close()
    +37        writer = open(tmpfp, 'w+')
     38
    -39        return tmpfp
    -40
    -41    def evaluate_runs(self, res: Union[str, List[str]], with_trec_binary=False, **kwargs):
    -42        if with_trec_binary:
    -43            return self._evaluate_with_binary(res, **kwargs)
    +39        with open(res) as out_file:
    +40            for line in out_file:
    +41                topic_num, _, doc_id, _, _, _ = line.split()
    +42                if doc_id in self.filter_ids[topic_num]:
    +43                    continue
     44
    -45        fp = self._filter_run(res)
    +45                writer.write(line)
     46
    -47        return super().evaluate_runs(fp, **kwargs)
    +47        writer.close()
     48
    -49    def _evaluate_with_binary(self, res, **kwargs):
    -50        fp = self._filter_run(res)
    -51
    -52        output = subprocess.check_output(["trec_eval", self.qrels_fp, fp]).decode()
    +49        return tmpfp
    +50
    +51    def evaluate_runs(self, res: Union[str, List[str]], with_trec_binary=False, **kwargs):
    +52        """ Run the residual evaluation for the runs
     53
    -54        metrics = {}
    -55
    -56        for line in str(output).split("\n"):
    -57            try:
    -58                metric, _, value = line.split()
    -59                metrics[metric] = value
    -60            except:
    -61                continue
    +54        :param res: The results to run the evaluator against
    +55        :param with_trec_binary: Use the TREC C binary instead of the default Python library, defaults to False
    +56        :return: A dictionary of supplied metrics of the results against the qrels 
    +57        """
    +58        if with_trec_binary:
    +59            return self._evaluate_with_binary(res, **kwargs)
    +60
    +61        fp = self._filter_run(res)
     62
    -63        return metrics
    +63        return super().evaluate_runs(fp, **kwargs)
    +64
    +65    def _evaluate_with_binary(self, res, **kwargs):
    +66        fp = self._filter_run(res)
    +67
    +68        output = subprocess.check_output(["trec_eval", self.qrels_fp, fp]).decode()
    +69
    +70        metrics = {}
    +71
    +72        for line in str(output).split("\n"):
    +73            try:
    +74                metric, _, value = line.split()
    +75                metrics[metric] = value
    +76            except:
    +77                continue
    +78
    +79        return metrics
     
    @@ -141,60 +157,75 @@

    -
    15class ResidualEvaluator(Evaluator):
    -16    def __init__(self, qrels: str, metrics: List[str], filter_ids: Dict[str, List[str]]):
    -17        super().__init__(qrels, metrics)
    -18        self.qrels_fp = qrels
    -19        self.filter_ids = filter_ids
    -20
    -21    def _filter_run(self, res: str):
    -22        if self.filter_ids is None:
    -23            return res
    -24
    -25        tmpdir = tempfile.mkdtemp()
    -26        tmpfp = os.path.join(tmpdir, str(uuid.uuid4()))
    -27
    -28        writer = open(tmpfp, 'w+')
    -29
    -30        with open(res) as out_file:
    -31            for line in out_file:
    -32                topic_num, _, doc_id, _, _, _ = line.split()
    -33                if doc_id in self.filter_ids[topic_num]:
    -34                    continue
    -35
    -36                writer.write(line)
    +            
    16class ResidualEvaluator(Evaluator):
    +17    """ Residual Scoring is the scoring of a subset of documents or the residiaul. The residual is created by removing documents from the collection and qrels.
    +18    """
    +19
    +20    def __init__(self, qrels: str, metrics: List[str], filter_ids: Dict[str, List[str]]):
    +21        """ 
    +22        Args:
    +23            qrels (str): Path to qrels 
    +24            metrics (List[str]): A list of metrics with depth e.g. NDCG@1000
    +25            filter_ids (Dict[str, List[str]]): A list of IDs to remove from the collection given by Dict[Topic_num, [Docids]]
    +26        """
    +27        super().__init__(qrels, metrics)
    +28        self.qrels_fp = qrels
    +29        self.filter_ids = filter_ids
    +30
    +31    def _filter_run(self, res: str):
    +32        if self.filter_ids is None:
    +33            return res
    +34
    +35        tmpdir = tempfile.mkdtemp()
    +36        tmpfp = os.path.join(tmpdir, str(uuid.uuid4()))
     37
    -38        writer.close()
    +38        writer = open(tmpfp, 'w+')
     39
    -40        return tmpfp
    -41
    -42    def evaluate_runs(self, res: Union[str, List[str]], with_trec_binary=False, **kwargs):
    -43        if with_trec_binary:
    -44            return self._evaluate_with_binary(res, **kwargs)
    +40        with open(res) as out_file:
    +41            for line in out_file:
    +42                topic_num, _, doc_id, _, _, _ = line.split()
    +43                if doc_id in self.filter_ids[topic_num]:
    +44                    continue
     45
    -46        fp = self._filter_run(res)
    +46                writer.write(line)
     47
    -48        return super().evaluate_runs(fp, **kwargs)
    +48        writer.close()
     49
    -50    def _evaluate_with_binary(self, res, **kwargs):
    -51        fp = self._filter_run(res)
    -52
    -53        output = subprocess.check_output(["trec_eval", self.qrels_fp, fp]).decode()
    +50        return tmpfp
    +51
    +52    def evaluate_runs(self, res: Union[str, List[str]], with_trec_binary=False, **kwargs):
    +53        """ Run the residual evaluation for the runs
     54
    -55        metrics = {}
    -56
    -57        for line in str(output).split("\n"):
    -58            try:
    -59                metric, _, value = line.split()
    -60                metrics[metric] = value
    -61            except:
    -62                continue
    +55        :param res: The results to run the evaluator against
    +56        :param with_trec_binary: Use the TREC C binary instead of the default Python library, defaults to False
    +57        :return: A dictionary of supplied metrics of the results against the qrels 
    +58        """
    +59        if with_trec_binary:
    +60            return self._evaluate_with_binary(res, **kwargs)
    +61
    +62        fp = self._filter_run(res)
     63
    -64        return metrics
    +64        return super().evaluate_runs(fp, **kwargs)
    +65
    +66    def _evaluate_with_binary(self, res, **kwargs):
    +67        fp = self._filter_run(res)
    +68
    +69        output = subprocess.check_output(["trec_eval", self.qrels_fp, fp]).decode()
    +70
    +71        metrics = {}
    +72
    +73        for line in str(output).split("\n"):
    +74            try:
    +75                metric, _, value = line.split()
    +76                metrics[metric] = value
    +77            except:
    +78                continue
    +79
    +80        return metrics
     
    -

    Evaluation class for computing metrics from TREC-style files

    +

    Residual Scoring is the scoring of a subset of documents or the residiaul. The residual is created by removing documents from the collection and qrels.

    @@ -208,14 +239,25 @@

    -
    16    def __init__(self, qrels: str, metrics: List[str], filter_ids: Dict[str, List[str]]):
    -17        super().__init__(qrels, metrics)
    -18        self.qrels_fp = qrels
    -19        self.filter_ids = filter_ids
    +            
    20    def __init__(self, qrels: str, metrics: List[str], filter_ids: Dict[str, List[str]]):
    +21        """ 
    +22        Args:
    +23            qrels (str): Path to qrels 
    +24            metrics (List[str]): A list of metrics with depth e.g. NDCG@1000
    +25            filter_ids (Dict[str, List[str]]): A list of IDs to remove from the collection given by Dict[Topic_num, [Docids]]
    +26        """
    +27        super().__init__(qrels, metrics)
    +28        self.qrels_fp = qrels
    +29        self.filter_ids = filter_ids
     
    - +

    Args: + qrels (str): Path to qrels + metrics (List[str]): A list of metrics with depth e.g. NDCG@1000 + filter_ids (Dict[str, List[str]]): A list of IDs to remove from the collection given by Dict[Topic_num, [Docids]]

    +
    +
    @@ -229,26 +271,36 @@

    -
    42    def evaluate_runs(self, res: Union[str, List[str]], with_trec_binary=False, **kwargs):
    -43        if with_trec_binary:
    -44            return self._evaluate_with_binary(res, **kwargs)
    -45
    -46        fp = self._filter_run(res)
    -47
    -48        return super().evaluate_runs(fp, **kwargs)
    +            
    52    def evaluate_runs(self, res: Union[str, List[str]], with_trec_binary=False, **kwargs):
    +53        """ Run the residual evaluation for the runs
    +54
    +55        :param res: The results to run the evaluator against
    +56        :param with_trec_binary: Use the TREC C binary instead of the default Python library, defaults to False
    +57        :return: A dictionary of supplied metrics of the results against the qrels 
    +58        """
    +59        if with_trec_binary:
    +60            return self._evaluate_with_binary(res, **kwargs)
    +61
    +62        fp = self._filter_run(res)
    +63
    +64        return super().evaluate_runs(fp, **kwargs)
     
    -

    Evaluates the TREC-style results from an input result list or file

    +

    Run the residual evaluation for the runs

    Parameters
      -
    • res: Results file path or raw results list
    • -
    • kwargs: Keyword arguments to pass to the underlying analysis_tools_ir.parse_run library
    • +
    • res: The results to run the evaluator against
    • +
    • with_trec_binary: Use the TREC C binary instead of the default Python library, defaults to False
    Returns
    + +
    +

    A dictionary of supplied metrics of the results against the qrels

    +
    @@ -366,7 +418,7 @@
    Inherited Members
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -393,7 +445,7 @@
    Inherited Members
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/models.html b/docs/debeir/models.html index c538219..50cef15 100644 --- a/docs/debeir/models.html +++ b/docs/debeir/models.html @@ -3,7 +3,7 @@ - + debeir.models API documentation @@ -152,7 +152,7 @@

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -179,7 +179,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/models/colbert.html b/docs/debeir/models/colbert.html index 63b5811..a067485 100644 --- a/docs/debeir/models/colbert.html +++ b/docs/debeir/models/colbert.html @@ -3,7 +3,7 @@ - + debeir.models.colbert API documentation @@ -147,13 +147,13 @@

    -
      1import logging
    -  2import os
    -  3
    -  4import torch
    -  5from torch import nn
    -  6from transformers import BertModel, BertConfig
    -  7import json
    +                        
      1import json
    +  2import logging
    +  3import os
    +  4
    +  5import torch
    +  6from torch import nn
    +  7from transformers import BertConfig, BertModel
       8
       9logger = logging.getLogger(__name__)
      10
    @@ -161,81 +161,81 @@ 

    12 "relu": nn.ReLU, 13} 14 - 15 - 16LOSS_FUNCS = { - 17 'cross_entropy_loss': nn.CrossEntropyLoss, - 18} + 15LOSS_FUNCS = { + 16 'cross_entropy_loss': nn.CrossEntropyLoss, + 17} + 18 19 - 20 - 21class CoLBERTConfig(object): - 22 default_fname = "colbert_config.json" - 23 - 24 def __init__(self, **kwargs): - 25 self.kwargs = kwargs - 26 self.__dict__.update(kwargs) - 27 - 28 def save(self, path, fname=default_fname): - 29 """ - 30 :param fname: file name - 31 :param path: Path to save - 32 """ - 33 json.dump(self.kwargs, open(os.path.join(path, fname), 'w+')) - 34 - 35 @classmethod - 36 def load(cls, path, fname=default_fname): - 37 """ - 38 Load the ColBERT config from path (don't point to file name just directory) - 39 :return ColBERTConfig: - 40 """ - 41 - 42 kwargs = json.load(open(os.path.join(path, fname))) - 43 - 44 return CoLBERTConfig(**kwargs) + 20class CoLBERTConfig(object): + 21 default_fname = "colbert_config.json" + 22 + 23 def __init__(self, **kwargs): + 24 self.kwargs = kwargs + 25 self.__dict__.update(kwargs) + 26 + 27 def save(self, path, fname=default_fname): + 28 """ + 29 :param fname: file name + 30 :param path: Path to save + 31 """ + 32 json.dump(self.kwargs, open(os.path.join(path, fname), 'w+')) + 33 + 34 @classmethod + 35 def load(cls, path, fname=default_fname): + 36 """ + 37 Load the ColBERT config from path (don't point to file name just directory) + 38 :return ColBERTConfig: + 39 """ + 40 + 41 kwargs = json.load(open(os.path.join(path, fname))) + 42 + 43 return CoLBERTConfig(**kwargs) + 44 45 - 46 - 47class ConvolutionalBlock(nn.Module): - 48 - 49 def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU): - 50 super(ConvolutionalBlock, self).__init__() - 51 - 52 padding = int((kernel_size - 1) / 2) - 53 if kernel_size == 3: - 54 assert padding == 1 # checks - 55 if kernel_size == 5: - 56 assert padding == 2 # checks - 57 layers = [ - 58 nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding), - 59 nn.BatchNorm1d(num_features=out_channels) - 60 ] - 61 - 62 if act_func is not None: - 63 layers.append(act_func()) - 64 - 65 self.sequential = nn.Sequential(*layers) - 66 - 67 def forward(self, x): - 68 return self.sequential(x) + 46class ConvolutionalBlock(nn.Module): + 47 + 48 def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU): + 49 super(ConvolutionalBlock, self).__init__() + 50 + 51 padding = int((kernel_size - 1) / 2) + 52 if kernel_size == 3: + 53 assert padding == 1 # checks + 54 if kernel_size == 5: + 55 assert padding == 2 # checks + 56 layers = [ + 57 nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding), + 58 nn.BatchNorm1d(num_features=out_channels) + 59 ] + 60 + 61 if act_func is not None: + 62 layers.append(act_func()) + 63 + 64 self.sequential = nn.Sequential(*layers) + 65 + 66 def forward(self, x): + 67 return self.sequential(x) + 68 69 - 70 - 71class KMaxPool(nn.Module): - 72 def __init__(self, k=1): - 73 super(KMaxPool, self).__init__() - 74 - 75 self.k = k - 76 - 77 def forward(self, x): - 78 # x : batch_size, channel, time_steps - 79 if self.k == 'half': - 80 time_steps = x.shape(2) - 81 self.k = time_steps // 2 - 82 - 83 kmax, kargmax = torch.topk(x, self.k, sorted=True) - 84 # kmax, kargmax = x.topk(self.k, dim=2) - 85 return kmax + 70class KMaxPool(nn.Module): + 71 def __init__(self, k=1): + 72 super(KMaxPool, self).__init__() + 73 + 74 self.k = k + 75 + 76 def forward(self, x): + 77 # x : batch_size, channel, time_steps + 78 if self.k == 'half': + 79 time_steps = x.shape(2) + 80 self.k = time_steps // 2 + 81 + 82 kmax, kargmax = torch.topk(x, self.k, sorted=True) + 83 # kmax, kargmax = x.topk(self.k, dim=2) + 84 return kmax + 85 86 - 87 - 88def visualisation_dump(argmax, input_tensors): - 89 pass + 87def visualisation_dump(argmax, input_tensors): + 88 pass + 89 90 91class ResidualBlock(nn.Module): 92 @@ -258,7 +258,7 @@

    109 110class ColBERT(nn.Module): 111 def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int -112 = 128, k: int = 8, +112 = 128, k: int = 8, 113 optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, 114 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, 115 act_func="mish", loss_func='cross_entropy_loss', **kwargs): # kwargs for compat @@ -318,16 +318,16 @@

    169 # Create the MLP to compress the k signals 170 linear_layers = list() 171 linear_layers.append(nn.Linear(hidden_dim * k, num_labels)) # Downsample into Kmaxpool? -172 #linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons)) -173 #linear_layers.append(nn.Dropout(dropout_perc)) -174 #linear_layers.append(nn.Linear(hidden_neurons, num_labels)) +172 # linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons)) +173 # linear_layers.append(nn.Dropout(dropout_perc)) +174 # linear_layers.append(nn.Linear(hidden_neurons, num_labels)) 175 176 self.linear_layers = nn.Sequential(*linear_layers) 177 self.apply(weight_init) 178 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, -179 config=self.bert_config) # Add Bert model after random initialisation +179 config=self.bert_config) # Add Bert model after random initialisation 180 -181 for param in self.bert.pooler.parameters(): # We don't need the pooler +181 for param in self.bert.pooler.parameters(): # We don't need the pooler 182 param.requires_grad = False 183 184 self.bert.to(self.device) @@ -353,13 +353,13 @@

    204 assert len(self.transformation_blocks) == len(hidden_states) 205 zip_args.append(self.transformation_blocks) 206 else: -207 zip_args.append([identity for i in range(self.num_layers+1)]) +207 zip_args.append([identity for i in range(self.num_layers + 1)]) 208 209 if self.use_batch_norms: 210 assert len(self.batch_norms) == len(hidden_states) 211 zip_args.append(self.batch_norms) 212 else: -213 zip_args.append([identity for i in range(self.num_layers+1)]) +213 zip_args.append([identity for i in range(self.num_layers + 1)]) 214 215 out = None 216 for co, hi, tr, bn in zip(*zip_args): @@ -377,184 +377,181 @@

    228 229 return self.loss_func(logits, labels), logits 230 -231 -232 @classmethod -233 def from_config(cls, *args, config_path): -234 kwargs = torch.load(config_path) -235 return ColBERT(*args, **kwargs) -236 -237 @classmethod -238 def from_pretrained(cls, output_dir, **kwargs): -239 config_found = True -240 colbert_config = None -241 -242 try: -243 colbert_config = CoLBERTConfig.load(output_dir) -244 except: -245 config_found = False -246 -247 bert_config = None -248 -249 if 'config' in kwargs: -250 bert_config = kwargs['config'] -251 del kwargs['config'] -252 else: -253 bert_config = BertConfig.from_pretrained(output_dir) -254 -255 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -256 model = None -257 -258 if config_found: -259 model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs) -260 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) -261 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") -262 -263 else: -264 model = ColBERT((output_dir,), {}, config=bert_config, **kwargs) -265 logger.info(f"*** Create New CNN Bert Model ***") -266 -267 return model -268 -269 def save_pretrained(self, output_dir): -270 logger.info(f"*** Saved Bert Model Weights to {output_dir}") -271 self.bert.save_pretrained(output_dir) -272 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') -273 self.bert_config.save_pretrained(output_dir) -274 self.colbert_config.save(output_dir) -275 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}") +231 @classmethod +232 def from_config(cls, *args, config_path): +233 kwargs = torch.load(config_path) +234 return ColBERT(*args, **kwargs) +235 +236 @classmethod +237 def from_pretrained(cls, output_dir, **kwargs): +238 config_found = True +239 colbert_config = None +240 +241 try: +242 colbert_config = CoLBERTConfig.load(output_dir) +243 except: +244 config_found = False +245 +246 bert_config = None +247 +248 if 'config' in kwargs: +249 bert_config = kwargs['config'] +250 del kwargs['config'] +251 else: +252 bert_config = BertConfig.from_pretrained(output_dir) +253 +254 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +255 model = None +256 +257 if config_found: +258 model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs) +259 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) +260 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") +261 +262 else: +263 model = ColBERT((output_dir,), {}, config=bert_config, **kwargs) +264 logger.info(f"*** Create New CNN Bert Model ***") +265 +266 return model +267 +268 def save_pretrained(self, output_dir): +269 logger.info(f"*** Saved Bert Model Weights to {output_dir}") +270 self.bert.save_pretrained(output_dir) +271 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') +272 self.bert_config.save_pretrained(output_dir) +273 self.colbert_config.save(output_dir) +274 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}") +275 276 -277 -278class ComBERT(nn.Module): -279 def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int= 128, -280 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, -281 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, -282 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs): # kwargs for compat -283 -284 super().__init__() -285 self.device = device -286 hidden_dim = config.hidden_size -287 self.seq_length = max_seq_len -288 self.use_trans_blocks = use_trans_blocks -289 self.use_batch_norms = use_batch_norms -290 self.num_layers = config.num_hidden_layers -291 num_labels = config.num_labels -292 self.num_blocks = num_blocks -293 self.loss_func = LOSS_FUNCS[loss_func.lower()]() -294 -295 # Save our kwargs to reinitialise the model during evaluation -296 self.bert_config = config -297 self.colbert_config = CoLBERTConfig(k=k, -298 optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons, -299 use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks, -300 residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc, -301 act_func=act_func, bert_model_args=bert_model_args, -302 bert_model_kwargs=bert_model_kwargs) -303 -304 logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs)) -305 -306 # relax this constraint later -307 assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}" -308 act_func = ACT_FUNCS[act_func.lower()] -309 -310 # CNN Part -311 conv_layers = [] -312 -313 # Adds up to num_layers + 1 embedding layer -314 conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) -315 -316 -317 for i in range(num_blocks): -318 conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut, -319 kernel_size=residual_kernel_size, act_func=act_func)) -320 -321 self.conv_layers = nn.ModuleList(conv_layers) -322 self.kmax_pooling = KMaxPool(k) -323 -324 # Create the MLP to compress the k signals -325 linear_layers = list() -326 linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons)) # Downsample into Kmaxpool? -327 linear_layers.append(nn.Dropout(dropout_perc)) -328 linear_layers.append(nn.Linear(hidden_neurons, num_labels)) -329 -330 self.linear_layers = nn.Sequential(*linear_layers) -331 self.apply(weight_init) -332 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, -333 config=self.bert_config) # Add Bert model after random initialisation -334 self.bert.to(self.device) -335 -336 def forward(self, *args, **kwargs): -337 # input_ids: batch_size x seq_length x hidden_dim -338 -339 labels = kwargs['labels'] if 'labels' in kwargs else None -340 if labels is not None: del kwargs['labels'] -341 -342 bert_outputs = self.bert(*args, **kwargs) -343 hidden_states = list(bert_outputs[-1]) -344 embedding_layer = hidden_states.pop(0) +277class ComBERT(nn.Module): +278 def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int = 128, +279 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, +280 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, +281 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs): # kwargs for compat +282 +283 super().__init__() +284 self.device = device +285 hidden_dim = config.hidden_size +286 self.seq_length = max_seq_len +287 self.use_trans_blocks = use_trans_blocks +288 self.use_batch_norms = use_batch_norms +289 self.num_layers = config.num_hidden_layers +290 num_labels = config.num_labels +291 self.num_blocks = num_blocks +292 self.loss_func = LOSS_FUNCS[loss_func.lower()]() +293 +294 # Save our kwargs to reinitialise the model during evaluation +295 self.bert_config = config +296 self.colbert_config = CoLBERTConfig(k=k, +297 optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons, +298 use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks, +299 residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc, +300 act_func=act_func, bert_model_args=bert_model_args, +301 bert_model_kwargs=bert_model_kwargs) +302 +303 logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs)) +304 +305 # relax this constraint later +306 assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}" +307 act_func = ACT_FUNCS[act_func.lower()] +308 +309 # CNN Part +310 conv_layers = [] +311 +312 # Adds up to num_layers + 1 embedding layer +313 conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) +314 +315 for i in range(num_blocks): +316 conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut, +317 kernel_size=residual_kernel_size, act_func=act_func)) +318 +319 self.conv_layers = nn.ModuleList(conv_layers) +320 self.kmax_pooling = KMaxPool(k) +321 +322 # Create the MLP to compress the k signals +323 linear_layers = list() +324 linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons)) # Downsample into Kmaxpool? +325 linear_layers.append(nn.Dropout(dropout_perc)) +326 linear_layers.append(nn.Linear(hidden_neurons, num_labels)) +327 +328 self.linear_layers = nn.Sequential(*linear_layers) +329 self.apply(weight_init) +330 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, +331 config=self.bert_config) # Add Bert model after random initialisation +332 self.bert.to(self.device) +333 +334 def forward(self, *args, **kwargs): +335 # input_ids: batch_size x seq_length x hidden_dim +336 +337 labels = kwargs['labels'] if 'labels' in kwargs else None +338 if labels is not None: del kwargs['labels'] +339 +340 bert_outputs = self.bert(*args, **kwargs) +341 hidden_states = list(bert_outputs[-1]) +342 embedding_layer = hidden_states.pop(0) +343 +344 split_size = len(hidden_states) // self.num_blocks 345 -346 split_size = len(hidden_states) // self.num_blocks -347 -348 assert split_size % 2 == 0, "must be an even number" -349 split_layers = [hidden_states[x:x+split_size] for x in range(0, len(hidden_states), split_size)] -350 split_layers.insert(0, embedding_layer) +346 assert split_size % 2 == 0, "must be an even number" +347 split_layers = [hidden_states[x:x + split_size] for x in range(0, len(hidden_states), split_size)] +348 split_layers.insert(0, embedding_layer) +349 +350 assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length" 351 -352 assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length" +352 outputs = [] 353 -354 outputs = [] -355 -356 for cnv, layer in zip(self.conv_layers, split_layers): -357 outputs.append(self.kmax_pooling(cnv(layer))) -358 -359 # batch_size x seq_len x hidden -> batch_size x flatten -360 logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1)) +354 for cnv, layer in zip(self.conv_layers, split_layers): +355 outputs.append(self.kmax_pooling(cnv(layer))) +356 +357 # batch_size x seq_len x hidden -> batch_size x flatten +358 logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1)) +359 +360 return self.loss_func(logits, labels), logits 361 -362 return self.loss_func(logits, labels), logits -363 -364 -365 @classmethod -366 def from_config(cls, *args, config_path): -367 kwargs = torch.load(config_path) -368 return ComBERT(*args, **kwargs) -369 -370 @classmethod -371 def from_pretrained(cls, output_dir, **kwargs): -372 config_found = True -373 colbert_config = None -374 -375 try: -376 colbert_config = CoLBERTConfig.load(output_dir) -377 except: -378 config_found = False -379 -380 bert_config = None -381 -382 if 'config' in kwargs: -383 bert_config = kwargs['config'] -384 del kwargs['config'] -385 else: -386 bert_config = BertConfig.from_pretrained(output_dir) +362 @classmethod +363 def from_config(cls, *args, config_path): +364 kwargs = torch.load(config_path) +365 return ComBERT(*args, **kwargs) +366 +367 @classmethod +368 def from_pretrained(cls, output_dir, **kwargs): +369 config_found = True +370 colbert_config = None +371 +372 try: +373 colbert_config = CoLBERTConfig.load(output_dir) +374 except: +375 config_found = False +376 +377 bert_config = None +378 +379 if 'config' in kwargs: +380 bert_config = kwargs['config'] +381 del kwargs['config'] +382 else: +383 bert_config = BertConfig.from_pretrained(output_dir) +384 +385 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +386 model = None 387 -388 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -389 model = None -390 -391 if config_found: -392 model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs) -393 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) -394 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") -395 -396 else: -397 model = ComBERT((output_dir,), {}, config=bert_config, **kwargs) -398 logger.info(f"*** Create New CNN Bert Model ***") -399 -400 return model -401 -402 def save_pretrained(self, output_dir): -403 logger.info(f"*** Saved Bert Model Weights to {output_dir}") -404 self.bert.save_pretrained(output_dir) -405 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') -406 self.bert_config.save_pretrained(output_dir) -407 self.colbert_config.save(output_dir) -408 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}") +388 if config_found: +389 model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs) +390 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) +391 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") +392 +393 else: +394 model = ComBERT((output_dir,), {}, config=bert_config, **kwargs) +395 logger.info(f"*** Create New CNN Bert Model ***") +396 +397 return model +398 +399 def save_pretrained(self, output_dir): +400 logger.info(f"*** Saved Bert Model Weights to {output_dir}") +401 self.bert.save_pretrained(output_dir) +402 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') +403 self.bert_config.save_pretrained(output_dir) +404 self.colbert_config.save(output_dir) +405 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")

    @@ -570,30 +567,30 @@

    -
    22class CoLBERTConfig(object):
    -23    default_fname = "colbert_config.json"
    -24
    -25    def __init__(self, **kwargs):
    -26        self.kwargs = kwargs
    -27        self.__dict__.update(kwargs)
    -28
    -29    def save(self, path, fname=default_fname):
    -30        """
    -31        :param fname: file name
    -32        :param path: Path to save
    -33        """
    -34        json.dump(self.kwargs, open(os.path.join(path, fname), 'w+'))
    -35
    -36    @classmethod
    -37    def load(cls, path, fname=default_fname):
    -38        """
    -39        Load the ColBERT config from path (don't point to file name just directory)
    -40        :return ColBERTConfig:
    -41        """
    -42
    -43        kwargs = json.load(open(os.path.join(path, fname)))
    -44
    -45        return CoLBERTConfig(**kwargs)
    +            
    21class CoLBERTConfig(object):
    +22    default_fname = "colbert_config.json"
    +23
    +24    def __init__(self, **kwargs):
    +25        self.kwargs = kwargs
    +26        self.__dict__.update(kwargs)
    +27
    +28    def save(self, path, fname=default_fname):
    +29        """
    +30        :param fname: file name
    +31        :param path: Path to save
    +32        """
    +33        json.dump(self.kwargs, open(os.path.join(path, fname), 'w+'))
    +34
    +35    @classmethod
    +36    def load(cls, path, fname=default_fname):
    +37        """
    +38        Load the ColBERT config from path (don't point to file name just directory)
    +39        :return ColBERTConfig:
    +40        """
    +41
    +42        kwargs = json.load(open(os.path.join(path, fname)))
    +43
    +44        return CoLBERTConfig(**kwargs)
     
    @@ -609,9 +606,9 @@

    -
    25    def __init__(self, **kwargs):
    -26        self.kwargs = kwargs
    -27        self.__dict__.update(kwargs)
    +            
    24    def __init__(self, **kwargs):
    +25        self.kwargs = kwargs
    +26        self.__dict__.update(kwargs)
     
    @@ -629,12 +626,12 @@

    -
    29    def save(self, path, fname=default_fname):
    -30        """
    -31        :param fname: file name
    -32        :param path: Path to save
    -33        """
    -34        json.dump(self.kwargs, open(os.path.join(path, fname), 'w+'))
    +            
    28    def save(self, path, fname=default_fname):
    +29        """
    +30        :param fname: file name
    +31        :param path: Path to save
    +32        """
    +33        json.dump(self.kwargs, open(os.path.join(path, fname), 'w+'))
     
    @@ -660,16 +657,16 @@

    -
    36    @classmethod
    -37    def load(cls, path, fname=default_fname):
    -38        """
    -39        Load the ColBERT config from path (don't point to file name just directory)
    -40        :return ColBERTConfig:
    -41        """
    -42
    -43        kwargs = json.load(open(os.path.join(path, fname)))
    -44
    -45        return CoLBERTConfig(**kwargs)
    +            
    35    @classmethod
    +36    def load(cls, path, fname=default_fname):
    +37        """
    +38        Load the ColBERT config from path (don't point to file name just directory)
    +39        :return ColBERTConfig:
    +40        """
    +41
    +42        kwargs = json.load(open(os.path.join(path, fname)))
    +43
    +44        return CoLBERTConfig(**kwargs)
     
    @@ -692,28 +689,28 @@
    Returns
    -
    48class ConvolutionalBlock(nn.Module):
    -49
    -50    def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU):
    -51        super(ConvolutionalBlock, self).__init__()
    -52
    -53        padding = int((kernel_size - 1) / 2)
    -54        if kernel_size == 3:
    -55            assert padding == 1  # checks
    -56        if kernel_size == 5:
    -57            assert padding == 2  # checks
    -58        layers = [
    -59            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding),
    -60            nn.BatchNorm1d(num_features=out_channels)
    -61        ]
    -62
    -63        if act_func is not None:
    -64            layers.append(act_func())
    -65
    -66        self.sequential = nn.Sequential(*layers)
    -67
    -68    def forward(self, x):
    -69        return self.sequential(x)
    +            
    47class ConvolutionalBlock(nn.Module):
    +48
    +49    def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU):
    +50        super(ConvolutionalBlock, self).__init__()
    +51
    +52        padding = int((kernel_size - 1) / 2)
    +53        if kernel_size == 3:
    +54            assert padding == 1  # checks
    +55        if kernel_size == 5:
    +56            assert padding == 2  # checks
    +57        layers = [
    +58            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding),
    +59            nn.BatchNorm1d(num_features=out_channels)
    +60        ]
    +61
    +62        if act_func is not None:
    +63            layers.append(act_func())
    +64
    +65        self.sequential = nn.Sequential(*layers)
    +66
    +67    def forward(self, x):
    +68        return self.sequential(x)
     
    @@ -739,7 +736,7 @@
    Returns

    Submodules assigned in this way will be registered, and will have their -parameters converted too when you call to, etc.

    +parameters converted too when you call to(), etc.

    @@ -764,23 +761,23 @@
    Returns
    -
    50    def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU):
    -51        super(ConvolutionalBlock, self).__init__()
    -52
    -53        padding = int((kernel_size - 1) / 2)
    -54        if kernel_size == 3:
    -55            assert padding == 1  # checks
    -56        if kernel_size == 5:
    -57            assert padding == 2  # checks
    -58        layers = [
    -59            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding),
    -60            nn.BatchNorm1d(num_features=out_channels)
    -61        ]
    -62
    -63        if act_func is not None:
    -64            layers.append(act_func())
    -65
    -66        self.sequential = nn.Sequential(*layers)
    +            
    49    def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU):
    +50        super(ConvolutionalBlock, self).__init__()
    +51
    +52        padding = int((kernel_size - 1) / 2)
    +53        if kernel_size == 3:
    +54            assert padding == 1  # checks
    +55        if kernel_size == 5:
    +56            assert padding == 2  # checks
    +57        layers = [
    +58            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding),
    +59            nn.BatchNorm1d(num_features=out_channels)
    +60        ]
    +61
    +62        if act_func is not None:
    +63            layers.append(act_func())
    +64
    +65        self.sequential = nn.Sequential(*layers)
     
    @@ -800,8 +797,8 @@
    Returns
    -
    68    def forward(self, x):
    -69        return self.sequential(x)
    +            
    67    def forward(self, x):
    +68        return self.sequential(x)
     
    @@ -883,21 +880,21 @@
    Inherited Members
    -
    72class KMaxPool(nn.Module):
    -73    def __init__(self, k=1):
    -74        super(KMaxPool, self).__init__()
    -75
    -76        self.k = k
    -77
    -78    def forward(self, x):
    -79        # x : batch_size, channel, time_steps
    -80        if self.k == 'half':
    -81            time_steps = x.shape(2)
    -82            self.k = time_steps // 2
    -83
    -84        kmax, kargmax = torch.topk(x, self.k, sorted=True)
    -85        # kmax, kargmax = x.topk(self.k, dim=2)
    -86        return kmax
    +            
    71class KMaxPool(nn.Module):
    +72    def __init__(self, k=1):
    +73        super(KMaxPool, self).__init__()
    +74
    +75        self.k = k
    +76
    +77    def forward(self, x):
    +78        # x : batch_size, channel, time_steps
    +79        if self.k == 'half':
    +80            time_steps = x.shape(2)
    +81            self.k = time_steps // 2
    +82
    +83        kmax, kargmax = torch.topk(x, self.k, sorted=True)
    +84        # kmax, kargmax = x.topk(self.k, dim=2)
    +85        return kmax
     
    @@ -923,7 +920,7 @@
    Inherited Members

    Submodules assigned in this way will be registered, and will have their -parameters converted too when you call to, etc.

    +parameters converted too when you call to(), etc.

    @@ -948,10 +945,10 @@
    Inherited Members
    -
    73    def __init__(self, k=1):
    -74        super(KMaxPool, self).__init__()
    -75
    -76        self.k = k
    +            
    72    def __init__(self, k=1):
    +73        super(KMaxPool, self).__init__()
    +74
    +75        self.k = k
     
    @@ -971,15 +968,15 @@
    Inherited Members
    -
    78    def forward(self, x):
    -79        # x : batch_size, channel, time_steps
    -80        if self.k == 'half':
    -81            time_steps = x.shape(2)
    -82            self.k = time_steps // 2
    -83
    -84        kmax, kargmax = torch.topk(x, self.k, sorted=True)
    -85        # kmax, kargmax = x.topk(self.k, dim=2)
    -86        return kmax
    +            
    77    def forward(self, x):
    +78        # x : batch_size, channel, time_steps
    +79        if self.k == 'half':
    +80            time_steps = x.shape(2)
    +81            self.k = time_steps // 2
    +82
    +83        kmax, kargmax = torch.topk(x, self.k, sorted=True)
    +84        # kmax, kargmax = x.topk(self.k, dim=2)
    +85        return kmax
     
    @@ -1061,8 +1058,8 @@
    Inherited Members
    -
    89def visualisation_dump(argmax, input_tensors):
    -90    pass
    +            
    88def visualisation_dump(argmax, input_tensors):
    +89    pass
     
    @@ -1122,7 +1119,7 @@
    Inherited Members

    Submodules assigned in this way will be registered, and will have their -parameters converted too when you call to, etc.

    +parameters converted too when you call to(), etc.

    @@ -1263,7 +1260,7 @@
    Inherited Members
    111class ColBERT(nn.Module):
     112    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int
    -113            = 128, k: int = 8,
    +113    = 128, k: int = 8,
     114                 optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
     115                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
     116                 act_func="mish", loss_func='cross_entropy_loss', **kwargs):  # kwargs for compat
    @@ -1323,16 +1320,16 @@ 
    Inherited Members
    170 # Create the MLP to compress the k signals 171 linear_layers = list() 172 linear_layers.append(nn.Linear(hidden_dim * k, num_labels)) # Downsample into Kmaxpool? -173 #linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons)) -174 #linear_layers.append(nn.Dropout(dropout_perc)) -175 #linear_layers.append(nn.Linear(hidden_neurons, num_labels)) +173 # linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons)) +174 # linear_layers.append(nn.Dropout(dropout_perc)) +175 # linear_layers.append(nn.Linear(hidden_neurons, num_labels)) 176 177 self.linear_layers = nn.Sequential(*linear_layers) 178 self.apply(weight_init) 179 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, -180 config=self.bert_config) # Add Bert model after random initialisation +180 config=self.bert_config) # Add Bert model after random initialisation 181 -182 for param in self.bert.pooler.parameters(): # We don't need the pooler +182 for param in self.bert.pooler.parameters(): # We don't need the pooler 183 param.requires_grad = False 184 185 self.bert.to(self.device) @@ -1358,13 +1355,13 @@
    Inherited Members
    205 assert len(self.transformation_blocks) == len(hidden_states) 206 zip_args.append(self.transformation_blocks) 207 else: -208 zip_args.append([identity for i in range(self.num_layers+1)]) +208 zip_args.append([identity for i in range(self.num_layers + 1)]) 209 210 if self.use_batch_norms: 211 assert len(self.batch_norms) == len(hidden_states) 212 zip_args.append(self.batch_norms) 213 else: -214 zip_args.append([identity for i in range(self.num_layers+1)]) +214 zip_args.append([identity for i in range(self.num_layers + 1)]) 215 216 out = None 217 for co, hi, tr, bn in zip(*zip_args): @@ -1382,51 +1379,50 @@
    Inherited Members
    229 230 return self.loss_func(logits, labels), logits 231 -232 -233 @classmethod -234 def from_config(cls, *args, config_path): -235 kwargs = torch.load(config_path) -236 return ColBERT(*args, **kwargs) -237 -238 @classmethod -239 def from_pretrained(cls, output_dir, **kwargs): -240 config_found = True -241 colbert_config = None -242 -243 try: -244 colbert_config = CoLBERTConfig.load(output_dir) -245 except: -246 config_found = False -247 -248 bert_config = None -249 -250 if 'config' in kwargs: -251 bert_config = kwargs['config'] -252 del kwargs['config'] -253 else: -254 bert_config = BertConfig.from_pretrained(output_dir) -255 -256 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -257 model = None -258 -259 if config_found: -260 model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs) -261 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) -262 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") -263 -264 else: -265 model = ColBERT((output_dir,), {}, config=bert_config, **kwargs) -266 logger.info(f"*** Create New CNN Bert Model ***") -267 -268 return model -269 -270 def save_pretrained(self, output_dir): -271 logger.info(f"*** Saved Bert Model Weights to {output_dir}") -272 self.bert.save_pretrained(output_dir) -273 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') -274 self.bert_config.save_pretrained(output_dir) -275 self.colbert_config.save(output_dir) -276 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}") +232 @classmethod +233 def from_config(cls, *args, config_path): +234 kwargs = torch.load(config_path) +235 return ColBERT(*args, **kwargs) +236 +237 @classmethod +238 def from_pretrained(cls, output_dir, **kwargs): +239 config_found = True +240 colbert_config = None +241 +242 try: +243 colbert_config = CoLBERTConfig.load(output_dir) +244 except: +245 config_found = False +246 +247 bert_config = None +248 +249 if 'config' in kwargs: +250 bert_config = kwargs['config'] +251 del kwargs['config'] +252 else: +253 bert_config = BertConfig.from_pretrained(output_dir) +254 +255 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +256 model = None +257 +258 if config_found: +259 model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs) +260 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) +261 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") +262 +263 else: +264 model = ColBERT((output_dir,), {}, config=bert_config, **kwargs) +265 logger.info(f"*** Create New CNN Bert Model ***") +266 +267 return model +268 +269 def save_pretrained(self, output_dir): +270 logger.info(f"*** Saved Bert Model Weights to {output_dir}") +271 self.bert.save_pretrained(output_dir) +272 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') +273 self.bert_config.save_pretrained(output_dir) +274 self.colbert_config.save(output_dir) +275 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
    @@ -1452,7 +1448,7 @@
    Inherited Members

    Submodules assigned in this way will be registered, and will have their -parameters converted too when you call to, etc.

    +parameters converted too when you call to(), etc.

    @@ -1478,7 +1474,7 @@
    Inherited Members
    112    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int
    -113            = 128, k: int = 8,
    +113    = 128, k: int = 8,
     114                 optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
     115                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
     116                 act_func="mish", loss_func='cross_entropy_loss', **kwargs):  # kwargs for compat
    @@ -1538,16 +1534,16 @@ 
    Inherited Members
    170 # Create the MLP to compress the k signals 171 linear_layers = list() 172 linear_layers.append(nn.Linear(hidden_dim * k, num_labels)) # Downsample into Kmaxpool? -173 #linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons)) -174 #linear_layers.append(nn.Dropout(dropout_perc)) -175 #linear_layers.append(nn.Linear(hidden_neurons, num_labels)) +173 # linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons)) +174 # linear_layers.append(nn.Dropout(dropout_perc)) +175 # linear_layers.append(nn.Linear(hidden_neurons, num_labels)) 176 177 self.linear_layers = nn.Sequential(*linear_layers) 178 self.apply(weight_init) 179 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, -180 config=self.bert_config) # Add Bert model after random initialisation +180 config=self.bert_config) # Add Bert model after random initialisation 181 -182 for param in self.bert.pooler.parameters(): # We don't need the pooler +182 for param in self.bert.pooler.parameters(): # We don't need the pooler 183 param.requires_grad = False 184 185 self.bert.to(self.device) @@ -1591,13 +1587,13 @@
    Inherited Members
    205 assert len(self.transformation_blocks) == len(hidden_states) 206 zip_args.append(self.transformation_blocks) 207 else: -208 zip_args.append([identity for i in range(self.num_layers+1)]) +208 zip_args.append([identity for i in range(self.num_layers + 1)]) 209 210 if self.use_batch_norms: 211 assert len(self.batch_norms) == len(hidden_states) 212 zip_args.append(self.batch_norms) 213 else: -214 zip_args.append([identity for i in range(self.num_layers+1)]) +214 zip_args.append([identity for i in range(self.num_layers + 1)]) 215 216 out = None 217 for co, hi, tr, bn in zip(*zip_args): @@ -1645,10 +1641,10 @@
    Inherited Members
    -
    233    @classmethod
    -234    def from_config(cls, *args, config_path):
    -235        kwargs = torch.load(config_path)
    -236        return ColBERT(*args, **kwargs)
    +            
    232    @classmethod
    +233    def from_config(cls, *args, config_path):
    +234        kwargs = torch.load(config_path)
    +235        return ColBERT(*args, **kwargs)
     
    @@ -1667,37 +1663,37 @@
    Inherited Members
    -
    238    @classmethod
    -239    def from_pretrained(cls, output_dir, **kwargs):
    -240        config_found = True
    -241        colbert_config = None
    -242
    -243        try:
    -244            colbert_config = CoLBERTConfig.load(output_dir)
    -245        except:
    -246            config_found = False
    -247
    -248        bert_config = None
    -249
    -250        if 'config' in kwargs:
    -251            bert_config = kwargs['config']
    -252            del kwargs['config']
    -253        else:
    -254            bert_config = BertConfig.from_pretrained(output_dir)
    -255
    -256        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    -257        model = None
    -258
    -259        if config_found:
    -260            model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs)
    -261            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
    -262            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
    -263
    -264        else:
    -265            model = ColBERT((output_dir,), {}, config=bert_config, **kwargs)
    -266            logger.info(f"*** Create New CNN Bert Model ***")
    -267
    -268        return model
    +            
    237    @classmethod
    +238    def from_pretrained(cls, output_dir, **kwargs):
    +239        config_found = True
    +240        colbert_config = None
    +241
    +242        try:
    +243            colbert_config = CoLBERTConfig.load(output_dir)
    +244        except:
    +245            config_found = False
    +246
    +247        bert_config = None
    +248
    +249        if 'config' in kwargs:
    +250            bert_config = kwargs['config']
    +251            del kwargs['config']
    +252        else:
    +253            bert_config = BertConfig.from_pretrained(output_dir)
    +254
    +255        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    +256        model = None
    +257
    +258        if config_found:
    +259            model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs)
    +260            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
    +261            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
    +262
    +263        else:
    +264            model = ColBERT((output_dir,), {}, config=bert_config, **kwargs)
    +265            logger.info(f"*** Create New CNN Bert Model ***")
    +266
    +267        return model
     
    @@ -1715,13 +1711,13 @@
    Inherited Members
    -
    270    def save_pretrained(self, output_dir):
    -271        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
    -272        self.bert.save_pretrained(output_dir)
    -273        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
    -274        self.bert_config.save_pretrained(output_dir)
    -275        self.colbert_config.save(output_dir)
    -276        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
    +            
    269    def save_pretrained(self, output_dir):
    +270        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
    +271        self.bert.save_pretrained(output_dir)
    +272        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
    +273        self.bert_config.save_pretrained(output_dir)
    +274        self.colbert_config.save(output_dir)
    +275        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
     
    @@ -1790,137 +1786,135 @@
    Inherited Members
    -
    279class ComBERT(nn.Module):
    -280    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int= 128,
    -281                 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
    -282                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
    -283                 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs):  # kwargs for compat
    -284
    -285        super().__init__()
    -286        self.device = device
    -287        hidden_dim = config.hidden_size
    -288        self.seq_length = max_seq_len
    -289        self.use_trans_blocks = use_trans_blocks
    -290        self.use_batch_norms = use_batch_norms
    -291        self.num_layers = config.num_hidden_layers
    -292        num_labels = config.num_labels
    -293        self.num_blocks = num_blocks
    -294        self.loss_func = LOSS_FUNCS[loss_func.lower()]()
    -295
    -296        # Save our kwargs to reinitialise the model during evaluation
    -297        self.bert_config = config
    -298        self.colbert_config = CoLBERTConfig(k=k,
    -299                                            optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons,
    -300                                            use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks,
    -301                                            residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc,
    -302                                            act_func=act_func, bert_model_args=bert_model_args,
    -303                                            bert_model_kwargs=bert_model_kwargs)
    -304
    -305        logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs))
    -306
    -307        # relax this constraint later
    -308        assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}"
    -309        act_func = ACT_FUNCS[act_func.lower()]
    -310
    -311        # CNN Part
    -312        conv_layers = []
    -313
    -314        # Adds up to num_layers + 1 embedding layer
    -315        conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
    -316
    -317
    -318        for i in range(num_blocks):
    -319            conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut,
    -320                                             kernel_size=residual_kernel_size, act_func=act_func))
    -321
    -322        self.conv_layers = nn.ModuleList(conv_layers)
    -323        self.kmax_pooling = KMaxPool(k)
    -324
    -325        # Create the MLP to compress the k signals
    -326        linear_layers = list()
    -327        linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons))  # Downsample into Kmaxpool?
    -328        linear_layers.append(nn.Dropout(dropout_perc))
    -329        linear_layers.append(nn.Linear(hidden_neurons, num_labels))
    -330
    -331        self.linear_layers = nn.Sequential(*linear_layers)
    -332        self.apply(weight_init)
    -333        self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs,
    -334                                              config=self.bert_config)  # Add Bert model after random initialisation
    -335        self.bert.to(self.device)
    -336
    -337    def forward(self, *args, **kwargs):
    -338        # input_ids: batch_size x seq_length x hidden_dim
    -339
    -340        labels = kwargs['labels'] if 'labels' in kwargs else None
    -341        if labels is not None: del kwargs['labels']
    -342
    -343        bert_outputs = self.bert(*args, **kwargs)
    -344        hidden_states = list(bert_outputs[-1])
    -345        embedding_layer = hidden_states.pop(0)
    +            
    278class ComBERT(nn.Module):
    +279    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int = 128,
    +280                 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
    +281                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
    +282                 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs):  # kwargs for compat
    +283
    +284        super().__init__()
    +285        self.device = device
    +286        hidden_dim = config.hidden_size
    +287        self.seq_length = max_seq_len
    +288        self.use_trans_blocks = use_trans_blocks
    +289        self.use_batch_norms = use_batch_norms
    +290        self.num_layers = config.num_hidden_layers
    +291        num_labels = config.num_labels
    +292        self.num_blocks = num_blocks
    +293        self.loss_func = LOSS_FUNCS[loss_func.lower()]()
    +294
    +295        # Save our kwargs to reinitialise the model during evaluation
    +296        self.bert_config = config
    +297        self.colbert_config = CoLBERTConfig(k=k,
    +298                                            optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons,
    +299                                            use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks,
    +300                                            residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc,
    +301                                            act_func=act_func, bert_model_args=bert_model_args,
    +302                                            bert_model_kwargs=bert_model_kwargs)
    +303
    +304        logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs))
    +305
    +306        # relax this constraint later
    +307        assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}"
    +308        act_func = ACT_FUNCS[act_func.lower()]
    +309
    +310        # CNN Part
    +311        conv_layers = []
    +312
    +313        # Adds up to num_layers + 1 embedding layer
    +314        conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
    +315
    +316        for i in range(num_blocks):
    +317            conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut,
    +318                                             kernel_size=residual_kernel_size, act_func=act_func))
    +319
    +320        self.conv_layers = nn.ModuleList(conv_layers)
    +321        self.kmax_pooling = KMaxPool(k)
    +322
    +323        # Create the MLP to compress the k signals
    +324        linear_layers = list()
    +325        linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons))  # Downsample into Kmaxpool?
    +326        linear_layers.append(nn.Dropout(dropout_perc))
    +327        linear_layers.append(nn.Linear(hidden_neurons, num_labels))
    +328
    +329        self.linear_layers = nn.Sequential(*linear_layers)
    +330        self.apply(weight_init)
    +331        self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs,
    +332                                              config=self.bert_config)  # Add Bert model after random initialisation
    +333        self.bert.to(self.device)
    +334
    +335    def forward(self, *args, **kwargs):
    +336        # input_ids: batch_size x seq_length x hidden_dim
    +337
    +338        labels = kwargs['labels'] if 'labels' in kwargs else None
    +339        if labels is not None: del kwargs['labels']
    +340
    +341        bert_outputs = self.bert(*args, **kwargs)
    +342        hidden_states = list(bert_outputs[-1])
    +343        embedding_layer = hidden_states.pop(0)
    +344
    +345        split_size = len(hidden_states) // self.num_blocks
     346
    -347        split_size = len(hidden_states) // self.num_blocks
    -348
    -349        assert split_size % 2 == 0, "must be an even number"
    -350        split_layers = [hidden_states[x:x+split_size] for x in range(0, len(hidden_states), split_size)]
    -351        split_layers.insert(0, embedding_layer)
    +347        assert split_size % 2 == 0, "must be an even number"
    +348        split_layers = [hidden_states[x:x + split_size] for x in range(0, len(hidden_states), split_size)]
    +349        split_layers.insert(0, embedding_layer)
    +350
    +351        assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length"
     352
    -353        assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length"
    +353        outputs = []
     354
    -355        outputs = []
    -356
    -357        for cnv, layer in zip(self.conv_layers, split_layers):
    -358            outputs.append(self.kmax_pooling(cnv(layer)))
    -359
    -360        # batch_size x seq_len x hidden -> batch_size x flatten
    -361        logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1))
    +355        for cnv, layer in zip(self.conv_layers, split_layers):
    +356            outputs.append(self.kmax_pooling(cnv(layer)))
    +357
    +358        # batch_size x seq_len x hidden -> batch_size x flatten
    +359        logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1))
    +360
    +361        return self.loss_func(logits, labels), logits
     362
    -363        return self.loss_func(logits, labels), logits
    -364
    -365
    -366    @classmethod
    -367    def from_config(cls, *args, config_path):
    -368        kwargs = torch.load(config_path)
    -369        return ComBERT(*args, **kwargs)
    -370
    -371    @classmethod
    -372    def from_pretrained(cls, output_dir, **kwargs):
    -373        config_found = True
    -374        colbert_config = None
    -375
    -376        try:
    -377            colbert_config = CoLBERTConfig.load(output_dir)
    -378        except:
    -379            config_found = False
    -380
    -381        bert_config = None
    -382
    -383        if 'config' in kwargs:
    -384            bert_config = kwargs['config']
    -385            del kwargs['config']
    -386        else:
    -387            bert_config = BertConfig.from_pretrained(output_dir)
    +363    @classmethod
    +364    def from_config(cls, *args, config_path):
    +365        kwargs = torch.load(config_path)
    +366        return ComBERT(*args, **kwargs)
    +367
    +368    @classmethod
    +369    def from_pretrained(cls, output_dir, **kwargs):
    +370        config_found = True
    +371        colbert_config = None
    +372
    +373        try:
    +374            colbert_config = CoLBERTConfig.load(output_dir)
    +375        except:
    +376            config_found = False
    +377
    +378        bert_config = None
    +379
    +380        if 'config' in kwargs:
    +381            bert_config = kwargs['config']
    +382            del kwargs['config']
    +383        else:
    +384            bert_config = BertConfig.from_pretrained(output_dir)
    +385
    +386        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    +387        model = None
     388
    -389        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    -390        model = None
    -391
    -392        if config_found:
    -393            model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs)
    -394            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
    -395            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
    -396
    -397        else:
    -398            model = ComBERT((output_dir,), {}, config=bert_config, **kwargs)
    -399            logger.info(f"*** Create New CNN Bert Model ***")
    -400
    -401        return model
    -402
    -403    def save_pretrained(self, output_dir):
    -404        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
    -405        self.bert.save_pretrained(output_dir)
    -406        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
    -407        self.bert_config.save_pretrained(output_dir)
    -408        self.colbert_config.save(output_dir)
    -409        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
    +389        if config_found:
    +390            model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs)
    +391            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
    +392            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
    +393
    +394        else:
    +395            model = ComBERT((output_dir,), {}, config=bert_config, **kwargs)
    +396            logger.info(f"*** Create New CNN Bert Model ***")
    +397
    +398        return model
    +399
    +400    def save_pretrained(self, output_dir):
    +401        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
    +402        self.bert.save_pretrained(output_dir)
    +403        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
    +404        self.bert_config.save_pretrained(output_dir)
    +405        self.colbert_config.save(output_dir)
    +406        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
     
    @@ -1946,7 +1940,7 @@
    Inherited Members

    Submodules assigned in this way will be registered, and will have their -parameters converted too when you call to, etc.

    +parameters converted too when you call to(), etc.

    @@ -1971,62 +1965,61 @@
    Inherited Members
    -
    280    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int= 128,
    -281                 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
    -282                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
    -283                 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs):  # kwargs for compat
    -284
    -285        super().__init__()
    -286        self.device = device
    -287        hidden_dim = config.hidden_size
    -288        self.seq_length = max_seq_len
    -289        self.use_trans_blocks = use_trans_blocks
    -290        self.use_batch_norms = use_batch_norms
    -291        self.num_layers = config.num_hidden_layers
    -292        num_labels = config.num_labels
    -293        self.num_blocks = num_blocks
    -294        self.loss_func = LOSS_FUNCS[loss_func.lower()]()
    -295
    -296        # Save our kwargs to reinitialise the model during evaluation
    -297        self.bert_config = config
    -298        self.colbert_config = CoLBERTConfig(k=k,
    -299                                            optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons,
    -300                                            use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks,
    -301                                            residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc,
    -302                                            act_func=act_func, bert_model_args=bert_model_args,
    -303                                            bert_model_kwargs=bert_model_kwargs)
    -304
    -305        logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs))
    -306
    -307        # relax this constraint later
    -308        assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}"
    -309        act_func = ACT_FUNCS[act_func.lower()]
    -310
    -311        # CNN Part
    -312        conv_layers = []
    -313
    -314        # Adds up to num_layers + 1 embedding layer
    -315        conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
    -316
    -317
    -318        for i in range(num_blocks):
    -319            conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut,
    -320                                             kernel_size=residual_kernel_size, act_func=act_func))
    -321
    -322        self.conv_layers = nn.ModuleList(conv_layers)
    -323        self.kmax_pooling = KMaxPool(k)
    -324
    -325        # Create the MLP to compress the k signals
    -326        linear_layers = list()
    -327        linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons))  # Downsample into Kmaxpool?
    -328        linear_layers.append(nn.Dropout(dropout_perc))
    -329        linear_layers.append(nn.Linear(hidden_neurons, num_labels))
    -330
    -331        self.linear_layers = nn.Sequential(*linear_layers)
    -332        self.apply(weight_init)
    -333        self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs,
    -334                                              config=self.bert_config)  # Add Bert model after random initialisation
    -335        self.bert.to(self.device)
    +            
    279    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int = 128,
    +280                 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
    +281                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
    +282                 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs):  # kwargs for compat
    +283
    +284        super().__init__()
    +285        self.device = device
    +286        hidden_dim = config.hidden_size
    +287        self.seq_length = max_seq_len
    +288        self.use_trans_blocks = use_trans_blocks
    +289        self.use_batch_norms = use_batch_norms
    +290        self.num_layers = config.num_hidden_layers
    +291        num_labels = config.num_labels
    +292        self.num_blocks = num_blocks
    +293        self.loss_func = LOSS_FUNCS[loss_func.lower()]()
    +294
    +295        # Save our kwargs to reinitialise the model during evaluation
    +296        self.bert_config = config
    +297        self.colbert_config = CoLBERTConfig(k=k,
    +298                                            optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons,
    +299                                            use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks,
    +300                                            residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc,
    +301                                            act_func=act_func, bert_model_args=bert_model_args,
    +302                                            bert_model_kwargs=bert_model_kwargs)
    +303
    +304        logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs))
    +305
    +306        # relax this constraint later
    +307        assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}"
    +308        act_func = ACT_FUNCS[act_func.lower()]
    +309
    +310        # CNN Part
    +311        conv_layers = []
    +312
    +313        # Adds up to num_layers + 1 embedding layer
    +314        conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
    +315
    +316        for i in range(num_blocks):
    +317            conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut,
    +318                                             kernel_size=residual_kernel_size, act_func=act_func))
    +319
    +320        self.conv_layers = nn.ModuleList(conv_layers)
    +321        self.kmax_pooling = KMaxPool(k)
    +322
    +323        # Create the MLP to compress the k signals
    +324        linear_layers = list()
    +325        linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons))  # Downsample into Kmaxpool?
    +326        linear_layers.append(nn.Dropout(dropout_perc))
    +327        linear_layers.append(nn.Linear(hidden_neurons, num_labels))
    +328
    +329        self.linear_layers = nn.Sequential(*linear_layers)
    +330        self.apply(weight_init)
    +331        self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs,
    +332                                              config=self.bert_config)  # Add Bert model after random initialisation
    +333        self.bert.to(self.device)
     
    @@ -2046,33 +2039,33 @@
    Inherited Members
    -
    337    def forward(self, *args, **kwargs):
    -338        # input_ids: batch_size x seq_length x hidden_dim
    -339
    -340        labels = kwargs['labels'] if 'labels' in kwargs else None
    -341        if labels is not None: del kwargs['labels']
    -342
    -343        bert_outputs = self.bert(*args, **kwargs)
    -344        hidden_states = list(bert_outputs[-1])
    -345        embedding_layer = hidden_states.pop(0)
    +            
    335    def forward(self, *args, **kwargs):
    +336        # input_ids: batch_size x seq_length x hidden_dim
    +337
    +338        labels = kwargs['labels'] if 'labels' in kwargs else None
    +339        if labels is not None: del kwargs['labels']
    +340
    +341        bert_outputs = self.bert(*args, **kwargs)
    +342        hidden_states = list(bert_outputs[-1])
    +343        embedding_layer = hidden_states.pop(0)
    +344
    +345        split_size = len(hidden_states) // self.num_blocks
     346
    -347        split_size = len(hidden_states) // self.num_blocks
    -348
    -349        assert split_size % 2 == 0, "must be an even number"
    -350        split_layers = [hidden_states[x:x+split_size] for x in range(0, len(hidden_states), split_size)]
    -351        split_layers.insert(0, embedding_layer)
    +347        assert split_size % 2 == 0, "must be an even number"
    +348        split_layers = [hidden_states[x:x + split_size] for x in range(0, len(hidden_states), split_size)]
    +349        split_layers.insert(0, embedding_layer)
    +350
    +351        assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length"
     352
    -353        assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length"
    +353        outputs = []
     354
    -355        outputs = []
    -356
    -357        for cnv, layer in zip(self.conv_layers, split_layers):
    -358            outputs.append(self.kmax_pooling(cnv(layer)))
    -359
    -360        # batch_size x seq_len x hidden -> batch_size x flatten
    -361        logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1))
    -362
    -363        return self.loss_func(logits, labels), logits
    +355        for cnv, layer in zip(self.conv_layers, split_layers):
    +356            outputs.append(self.kmax_pooling(cnv(layer)))
    +357
    +358        # batch_size x seq_len x hidden -> batch_size x flatten
    +359        logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1))
    +360
    +361        return self.loss_func(logits, labels), logits
     
    @@ -2104,10 +2097,10 @@
    Inherited Members
    -
    366    @classmethod
    -367    def from_config(cls, *args, config_path):
    -368        kwargs = torch.load(config_path)
    -369        return ComBERT(*args, **kwargs)
    +            
    363    @classmethod
    +364    def from_config(cls, *args, config_path):
    +365        kwargs = torch.load(config_path)
    +366        return ComBERT(*args, **kwargs)
     
    @@ -2126,37 +2119,37 @@
    Inherited Members
    -
    371    @classmethod
    -372    def from_pretrained(cls, output_dir, **kwargs):
    -373        config_found = True
    -374        colbert_config = None
    -375
    -376        try:
    -377            colbert_config = CoLBERTConfig.load(output_dir)
    -378        except:
    -379            config_found = False
    -380
    -381        bert_config = None
    -382
    -383        if 'config' in kwargs:
    -384            bert_config = kwargs['config']
    -385            del kwargs['config']
    -386        else:
    -387            bert_config = BertConfig.from_pretrained(output_dir)
    +            
    368    @classmethod
    +369    def from_pretrained(cls, output_dir, **kwargs):
    +370        config_found = True
    +371        colbert_config = None
    +372
    +373        try:
    +374            colbert_config = CoLBERTConfig.load(output_dir)
    +375        except:
    +376            config_found = False
    +377
    +378        bert_config = None
    +379
    +380        if 'config' in kwargs:
    +381            bert_config = kwargs['config']
    +382            del kwargs['config']
    +383        else:
    +384            bert_config = BertConfig.from_pretrained(output_dir)
    +385
    +386        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    +387        model = None
     388
    -389        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    -390        model = None
    -391
    -392        if config_found:
    -393            model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs)
    -394            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
    -395            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
    -396
    -397        else:
    -398            model = ComBERT((output_dir,), {}, config=bert_config, **kwargs)
    -399            logger.info(f"*** Create New CNN Bert Model ***")
    -400
    -401        return model
    +389        if config_found:
    +390            model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs)
    +391            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
    +392            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
    +393
    +394        else:
    +395            model = ComBERT((output_dir,), {}, config=bert_config, **kwargs)
    +396            logger.info(f"*** Create New CNN Bert Model ***")
    +397
    +398        return model
     
    @@ -2174,13 +2167,13 @@
    Inherited Members
    -
    403    def save_pretrained(self, output_dir):
    -404        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
    -405        self.bert.save_pretrained(output_dir)
    -406        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
    -407        self.bert_config.save_pretrained(output_dir)
    -408        self.colbert_config.save(output_dir)
    -409        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
    +            
    400    def save_pretrained(self, output_dir):
    +401        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
    +402        self.bert.save_pretrained(output_dir)
    +403        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
    +404        self.bert_config.save_pretrained(output_dir)
    +405        self.colbert_config.save(output_dir)
    +406        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
     
    @@ -2339,7 +2332,7 @@
    Inherited Members
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -2366,7 +2359,7 @@
    Inherited Members
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/rankers.html b/docs/debeir/rankers.html index ee5dd0d..9feba0b 100644 --- a/docs/debeir/rankers.html +++ b/docs/debeir/rankers.html @@ -3,7 +3,7 @@ - + debeir.rankers API documentation @@ -172,7 +172,7 @@

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -199,7 +199,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/rankers/reranking.html b/docs/debeir/rankers/reranking.html index eeb1b26..837ab1a 100644 --- a/docs/debeir/rankers/reranking.html +++ b/docs/debeir/rankers/reranking.html @@ -3,7 +3,7 @@ - + debeir.rankers.reranking API documentation @@ -154,7 +154,7 @@

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -181,7 +181,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/rankers/reranking/nir.html b/docs/debeir/rankers/reranking/nir.html index 6de8d2a..857d6a5 100644 --- a/docs/debeir/rankers/reranking/nir.html +++ b/docs/debeir/rankers/reranking/nir.html @@ -3,7 +3,7 @@ - + debeir.rankers.reranking.nir API documentation @@ -70,61 +70,84 @@

    4[Insert paper link here] 5""" 6 - 7from typing import List, Dict - 8 - 9from tqdm import tqdm -10 -11from debeir.utils import scaler -12from debeir.interfaces.document import Document -13from debeir.rankers.reranking.reranker import DocumentReRanker -14from debeir.rankers.transformer_sent_encoder import Encoder -15from scipy import spatial -16import math + 7import math + 8from typing import Dict, List + 9 +10from debeir.core.document import Document +11from debeir.rankers.reranking.reranker import DocumentReRanker +12from debeir.rankers.transformer_sent_encoder import Encoder +13from debeir.utils import scaler +14from scipy import spatial +15from tqdm import tqdm +16 17 -18 -19class NIReRanker(DocumentReRanker): -20 """ -21 Re-ranker which uses the NIR scoring method -22 score = log(bm25)/log(z) + cosine_sum -23 """ -24 -25 def __init__(self, query, ranked_list: List[Document], encoder: Encoder, -26 distance_fn=spatial.distance.cosine, facets_weights: Dict = None, +18class NIReRanker(DocumentReRanker): +19 """ +20 Re-ranker which uses the NIR scoring method +21 score = log(bm25)/log(z) + cosine_sum +22 """ +23 +24 def __init__(self, query, ranked_list: List[Document], encoder: Encoder, +25 distance_fn=spatial.distance.cosine, facets_weights: Dict = None, +26 presort=False, fields_to_encode=None, 27 *args, **kwargs): -28 super().__init__(query, ranked_list, *args, **kwargs) -29 self.encoder = encoder -30 self.top_score = self._get_top_score() -31 self.top_cosine_score = -1 -32 -33 self.query_vec = self.encoder(self.query) -34 self.distance_fn = distance_fn -35 -36 # Compute all the cosine scores -37 self.pre_calc = {} -38 self._compute_scores_helper() -39 self.log_norm = scaler.get_z_value(self.top_score, self.top_cosine_score) -40 self.facets_weights = facets_weights -41 -42 def _get_top_score(self): -43 return self.ranked_list[0].score -44 -45 def _compute_scores_helper(self): -46 for document in tqdm(self.ranked_list, desc="Calculating cosine scores"): -47 facet_scores = {} -48 for facet in document.facets: -49 document_vec = self.encoder(facet) +28 +29 if presort: +30 ranked_list.sort(key=lambda k: k.score) +31 +32 super().__init__(query, ranked_list, *args, **kwargs) +33 self.encoder = encoder +34 self.top_score = self._get_top_score() +35 self.top_cosine_score = -1 +36 +37 self.query_vec = self.encoder(self.query) +38 self.distance_fn = distance_fn +39 self.fields_to_encode = fields_to_encode +40 +41 if facets_weights: +42 self.facets_weights = facets_weights +43 else: +44 self.facets_weights = {} +45 +46 # Compute all the cosine scores +47 self.pre_calc = {} +48 self.pre_calc_finished = False +49 self.log_norm = None 50 -51 facet_weight = self.facets_weights[facet] if facet in self.facets_weights else 1.0 -52 facet_scores[facet] = self.distance_fn(self.query_vec, document_vec) * facet_weight +51 def _get_top_score(self): +52 return self.ranked_list[0].score 53 -54 sum_score = sum(facet_scores.values()) -55 facet_scores["cosine_sum"] = sum_score -56 -57 self.top_score = max(self.top_score, sum_score) -58 self.pre_calc[document.doc_id] = facet_scores -59 -60 def _compute_scores(self, document): -61 return math.log(document.score, self.log_norm) + self.pre_calc[document.id]["cosine_sum"] +54 def _compute_scores_helper(self): +55 for document in tqdm(self.ranked_list, desc="Calculating cosine scores"): +56 facet_scores = {} +57 for facet in self.fields_to_encode if self.fields_to_encode else document.facets: +58 if "embedding" in facet.lower(): +59 continue +60 +61 document_facet = document.facets[facet] +62 facet_weight = self.facets_weights[document_facet] if facet in self.facets_weights else 1.0 +63 +64 # Early exit +65 if facet_weight == 0: +66 continue +67 +68 document_vec = self.encoder(document_facet) +69 facet_scores[facet] = self.distance_fn(self.query_vec, document_vec) * facet_weight +70 +71 sum_score = sum(facet_scores.values()) +72 facet_scores["cosine_sum"] = sum_score +73 +74 self.top_cosine_score = max(self.top_cosine_score, sum_score) +75 self.pre_calc[document.doc_id] = facet_scores +76 +77 self.pre_calc_finished = True +78 +79 def _compute_scores(self, document): +80 if not self.pre_calc_finished: +81 self._compute_scores_helper() +82 self.log_norm = scaler.get_z_value(self.top_cosine_score, self.top_score) +83 +84 return math.log(document.score, self.log_norm) + self.pre_calc[document.doc_id]["cosine_sum"]

    @@ -140,49 +163,73 @@

    -
    20class NIReRanker(DocumentReRanker):
    -21    """
    -22    Re-ranker which uses the NIR scoring method
    -23        score = log(bm25)/log(z) + cosine_sum
    -24    """
    -25
    -26    def __init__(self, query, ranked_list: List[Document], encoder: Encoder,
    -27                 distance_fn=spatial.distance.cosine, facets_weights: Dict = None,
    +            
    19class NIReRanker(DocumentReRanker):
    +20    """
    +21    Re-ranker which uses the NIR scoring method
    +22        score = log(bm25)/log(z) + cosine_sum
    +23    """
    +24
    +25    def __init__(self, query, ranked_list: List[Document], encoder: Encoder,
    +26                 distance_fn=spatial.distance.cosine, facets_weights: Dict = None,
    +27                 presort=False, fields_to_encode=None,
     28                 *args, **kwargs):
    -29        super().__init__(query, ranked_list, *args, **kwargs)
    -30        self.encoder = encoder
    -31        self.top_score = self._get_top_score()
    -32        self.top_cosine_score = -1
    -33
    -34        self.query_vec = self.encoder(self.query)
    -35        self.distance_fn = distance_fn
    -36
    -37        # Compute all the cosine scores
    -38        self.pre_calc = {}
    -39        self._compute_scores_helper()
    -40        self.log_norm = scaler.get_z_value(self.top_score, self.top_cosine_score)
    -41        self.facets_weights = facets_weights
    -42
    -43    def _get_top_score(self):
    -44        return self.ranked_list[0].score
    -45
    -46    def _compute_scores_helper(self):
    -47        for document in tqdm(self.ranked_list, desc="Calculating cosine scores"):
    -48            facet_scores = {}
    -49            for facet in document.facets:
    -50                document_vec = self.encoder(facet)
    +29
    +30        if presort:
    +31            ranked_list.sort(key=lambda k: k.score)
    +32
    +33        super().__init__(query, ranked_list, *args, **kwargs)
    +34        self.encoder = encoder
    +35        self.top_score = self._get_top_score()
    +36        self.top_cosine_score = -1
    +37
    +38        self.query_vec = self.encoder(self.query)
    +39        self.distance_fn = distance_fn
    +40        self.fields_to_encode = fields_to_encode
    +41
    +42        if facets_weights:
    +43            self.facets_weights = facets_weights
    +44        else:
    +45            self.facets_weights = {}
    +46
    +47        # Compute all the cosine scores
    +48        self.pre_calc = {}
    +49        self.pre_calc_finished = False
    +50        self.log_norm = None
     51
    -52                facet_weight = self.facets_weights[facet] if facet in self.facets_weights else 1.0
    -53                facet_scores[facet] = self.distance_fn(self.query_vec, document_vec) * facet_weight
    +52    def _get_top_score(self):
    +53        return self.ranked_list[0].score
     54
    -55                sum_score = sum(facet_scores.values())
    -56                facet_scores["cosine_sum"] = sum_score
    -57
    -58                self.top_score = max(self.top_score, sum_score)
    -59                self.pre_calc[document.doc_id] = facet_scores
    -60
    -61    def _compute_scores(self, document):
    -62        return math.log(document.score, self.log_norm) + self.pre_calc[document.id]["cosine_sum"]
    +55    def _compute_scores_helper(self):
    +56        for document in tqdm(self.ranked_list, desc="Calculating cosine scores"):
    +57            facet_scores = {}
    +58            for facet in self.fields_to_encode if self.fields_to_encode else document.facets:
    +59                if "embedding" in facet.lower():
    +60                    continue
    +61
    +62                document_facet = document.facets[facet]
    +63                facet_weight = self.facets_weights[document_facet] if facet in self.facets_weights else 1.0
    +64
    +65                # Early exit
    +66                if facet_weight == 0:
    +67                    continue
    +68
    +69                document_vec = self.encoder(document_facet)
    +70                facet_scores[facet] = self.distance_fn(self.query_vec, document_vec) * facet_weight
    +71
    +72                sum_score = sum(facet_scores.values())
    +73                facet_scores["cosine_sum"] = sum_score
    +74
    +75                self.top_cosine_score = max(self.top_cosine_score, sum_score)
    +76                self.pre_calc[document.doc_id] = facet_scores
    +77
    +78        self.pre_calc_finished = True
    +79
    +80    def _compute_scores(self, document):
    +81        if not self.pre_calc_finished:
    +82            self._compute_scores_helper()
    +83            self.log_norm = scaler.get_z_value(self.top_cosine_score, self.top_score)
    +84
    +85        return math.log(document.score, self.log_norm) + self.pre_calc[document.doc_id]["cosine_sum"]
     
    @@ -195,28 +242,38 @@

    - NIReRanker( query, ranked_list: List[debeir.interfaces.document.Document], encoder: debeir.rankers.transformer_sent_encoder.Encoder, distance_fn=<function cosine>, facets_weights: Dict = None, *args, **kwargs) + NIReRanker( query, ranked_list: List[debeir.core.document.Document], encoder: debeir.rankers.transformer_sent_encoder.Encoder, distance_fn=<function cosine>, facets_weights: Dict = None, presort=False, fields_to_encode=None, *args, **kwargs)
    -
    26    def __init__(self, query, ranked_list: List[Document], encoder: Encoder,
    -27                 distance_fn=spatial.distance.cosine, facets_weights: Dict = None,
    +            
    25    def __init__(self, query, ranked_list: List[Document], encoder: Encoder,
    +26                 distance_fn=spatial.distance.cosine, facets_weights: Dict = None,
    +27                 presort=False, fields_to_encode=None,
     28                 *args, **kwargs):
    -29        super().__init__(query, ranked_list, *args, **kwargs)
    -30        self.encoder = encoder
    -31        self.top_score = self._get_top_score()
    -32        self.top_cosine_score = -1
    -33
    -34        self.query_vec = self.encoder(self.query)
    -35        self.distance_fn = distance_fn
    -36
    -37        # Compute all the cosine scores
    -38        self.pre_calc = {}
    -39        self._compute_scores_helper()
    -40        self.log_norm = scaler.get_z_value(self.top_score, self.top_cosine_score)
    -41        self.facets_weights = facets_weights
    +29
    +30        if presort:
    +31            ranked_list.sort(key=lambda k: k.score)
    +32
    +33        super().__init__(query, ranked_list, *args, **kwargs)
    +34        self.encoder = encoder
    +35        self.top_score = self._get_top_score()
    +36        self.top_cosine_score = -1
    +37
    +38        self.query_vec = self.encoder(self.query)
    +39        self.distance_fn = distance_fn
    +40        self.fields_to_encode = fields_to_encode
    +41
    +42        if facets_weights:
    +43            self.facets_weights = facets_weights
    +44        else:
    +45            self.facets_weights = {}
    +46
    +47        # Compute all the cosine scores
    +48        self.pre_calc = {}
    +49        self.pre_calc_finished = False
    +50        self.log_norm = None
     
    @@ -228,7 +285,6 @@
    Inherited Members
    @@ -335,7 +391,7 @@
    Inherited Members
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -362,7 +418,7 @@
    Inherited Members
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/rankers/reranking/reranker.html b/docs/debeir/rankers/reranking/reranker.html index 18a88f5..6546d59 100644 --- a/docs/debeir/rankers/reranking/reranker.html +++ b/docs/debeir/rankers/reranking/reranker.html @@ -3,7 +3,7 @@ - + debeir.rankers.reranking.reranker API documentation @@ -39,9 +39,6 @@

    API Documentation

  • rerank
  • -
  • - rrerank -
@@ -54,6 +51,15 @@

API Documentation

+
  • + ReRankerPool + + +
  • @@ -82,9 +88,9 @@

    3""" 4 5import abc - 6from typing import List, AnyStr + 6from typing import AnyStr, List, Union 7 - 8from debeir.interfaces.document import Document + 8from debeir.core.document import Document 9 10 11class ReRanker: @@ -112,49 +118,45 @@

    33 34 def rerank(self) -> List: 35 """ -36 Re-ranks the internal list +36 Re-rank the passed ranked list based on implemented private _compute_scores method. 37 -38 :return: -39 """ -40 return self.rrerank(self.ranked_list) -41 -42 @classmethod -43 def rrerank(cls, ranked_list: List) -> List: -44 """ -45 Re-rank the passed ranked list based on implemented private _compute_scores method. -46 -47 :param ranked_list: -48 :return: -49 A ranked list in descending order of the score field (which will be the last item in the list) -50 """ -51 ranking = [] -52 -53 for document in ranked_list: -54 doc_id, doc_repr = cls._get_document_representation(document) -55 score = cls._compute_scores(doc_repr) -56 -57 ranking.append([doc_id, doc_repr, score]) -58 -59 ranking.sort(key=lambda k: k[-1], reverse=True) -60 -61 return ranking +38 :param ranked_list: +39 :return: +40 A ranked list in descending order of the score field (which will be the last item in the list) +41 """ +42 ranking = [] +43 +44 for document in self.ranked_list: +45 doc_id, doc_repr = self._get_document_representation(document) +46 score = self._compute_scores(doc_repr) +47 +48 ranking.append([doc_id, doc_repr, score]) +49 +50 ranking.sort(key=lambda k: k[-1], reverse=True) +51 +52 return ranking +53 +54 +55class DocumentReRanker(ReRanker): +56 """ +57 Reranking interface for a ranked list of Document objects. +58 """ +59 +60 def __init__(self, query, ranked_list: List[Document], *args, **kwargs): +61 super().__init__(query, ranked_list, *args, **kwargs) 62 -63 -64class DocumentReRanker(ReRanker): -65 """ -66 Reranking interface for a ranked list of Document objects. -67 """ -68 -69 def __init__(self, query, ranked_list: List[Document], *args, **kwargs): -70 super().__init__(query, ranked_list, *args, **kwargs) +63 @abc.abstractmethod +64 def _compute_scores(self, document_repr): +65 pass +66 +67 @classmethod +68 def _get_document_representation(cls, document: Document) -> (Union[int, str, float], Document): +69 return document.doc_id, document +70 71 -72 @abc.abstractmethod -73 def _compute_scores(self, document_repr): -74 pass -75 -76 @classmethod -77 def _get_document_representation(cls, document: Document) -> (AnyStr, AnyStr): -78 return " ".join(document.facets.values()) +72class ReRankerPool: +73 # Reranks per topic using threads. +74 pass @@ -195,32 +197,23 @@

    34 35 def rerank(self) -> List: 36 """ -37 Re-ranks the internal list +37 Re-rank the passed ranked list based on implemented private _compute_scores method. 38 -39 :return: -40 """ -41 return self.rrerank(self.ranked_list) -42 -43 @classmethod -44 def rrerank(cls, ranked_list: List) -> List: -45 """ -46 Re-rank the passed ranked list based on implemented private _compute_scores method. -47 -48 :param ranked_list: -49 :return: -50 A ranked list in descending order of the score field (which will be the last item in the list) -51 """ -52 ranking = [] -53 -54 for document in ranked_list: -55 doc_id, doc_repr = cls._get_document_representation(document) -56 score = cls._compute_scores(doc_repr) -57 -58 ranking.append([doc_id, doc_repr, score]) -59 -60 ranking.sort(key=lambda k: k[-1], reverse=True) -61 -62 return ranking +39 :param ranked_list: +40 :return: +41 A ranked list in descending order of the score field (which will be the last item in the list) +42 """ +43 ranking = [] +44 +45 for document in self.ranked_list: +46 doc_id, doc_repr = self._get_document_representation(document) +47 score = self._compute_scores(doc_repr) +48 +49 ranking.append([doc_id, doc_repr, score]) +50 +51 ranking.sort(key=lambda k: k[-1], reverse=True) +52 +53 return ranking @@ -262,53 +255,23 @@

    35    def rerank(self) -> List:
     36        """
    -37        Re-ranks the internal list
    +37        Re-rank the passed ranked list based on implemented private _compute_scores method.
     38
    -39        :return:
    -40        """
    -41        return self.rrerank(self.ranked_list)
    -
    - - -

    Re-ranks the internal list

    - -
    Returns
    -
    - - - -
    - -
    -
    @classmethod
    - - def - rrerank(cls, ranked_list: List) -> List: - - - -
    - -
    43    @classmethod
    -44    def rrerank(cls, ranked_list: List) -> List:
    -45        """
    -46        Re-rank the passed ranked list based on implemented private _compute_scores method.
    -47
    -48        :param ranked_list:
    -49        :return:
    -50            A ranked list in descending order of the score field (which will be the last item in the list)
    -51        """
    -52        ranking = []
    -53
    -54        for document in ranked_list:
    -55            doc_id, doc_repr = cls._get_document_representation(document)
    -56            score = cls._compute_scores(doc_repr)
    -57
    -58            ranking.append([doc_id, doc_repr, score])
    -59
    -60        ranking.sort(key=lambda k: k[-1], reverse=True)
    -61
    -62        return ranking
    +39        :param ranked_list:
    +40        :return:
    +41            A ranked list in descending order of the score field (which will be the last item in the list)
    +42        """
    +43        ranking = []
    +44
    +45        for document in self.ranked_list:
    +46            doc_id, doc_repr = self._get_document_representation(document)
    +47            score = self._compute_scores(doc_repr)
    +48
    +49            ranking.append([doc_id, doc_repr, score])
    +50
    +51        ranking.sort(key=lambda k: k[-1], reverse=True)
    +52
    +53        return ranking
     
    @@ -342,21 +305,21 @@
    Returns
    -
    65class DocumentReRanker(ReRanker):
    -66    """
    -67    Reranking interface for a ranked list of Document objects.
    -68    """
    -69
    -70    def __init__(self, query, ranked_list: List[Document], *args, **kwargs):
    -71        super().__init__(query, ranked_list, *args, **kwargs)
    -72
    -73    @abc.abstractmethod
    -74    def _compute_scores(self, document_repr):
    -75        pass
    -76
    -77    @classmethod
    -78    def _get_document_representation(cls, document: Document) -> (AnyStr, AnyStr):
    -79        return " ".join(document.facets.values())
    +            
    56class DocumentReRanker(ReRanker):
    +57    """
    +58    Reranking interface for a ranked list of Document objects.
    +59    """
    +60
    +61    def __init__(self, query, ranked_list: List[Document], *args, **kwargs):
    +62        super().__init__(query, ranked_list, *args, **kwargs)
    +63
    +64    @abc.abstractmethod
    +65    def _compute_scores(self, document_repr):
    +66        pass
    +67
    +68    @classmethod
    +69    def _get_document_representation(cls, document: Document) -> (Union[int, str, float], Document):
    +70        return document.doc_id, document
     
    @@ -368,14 +331,14 @@
    Returns
    - DocumentReRanker( query, ranked_list: List[debeir.interfaces.document.Document], *args, **kwargs) + DocumentReRanker( query, ranked_list: List[debeir.core.document.Document], *args, **kwargs)
    -
    70    def __init__(self, query, ranked_list: List[Document], *args, **kwargs):
    -71        super().__init__(query, ranked_list, *args, **kwargs)
    +            
    61    def __init__(self, query, ranked_list: List[Document], *args, **kwargs):
    +62        super().__init__(query, ranked_list, *args, **kwargs)
     
    @@ -387,10 +350,41 @@
    Inherited Members
    +
    + +
    + +
    + + class + ReRankerPool: + + + +
    + +
    73class ReRankerPool:
    +74    # Reranks per topic using threads.
    +75    pass
    +
    + + + + +
    +
    + + ReRankerPool() + + +
    + + + +
    @@ -494,7 +488,7 @@
    Inherited Members
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -521,7 +515,7 @@
    Inherited Members
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/rankers/reranking/use.html b/docs/debeir/rankers/reranking/use.html index ae85216..feec414 100644 --- a/docs/debeir/rankers/reranking/use.html +++ b/docs/debeir/rankers/reranking/use.html @@ -3,7 +3,7 @@ - + debeir.rankers.reranking.use API documentation @@ -72,7 +72,10 @@

    10 super().__init__(*args, **kwargs) 11 12 def _compute_scores(self, document): -13 return self.pre_calc[document.id]["cosine_sum"] +13 if not self.pre_calc_finished: +14 self._compute_scores_helper() +15 +16 return self.pre_calc[document.doc_id]["cosine_sum"]

    @@ -97,7 +100,10 @@

    11 super().__init__(*args, **kwargs) 12 13 def _compute_scores(self, document): -14 return self.pre_calc[document.id]["cosine_sum"] +14 if not self.pre_calc_finished: +15 self._compute_scores_helper() +16 +17 return self.pre_calc[document.doc_id]["cosine_sum"] @@ -128,7 +134,6 @@

    Inherited Members
    @@ -235,7 +240,7 @@
    Inherited Members
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -262,7 +267,7 @@
    Inherited Members
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/rankers/transformer_sent_encoder.html b/docs/debeir/rankers/transformer_sent_encoder.html index 14469f1..0b32f73 100644 --- a/docs/debeir/rankers/transformer_sent_encoder.html +++ b/docs/debeir/rankers/transformer_sent_encoder.html @@ -3,7 +3,7 @@ - + debeir.rankers.transformer_sent_encoder API documentation @@ -37,7 +37,7 @@

    API Documentation

    Encoder
  • - encode + encode
  • @@ -63,14 +63,14 @@

    -
     1from typing import List
    - 2
    - 3import sentence_transformers
    - 4import torch
    - 5import torch.nn.functional as F
    - 6import spacy
    - 7from analysis_tools_ir.utils import cache
    - 8from hashlib import md5
    +                        
     1from hashlib import md5
    + 2from typing import List
    + 3
    + 4import sentence_transformers
    + 5import spacy
    + 6import torch
    + 7import torch.nn.functional as F
    + 8from analysis_tools_ir.utils import cache
      9
     10EMBEDDING_DIM_SIZE = 768
     11
    @@ -84,77 +84,78 @@ 

    19 :param spacy_model: the spacy or scispacy model to use for sentence boundary detection. 20 :param max_length: Maximum input length for the spacy nlp model. 21 """ -22 def __init__( -23 self, -24 model_path, -25 normalize=False, -26 spacy_model="en_core_sci_md", -27 max_length=2000000, -28 ): -29 self.model = sentence_transformers.SentenceTransformer(model_path) -30 self.model_path = model_path -31 self.nlp = spacy.load(spacy_model) -32 self.spacy_model = spacy_model -33 self.max_length = max_length -34 self.nlp.max_length = max_length -35 self.normalize = normalize -36 -37 @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/") -38 def encode(self, topic: str) -> List: -39 """ -40 Computes sentence embeddings for a given topic, uses spacy for sentence segmentation. -41 By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this. -42 -43 :param topic: The topic (a list of sentences) to encode. Should be a raw string. -44 :param disable_cache: keyword argument, pass as True to disable encoding caching. -45 :return: -46 Returns a list of encoded tensors is returned. -47 """ -48 sentences = [ -49 " ".join(sent.text.split()) -50 for sent in self.nlp(topic).sents -51 if sent.text.strip() -52 ] -53 -54 embeddings = self.model.encode(sentences, convert_to_tensor=True, -55 show_progress_bar=False) -56 -57 if len(embeddings.size()) == 1: -58 embeddings = torch.unsqueeze(embeddings, dim=0) -59 embeddings = torch.mean(embeddings, axis=0) -60 -61 if self.normalize: -62 embeddings = F.normalize(embeddings, dim=-1) -63 -64 embeddings = embeddings.tolist() -65 -66 if isinstance(embeddings, list) and isinstance(embeddings[0], list): -67 return embeddings[0] -68 -69 return embeddings -70 -71 def __call__(self, topic, *args, **kwargs) -> List: -72 return self.encode(topic) -73 -74 def __eq__(self, other): -75 return ( -76 self.model_path == other.model_path -77 and self.spacy_model == other.spacy_model -78 and self.normalize == other.normalize -79 and self.max_length == other.max_length -80 ) -81 -82 def __hash__(self): -83 return int( -84 md5( -85 (self.model_path -86 + self.spacy_model -87 + str(self.normalize) -88 + str(self.max_length) -89 ).encode() -90 ).hexdigest(), -91 16, -92 ) +22 +23 def __init__( +24 self, +25 model_path, +26 normalize=False, +27 spacy_model="en_core_sci_md", +28 max_length=2000000, +29 ): +30 self.model = sentence_transformers.SentenceTransformer(model_path) +31 self.model_path = model_path +32 self.nlp = spacy.load(spacy_model) +33 self.spacy_model = spacy_model +34 self.max_length = max_length +35 self.nlp.max_length = max_length +36 self.normalize = normalize +37 +38 @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/") +39 def encode(self, topic: str) -> List: +40 """ +41 Computes sentence embeddings for a given topic, uses spacy for sentence segmentation. +42 By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this. +43 +44 :param topic: The topic (a list of sentences) to encode. Should be a raw string. +45 :param disable_cache: keyword argument, pass as True to disable encoding caching. +46 :return: +47 Returns a list of encoded tensors is returned. +48 """ +49 sentences = [ +50 " ".join(sent.text.split()) +51 for sent in self.nlp(topic).sents +52 if sent.text.strip() +53 ] +54 +55 embeddings = self.model.encode(sentences, convert_to_tensor=True, +56 show_progress_bar=False) +57 +58 if len(embeddings.size()) == 1: +59 embeddings = torch.unsqueeze(embeddings, dim=0) +60 embeddings = torch.mean(embeddings, axis=0) +61 +62 if self.normalize: +63 embeddings = F.normalize(embeddings, dim=-1) +64 +65 embeddings = embeddings.tolist() +66 +67 if isinstance(embeddings, list) and isinstance(embeddings[0], list): +68 return embeddings[0] +69 +70 return embeddings +71 +72 def __call__(self, topic, *args, **kwargs) -> List: +73 return self.encode(topic) +74 +75 def __eq__(self, other): +76 return ( +77 self.model_path == other.model_path +78 and self.spacy_model == other.spacy_model +79 and self.normalize == other.normalize +80 and self.max_length == other.max_length +81 ) +82 +83 def __hash__(self): +84 return int( +85 md5( +86 (self.model_path +87 + self.spacy_model +88 + str(self.normalize) +89 + str(self.max_length) +90 ).encode() +91 ).hexdigest(), +92 16, +93 )

    @@ -179,77 +180,78 @@

    20 :param spacy_model: the spacy or scispacy model to use for sentence boundary detection. 21 :param max_length: Maximum input length for the spacy nlp model. 22 """ -23 def __init__( -24 self, -25 model_path, -26 normalize=False, -27 spacy_model="en_core_sci_md", -28 max_length=2000000, -29 ): -30 self.model = sentence_transformers.SentenceTransformer(model_path) -31 self.model_path = model_path -32 self.nlp = spacy.load(spacy_model) -33 self.spacy_model = spacy_model -34 self.max_length = max_length -35 self.nlp.max_length = max_length -36 self.normalize = normalize -37 -38 @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/") -39 def encode(self, topic: str) -> List: -40 """ -41 Computes sentence embeddings for a given topic, uses spacy for sentence segmentation. -42 By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this. -43 -44 :param topic: The topic (a list of sentences) to encode. Should be a raw string. -45 :param disable_cache: keyword argument, pass as True to disable encoding caching. -46 :return: -47 Returns a list of encoded tensors is returned. -48 """ -49 sentences = [ -50 " ".join(sent.text.split()) -51 for sent in self.nlp(topic).sents -52 if sent.text.strip() -53 ] -54 -55 embeddings = self.model.encode(sentences, convert_to_tensor=True, -56 show_progress_bar=False) -57 -58 if len(embeddings.size()) == 1: -59 embeddings = torch.unsqueeze(embeddings, dim=0) -60 embeddings = torch.mean(embeddings, axis=0) -61 -62 if self.normalize: -63 embeddings = F.normalize(embeddings, dim=-1) -64 -65 embeddings = embeddings.tolist() -66 -67 if isinstance(embeddings, list) and isinstance(embeddings[0], list): -68 return embeddings[0] -69 -70 return embeddings -71 -72 def __call__(self, topic, *args, **kwargs) -> List: -73 return self.encode(topic) -74 -75 def __eq__(self, other): -76 return ( -77 self.model_path == other.model_path -78 and self.spacy_model == other.spacy_model -79 and self.normalize == other.normalize -80 and self.max_length == other.max_length -81 ) -82 -83 def __hash__(self): -84 return int( -85 md5( -86 (self.model_path -87 + self.spacy_model -88 + str(self.normalize) -89 + str(self.max_length) -90 ).encode() -91 ).hexdigest(), -92 16, -93 ) +23 +24 def __init__( +25 self, +26 model_path, +27 normalize=False, +28 spacy_model="en_core_sci_md", +29 max_length=2000000, +30 ): +31 self.model = sentence_transformers.SentenceTransformer(model_path) +32 self.model_path = model_path +33 self.nlp = spacy.load(spacy_model) +34 self.spacy_model = spacy_model +35 self.max_length = max_length +36 self.nlp.max_length = max_length +37 self.normalize = normalize +38 +39 @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/") +40 def encode(self, topic: str) -> List: +41 """ +42 Computes sentence embeddings for a given topic, uses spacy for sentence segmentation. +43 By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this. +44 +45 :param topic: The topic (a list of sentences) to encode. Should be a raw string. +46 :param disable_cache: keyword argument, pass as True to disable encoding caching. +47 :return: +48 Returns a list of encoded tensors is returned. +49 """ +50 sentences = [ +51 " ".join(sent.text.split()) +52 for sent in self.nlp(topic).sents +53 if sent.text.strip() +54 ] +55 +56 embeddings = self.model.encode(sentences, convert_to_tensor=True, +57 show_progress_bar=False) +58 +59 if len(embeddings.size()) == 1: +60 embeddings = torch.unsqueeze(embeddings, dim=0) +61 embeddings = torch.mean(embeddings, axis=0) +62 +63 if self.normalize: +64 embeddings = F.normalize(embeddings, dim=-1) +65 +66 embeddings = embeddings.tolist() +67 +68 if isinstance(embeddings, list) and isinstance(embeddings[0], list): +69 return embeddings[0] +70 +71 return embeddings +72 +73 def __call__(self, topic, *args, **kwargs) -> List: +74 return self.encode(topic) +75 +76 def __eq__(self, other): +77 return ( +78 self.model_path == other.model_path +79 and self.spacy_model == other.spacy_model +80 and self.normalize == other.normalize +81 and self.max_length == other.max_length +82 ) +83 +84 def __hash__(self): +85 return int( +86 md5( +87 (self.model_path +88 + self.spacy_model +89 + str(self.normalize) +90 + str(self.max_length) +91 ).encode() +92 ).hexdigest(), +93 16, +94 )

    @@ -276,20 +278,20 @@

    Parameters
    -
    23    def __init__(
    -24            self,
    -25            model_path,
    -26            normalize=False,
    -27            spacy_model="en_core_sci_md",
    -28            max_length=2000000,
    -29    ):
    -30        self.model = sentence_transformers.SentenceTransformer(model_path)
    -31        self.model_path = model_path
    -32        self.nlp = spacy.load(spacy_model)
    -33        self.spacy_model = spacy_model
    -34        self.max_length = max_length
    -35        self.nlp.max_length = max_length
    -36        self.normalize = normalize
    +            
    24    def __init__(
    +25            self,
    +26            model_path,
    +27            normalize=False,
    +28            spacy_model="en_core_sci_md",
    +29            max_length=2000000,
    +30    ):
    +31        self.model = sentence_transformers.SentenceTransformer(model_path)
    +32        self.model_path = model_path
    +33        self.nlp = spacy.load(spacy_model)
    +34        self.spacy_model = spacy_model
    +35        self.max_length = max_length
    +36        self.nlp.max_length = max_length
    +37        self.normalize = normalize
     
    @@ -297,13 +299,53 @@
    Parameters
    -
    - encode = <analysis_tools_ir.utils.cache._Cache object> + +
    +
    @cache.Cache(hash_self=True, cache_dir='./cache/embedding_cache/')
    + + def + encode(self, topic: str) -> List: + + -
    - +
    39    @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/")
    +40    def encode(self, topic: str) -> List:
    +41        """
    +42        Computes sentence embeddings for a given topic, uses spacy for sentence segmentation.
    +43        By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this.
    +44
    +45        :param topic: The topic (a list of sentences) to encode. Should be a raw string.
    +46        :param disable_cache: keyword argument, pass as True to disable encoding caching.
    +47        :return:
    +48            Returns a list of encoded tensors is returned.
    +49        """
    +50        sentences = [
    +51            " ".join(sent.text.split())
    +52            for sent in self.nlp(topic).sents
    +53            if sent.text.strip()
    +54        ]
    +55
    +56        embeddings = self.model.encode(sentences, convert_to_tensor=True,
    +57                                       show_progress_bar=False)
    +58
    +59        if len(embeddings.size()) == 1:
    +60            embeddings = torch.unsqueeze(embeddings, dim=0)
    +61            embeddings = torch.mean(embeddings, axis=0)
    +62
    +63        if self.normalize:
    +64            embeddings = F.normalize(embeddings, dim=-1)
    +65
    +66        embeddings = embeddings.tolist()
    +67
    +68        if isinstance(embeddings, list) and isinstance(embeddings[0], list):
    +69            return embeddings[0]
    +70
    +71        return embeddings
    +
    + +

    Computes sentence embeddings for a given topic, uses spacy for sentence segmentation. By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this.

    @@ -426,7 +468,7 @@
    Returns
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -453,7 +495,7 @@
    Returns
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/training.html b/docs/debeir/training.html index cd44acc..c35926a 100644 --- a/docs/debeir/training.html +++ b/docs/debeir/training.html @@ -3,7 +3,7 @@ - + debeir.training API documentation @@ -166,7 +166,7 @@

    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -193,7 +193,7 @@

    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/training/evaluate_reranker.html b/docs/debeir/training/evaluate_reranker.html index f17e1ce..60b1310 100644 --- a/docs/debeir/training/evaluate_reranker.html +++ b/docs/debeir/training/evaluate_reranker.html @@ -3,7 +3,7 @@ - + debeir.training.evaluate_reranker API documentation @@ -63,97 +63,96 @@

    -
     1import numpy as np
    - 2
    - 3from collections import defaultdict
    - 4from typing import List, Dict, Union
    - 5from datasets import Dataset
    - 6from sklearn.metrics.pairwise import cosine_similarity
    - 7
    - 8from debeir.evaluation.evaluator import Evaluator
    - 9from debeir.rankers.transformer_sent_encoder import Encoder
    +                        
     1from collections import defaultdict
    + 2from typing import Dict, List, Union
    + 3
    + 4import numpy as np
    + 5from debeir.evaluation.evaluator import Evaluator
    + 6from debeir.rankers.transformer_sent_encoder import Encoder
    + 7from sklearn.metrics.pairwise import cosine_similarity
    + 8
    + 9from datasets import Dataset
     10
    -11
    -12distance_fns = {
    -13    "dot_score": np.dot,
    -14    "cos_sim": cosine_similarity
    -15}
    +11distance_fns = {
    +12    "dot_score": np.dot,
    +13    "cos_sim": cosine_similarity
    +14}
    +15
     16
    -17
    -18class SentenceEvaluator(Evaluator):
    -19    def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict],
    -20                 text_cols: List[str], query_cols: List[str], id_col: str,
    -21                 distance_fn: str,
    -22                 qrels: str, metrics: List[str]):
    -23        super().__init__(qrels, metrics)
    -24        self.encoder = model
    -25        self.dataset = dataset
    -26        self.parsed_topics = parsed_topics
    -27        self.distance_fn = distance_fns[distance_fn]
    -28        self.query_cols = query_cols
    -29        self.text_cols = text_cols
    -30
    -31        self._get_topic_embeddings(query_cols)
    -32        self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols)
    -33
    -34    def _get_topic_embeddings(self, query_cols):
    -35        for topic_num, topic in self.parsed_topics.items():
    -36            for query_col in query_cols:
    -37                query = topic[query_col]
    -38                query_eb = self.encoder(query)
    -39
    -40                topic[query_col+"_eb"] = query_eb
    -41
    -42    def _get_document_embedding_and_mapping(self, id_col, text_cols):
    -43        document_ebs = defaultdict(lambda: defaultdict(lambda: []))
    -44
    -45        for datum in self.dataset:
    -46            for text_col in text_cols:
    -47                embedding = self.encoder(datum[text_col])
    -48                topic_num, doc_id = datum[id_col].split("_")
    -49                document_ebs[topic_num][doc_id].append([text_col, embedding])
    -50
    -51        return document_ebs
    -52
    -53    def _get_score(self, a, b, aggregate="sum"):
    -54        scores = []
    -55
    -56        aggs = {
    -57            "max": max,
    -58            "min": min,
    -59            "sum": sum,
    -60            "avg": lambda k: sum(k)/len(k)
    -61        }
    -62
    -63        if not isinstance(a[0], list):
    -64            a = [a]
    -65
    -66        if not isinstance(b[0], list):
    -67            b = [b]
    -68
    -69        for _a in a:
    -70            for _b in b:
    -71                scores.append(float(self.distance_fn(_a, _b)))
    -72
    -73        return aggs[aggregate](scores)
    -74
    -75    def produce_ranked_lists(self):
    -76        # Store the indexes to access
    -77        # For each topic, sort.
    -78
    -79        topics = defaultdict(lambda: [])  # [document_id, score]
    -80
    -81        for topic_num, doc_topics in self.document_ebs.items():
    -82            for doc_id, document_repr in doc_topics.items():
    -83                doc_txt_cols, doc_embeddings = list(zip(*document_repr))
    -84
    -85                query_ebs = [self.parsed_topics[text_col+"_eb"] for text_col in self.text_cols]
    -86                topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)])
    -87
    -88        for topic_num in topics:
    -89            topics[topic_num].sort(key=lambda k: k[1], reverse=True)
    -90
    -91        return topics
    +17class SentenceEvaluator(Evaluator):
    +18    def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict],
    +19                 text_cols: List[str], query_cols: List[str], id_col: str,
    +20                 distance_fn: str,
    +21                 qrels: str, metrics: List[str]):
    +22        super().__init__(qrels, metrics)
    +23        self.encoder = model
    +24        self.dataset = dataset
    +25        self.parsed_topics = parsed_topics
    +26        self.distance_fn = distance_fns[distance_fn]
    +27        self.query_cols = query_cols
    +28        self.text_cols = text_cols
    +29
    +30        self._get_topic_embeddings(query_cols)
    +31        self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols)
    +32
    +33    def _get_topic_embeddings(self, query_cols):
    +34        for topic_num, topic in self.parsed_topics.items():
    +35            for query_col in query_cols:
    +36                query = topic[query_col]
    +37                query_eb = self.encoder(query)
    +38
    +39                topic[query_col + "_eb"] = query_eb
    +40
    +41    def _get_document_embedding_and_mapping(self, id_col, text_cols):
    +42        document_ebs = defaultdict(lambda: defaultdict(lambda: []))
    +43
    +44        for datum in self.dataset:
    +45            for text_col in text_cols:
    +46                embedding = self.encoder(datum[text_col])
    +47                topic_num, doc_id = datum[id_col].split("_")
    +48                document_ebs[topic_num][doc_id].append([text_col, embedding])
    +49
    +50        return document_ebs
    +51
    +52    def _get_score(self, a, b, aggregate="sum"):
    +53        scores = []
    +54
    +55        aggs = {
    +56            "max": max,
    +57            "min": min,
    +58            "sum": sum,
    +59            "avg": lambda k: sum(k) / len(k)
    +60        }
    +61
    +62        if not isinstance(a[0], list):
    +63            a = [a]
    +64
    +65        if not isinstance(b[0], list):
    +66            b = [b]
    +67
    +68        for _a in a:
    +69            for _b in b:
    +70                scores.append(float(self.distance_fn(_a, _b)))
    +71
    +72        return aggs[aggregate](scores)
    +73
    +74    def produce_ranked_lists(self):
    +75        # Store the indexes to access
    +76        # For each topic, sort.
    +77
    +78        topics = defaultdict(lambda: [])  # [document_id, score]
    +79
    +80        for topic_num, doc_topics in self.document_ebs.items():
    +81            for doc_id, document_repr in doc_topics.items():
    +82                doc_txt_cols, doc_embeddings = list(zip(*document_repr))
    +83
    +84                query_ebs = [self.parsed_topics[text_col + "_eb"] for text_col in self.text_cols]
    +85                topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)])
    +86
    +87        for topic_num in topics:
    +88            topics[topic_num].sort(key=lambda k: k[1], reverse=True)
    +89
    +90        return topics
     
    @@ -169,80 +168,80 @@

    -
    19class SentenceEvaluator(Evaluator):
    -20    def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict],
    -21                 text_cols: List[str], query_cols: List[str], id_col: str,
    -22                 distance_fn: str,
    -23                 qrels: str, metrics: List[str]):
    -24        super().__init__(qrels, metrics)
    -25        self.encoder = model
    -26        self.dataset = dataset
    -27        self.parsed_topics = parsed_topics
    -28        self.distance_fn = distance_fns[distance_fn]
    -29        self.query_cols = query_cols
    -30        self.text_cols = text_cols
    -31
    -32        self._get_topic_embeddings(query_cols)
    -33        self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols)
    -34
    -35    def _get_topic_embeddings(self, query_cols):
    -36        for topic_num, topic in self.parsed_topics.items():
    -37            for query_col in query_cols:
    -38                query = topic[query_col]
    -39                query_eb = self.encoder(query)
    -40
    -41                topic[query_col+"_eb"] = query_eb
    -42
    -43    def _get_document_embedding_and_mapping(self, id_col, text_cols):
    -44        document_ebs = defaultdict(lambda: defaultdict(lambda: []))
    -45
    -46        for datum in self.dataset:
    -47            for text_col in text_cols:
    -48                embedding = self.encoder(datum[text_col])
    -49                topic_num, doc_id = datum[id_col].split("_")
    -50                document_ebs[topic_num][doc_id].append([text_col, embedding])
    -51
    -52        return document_ebs
    -53
    -54    def _get_score(self, a, b, aggregate="sum"):
    -55        scores = []
    -56
    -57        aggs = {
    -58            "max": max,
    -59            "min": min,
    -60            "sum": sum,
    -61            "avg": lambda k: sum(k)/len(k)
    -62        }
    -63
    -64        if not isinstance(a[0], list):
    -65            a = [a]
    -66
    -67        if not isinstance(b[0], list):
    -68            b = [b]
    -69
    -70        for _a in a:
    -71            for _b in b:
    -72                scores.append(float(self.distance_fn(_a, _b)))
    -73
    -74        return aggs[aggregate](scores)
    -75
    -76    def produce_ranked_lists(self):
    -77        # Store the indexes to access
    -78        # For each topic, sort.
    -79
    -80        topics = defaultdict(lambda: [])  # [document_id, score]
    -81
    -82        for topic_num, doc_topics in self.document_ebs.items():
    -83            for doc_id, document_repr in doc_topics.items():
    -84                doc_txt_cols, doc_embeddings = list(zip(*document_repr))
    -85
    -86                query_ebs = [self.parsed_topics[text_col+"_eb"] for text_col in self.text_cols]
    -87                topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)])
    -88
    -89        for topic_num in topics:
    -90            topics[topic_num].sort(key=lambda k: k[1], reverse=True)
    -91
    -92        return topics
    +            
    18class SentenceEvaluator(Evaluator):
    +19    def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict],
    +20                 text_cols: List[str], query_cols: List[str], id_col: str,
    +21                 distance_fn: str,
    +22                 qrels: str, metrics: List[str]):
    +23        super().__init__(qrels, metrics)
    +24        self.encoder = model
    +25        self.dataset = dataset
    +26        self.parsed_topics = parsed_topics
    +27        self.distance_fn = distance_fns[distance_fn]
    +28        self.query_cols = query_cols
    +29        self.text_cols = text_cols
    +30
    +31        self._get_topic_embeddings(query_cols)
    +32        self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols)
    +33
    +34    def _get_topic_embeddings(self, query_cols):
    +35        for topic_num, topic in self.parsed_topics.items():
    +36            for query_col in query_cols:
    +37                query = topic[query_col]
    +38                query_eb = self.encoder(query)
    +39
    +40                topic[query_col + "_eb"] = query_eb
    +41
    +42    def _get_document_embedding_and_mapping(self, id_col, text_cols):
    +43        document_ebs = defaultdict(lambda: defaultdict(lambda: []))
    +44
    +45        for datum in self.dataset:
    +46            for text_col in text_cols:
    +47                embedding = self.encoder(datum[text_col])
    +48                topic_num, doc_id = datum[id_col].split("_")
    +49                document_ebs[topic_num][doc_id].append([text_col, embedding])
    +50
    +51        return document_ebs
    +52
    +53    def _get_score(self, a, b, aggregate="sum"):
    +54        scores = []
    +55
    +56        aggs = {
    +57            "max": max,
    +58            "min": min,
    +59            "sum": sum,
    +60            "avg": lambda k: sum(k) / len(k)
    +61        }
    +62
    +63        if not isinstance(a[0], list):
    +64            a = [a]
    +65
    +66        if not isinstance(b[0], list):
    +67            b = [b]
    +68
    +69        for _a in a:
    +70            for _b in b:
    +71                scores.append(float(self.distance_fn(_a, _b)))
    +72
    +73        return aggs[aggregate](scores)
    +74
    +75    def produce_ranked_lists(self):
    +76        # Store the indexes to access
    +77        # For each topic, sort.
    +78
    +79        topics = defaultdict(lambda: [])  # [document_id, score]
    +80
    +81        for topic_num, doc_topics in self.document_ebs.items():
    +82            for doc_id, document_repr in doc_topics.items():
    +83                doc_txt_cols, doc_embeddings = list(zip(*document_repr))
    +84
    +85                query_ebs = [self.parsed_topics[text_col + "_eb"] for text_col in self.text_cols]
    +86                topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)])
    +87
    +88        for topic_num in topics:
    +89            topics[topic_num].sort(key=lambda k: k[1], reverse=True)
    +90
    +91        return topics
     
    @@ -260,20 +259,20 @@

    -
    20    def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict],
    -21                 text_cols: List[str], query_cols: List[str], id_col: str,
    -22                 distance_fn: str,
    -23                 qrels: str, metrics: List[str]):
    -24        super().__init__(qrels, metrics)
    -25        self.encoder = model
    -26        self.dataset = dataset
    -27        self.parsed_topics = parsed_topics
    -28        self.distance_fn = distance_fns[distance_fn]
    -29        self.query_cols = query_cols
    -30        self.text_cols = text_cols
    -31
    -32        self._get_topic_embeddings(query_cols)
    -33        self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols)
    +            
    19    def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict],
    +20                 text_cols: List[str], query_cols: List[str], id_col: str,
    +21                 distance_fn: str,
    +22                 qrels: str, metrics: List[str]):
    +23        super().__init__(qrels, metrics)
    +24        self.encoder = model
    +25        self.dataset = dataset
    +26        self.parsed_topics = parsed_topics
    +27        self.distance_fn = distance_fns[distance_fn]
    +28        self.query_cols = query_cols
    +29        self.text_cols = text_cols
    +30
    +31        self._get_topic_embeddings(query_cols)
    +32        self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols)
     
    @@ -291,23 +290,23 @@

    -
    76    def produce_ranked_lists(self):
    -77        # Store the indexes to access
    -78        # For each topic, sort.
    -79
    -80        topics = defaultdict(lambda: [])  # [document_id, score]
    -81
    -82        for topic_num, doc_topics in self.document_ebs.items():
    -83            for doc_id, document_repr in doc_topics.items():
    -84                doc_txt_cols, doc_embeddings = list(zip(*document_repr))
    -85
    -86                query_ebs = [self.parsed_topics[text_col+"_eb"] for text_col in self.text_cols]
    -87                topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)])
    -88
    -89        for topic_num in topics:
    -90            topics[topic_num].sort(key=lambda k: k[1], reverse=True)
    -91
    -92        return topics
    +            
    75    def produce_ranked_lists(self):
    +76        # Store the indexes to access
    +77        # For each topic, sort.
    +78
    +79        topics = defaultdict(lambda: [])  # [document_id, score]
    +80
    +81        for topic_num, doc_topics in self.document_ebs.items():
    +82            for doc_id, document_repr in doc_topics.items():
    +83                doc_txt_cols, doc_embeddings = list(zip(*document_repr))
    +84
    +85                query_ebs = [self.parsed_topics[text_col + "_eb"] for text_col in self.text_cols]
    +86                topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)])
    +87
    +88        for topic_num in topics:
    +89            topics[topic_num].sort(key=lambda k: k[1], reverse=True)
    +90
    +91        return topics
     
    @@ -428,7 +427,7 @@
    Inherited Members
    } let heading; - switch (result.doc.type) { + switch (result.doc.kind) { case "function": if (doc.fullname.endsWith(".__init__")) { heading = `${doc.fullname.replace(/\.__init__$/, "")}${doc.signature}`; @@ -455,7 +454,7 @@
    Inherited Members
    } html += `
    - ${heading} + ${heading}
    ${doc.doc}
    `; diff --git a/docs/debeir/training/hparm_tuning.html b/docs/debeir/training/hparm_tuning.html index ccbc944..c02e452 100644 --- a/docs/debeir/training/hparm_tuning.html +++ b/docs/debeir/training/hparm_tuning.html @@ -3,7 +3,7 @@ - + debeir.training.hparm_tuning API documentation @@ -49,10 +49,19 @@

    Submodules

    debeir.training.hparm_tuning

    - - - - +

    Hyper parameter tuning library using Optuna and Wandb

    +
    + + + + + +
    1"""
    +2Hyper parameter tuning library using Optuna and Wandb
    +3"""
    +
    + +