Skip to content

Commit

Permalink
Merge pull request #7 from ContinuumIO/extra-headers
Browse files Browse the repository at this point in the history
allow extra headers in open methods
  • Loading branch information
AlbertDeFusco authored Mar 29, 2023
2 parents c28d36e + fcfedac commit 4922c59
Showing 1 changed file with 34 additions and 20 deletions.
54 changes: 34 additions & 20 deletions intake_metabase/source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from datetime import datetime, timedelta
from urllib.parse import urljoin
from collections import ChainMap

import requests
from intake.catalog import Catalog
Expand All @@ -10,19 +11,26 @@
from . import __version__


def _merge_dicts(*dicts):
return dict(ChainMap(*reversed(dicts)))


class MetabaseCatalog(Catalog):
name = 'metabase_catalog'
version = __version__
# partition_access = False

def __init__(self, domain, username=None, password=None, token=None, metadata=None, name=None):
def __init__(self, domain, username=None, password=None, token=None,
extra_headers=None,
metadata=None, name=None):
self.name = name
self.domain = domain
self.username = username
self.password = password
self.token = token
self.extra_headers = extra_headers

self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token)
self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token, self.extra_headers)

super().__init__(name='metabase', metadata=metadata)

Expand All @@ -46,7 +54,8 @@ def _load(self):
'username': self.username,
'password': self.password,
'token': self.token,
'question': question
'question': question,
'extra_headers': self.extra_headers
}
)
e._plugin = [MetabaseQuestionSource]
Expand All @@ -65,7 +74,8 @@ def _load(self):
'password': self.password,
'token': self.token,
'database': db['id'],
'table': table['id']
'table': table['id'],
'extra_headers': self.extra_headers
}
)
e._plugin = [MetabaseTableSource]
Expand All @@ -78,15 +88,17 @@ class MetabaseQuestionSource(DataSource):
version = __version__
partition_access = True

def __init__(self, domain, question, username=None, password=None, token=None, metadata=None):
def __init__(self, domain, question, username=None, password=None, token=None,
extra_headers=None, metadata=None):
self.domain = domain
self.username = username
self.password = password
self.token = token
self.question = question
self._df = None
self.extra_headers = extra_headers

self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token)
self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token, self.extra_headers)

super(MetabaseQuestionSource, self).__init__(metadata=metadata)

Expand Down Expand Up @@ -121,7 +133,8 @@ class MetabaseTableSource(DataSource):
version = __version__
partition_access = True

def __init__(self, domain, database, table=None, query=None, username=None, password=None, token=None, metadata=None):
def __init__(self, domain, database, table=None, query=None, username=None, password=None, token=None,
extra_headers=None, metadata=None):
self.domain = domain
self.username = username
self.password = password
Expand All @@ -130,8 +143,9 @@ def __init__(self, domain, database, table=None, query=None, username=None, pass
self.table = table
self.query = query
self._df = None
self.extra_headers = extra_headers

self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token)
self._metabase = MetabaseAPI(self.domain, self.username, self.password, self.token, self.extra_headers)

super(MetabaseTableSource, self).__init__(metadata=metadata)

Expand Down Expand Up @@ -161,7 +175,7 @@ def _close(self):


class MetabaseAPI():
def __init__(self, domain, username=None, password=None, token=None):
def __init__(self, domain, username=None, password=None, token=None, extra_headers=None):
self.domain = domain

self.password = password
Expand All @@ -174,14 +188,18 @@ def __init__(self, domain, username=None, password=None, token=None):
self._token = None
self._token_expiration = datetime.now()

self.extra_headers = {} if extra_headers is None else extra_headers

def _create_or_refresh_token(self):
if self._token:
if (self._token_expiration is None) or (datetime.now() < self._token_expiration):
return

headers=_merge_dicts({'Content-Type': 'application/json'}, self.extra_headers)

res = requests.post(
urljoin(self.domain, '/api/session'),
headers={'Content-Type': 'application/json'},
headers=headers,
data=json.dumps({
'username': self.username,
'password': self.password
Expand All @@ -195,9 +213,7 @@ def _create_or_refresh_token(self):
def get_databases(self):
self._create_or_refresh_token()

headers = {
'X-Metabase-Session': self._token
}
headers = _merge_dicts({'X-Metabase-Session': self._token}, self.extra_headers)
params = {'include': 'tables', 'saved': True}

res = requests.get(
Expand All @@ -210,9 +226,7 @@ def get_databases(self):
def get_metadata(self, table):
self._create_or_refresh_token()

headers = {
'X-Metabase-Session': self._token
}
headers = _merge_dicts({'X-Metabase-Session': self._token}, self.extra_headers)

res = requests.get(
urljoin(self.domain, f'/api/table/{table}/query_metadata'),
Expand All @@ -232,10 +246,10 @@ def get_card(self, question):
date_fields = [f['display_name'] for f in card_metadata['fields']
if 'date' in f['base_type'].lower()]

headers = {
headers = _merge_dicts({
'Content-Type': 'application/x-www-form-urlencoded',
'X-Metabase-Session': self._token
}
}, self.extra_headers)

res = requests.post(
urljoin(self.domain, f'/api/card/{question}/query/csv'),
Expand Down Expand Up @@ -280,10 +294,10 @@ def get_table(self, database, table=None, query=None):
body['type'] = 'native'
body['native'] = {'query': query}

headers = {
headers = _merge_dicts({
'Content-Type': 'application/x-www-form-urlencoded',
'X-Metabase-Session': self._token
}
}, self.extra_headers)

res = requests.post(
urljoin(self.domain, '/api/dataset/csv'),
Expand Down

0 comments on commit 4922c59

Please sign in to comment.