Skip to content

Commit

Permalink
Bug in EncExp addition
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Nov 8, 2024
1 parent da2aa39 commit d9be203
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
6 changes: 4 additions & 2 deletions encexp/tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,10 @@ def test_build_encexp_tokens():
for txt in tweet_iterator('es-mx-sample.json'):
cnt.update([x for x in tokenize(txt) if x[:2] != 'q:'])
voc = download_seqtm(lang='es', voc_source='noGeo')
words = set(cnt.keys()) - set(voc['counter']['dict'])
words = sorted([word for word in words if cnt[word] >= 8])
# words = set(cnt.keys()) - set(voc['counter']['dict'])
words = [word for word in cnt if cnt[word] >= 8]
words.append('de')
words = sorted(words)
output, cnt = encode(voc, 'es-mx-sample.json', tokens=words)
tokens = feasible_tokens(voc, cnt, tokens=words,
min_pos=8)
Expand Down
16 changes: 9 additions & 7 deletions encexp/text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,18 +322,19 @@ def fit(self, D, y=None):

def force_tokens_weights(self, IDF: bool=False):
"""Set the maximum weight"""
rows = np.arange(len(self.names))
# rows = np.arange(len(self.names))
rows = np.array([i for i, k in enumerate(self.names)
if k in self.bow.token2id])

cols = np.array([self.bow.token2id[x] for x in self.names
if x in self.bow.token2id])
if cols.shape[0] == 0:
return
w = self.weights[:, cols]
if IDF:
w = w * self.bow.weights[cols]
w = self.weights[rows][cols] * self.bow.weights[cols]
_max = (w.max(axis=1) / self.bow.weights[cols]).astype(self.precision)
else:
_max = w.max(axis=1)
_max = self.weights[rows].max(axis=1)
self.weights[rows, cols] = _max

@property
Expand All @@ -356,8 +357,8 @@ def weights(self):
return self._weights
except AttributeError:
if self.EncExp_filename is not None:
data = download_encexp(output=self.EncExp_filename,
precision=self.precision)
data = download_encexp(output=self.EncExp_filename)
# precision=self.precision)
else:
if self.intercept:
assert not self.merge_IDF
Expand Down Expand Up @@ -559,6 +560,7 @@ def __sklearn_clone__(self):
ins.weights = self.weights
ins.bow = self.bow
ins.names = self.names
ins.estimator = clone(self.estimator)
if hasattr(self, '_estimator'):
ins.estimator = clone(self.estimator)
ins.enc_training_size = self.enc_training_size
return ins

0 comments on commit d9be203

Please sign in to comment.