Skip to content

Commit

Permalink
Merge jamalex#370
Browse files Browse the repository at this point in the history
  • Loading branch information
vzhd1701 authored May 9, 2022
2 parents c9223c0 + 19f7285 commit 697f290
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 37 deletions.
20 changes: 12 additions & 8 deletions notion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from requests import Session, HTTPError
from requests.cookies import cookiejar_from_dict
from urllib.parse import urljoin
from urllib.parse import unquote, urljoin
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from getpass import getpass
Expand Down Expand Up @@ -53,6 +53,7 @@ def create_session(client_specified_retry=None):
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("https://", adapter)
session.headers.update({"content-type": "application/json"})
return session


Expand Down Expand Up @@ -130,6 +131,10 @@ def _update_user_info(self):
self._store.store_recordmap(records)
self.current_user = self.get_user(list(records["notion_user"].keys())[0])
self.current_space = self.get_space(list(records["space"].keys())[0])

self.session.headers.update({"x-notion-active-user-header":
json.loads(unquote(self.session.cookies.get("notion_users")))[1]})

return records

def get_email_uid(self):
Expand All @@ -140,7 +145,6 @@ def get_email_uid(self):
}

def set_user_by_uid(self, user_id):
self.session.headers.update({"x-notion-active-user-header": user_id})
self._update_user_info()

def set_user_by_email(self, email):
Expand All @@ -158,15 +162,15 @@ def get_top_level_pages(self):
records = self._update_user_info()
return [self.get_block(bid) for bid in records["block"].keys()]

def get_record_data(self, table, id, force_refresh=False):
return self._store.get(table, id, force_refresh=force_refresh)
def get_record_data(self, table, id, force_refresh=False, limit=100):
return self._store.get(table, id, force_refresh=force_refresh, limit=limit)

def get_block(self, url_or_id, force_refresh=False):
def get_block(self, url_or_id, force_refresh=False, limit=100):
"""
Retrieve an instance of a subclass of Block that maps to the block/page identified by the URL or ID passed in.
"""
block_id = extract_id(url_or_id)
block = self.get_record_data("block", block_id, force_refresh=force_refresh)
block = self.get_record_data("block", block_id, force_refresh=force_refresh, limit=limit)
if not block:
return None
if block.get("parent_table") == "collection":
Expand Down Expand Up @@ -306,11 +310,11 @@ def in_transaction(self):
"""
return hasattr(self, "_transaction_operations")

def search_pages_with_parent(self, parent_id, search=""):
def search_pages_with_parent(self, parent_id, search="", limit=100):
data = {
"query": search,
"parentId": parent_id,
"limit": 10000,
"limit": limit,
"spaceId": self.current_space.id,
}
response = self.post("searchPagesWithParent", data).json()
Expand Down
40 changes: 28 additions & 12 deletions notion/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def __init__(
sort=[],
calendar_by="",
group_by="",
limit=100
):
assert not (
aggregate and aggregations
Expand All @@ -374,25 +375,40 @@ def __init__(
self.sort = _normalize_query_data(sort, collection)
self.calendar_by = _normalize_property_name(calendar_by, collection)
self.group_by = _normalize_property_name(group_by, collection)
self.limit = limit
self._client = collection._client

def execute(self):

result_class = QUERY_RESULT_TYPES.get(self.type, QueryResult)

kwargs = {
'collection_id':self.collection.id,
'collection_view_id':self.collection_view.id,
'search':self.search,
'type':self.type,
'aggregate':self.aggregate,
'aggregations':self.aggregations,
'filter':self.filter,
'sort':self.sort,
'calendar_by':self.calendar_by,
'group_by':self.group_by,
'limit':0
}

if self.limit == -1:
# fetch remote total
result = self._client.query_collection(
**kwargs
)
self.limit = result.get("total",-1)

kwargs['limit'] = self.limit

return result_class(
self.collection,
self._client.query_collection(
collection_id=self.collection.id,
collection_view_id=self.collection_view.id,
search=self.search,
type=self.type,
aggregate=self.aggregate,
aggregations=self.aggregations,
filter=self.filter,
sort=self.sort,
calendar_by=self.calendar_by,
group_by=self.group_by,
**kwargs
),
self,
)
Expand Down Expand Up @@ -704,14 +720,15 @@ def __init__(self, collection, result, query):
self.collection = collection
self._client = collection._client
self._block_ids = self._get_block_ids(result)
self.total = result.get("total", -1)
self.aggregates = result.get("aggregationResults", [])
self.aggregate_ids = [
agg.get("id") for agg in (query.aggregate or query.aggregations)
]
self.query = query

def _get_block_ids(self, result):
return result["blockIds"]
return result['reducerResults']['collection_group_results']["blockIds"]

def _get_block(self, id):
block = CollectionRowBlock(self._client, id)
Expand Down Expand Up @@ -754,7 +771,6 @@ def __contains__(self, item):
return False
return item_id in self._block_ids


class TableQueryResult(QueryResult):

_type = "table"
Expand Down
41 changes: 24 additions & 17 deletions notion/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,14 @@ def get_role(self, table, id, force_refresh=False):
self.get(table, id, force_refresh=force_refresh)
return self._role[table].get(id, None)

def get(self, table, id, force_refresh=False):
def get(self, table, id, force_refresh=False, limit=100):
id = extract_id(id)
# look up the record in the current local dataset
result = self._get(table, id)
# if it's not found, try refreshing the record from the server
if result is Missing or force_refresh:
if table == "block":
self.call_load_page_chunk(id)
self.call_load_page_chunk(id,limit=limit)
else:
self.call_get_record_values(**{table: id})
result = self._get(table, id)
Expand Down Expand Up @@ -269,15 +269,17 @@ def get_current_version(self, table, id):
else:
return -1

def call_load_page_chunk(self, page_id):
def call_load_page_chunk(self, page_id, limit=100):

if self._client.in_transaction():
self._pages_to_refresh.append(page_id)
return

data = {
"pageId": page_id,
"limit": 100000,
"page": {
"id": page_id,
},
"limit": limit,
"cursor": {"stack": []},
"chunkNumber": 0,
"verticalColumns": False,
Expand Down Expand Up @@ -310,6 +312,7 @@ def call_query_collection(
sort=[],
calendar_by="",
group_by="",
limit=50
):

assert not (
Expand All @@ -323,21 +326,25 @@ def call_query_collection(
sort = [sort]

data = {
"collectionId": collection_id,
"collectionViewId": collection_view_id,
"collection": {
"id": collection_id,
"spaceId": self._client.current_space.id
},
"collectionView": {
"id": collection_view_id,
"spaceId": self._client.current_space.id
},
"loader": {
"limit": 10000,
"loadContentCover": True,
'reducers': {
'collection_group_results': {
'limit': limit,
'type': 'results',
},
},
"searchQuery": search,
"userLocale": "en",
'sort': sort,
"userTimeZone": str(get_localzone()),
"type": type,
},
"query": {
"aggregate": aggregate,
"aggregations": aggregations,
"filter": filter,
"sort": sort,
"type": 'reducer',
},
}

Expand Down

0 comments on commit 697f290

Please sign in to comment.