Skip to content

Commit

Permalink
retry all steps during self delete if the delete was not successful a…
Browse files Browse the repository at this point in the history
…nd change self-delete permission assignment to happen during container creation and update benchmarks to set more random seeds
  • Loading branch information
paulbkoch committed Oct 7, 2024
1 parent 9fc4eb5 commit 1dc0d4f
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 97 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ We also build on top of many great packages. Please check them out!

# Papers that use or compare EBMs

- [Challenging the Performance-Interpretability Trade-off: An Evaluation of Interpretable Machine Learning Models](https://arxiv.org/pdf/2409.14429)
- [Data Science with LLMs and Interpretable Models](https://arxiv.org/pdf/2402.14474v1.pdf)
- [DimVis: Interpreting Visual Clusters in Dimensionality Reduction With Explainable Boosting Machine](https://arxiv.org/pdf/2402.06885.pdf)
- [Distill knowledge of additive tree models into generalized linear models](https://detralytics.com/wp-content/uploads/2023/10/Detra-Note_Additive-tree-ensembles.pdf)
Expand Down
8 changes: 6 additions & 2 deletions docs/benchmarks/ebm-benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"force_recreate = False\n",
"exist_ok = True\n",
"TIMEOUT_SEC = 60 * 60 * 24 * 180 # 180 days\n",
"wheel_filepaths = ['interpret_core-0.6.3-py3-none-any.whl', 'powerlift-0.1.11-py3-none-any.whl']\n",
"wheel_filepaths = ['interpret_core-0.6.4-py3-none-any.whl', 'powerlift-0.1.12-py3-none-any.whl']\n",
"\n",
"import datetime\n",
"experiment_name = datetime.datetime.now().strftime('%Y_%m_%d_%H%M__') + 'myexperiment'\n",
Expand Down Expand Up @@ -230,6 +230,10 @@
" import warnings\n",
" import gc\n",
" import re\n",
" import random\n",
"\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
"\n",
" X, y = trial.task.data()\n",
"\n",
Expand Down Expand Up @@ -851,7 +855,7 @@
" executor = AzureContainerInstance(\n",
" store, azure_tenant_id, subscription_id, azure_client_id, credential,\n",
" resource_group=resource_group,\n",
" pip_install= requirements + \" interpret-core\",\n",
" pip_install=requirements + \" interpret-core\",\n",
" wheel_filepaths=wheel_filepaths,\n",
" n_running_containers=n_containers\n",
" )\n",
Expand Down
224 changes: 129 additions & 95 deletions python/powerlift/powerlift/run_azure/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,80 @@
"""This is called to run a trial by worker nodes (local / remote)."""


def assign_delete_permissions(
aci_client,
auth_client,
max_undead_containers,
credential,
subscription_id,
client_id,
resource_group_name,
container_groups,
):
from heapq import heappush, heappop
from datetime import datetime
import time
import uuid
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
from azure.mgmt.authorization import AuthorizationManagementClient
from azure.mgmt.authorization.models import RoleAssignmentCreateParameters
from azure.core.exceptions import HttpResponseError

# Contributor Role
role_definition_id = f"/subscriptions/{subscription_id}/providers/Microsoft.Authorization/roleDefinitions/b24988ac-6180-42a0-ab88-20f7382dd24c"

while max_undead_containers < len(container_groups):
_, container_group_name, started = heappop(container_groups)
try:
if started is not None:
if not started.done():
heappush(
container_groups,
(datetime.now(), container_group_name, started),
)
time.sleep(1)
continue
started = None

if aci_client is None:
aci_client = ContainerInstanceManagementClient(
credential, subscription_id
)

container_group = aci_client.container_groups.get(
resource_group_name, container_group_name
)

role_assignment_params1 = RoleAssignmentCreateParameters(
role_definition_id=role_definition_id,
principal_id=container_group.identity.principal_id,
principal_type="ServicePrincipal",
)
role_assignment_params2 = RoleAssignmentCreateParameters(
role_definition_id=role_definition_id,
principal_id=client_id,
principal_type="User",
)
scope = f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.ContainerInstance/containerGroups/{container_group_name}"

if auth_client is None:
auth_client = AuthorizationManagementClient(credential, subscription_id)

auth_client.role_assignments.create(
scope, str(uuid.uuid4()), role_assignment_params1
)
auth_client.role_assignments.create(
scope, str(uuid.uuid4()), role_assignment_params2
)
except HttpResponseError:
aci_client = None
auth_client = None
heappush(container_groups, (datetime.now(), container_group_name, started))
time.sleep(1)

return aci_client, auth_client


def run_azure_process(
experiment_id,
n_runners,
Expand All @@ -16,46 +90,33 @@ def run_azure_process(
):
startup_script = """
self_delete() {
echo "Attempt to self-delete this container group. Exit code was $1."
echo "Attempt to self-delete this container group. Exit code was: $1"
if [ $1 -ne 0 ]; then
echo "Waiting 10 mintues to allow inspection of the logs..."
sleep 600
fi
SUBSCRIPTION_ID=${SUBSCRIPTION_ID}
RESOURCE_GROUP_NAME=${RESOURCE_GROUP_NAME}
CONTAINER_GROUP_NAME=${CONTAINER_GROUP_NAME}
retry_count=0
while true; do
echo "Downloading azure tools."
curl -sL https://aka.ms/InstallAzureCLIDeb -o install_script.sh
exit_code=$?
if [ $exit_code -eq 0 ]; then
break
if [ $exit_code -ne 0 ]; then
echo "curl failed with exit code $exit_code."
fi
echo "curl failed with exit code $exit_code."
if [ $retry_count -ge 300 ]; then
echo "Maximum number of retries reached. Command failed."
exit 62
bash install_script.sh
exit_code=$?
if [ $exit_code -ne 0 ]; then
echo "Failed to install azure tools with exit code $exit_code. Attempting to delete anyway."
fi
retry_count=$((retry_count + 1))
echo "Sleeping."
sleep 300
echo "Retrying."
done
bash install_script.sh
exit_code=$?
if [ $exit_code -ne 0 ]; then
echo "Failed to install azure tools with exit code $exit_code. Attempting to delete anyway."
fi
SUBSCRIPTION_ID=${SUBSCRIPTION_ID}
RESOURCE_GROUP_NAME=${RESOURCE_GROUP_NAME}
CONTAINER_GROUP_NAME=${CONTAINER_GROUP_NAME}
if [ $1 -ne 0 ]; then
echo "Waiting 10 mintues to allow inspection of the logs..."
sleep 600
fi
retry_count=0
while true; do
echo "Logging into azure to delete this container."
az login --identity
exit_code=$?
Expand All @@ -78,8 +139,10 @@ def run_azure_process(
break
fi
retry_count=$((retry_count + 1))
done
echo "Retrying."
done
exit $1 # failed to self-kill the container we are running this on.
}
Expand Down Expand Up @@ -298,11 +361,10 @@ def run_azure_process(
import time
import uuid
from multiprocessing.pool import MaybeEncodingError

from heapq import heappush
from datetime import datetime
from azure.core.exceptions import HttpResponseError
from azure.identity import ClientSecretCredential
from azure.mgmt.authorization import AuthorizationManagementClient
from azure.mgmt.authorization.models import RoleAssignmentCreateParameters
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
from azure.mgmt.containerinstance.models import (
Container,
Expand All @@ -315,6 +377,8 @@ def run_azure_process(
)
from azure.mgmt.resource import ResourceManagementClient

max_undead_containers = 5

client_id = azure_json["client_id"]

if credential is None:
Expand All @@ -327,11 +391,15 @@ def run_azure_process(
resource_group_name = azure_json["resource_group"]
subscription_id = azure_json["subscription_id"]

aci_client = ContainerInstanceManagementClient(credential, subscription_id)
aci_client = None
auth_client = None
container_groups = []
res_client = ResourceManagementClient(credential, subscription_id)

# If this first call fails, then allow the Exception to propagate.
resource_group = res_client.resource_groups.get(resource_group_name)
resource_group_location = res_client.resource_groups.get(
resource_group_name
).location

container_resource_requests = ResourceRequests(
cpu=num_cores,
Expand All @@ -342,7 +410,6 @@ def run_azure_process(
)

container_group_names = set()
starts = []
for runner_id in range(n_runners):
container_group_name = f"powerlift-container-group-{batch_id}-{runner_id:04}"

Expand All @@ -366,85 +433,52 @@ def run_azure_process(
environment_variables=env_vars,
)
container_group = ContainerGroup(
location=resource_group.location,
location=resource_group_location,
containers=[container],
os_type=OperatingSystemTypes.linux,
restart_policy=ContainerGroupRestartPolicy.never,
identity={"type": "SystemAssigned"},
)

if aci_client is None:
aci_client = ContainerInstanceManagementClient(credential, subscription_id)

while True:
try:
# begin_create_or_update returns LROPoller,
# but this is only indicates when the containter is started
started = aci_client.container_groups.begin_create_or_update(
resource_group.name, container_group_name, container_group
resource_group_name, container_group_name, container_group
)
break
except HttpResponseError:
time.sleep(1)

starts.append(started)

container_group_names.add(container_group_name)
heappush(container_groups, (datetime.now(), container_group_name, started))
aci_client, auth_client = assign_delete_permissions(
aci_client,
auth_client,
max_undead_containers,
credential,
subscription_id,
client_id,
resource_group_name,
container_groups,
)

# make sure they have all started before exiting the process
for started in starts:
while True:
try:
while not started.done():
time.sleep(1)
break
except HttpResponseError:
time.sleep(1)
assign_delete_permissions(
aci_client,
auth_client,
0,
credential,
subscription_id,
client_id,
resource_group_name,
container_groups,
)

if delete_group_container_on_complete:
auth_client = AuthorizationManagementClient(credential, subscription_id)

# Contributor Role
role_definition_id = f"/subscriptions/{subscription_id}/providers/Microsoft.Authorization/roleDefinitions/b24988ac-6180-42a0-ab88-20f7382dd24c"

for container_group_name in container_group_names:
while True:
try:
container_group = aci_client.container_groups.get(
resource_group_name, container_group_name
)
break
except HttpResponseError:
time.sleep(1)

scope = f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.ContainerInstance/containerGroups/{container_group_name}"
role_assignment_params = RoleAssignmentCreateParameters(
role_definition_id=role_definition_id,
principal_id=container_group.identity.principal_id,
principal_type="ServicePrincipal",
)

while True:
try:
auth_client.role_assignments.create(
scope, str(uuid.uuid4()), role_assignment_params
)
break
except HttpResponseError:
time.sleep(1)

role_assignment_params = RoleAssignmentCreateParameters(
role_definition_id=role_definition_id,
principal_id=client_id,
principal_type="User",
)

while True:
try:
auth_client.role_assignments.create(
scope, str(uuid.uuid4()), role_assignment_params
)
break
except HttpResponseError:
time.sleep(1)

deletes = []
while len(container_group_names) != 0:
remove_after = []
Expand Down

0 comments on commit 1dc0d4f

Please sign in to comment.