Skip to content

Commit

Permalink
Merge pull request #61 from raminqaf/fix/gensim-size-param
Browse files Browse the repository at this point in the history
Check gensim version and set size parameter name respectively
  • Loading branch information
eliorc authored Apr 2, 2021
2 parents f6623b7 + d61531d commit 81a3c23
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion node2vec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import edges
from .node2vec import Node2Vec

__version__ = '0.4.1'
__version__ = '0.4.2'
17 changes: 11 additions & 6 deletions node2vec/node2vec.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import random
import os
import random
from collections import defaultdict

import numpy as np
import networkx as nx
import gensim
import networkx as nx
import numpy as np
import pkg_resources
from joblib import Parallel, delayed
from tqdm.auto import tqdm

Expand Down Expand Up @@ -166,16 +167,20 @@ def _generate_walks(self) -> list:
def fit(self, **skip_gram_params) -> gensim.models.Word2Vec:
"""
Creates the embeddings using gensim's Word2Vec.
:param skip_gram_params: Parameteres for gensim.models.Word2Vec - do not supply 'size' it is taken from the Node2Vec 'dimensions' parameter
:param skip_gram_params: Parameters for gensim.models.Word2Vec - do not supply 'size' / 'vector_size' it is
taken from the Node2Vec 'dimensions' parameter
:type skip_gram_params: dict
:return: A gensim word2vec model
"""

if 'workers' not in skip_gram_params:
skip_gram_params['workers'] = self.workers

if 'size' not in skip_gram_params:
skip_gram_params['size'] = self.dimensions
# Figure out gensim version, naming of output dimensions changed from size to vector_size in v4.0.0
gensim_version = pkg_resources.get_distribution("gensim").version
size = 'size' if gensim_version < '4.0.0' else 'vector_size'
if size not in skip_gram_params:
skip_gram_params[size] = self.dimensions

if 'sg' not in skip_gram_params:
skip_gram_params['sg'] = 1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name='node2vec',
packages=['node2vec'],
version='0.4.1',
version='0.4.2',
description='Implementation of the node2vec algorithm.',
author='Elior Cohen',
author_email='[email protected]',
Expand Down

0 comments on commit 81a3c23

Please sign in to comment.