Skip to content

Commit

Permalink
Format with black
Browse files Browse the repository at this point in the history
  • Loading branch information
b-per committed Jun 5, 2024
1 parent 0dcedd1 commit e5af0d8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
8 changes: 7 additions & 1 deletion src/changeset/change_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ def to_table(self) -> Table:
table.add_column("Env ID", style="red")

for change in self.__root__:
table.add_row(change.action.upper(), string.capwords(change.type), change.identifier, str(change.proj_id), str(change.env_id))
table.add_row(
change.action.upper(),
string.capwords(change.type),
change.identifier,
str(change.proj_id),
str(change.env_id),
)

return table

Expand Down
21 changes: 11 additions & 10 deletions src/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def _check_for_creds(self):

def build_mapping_job_identifier_job_id(self, cloud_jobs: List[JobDefinition] = None):


if cloud_jobs is None:
# TODO, we should filter things here at least if we call it often
cloud_jobs = self.get_jobs()
Expand Down Expand Up @@ -124,7 +123,7 @@ def delete_job(self, job: JobDefinition) -> None:

logger.success("Job deleted successfully.")

def get_jobs(self, project_id=None , environment_id=None) -> List[JobDefinition]:
def get_jobs(self, project_id=None, environment_id=None) -> List[JobDefinition]:
"""Return a list of Jobs for all the dbt Cloud jobs in an environment."""

self._check_for_creds()
Expand All @@ -150,13 +149,13 @@ def get_jobs(self, project_id=None , environment_id=None) -> List[JobDefinition]

while True:
parameters = {"offset": offset}
parameters["environment_id"]=env_id
parameters["environment_id"] = env_id

if len(project_id) == 1:
parameters["project_id"]=project_id[0]
parameters["project_id"] = project_id[0]
elif len(project_id) > 1:
project_id_str = [str(i) for i in project_id]
parameters["project_id__in"]=f"[{','.join(project_id_str)}]"
parameters["project_id__in"] = f"[{','.join(project_id_str)}]"

response = requests.get(
url=f"{self.base_url}/api/v2/accounts/{self.account_id}/jobs/",
Expand All @@ -173,25 +172,26 @@ def get_jobs(self, project_id=None , environment_id=None) -> List[JobDefinition]
jobs.extend(job_data["data"])

if (
job_data["extra"]["filters"]["limit"] + job_data["extra"]["filters"]["offset"]
job_data["extra"]["filters"]["limit"]
+ job_data["extra"]["filters"]["offset"]
>= job_data["extra"]["pagination"]["total_count"]
):
break

offset += job_data["extra"]["filters"]["limit"]

else:
else:
# In this case, there are no multiple environments ID's.. Invoke the API once
while True:
parameters = {"offset": offset}
if len(project_id) == 1:
parameters["project_id"]=project_id[0]
parameters["project_id"] = project_id[0]
elif len(project_id) > 1:
project_id_str = [str(i) for i in project_id]
parameters["project_id__in"]=f"[{','.join(project_id_str)}]"
parameters["project_id__in"] = f"[{','.join(project_id_str)}]"

if len(environment_id) == 1:
parameters["environment_id"]=environment_id[0]
parameters["environment_id"] = environment_id[0]

logger.debug(f"Request parameters {parameters}")
response = requests.get(
Expand Down Expand Up @@ -298,6 +298,7 @@ def update_env_var(

# handle the case where the job was not created when we queued the function call
if yml_job_identifier and not job_id:
# TODO - we shouldn't have to call the API so many times
mapping_job_identifier_job_id = self.build_mapping_job_identifier_job_id()
job_id = mapping_job_identifier_job_id[yml_job_identifier]
custom_env_var.job_definition_id = job_id
Expand Down
18 changes: 12 additions & 6 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def build_change_set(config, disable_ssl_verification, project_id, environment_i
)
dbt_cloud_change_set.append(dbt_cloud_change)

# Filtering out the change set, if project_id(s), environment_id(s) are passed as arguments to function - Desired functionality?
# Filtering out the change set, if project_id(s), environment_id(s) are passed as arguments to function - Desired functionality?
dbt_cloud_change_set_filtered = ChangeSet()
logger.debug(f"dbt cloud change set: {dbt_cloud_change_set}")

Expand All @@ -204,7 +204,7 @@ def build_change_set(config, disable_ssl_verification, project_id, environment_i
if dbt_cloud_change.env_id in environment_id or dbt_cloud_change.proj_id in project_id:
dbt_cloud_change_set_filtered.append(dbt_cloud_change)

dbt_cloud_change_set = dbt_cloud_change_set_filtered
dbt_cloud_change_set = dbt_cloud_change_set_filtered

return dbt_cloud_change_set

Expand Down Expand Up @@ -246,7 +246,9 @@ def sync(config, project_id, environment_id, disable_ssl_verification):
cloud_environment_id = environment_id

logger.info("-- SYNC -- Invoking build_change_set")
change_set = build_change_set(config, disable_ssl_verification, cloud_project_id, cloud_environment_id )
change_set = build_change_set(
config, disable_ssl_verification, cloud_project_id, cloud_environment_id
)
if len(change_set) == 0:
logger.success("-- SYNC -- No changes detected.")
else:
Expand Down Expand Up @@ -288,7 +290,9 @@ def plan(config, project_id, environment_id, disable_ssl_verification):
if environment_id:
cloud_environment_id = environment_id

change_set = build_change_set(config, disable_ssl_verification, cloud_project_id, cloud_environment_id )
change_set = build_change_set(
config, disable_ssl_verification, cloud_project_id, cloud_environment_id
)
if len(change_set) == 0:
logger.success("-- PLAN -- No changes detected.")
else:
Expand Down Expand Up @@ -433,7 +437,7 @@ def import_jobs(config, account_id, project_id, environment_id, job_id, disable_

cloud_project_id = []
cloud_environment_id = []

if project_id:
cloud_project_id = project_id

Expand All @@ -446,7 +450,9 @@ def import_jobs(config, account_id, project_id, environment_id, job_id, disable_
base_url=os.environ.get("DBT_BASE_URL", "https://cloud.getdbt.com"),
disable_ssl_verification=disable_ssl_verification,
)
cloud_jobs = dbt_cloud.get_jobs(project_id=cloud_project_id, environment_id=cloud_environment_id)
cloud_jobs = dbt_cloud.get_jobs(
project_id=cloud_project_id, environment_id=cloud_environment_id
)
logger.info(f"Getting the jobs definition from dbt Cloud")

if job_id:
Expand Down

0 comments on commit e5af0d8

Please sign in to comment.