Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify AWS Date Override Behavior #121

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ private async Task<string> PostAuth(Server.VerifyResult result)

private async Task<string> MigrateShardToEnclave(Server.VerifyResult authResult)
{
// TODO: For recovery code, allow old encryption keys as overrides to migrate sharded custom auth?
var (address, encryptedPrivateKeyB64, ivB64, kmsCiphertextB64) = await this.EmbeddedWallet
.GenerateEncryptionDataAsync(authResult.AuthToken, this.LegacyEncryptionKey ?? authResult.RecoveryCode)
.ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,66 +53,18 @@ private static async Task<AwsCredentials> GetTemporaryCredentialsAsync(string id
};
}

private static async Task<JToken> GenerateDataKey(AwsCredentials credentials, IThirdwebHttpClient httpClient, DateTime? dateOverride = null)
private static async Task<JToken> GenerateDataKey(AwsCredentials credentials, IThirdwebHttpClient httpClient)
{
var client = Utils.ReconstructHttpClient(httpClient);
var endpoint = $"https://kms.{AWS_REGION}.amazonaws.com/";

var payloadForGenerateDataKey = new { KeyId = _migrationKeyId, KeySpec = "AES_256" };
var requestBodyString = JsonConvert.SerializeObject(payloadForGenerateDataKey);

var content = new StringContent(JsonConvert.SerializeObject(payloadForGenerateDataKey), Encoding.UTF8, "application/x-amz-json-1.1");
var contentType = "application/x-amz-json-1.1";

client.AddHeader("X-Amz-Target", "TrentService.GenerateDataKey");
var extraHeaders = new Dictionary<string, string> { { "X-Amz-Target", "TrentService.GenerateDataKey" } };

var dateTimeNow = dateOverride ?? DateTime.UtcNow;
var dateStamp = dateTimeNow.ToString("yyyyMMdd");
var amzDateFormat = "yyyyMMddTHHmmssZ";
var amzDate = dateTimeNow.ToString(amzDateFormat);
var canonicalUri = "/";

var canonicalHeaders = $"host:kms.{AWS_REGION}.amazonaws.com\nx-amz-date:{amzDate}\n";
var signedHeaders = "host;x-amz-date";

#if NETSTANDARD
using var sha256 = SHA256.Create();
var payloadHash = ToHexString(sha256.ComputeHash(Encoding.UTF8.GetBytes(await content.ReadAsStringAsync())));
#else
var payloadHash = ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(await content.ReadAsStringAsync())));
#endif

var canonicalRequest = $"POST\n{canonicalUri}\n\n{canonicalHeaders}\n{signedHeaders}\n{payloadHash}";

var algorithm = "AWS4-HMAC-SHA256";
var credentialScope = $"{dateStamp}/{AWS_REGION}/kms/aws4_request";

#if NETSTANDARD
var stringToSign = $"{algorithm}\n{amzDate}\n{credentialScope}\n{ToHexString(sha256.ComputeHash(Encoding.UTF8.GetBytes(canonicalRequest)))}";
#else
var stringToSign = $"{algorithm}\n{amzDate}\n{credentialScope}\n{ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(canonicalRequest)))}";
#endif

var signingKey = GetSignatureKey(credentials.SecretAccessKey, dateStamp, AWS_REGION, "kms");
var signature = ToHexString(HMACSHA256(signingKey, stringToSign));

var authorizationHeader = $"{algorithm} Credential={credentials.AccessKeyId}/{credentialScope}, SignedHeaders={signedHeaders}, Signature={signature}";

client.AddHeader("x-amz-date", amzDate);
client.AddHeader("Authorization", authorizationHeader);
client.AddHeader("x-amz-security-token", credentials.SessionToken);

var response = await client.PostAsync(endpoint, content).ConfigureAwait(false);
var responseContent = await response.Content.ReadAsStringAsync().ConfigureAwait(false);

if (!response.IsSuccessStatusCode)
{
if (dateOverride == null && responseContent.Contains("InvalidSignatureException"))
{
var parsedTime = responseContent.Substring(responseContent.LastIndexOf('(') + 1, amzDate.Length);
return await GenerateDataKey(credentials, httpClient, DateTime.ParseExact(parsedTime, amzDateFormat, System.Globalization.CultureInfo.InvariantCulture).ToUniversalTime())
.ConfigureAwait(false);
}
throw new Exception($"Failed to generate data key: {responseContent}");
}
var responseContent = await PostAwsRequestWithDateOverride(credentials, httpClient, AWS_REGION, "kms", endpoint, "/", "", requestBodyString, contentType, extraHeaders).ConfigureAwait(false);

