diff --git a/src/common/service/NetRemoteService.cxx b/src/common/service/NetRemoteService.cxx index 6a0a5894..21daf5d3 100644 --- a/src/common/service/NetRemoteService.cxx +++ b/src/common/service/NetRemoteService.cxx @@ -75,6 +75,7 @@ HandleFailure(RequestT& request, ResultT& result, WifiAccessPointOperationStatus * @param code The error code to set in the result message. * @param message The error message to set in the result message. * @param returnValue The value to return from the function. + * @return ReturnT */ template < typename RequestT, @@ -86,6 +87,31 @@ HandleFailure(RequestT& request, ResultT& result, AccessPointOperationStatusCode return HandleFailure(request, result, ToDot11AccessPointOperationStatusCode(code), message, returnValue); } +/** + * @brief Wrapper for HandleFailure that converts a AccessPointOperationStatus to a WifiAccessPointOperationStatus and error details message. + * + * @tparam RequestT The request type. This must contain an access point id (trait). + * @tparam ResultT The result type. This must contain an access point id and a status (traits). + * @return ReturnT The type of the return value. Defaults to grpc::Status with a value of grpc::OK. + * @param request A reference to the request. + * @param result A reference to the result. + * @param operationStatus The AccessPointOperationStatus to derive the failure code and details message from. + * @param returnValue The value to return from the function. + * @return ReturnT + */ +template < + typename RequestT, + typename ResultT, + typename ReturnT = grpc::Status> +ReturnT +HandleFailure(RequestT& request, ResultT& result, const AccessPointOperationStatus& operationStatus, ReturnT returnValue = {}) +{ + return HandleFailure(request, result, operationStatus.Code, operationStatus.ToString(), returnValue); +} + +std::shared_ptr +TryGetAccessPoint(std::string_view accessPointId); + /** * @brief Attempt to obtain an IAccessPoint instance for the access point in the specified request message. * @@ -331,43 +357,9 @@ NetRemoteService::WifiAccessPointSetPhyType([[maybe_unused]] grpc::ServerContext { const NetRemoteWifiApiTrace traceMe{ request->accesspointid(), result->mutable_status() }; - WifiAccessPointOperationStatus status{}; - - // Check if PHY type is provided. - if (request->phytype() == Dot11PhyType::Dot11PhyTypeUnknown) { - return HandleFailure(request, result, WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter, "No PHY type provided"); - } - - // Create an AP controller for the requested AP. - auto accessPointController = detail::TryGetAccessPointController(request, result, m_accessPointManager); - if (accessPointController == nullptr) { - return grpc::Status::OK; - } - - // Convert PHY type to Ieee80211 protocol. - auto ieee80211Protocol = FromDot11PhyType(request->phytype()); - - // Check if Ieee80211 protocol is supported by AP. - Ieee80211AccessPointCapabilities accessPointCapabilities{}; - auto operationStatus = accessPointController->GetCapabilities(accessPointCapabilities); - if (!operationStatus) { - return HandleFailure(request, result, operationStatus.Code, std::format("Failed to get capabilities for access point {} ({})", request->accesspointid(), operationStatus.ToString())); - } - - const auto& supportedIeee80211Protocols = accessPointCapabilities.Protocols; - if (std::ranges::find(supportedIeee80211Protocols, ieee80211Protocol) == std::cend(supportedIeee80211Protocols)) { - return HandleFailure(request, result, WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeOperationNotSupported, std::format("PHY type not supported by access point {}", request->accesspointid())); - } - - // Set the Ieee80211 protocol. - operationStatus = accessPointController->SetProtocol(ieee80211Protocol); - if (!operationStatus) { - return HandleFailure(request, result, operationStatus.Code, std::format("Failed to set PHY type for access point {} ({})", request->accesspointid(), operationStatus.ToString())); - } - - status.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded); + auto wifiOperationStatus = WifiAccessPointSetPhyTypeImpl(request->accesspointid(), request->phytype()); result->set_accesspointid(request->accesspointid()); - *result->mutable_status() = std::move(status); + *result->mutable_status() = std::move(wifiOperationStatus); return grpc::Status::OK; } @@ -478,3 +470,112 @@ NetRemoteService::ValidateWifiAccessPointEnableRequest(const WifiAccessPointEnab return true; } + +AccessPointOperationStatus +NetRemoteService::TryGetAccessPoint(std::string_view accessPointId, std::shared_ptr& accessPoint) +{ + AccessPointOperationStatus operationStatus{ accessPointId }; + + // Find the requested AP. + auto accessPointOpt = m_accessPointManager->GetAccessPoint(accessPointId); + if (!accessPointOpt.has_value()) { + operationStatus.Code = AccessPointOperationStatusCode::AccessPointInvalid; + operationStatus.Details = "access point not found"; + return operationStatus; + } + + // Attempt to promote the weak reference to a strong reference. + auto accessPointWeak{ accessPointOpt.value() }; + accessPoint = accessPointWeak.lock(); + if (accessPoint == nullptr) { + operationStatus.Code = AccessPointOperationStatusCode::AccessPointInvalid; + operationStatus.Details = "access point no longer valid"; + return operationStatus; + } + + operationStatus.Code = AccessPointOperationStatusCode::Succeeded; + + return operationStatus; +} + +AccessPointOperationStatus +NetRemoteService::TryGetAccessPointController(std::shared_ptr accessPoint, std::shared_ptr& accessPointController) +{ + AccessPointOperationStatus operationStatus{ accessPoint->GetInterfaceName() }; + + accessPointController = accessPoint->CreateController(); + if (accessPointController == nullptr) { + operationStatus.Code = AccessPointOperationStatusCode::InternalError; + operationStatus.Details = "failed to create access point controller"; + return operationStatus; + } + + operationStatus.Code = AccessPointOperationStatusCode::Succeeded; + + return operationStatus; +} + +AccessPointOperationStatus +NetRemoteService::TryGetAccessPointController(std::string_view accessPointId, std::shared_ptr& accessPointController) +{ + std::shared_ptr accessPoint{}; + auto operationStatus = TryGetAccessPoint(accessPointId, accessPoint); + if (!operationStatus.Succeeded()) { + return operationStatus; + } + + return TryGetAccessPointController(accessPoint, accessPointController); +} + +WifiAccessPointOperationStatus +NetRemoteService::WifiAccessPointSetPhyTypeImpl(std::string_view accessPointId, Dot11PhyType dot11PhyType) +{ + WifiAccessPointOperationStatus wifiOperationStatus{}; + + // Check if PHY type is provided. + if (dot11PhyType == Dot11PhyType::Dot11PhyTypeUnknown) { + wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter); + wifiOperationStatus.set_message("No PHY type provided"); + return wifiOperationStatus; + } + + // Create an AP controller for the requested AP. + std::shared_ptr accessPointController{}; + auto operationStatus = TryGetAccessPointController(accessPointId, accessPointController); + if (!operationStatus.Succeeded() || accessPointController == nullptr) { + wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code)); + wifiOperationStatus.set_message(std::format("Failed to create access point controller - {}", operationStatus.ToString())); + return wifiOperationStatus; + } + + // Convert PHY type to Ieee80211 protocol. + auto ieee80211Protocol = FromDot11PhyType(dot11PhyType); + + // Check if Ieee80211 protocol is supported by AP. + Ieee80211AccessPointCapabilities accessPointCapabilities{}; + operationStatus = accessPointController->GetCapabilities(accessPointCapabilities); + if (!operationStatus) { + wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code)); + wifiOperationStatus.set_message(std::format("Failed to get capabilities for access point {} - {}", accessPointId, operationStatus.ToString())); + return wifiOperationStatus; + } + + const auto& supportedIeee80211Protocols = accessPointCapabilities.Protocols; + if (std::ranges::find(supportedIeee80211Protocols, ieee80211Protocol) == std::cend(supportedIeee80211Protocols)) { + wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeOperationNotSupported); + wifiOperationStatus.set_message(std::format("PHY type '{}' not supported by access point {}", magic_enum::enum_name(ieee80211Protocol), accessPointId)); + return wifiOperationStatus; + } + + // Set the Ieee80211 protocol. + operationStatus = accessPointController->SetProtocol(ieee80211Protocol); + if (!operationStatus) { + wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code)); + wifiOperationStatus.set_message(std::format("Failed to set PHY type for access point {} - {}", accessPointId, operationStatus.ToString())); + return wifiOperationStatus; + } + + wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded); + + return wifiOperationStatus; +} diff --git a/src/common/service/NetRemoteWifiApiTrace.cxx b/src/common/service/NetRemoteWifiApiTrace.cxx index 675fc9cc..265cd846 100644 --- a/src/common/service/NetRemoteWifiApiTrace.cxx +++ b/src/common/service/NetRemoteWifiApiTrace.cxx @@ -30,13 +30,16 @@ NetRemoteWifiApiTrace::NetRemoteWifiApiTrace(std::optional accessPo NetRemoteWifiApiTrace::~NetRemoteWifiApiTrace() { - if (m_operationStatus != nullptr) { - AddReturnValue(ArgNameStatus, std::string(magic_enum::enum_name(m_operationStatus->code()))); - if (m_operationStatus->code() != WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded) { - AddReturnValue(ArgNameErrorMessage, m_operationStatus->message()); - SetFailed(); - } else { - SetSucceeded(); - } + if (m_operationStatus == nullptr) { + return; + } + + AddReturnValue(ArgNameStatus, std::string(magic_enum::enum_name(m_operationStatus->code()))); + + if (m_operationStatus->code() != WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded) { + AddReturnValue(ArgNameErrorMessage, m_operationStatus->message()); + SetFailed(); + } else { + SetSucceeded(); } } diff --git a/src/common/service/include/microsoft/net/remote/NetRemoteService.hxx b/src/common/service/include/microsoft/net/remote/NetRemoteService.hxx index 32b1f956..fd013b2b 100644 --- a/src/common/service/include/microsoft/net/remote/NetRemoteService.hxx +++ b/src/common/service/include/microsoft/net/remote/NetRemoteService.hxx @@ -3,12 +3,16 @@ #define NET_REMOTE_SERVICE_HXX #include +#include #include #include #include #include #include +#include +#include +#include namespace Microsoft::Net::Remote::Service { @@ -92,6 +96,48 @@ private: grpc::Status WifiAccessPointSetFrequencyBands(grpc::ServerContext* context, const Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest* request, Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsResult* result) override; +protected: + /** + * @brief Attempt to obtain an IAccessPoint instance for the specified access point identifier. + * + * @param accessPointId The access point identifier. + * @param accessPoint Output variable to receive the access point instance. + * @return Microsoft::Net::Wifi::AccessPointOperationStatus + */ + Microsoft::Net::Wifi::AccessPointOperationStatus + TryGetAccessPoint(std::string_view accessPointId, std::shared_ptr& accessPoint); + + /** + * @brief Attempt to obtain an IAccessPointController instance for the access point with the specified identifier. + * + * @param accessPointId The access point identifier. + * @param accessPointController Output variable to receive the access point controller instance. + * @return Microsoft::Net::Wifi::AccessPointOperationStatus + */ + Microsoft::Net::Wifi::AccessPointOperationStatus + TryGetAccessPointController(std::string_view accessPointId, std::shared_ptr& accessPointController); + + /** + * @brief Attempt to obtain an IAccessPointController instance for the specified access point. + * + * @param accessPoint The access point to obtain a controller for. + * @param accessPointController + * @return Microsoft::Net::Wifi::AccessPointOperationStatus + */ + Microsoft::Net::Wifi::AccessPointOperationStatus + TryGetAccessPointController(std::shared_ptr accessPoint, std::shared_ptr& accessPointController); + + /** + * @brief Set the active PHY type or protocol of the access point. The access point must be enabled. This will cause + * the access point to temporarily go offline while the change is being applied. + * + * @param accessPointId The access point identifier. + * @param dot11PhyType The new PHY type to set. + * @return Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus + */ + Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus + WifiAccessPointSetPhyTypeImpl(std::string_view accessPointId, Microsoft::Net::Wifi::Dot11PhyType dot11PhyType); + protected: /** * @brief Validate the basic input parameters for the WifiAccessPointEnable request.