Skip to content

Commit

Permalink
Merge pull request #83 from isi-usc-edu/johnp/miniML-fixes
Browse files Browse the repository at this point in the history
fixes issue #74 and #75
  • Loading branch information
jp7745 authored Dec 16, 2024
2 parents 51b15c5 + c24ee2f commit fa7dbe6
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 218 deletions.
6 changes: 2 additions & 4 deletions BubbleML/miniML/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ options:
Example call:

```bash
./miniML.py --ham_features_file TESTING_ONLY_Hamiltonian_features.csv --config_file miniML_config.json --solver_uuid bd63d0e0-a681-11ef-b4bd-6f1bb9e0689f --solver_labels_file TESTING_ONLY_solver.ccsdt.labels.csv --verbose
./miniML.py -v --ham_features_file ../../Hamiltonian_features/experimental/fast_double_factorization_features/Hamiltonian_features.csv --config_file miniML_config.json --solver_uuid 16537433-9f4c-4eae-a65d-787dc3b35b59 --solver_labels_file ../../scripts/solver_labels.DMRG_Niagara_cluster_lowest_energy.16537433-9f4c-4eae-a65d-787dc3b35b59.csv --verbose
```

Note that the `solver_uuid` is contrived for this example. The solver is CCSDT in this example.
Note that the `solver_uuid` is specific to the DMRG algorithm running on a specific compute platform in this example.

If the `--verbose` flag is included, then a `plot_<solver_uuid, datestamp>.png` plot and the `probs<solver_uuid, datestamp>.csv` file will be generated as artifacts.

TODO: As we are integrating, use the `TESTING_ONLY_Hamiltonian_features.csv` and `TESTING_ONLY_solver.ccsdt.labels.csv` files. As more data is collated, we should point to `../../Hamiltonian_features/experimental/fast_double_factorization/Hamiltonian_features.csv` and an appropriately aligned `solver.<solver_uuid>.labels.csv` files.
18 changes: 16 additions & 2 deletions BubbleML/miniML/miniML.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,30 @@ def trainML(
)
logging.info('Percent of solvable space: ', str(ml_solvability_ratio))

timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")


#explain all the predictions in the test set
plt.figure()
explainer = shap.KernelExplainer(model.predict_proba, X_train)
shap_values = explainer.shap_values(X_train)
class_index = 1
shap.initjs()
shap.summary_plot(shap_values[1],features=X.columns,plot_type="bar")

shap.summary_plot(
shap_values[1],
features=X.columns,
plot_type="bar"
)

#shap.force_plot(explainer.expected_value[class_index], shap_values[class_index], X_train, matplotlib=True, show=False)

plt.savefig(
f"shap_summary_plot_solver={solver_uuid}_{timestamp}.png",
format="png"
)


if verbose:
# print to file
y_pred = model.predict(X_train)
Expand All @@ -164,7 +179,6 @@ def trainML(
"prob_class_1": probs[:,1]
}
)
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")
probs_file_name = f"probs_solver={solver_uuid}_{timestamp}.csv"
df.to_csv(probs_file_name, index=False)
logging.info(f"wrote probs to file {probs_file_name}.")
Expand Down
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

This file was deleted.

Loading

0 comments on commit fa7dbe6

Please sign in to comment.