diff --git a/nn_meter/utils/import_package.py b/nn_meter/utils/import_package.py index 08a5cb79..f4304d2a 100644 --- a/nn_meter/utils/import_package.py +++ b/nn_meter/utils/import_package.py @@ -28,7 +28,7 @@ def try_import_torch(require_version = ["1.9.0", "1.7.1"]): logging.error(f'You have not install the torch package, please install torch=={require_version[0]} and try again.') exit() -def try_import_tensorflow(require_version = ["1.15.0"]): +def try_import_tensorflow(require_version = ["2.6.0", "1.15.0"]): if isinstance(require_version, str): require_version = [require_version] try: @@ -40,7 +40,7 @@ def try_import_tensorflow(require_version = ["1.15.0"]): logging.error(f'You have not install the tensorflow package, please install tensorflow=={require_version[0]} and try again.') exit() -def try_import_nni(require_version = ["2.4", "2.5"]): +def try_import_nni(require_version = ["2.5", "2.4"]): if isinstance(require_version, str): require_version = [require_version] try: @@ -49,7 +49,7 @@ def try_import_nni(require_version = ["2.4", "2.5"]): logging.warning(f'nni=={nni.__version__} is not well tested now, well tested version: nni=={", ".join(require_version)}' ) return nni except ImportError: - logging.error(f'You have not install the tensorflow package, please install tensorflow=={require_version[0]} and try again.') + logging.error(f'You have not install the nni package, please install nni=={require_version[0]} and try again.') exit() def try_import_torchvision_models():