Skip to content

Commit

Permalink
Adapt BytePairEmbeddings class to new version of bpemb
Browse files Browse the repository at this point in the history
Now, we can load custom BPEmb
  • Loading branch information
mauryaland committed Jul 17, 2020
1 parent 17fa344 commit f1aacab
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
26 changes: 23 additions & 3 deletions flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,18 +1548,38 @@ def __setstate__(self, state):
class BytePairEmbeddings(TokenEmbeddings):
def __init__(
self,
language: str,
language: str = None,
dim: int = 50,
syllables: int = 100000,
cache_dir=Path(flair.cache_root) / "embeddings",
model_file_path: Path = None,
embedding_file_path: Path = None,
**kwargs,
):
"""
Initializes BP embeddings. Constructor downloads required files if not there.
"""
if language:
self.name: str = f"bpe-{language}-{syllables}-{dim}"
else:
assert (
model_file_path is not None and embedding_file_path is not None
), "Need to specify model_file_path and embedding_file_path if no language is given in BytePairEmbeddings(...)"
dim=None

self.embedder = BPEmb(
lang=language,
vs=syllables,
dim=dim,
cache_dir=cache_dir,
model_file=model_file_path,
emb_file=embedding_file_path,
**kwargs,
)

self.name: str = f"bpe-{language}-{syllables}-{dim}"
if not language:
self.name: str = f"bpe-custom-{self.embedder.vs}-{self.embedder.dim}"
self.static_embeddings = True
self.embedder = BPEmb(lang=language, vs=syllables, dim=dim, cache_dir=cache_dir)

self.__embedding_length: int = self.embedder.emb.vector_size * 2
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ sqlitedict>=1.6.0
deprecated>=1.2.4
hyperopt>=0.1.1
transformers>=3.0.0
bpemb>=0.2.9
bpemb>=0.3.2
regex
tabulate
langdetect
Expand Down
7 changes: 7 additions & 0 deletions resources/docs/embeddings/BYTE_PAIR_EMBEDDINGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,10 @@ on the [byte pair embeddings](https://nlp.h-its.org/bpemb/) web page.
# init embedding
embedding = BytePairEmbeddings('multi')
```

You can also load custom `BytePairEmbeddings` by specifying a path to model_file_path and embedding_file_path arguments. They correspond respectively to a SentencePiece model file and to an embedding file (Word2Vec plain text or GenSim binary). For example:

```python
# init custom embedding
embedding = BytePairEmbeddings(model_file_path='your/path/m.model', embedding_file_path='your/path/w2v.txt')
```

0 comments on commit f1aacab

Please sign in to comment.