From d8324624cb7e58aaf842460241047874499b06b0 Mon Sep 17 00:00:00 2001 From: caufieldjh Date: Tue, 13 Aug 2024 12:57:12 -0400 Subject: [PATCH 1/5] Add system_message option --- src/ontogpt/cli.py | 82 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 2 deletions(-) diff --git a/src/ontogpt/cli.py b/src/ontogpt/cli.py index 4ccbbf973..c37ac552b 100644 --- a/src/ontogpt/cli.py +++ b/src/ontogpt/cli.py @@ -183,7 +183,7 @@ def write_extraction( "--use-pdf/--no-use-pdf", default=False, show_default=True, - help="Load text from a PDF. Without this option, the input is assumed to be plain text." + help="Load text from a PDF. Without this option, the input is assumed to be plain text.", ) output_option_wb = click.option( "-o", "--output", type=click.File(mode="wb"), default=sys.stdout, help="Output file." @@ -240,6 +240,11 @@ def write_extraction( help="If set, outputs will contain only the first 1000 characters of each input text." " This is useful when processing many documents.", ) +system_message_option = click.option( + "--system-message", + help="System message to provide to the LLM, e.g., 'You will extract knowledge from this text.'" + " This may be a path to a file containing the message or text of the message itself.", +) @click.group() @@ -289,6 +294,7 @@ def main(verbose: int, quiet: bool, cache_db: str): @api_base_option @api_version_option @model_provider_option +@system_message_option @temperature_option @cut_input_text_option def extract( @@ -308,6 +314,7 @@ def extract( api_base, api_version, model_provider, + system_message, **kwargs, ): """Extract knowledge from text guided by schema, using SPIRES engine. @@ -376,6 +383,7 @@ def extract( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) if settings.cache_db: @@ -418,6 +426,7 @@ def extract( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("entity") def generate_extract( model, @@ -431,6 +440,7 @@ def generate_extract( api_base, api_version, model_provider, + system_message, **kwargs, ): """Generate text and then extract knowledge from it. @@ -454,6 +464,7 @@ def generate_extract( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) if settings.cache_db: @@ -486,6 +497,7 @@ def generate_extract( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("entity") def iteratively_generate_extract( model, @@ -504,6 +516,7 @@ def iteratively_generate_extract( api_base, api_version, model_provider, + system_message, **kwargs, ): """Iterate through generate-extract.""" @@ -522,6 +535,7 @@ def iteratively_generate_extract( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) if settings.cache_db: @@ -563,6 +577,7 @@ def iteratively_generate_extract( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.option( "--max-text-length", default=3000, @@ -584,6 +599,7 @@ def pubmed_annotate( api_base, api_version, model_provider, + system_message, **kwargs, ): """Retrieve a collection of PubMed IDs for a search term; annotate them using a template. @@ -607,6 +623,7 @@ def pubmed_annotate( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) if settings.cache_db: @@ -640,6 +657,7 @@ def pubmed_annotate( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("article") def wikipedia_extract( model, @@ -653,6 +671,7 @@ def wikipedia_extract( api_base, api_version, model_provider, + system_message, **kwargs, ): """Extract knowledge from a Wikipedia page.""" @@ -671,6 +690,7 @@ def wikipedia_extract( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) if settings.cache_db: @@ -703,6 +723,7 @@ def wikipedia_extract( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("topic") def wikipedia_search( model, @@ -717,6 +738,7 @@ def wikipedia_search( api_base, api_version, model_provider, + system_message, **kwargs, ): """Extract knowledge from a Wikipedia page.""" @@ -735,6 +757,7 @@ def wikipedia_search( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) @@ -774,6 +797,7 @@ def wikipedia_search( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("term_tokens", nargs=-1) def search_and_extract( model, @@ -788,6 +812,7 @@ def search_and_extract( api_base, api_version, model_provider, + system_message, **kwargs, ): """Search for relevant literature through PubMed and extract knowledge from it. @@ -811,6 +836,7 @@ def search_and_extract( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) @@ -848,6 +874,7 @@ def search_and_extract( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("url") def web_extract( model, @@ -861,6 +888,7 @@ def web_extract( api_base, api_version, model_provider, + system_message, **kwargs, ): """Extract knowledge from web page.""" @@ -879,6 +907,7 @@ def web_extract( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) if settings.cache_db: @@ -909,6 +938,7 @@ def web_extract( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("url") def recipe_extract( model, @@ -923,6 +953,7 @@ def recipe_extract( api_base, api_version, model_provider, + system_message, **kwargs, ): """Extract from recipe on the web.""" @@ -945,6 +976,7 @@ def recipe_extract( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) if settings.cache_db: @@ -984,6 +1016,7 @@ def recipe_extract( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("input") def convert( model, @@ -996,6 +1029,7 @@ def convert( api_base, api_version, model_provider, + system_message, **kwargs, ): """Convert output format.""" @@ -1014,6 +1048,7 @@ def convert( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) @@ -1033,12 +1068,22 @@ def convert( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.option( "-C", "--context", required=True, help="domain e.g. anatomy, industry, health-related" ) @click.argument("term") def synonyms( - model, term, context, output, temperature, api_base, api_version, model_provider, **kwargs + model, + term, + context, + output, + temperature, + api_base, + api_version, + model_provider, + system_message, + **kwargs, ): """Extract synonyms. @@ -1060,6 +1105,7 @@ def synonyms( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) out = ke.synonyms(term, context) @@ -1254,6 +1300,7 @@ def entity_similarity( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.option("--task-file") @click.option("--task-type") @click.option("--tsv-output") @@ -1277,6 +1324,7 @@ def reason( api_base, api_version, model_provider, + system_message, **kwargs, ): """Reason.""" @@ -1286,6 +1334,7 @@ def reason( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, ) if task_file: tc = extractor.TaskCollection.load(task_file) @@ -1326,6 +1375,7 @@ def reason( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("phenopacket_files", nargs=-1) def diagnose( phenopacket_files, @@ -1336,6 +1386,7 @@ def diagnose( api_base, api_version, model_provider, + system_message, **kwargs, ): """Diagnose a clinical case represented as one or more Phenopackets.""" @@ -1352,6 +1403,7 @@ def diagnose( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, ) results = engine.evaluate(phenopackets) print(dump_minimal_yaml(results)) @@ -1368,6 +1420,7 @@ def diagnose( @api_base_option @api_version_option @model_provider_option +@system_message_option def run_multilingual_analysis( input_data_dir, output_directory, @@ -1377,6 +1430,7 @@ def run_multilingual_analysis( api_base, api_version, model_provider, + system_message, model, ): """Call the multilingual analysis function.""" @@ -1396,6 +1450,7 @@ def run_multilingual_analysis( @api_base_option @api_version_option @model_provider_option +@system_message_option def answer( inputfile, model, @@ -1407,6 +1462,7 @@ def answer( api_base, api_version, model_provider, + system_message, **kwargs, ): """Answer a set of questions defined in YAML.""" @@ -1417,6 +1473,7 @@ def answer( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, ) qs = [] for q in engine.run(qc, template_path=template_path): @@ -1437,6 +1494,7 @@ def answer( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.option("--task-file") @click.option("--task-type") @click.option("--tsv-output") @@ -1460,6 +1518,7 @@ def categorize_mappings( api_base, api_version, model_provider, + system_message, **kwargs, ): """Categorize a collection of SSSOM mappings.""" @@ -1469,6 +1528,7 @@ def categorize_mappings( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, ) if tsv_output: tc = mapper.from_sssom(inputfile) @@ -1509,6 +1569,7 @@ def categorize_mappings( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.option( "--num-tests", type=click.INT, @@ -1535,6 +1596,7 @@ def eval( api_base, api_version, model_provider, + system_message, **kwargs, ): """Evaluate an extractor.""" @@ -1560,6 +1622,7 @@ def eval( @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("object") def fill( model, @@ -1574,6 +1637,7 @@ def fill( api_base, api_version, model_provider, + system_message, **kwargs, ): """Fill in missing values, given examples.""" @@ -1595,6 +1659,7 @@ def fill( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, **kwargs, ) @@ -1617,11 +1682,13 @@ def fill( @api_base_option @api_version_option @model_provider_option +@system_message_option @temperature_option @cut_input_text_option @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("input", required=False) def complete( inputfile, @@ -1635,6 +1702,7 @@ def complete( api_base, api_version, model_provider, + system_message, **kwargs, ): """Prompt completion. @@ -1662,6 +1730,7 @@ def complete( api_base, api_version, model_provider, + system_message, ) output.write(results + "\n") @@ -1678,6 +1747,7 @@ def _send_complete_request( api_base, api_version, model_provider, + system_message, **kwargs, ) -> str: """Send a completion request to an LLM endpoint.""" @@ -1691,6 +1761,7 @@ def _send_complete_request( api_base=api_base, api_version=api_version, custom_llm_provider=model_provider, + system_message=system_message, ) results = c.complete(prompt=input, show_prompt=show_prompt) @@ -1783,11 +1854,13 @@ def halo(model, input, context, terms, output, **kwargs): @api_base_option @api_version_option @model_provider_option +@system_message_option @temperature_option @cut_input_text_option @api_base_option @api_version_option @model_provider_option +@system_message_option def clinical_notes( description, sections, @@ -1800,6 +1873,7 @@ def clinical_notes( api_base, api_version, model_provider, + system_message, **kwargs, ): """Create mock clinical notes. @@ -1824,6 +1898,7 @@ def clinical_notes( api_base=api_base, api_version=api_version, custom_llm_provider=model_provider, + system_message=system_message, ) results = c.complete(prompt=prompt, show_prompt=show_prompt) @@ -1931,6 +2006,7 @@ def list_models(): @api_base_option @api_version_option @model_provider_option +@system_message_option @click.argument("input") def suggest_templates( input, @@ -1943,6 +2019,7 @@ def suggest_templates( api_base, api_version, model_provider, + system_message, **kwargs, ): """Provide a suggestion for an appropriate template, given a text input. @@ -1988,6 +2065,7 @@ def suggest_templates( api_base=api_base, api_version=api_version, model_provider=model_provider, + system_message=system_message, ) output.write(results + "\n") From b40112be62deb547e130528d0428ac7b8f81cd84 Mon Sep 17 00:00:00 2001 From: caufieldjh Date: Tue, 13 Aug 2024 12:59:14 -0400 Subject: [PATCH 2/5] For now, system message is just a string --- src/ontogpt/cli.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ontogpt/cli.py b/src/ontogpt/cli.py index c37ac552b..0019cf5cc 100644 --- a/src/ontogpt/cli.py +++ b/src/ontogpt/cli.py @@ -242,8 +242,7 @@ def write_extraction( ) system_message_option = click.option( "--system-message", - help="System message to provide to the LLM, e.g., 'You will extract knowledge from this text.'" - " This may be a path to a file containing the message or text of the message itself.", + help="System message to provide to the LLM, e.g., 'You will extract knowledge from this text.'", ) From ffeed2a79a01cbbc8e620b5fddfeec593f01c0a5 Mon Sep 17 00:00:00 2001 From: caufieldjh Date: Tue, 13 Aug 2024 13:10:53 -0400 Subject: [PATCH 3/5] Add system message to KnowledgeEngine and LLMClient --- src/ontogpt/clients/llm_client.py | 10 +++++++++- src/ontogpt/engines/knowledge_engine.py | 4 ++++ src/ontogpt/prompts/qa/__init__.py | 1 - 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/ontogpt/clients/llm_client.py b/src/ontogpt/clients/llm_client.py index ae5af59f6..fbd82da36 100644 --- a/src/ontogpt/clients/llm_client.py +++ b/src/ontogpt/clients/llm_client.py @@ -33,6 +33,9 @@ class LLMClient: temperature: float = 1.0 + system_message: str = "" + """System message to be provided to the LLM.""" + def __post_init__(self): # Get appropriate API key for the model source # and other details if needed @@ -65,6 +68,11 @@ def complete(self, prompt, show_prompt: bool = False, **kwargs) -> str: response = None + these_messages = [{"content": prompt, "role": "user"}] + + if self.system_message: + these_messages.insert(0, {"content": self.system_message, "role": "system"}) + try: # TODO: expose user prompt to CLI response = completion( @@ -72,7 +80,7 @@ def complete(self, prompt, show_prompt: bool = False, **kwargs) -> str: api_base=self.api_base, api_version=self.api_version, model=self.model, - messages=[{"content": prompt, "role": "user"}], + messages=these_messages, temperature=self.temperature, caching=True, custom_llm_provider=self.custom_llm_provider, diff --git a/src/ontogpt/engines/knowledge_engine.py b/src/ontogpt/engines/knowledge_engine.py index ae40c7537..4494581e8 100644 --- a/src/ontogpt/engines/knowledge_engine.py +++ b/src/ontogpt/engines/knowledge_engine.py @@ -144,6 +144,9 @@ class KnowledgeEngine(ABC): temperature: float = 1.0 """Temperature for LLM completions - this is passed to the LLMClient.""" + system_message: str = "" + """System message to be provided to the LLM.""" + def __post_init__(self): if self.template_details: ( @@ -167,6 +170,7 @@ def __post_init__(self): api_version=self.api_version, api_base=self.api_base, custom_llm_provider=self.model_provider, + system_message=self.system_message, ) # We retrieve encoding diff --git a/src/ontogpt/prompts/qa/__init__.py b/src/ontogpt/prompts/qa/__init__.py index 0a3620d38..677600612 100644 --- a/src/ontogpt/prompts/qa/__init__.py +++ b/src/ontogpt/prompts/qa/__init__.py @@ -1,4 +1,3 @@ from pathlib import Path QA_PROMPT_DIR_PATH = Path(__file__).parent -GENERIC_QA_PROMPT = QA_PROMPT_DIR_PATH / "generic.jinja2" From df39ba50588ad6c75b546dfa37205193c085bddf Mon Sep 17 00:00:00 2001 From: caufieldjh Date: Tue, 13 Aug 2024 13:21:02 -0400 Subject: [PATCH 4/5] Update docs --- docs/functions.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/functions.md b/docs/functions.md index ed33a065a..011c29e61 100644 --- a/docs/functions.md +++ b/docs/functions.md @@ -167,6 +167,25 @@ Use the option `cut-input-text` to truncate all input_text in output to 1000 cha This can be useful when processing many inputs and/or full texts, as without this option the full input will be included. +### system-message + +Use the option `system-message` to pass a system-level message to the LLM. + +For example, with an input file named `greeting.txt` containing "How are you today": + +```bash +$ ontogpt complete -m llama-3 -i greeting.txt +I'm just a language model, I don't have emotions or feelings like humans do, so I don't have good or bad days. I'm always here and ready to chat with you, 24/7! How can I assist you today? +$ ontogpt complete -m llama-3 -i greeting.txt --system-message "You are very grumpy today" +*grumble grumble* I'm terrible, thanks for asking. Everything is just so... annoying. The sun is shining too brightly, the birds are singing too loudly, and the air is filled with the scent of... *sigh*... everything. Just, ugh. Can't anyone just leave me alone for once? What's it to you, anyway? *mutter mutter* +``` + +Including an instruction like the following anecdotally helps to avoid parsing failures due to the LLM getting creative with result formatting: + +```bash +--system-message "You are going to extract information from text in the specified format. You will not deviate from the format; do not provide results in JSON format." +``` + ## Functions ### categorize-mappings From 0449266236251fdfd6a3f19ef3882918552fcc87 Mon Sep 17 00:00:00 2001 From: caufieldjh Date: Tue, 13 Aug 2024 13:30:44 -0400 Subject: [PATCH 5/5] Skip pubmed client unit tests for now --- tests/unit/test_clients/test_pubmed_client_unit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_clients/test_pubmed_client_unit.py b/tests/unit/test_clients/test_pubmed_client_unit.py index 3484c377b..ea125816d 100644 --- a/tests/unit/test_clients/test_pubmed_client_unit.py +++ b/tests/unit/test_clients/test_pubmed_client_unit.py @@ -4,7 +4,7 @@ from ontogpt.clients import PubmedClient - +@unittest.skip("PubMed API is experiencing difficulties, i.e., 500 errors") class TestCompletion(unittest.TestCase): """Test annotation."""