From 8d9970a240c4c0e4927c63d06314eecc88fc11f0 Mon Sep 17 00:00:00 2001 From: "Sasha :)" <22552600+thatsmesasha@users.noreply.github.com> Date: Thu, 6 Jun 2024 17:29:51 +0200 Subject: [PATCH] Update nlu_evaluation.ipynb (#572) * Update nlu_evaluation.ipynb with some fixes and adding environment name to call DF API * Update nlu_evaluation.ipynb Fix a typo for a return type of the function. --------- Co-authored-by: thatsmesasha --- nlu-evaluation/nlu_evaluation.ipynb | 47 ++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/nlu-evaluation/nlu_evaluation.ipynb b/nlu-evaluation/nlu_evaluation.ipynb index fa698f7d..edb42014 100644 --- a/nlu-evaluation/nlu_evaluation.ipynb +++ b/nlu-evaluation/nlu_evaluation.ipynb @@ -214,6 +214,7 @@ "from google.cloud.dialogflowcx_v3.services.flows import FlowsClient\n", "from google.cloud.dialogflowcx_v3.services.pages import PagesClient\n", "from google.cloud.dialogflowcx_v3.services.sessions import SessionsClient\n", + "from google.cloud.dialogflowcx_v3.services.environments import EnvironmentsClient\n", "from google.cloud.dialogflowcx_v3.types import session\n", "\n", "import warnings\n", @@ -269,7 +270,8 @@ " print(f'Worksheet `{summary_tab}` does not exist. Creating...')\n", " info['summary_worksheet'] = info['spreadsheet'].add_worksheet(summary_tab, rows=summary_rows, cols=len(summary_columns))\n", "\n", - " if not info['summary_worksheet'].get_all_values():\n", + " summary_values = info['summary_worksheet'].get_all_values()\n", + " if len(summary_values) == 1 and not summary_values[0]:\n", " print('Summary worksheet is empty. Initializing...')\n", " cell_list = info['summary_worksheet'].range(f'A1:{chr(ord(\"A\") + len(summary_columns) - 1)}1')\n", " for i, cell in enumerate(cell_list):\n", @@ -322,6 +324,11 @@ " api_endpoint = get_api_endpoint(region)\n", " return PagesClient(client_options={'api_endpoint': api_endpoint})\n", "\n", + "def get_environment_client(region: str) -> EnvironmentsClient:\n", + " '''Returns DF CX EnvironmentClient based on region.'''\n", + " api_endpoint = get_api_endpoint(region)\n", + " return EnvironmentsClient(client_options={'api_endpoint': api_endpoint})\n", + "\n", "\n", "@ratelimit.sleep_and_retry\n", "@ratelimit.limits(calls=MAX_CALLS, period=ONE_SECOND)\n", @@ -330,11 +337,12 @@ " project_id: str,\n", " region: str,\n", " agent_id: str,\n", + " environment_id: str,\n", " language_code: str,\n", " flow_id: str,\n", " page_id: str,\n", " utterance: str) -> tuple[str, float]:\n", - " session_path = f'projects/{project_id}/locations/{region}/agents/{agent_id}/sessions/{uuid.uuid4()}'\n", + " session_path = f'projects/{project_id}/locations/{region}/agents/{agent_id}/environments/{environment_id}/sessions/{uuid.uuid4()}'\n", "\n", " current_page_path = f'projects/{project_id}/locations/{region}/agents/{agent_id}/flows/{flow_id}/pages/{page_id}'\n", "\n", @@ -360,6 +368,7 @@ " project_id: str,\n", " region: str,\n", " agent_id: str,\n", + " environment_id: str,\n", " language_code: str,\n", " flow_ids: list[str],\n", " page_ids: list[str],\n", @@ -382,6 +391,7 @@ " [project_id] * count,\n", " [region] * count,\n", " [agent_id] * count,\n", + " [environment_id] * count,\n", " [language_code] * count,\n", " flow_ids,\n", " page_ids,\n", @@ -418,20 +428,38 @@ " page_ids[(flow_id, 'End Session')] = 'END_SESSION'\n", " return page_ids\n", "\n", + "def get_environment_id(environment_client: EnvironmentsClient, project_id: str, region: str, agent_id: str, environment_name: str) -> str:\n", + " environments = environment_client.list_environments(parent=f'projects/{project_id}/locations/{region}/agents/{agent_id}')\n", + " for environment in environments:\n", + " if environment.display_name == environment_name:\n", + " return environment.name.split('/')[-1]\n", + " raise ValueError(f'Environment {environment_name} not found. There are following environments available: {\", \".join([\"draft\"] + [environment.display_name for environment in environments])}')\n", + "\n", "\n", "def evaluate_dataset(\n", " session_client: SessionsClient,\n", " flow_client: FlowsClient,\n", " page_client: PagesClient,\n", + " environment_client: EnvironmentsClient,\n", " dataset: pd.DataFrame,\n", " agent_label: str,\n", " project_id: str,\n", " region: str,\n", " agent_id: str,\n", + " environment_name: str,\n", " language_code: str,\n", ") -> pd.DataFrame:\n", " \"\"\"Evaluates datasets and populates new columns inside it.\"\"\"\n", "\n", + " # Map environment\n", + " if environment_name and environment_name != 'draft':\n", + " print('Getting environment mapping...', end='')\n", + " environment_id = get_environment_id(environment_client, project_id, region, agent_id, environment_name)\n", + " print('done')\n", + " print(f'Using environment id {environment_id} for environment {environment_name}.')\n", + " else:\n", + " environment_id = 'draft'\n", + "\n", " # Map flows\n", " print('Getting flow mapping...', end='')\n", " flow_display_name_to_ids = get_flow_ids(flow_client, project_id, region, agent_id)\n", @@ -457,6 +485,7 @@ " project_id,\n", " region,\n", " agent_id,\n", + " environment_id,\n", " language_code,\n", " flow_ids,\n", " page_ids,\n", @@ -495,13 +524,13 @@ " summary = pd.concat([summary, pd.DataFrame(run_values, index=[0])], ignore_index=True)\n", "\n", " summary_worksheet.clear()\n", - " summary_worksheet.update(range_name=\"A:H\", values=([summary.columns.values.tolist()] + summary.values.tolist()))\n", + " summary_worksheet.update(range_name=\"A:ZZ\", values=([summary.columns.values.tolist()] + summary.values.tolist()))\n", "\n", "def write_dataset(dataset_worksheet: gspread.Worksheet, dataset: pd.DataFrame) -> None:\n", " dataset = dataset.drop(['Flow Id', 'Page Id'], axis=1)\n", "\n", " dataset_worksheet.clear()\n", - " dataset_worksheet.update(range_name=\"A:H\", values=([dataset.columns.values.tolist()] + dataset.values.tolist()))\n", + " dataset_worksheet.update(range_name=\"A:ZZ\", values=([dataset.columns.values.tolist()] + dataset.values.tolist()))\n", "\n", "\n", "def run(\n", @@ -510,6 +539,7 @@ " project_id: str,\n", " region: str,\n", " agent_id: str,\n", + " environment_name: str,\n", " language_code: str,\n", " spreadsheet_url: str,\n", " dataset_tab: str,\n", @@ -519,11 +549,12 @@ " session_client = get_session_client(region)\n", " flow_client = get_flow_client(region)\n", " page_client = get_page_client(region)\n", + " environment_client = get_environment_client(region)\n", "\n", " spreadsheet_info, dataset = setup_spreadsheet(spreadsheet_client, spreadsheet_url, dataset_tab, summary_tab)\n", "\n", " timestamp = datetime.datetime.utcnow().strftime('%Y.%m.%d %H:%M:%S')\n", - " dataset = evaluate_dataset(session_client, flow_client, page_client, dataset, agent_label, project_id, region, agent_id, language_code)\n", + " dataset = evaluate_dataset(session_client, flow_client, page_client, environment_client, dataset, agent_label, project_id, region, agent_id, environment_name, language_code)\n", "\n", " write_dataset(spreadsheet_info['dataset_worksheet'], dataset)\n", " write_summary(spreadsheet_info['summary_worksheet'], dataset, agent_label, timestamp, dataset_tab)\n" @@ -571,6 +602,7 @@ "project_id_1 = 'YOUR_PROJECT_ID' #@param {type:\"string\"}\n", "region_1 = 'us-central1' #@param {type:\"string\"}\n", "agent_id_1 = 'YOUR_AGENT_ID' #@param {type:\"string\"}\n", + "environment_name_1 = 'draft' #@param {type:\"string\"}\n", "language_code_1 = 'en' #@param {type:\"string\"}\n", "\n", "#@markdown  \n", @@ -585,6 +617,7 @@ "project_id_2 = 'YOUR_PROJECT_ID' #@param {type:\"string\"}\n", "region_2 = 'global' #@param {type:\"string\"}\n", "agent_id_2 = 'YOUR_AGENT_ID' #@param {type:\"string\"}\n", + "environment_name_2 = 'prod' #@param {type:\"string\"}\n", "language_code_2 = 'en' #@param {type:\"string\"}\n", "\n", "#@markdown  \n", @@ -597,13 +630,13 @@ "auth.authenticate_user(project_id=project_id_1)\n", "creds, _ = default()\n", "print(f\"### Running evaluation on agent {agent_label_1} ###\\n\")\n", - "run(creds, agent_label_1, project_id_1, region_1, agent_id_1, language_code_1, spreadsheet_url, dataset_tab, summary_tab)\n", + "run(creds, agent_label_1, project_id_1, region_1, agent_id_1, environment_name_1, language_code_1, spreadsheet_url, dataset_tab, summary_tab)\n", "\n", "if use_second_agent:\n", " auth.authenticate_user(project_id=project_id_2)\n", " creds, _ = default()\n", " print(f\"\\n### Running evaluation on agent {agent_label_2} ###\\n\")\n", - " run(creds, agent_label_2, project_id_2, region_2, agent_id_2, language_code_2, spreadsheet_url, dataset_tab, summary_tab)\n" + " run(creds, agent_label_2, project_id_2, region_2, agent_id_2, environment_name_2, language_code_2, spreadsheet_url, dataset_tab, summary_tab)\n" ] } ],