Skip to content

Commit

Permalink
Initial commit for Bigquery glossary term ingestion
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es committed Oct 7, 2024
1 parent 0187fc6 commit 567a387
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config):
self.bq_schema_extractor = BigQuerySchemaGenerator(
self.config,
self.report,
self.ctx.graph,
self.bigquery_data_dictionary,
self.domain_registry,
self.sql_parser_schema_resolver,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ def have_table_data_read_permission(self) -> bool:
),
)

extract_policy_tags_as_glossary_term: bool = Field(
default=False,
hidden_from_docs=True,
description="This flag enables the extraction of policy tags as glossary terms. When enabled, the extractor will create a glossary term for each policy tag associated with BigQuery table columns.",
)

scheme: str = "bigquery"

log_page_size: PositiveInt = Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TimePartitioning,
TimePartitioningType,
)
from google.cloud.datacatalog_v1 import PolicyTag

from datahub.ingestion.api.source import SourceReport
from datahub.ingestion.source.bigquery_v2.bigquery_audit import BigqueryTableIdentifier
Expand All @@ -39,7 +40,7 @@ class BigqueryColumn(BaseColumn):
field_path: str
is_partition_column: bool
cluster_column_position: Optional[int]
policy_tags: Optional[List[str]] = None
policy_tags: Optional[List[PolicyTag]] = None


RANGE_PARTITION_NAME: str = "RANGE"
Expand Down Expand Up @@ -442,6 +443,17 @@ def _make_bigquery_view(view: bigquery.Row) -> BigqueryView:
labels=parse_labels(view.labels) if view.get("labels") else None,
)

@lru_cache()
def get_policy_tags(self, taxonomy: str) -> Dict[str, PolicyTag]:
assert self.datacatalog_client
taxonomy = self.datacatalog_client.get_taxonomy(name=taxonomy)
if taxonomy:
policy_tags = list(
self.datacatalog_client.list_policy_tags(parent=taxonomy.name)
)
return {tag.name: tag for tag in policy_tags}
return {}

def get_policy_tags_for_column(
self,
project_id: str,
Expand All @@ -450,7 +462,7 @@ def get_policy_tags_for_column(
column_name: str,
report: BigQueryV2Report,
rate_limiter: Optional[RateLimiter] = None,
) -> Iterable[str]:
) -> Iterable[PolicyTag]:
assert self.datacatalog_client

