Skip to content

UID2-4808 change policy to enum #378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@ public class AzureCCCoreAttestationService implements ICoreAttestationService {

private final IPolicyValidator policyValidator;

private final String azureCcProtocol;
private final Protocol azureCcProtocol;

public AzureCCCoreAttestationService(String maaServerBaseUrl, String attestationUrl, String azureCcProtocol) {
public AzureCCCoreAttestationService(String maaServerBaseUrl, String attestationUrl, Protocol azureCcProtocol) {
this(new MaaTokenSignatureValidator(maaServerBaseUrl), new PolicyValidator(attestationUrl), azureCcProtocol);
}

public AzureCCCoreAttestationService(String maaServerBaseUrl, String attestationUrl) {
this(new MaaTokenSignatureValidator(maaServerBaseUrl), new PolicyValidator(attestationUrl), Protocol.AZURE_CC_ACI);
}

// used in UT
protected AzureCCCoreAttestationService(IMaaTokenSignatureValidator tokenSignatureValidator, IPolicyValidator policyValidator, String azureCcProtocol) {
protected AzureCCCoreAttestationService(IMaaTokenSignatureValidator tokenSignatureValidator, IPolicyValidator policyValidator, Protocol azureCcProtocol) {
this.tokenSignatureValidator = tokenSignatureValidator;
this.policyValidator = policyValidator;
this.azureCcProtocol = azureCcProtocol;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ public void attest(byte[] attestationRequest, byte[] publicKey, Handler<AsyncRes

var enclaveId = this.validate(tokenPayload);
if (enclaveId != null) {
LOGGER.info("Successfully attested gcp-oidc against registered enclaves, enclave id: " + enclaveId);
LOGGER.info("Successfully attested {} against registered enclaves, enclave id: {}", Protocol.GCP_OIDC, enclaveId);
handler.handle(Future.succeededFuture(new AttestationResult(publicKey, enclaveId)));
} else {
LOGGER.warn("Can not find registered gcp-oidc enclave id.");
LOGGER.warn("Can not find registered {} enclave id.", Protocol.GCP_OIDC);
handler.handle(Future.succeededFuture(new AttestationResult(AttestationFailure.FORBIDDEN_ENCLAVE)));
}
}
Expand Down Expand Up @@ -93,10 +93,10 @@ private String validate(TokenPayload tokenPayload) throws Exception {
LOGGER.info("Validator version: " + policyValidator.getVersion() + ", result: " + enclaveId);

if (allowedEnclaveIds.contains(enclaveId)) {
LOGGER.info("Successfully attested gcp-oidc against registered enclaves");
LOGGER.info("Successfully attested {} against registered enclaves", Protocol.GCP_OIDC);
return enclaveId;
} else {
LOGGER.warn("Got unsupported gcp-oidc enclave id: " + enclaveId);
LOGGER.warn("Got unsupported {} enclave id: {}", Protocol.GCP_OIDC, enclaveId);
}
} catch (Exception ex) {
lastException = ex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private AttestationResult attestInternal(byte[] publicKey, AttestationRequest aR
return new AttestationResult(AttestationFailure.FORBIDDEN_ENCLAVE);
}

LOGGER.info("Successfully attested aws-nitro against registered enclaves, enclave id: " + id.toString());
LOGGER.info("Successfully attested {} against registered enclaves, enclave id: {}", Protocol.AWS_NITRO, id);
return new AttestationResult(aDoc.getPublicKey(), id.toString());
}

Expand Down
27 changes: 27 additions & 0 deletions src/main/java/com/uid2/shared/secure/Protocol.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.uid2.shared.secure;

public enum Protocol {
GCP_OIDC,
GCP_VMID,
AWS_NITRO,
AZURE_CC_ACI,
AZURE_CC_AKS;

public String toString() {
switch(this) {
case GCP_OIDC:
return "gcp-oidc";
case GCP_VMID:
return "gcp-vmid";
case AWS_NITRO:
return "aws-nitro";
case AZURE_CC_ACI:
return "azure-cc";
case AZURE_CC_AKS:
return "azure-cc-aks";
default:
return "unknown-protocol";
}

}
Comment on lines +4 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
GCP_OIDC,
GCP_VMID,
AWS_NITRO,
AZURE_CC_ACI,
AZURE_CC_AKS;
public String toString() {
switch(this) {
case GCP_OIDC:
return "gcp-oidc";
case GCP_VMID:
return "gcp-vmid";
case AWS_NITRO:
return "aws-nitro";
case AZURE_CC_ACI:
return "azure-cc";
case AZURE_CC_AKS:
return "azure-cc-aks";
default:
return "unknown-protocol";
}
}
GCP_OIDC("gcp-oidc"),
GCP_VMID("gcp-vmid"),
AWS_NITRO("aws-nitro"),
AZURE_CC_ACI("azure-cc"),
AZURE_CC_AKS("azure-cc-aks");
private final String value;
Protocol(String value) {
this.value = value;
}
@Override
public String toString() {
return this.value;
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.uid2.shared.secure.azurecc;

import com.uid2.shared.secure.AttestationException;
import com.uid2.shared.secure.Protocol;

public interface IMaaTokenSignatureValidator {
/**
Expand All @@ -10,5 +11,5 @@ public interface IMaaTokenSignatureValidator {
* @return Parsed token payload.
* @throws AttestationException
*/
MaaTokenPayload validate(String tokenString, String protocol) throws AttestationException;
MaaTokenPayload validate(String tokenString, Protocol protocol) throws AttestationException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
import com.uid2.shared.secure.AttestationClientException;
import com.uid2.shared.secure.AttestationException;
import com.uid2.shared.secure.AttestationFailure;
import com.uid2.shared.secure.Protocol;
import lombok.Builder;
import lombok.Value;

@Value
@Builder(toBuilder = true)
public class MaaTokenPayload {
public static final String SEV_SNP_VM_TYPE = "sevsnpvm";
public static final String AZURE_CC_ACI_PROTOCOL = "azure-cc";
public static final String AZURE_CC_AKS_PROTOCOL = "azure-cc-aks";
// the `x-ms-compliance-status` value for ACI CC
public static final String AZURE_COMPLIANT_UVM = "azure-compliant-uvm";
// the `x-ms-compliance-status` value for AKS CC
public static final String AZURE_COMPLIANT_UVM_AKS = "azure-signed-katacc-uvm";

private String azureProtocol;
private Protocol azureProtocol;
private String attestationType;
private String complianceStatus;
private boolean vmDebuggable;
Expand All @@ -30,9 +29,9 @@ public boolean isSevSnpVM(){
}

public boolean isUtilityVMCompliant() throws AttestationClientException {
if (azureProtocol == AZURE_CC_ACI_PROTOCOL) {
if (azureProtocol == Protocol.AZURE_CC_ACI) {
return AZURE_COMPLIANT_UVM.equalsIgnoreCase(complianceStatus);
} else if (azureProtocol == AZURE_CC_AKS_PROTOCOL) {
} else if (azureProtocol == Protocol.AZURE_CC_AKS) {
return AZURE_COMPLIANT_UVM_AKS.equalsIgnoreCase(complianceStatus);
} else {
throw new AttestationClientException(String.format("Azure protocol: %s not supported", azureProtocol), AttestationFailure.INVALID_PROTOCOL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.uid2.shared.secure.AttestationClientException;
import com.uid2.shared.secure.AttestationException;
import com.uid2.shared.secure.AttestationFailure;
import com.uid2.shared.secure.Protocol;

import java.io.IOException;
import java.util.Map;
Expand Down Expand Up @@ -51,7 +52,7 @@ private TokenVerifier buildTokenVerifier(String kid) throws AttestationException
}

@Override
public MaaTokenPayload validate(String tokenString, String protocol) throws AttestationException {
public MaaTokenPayload validate(String tokenString, Protocol protocol) throws AttestationException {
if (Strings.isNullOrEmpty(tokenString)) {
throw new IllegalArgumentException("tokenString can not be null or empty");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.uid2.shared.Utils;
import com.uid2.shared.secure.Protocol;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -116,14 +117,14 @@ public VmConfigId getVmConfigId(InstanceDocument id) {
String templatizedConfig = templatizeVmConfig(cloudInitConfig);
str.append(getSha256Base64Encoded(templatizedConfig));
} else if (forbiddenMetadataKeys.contains(metadataItem.getKey())) {
LOGGER.debug("gcp-vmid attestation got forbidden metadata key: " + metadataItem.getKey());
LOGGER.debug("{} attestation got forbidden metadata key: {}", Protocol.GCP_VMID, metadataItem.getKey());
return VmConfigId.failure("forbidden metadata key: " + metadataItem.getKey(), id.getProjectId());
}
}

String badAuditLog = findUnauthorizedAuditLog(id);
if (badAuditLog != null) {
LOGGER.debug("attestation failed because of audit log: " + badAuditLog);
LOGGER.debug("attestation failed because of audit log: {}", badAuditLog);
return VmConfigId.failure("bad audit log: " + badAuditLog, id.getProjectId());
}

Expand Down Expand Up @@ -205,7 +206,7 @@ private boolean validateAuditLog(AuditLog auditLog) {
if (allowedMethodsFromInstanceAuditLogs.contains(auditLog.getMethodName())) {
return true;
} else {
LOGGER.warn("gcp-vmid attestation receives unauthorized method: " + auditLog.getMethodName());
LOGGER.warn("{} attestation receives unauthorized method: {}", Protocol.GCP_VMID, auditLog.getMethodName());
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
Expand Down Expand Up @@ -65,8 +62,8 @@ public void setup() throws AttestationException {
}

@ParameterizedTest
@MethodSource("argumentProvider")
public void testHappyPath(String azureProtocol) throws AttestationException {
@EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"})
public void testHappyPath(Protocol azureProtocol) throws AttestationException {
var provider = new AzureCCCoreAttestationService(alwaysPassTokenValidator, alwaysPassPolicyValidator, azureProtocol);
provider.registerEnclave(ENCLAVE_ID);
attest(provider, ar -> {
Expand All @@ -76,8 +73,8 @@ public void testHappyPath(String azureProtocol) throws AttestationException {
}

@ParameterizedTest
@MethodSource("argumentProvider")
public void testSignatureCheckFailed_ClientError(String azureProtocol) throws AttestationException {
@EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"})
public void testSignatureCheckFailed_ClientError(Protocol azureProtocol) throws AttestationException {
var errorStr = "token signature validation failed";
when(alwaysFailTokenValidator.validate(any(), any())).thenThrow(new AttestationClientException(errorStr, AttestationFailure.BAD_PAYLOAD));
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator, azureProtocol);
Expand All @@ -90,8 +87,8 @@ public void testSignatureCheckFailed_ClientError(String azureProtocol) throws At
}

@ParameterizedTest
@MethodSource("argumentProvider")
public void testSignatureCheckFailed_ServerError(String azureProtocol) throws AttestationException {
@EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"})
public void testSignatureCheckFailed_ServerError(Protocol azureProtocol) throws AttestationException {
when(alwaysFailTokenValidator.validate(any(), any())).thenThrow(new AttestationException("unknown server error"));
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator, azureProtocol);
provider.registerEnclave(ENCLAVE_ID);
Expand All @@ -102,8 +99,8 @@ public void testSignatureCheckFailed_ServerError(String azureProtocol) throws At
}

@ParameterizedTest
@MethodSource("argumentProvider")
public void testPolicyCheckSuccess_ClientError(String azureProtocol) throws AttestationException {
@EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"})
public void testPolicyCheckSuccess_ClientError(Protocol azureProtocol) throws AttestationException {
var errorStr = "policy validation failed";
when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationClientException(errorStr, AttestationFailure.BAD_PAYLOAD));
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysFailPolicyValidator, azureProtocol);
Expand All @@ -116,8 +113,8 @@ public void testPolicyCheckSuccess_ClientError(String azureProtocol) throws Atte
}

@ParameterizedTest
@MethodSource("argumentProvider")
public void testPolicyCheckFailed_ServerError(String azureProtocol) throws AttestationException {
@EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"})
public void testPolicyCheckFailed_ServerError(Protocol azureProtocol) throws AttestationException {
when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationException("unknown server error"));
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysFailPolicyValidator, azureProtocol);
provider.registerEnclave(ENCLAVE_ID);
Expand All @@ -128,8 +125,8 @@ public void testPolicyCheckFailed_ServerError(String azureProtocol) throws Attes
}

@ParameterizedTest
@MethodSource("argumentProvider")
public void testEnclaveNotRegistered(String azureProtocol) throws AttestationException {
@EnumSource(value = Protocol.class, names = {"AZURE_CC_ACI", "AZURE_CC_AKS"})
public void testEnclaveNotRegistered(Protocol azureProtocol) throws AttestationException {
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator, azureProtocol);
attest(provider, ar -> {
assertTrue(ar.succeeded());
Expand All @@ -144,11 +141,4 @@ private static void attest(ICoreAttestationService provider, Handler<AsyncResult
PUBLIC_KEY.getBytes(StandardCharsets.UTF_8),
handler);
}

static Stream<Arguments> argumentProvider() {
return Stream.of(
Arguments.of(MaaTokenPayload.AZURE_CC_ACI_PROTOCOL),
Arguments.of(MaaTokenPayload.AZURE_CC_AKS_PROTOCOL)
);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.uid2.shared.secure.azurecc;

import com.uid2.shared.secure.AttestationException;
import com.uid2.shared.secure.Protocol;
import com.uid2.shared.secure.TestClock;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.params.ParameterizedTest;
Expand All @@ -16,7 +17,7 @@
public class MaaTokenSignatureValidatorTest {
@ParameterizedTest
@MethodSource("argumentProvider")
public void testPayload(String payloadPath, String protocol) throws Exception {
public void testPayload(String payloadPath, Protocol protocol) throws Exception {
// expire at 1695313895
var payload = loadFromJson(payloadPath);
var clock = new TestClock();
Expand All @@ -41,13 +42,13 @@ public void testE2E() throws AttestationException {
var maaToken = "<Placeholder>";
var maaServerUrl = "https://sharedeus.eus.attest.azure.net";
var validator = new MaaTokenSignatureValidator(maaServerUrl);
var token = validator.validate(maaToken, MaaTokenPayload.AZURE_CC_ACI_PROTOCOL);
var token = validator.validate(maaToken, Protocol.AZURE_CC_ACI);
}

static Stream<Arguments> argumentProvider() {
return Stream.of(
Arguments.of("/com.uid2.shared/test/secure/azurecc/jwt_payload_aci.json", MaaTokenPayload.AZURE_CC_ACI_PROTOCOL),
Arguments.of("/com.uid2.shared/test/secure/azurecc/jwt_payload_aks.json", MaaTokenPayload.AZURE_CC_AKS_PROTOCOL)
Arguments.of("/com.uid2.shared/test/secure/azurecc/jwt_payload_aci.json", Protocol.AZURE_CC_ACI),
Arguments.of("/com.uid2.shared/test/secure/azurecc/jwt_payload_aks.json", Protocol.AZURE_CC_AKS)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.google.gson.JsonObject;
import com.uid2.shared.Const;
import com.uid2.shared.secure.AttestationException;
import com.uid2.shared.secure.Protocol;

import java.security.KeyPairGenerator;
import java.security.PublicKey;
Expand All @@ -14,7 +15,7 @@
public class MaaTokenUtils {
public static final String MAA_BASE_URL = "https://sharedeus.eus.attest.azure.net";

public static MaaTokenPayload validateAndParseToken(JsonObject payload, Clock clock, String protocol) throws Exception{
public static MaaTokenPayload validateAndParseToken(JsonObject payload, Clock clock, Protocol protocol) throws Exception{
var gen = KeyPairGenerator.getInstance(Const.Name.AsymetricEncryptionKeyClass);
gen.initialize(2048, new SecureRandom());
var keyPair = gen.generateKeyPair();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.uid2.shared.secure.AttestationClientException;
import com.uid2.shared.secure.AttestationException;
import com.uid2.shared.secure.AttestationFailure;
import com.uid2.shared.secure.Protocol;
import org.junit.jupiter.api.Test;

import java.nio.ByteBuffer;
Expand Down Expand Up @@ -97,7 +98,7 @@ private MaaTokenPayload generateBasicPayload() {
.vmDebuggable(false)
.runtimeData(generateBasicRuntimeData())
.ccePolicyDigest(CCE_POLICY_DIGEST)
.azureProtocol(MaaTokenPayload.AZURE_CC_ACI_PROTOCOL)
.azureProtocol(Protocol.AZURE_CC_ACI)
.build();
}

Expand Down Expand Up @@ -145,7 +146,7 @@ public void testValidationSuccess_AksWithAzureSignedKataccUvm() throws Attestati
var aksPayload = generateBasicPayload()
.toBuilder()
.complianceStatus("azure-signed-katacc-uvm")
.azureProtocol(MaaTokenPayload.AZURE_CC_AKS_PROTOCOL)
.azureProtocol(Protocol.AZURE_CC_AKS)
.build();
var enclaveId = validator.validate(aksPayload, PUBLIC_KEY);
assertEquals(CCE_POLICY_DIGEST, enclaveId);
Expand All @@ -157,22 +158,11 @@ public void testValidationFailure_AksWithOtherUvm() {
var aksPayload = generateBasicPayload()
.toBuilder()
.complianceStatus("fake-compliance")
.azureProtocol(MaaTokenPayload.AZURE_CC_AKS_PROTOCOL)
.azureProtocol(Protocol.AZURE_CC_AKS)
.build();
Throwable t = assertThrows(AttestationException.class, ()-> validator.validate(aksPayload, PUBLIC_KEY));
assertEquals("Not run in Azure Compliance Utility VM", t.getMessage());
assertEquals(AttestationFailure.BAD_FORMAT, ((AttestationClientException)t).getAttestationFailure());
}

@Test
public void testValidationFailure_InvalidProtocol() {
var validator = new PolicyValidator(ATTESTATION_URL);
var aksPayload = generateBasicPayload()
.toBuilder()
.azureProtocol("fake-protocol")
.build();
Throwable t = assertThrows(AttestationException.class, ()-> validator.validate(aksPayload, PUBLIC_KEY));
assertEquals("Azure protocol: fake-protocol not supported", t.getMessage());
assertEquals(AttestationFailure.INVALID_PROTOCOL, ((AttestationClientException)t).getAttestationFailure());
}
}
Loading