Skip to content

Commit 4a8a7f8

Browse files
committed
bandwidth tuning via classification
1 parent d64a7c0 commit 4a8a7f8

16 files changed

+878
-651
lines changed

.pre-commit-config.yaml

+10-5
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,18 @@ repos:
88
- id: requirements-txt-fixer
99
- id: trailing-whitespace
1010

11+
- repo: https://github.com/PyCQA/flake8
12+
rev: 3.9.2
13+
hooks:
14+
- id: flake8
15+
16+
- repo: https://github.com/pycqa/isort
17+
rev: 5.6.4
18+
hooks:
19+
- id: isort
20+
1121
- repo: https://github.com/ambv/black
1222
rev: 21.9b0
1323
hooks:
1424
- id: black
1525
language_version: python3.8
16-
17-
- repo: https://github.com/PyCQA/flake8
18-
rev: 3.9.2
19-
hooks:
20-
- id: flake8

AUTHORS

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Aleksandr Artemenkov <[email protected]>
2+
Nikita Kotelevskii <[email protected]>
3+
mephody-bro <[email protected]>

ChangeLog

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
CHANGES
2+
=======
3+
4+
* hnsw without parallel
5+
* n\_jobs parameter
6+
* parallelize
7+
* local variant
8+
* refactor: initial commit
9+
* Fix the typo
10+
* Make the packcage
11+
* Cleaner example and little bit more compatible
12+
* Rename to nuq
13+
* Initial commit
14+
* Minor stuff with batch formation is fixed
15+
* Refined minibatches for uncertainty
16+
* Added minibatch for uncertainties
17+
* Added 3 types of uncertainty
18+
* Minor changes
19+
* Added small value for numerical stability in uncertainty
20+
* Added logsumexp computation of the conditional
21+
* Nadaraya-Watson method for uncertainty quantification

environment.yaml

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: ray-env
2+
channels:
3+
- conda-forge
4+
- defaults
5+
dependencies:
6+
- python=3.8
7+
- pip
8+
9+
# Computations
10+
- numpy
11+
- matplotlib
12+
- scikit-learn
13+
- scipy
14+
- pandas
15+
- seaborn
16+
17+
# Jupyter Lab
18+
- ipykernel
19+
- ipywidgets
20+
- nb_black # black formatter
21+
- wurlitzer # C level output
22+
23+
# Parallelism
24+
- ray-dashboard
25+
26+
# Utilities
27+
- tqdm # progress bar
28+
- black # formatter
29+
- isort # import sorter
30+
- flake8 # linter
31+
- pre_commit # git pre-commit hooks
32+
- pytest # code testing
33+
- pybind11 # hnswlib build
34+
- cython # KDEpy build
35+
36+
- pip:
37+
- hnswlib
38+
- KDEpy

example.ipynb

+349-237
Large diffs are not rendered by default.

nuq/bandwidth_selection.py

-40
This file was deleted.

nuq/kernels.py

-153
This file was deleted.

nuq/misc.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import re
2+
3+
import matplotlib.pyplot as plt
14
import numpy as np
25
from sklearn.datasets import make_blobs
3-
import matplotlib.pyplot as plt
46

57

68
def plot_data(X, y, title=None):
@@ -16,3 +18,13 @@ def plot_data(X, y, title=None):
1618
def make_data(total_size=5000, centers=np.array([[-4.0, -4.0], [0.0, 4.0]])):
1719
X, y = make_blobs(n_samples=total_size, n_features=2, centers=centers)
1820
return X, y
21+
22+
23+
def parse_param(param_string):
24+
res = param_string.split(":", maxsplit=1)
25+
if len(res) == 1:
26+
return res[0], {}
27+
name, vals_string = res
28+
vals_strings = filter(None, vals_string.split(";"))
29+
vals = [v.split("=", maxsplit=1) for v in vals_strings]
30+
return name, dict(vals)

0 commit comments

Comments
 (0)