Skip to content

Commit

Permalink
Merge pull request #3611 from opsmill/pog-repository-typing
Browse files Browse the repository at this point in the history
Typing cleanup for git.repository
  • Loading branch information
ogenstad authored Jun 8, 2024
2 parents 2eb4ebc + 2a2ec20 commit 02a8839
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions backend/infrahub/git/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,15 +516,7 @@ async def create_locally(
return True

@classmethod
async def new(cls, service: Optional[InfrahubServices] = None, **kwargs):
service = service or InfrahubServices()
self = cls(service=service, **kwargs)
await self.create_locally()
log.info("Created the new project locally.", repository=self.name)
return self

@classmethod
async def init(cls, service: Optional[InfrahubServices] = None, **kwargs):
async def init(cls, service: Optional[InfrahubServices] = None, **kwargs: Any):
service = service or InfrahubServices()
self = cls(service=service, **kwargs)
self.validate_local_directories()
Expand Down Expand Up @@ -861,7 +853,7 @@ async def get_conflicts(self, source_branch: str, dest_branch: str) -> List[str]

return conflict_files

async def import_objects_from_files(self, branch_name: str, commit: Optional[str] = None):
async def import_objects_from_files(self, branch_name: str, commit: Optional[str] = None) -> None:
if not commit:
commit = self.get_commit_value(branch_name=branch_name)

Expand All @@ -878,7 +870,9 @@ async def import_objects_from_files(self, branch_name: str, commit: Optional[str
await self.import_jinja2_transforms(branch_name=branch_name, commit=commit, config_file=config_file)
await self.import_artifact_definitions(branch_name=branch_name, commit=commit, config_file=config_file)

async def import_jinja2_transforms(self, branch_name: str, commit: str, config_file: InfrahubRepositoryConfig):
async def import_jinja2_transforms(
self, branch_name: str, commit: str, config_file: InfrahubRepositoryConfig
) -> None:
log.debug("Importing all Jinja2 transforms", repository=self.name, branch=branch_name, commit=commit)

schema = await self.sdk.schema.get(kind=InfrahubKind.TRANSFORMJINJA2, branch=branch_name)
Expand Down Expand Up @@ -986,7 +980,9 @@ async def update_jinja2_transform(

await existing_transform.save()

async def import_artifact_definitions(self, branch_name: str, commit: str, config_file: InfrahubRepositoryConfig):
async def import_artifact_definitions(
self, branch_name: str, commit: str, config_file: InfrahubRepositoryConfig
) -> None:
log.debug("Importing all Artifact Definitions", repository=self.name, branch=branch_name, commit=commit)

schema = await self.sdk.schema.get(kind=InfrahubKind.ARTIFACTDEFINITION, branch=branch_name)
Expand Down Expand Up @@ -1770,7 +1766,9 @@ async def compare_python_transform(
return False
return True

async def import_all_python_files(self, branch_name: str, commit: str, config_file: InfrahubRepositoryConfig):
async def import_all_python_files(
self, branch_name: str, commit: str, config_file: InfrahubRepositoryConfig
) -> None:
await self.import_python_check_definitions(branch_name=branch_name, commit=commit, config_file=config_file)
await self.import_python_transforms(branch_name=branch_name, commit=commit, config_file=config_file)
await self.import_generator_definitions(branch_name=branch_name, commit=commit, config_file=config_file)
Expand Down Expand Up @@ -1820,7 +1818,7 @@ async def get_file(self, commit: str, location: str) -> str:

return path.read_text(encoding="UTF-8")

async def render_jinja2_template(self, commit: str, location: str, data: dict):
async def render_jinja2_template(self, commit: str, location: str, data: dict) -> str:
commit_worktree = self.get_commit_worktree(commit=commit)

self.validate_location(commit=commit, worktree_directory=commit_worktree.directory, file_path=location)
Expand All @@ -1831,7 +1829,7 @@ async def render_jinja2_template(self, commit: str, location: str, data: dict):
template = templateEnv.get_template(location)
return template.render(**data)
except Exception as exc:
log.error(exc, exc_info=True, repository=self.name, commit=commit, location=location)
log.error(str(exc), exc_info=True, repository=self.name, commit=commit, location=location)
raise TransformError(repository_name=self.name, commit=commit, location=location, message=str(exc)) from exc

async def execute_python_check(
Expand Down Expand Up @@ -1894,7 +1892,7 @@ async def execute_python_check(

except Exception as exc:
log.critical(
exc,
str(exc),
exc_info=True,
repository=self.name,
branch=branch_name,
Expand Down Expand Up @@ -2090,7 +2088,7 @@ class InfrahubRepository(InfrahubRepositoryBase):
"""

def get_commit_value(self, branch_name: str, remote: bool = False) -> str:
branches = None
branches = {}
if remote:
branches = self.get_branches_from_remote()
else:
Expand Down Expand Up @@ -2265,6 +2263,14 @@ async def rebase(self, branch_name: str, source_branch: str = "main", push_remot

return response

@classmethod
async def new(cls, service: Optional[InfrahubServices] = None, **kwargs: Any) -> InfrahubRepository:
service = service or InfrahubServices()
self = cls(service=service, **kwargs)
await self.create_locally()
log.info("Created the new project locally.", repository=self.name)
return self


class InfrahubReadOnlyRepository(InfrahubRepositoryBase):
"""
Expand All @@ -2278,7 +2284,7 @@ class InfrahubReadOnlyRepository(InfrahubRepositoryBase):
)

@classmethod
async def new(cls, service: Optional[InfrahubServices] = None, **kwargs):
async def new(cls, service: Optional[InfrahubServices] = None, **kwargs: Any) -> InfrahubReadOnlyRepository:
service = service or InfrahubServices()

if "ref" not in kwargs or "infrahub_branch_name" not in kwargs:
Expand Down

0 comments on commit 02a8839

Please sign in to comment.