diff --git a/src/AppInstallerCLITests/DownloadFlow.cpp b/src/AppInstallerCLITests/DownloadFlow.cpp index 3e76830a33..77cc42750f 100644 --- a/src/AppInstallerCLITests/DownloadFlow.cpp +++ b/src/AppInstallerCLITests/DownloadFlow.cpp @@ -1,10 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #include "pch.h" +#include "TestHooks.h" +#include "AppInstallerRuntime.h" #include "WorkflowCommon.h" #include using namespace TestCommon; +using namespace AppInstaller; +using namespace AppInstaller::Authentication; using namespace AppInstaller::CLI; TEST_CASE("DownloadFlow_DownloadCommandProhibited", "[DownloadFlow][workflow]") @@ -20,5 +24,112 @@ TEST_CASE("DownloadFlow_DownloadCommandProhibited", "[DownloadFlow][workflow]") // Verify AppInfo is printed REQUIRE_TERMINATED_WITH(context, APPINSTALLER_CLI_ERROR_DOWNLOAD_COMMAND_PROHIBITED); - REQUIRE(downloadOutput.str().find(Resource::LocString(Resource::String::InstallerDownloadCommandProhibited).get()) != std::string::npos); + REQUIRE(downloadOutput.str().find(CLI::Resource::LocString(CLI::Resource::String::InstallerDownloadCommandProhibited).get()) != std::string::npos); +} + +AppInstaller::Utility::DownloadResult ValidateAzureBlobStorageAuthHeaders( + const std::string&, + const std::filesystem::path&, + AppInstaller::Utility::DownloadType, + AppInstaller::IProgressCallback&, + std::optional info) +{ + REQUIRE(info); + REQUIRE(info->RequestHeaders.size() > 0); + REQUIRE(info->RequestHeaders[0].IsAuth); + REQUIRE(info->RequestHeaders[0].Name == "Authorization"); + REQUIRE(info->RequestHeaders[0].Value == "Bearer TestToken"); + REQUIRE_FALSE(info->RequestHeaders[1].IsAuth); + REQUIRE(info->RequestHeaders[1].Name == "x-ms-version"); + // Not validating x-ms-version value + + AppInstaller::Utility::DownloadResult result; + result.Sha256Hash = AppInstaller::Utility::SHA256::ConvertToBytes("65DB2F2AC2686C7F2FD69D4A4C6683B888DC55BFA20A0E32CA9F838B51689A3B"); + return result; +} + +TEST_CASE("DownloadFlow_DownloadWithInstallerAuthenticationSuccess", "[DownloadFlow][workflow]") +{ + if (Runtime::IsRunningAsSystem()) + { + WARN("Test does not support running as system. Skipped."); + return; + } + + // Set authentication success result override + std::string expectedToken = "TestToken"; + AuthenticationResult authResultOverride; + authResultOverride.Status = S_OK; + authResultOverride.Token = expectedToken; + TestHook::SetAuthenticationResult_Override setAuthenticationResultOverride(authResultOverride); + + // Set auth header validation override + TestHook::SetDownloadResult_Function_Override downloadFunctionOverride({ &ValidateAzureBlobStorageAuthHeaders }); + + std::ostringstream downloadOutput; + TestContext context{ downloadOutput, std::cin }; + auto previousThreadGlobals = context.SetForCurrentThread(); + context.Args.AddArg(Execution::Args::Type::Manifest, TestDataFile("ManifestV1_10-InstallerAuthentication.yaml").GetPath().u8string()); + + DownloadCommand download({}); + download.Execute(context); + INFO(downloadOutput.str()); + + // Verify success + REQUIRE_FALSE(context.IsTerminated()); + REQUIRE(context.GetTerminationHR() == S_OK); +} + +TEST_CASE("DownloadFlow_DownloadWithInstallerAuthenticationNotSupported", "[DownloadFlow][workflow]") +{ + if (Runtime::IsRunningAsSystem()) + { + WARN("Test does not support running as system. Skipped."); + return; + } + + // Set authentication failed result + AuthenticationResult authResultOverride; + authResultOverride.Status = APPINSTALLER_CLI_ERROR_AUTHENTICATION_TYPE_NOT_SUPPORTED; + TestHook::SetAuthenticationResult_Override setAuthenticationResultOverride(authResultOverride); + + std::ostringstream downloadOutput; + TestContext context{ downloadOutput, std::cin }; + auto previousThreadGlobals = context.SetForCurrentThread(); + context.Args.AddArg(Execution::Args::Type::Manifest, TestDataFile("ManifestV1_10-InstallerAuthentication.yaml").GetPath().u8string()); + + DownloadCommand download({}); + download.Execute(context); + INFO(downloadOutput.str()); + + // Verify AppInfo is printed + REQUIRE_TERMINATED_WITH(context, APPINSTALLER_CLI_ERROR_AUTHENTICATION_TYPE_NOT_SUPPORTED); + REQUIRE(downloadOutput.str().find(CLI::Resource::LocString(CLI::Resource::String::InstallerDownloadAuthenticationNotSupported).get()) != std::string::npos); +} + +TEST_CASE("DownloadFlow_DownloadWithInstallerAuthenticationFailed", "[DownloadFlow][workflow]") +{ + if (Runtime::IsRunningAsSystem()) + { + WARN("Test does not support running as system. Skipped."); + return; + } + + // Set authentication failed result + AuthenticationResult authResultOverride; + authResultOverride.Status = APPINSTALLER_CLI_ERROR_AUTHENTICATION_FAILED; + TestHook::SetAuthenticationResult_Override setAuthenticationResultOverride(authResultOverride); + + std::ostringstream downloadOutput; + TestContext context{ downloadOutput, std::cin }; + auto previousThreadGlobals = context.SetForCurrentThread(); + context.Args.AddArg(Execution::Args::Type::Manifest, TestDataFile("ManifestV1_10-InstallerAuthentication.yaml").GetPath().u8string()); + + DownloadCommand download({}); + download.Execute(context); + INFO(downloadOutput.str()); + + // Verify AppInfo is printed + REQUIRE_TERMINATED_WITH(context, APPINSTALLER_CLI_ERROR_AUTHENTICATION_FAILED); + REQUIRE(downloadOutput.str().find(CLI::Resource::LocString(CLI::Resource::String::InstallerDownloadAuthenticationFailed).get()) != std::string::npos); } diff --git a/src/AppInstallerCLITests/TestHooks.h b/src/AppInstallerCLITests/TestHooks.h index 3917c8d1c2..3d862c8270 100644 --- a/src/AppInstallerCLITests/TestHooks.h +++ b/src/AppInstallerCLITests/TestHooks.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -93,6 +94,16 @@ namespace AppInstaller void SetLicensingHttpPipelineStage_Override(std::shared_ptr value); } + + namespace Utility::TestHooks + { + void SetDownloadResult_Function_Override(std::function info)>* value); + } } namespace TestHook @@ -301,4 +312,30 @@ namespace TestHook AppInstaller::MSStore::TestHooks::SetLicensingHttpPipelineStage_Override(nullptr); } }; + + struct SetDownloadResult_Function_Override + { + SetDownloadResult_Function_Override(std::function info)> value) : m_downloadFunction(std::move(value)) + { + AppInstaller::Utility::TestHooks::SetDownloadResult_Function_Override(&m_downloadFunction); + } + + ~SetDownloadResult_Function_Override() + { + AppInstaller::Utility::TestHooks::SetDownloadResult_Function_Override(nullptr); + } + + private: + std::function info)> m_downloadFunction; + }; } diff --git a/src/AppInstallerCommonCore/Downloader.cpp b/src/AppInstallerCommonCore/Downloader.cpp index 037eb08c37..13c1be3b8c 100644 --- a/src/AppInstallerCommonCore/Downloader.cpp +++ b/src/AppInstallerCommonCore/Downloader.cpp @@ -102,6 +102,28 @@ namespace AppInstaller::Utility } } +#ifndef AICLI_DISABLE_TEST_HOOKS + namespace TestHooks + { + static std::function info)>* s_Download_Function_Override = nullptr; + + void SetDownloadResult_Function_Override(std::function info)>* value) + { + s_Download_Function_Override = value; + } + } +#endif + DownloadResult WinINetDownloadToStream( const std::string& url, std::ostream& dest, @@ -320,6 +342,13 @@ namespace AppInstaller::Utility IProgressCallback& progress, std::optional info) { +#ifndef AICLI_DISABLE_TEST_HOOKS + if (TestHooks::s_Download_Function_Override) + { + return (*TestHooks::s_Download_Function_Override)(url, dest, type, progress, info); + } +#endif + THROW_HR_IF(E_INVALIDARG, url.empty()); THROW_HR_IF(E_INVALIDARG, dest.empty());