Skip to content

Commit

Permalink
Merge pull request #46 from lmcinnes/master
Browse files Browse the repository at this point in the history
Stream matrix creation in blocks for lower overhead; if only we could do this with tokenization...
  • Loading branch information
jc-healy authored Jun 29, 2020
2 parents 7e454b0 + 50f309e commit fda0a0b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 19 deletions.
54 changes: 43 additions & 11 deletions vectorizers/_vectorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ def sequence_tree_skip_grams(
global_counts += global_counts.T
elif window_orientation == "after":
pass
elif window_orientation == 'directional':
elif window_orientation == "directional":
global_counts = scipy.sparse.hstack([global_counts.T, global_counts])
else:
raise ValueError(
Expand All @@ -924,6 +924,7 @@ def token_cooccurrence_matrix(
window_args,
kernel_args,
window_orientation="symmetric",
chunk_size=1 << 20,
):
"""Generate a matrix of (weighted) counts of co-occurrences of tokens within
windows in a set of sequences of tokens. Each sequence in the collection of
Expand Down Expand Up @@ -966,6 +967,11 @@ def token_cooccurrence_matrix(
symmetric: counts tokens occurring before and after as the same tokens
directional: counts tokens before and after as different and returns both counts.
chunk_size: int (optional, default=1048576)
When processing token sequences, break the list of sequences into
chunks of this size to stream the data through, rather than storing all
the results at once. This saves on peak memory usage.
Returns
-------
cooccurrence_matrix: scipyr.sparse.csr_matrix
Expand All @@ -976,17 +982,35 @@ def token_cooccurrence_matrix(
if n_unique_tokens == 0:
raise ValueError("Token dictionary is empty; try using less extreme contraints")

raw_coo_data = sequence_skip_grams(
token_sequences, window_function, kernel_function, window_args, kernel_args
)
cooccurrence_matrix = scipy.sparse.coo_matrix(
(
raw_coo_data.T[2],
(raw_coo_data.T[0].astype(np.int64), raw_coo_data.T[1].astype(np.int64)),
),
shape=(n_unique_tokens, n_unique_tokens),
dtype=np.float32,
(n_unique_tokens, n_unique_tokens), dtype=np.float32
)
n_chunks = (len(token_sequences) // chunk_size) + 1

for chunk_index in range(n_chunks):
chunk_start = chunk_index * chunk_size
chunk_end = min(len(token_sequences), chunk_start + chunk_size)

raw_coo_data = sequence_skip_grams(
token_sequences[chunk_start:chunk_end],
window_function,
kernel_function,
window_args,
kernel_args,
)
cooccurrence_matrix += scipy.sparse.coo_matrix(
(
raw_coo_data.T[2],
(
raw_coo_data.T[0].astype(np.int64),
raw_coo_data.T[1].astype(np.int64),
),
),
shape=(n_unique_tokens, n_unique_tokens),
dtype=np.float32,
)
cooccurrence_matrix.sum_duplicates()

if window_orientation == "before":
cooccurrence_matrix = cooccurrence_matrix.transpose()
elif window_orientation == "after":
Expand Down Expand Up @@ -1187,6 +1211,11 @@ class TokenCooccurrenceVectorizer(BaseEstimator, TransformerMixin):
symmetric: counts tokens occurring before and after as the same tokens
directional: counts tokens before and after as different and returns both counts.
chunk_size: int (optional, default=1048576)
When processing token sequences, break the list of sequences into
chunks of this size to stream the data through, rather than storing all
the results at once. This saves on peak memory usage.
validate_data: bool (optional, default=True)
Check whether the data is valid (e.g. of homogeneous token type).
"""
Expand All @@ -1208,6 +1237,7 @@ def __init__(
kernel_function="flat",
window_radius=5,
window_orientation="directional",
chunk_size=1<<20,
validate_data=True,
):
self.token_dictionary = token_dictionary
Expand All @@ -1227,6 +1257,7 @@ def __init__(
self.window_radius = window_radius

self.window_orientation = window_orientation
self.chunk_size = chunk_size
self.validate_data = validate_data

def fit_transform(self, X, y=None, **fit_params):
Expand Down Expand Up @@ -1291,6 +1322,7 @@ def fit_transform(self, X, y=None, **fit_params):
window_args=(self._window_size, self._token_frequencies_),
kernel_args=(self.window_radius,),
window_orientation=self.window_orientation,
chunk_size=self.chunk_size,
)
self.cooccurrences_.eliminate_zeros()

Expand Down Expand Up @@ -1599,7 +1631,7 @@ def transform(self, X):
self.token_label_dictionary_,
self.token_index_dictionary_,
self._token_frequencies_,
) = preprocess_tree_sequences(X, flat_sequences, self.token_label_dictionary_, )
) = preprocess_tree_sequences(X, flat_sequences, self.token_label_dictionary_,)

if callable(self.kernel_function):
self._kernel_function = self.kernel_function
Expand Down
2 changes: 1 addition & 1 deletion vectorizers/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.3"
__version__ = "0.0.3"
21 changes: 14 additions & 7 deletions vectorizers/optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
## Defaults to double for everythig in POT
INFINITY = np.finfo(np.float64).max
MAX = np.finfo(np.float64).max
FLOAT32_MAX = np.finfo(np.float32).max

dummy_cost = np.zeros((2, 2), dtype=np.float64)

Expand Down Expand Up @@ -943,10 +944,16 @@ def kantorovich_distance(x, y, cost=dummy_cost, max_iter=100000):
a_sum = a.sum()
b_sum = b.sum()

if not isclose(a_sum, b_sum):
raise ValueError(
"Kantorovich distance inputs must be valid probability distributions."
)
# Handle all zero vectors more gracefully
if a_sum == 0.0 and b_sum == 0.0:
return 0.0
elif a_sum == 0.0 or b_sum == 0.0:
return FLOAT32_MAX

# if not isclose(a_sum, b_sum):
# raise ValueError(
# "Kantorovich distance inputs must be valid probability distributions."
# )

a /= a_sum
b /= b_sum
Expand All @@ -965,9 +972,9 @@ def kantorovich_distance(x, y, cost=dummy_cost, max_iter=100000):
"Kantorovich distance inputs must be valid probability distributions."
)
solve_status = network_simplex_core(node_arc_data, spanning_tree, graph, max_iter,)
if solve_status == ProblemStatus.MAX_ITER_REACHED:
print("WARNING: RESULT MIGHT BE INACURATE\nMax number of iteration reached!")
elif solve_status == ProblemStatus.INFEASIBLE:
# if solve_status == ProblemStatus.MAX_ITER_REACHED:
# print("WARNING: RESULT MIGHT BE INACCURATE\nMax number of iteration reached!")
if solve_status == ProblemStatus.INFEASIBLE:
raise ValueError(
"Optimal transport problem was INFEASIBLE. Please check " "inputs."
)
Expand Down

0 comments on commit fda0a0b

Please sign in to comment.