var responseObject = JToken.Parse(responseContent);
var plaintextKeyBlob = responseObject["Plaintext"];
Expand All @@ -129,54 +81,131 @@ private static async Task<JToken> GenerateDataKey(AwsCredentials credentials, IT
private static async Task<MemoryStream> InvokeLambdaWithTemporaryCredentialsAsync(AwsCredentials credentials, string invokePayload, IThirdwebHttpClient httpClient, string lambdaFunction)
{
var endpoint = $"https://lambda.{AWS_REGION}.amazonaws.com/2015-03-31/functions/{lambdaFunction}/invocations";
var requestBody = new StringContent(invokePayload, Encoding.UTF8, "application/json");
var contentType = "application/json";

var canonicalUri = $"/2015-03-31/functions/{Uri.EscapeDataString(lambdaFunction)}/invocations";
var canonicalQueryString = "";

var extraHeaders = new Dictionary<string, string>();

var responseContent = await PostAwsRequestWithDateOverride(
credentials,
httpClient,
AWS_REGION,
"lambda",
endpoint,
canonicalUri,
canonicalQueryString,
invokePayload,
contentType,
extraHeaders
)
.ConfigureAwait(false);

var memoryStream = new MemoryStream(Encoding.UTF8.GetBytes(responseContent));
return memoryStream;
}

private static async Task<string> PostAwsRequestWithDateOverride(
AwsCredentials credentials,
IThirdwebHttpClient httpClient,
string region,
string service,
string endpoint,
string canonicalUri,
string canonicalQueryString,
string requestBodyString,
string contentType,
Dictionary<string, string> extraHeaders,
DateTime? dateOverride = null
)
{
var client = Utils.ReconstructHttpClient(httpClient);

var dateTimeNow = DateTime.UtcNow;
if (extraHeaders != null)
{
foreach (var kvp in extraHeaders)
{
client.AddHeader(kvp.Key, kvp.Value);
}
}

var dateTimeNow = dateOverride ?? DateTime.UtcNow;
var amzDateFormat = "yyyyMMddTHHmmssZ";
var amzDate = dateTimeNow.ToString(amzDateFormat);
var dateStamp = dateTimeNow.ToString("yyyyMMdd");
var amzDate = dateTimeNow.ToString("yyyyMMddTHHmmssZ");

var canonicalUri = "/2015-03-31/functions/" + Uri.EscapeDataString(lambdaFunction) + "/invocations";
var canonicalQueryString = "";
var canonicalHeaders = $"host:lambda.{AWS_REGION}.amazonaws.com\nx-amz-date:{amzDate}\n";
var canonicalHeaders = $"host:{new Uri(endpoint).Host}\n" + $"x-amz-date:{amzDate}\n";
var signedHeaders = "host;x-amz-date";

#if NETSTANDARD
using var sha256 = SHA256.Create();
var payloadHash = ToHexString(sha256.ComputeHash(Encoding.UTF8.GetBytes(invokePayload)));
var payloadHash = ToHexString(sha256.ComputeHash(Encoding.UTF8.GetBytes(requestBodyString)));
#else
var payloadHash = ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(invokePayload)));
var payloadHash = ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(requestBodyString)));
#endif

var canonicalRequest = $"POST\n{canonicalUri}\n{canonicalQueryString}\n{canonicalHeaders}\n{signedHeaders}\n{payloadHash}";

var algorithm = "AWS4-HMAC-SHA256";
var credentialScope = $"{dateStamp}/{AWS_REGION}/lambda/aws4_request";
var credentialScope = $"{dateStamp}/{region}/{service}/aws4_request";
#if NETSTANDARD
var stringToSign = $"{algorithm}\n{amzDate}\n{credentialScope}\n{ToHexString(sha256.ComputeHash(Encoding.UTF8.GetBytes(canonicalRequest)))}";
#else
var stringToSign = $"{algorithm}\n{amzDate}\n{credentialScope}\n{ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(canonicalRequest)))}";
#endif

var signingKey = GetSignatureKey(credentials.SecretAccessKey, dateStamp, AWS_REGION, "lambda");
var signingKey = GetSignatureKey(credentials.SecretAccessKey, dateStamp, region, service);
var signature = ToHexString(HMACSHA256(signingKey, stringToSign));

var authorizationHeader = $"{algorithm} Credential={credentials.AccessKeyId}/{credentialScope}, SignedHeaders={signedHeaders}, Signature={signature}";

client.AddHeader("x-amz-date", amzDate);
client.AddHeader("Authorization", authorizationHeader);
client.AddHeader("x-amz-security-token", credentials.SessionToken);

var response = await client.PostAsync(endpoint, requestBody).ConfigureAwait(false);
if (!string.IsNullOrEmpty(credentials.SessionToken))
{
client.AddHeader("x-amz-security-token", credentials.SessionToken);
}

var content = new StringContent(requestBodyString, Encoding.UTF8, contentType);

var response = await client.PostAsync(endpoint, content).ConfigureAwait(false);
var responseContent = await response.Content.ReadAsStringAsync().ConfigureAwait(false);

if (!response.IsSuccessStatusCode)
{
throw new Exception($"Lambda invocation failed: {responseContent}");
if (dateOverride == null && responseContent.Contains("Signature expired"))
{
var idx = responseContent.LastIndexOf('(');
if (idx > -1)
{
var parsedTimeString = responseContent.Substring(idx + 1, amzDate.Length);
var serverTime = DateTime.ParseExact(parsedTimeString, amzDateFormat, System.Globalization.CultureInfo.InvariantCulture).ToUniversalTime();

Console.WriteLine($"Server time: {serverTime}");

return await PostAwsRequestWithDateOverride(
credentials,
httpClient,
region,
service,
endpoint,
canonicalUri,
canonicalQueryString,
requestBodyString,
contentType,
extraHeaders,
serverTime
)
.ConfigureAwait(false);
}
}

throw new Exception($"AWS request failed: {responseContent}");
}

var memoryStream = new MemoryStream(Encoding.UTF8.GetBytes(responseContent));
return memoryStream;
return responseContent;
}

private static byte[] HMACSHA256(byte[] key, string data)
Expand Down
Loading