diff --git a/README.md b/README.md index 11ff2075e..694a47c17 100644 --- a/README.md +++ b/README.md @@ -622,6 +622,7 @@ We also build on top of many great packages. Please check them out! # Papers that use or compare EBMs +- [Challenging the Performance-Interpretability Trade-off: An Evaluation of Interpretable Machine Learning Models](https://arxiv.org/pdf/2409.14429) - [Data Science with LLMs and Interpretable Models](https://arxiv.org/pdf/2402.14474v1.pdf) - [DimVis: Interpreting Visual Clusters in Dimensionality Reduction With Explainable Boosting Machine](https://arxiv.org/pdf/2402.06885.pdf) - [Distill knowledge of additive tree models into generalized linear models](https://detralytics.com/wp-content/uploads/2023/10/Detra-Note_Additive-tree-ensembles.pdf) diff --git a/docs/benchmarks/ebm-benchmark.ipynb b/docs/benchmarks/ebm-benchmark.ipynb index b8306311d..5adc67e32 100644 --- a/docs/benchmarks/ebm-benchmark.ipynb +++ b/docs/benchmarks/ebm-benchmark.ipynb @@ -17,7 +17,7 @@ "force_recreate = False\n", "exist_ok = True\n", "TIMEOUT_SEC = 60 * 60 * 24 * 180 # 180 days\n", - "wheel_filepaths = ['interpret_core-0.6.3-py3-none-any.whl', 'powerlift-0.1.11-py3-none-any.whl']\n", + "wheel_filepaths = ['interpret_core-0.6.4-py3-none-any.whl', 'powerlift-0.1.12-py3-none-any.whl']\n", "\n", "import datetime\n", "experiment_name = datetime.datetime.now().strftime('%Y_%m_%d_%H%M__') + 'myexperiment'\n", @@ -230,6 +230,10 @@ " import warnings\n", " import gc\n", " import re\n", + " import random\n", + "\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", "\n", " X, y = trial.task.data()\n", "\n", @@ -851,7 +855,7 @@ " executor = AzureContainerInstance(\n", " store, azure_tenant_id, subscription_id, azure_client_id, credential,\n", " resource_group=resource_group,\n", - " pip_install= requirements + \" interpret-core\",\n", + " pip_install=requirements + \" interpret-core\",\n", " wheel_filepaths=wheel_filepaths,\n", " n_running_containers=n_containers\n", " )\n", diff --git a/python/powerlift/powerlift/run_azure/__main__.py b/python/powerlift/powerlift/run_azure/__main__.py index d7c38e154..34cd00c68 100644 --- a/python/powerlift/powerlift/run_azure/__main__.py +++ b/python/powerlift/powerlift/run_azure/__main__.py @@ -1,6 +1,80 @@ """This is called to run a trial by worker nodes (local / remote).""" +def assign_delete_permissions( + aci_client, + auth_client, + max_undead_containers, + credential, + subscription_id, + client_id, + resource_group_name, + container_groups, +): + from heapq import heappush, heappop + from datetime import datetime + import time + import uuid + from azure.mgmt.containerinstance import ContainerInstanceManagementClient + from azure.mgmt.authorization import AuthorizationManagementClient + from azure.mgmt.authorization.models import RoleAssignmentCreateParameters + from azure.core.exceptions import HttpResponseError + + # Contributor Role + role_definition_id = f"/subscriptions/{subscription_id}/providers/Microsoft.Authorization/roleDefinitions/b24988ac-6180-42a0-ab88-20f7382dd24c" + + while max_undead_containers < len(container_groups): + _, container_group_name, started = heappop(container_groups) + try: + if started is not None: + if not started.done(): + heappush( + container_groups, + (datetime.now(), container_group_name, started), + ) + time.sleep(1) + continue + started = None + + if aci_client is None: + aci_client = ContainerInstanceManagementClient( + credential, subscription_id + ) + + container_group = aci_client.container_groups.get( + resource_group_name, container_group_name + ) + + role_assignment_params1 = RoleAssignmentCreateParameters( + role_definition_id=role_definition_id, + principal_id=container_group.identity.principal_id, + principal_type="ServicePrincipal", + ) + role_assignment_params2 = RoleAssignmentCreateParameters( + role_definition_id=role_definition_id, + principal_id=client_id, + principal_type="User", + ) + scope = f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.ContainerInstance/containerGroups/{container_group_name}" + + if auth_client is None: + auth_client = AuthorizationManagementClient(credential, subscription_id) + + auth_client.role_assignments.create( + scope, str(uuid.uuid4()), role_assignment_params1 + ) + auth_client.role_assignments.create( + scope, str(uuid.uuid4()), role_assignment_params2 + ) + except HttpResponseError: + aci_client = None + auth_client = None + heappush(container_groups, (datetime.now(), container_group_name, started)) + time.sleep(1) + + return aci_client, auth_client + + def run_azure_process( experiment_id, n_runners, @@ -16,7 +90,16 @@ def run_azure_process( ): startup_script = """ self_delete() { - echo "Attempt to self-delete this container group. Exit code was $1." + echo "Attempt to self-delete this container group. Exit code was: $1" + + if [ $1 -ne 0 ]; then + echo "Waiting 10 mintues to allow inspection of the logs..." + sleep 600 + fi + + SUBSCRIPTION_ID=${SUBSCRIPTION_ID} + RESOURCE_GROUP_NAME=${RESOURCE_GROUP_NAME} + CONTAINER_GROUP_NAME=${CONTAINER_GROUP_NAME} retry_count=0 while true; do @@ -24,38 +107,16 @@ def run_azure_process( curl -sL https://aka.ms/InstallAzureCLIDeb -o install_script.sh exit_code=$? - if [ $exit_code -eq 0 ]; then - break + if [ $exit_code -ne 0 ]; then + echo "curl failed with exit code $exit_code." fi - echo "curl failed with exit code $exit_code." - if [ $retry_count -ge 300 ]; then - echo "Maximum number of retries reached. Command failed." - exit 62 + bash install_script.sh + exit_code=$? + if [ $exit_code -ne 0 ]; then + echo "Failed to install azure tools with exit code $exit_code. Attempting to delete anyway." fi - retry_count=$((retry_count + 1)) - echo "Sleeping." - sleep 300 - echo "Retrying." - done - - bash install_script.sh - exit_code=$? - if [ $exit_code -ne 0 ]; then - echo "Failed to install azure tools with exit code $exit_code. Attempting to delete anyway." - fi - - SUBSCRIPTION_ID=${SUBSCRIPTION_ID} - RESOURCE_GROUP_NAME=${RESOURCE_GROUP_NAME} - CONTAINER_GROUP_NAME=${CONTAINER_GROUP_NAME} - if [ $1 -ne 0 ]; then - echo "Waiting 10 mintues to allow inspection of the logs..." - sleep 600 - fi - - retry_count=0 - while true; do echo "Logging into azure to delete this container." az login --identity exit_code=$? @@ -78,8 +139,10 @@ def run_azure_process( break fi retry_count=$((retry_count + 1)) - done + echo "Retrying." + done + exit $1 # failed to self-kill the container we are running this on. } @@ -298,11 +361,10 @@ def run_azure_process( import time import uuid from multiprocessing.pool import MaybeEncodingError - + from heapq import heappush + from datetime import datetime from azure.core.exceptions import HttpResponseError from azure.identity import ClientSecretCredential - from azure.mgmt.authorization import AuthorizationManagementClient - from azure.mgmt.authorization.models import RoleAssignmentCreateParameters from azure.mgmt.containerinstance import ContainerInstanceManagementClient from azure.mgmt.containerinstance.models import ( Container, @@ -315,6 +377,8 @@ def run_azure_process( ) from azure.mgmt.resource import ResourceManagementClient + max_undead_containers = 5 + client_id = azure_json["client_id"] if credential is None: @@ -327,11 +391,15 @@ def run_azure_process( resource_group_name = azure_json["resource_group"] subscription_id = azure_json["subscription_id"] - aci_client = ContainerInstanceManagementClient(credential, subscription_id) + aci_client = None + auth_client = None + container_groups = [] res_client = ResourceManagementClient(credential, subscription_id) # If this first call fails, then allow the Exception to propagate. - resource_group = res_client.resource_groups.get(resource_group_name) + resource_group_location = res_client.resource_groups.get( + resource_group_name + ).location container_resource_requests = ResourceRequests( cpu=num_cores, @@ -342,7 +410,6 @@ def run_azure_process( ) container_group_names = set() - starts = [] for runner_id in range(n_runners): container_group_name = f"powerlift-container-group-{batch_id}-{runner_id:04}" @@ -366,85 +433,52 @@ def run_azure_process( environment_variables=env_vars, ) container_group = ContainerGroup( - location=resource_group.location, + location=resource_group_location, containers=[container], os_type=OperatingSystemTypes.linux, restart_policy=ContainerGroupRestartPolicy.never, identity={"type": "SystemAssigned"}, ) + if aci_client is None: + aci_client = ContainerInstanceManagementClient(credential, subscription_id) + while True: try: # begin_create_or_update returns LROPoller, # but this is only indicates when the containter is started started = aci_client.container_groups.begin_create_or_update( - resource_group.name, container_group_name, container_group + resource_group_name, container_group_name, container_group ) break except HttpResponseError: time.sleep(1) - starts.append(started) - container_group_names.add(container_group_name) + heappush(container_groups, (datetime.now(), container_group_name, started)) + aci_client, auth_client = assign_delete_permissions( + aci_client, + auth_client, + max_undead_containers, + credential, + subscription_id, + client_id, + resource_group_name, + container_groups, + ) - # make sure they have all started before exiting the process - for started in starts: - while True: - try: - while not started.done(): - time.sleep(1) - break - except HttpResponseError: - time.sleep(1) + assign_delete_permissions( + aci_client, + auth_client, + 0, + credential, + subscription_id, + client_id, + resource_group_name, + container_groups, + ) if delete_group_container_on_complete: - auth_client = AuthorizationManagementClient(credential, subscription_id) - - # Contributor Role - role_definition_id = f"/subscriptions/{subscription_id}/providers/Microsoft.Authorization/roleDefinitions/b24988ac-6180-42a0-ab88-20f7382dd24c" - - for container_group_name in container_group_names: - while True: - try: - container_group = aci_client.container_groups.get( - resource_group_name, container_group_name - ) - break - except HttpResponseError: - time.sleep(1) - - scope = f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.ContainerInstance/containerGroups/{container_group_name}" - role_assignment_params = RoleAssignmentCreateParameters( - role_definition_id=role_definition_id, - principal_id=container_group.identity.principal_id, - principal_type="ServicePrincipal", - ) - - while True: - try: - auth_client.role_assignments.create( - scope, str(uuid.uuid4()), role_assignment_params - ) - break - except HttpResponseError: - time.sleep(1) - - role_assignment_params = RoleAssignmentCreateParameters( - role_definition_id=role_definition_id, - principal_id=client_id, - principal_type="User", - ) - - while True: - try: - auth_client.role_assignments.create( - scope, str(uuid.uuid4()), role_assignment_params - ) - break - except HttpResponseError: - time.sleep(1) - deletes = [] while len(container_group_names) != 0: remove_after = []