Skip to content

Commit 8d66a51

Browse files
authored
Add static type checking with mypy (pinecone-io#240)
## Problem Some parts of this code were including typehints, but we haven't been running any checks to verify the correctness of those annotations. Static typechecking helps to improve overall quality and catch mistakes. ## Solution - Add a dev dependency on `mypy` ([site](https://mypy-lang.org/)) and packages with types for third-party dependencies. - Make adjustments throughout the code to satify the typechecker - Add steps in CI to run mypy checks automatically on every change. - For now, exclude generated code in `pinecone/core` because it has a bunch of issues but is not easy to change. Summary of code changes: - Everywhere that has a default value of `None` needs to be wrapped in the `Optional` type from the `typings` package. Correct: `def my_method(foo: Optional[str] = None)`. Incorrect: `def my_method(foo: str = None)` - Refactored config-related stuff a bit because having config classes with inner `_config` attributes seemed like a smell and also made typing config-related stuff more challening. Ended up converting what was previously the `Config` class into a `ConfigBuilder` and the actual config object is now just a `NamedTuple` which is a container for immutable data. - Mypy doesn't like when you overwrite variables in a way that changes the type. So some places required me to introduce an additional intermediate variable to satisfy the type checker. - Mypy got mad about having a `Pinecone` class in both `pinecone` and `pinecone.grpc`. So I renamed the GRPC one to `PineconeGRPC`. ## Type of Change - [x] Infrastructure change (CI configs, etc) ## Test Plan See new step running in CI.
1 parent 46228c3 commit 8d66a51

21 files changed

+689
-452
lines changed

.github/actions/setup-poetry/action.yml

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ inputs:
55
description: 'Install gRPC dependencies'
66
required: true
77
default: 'false'
8+
89
runs:
910
using: 'composite'
1011
steps:

.github/workflows/testing.yaml

+15
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ jobs:
2424

2525
- name: Run unit tests
2626
run: poetry run pytest --cov=pinecone --timeout=120 tests/unit
27+
28+
- name: mypy check
29+
run: |
30+
# Still lots of errors when running on the whole package (especially
31+
# in the generated core module), but we can check these subpackages
32+
# so we don't add new regressions.
33+
poetry run mypy pinecone --exclude pinecone/core --exclude pinecone/grpc
2734
2835
units-grpc:
2936
name: Run tests (GRPC)
@@ -48,6 +55,14 @@ jobs:
4855
- name: Run unit tests (GRPC)
4956
run: poetry run pytest --cov=pinecone --timeout=120 tests/unit_grpc
5057

58+
- name: mypy check
59+
run: |
60+
# Still lots of errors when running on the whole package (especially
61+
# in the generated core module), but we can check these subpackages
62+
# so we don't add new regressions.
63+
poetry run mypy pinecone --exclude pinecone/core
64+
65+
5166
package:
5267
name: Check packaging
5368
runs-on: ubuntu-latest

Makefile

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ test-grpc-unit:
1717
@echo "Running tests..."
1818
poetry run pytest --cov=pinecone --timeout=120 tests/unit_grpc
1919

20+
make type-check:
21+
poetry run mypy pinecone --exclude pinecone/core
22+
2023
version:
2124
poetry version
2225

pinecone/config/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .config import Config
1+
from .config import ConfigBuilder, Config
22
from .pinecone_config import PineconeConfig

pinecone/config/config.py

+17-32
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
1-
from typing import NamedTuple
1+
from typing import NamedTuple, Optional
22
import os
33

4-
from pinecone.utils import check_kwargs
54
from pinecone.exceptions import PineconeConfigurationError
65
from pinecone.core.client.exceptions import ApiKeyError
76
from pinecone.config.openapi import OpenApiConfigFactory
87
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration
98

109

11-
class ConfigBase(NamedTuple):
10+
class Config(NamedTuple):
1211
api_key: str = ""
1312
host: str = ""
14-
openapi_config: OpenApiConfiguration = None
13+
openapi_config: Optional[OpenApiConfiguration] = None
1514

1615

17-
class Config:
16+
class ConfigBuilder:
1817
"""
1918
2019
Configurations are resolved in the following order:
@@ -31,39 +30,25 @@ class Config:
3130
:param openapi_config: Optional. Set OpenAPI client configuration.
3231
"""
3332

34-
def __init__(
35-
self,
36-
api_key: str = None,
37-
host: str = None,
38-
openapi_config: OpenApiConfiguration = None,
33+
@staticmethod
34+
def build(
35+
api_key: Optional[str] = None,
36+
host: Optional[str] = None,
37+
openapi_config: Optional[OpenApiConfiguration] = None,
3938
**kwargs,
40-
):
39+
) -> Config:
4140
api_key = api_key or kwargs.pop("api_key", None) or os.getenv("PINECONE_API_KEY")
4241
host = host or kwargs.pop("host", None)
42+
43+
if not api_key:
44+
raise PineconeConfigurationError("You haven't specified an Api-Key.")
45+
if not host:
46+
raise PineconeConfigurationError("You haven't specified a host.")
47+
4348
openapi_config = (
4449
openapi_config
4550
or kwargs.pop("openapi_config", None)
4651
or OpenApiConfigFactory.build(api_key=api_key, host=host)
4752
)
4853

49-
check_kwargs(self.__init__, kwargs)
50-
self._config = ConfigBase(api_key, host, openapi_config)
51-
self.validate()
52-
53-
def validate(self):
54-
if not self._config.api_key:
55-
raise PineconeConfigurationError("You haven't specified an Api-Key.")
56-
if not self._config.host:
57-
raise PineconeConfigurationError("You haven't specified a host.")
58-
59-
@property
60-
def API_KEY(self):
61-
return self._config.api_key
62-
63-
@property
64-
def HOST(self):
65-
return self._config.host
66-
67-
@property
68-
def OPENAPI_CONFIG(self):
69-
return self._config.openapi_config
54+
return Config(api_key, host, openapi_config)

pinecone/config/openapi.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import List
2+
from typing import List, Optional
33

44
import certifi
55
import requests
@@ -16,7 +16,7 @@
1616

1717
class OpenApiConfigFactory:
1818
@classmethod
19-
def build(cls, api_key: str, host: str = None, **kwargs):
19+
def build(cls, api_key: str, host: Optional[str] = None, **kwargs):
2020
openapi_config = OpenApiConfiguration()
2121
openapi_config.host = host
2222
openapi_config.ssl_ca_cert = certifi.where()
@@ -26,6 +26,7 @@ def build(cls, api_key: str, host: str = None, **kwargs):
2626

2727
@classmethod
2828
def _get_socket_options(
29+
self,
2930
do_keep_alive: bool = True,
3031
keep_alive_idle_sec: int = TCP_KEEPIDLE,
3132
keep_alive_interval_sec: int = TCP_KEEPINTVL,

pinecone/config/pinecone_config.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from typing import Optional
12
import os
2-
from .config import Config
3+
from .config import ConfigBuilder, Config
34

45
DEFAULT_CONTROLLER_HOST = "https://api.pinecone.io"
56

67

7-
class PineconeConfig(Config):
8-
def __init__(self, api_key: str = None, host: str = None, **kwargs):
8+
class PineconeConfig():
9+
@staticmethod
10+
def build(api_key: Optional[str] = None, host: Optional[str] = None, **kwargs) -> Config:
911
host = host or kwargs.get("host") or os.getenv("PINECONE_CONTROLLER_HOST") or DEFAULT_CONTROLLER_HOST
10-
super().__init__(api_key=api_key, host=host, **kwargs)
12+
return ConfigBuilder.build(api_key=api_key, host=host, **kwargs)

pinecone/control/index_host_store.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from typing import Dict
12
from pinecone.config import Config
23
from pinecone.core.client.api.index_operations_api import IndexOperationsApi
34

45

56
class SingletonMeta(type):
6-
_instances = {}
7+
_instances: Dict[str, str] = {}
78

89
def __call__(cls, *args, **kwargs):
910
if cls not in cls._instances:
@@ -17,7 +18,7 @@ def __init__(self):
1718
self._indexHosts = {}
1819

1920
def _key(self, config: Config, index_name: str) -> str:
20-
return ":".join([config.API_KEY, index_name])
21+
return ":".join([config.api_key, index_name])
2122

2223
def delete_host(self, config: Config, index_name: str):
2324
key = self._key(config, index_name)

pinecone/control/pinecone.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Optional
2+
from typing import Optional, Dict, Any, cast
33

44
from .index_host_store import IndexHostStore
55

@@ -18,21 +18,25 @@
1818
class Pinecone:
1919
def __init__(
2020
self,
21-
api_key: str = None,
22-
host: str = None,
23-
config: Config = None,
24-
index_api: IndexOperationsApi = None,
21+
api_key: Optional[str] = None,
22+
host: Optional[str] = None,
23+
config: Optional[Config] = None,
24+
index_api: Optional[IndexOperationsApi] = None,
2525
**kwargs,
2626
):
2727
if config or kwargs.get("config"):
28-
self.config = config or kwargs.get("config")
28+
configKwarg = config or kwargs.get("config")
29+
if not isinstance(configKwarg, Config):
30+
raise TypeError("config must be of type pinecone.config.Config")
31+
else:
32+
self.config = configKwarg
2933
else:
30-
self.config = PineconeConfig(api_key=api_key, host=host, **kwargs)
31-
34+
self.config = PineconeConfig.build(api_key=api_key, host=host, **kwargs)
35+
3236
if index_api:
3337
self.index_api = index_api
3438
else:
35-
api_client = ApiClient(configuration=self.config.OPENAPI_CONFIG)
39+
api_client = ApiClient(configuration=self.config.openapi_config)
3640
api_client.user_agent = get_user_agent()
3741
self.index_api = IndexOperationsApi(api_client)
3842

@@ -45,15 +49,15 @@ def create_index(
4549
cloud: str,
4650
region: str,
4751
capacity_mode: str,
48-
timeout: int = None,
52+
timeout: Optional[int] = None,
4953
index_type: str = "approximated",
5054
metric: str = "cosine",
5155
replicas: int = 1,
5256
shards: int = 1,
5357
pods: int = 1,
5458
pod_type: str = "p1",
55-
index_config: dict = None,
56-
metadata_config: dict = None,
59+
index_config: Optional[dict] = None,
60+
metadata_config: Optional[dict] = None,
5761
source_collection: str = "",
5862
):
5963
"""Creates a Pinecone index.
@@ -138,7 +142,7 @@ def is_ready():
138142
)
139143
)
140144

141-
def delete_index(self, name: str, timeout: int = None):
145+
def delete_index(self, name: str, timeout: Optional[int] = None):
142146
"""Deletes a Pinecone index.
143147
144148
:param name: the name of the index.
@@ -210,7 +214,7 @@ def configure_index(self, name: str, replicas: Optional[int] = None, pod_type: O
210214
:param: pod_type: the new pod_type for the index.
211215
"""
212216
api_instance = self.index_api
213-
config_args = {}
217+
config_args: Dict[str, Any] = {}
214218
if pod_type != "":
215219
config_args.update(pod_type=pod_type)
216220
if replicas:
@@ -267,4 +271,4 @@ def _get_status(self, name: str):
267271

268272
def Index(self, name: str):
269273
index_host = self.index_host_store.get_host(self.index_api, self.config, name)
270-
return Index(api_key=self.config.API_KEY, host=index_host)
274+
return Index(api_key=self.config.api_key, host=index_host)

pinecone/data/index.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Iterable
44
from typing import Union, List, Tuple, Optional, Dict, Any
55

6-
from pinecone.config import Config
6+
from pinecone.config import ConfigBuilder
77

88
from pinecone.core.client.models import SparseValues
99
from pinecone.core.client import ApiClient
@@ -76,13 +76,13 @@ class Index():
7676
"""
7777

7878
def __init__(self, api_key: str, host: str, pool_threads=1, **kwargs):
79-
api_key = api_key or kwargs.get("api_key")
80-
host = host or kwargs.get('host')
79+
api_key = api_key or kwargs.get("api_key", None)
80+
host = host or kwargs.get('host', None)
8181
pool_threads = pool_threads or kwargs.get("pool_threads")
8282

83-
self._config = Config(api_key=api_key, host=host, **kwargs)
83+
self._config = ConfigBuilder.build(api_key=api_key, host=host, **kwargs)
8484

85-
api_client = ApiClient(configuration=self._config.OPENAPI_CONFIG, pool_threads=pool_threads)
85+
api_client = ApiClient(configuration=self._config.openapi_config, pool_threads=pool_threads)
8686
api_client.user_agent = get_user_agent()
8787
self._api_client = api_client
8888
self._vector_api = VectorOperationsApi(api_client=api_client)
@@ -180,7 +180,7 @@ def upsert(
180180
return UpsertResponse(upserted_count=total_upserted)
181181

182182
def _upsert_batch(
183-
self, vectors: List[Vector], namespace: Optional[str], _check_type: bool, **kwargs
183+
self, vectors: Union[List[Vector], List[tuple], List[dict]], namespace: Optional[str], _check_type: bool, **kwargs
184184
) -> UpsertResponse:
185185
args_dict = self._parse_non_empty_args([("namespace", namespace)])
186186
vec_builder = lambda v: VectorFactory.build(v, check_type=_check_type)
@@ -202,7 +202,7 @@ def _iter_dataframe(df, batch_size):
202202
yield batch
203203

204204
def upsert_from_dataframe(
205-
self, df, namespace: str = None, batch_size: int = 500, show_progress: bool = True
205+
self, df, namespace: Optional[str] = None, batch_size: int = 500, show_progress: bool = True
206206
) -> UpsertResponse:
207207
"""Upserts a dataframe into the index.
208208
@@ -416,7 +416,8 @@ def _query_transform(item):
416416
),
417417
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
418418
)
419-
return parse_query_response(response, vector is not None or id)
419+
unary_query = True if vector is not None or id else False
420+
return parse_query_response(response, unary_query)
420421

421422
@validate_and_convert_errors
422423
def update(

pinecone/data/vector_factory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _dict_to_vector(item, check_type: bool) -> Vector:
9191
try:
9292
return Vector(**item, _check_type=check_type)
9393
except TypeError as e:
94-
if not isinstance(item["values"], Iterable) or not isinstance(item["values"][0], numbers.Real):
94+
if not isinstance(item["values"], Iterable) or not isinstance(item["values"].__iter__().__next__(), numbers.Real):
9595
raise TypeError(f"Column `values` is expected to be a list of floats")
9696
raise e
9797

pinecone/grpc/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .index_grpc import GRPCIndex
2-
from .pinecone import Pinecone
2+
from .pinecone import PineconeGRPC

pinecone/grpc/base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
22
from abc import ABC, abstractmethod
33
from functools import wraps
4-
from typing import Dict
4+
from typing import Dict, Optional
55

66
import certifi
77
import grpc
8-
from grpc._channel import _InactiveRpcError
8+
from grpc._channel import _InactiveRpcError, Channel
99
import json
1010

1111
from .retry import RetryConfig
@@ -29,15 +29,15 @@ def __init__(
2929
self,
3030
index_name: str,
3131
config: Config,
32-
channel=None,
33-
grpc_config: GRPCClientConfig = None,
34-
_endpoint_override: str = None,
32+
channel: Optional[Channel] =None,
33+
grpc_config: Optional[GRPCClientConfig] = None,
34+
_endpoint_override: Optional[str] = None,
3535
):
3636
self.name = index_name
3737

3838
self.grpc_client_config = grpc_config or GRPCClientConfig()
3939
self.retry_config = self.grpc_client_config.retry_config or RetryConfig()
40-
self.fixed_metadata = {"api-key": config.API_KEY, "service-name": index_name, "client-version": CLIENT_VERSION}
40+
self.fixed_metadata = {"api-key": config.api_key, "service-name": index_name, "client-version": CLIENT_VERSION}
4141
self._endpoint_override = _endpoint_override
4242

4343
self.method_config = json.dumps(

pinecone/grpc/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class GRPCClientConfig(NamedTuple):
2525
conn_timeout: int = 1
2626
reuse_channel: bool = True
2727
retry_config: Optional[RetryConfig] = None
28-
grpc_channel_options: Dict[str, str] = None
28+
grpc_channel_options: Optional[Dict[str, str]] = None
2929

3030
@classmethod
3131
def _from_dict(cls, kwargs: dict):

0 commit comments

Comments
 (0)