Skip to content

Commit

Permalink
add test cases for tgpt validator
Browse files Browse the repository at this point in the history
  • Loading branch information
RockfordMankiniUCSD committed Oct 7, 2024
1 parent 3babee6 commit 8191929
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 19 deletions.
8 changes: 5 additions & 3 deletions src/dsmlp/app/tritongpt_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ def __init__(self, kube: KubeClient, logger: Logger) -> None:

def validate_pod(self, request: Request):

permitted_uids = self.kube.get_tgpt_uids()
namespace = self.kube.get_namespace(request.namespace)

permitted_uids = self.kube.get_tgpt_uids(namespace)
requested_uid = request.object.spec.securityContext.runAsUser

# if request.uid is not in kube.get_tgpt_uids
# return validationfailure
if requested_uid not in permitted_uids:
raise ValidationFailure(f"TritonGPT Validator: user with {permitted_uids} attempted to run a pod as {requested_uid}. Pod denied.")
if str(requested_uid) not in permitted_uids:
raise ValidationFailure(f"TritonGPT Validator: user with access to UIDs {permitted_uids} attempted to run a pod as {requested_uid}. Pod denied.")
10 changes: 7 additions & 3 deletions src/dsmlp/app/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,17 @@ def handle_request(self, request: Request):

def validate_pod(self, request: Request):

### if tgpt-validator == enabled
### run special tritongpt validator that gets permitted UIDs from namespace instead of sicad
try:
if(self.kube.get_tgpt_label(request.namespace) == "enabled"):
namespace = self.kube.get_namespace(request.namespace)

if(self.kube.get_tgpt_label(namespace) == "enabled"):
self.logger.info("Triton GPT Mode Activated. Only running TritonGPT Validator.")
TritonGPTValidator(self.kube, self.logger).validate_pod(request)
return
except:
self.logger.info("Failed to evaluate TGPT label logic. Falling back on regular validator components.")
except Exception as err:
self.logger.exception(err)

for component_validator in self.component_validators:
component_validator.validate_pod(request)
Expand Down
17 changes: 4 additions & 13 deletions src/dsmlp/ext/kube.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,14 @@ def get_gpus_in_namespace(self, name: str) -> int:

return gpu_count

def get_tgpt_label(self, name: str) -> str:
api = self.get_policy_api()
v1namespace: V1Namespace = api.read_namespace(name=name)
metadata: V1ObjectMeta = v1namespace.metadata

if metadata is not None and metadata.labels is not None and "tgpt-validator" in metadata.labels:
return metadata.labels["tgpt-validator"]
def get_tgpt_label(self, namespace) -> str:
return namespace.labels.get("tgt-validator","")

# TODO: make arbitrary function of getting namespace labels.
def get_tgpt_uids(self, name: str) -> str:
api = self.get_policy_api()
v1namespace: V1Namespace = api.read_namespace(name=name)
metadata: V1ObjectMeta = v1namespace.metadata
def get_tgpt_uids(self, namespace) -> str:

# should be comma delimited, i.e. 2000,100,2,20
if metadata is not None and metadata.labels is not None and "permitted-uids" in metadata.labels:
return metadata.labels["permitted-uids"].split(',')
return namespace.labels.get("permitted-uids", "").split(',')

# noinspection PyMethodMayBeStatic

Expand Down
165 changes: 165 additions & 0 deletions tests/app/test_tgpt_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import inspect
from operator import contains
from dsmlp.app.validator import Validator
from dsmlp.plugin.awsed import ListTeamsResponse, TeamJson, UserResponse
from dsmlp.plugin.kube import Namespace
from hamcrest import assert_that, contains_inanyorder, equal_to, has_item
from tests.fakes import FakeAwsedClient, FakeLogger, FakeKubeClient


class TestTGPTValidator:
def setup_method(self) -> None:
self.logger = FakeLogger()
self.awsed_client = FakeAwsedClient()
self.kube_client = FakeKubeClient()

self.awsed_client.add_user(
'user10', UserResponse(uid=30, enrollments=[]))
self.awsed_client.add_teams('user10', ListTeamsResponse(
teams=[TeamJson(gid=1000)]
))

self.kube_client.add_namespace('user10', Namespace(
name='user10', labels={'k8s-sync': 'true', 'tgpt-validator': 'enabled', 'permitted-uids': '30,3000'}, gpu_quota=10))

