diff --git a/intake_metabase/source.py b/intake_metabase/source.py index 28ddf30..45753eb 100644 --- a/intake_metabase/source.py +++ b/intake_metabase/source.py @@ -1,6 +1,7 @@ import json from datetime import datetime, timedelta from urllib.parse import urljoin +from collections import ChainMap import requests from intake.catalog import Catalog @@ -10,19 +11,26 @@ from . import __version__ +def _merge_dicts(*dicts): + return dict(ChainMap(*reversed(dicts))) + + class MetabaseCatalog(Catalog): name = 'metabase_catalog' version = __version__ # partition_access = False - def __init__(self, domain, username=None, password=None, token=None, metadata=None, name=None): + def __init__(self, domain, username=None, password=None, token=None, + extra_headers=None, + metadata=None, name=None): self.name = name self.domain = domain self.username = username self.password = password self.token = token + self.extra_headers = extra_headers - self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token) + self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token, self.extra_headers) super().__init__(name='metabase', metadata=metadata) @@ -46,7 +54,8 @@ def _load(self): 'username': self.username, 'password': self.password, 'token': self.token, - 'question': question + 'question': question, + 'extra_headers': self.extra_headers } ) e._plugin = [MetabaseQuestionSource] @@ -65,7 +74,8 @@ def _load(self): 'password': self.password, 'token': self.token, 'database': db['id'], - 'table': table['id'] + 'table': table['id'], + 'extra_headers': self.extra_headers } ) e._plugin = [MetabaseTableSource] @@ -78,15 +88,17 @@ class MetabaseQuestionSource(DataSource): version = __version__ partition_access = True - def __init__(self, domain, question, username=None, password=None, token=None, metadata=None): + def __init__(self, domain, question, username=None, password=None, token=None, + extra_headers=None, metadata=None): self.domain = domain self.username = username self.password = password self.token = token self.question = question self._df = None + self.extra_headers = extra_headers - self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token) + self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token, self.extra_headers) super(MetabaseQuestionSource, self).__init__(metadata=metadata) @@ -121,7 +133,8 @@ class MetabaseTableSource(DataSource): version = __version__ partition_access = True - def __init__(self, domain, database, table=None, query=None, username=None, password=None, token=None, metadata=None): + def __init__(self, domain, database, table=None, query=None, username=None, password=None, token=None, + extra_headers=None, metadata=None): self.domain = domain self.username = username self.password = password @@ -130,8 +143,9 @@ def __init__(self, domain, database, table=None, query=None, username=None, pass self.table = table self.query = query self._df = None + self.extra_headers = extra_headers - self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token) + self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token, self.extra_headers) super(MetabaseTableSource, self).__init__(metadata=metadata) @@ -161,7 +175,7 @@ def _close(self): class MetabaseAPI(): - def __init__(self, domain, username=None, password=None, token=None): + def __init__(self, domain, username=None, password=None, token=None, extra_headers=None): self.domain = domain self.password = password @@ -174,14 +188,18 @@ def __init__(self, domain, username=None, password=None, token=None): self._token = None self._token_expiration = datetime.now() + self.extra_headers = {} if extra_headers is None else extra_headers + def _create_or_refresh_token(self): if self._token: if (self._token_expiration is None) or (datetime.now() < self._token_expiration): return + headers=_merge_dicts({'Content-Type': 'application/json'}, self.extra_headers) + res = requests.post( urljoin(self.domain, '/api/session'), - headers={'Content-Type': 'application/json'}, + headers=headers, data=json.dumps({ 'username': self.username, 'password': self.password @@ -195,9 +213,7 @@ def _create_or_refresh_token(self): def get_databases(self): self._create_or_refresh_token() - headers = { - 'X-Metabase-Session': self._token - } + headers = _merge_dicts({'X-Metabase-Session': self._token}, self.extra_headers) params = {'include': 'tables', 'saved': True} res = requests.get( @@ -210,9 +226,7 @@ def get_databases(self): def get_metadata(self, table): self._create_or_refresh_token() - headers = { - 'X-Metabase-Session': self._token - } + headers = _merge_dicts({'X-Metabase-Session': self._token}, self.extra_headers) res = requests.get( urljoin(self.domain, f'/api/table/{table}/query_metadata'), @@ -232,10 +246,10 @@ def get_card(self, question): date_fields = [f['display_name'] for f in card_metadata['fields'] if 'date' in f['base_type'].lower()] - headers = { + headers = _merge_dicts({ 'Content-Type': 'application/x-www-form-urlencoded', 'X-Metabase-Session': self._token - } + }, self.extra_headers) res = requests.post( urljoin(self.domain, f'/api/card/{question}/query/csv'), @@ -280,10 +294,10 @@ def get_table(self, database, table=None, query=None): body['type'] = 'native' body['native'] = {'query': query} - headers = { + headers = _merge_dicts({ 'Content-Type': 'application/x-www-form-urlencoded', 'X-Metabase-Session': self._token - } + }, self.extra_headers) res = requests.post( urljoin(self.domain, '/api/dataset/csv'),