Skip to content

Commit 8d7d14d

Browse files
adjust multilabel notebook to adapt prediction (Azure#1972)
data format change
1 parent c2e8945 commit 8d7d14d

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

v1/python-sdk/tutorials/automl-with-azureml/automl-nlp-multilabel/automl-nlp-text-classification-multilabel.ipynb

+4-9
Original file line numberDiff line numberDiff line change
@@ -454,10 +454,7 @@
454454
"outputs": [],
455455
"source": [
456456
"test_data_df = test_dataset.to_pandas_dataframe()\n",
457-
"test_set_predictions_df = pd.read_csv(\"preds_multilabel.csv\")\n",
458-
"test_set_predictions_df[\"label_confidence\"] = test_set_predictions_df[\n",
459-
" \"label_confidence\"\n",
460-
"].apply(lambda x: [float(num) for num in x.split(\",\")])"
457+
"test_set_predictions_df = pd.read_csv(\"preds_multilabel.csv\")"
461458
]
462459
},
463460
{
@@ -507,10 +504,7 @@
507504
"metadata": {},
508505
"outputs": [],
509506
"source": [
510-
"test_pred_probs = []\n",
511-
"for i in range(test_set_predictions_df.shape[0]):\n",
512-
" test_pred_probs.append(test_set_predictions_df.loc[i, \"label_confidence\"])\n",
513-
"test_pred_probs = np.array(test_pred_probs)"
507+
"test_pred_probs = test_set_predictions_df.to_numpy()"
514508
]
515509
},
516510
{
@@ -572,11 +566,12 @@
572566
" y_true = []\n",
573567
" y_pred = []\n",
574568
"\n",
569+
" pred_df = pred_df.to_numpy()\n",
575570
" for row in range(test_df.shape[0]):\n",
576571
" true_labels = y_transformer.transform(\n",
577572
" [ast.literal_eval(test_df.loc[row, label_col])]\n",
578573
" ).toarray()[0]\n",
579-
" pred_labels = pred_df.loc[row, \"label_confidence\"]\n",
574+
" pred_labels = pred_df[row]\n",
580575
" for ind, (label, prob) in enumerate(zip(true_labels, pred_labels)):\n",
581576
" predict_positive = prob >= threshold\n",
582577
" if label or predict_positive:\n",

0 commit comments

Comments
 (0)