|
454 | 454 | "outputs": [],
|
455 | 455 | "source": [
|
456 | 456 | "test_data_df = test_dataset.to_pandas_dataframe()\n",
|
457 |
| - "test_set_predictions_df = pd.read_csv(\"preds_multilabel.csv\")\n", |
458 |
| - "test_set_predictions_df[\"label_confidence\"] = test_set_predictions_df[\n", |
459 |
| - " \"label_confidence\"\n", |
460 |
| - "].apply(lambda x: [float(num) for num in x.split(\",\")])" |
| 457 | + "test_set_predictions_df = pd.read_csv(\"preds_multilabel.csv\")" |
461 | 458 | ]
|
462 | 459 | },
|
463 | 460 | {
|
|
507 | 504 | "metadata": {},
|
508 | 505 | "outputs": [],
|
509 | 506 | "source": [
|
510 |
| - "test_pred_probs = []\n", |
511 |
| - "for i in range(test_set_predictions_df.shape[0]):\n", |
512 |
| - " test_pred_probs.append(test_set_predictions_df.loc[i, \"label_confidence\"])\n", |
513 |
| - "test_pred_probs = np.array(test_pred_probs)" |
| 507 | + "test_pred_probs = test_set_predictions_df.to_numpy()" |
514 | 508 | ]
|
515 | 509 | },
|
516 | 510 | {
|
|
572 | 566 | " y_true = []\n",
|
573 | 567 | " y_pred = []\n",
|
574 | 568 | "\n",
|
| 569 | + " pred_df = pred_df.to_numpy()\n", |
575 | 570 | " for row in range(test_df.shape[0]):\n",
|
576 | 571 | " true_labels = y_transformer.transform(\n",
|
577 | 572 | " [ast.literal_eval(test_df.loc[row, label_col])]\n",
|
578 | 573 | " ).toarray()[0]\n",
|
579 |
| - " pred_labels = pred_df.loc[row, \"label_confidence\"]\n", |
| 574 | + " pred_labels = pred_df[row]\n", |
580 | 575 | " for ind, (label, prob) in enumerate(zip(true_labels, pred_labels)):\n",
|
581 | 576 | " predict_positive = prob >= threshold\n",
|
582 | 577 | " if label or predict_positive:\n",
|
|
0 commit comments