Skip to content

Commit

Permalink
typechecks
Browse files Browse the repository at this point in the history
  • Loading branch information
lilatomic committed Jul 5, 2024
1 parent b1f7f70 commit 2c8402e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 15 deletions.
19 changes: 17 additions & 2 deletions llamazure/azrest/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel, BeforeValidator, Field

Ret_T = TypeVar("Ret_T")
Ret_T0 = TypeVar("Ret_T0")
ReadOnly = Optional[Ret_T]


Expand Down Expand Up @@ -61,9 +62,9 @@ def add_params(self, params: Dict[str, str]) -> Req:
def add_param(self, name: str, value: str) -> Req:
return dataclasses.replace(self, params={**self.params, **{name: value}})

Check warning on line 63 in llamazure/azrest/models.py

View check run for this annotation

Codecov / codecov/patch

llamazure/azrest/models.py#L63

Added line #L63 was not covered by tests

def with_ret_t(self, ret_t: Type[Ret_T]) -> Req:
def with_ret_t(self, ret_t: Type[Ret_T0]) -> Req[Ret_T0]:
"""Override the return type"""
return dataclasses.replace(self, ret_t=ret_t)
return dataclasses.replace(self, ret_t=ret_t) # type: ignore

Check warning on line 67 in llamazure/azrest/models.py

View check run for this annotation

Codecov / codecov/patch

llamazure/azrest/models.py#L67

Added line #L67 was not covered by tests


@dataclass
Expand Down Expand Up @@ -156,6 +157,20 @@ class AzureErrorAdditionInfo(BaseModel):


def ensure(a: Optional[T]) -> T:
"""Ensure the result is not None"""
if a is None:
raise TypeError("value was None")

Check warning on line 162 in llamazure/azrest/models.py

View check run for this annotation

Codecov / codecov/patch

llamazure/azrest/models.py#L162

Added line #L162 was not covered by tests
return a


P0 = TypeVar("P0", bound=BaseModel)
P1 = TypeVar("P1", bound=BaseModel)


def cast_as(obj: P0, cls: Type[P1]) -> P1:
"""
Cast one model into another.
Useful for turning a Foo into a FooUpdateParameters.
"""
return cls.model_validate(obj.model_dump())

Check warning on line 176 in llamazure/azrest/models.py

View check run for this annotation

Codecov / codecov/patch

llamazure/azrest/models.py#L176

Added line #L176 was not covered by tests
15 changes: 7 additions & 8 deletions llamazure_tools/migrate/dashboard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Migrate an Azure Dashboard to a different Log Analytics Workspace"""
import dataclasses
import json
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -9,10 +8,11 @@
from azure.identity import DefaultAzureCredential

from llamazure.azrest.azrest import AzRest
from llamazure.azrest.models import cast_as
from llamazure.rid import rid
from llamazure.rid.rid import Resource
from llamazure_tools.migrate.portal.r.m.portal.portal import AzDashboards, Dashboard, PatchableDashboard # pylint: disable=E0611,E0401
from llamazure_tools.migrate.util import JSONTraverser
from llamazure_tools.migrate.util import JSONTraverser, rid_params


@dataclass
Expand All @@ -35,7 +35,7 @@ def migrate(self):

def get_dashboard(self) -> dict:
"""Retrieve the current dashboard data from Azure."""
return self.az.call(dataclasses.replace(AzDashboards.Get(self.dashboard.sub.uuid, self.dashboard.rg.name, self.dashboard.name), ret_t=dict))
return self.az.call(AzDashboards.Get(*rid_params(self.dashboard)).with_ret_t(dict))

def transform(self, dashboard: dict) -> dict:
"""Transform the dashboard data using the provided transformer."""
Expand All @@ -44,12 +44,9 @@ def transform(self, dashboard: dict) -> dict:
def put_dashboard(self, transformed: dict):
"""Update the dashboard in Azure with the transformed data."""
d = Dashboard(**transformed)
p = PatchableDashboard(
properties=d.properties.model_dump(),
tags=d.tags,
)
p = cast_as(d, PatchableDashboard)
self.az.call(
AzDashboards.Update(self.dashboard.sub.uuid, self.dashboard.rg.name, self.dashboard.name, p),
AzDashboards.Update(*rid_params(self.dashboard), p),
)

def make_backup(self, dashboard: dict):
Expand All @@ -68,7 +65,9 @@ def migrate(resource_id: str, replacements: str, backup_directory: str):
az = AzRest.from_credential(DefaultAzureCredential())

replacements = json.loads(replacements)
assert isinstance(replacements, dict)
resource = rid.parse(resource_id)
assert isinstance(resource, rid.Resource)
transformer = JSONTraverser(replacements)
migrator = Migrator(az, resource, transformer, Path(backup_directory))

Expand Down
9 changes: 8 additions & 1 deletion llamazure_tools/migrate/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Utils"""
from dataclasses import dataclass
from typing import Any, Dict
from typing import Any, Dict, Tuple

from llamazure.azrest.models import ensure
from llamazure.rid import rid


@dataclass
Expand All @@ -19,3 +22,7 @@ def traverse(self, obj: Any) -> Any:
return self.replacements.get(obj, obj)
else:
return obj


def rid_params(res: rid.Resource) -> Tuple[str, str, str]:
return ensure(res.sub).uuid, ensure(res.rg).name, res.name
11 changes: 7 additions & 4 deletions llamazure_tools/migrate/workbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from azure.identity import DefaultAzureCredential

from llamazure.azrest.azrest import AzRest
from llamazure.azrest.models import cast_as
from llamazure.rid import rid
from llamazure.rid.rid import Resource
from llamazure_tools.migrate.applicationinsights.r.m.insights.workbooks import AzWorkbooks, Workbook # pylint: disable=E0611,E0401
from llamazure_tools.migrate.util import JSONTraverser
from llamazure_tools.migrate.applicationinsights.r.m.insights.workbooks import AzWorkbooks, Workbook, WorkbookUpdateParameters # pylint: disable=E0611,E0401
from llamazure_tools.migrate.util import JSONTraverser, rid_params


@dataclass
Expand All @@ -33,7 +34,7 @@ def migrate(self):

def get_workbook(self) -> Workbook:
"""Retrieve the current workbook data from Azure."""
return self.az.call(AzWorkbooks.Get(self.workbook.sub.uuid, self.workbook.rg.name, self.workbook.name, canFetchContent=True))
return self.az.call(AzWorkbooks.Get(*rid_params(self.workbook), canFetchContent=True))

def transform(self, workbook: Workbook) -> Workbook:
"""Transform the workbook data using the provided transformer."""
Expand All @@ -43,7 +44,7 @@ def transform(self, workbook: Workbook) -> Workbook:
def put_workbook(self, transformed: Workbook):
"""Update the dashboard in Azure with the transformed data."""
self.az.call(
AzWorkbooks.Update(self.workbook.sub.uuid, self.workbook.rg.name, self.workbook.name, transformed),
AzWorkbooks.Update(*rid_params(self.workbook), cast_as(transformed, WorkbookUpdateParameters)),
)

def make_backup(self, workbook: Workbook):
Expand All @@ -62,7 +63,9 @@ def migrate(resource_id: str, replacements: str, backup_directory: str):
az = AzRest.from_credential(DefaultAzureCredential())

replacements = json.loads(replacements)
assert isinstance(replacements, dict)
resource = rid.parse(resource_id)
assert isinstance(resource, rid.Resource)
transformer = JSONTraverser(replacements)
migrator = Migrator(az, resource, transformer, Path(backup_directory))

Expand Down

0 comments on commit 2c8402e

Please sign in to comment.