Skip to content

Commit

Permalink
Update nlu_evaluation.ipynb (#572)
Browse files Browse the repository at this point in the history
* Update nlu_evaluation.ipynb with some fixes and adding environment name to call DF API

* Update nlu_evaluation.ipynb

Fix a typo for a return type of the function.

---------

Co-authored-by: thatsmesasha <[email protected]>
  • Loading branch information
thatsmesasha and thatsmesasha authored Jun 6, 2024
1 parent adc9421 commit 8d9970a
Showing 1 changed file with 40 additions and 7 deletions.
47 changes: 40 additions & 7 deletions nlu-evaluation/nlu_evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@
"from google.cloud.dialogflowcx_v3.services.flows import FlowsClient\n",
"from google.cloud.dialogflowcx_v3.services.pages import PagesClient\n",
"from google.cloud.dialogflowcx_v3.services.sessions import SessionsClient\n",
"from google.cloud.dialogflowcx_v3.services.environments import EnvironmentsClient\n",
"from google.cloud.dialogflowcx_v3.types import session\n",
"\n",
"import warnings\n",
Expand Down Expand Up @@ -269,7 +270,8 @@
" print(f'Worksheet `{summary_tab}` does not exist. Creating...')\n",
" info['summary_worksheet'] = info['spreadsheet'].add_worksheet(summary_tab, rows=summary_rows, cols=len(summary_columns))\n",
"\n",
" if not info['summary_worksheet'].get_all_values():\n",
" summary_values = info['summary_worksheet'].get_all_values()\n",
" if len(summary_values) == 1 and not summary_values[0]:\n",
" print('Summary worksheet is empty. Initializing...')\n",
" cell_list = info['summary_worksheet'].range(f'A1:{chr(ord(\"A\") + len(summary_columns) - 1)}1')\n",
" for i, cell in enumerate(cell_list):\n",
Expand Down Expand Up @@ -322,6 +324,11 @@
" api_endpoint = get_api_endpoint(region)\n",
" return PagesClient(client_options={'api_endpoint': api_endpoint})\n",
"\n",
"def get_environment_client(region: str) -> EnvironmentsClient:\n",
" '''Returns DF CX EnvironmentClient based on region.'''\n",
" api_endpoint = get_api_endpoint(region)\n",
" return EnvironmentsClient(client_options={'api_endpoint': api_endpoint})\n",
"\n",
"\n",
"@ratelimit.sleep_and_retry\n",
"@ratelimit.limits(calls=MAX_CALLS, period=ONE_SECOND)\n",
Expand All @@ -330,11 +337,12 @@
" project_id: str,\n",
" region: str,\n",
" agent_id: str,\n",
" environment_id: str,\n",
" language_code: str,\n",
" flow_id: str,\n",
" page_id: str,\n",
" utterance: str) -> tuple[str, float]:\n",
" session_path = f'projects/{project_id}/locations/{region}/agents/{agent_id}/sessions/{uuid.uuid4()}'\n",
" session_path = f'projects/{project_id}/locations/{region}/agents/{agent_id}/environments/{environment_id}/sessions/{uuid.uuid4()}'\n",
"\n",
" current_page_path = f'projects/{project_id}/locations/{region}/agents/{agent_id}/flows/{flow_id}/pages/{page_id}'\n",
"\n",
Expand All @@ -360,6 +368,7 @@
" project_id: str,\n",
" region: str,\n",
" agent_id: str,\n",
" environment_id: str,\n",
" language_code: str,\n",
" flow_ids: list[str],\n",
" page_ids: list[str],\n",
Expand All @@ -382,6 +391,7 @@
" [project_id] * count,\n",
" [region] * count,\n",
" [agent_id] * count,\n",
" [environment_id] * count,\n",
" [language_code] * count,\n",
" flow_ids,\n",
" page_ids,\n",
Expand Down Expand Up @@ -418,20 +428,38 @@
" page_ids[(flow_id, 'End Session')] = 'END_SESSION'\n",
" return page_ids\n",
"\n",
"def get_environment_id(environment_client: EnvironmentsClient, project_id: str, region: str, agent_id: str, environment_name: str) -> str:\n",
" environments = environment_client.list_environments(parent=f'projects/{project_id}/locations/{region}/agents/{agent_id}')\n",
" for environment in environments:\n",
" if environment.display_name == environment_name:\n",
" return environment.name.split('/')[-1]\n",
" raise ValueError(f'Environment {environment_name} not found. There are following environments available: {\", \".join([\"draft\"] + [environment.display_name for environment in environments])}')\n",
"\n",
"\n",
"def evaluate_dataset(\n",
" session_client: SessionsClient,\n",
" flow_client: FlowsClient,\n",
" page_client: PagesClient,\n",
" environment_client: EnvironmentsClient,\n",
" dataset: pd.DataFrame,\n",
" agent_label: str,\n",
" project_id: str,\n",
" region: str,\n",
" agent_id: str,\n",
" environment_name: str,\n",
" language_code: str,\n",
") -> pd.DataFrame:\n",
" \"\"\"Evaluates datasets and populates new columns inside it.\"\"\"\n",
"\n",
" # Map environment\n",
" if environment_name and environment_name != 'draft':\n",
" print('Getting environment mapping...', end='')\n",
" environment_id = get_environment_id(environment_client, project_id, region, agent_id, environment_name)\n",
" print('done')\n",
" print(f'Using environment id {environment_id} for environment {environment_name}.')\n",
" else:\n",
" environment_id = 'draft'\n",
"\n",
" # Map flows\n",
" print('Getting flow mapping...', end='')\n",
" flow_display_name_to_ids = get_flow_ids(flow_client, project_id, region, agent_id)\n",
Expand All @@ -457,6 +485,7 @@
" project_id,\n",
" region,\n",
" agent_id,\n",
" environment_id,\n",
" language_code,\n",
" flow_ids,\n",
" page_ids,\n",
Expand Down Expand Up @@ -495,13 +524,13 @@
" summary = pd.concat([summary, pd.DataFrame(run_values, index=[0])], ignore_index=True)\n",
"\n",
" summary_worksheet.clear()\n",
" summary_worksheet.update(range_name=\"A:H\", values=([summary.columns.values.tolist()] + summary.values.tolist()))\n",
" summary_worksheet.update(range_name=\"A:ZZ\", values=([summary.columns.values.tolist()] + summary.values.tolist()))\n",
"\n",
"def write_dataset(dataset_worksheet: gspread.Worksheet, dataset: pd.DataFrame) -> None:\n",
" dataset = dataset.drop(['Flow Id', 'Page Id'], axis=1)\n",
"\n",
" dataset_worksheet.clear()\n",
" dataset_worksheet.update(range_name=\"A:H\", values=([dataset.columns.values.tolist()] + dataset.values.tolist()))\n",
" dataset_worksheet.update(range_name=\"A:ZZ\", values=([dataset.columns.values.tolist()] + dataset.values.tolist()))\n",
"\n",
"\n",
"def run(\n",
Expand All @@ -510,6 +539,7 @@
" project_id: str,\n",
" region: str,\n",
" agent_id: str,\n",
" environment_name: str,\n",
" language_code: str,\n",
" spreadsheet_url: str,\n",
" dataset_tab: str,\n",
Expand All @@ -519,11 +549,12 @@
" session_client = get_session_client(region)\n",
" flow_client = get_flow_client(region)\n",
" page_client = get_page_client(region)\n",
" environment_client = get_environment_client(region)\n",
"\n",
" spreadsheet_info, dataset = setup_spreadsheet(spreadsheet_client, spreadsheet_url, dataset_tab, summary_tab)\n",
"\n",
" timestamp = datetime.datetime.utcnow().strftime('%Y.%m.%d %H:%M:%S')\n",
" dataset = evaluate_dataset(session_client, flow_client, page_client, dataset, agent_label, project_id, region, agent_id, language_code)\n",
" dataset = evaluate_dataset(session_client, flow_client, page_client, environment_client, dataset, agent_label, project_id, region, agent_id, environment_name, language_code)\n",
"\n",
" write_dataset(spreadsheet_info['dataset_worksheet'], dataset)\n",
" write_summary(spreadsheet_info['summary_worksheet'], dataset, agent_label, timestamp, dataset_tab)\n"
Expand Down Expand Up @@ -571,6 +602,7 @@
"project_id_1 = 'YOUR_PROJECT_ID' #@param {type:\"string\"}\n",
"region_1 = 'us-central1' #@param {type:\"string\"}\n",
"agent_id_1 = 'YOUR_AGENT_ID' #@param {type:\"string\"}\n",
"environment_name_1 = 'draft' #@param {type:\"string\"}\n",
"language_code_1 = 'en' #@param {type:\"string\"}\n",
"\n",
"#@markdown &nbsp;\n",
Expand All @@ -585,6 +617,7 @@
"project_id_2 = 'YOUR_PROJECT_ID' #@param {type:\"string\"}\n",
"region_2 = 'global' #@param {type:\"string\"}\n",
"agent_id_2 = 'YOUR_AGENT_ID' #@param {type:\"string\"}\n",
"environment_name_2 = 'prod' #@param {type:\"string\"}\n",
"language_code_2 = 'en' #@param {type:\"string\"}\n",
"\n",
"#@markdown &nbsp;\n",
Expand All @@ -597,13 +630,13 @@
"auth.authenticate_user(project_id=project_id_1)\n",
"creds, _ = default()\n",
"print(f\"### Running evaluation on agent {agent_label_1} ###\\n\")\n",
"run(creds, agent_label_1, project_id_1, region_1, agent_id_1, language_code_1, spreadsheet_url, dataset_tab, summary_tab)\n",
"run(creds, agent_label_1, project_id_1, region_1, agent_id_1, environment_name_1, language_code_1, spreadsheet_url, dataset_tab, summary_tab)\n",
"\n",
"if use_second_agent:\n",
" auth.authenticate_user(project_id=project_id_2)\n",
" creds, _ = default()\n",
" print(f\"\\n### Running evaluation on agent {agent_label_2} ###\\n\")\n",
" run(creds, agent_label_2, project_id_2, region_2, agent_id_2, language_code_2, spreadsheet_url, dataset_tab, summary_tab)\n"
" run(creds, agent_label_2, project_id_2, region_2, agent_id_2, environment_name_2, language_code_2, spreadsheet_url, dataset_tab, summary_tab)\n"
]
}
],
Expand Down

0 comments on commit 8d9970a

Please sign in to comment.