diff --git a/vectorizers/_vectorizers.py b/vectorizers/_vectorizers.py index 0c700d4..ab588a7 100644 --- a/vectorizers/_vectorizers.py +++ b/vectorizers/_vectorizers.py @@ -967,6 +967,11 @@ def token_cooccurrence_matrix( symmetric: counts tokens occurring before and after as the same tokens directional: counts tokens before and after as different and returns both counts. + chunk_size: int (optional, default=1048576) + When processing token sequences, break the list of sequences into + chunks of this size to stream the data through, rather than storing all + the results at once. This saves on peak memory usage. + Returns ------- cooccurrence_matrix: scipyr.sparse.csr_matrix @@ -1004,6 +1009,7 @@ def token_cooccurrence_matrix( shape=(n_unique_tokens, n_unique_tokens), dtype=np.float32, ) + cooccurrence_matrix.sum_duplicates() if window_orientation == "before": cooccurrence_matrix = cooccurrence_matrix.transpose() @@ -1205,6 +1211,11 @@ class TokenCooccurrenceVectorizer(BaseEstimator, TransformerMixin): symmetric: counts tokens occurring before and after as the same tokens directional: counts tokens before and after as different and returns both counts. + chunk_size: int (optional, default=1048576) + When processing token sequences, break the list of sequences into + chunks of this size to stream the data through, rather than storing all + the results at once. This saves on peak memory usage. + validate_data: bool (optional, default=True) Check whether the data is valid (e.g. of homogeneous token type). """ @@ -1226,6 +1237,7 @@ def __init__( kernel_function="flat", window_radius=5, window_orientation="directional", + chunk_size=1<<20, validate_data=True, ): self.token_dictionary = token_dictionary @@ -1245,6 +1257,7 @@ def __init__( self.window_radius = window_radius self.window_orientation = window_orientation + self.chunk_size = chunk_size self.validate_data = validate_data def fit_transform(self, X, y=None, **fit_params): @@ -1309,6 +1322,7 @@ def fit_transform(self, X, y=None, **fit_params): window_args=(self._window_size, self._token_frequencies_), kernel_args=(self.window_radius,), window_orientation=self.window_orientation, + chunk_size=self.chunk_size, ) self.cooccurrences_.eliminate_zeros()