self.awsed_client.add_user(
'user100', UserResponse(uid=10, enrollments=[]))
self.awsed_client.add_teams('user10', ListTeamsResponse(
teams=[TeamJson(gid=1000)]
))

self.kube_client.add_namespace('user100', Namespace(
name='user100', labels={'k8s-sync': 'true', 'tgpt-validator': 'disabled', 'permitted-uids': '10'}, gpu_quota=10))

def test_good_request(self):
self.when_validate(
{
"request": {
"uid": "705ab4f5-6393-11e8-b7cc-42010a800002",
"namespace": "user10",
"userInfo": {
"username": "system:kube-system"
},
"object": {
"metadata": {
"labels": {}
},
"spec": {
"containers": [{}],
"securityContext": {"runAsUser": 30},
},
}
}
}
)

assert_that(self.logger.messages, has_item(
f"INFO Allowed request username=system:kube-system namespace=user10 uid=705ab4f5-6393-11e8-b7cc-42010a800002"))

def test_good_request_2(self):
self.when_validate(
{
"request": {
"uid": "705ab4f5-6393-11e8-b7cc-42010a800002",
"namespace": "user10",
"userInfo": {
"username": "system:kube-system"
},
"object": {
"metadata": {
"labels": {}
},
"spec": {
"containers": [{}],
"securityContext": {"runAsUser": 3000},
},
}
}
}
)

assert_that(self.logger.messages, has_item(
f"INFO Allowed request username=system:kube-system namespace=user10 uid=705ab4f5-6393-11e8-b7cc-42010a800002"))

def test_bad_request(self):
self.when_validate(
{
"request": {
"uid": "705ab4f5-6393-11e8-b7cc-42010a800002",
"namespace": "user10",
"userInfo": {
"username": "system:kube-system"
},
"object": {
"metadata": {
"labels": {}
},
"spec": {
"containers": [{}],
"securityContext": {"runAsUser": 300},
},
}
}
}
)

assert_that(self.logger.messages, has_item(
f"EXCEPTION TritonGPT Validator: user with access to UIDs ['30', '3000'] attempted to run a pod as 300. Pod denied."))

def test_good_request_not_enabled_permitted_on(self):
self.when_validate(
{
"request": {
"uid": "705ab4f5-6393-11e8-b7cc-42010a800002",
"namespace": "user100",
"userInfo": {
"username": "system:kube-system"
},
"object": {
"metadata": {
"labels": {}
},
"spec": {
"containers": [{}],
"securityContext": {"runAsUser": 10},
},
}
}
}
)

assert_that(self.logger.messages, has_item(
f"INFO Allowed request username=system:kube-system namespace=user100 uid=705ab4f5-6393-11e8-b7cc-42010a800002"))

#assert_that(self.logger.messages, has_item(
#"INFO Allowed request username=user10 namespace=user10 uid=705ab4f5-6393-11e8-b7cc-42010a800002"))

# def test_gpu_quota_request(self):
# self.awsed_client.add_user_gpu_quota('user10', 10)
# self.awsed_client.get_user_gpu_quota('user10')

# response = self.when_validate(
# {
# "request": {
# "uid": "705ab4f5-6393-11e8-b7cc-42010a800002",
# "namespace": "user10",
# "userInfo": {
# "username": "user10"
# },
# "object": {
# "metadata": {
# "labels": {}
# },
# "spec": {
# "containers": [{}]
# }
# }
# }
# }
# )

def when_validate(self, json):
validator = Validator(self.awsed_client, self.kube_client, self.logger)
response = validator.validate_request(json)

return response
12 changes: 12 additions & 0 deletions tests/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ def add_namespace(self, name: str, namespace: Namespace):
def set_existing_gpus(self, name: str, gpus: int):
self.existing_gpus[name] = gpus

def get_tgpt_label(self, namespace) -> str:
try:
return namespace.labels.get("tgpt-validator", "")
except KeyError:
raise UnsuccessfulRequest()

def get_tgpt_uids(self, namespace) -> str:
try:
return namespace.labels.get("permitted-uids").split(',')
except KeyError:
raise UnsuccessfulRequest()


class FakeLogger(Logger):
def __init__(self) -> None:
Expand Down

0 comments on commit 8191929

Please sign in to comment.