Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update nlu_evaluation.ipynb #572

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions nlu-evaluation/nlu_evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@
"from google.cloud.dialogflowcx_v3.services.flows import FlowsClient\n",
"from google.cloud.dialogflowcx_v3.services.pages import PagesClient\n",
"from google.cloud.dialogflowcx_v3.services.sessions import SessionsClient\n",
"from google.cloud.dialogflowcx_v3.services.environments import EnvironmentsClient\n",
"from google.cloud.dialogflowcx_v3.types import session\n",
"\n",
"import warnings\n",
Expand Down Expand Up @@ -269,7 +270,8 @@
" print(f'Worksheet `{summary_tab}` does not exist. Creating...')\n",
" info['summary_worksheet'] = info['spreadsheet'].add_worksheet(summary_tab, rows=summary_rows, cols=len(summary_columns))\n",
"\n",
" if not info['summary_worksheet'].get_all_values():\n",
" summary_values = info['summary_worksheet'].get_all_values()\n",
" if len(summary_values) == 1 and not summary_values[0]:\n",
" print('Summary worksheet is empty. Initializing...')\n",
" cell_list = info['summary_worksheet'].range(f'A1:{chr(ord(\"A\") + len(summary_columns) - 1)}1')\n",
" for i, cell in enumerate(cell_list):\n",
Expand Down Expand Up @@ -322,6 +324,11 @@
" api_endpoint = get_api_endpoint(region)\n",
" return PagesClient(client_options={'api_endpoint': api_endpoint})\n",
"\n",
"def get_environment_client(region: str) -> EnvironmentsClient:\n",
" '''Returns DF CX EnvironmentClient based on region.'''\n",
" api_endpoint = get_api_endpoint(region)\n",
" return EnvironmentsClient(client_options={'api_endpoint': api_endpoint})\n",
"\n",
"\n",
"@ratelimit.sleep_and_retry\n",
"@ratelimit.limits(calls=MAX_CALLS, period=ONE_SECOND)\n",
Expand All @@ -330,11 +337,12 @@
" project_id: str,\n",
" region: str,\n",
" agent_id: str,\n",
" environment_id: str,\n",
" language_code: str,\n",
" flow_id: str,\n",
" page_id: str,\n",
" utterance: str) -> tuple[str, float]:\n",
" session_path = f'projects/{project_id}/locations/{region}/agents/{agent_id}/sessions/{uuid.uuid4()}'\n",
" session_path = f'projects/{project_id}/locations/{region}/agents/{agent_id}/environments/{environment_id}/sessions/{uuid.uuid4()}'\n",
"\n",
" current_page_path = f'projects/{project_id}/locations/{region}/agents/{agent_id}/flows/{flow_id}/pages/{page_id}'\n",
"\n",
Expand All @@ -360,6 +368,7 @@
" project_id: str,\n",
" region: str,\n",
" agent_id: str,\n",
" environment_id: str,\n",
" language_code: str,\n",
" flow_ids: list[str],\n",
" page_ids: list[str],\n",
Expand All @@ -382,6 +391,7 @@
" [project_id] * count,\n",
" [region] * count,\n",
" [agent_id] * count,\n",
" [environment_id] * count,\n",
" [language_code] * count,\n",
" flow_ids,\n",
" page_ids,\n",
Expand Down Expand Up @@ -418,20 +428,38 @@
" page_ids[(flow_id, 'End Session')] = 'END_SESSION'\n",
" return page_ids\n",
"\n",
"def get_environment_id(environment_client: EnvironmentsClient, project_id: str, region: str, agent_id: str, environment_name: str) -> str:\n",
" environments = environment_client.list_environments(parent=f'projects/{project_id}/locations/{region}/agents/{agent_id}')\n",
" for environment in environments:\n",
" if environment.display_name == environment_name:\n",
" return environment.name.split('/')[-1]\n",
" raise ValueError(f'Environment {environment_name} not found. There are following environments available: {\", \".join([\"draft\"] + [environment.display_name for environment in environments])}')\n",
"\n",
"\n",
"def evaluate_dataset(\n",
" session_client: SessionsClient,\n",
" flow_client: FlowsClient,\n",
" page_client: PagesClient,\n",
" environment_client: EnvironmentsClient,\n",
" dataset: pd.DataFrame,\n",
" agent_label: str,\n",
" project_id: str,\n",
" region: str,\n",
" agent_id: str,\n",
" environment_name: str,\n",
" language_code: str,\n",
") -> pd.DataFrame:\n",
" \"\"\"Evaluates datasets and populates new columns inside it.\"\"\"\n",
"\n",
" # Map environment\n",
" if environment_name and environment_name != 'draft':\n",
" print('Getting environment mapping...', end='')\n",
" environment_id = get_environment_id(environment_client, project_id, region, agent_id, environment_name)\n",
" print('done')\n",
" print(f'Using environment id {environment_id} for environment {environment_name}.')\n",
" else:\n",
" environment_id = 'draft'\n",
"\n",
" # Map flows\n",
" print('Getting flow mapping...', end='')\n",
" flow_display_name_to_ids = get_flow_ids(flow_client, project_id, region, agent_id)\n",
Expand All @@ -457,6 +485,7 @@
" project_id,\n",
" region,\n",
" agent_id,\n",
" environment_id,\n",
" language_code,\n",
" flow_ids,\n",
" page_ids,\n",
Expand Down Expand Up @@ -495,13 +524,13 @@
" summary = pd.concat([summary, pd.DataFrame(run_values, index=[0])], ignore_index=True)\n",
"\n",
" summary_worksheet.clear()\n",
" summary_worksheet.update(range_name=\"A:H\", values=([summary.columns.values.tolist()] + summary.values.tolist()))\n",
" summary_worksheet.update(range_name=\"A:ZZ\", values=([summary.columns.values.tolist()] + summary.values.tolist()))\n",
"\n",
"def write_dataset(dataset_worksheet: gspread.Worksheet, dataset: pd.DataFrame) -> None:\n",
" dataset = dataset.drop(['Flow Id', 'Page Id'], axis=1)\n",
"\n",
" dataset_worksheet.clear()\n",
" dataset_worksheet.update(range_name=\"A:H\", values=([dataset.columns.values.tolist()] + dataset.values.tolist()))\n",
" dataset_worksheet.update(range_name=\"A:ZZ\", values=([dataset.columns.values.tolist()] + dataset.values.tolist()))\n",
"\n",
"\n",
"def run(\n",
Expand All @@ -510,6 +539,7 @@
" project_id: str,\n",
" region: str,\n",
" agent_id: str,\n",
" environment_name: str,\n",
" language_code: str,\n",
" spreadsheet_url: str,\n",
" dataset_tab: str,\n",
Expand All @@ -519,11 +549,12 @@
" session_client = get_session_client(region)\n",
" flow_client = get_flow_client(region)\n",
" page_client = get_page_client(region)\n",
" environment_client = get_environment_client(region)\n",
"\n",
" spreadsheet_info, dataset = setup_spreadsheet(spreadsheet_client, spreadsheet_url, dataset_tab, summary_tab)\n",
"\n",
" timestamp = datetime.datetime.utcnow().strftime('%Y.%m.%d %H:%M:%S')\n",
" dataset = evaluate_dataset(session_client, flow_client, page_client, dataset, agent_label, project_id, region, agent_id, language_code)\n",
" dataset = evaluate_dataset(session_client, flow_client, page_client, environment_client, dataset, agent_label, project_id, region, agent_id, environment_name, language_code)\n",
"\n",
" write_dataset(spreadsheet_info['dataset_worksheet'], dataset)\n",
" write_summary(spreadsheet_info['summary_worksheet'], dataset, agent_label, timestamp, dataset_tab)\n"
Expand Down Expand Up @@ -571,6 +602,7 @@
"project_id_1 = 'YOUR_PROJECT_ID' #@param {type:\"string\"}\n",
"region_1 = 'us-central1' #@param {type:\"string\"}\n",
"agent_id_1 = 'YOUR_AGENT_ID' #@param {type:\"string\"}\n",
"environment_name_1 = 'draft' #@param {type:\"string\"}\n",
"language_code_1 = 'en' #@param {type:\"string\"}\n",
"\n",
"#@markdown  \n",
Expand All @@ -585,6 +617,7 @@
"project_id_2 = 'YOUR_PROJECT_ID' #@param {type:\"string\"}\n",
"region_2 = 'global' #@param {type:\"string\"}\n",
"agent_id_2 = 'YOUR_AGENT_ID' #@param {type:\"string\"}\n",
"environment_name_2 = 'prod' #@param {type:\"string\"}\n",
"language_code_2 = 'en' #@param {type:\"string\"}\n",
"\n",
"#@markdown  \n",
Expand All @@ -597,13 +630,13 @@
"auth.authenticate_user(project_id=project_id_1)\n",
"creds, _ = default()\n",
"print(f\"### Running evaluation on agent {agent_label_1} ###\\n\")\n",
"run(creds, agent_label_1, project_id_1, region_1, agent_id_1, language_code_1, spreadsheet_url, dataset_tab, summary_tab)\n",
"run(creds, agent_label_1, project_id_1, region_1, agent_id_1, environment_name_1, language_code_1, spreadsheet_url, dataset_tab, summary_tab)\n",
"\n",
"if use_second_agent:\n",
" auth.authenticate_user(project_id=project_id_2)\n",
" creds, _ = default()\n",
" print(f\"\\n### Running evaluation on agent {agent_label_2} ###\\n\")\n",
" run(creds, agent_label_2, project_id_2, region_2, agent_id_2, language_code_2, spreadsheet_url, dataset_tab, summary_tab)\n"
" run(creds, agent_label_2, project_id_2, region_2, agent_id_2, environment_name_2, language_code_2, spreadsheet_url, dataset_tab, summary_tab)\n"
]
}
],
Expand Down
Loading