diff --git a/src/common/service/NetRemoteService.cxx b/src/common/service/NetRemoteService.cxx index 21daf5d3..066a9f00 100644 --- a/src/common/service/NetRemoteService.cxx +++ b/src/common/service/NetRemoteService.cxx @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -8,8 +9,6 @@ #include #include -#include "NetRemoteApiTrace.hxx" -#include "NetRemoteWifiApiTrace.hxx" #include #include #include @@ -20,10 +19,14 @@ #include #include #include +#include #include #include #include +#include "NetRemoteApiTrace.hxx" +#include "NetRemoteWifiApiTrace.hxx" + using namespace Microsoft::Net::Remote::Service; using namespace Microsoft::Net::Remote::Service::Tracing; using namespace Microsoft::Net::Remote::Wifi; @@ -109,9 +112,6 @@ HandleFailure(RequestT& request, ResultT& result, const AccessPointOperationStat return HandleFailure(request, result, operationStatus.Code, operationStatus.ToString(), returnValue); } -std::shared_ptr -TryGetAccessPoint(std::string_view accessPointId); - /** * @brief Attempt to obtain an IAccessPoint instance for the access point in the specified request message. * @@ -300,14 +300,39 @@ NetRemoteService::WifiAccessPointEnable([[maybe_unused]] grpc::ServerContext* co { const NetRemoteWifiApiTrace traceMe{ request->accesspointid(), result->mutable_status() }; + // Create an AP controller for the requested AP. + auto accessPointController = detail::TryGetAccessPointController(request, result, m_accessPointManager); + if (accessPointController == nullptr) { + return grpc::Status::OK; + } + + // Obtain current operational state. + AccessPointOperationalState operationalState{}; + auto operationStatus = accessPointController->GetOperationalState(operationalState); + if (!operationStatus) { + return HandleFailure(request, result, operationStatus.Code, std::format("Failed to get operational state for access point {}", request->accesspointid())); + } + WifiAccessPointOperationStatus status{}; - // Validate request is well-formed and has all required parameters. - if (ValidateWifiAccessPointEnableRequest(request, status)) { - // TODO: Enable the access point. - status.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded); + // Enable the access point if it's not already enabled. + if (operationalState != AccessPointOperationalState::Enabled) { + // Validate request is well-formed and has all required parameters. + if (ValidateWifiAccessPointEnableRequest(request, status)) { + // TODO: Enable the access point. + + // Set the operational state to 'enabled' now that initial configuration has been set. + operationStatus = accessPointController->SetOperationalState(AccessPointOperationalState::Enabled); + if (!operationStatus) { + return HandleFailure(request, result, operationStatus.Code, std::format("Failed to set operational state to 'enabled' for access point {}", request->accesspointid())); + } + } + } else { + LOGI << std::format("Access point {} is already enabled", request->accesspointid()); } + status.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded); + result->set_accesspointid(request->accesspointid()); *result->mutable_status() = std::move(status); @@ -364,67 +389,16 @@ NetRemoteService::WifiAccessPointSetPhyType([[maybe_unused]] grpc::ServerContext return grpc::Status::OK; } -/* static */ -bool -NetRemoteService::ValidateWifiSetFrequencyBandsRequest(const WifiAccessPointSetFrequencyBandsRequest* request, WifiAccessPointSetFrequencyBandsResult* result) -{ - const auto& frequencyBands = request->frequencybands(); - - if (std::empty(frequencyBands)) { - return HandleFailure(request, result, WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter, "No frequency band provided", false); - } - if (std::ranges::find(frequencyBands, Dot11FrequencyBand::Dot11FrequencyBandUnknown) != std::cend(frequencyBands)) { - return HandleFailure(request, result, WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter, "Invalid frequency band provided", false); - } - - return true; -} - grpc::Status NetRemoteService::WifiAccessPointSetFrequencyBands([[maybe_unused]] grpc::ServerContext* context, const WifiAccessPointSetFrequencyBandsRequest* request, WifiAccessPointSetFrequencyBandsResult* result) { const NetRemoteWifiApiTrace traceMe{ request->accesspointid(), result->mutable_status() }; - // Validate basic parameters in the request. - if (!ValidateWifiSetFrequencyBandsRequest(request, result)) { - return grpc::Status::OK; - } - - // Create an AP controller for the requested AP. - auto accessPointController = detail::TryGetAccessPointController(request, result, m_accessPointManager); - if (accessPointController == nullptr) { - return grpc::Status::OK; - } - - // Convert dot11 bands to ieee80211 bands. - auto ieee80211FrequencyBands = FromDot11SetFrequencyBandsRequest(*request); - - // Obtain capabilities of the access point. - Ieee80211AccessPointCapabilities accessPointCapabilities{}; - auto operationStatus = accessPointController->GetCapabilities(accessPointCapabilities); - if (!operationStatus) { - return HandleFailure(request, result, operationStatus.Code, std::format("Failed to get capabilities for access point {}", request->accesspointid())); - } - - // Check if requested bands are supported by the AP. - for (const auto& requestedFrequencyBand : ieee80211FrequencyBands) { - if (std::ranges::find(accessPointCapabilities.FrequencyBands, requestedFrequencyBand) == std::cend(accessPointCapabilities.FrequencyBands)) { - return HandleFailure(request, result, WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeOperationNotSupported, std::format("Frequency band {} not supported by access point {}", magic_enum::enum_name(requestedFrequencyBand), request->accesspointid())); - } - } - - // Attempt to set the frequency bands. - operationStatus = accessPointController->SetFrequencyBands(std::move(ieee80211FrequencyBands)); - if (!operationStatus) { - return HandleFailure(request, result, operationStatus.Code, std::format("Failed to set frequency bands for access point {} ({})", request->accesspointid(), operationStatus.ToString())); - } - - // Prepare result with success indication. - WifiAccessPointOperationStatus status{}; - status.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded); + auto dot11FrequencyBands = ToDot11FrequencyBands(*request); + auto wifiOperationStatus = WifiAccessPointSetFrequencyBandsImpl(request->accesspointid(), dot11FrequencyBands); result->set_accesspointid(request->accesspointid()); - *result->mutable_status() = std::move(status); + *result->mutable_status() = std::move(wifiOperationStatus); return grpc::Status::OK; } @@ -498,8 +472,9 @@ NetRemoteService::TryGetAccessPoint(std::string_view accessPointId, std::shared_ return operationStatus; } +/* static */ AccessPointOperationStatus -NetRemoteService::TryGetAccessPointController(std::shared_ptr accessPoint, std::shared_ptr& accessPointController) +NetRemoteService::TryGetAccessPointController(const std::shared_ptr& accessPoint, std::shared_ptr& accessPointController) { AccessPointOperationStatus operationStatus{ accessPoint->GetInterfaceName() }; @@ -544,7 +519,7 @@ NetRemoteService::WifiAccessPointSetPhyTypeImpl(std::string_view accessPointId, auto operationStatus = TryGetAccessPointController(accessPointId, accessPointController); if (!operationStatus.Succeeded() || accessPointController == nullptr) { wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code)); - wifiOperationStatus.set_message(std::format("Failed to create access point controller - {}", operationStatus.ToString())); + wifiOperationStatus.set_message(std::format("Failed to create access point controller for access point {} - {}", accessPointId, operationStatus.ToString())); return wifiOperationStatus; } @@ -579,3 +554,64 @@ NetRemoteService::WifiAccessPointSetPhyTypeImpl(std::string_view accessPointId, return wifiOperationStatus; } + +WifiAccessPointOperationStatus +NetRemoteService::WifiAccessPointSetFrequencyBandsImpl(std::string_view accessPointId, std::vector& dot11FrequencyBands) +{ + WifiAccessPointOperationStatus wifiOperationStatus{}; + + // Validate basic parameters in the request. + if (std::empty(dot11FrequencyBands)) { + wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter); + wifiOperationStatus.set_message("No frequency band provided"); + return wifiOperationStatus; + } + if (std::ranges::find(dot11FrequencyBands, Dot11FrequencyBand::Dot11FrequencyBandUnknown) != std::cend(dot11FrequencyBands)) { + wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter); + wifiOperationStatus.set_message("Invalid frequency band provided"); + return wifiOperationStatus; + } + + // Create an AP controller for the requested AP. + std::shared_ptr accessPointController{}; + auto operationStatus = TryGetAccessPointController(accessPointId, accessPointController); + if (!operationStatus.Succeeded() || accessPointController == nullptr) { + wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code)); + wifiOperationStatus.set_message(std::format("Failed to create access point controller for access point - {}", accessPointId, operationStatus.ToString())); + return wifiOperationStatus; + } + + // Convert dot11 bands to ieee80211 bands. + std::vector ieee80211FrequencyBands(static_cast(std::size(dot11FrequencyBands))); + std::ranges::transform(dot11FrequencyBands, std::begin(ieee80211FrequencyBands), FromDot11FrequencyBand); + + // Obtain capabilities of the access point. + Ieee80211AccessPointCapabilities accessPointCapabilities{}; + operationStatus = accessPointController->GetCapabilities(accessPointCapabilities); + if (!operationStatus.Succeeded()) { + wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code)); + wifiOperationStatus.set_message(std::format("Failed to get capabilities for access point {} - {}", accessPointId, operationStatus.ToString())); + return wifiOperationStatus; + } + + // Check if requested bands are supported by the AP. + for (const auto& requestedFrequencyBand : ieee80211FrequencyBands) { + if (std::ranges::find(accessPointCapabilities.FrequencyBands, requestedFrequencyBand) == std::cend(accessPointCapabilities.FrequencyBands)) { + wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeOperationNotSupported); + wifiOperationStatus.set_message(std::format("Frequency band '{}' not supported by access point {}", magic_enum::enum_name(requestedFrequencyBand), accessPointId)); + return wifiOperationStatus; + } + } + + // Attempt to set the frequency bands. + operationStatus = accessPointController->SetFrequencyBands(std::move(ieee80211FrequencyBands)); + if (!operationStatus.Succeeded()) { + wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code)); + wifiOperationStatus.set_message(std::format("Failed to set frequency bands for access point {} - {}", accessPointId, operationStatus.ToString())); + return wifiOperationStatus; + } + + wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded); + + return wifiOperationStatus; +} diff --git a/src/common/service/include/microsoft/net/remote/NetRemoteService.hxx b/src/common/service/include/microsoft/net/remote/NetRemoteService.hxx index fd013b2b..6e2ec703 100644 --- a/src/common/service/include/microsoft/net/remote/NetRemoteService.hxx +++ b/src/common/service/include/microsoft/net/remote/NetRemoteService.hxx @@ -124,8 +124,8 @@ protected: * @param accessPointController * @return Microsoft::Net::Wifi::AccessPointOperationStatus */ - Microsoft::Net::Wifi::AccessPointOperationStatus - TryGetAccessPointController(std::shared_ptr accessPoint, std::shared_ptr& accessPointController); + static Microsoft::Net::Wifi::AccessPointOperationStatus + TryGetAccessPointController(const std::shared_ptr& accessPoint, std::shared_ptr& accessPointController); /** * @brief Set the active PHY type or protocol of the access point. The access point must be enabled. This will cause @@ -138,28 +138,28 @@ protected: Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus WifiAccessPointSetPhyTypeImpl(std::string_view accessPointId, Microsoft::Net::Wifi::Dot11PhyType dot11PhyType); -protected: /** - * @brief Validate the basic input parameters for the WifiAccessPointEnable request. + * @brief Set the active frequency bands of the access point. The access point must be enabled. This will cause the + * access point to temporarily go offline while the change is being applied. * - * @param request The request to validate. - * @param status The status to populate with failure information. - * @return true - * @return false + * @param accessPointId The access point identifier. + * @param dot11FrequencyBands The new frequency bands to set. + * @return Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus */ - static bool - ValidateWifiAccessPointEnableRequest(const Microsoft::Net::Remote::Wifi::WifiAccessPointEnableRequest* request, Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus& status); + Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus + WifiAccessPointSetFrequencyBandsImpl(std::string_view accessPointId, std::vector& dot11FrequencyBands); +protected: /** - * @brief Validate the basic input parameters for the WifiAccessPointSetFrequencyBands reqest. + * @brief Validate the basic input parameters for the WifiAccessPointEnable request. * * @param request The request to validate. - * @param result The result to populate with failure information. + * @param status The status to populate with failure information. * @return true * @return false */ static bool - ValidateWifiSetFrequencyBandsRequest(const Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest* request, Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsResult* result); + ValidateWifiAccessPointEnableRequest(const Microsoft::Net::Remote::Wifi::WifiAccessPointEnableRequest* request, Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus& status); private: std::shared_ptr m_accessPointManager; diff --git a/src/common/wifi/dot11/adapter/Ieee80211Dot11Adapters.cxx b/src/common/wifi/dot11/adapter/Ieee80211Dot11Adapters.cxx index 0c7d9a17..22d9ddd2 100644 --- a/src/common/wifi/dot11/adapter/Ieee80211Dot11Adapters.cxx +++ b/src/common/wifi/dot11/adapter/Ieee80211Dot11Adapters.cxx @@ -153,6 +153,23 @@ FromDot11FrequencyBand(const Dot11FrequencyBand dot11FrequencyBand) noexcept using Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest; +std::vector +ToDot11FrequencyBands(const WifiAccessPointSetFrequencyBandsRequest& request) noexcept +{ + // protobuf encodes enums in repeated fields as 'int' instead of the enum type itself. So, the below is a simple + // function to convert the repeated field of int to the enum type. + constexpr auto toDot11FrequencyBand = [](const auto& frequencyBand) { + return static_cast(frequencyBand); + }; + + std::vector dot11FrequencyBands(static_cast(std::size(request.frequencybands()))); + std::ranges::transform(request.frequencybands(), std::begin(dot11FrequencyBands), toDot11FrequencyBand); + + return dot11FrequencyBands; + // TODO: for some reason, std::ranges::to is not being found for clang. Once resolved, update to the following: + // return request.frequencybands() | std::views::transform(toDot11FrequencyBand) | std::ranges::to(); +} + std::vector FromDot11SetFrequencyBandsRequest(const WifiAccessPointSetFrequencyBandsRequest& request) { diff --git a/src/common/wifi/dot11/adapter/include/microsoft/net/wifi/Ieee80211Dot11Adapters.hxx b/src/common/wifi/dot11/adapter/include/microsoft/net/wifi/Ieee80211Dot11Adapters.hxx index 275b7fc0..c756fe9a 100644 --- a/src/common/wifi/dot11/adapter/include/microsoft/net/wifi/Ieee80211Dot11Adapters.hxx +++ b/src/common/wifi/dot11/adapter/include/microsoft/net/wifi/Ieee80211Dot11Adapters.hxx @@ -57,6 +57,15 @@ FromDot11PhyType(Microsoft::Net::Wifi::Dot11PhyType dot11PhyType) noexcept; Microsoft::Net::Wifi::Dot11FrequencyBand ToDot11FrequencyBand(Microsoft::Net::Wifi::Ieee80211FrequencyBand ieee80211FrequencyBand) noexcept; +/** + * @brief Obtain a vector of Dot11FrequencyBands from the specified WifiAccessPointSetFrequencyBandsRequest. + * + * @param request The request to extract the Dot11FrequencyBands from. + * @return std::vector + */ +std::vector +ToDot11FrequencyBands(const Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest& request) noexcept; + /** * @brief Convert the specified Dot11FrequencyBand to the equivalent IEEE 802.11 frequency band. * @@ -76,6 +85,9 @@ FromDot11FrequencyBand(Microsoft::Net::Wifi::Dot11FrequencyBand dot11FrequencyBa std::vector FromDot11SetFrequencyBandsRequest(const Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest& request); +std::vector +FromDot11SetFrequencyBandsRequest(const Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest& request); + /** * @brief Convert the specified IEEE 802.11 authentication algorithm to the equivalent Dot11AuthenticationAlgorithm. * diff --git a/tests/unit/TestNetRemoteServiceClient.cxx b/tests/unit/TestNetRemoteServiceClient.cxx index 46f470e7..1b08d251 100644 --- a/tests/unit/TestNetRemoteServiceClient.cxx +++ b/tests/unit/TestNetRemoteServiceClient.cxx @@ -86,14 +86,29 @@ TEST_CASE("WifiAccessPointEnable API", "[basic][rpc][client][remote]") { using namespace Microsoft::Net::Remote; using namespace Microsoft::Net::Remote::Service; + using namespace Microsoft::Net::Remote::Test; using namespace Microsoft::Net::Remote::Wifi; using namespace Microsoft::Net::Wifi; + using namespace Microsoft::Net::Wifi::Test; constexpr auto SsidName{ "TestWifiAccessPointEnable" }; + constexpr auto InterfaceName1{ "TestWifiAccessPointEnable1" }; + constexpr auto InterfaceName2{ "TestWifiAccessPointEnable2" }; + constexpr auto InterfaceNameInvalid{ "TestWifiAccessPointEnableInvalid" }; + + auto apManagerTest = std::make_shared(); + const Ieee80211AccessPointCapabilities apCapabilities{ + .Protocols{ std::cbegin(AllProtocols), std::cend(AllProtocols) } + }; + + auto apTest1 = std::make_shared(InterfaceName1, apCapabilities); + auto apTest2 = std::make_shared(InterfaceName2, apCapabilities); + apManagerTest->AddAccessPoint(apTest1); + apManagerTest->AddAccessPoint(apTest2); const NetRemoteServerConfiguration Configuration{ .ServerAddress = RemoteServiceAddressHttp, - .AccessPointManager = AccessPointManager::Create(), + .AccessPointManager = apManagerTest, }; NetRemoteServer server{ Configuration }; @@ -113,7 +128,7 @@ TEST_CASE("WifiAccessPointEnable API", "[basic][rpc][client][remote]") apConfiguration.mutable_bands()->Add(Dot11FrequencyBand::Dot11FrequencyBand5_0GHz); WifiAccessPointEnableRequest request{}; - request.set_accesspointid("TestWifiAccessPointEnable"); + request.set_accesspointid(InterfaceName1); *request.mutable_configuration() = std::move(apConfiguration); WifiAccessPointEnableResult result{}; @@ -127,6 +142,70 @@ TEST_CASE("WifiAccessPointEnable API", "[basic][rpc][client][remote]") REQUIRE(result.status().message().empty()); REQUIRE(result.status().has_details() == false); } + + SECTION("Fails with invalid access point") + { + WifiAccessPointEnableRequest request{}; + *request.mutable_configuration() = {}; + request.set_accesspointid(InterfaceNameInvalid); + + WifiAccessPointEnableResult result{}; + grpc::ClientContext clientContext{}; + + grpc::Status status; + REQUIRE_NOTHROW(status = client->WifiAccessPointEnable(&clientContext, request, &result)); + REQUIRE(status.ok()); + REQUIRE(result.accesspointid() == request.accesspointid()); + REQUIRE(result.has_status()); + REQUIRE(result.status().code() == WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeAccessPointInvalid); + } + + SECTION("Succeeds without access point configuration if already configured") + { + // Perform initial enable with configuration. + Dot11AccessPointConfiguration apConfiguration{}; + apConfiguration.mutable_ssid()->set_name(SsidName); + apConfiguration.set_phytype(Dot11PhyType::Dot11PhyTypeA); + apConfiguration.set_authenticationalgorithm(Dot11AuthenticationAlgorithm::Dot11AuthenticationAlgorithmSharedKey); + apConfiguration.set_ciphersuite(Dot11CipherSuite::Dot11CipherSuiteCcmp256); + apConfiguration.mutable_bands()->Add(Dot11FrequencyBand::Dot11FrequencyBand2_4GHz); + + WifiAccessPointEnableRequest request{}; + request.set_accesspointid(InterfaceName1); + *request.mutable_configuration() = std::move(apConfiguration); + + WifiAccessPointEnableResult result{}; + grpc::ClientContext clientContext{}; + + auto status = client->WifiAccessPointEnable(&clientContext, request, &result); + REQUIRE(status.ok()); + REQUIRE(result.accesspointid() == request.accesspointid()); + REQUIRE(result.has_status()); + REQUIRE(result.status().code() == WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded); + + // Disable. + { + WifiAccessPointDisableRequest disableRequest{}; + disableRequest.set_accesspointid(InterfaceName1); + WifiAccessPointDisableResult disableResult{}; + grpc::ClientContext disableClientContext{}; + + grpc::Status disableStatus; + REQUIRE_NOTHROW(disableStatus = client->WifiAccessPointDisable(&disableClientContext, disableRequest, &disableResult)); + + REQUIRE(disableResult.has_status()); + REQUIRE(disableResult.status().code() == WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded); + } + + // Perform second enable without configuration. + request.clear_configuration(); + result.Clear(); + grpc::ClientContext clientContextReenable{}; + status = client->WifiAccessPointEnable(&clientContextReenable, request, &result); + REQUIRE(result.accesspointid() == request.accesspointid()); + REQUIRE(result.has_status()); + REQUIRE(result.status().code() == WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded); + } } TEST_CASE("WifiAccessPointDisable API", "[basic][rpc][client][remote]")