try:
Expand All @@ -476,7 +488,7 @@ def get_policy_tags_for_column(
policy_tag = self.datacatalog_client.get_policy_tag(
name=policy_tag_name
)
yield policy_tag.display_name
yield policy_tag
except Exception as e:
report.warning(
title="Failed to retrieve policy tag",
Expand All @@ -492,6 +504,24 @@ def get_policy_tags_for_column(
exc=e,
)

def get_policy_terms(self, policy_tag: PolicyTag) -> Dict[str, PolicyTag]:
assert self.datacatalog_client
policy_tag_path = self.datacatalog_client.parse_policy_tag_path(policy_tag.name)
taxonomy = policy_tag_path.get("taxonomy")
location = policy_tag_path.get("location")
project = policy_tag_path.get("project")

if taxonomy and location and project:
taxonomy_path = self.datacatalog_client.taxonomy_path(
project,
location,
taxonomy,
)
policy_tag_dict = self.get_policy_tags(taxonomy_path)
return policy_tag_dict
else:
return {}

def get_columns_for_dataset(
self,
project_id: str,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
import re
import time
from base64 import b32decode
from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Set, Type, Union, cast

from cachetools.func import lru_cache
from google.cloud.bigquery.table import TableListItem
from google.cloud.datacatalog_v1 import PolicyTag

from datahub.configuration.pattern_utils import is_schema_allowed, is_tag_allowed
from datahub.emitter.mce_builder import make_tag_urn
Expand All @@ -15,6 +19,7 @@
ClassificationHandler,
classification_workunit_processor,
)
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.bigquery_v2.bigquery_audit import (
BigqueryTableIdentifier,
BigQueryTableRef,
Expand Down Expand Up @@ -54,7 +59,11 @@
METADATA_EXTRACTION,
PROFILING,
)
from datahub.metadata._schema_classes import AuditStampClass, MetadataAttributionClass
from datahub.metadata._urns.urn_defs import GlossaryNodeUrn, GlossaryTermUrn
from datahub.metadata.com.linkedin.pegasus2avro.common import (
GlossaryTermAssociation,
GlossaryTerms,
Status,
SubTypes,
TimeStamp,
Expand All @@ -63,6 +72,10 @@
DatasetProperties,
ViewProperties,
)
from datahub.metadata.com.linkedin.pegasus2avro.glossary import (
GlossaryNodeInfo,
GlossaryTermInfo,
)
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
ArrayType,
BooleanType,
Expand All @@ -89,7 +102,6 @@
HiveColumnToAvroConverter,
get_schema_fields_for_hive_column,
)
from datahub.utilities.mapping import Constants
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.ratelimiter import RateLimiter
from datahub.utilities.registries.domain_registry import DomainRegistry
Expand Down Expand Up @@ -153,6 +165,7 @@ def __init__(
self,
config: BigQueryV2Config,
report: BigQueryV2Report,
graph: Optional[DataHubGraph],
bigquery_data_dictionary: BigQuerySchemaApi,
domain_registry: Optional[DomainRegistry],
sql_parser_schema_resolver: SchemaResolver,
Expand All @@ -161,6 +174,7 @@ def __init__(
):
self.config = config
self.report = report
self.graph = graph
self.schema_api = bigquery_data_dictionary
self.domain_registry = domain_registry
self.sql_parser_schema_resolver = sql_parser_schema_resolver
Expand Down Expand Up @@ -662,6 +676,18 @@ def _process_snapshot(
dataset_name=dataset_name,
)

def modified_base32decode(self, text_to_decode: str) -> str:
# When we sync from DataHub to BigQuery, we encode the tags as modified base32 strings.
# BiqQuery labels only support lowercase letters, international characters, numbers, or underscores.
# So we need to modify the base32 encoding to replace the padding character `=` with `_` and convert to lowercase.
if not text_to_decode.startswith("urn_li_encoded_tag_"):
return text_to_decode
text_to_decode = (
text_to_decode.replace("urn_li_encoded_tag_", "").upper().replace("_", "=")
)
text = b32decode(text_to_decode.encode("utf-8")).decode("utf-8")
return text

def gen_table_dataset_workunits(
self,
table: BigqueryTable,
Expand Down Expand Up @@ -708,6 +734,8 @@ def gen_table_dataset_workunits(
tags_to_add.extend(
[
make_tag_urn(f"""{k}:{v}""")
if not v.startswith("urn_li_encoded_tag_")
else self.modified_base32decode(v)
for k, v in table.labels.items()
if is_tag_allowed(self.config.capture_table_label_as_tag, k)
]
Expand All @@ -733,7 +761,9 @@ def gen_view_dataset_workunits(
tags_to_add = None
if table.labels and self.config.capture_view_label_as_tag:
tags_to_add = [
make_tag_urn(f"{k}:{v}")
make_tag_urn(f"""{k}:{v}""")
if not v.startswith("urn_li_encoded_tag_")
else self.modified_base32decode(v)
for k, v in table.labels.items()
if is_tag_allowed(self.config.capture_view_label_as_tag, k)
]
Expand Down Expand Up @@ -785,6 +815,14 @@ def gen_snapshot_dataset_workunits(
custom_properties=custom_properties,
)

def gen_glossary_terms(
self, columns: List[BigqueryColumn]
) -> Iterable[MetadataWorkUnit]:
for column in columns:
if column.policy_tags:
for policy_tag in column.policy_tags:
yield from self.create_and_get_glossary_term(policy_tag)

def gen_dataset_workunits(
self,
table: Union[BigqueryTable, BigqueryView, BigqueryTableSnapshot],
Expand All @@ -808,6 +846,12 @@ def gen_dataset_workunits(
project_id, dataset_name, table.name
)

if (
self.config.extract_policy_tags_from_catalog
and self.config.extract_policy_tags_as_glossary_term
):
yield from self.gen_glossary_terms(columns)

yield self.gen_schema_metadata(
dataset_urn, table, columns, datahub_dataset_name
)
Expand Down Expand Up @@ -879,6 +923,74 @@ def gen_tags_aspect_workunit(
entityUrn=dataset_urn, aspect=tags
).as_workunit()

@lru_cache()
def glossaryNodeExists(self, urn: str) -> bool:
assert self.graph, "graph is not set"
exists = self.graph.get_aspects_for_entity(
entity_urn=urn,
aspects=["glossaryNodeInfo"],
aspect_types=[GlossaryNodeInfo],
)
return exists.get("glossaryNodeInfo") is True

def create_and_get_glossary_term(
self, policy_tag: PolicyTag
) -> Iterable[MetadataWorkUnit]:
policy_tag_dict = self.schema_api.get_policy_terms(policy_tag)

parents = []

parent: Optional[str] = policy_tag.parent_policy_tag
while parent:
parent_policy_tag = policy_tag_dict.get(parent)
if parent_policy_tag is None:
parent = None
continue
parents.append(parent_policy_tag)
parent = parent_policy_tag.parent_policy_tag

parent_urn: Optional[GlossaryNodeUrn] = None
for p in reversed(parents):
p_tag = policy_tag_dict.get(p.name)
if p_tag is None:
continue
exists = self.glossaryNodeExists(
self.gen_glossary_node_urn_from_policy_tag(p_tag).urn()
)
if not exists:
yield MetadataChangeProposalWrapper(
entityUrn=GlossaryNodeUrn(
name=p_tag.display_name,
).urn(),
aspect=GlossaryNodeInfo(
name=p_tag.display_name,
definition=p_tag.description,
parentNode=parent_urn.urn() if parent_urn else None,
),
).as_workunit()
parent_urn = self.gen_glossary_node_urn_from_policy_tag(p_tag)

yield MetadataChangeProposalWrapper(
entityUrn=self.gen_glossary_term_urn_from_policy_tag(policy_tag).urn(),
aspect=GlossaryTermInfo(
name=policy_tag.display_name,
definition=policy_tag.description,
sourceRef=policy_tag.name,
termSource="EXTERNAL",
parentNode=parent_urn.urn() if parent_urn else None,
),
).as_workunit()

def gen_glossary_term_urn_from_policy_tag(
self, policy_tag: PolicyTag
) -> GlossaryTermUrn:
return GlossaryTermUrn(name=policy_tag.display_name)

def gen_glossary_node_urn_from_policy_tag(
self, policy_tag: PolicyTag
) -> GlossaryNodeUrn:
return GlossaryNodeUrn(name=policy_tag.display_name)

def gen_schema_fields(self, columns: List[BigqueryColumn]) -> List[SchemaField]:
schema_fields: List[SchemaField] = []

Expand Down Expand Up @@ -922,11 +1034,6 @@ def gen_schema_fields(self, columns: List[BigqueryColumn]) -> List[SchemaField]:
break
else:
tags = []
if col.is_partition_column:
tags.append(
TagAssociationClass(make_tag_urn(Constants.TAG_PARTITION_KEY))
)

if col.cluster_column_position is not None:
tags.append(
TagAssociationClass(
Expand All @@ -936,18 +1043,54 @@ def gen_schema_fields(self, columns: List[BigqueryColumn]) -> List[SchemaField]:
)
)

term_associations: List[GlossaryTermAssociation] = []
if col.policy_tags:
for policy_tag in col.policy_tags:
tags.append(TagAssociationClass(make_tag_urn(policy_tag)))
if not self.config.extract_policy_tags_as_glossary_term:
tags.append(
TagAssociationClass(
make_tag_urn(policy_tag.display_name)
)
)
else:
glossary_term_urn = (
self.gen_glossary_term_urn_from_policy_tag(policy_tag)
)
term_associations.append(
GlossaryTermAssociation(
urn=glossary_term_urn.urn(),
attribution=MetadataAttributionClass(
sourceDetail={
"source": "bigquery",
"sourceType": "policyTag",
"sourceId": policy_tag.name,
"policy_tag_name": policy_tag.name,
},
source="urn:li:platform:bigquery",
time=int(time.time() * 1000),
actor="urn:li:corpuser:ingestion",
),
)
)
now = int(time.time() * 1000)
current_timestamp = AuditStampClass(
time=now, actor="urn:li:corpuser:ingestion"
)
field = SchemaField(
fieldPath=col.name,
type=SchemaFieldDataType(
self.BIGQUERY_FIELD_TYPE_MAPPINGS.get(col.data_type, NullType)()
),
isPartitioningKey=col.is_partition_column,
nativeDataType=col.data_type,
description=col.comment,
nullable=col.is_nullable,
globalTags=GlobalTagsClass(tags=tags),
glossaryTerms=GlossaryTerms(
terms=term_associations, auditStamp=current_timestamp
)
if term_associations
else None,
)
schema_fields.append(field)
last_id = col.ordinal_position
Expand Down

0 comments on commit 567a387

Please sign in to comment.