Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine WifiAccessPointSetFrequencyBands into top-level and implementation functions #201

Merged
merged 5 commits into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 101 additions & 65 deletions src/common/service/NetRemoteService.cxx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

#include <algorithm>
#include <cstddef>
#include <format>
#include <iterator>
#include <memory>
Expand All @@ -8,8 +9,6 @@
#include <utility>
#include <vector>

#include "NetRemoteApiTrace.hxx"
#include "NetRemoteWifiApiTrace.hxx"
#include <grpcpp/impl/codegen/status.h>
#include <grpcpp/server_context.h>
#include <magic_enum.hpp>
Expand All @@ -20,10 +19,14 @@
#include <microsoft/net/wifi/AccessPointOperationStatus.hxx>
#include <microsoft/net/wifi/IAccessPoint.hxx>
#include <microsoft/net/wifi/IAccessPointController.hxx>
#include <microsoft/net/wifi/Ieee80211.hxx>
#include <microsoft/net/wifi/Ieee80211AccessPointCapabilities.hxx>
#include <microsoft/net/wifi/Ieee80211Dot11Adapters.hxx>
#include <plog/Log.h>

#include "NetRemoteApiTrace.hxx"
#include "NetRemoteWifiApiTrace.hxx"

using namespace Microsoft::Net::Remote::Service;
using namespace Microsoft::Net::Remote::Service::Tracing;
using namespace Microsoft::Net::Remote::Wifi;
Expand Down Expand Up @@ -109,9 +112,6 @@ HandleFailure(RequestT& request, ResultT& result, const AccessPointOperationStat
return HandleFailure(request, result, operationStatus.Code, operationStatus.ToString(), returnValue);
}

std::shared_ptr<IAccessPoint>
TryGetAccessPoint(std::string_view accessPointId);

/**
* @brief Attempt to obtain an IAccessPoint instance for the access point in the specified request message.
*
Expand Down Expand Up @@ -300,14 +300,39 @@ NetRemoteService::WifiAccessPointEnable([[maybe_unused]] grpc::ServerContext* co
{
const NetRemoteWifiApiTrace traceMe{ request->accesspointid(), result->mutable_status() };

// Create an AP controller for the requested AP.
auto accessPointController = detail::TryGetAccessPointController(request, result, m_accessPointManager);
if (accessPointController == nullptr) {
return grpc::Status::OK;
}

// Obtain current operational state.
AccessPointOperationalState operationalState{};
auto operationStatus = accessPointController->GetOperationalState(operationalState);
if (!operationStatus) {
return HandleFailure(request, result, operationStatus.Code, std::format("Failed to get operational state for access point {}", request->accesspointid()));
}

WifiAccessPointOperationStatus status{};

// Validate request is well-formed and has all required parameters.
if (ValidateWifiAccessPointEnableRequest(request, status)) {
// TODO: Enable the access point.
status.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded);
// Enable the access point if it's not already enabled.
if (operationalState != AccessPointOperationalState::Enabled) {
// Validate request is well-formed and has all required parameters.
if (ValidateWifiAccessPointEnableRequest(request, status)) {
// TODO: Enable the access point.

// Set the operational state to 'enabled' now that initial configuration has been set.
operationStatus = accessPointController->SetOperationalState(AccessPointOperationalState::Enabled);
if (!operationStatus) {
return HandleFailure(request, result, operationStatus.Code, std::format("Failed to set operational state to 'enabled' for access point {}", request->accesspointid()));
}
}
} else {
LOGI << std::format("Access point {} is already enabled", request->accesspointid());
}

status.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded);

result->set_accesspointid(request->accesspointid());
*result->mutable_status() = std::move(status);

Expand Down Expand Up @@ -364,67 +389,16 @@ NetRemoteService::WifiAccessPointSetPhyType([[maybe_unused]] grpc::ServerContext
return grpc::Status::OK;
}

/* static */
bool
NetRemoteService::ValidateWifiSetFrequencyBandsRequest(const WifiAccessPointSetFrequencyBandsRequest* request, WifiAccessPointSetFrequencyBandsResult* result)
{
const auto& frequencyBands = request->frequencybands();

if (std::empty(frequencyBands)) {
return HandleFailure(request, result, WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter, "No frequency band provided", false);
}
if (std::ranges::find(frequencyBands, Dot11FrequencyBand::Dot11FrequencyBandUnknown) != std::cend(frequencyBands)) {
return HandleFailure(request, result, WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter, "Invalid frequency band provided", false);
}

return true;
}

grpc::Status
NetRemoteService::WifiAccessPointSetFrequencyBands([[maybe_unused]] grpc::ServerContext* context, const WifiAccessPointSetFrequencyBandsRequest* request, WifiAccessPointSetFrequencyBandsResult* result)
{
const NetRemoteWifiApiTrace traceMe{ request->accesspointid(), result->mutable_status() };

// Validate basic parameters in the request.
if (!ValidateWifiSetFrequencyBandsRequest(request, result)) {
return grpc::Status::OK;
}

// Create an AP controller for the requested AP.
auto accessPointController = detail::TryGetAccessPointController(request, result, m_accessPointManager);
if (accessPointController == nullptr) {
return grpc::Status::OK;
}

// Convert dot11 bands to ieee80211 bands.
auto ieee80211FrequencyBands = FromDot11SetFrequencyBandsRequest(*request);

// Obtain capabilities of the access point.
Ieee80211AccessPointCapabilities accessPointCapabilities{};
auto operationStatus = accessPointController->GetCapabilities(accessPointCapabilities);
if (!operationStatus) {
return HandleFailure(request, result, operationStatus.Code, std::format("Failed to get capabilities for access point {}", request->accesspointid()));
}

// Check if requested bands are supported by the AP.
for (const auto& requestedFrequencyBand : ieee80211FrequencyBands) {
if (std::ranges::find(accessPointCapabilities.FrequencyBands, requestedFrequencyBand) == std::cend(accessPointCapabilities.FrequencyBands)) {
return HandleFailure(request, result, WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeOperationNotSupported, std::format("Frequency band {} not supported by access point {}", magic_enum::enum_name(requestedFrequencyBand), request->accesspointid()));
}
}

// Attempt to set the frequency bands.
operationStatus = accessPointController->SetFrequencyBands(std::move(ieee80211FrequencyBands));
if (!operationStatus) {
return HandleFailure(request, result, operationStatus.Code, std::format("Failed to set frequency bands for access point {} ({})", request->accesspointid(), operationStatus.ToString()));
}

// Prepare result with success indication.
WifiAccessPointOperationStatus status{};
status.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded);
auto dot11FrequencyBands = ToDot11FrequencyBands(*request);

auto wifiOperationStatus = WifiAccessPointSetFrequencyBandsImpl(request->accesspointid(), dot11FrequencyBands);
result->set_accesspointid(request->accesspointid());
*result->mutable_status() = std::move(status);
*result->mutable_status() = std::move(wifiOperationStatus);

return grpc::Status::OK;
}
Expand Down Expand Up @@ -498,8 +472,9 @@ NetRemoteService::TryGetAccessPoint(std::string_view accessPointId, std::shared_
return operationStatus;
}

/* static */
AccessPointOperationStatus
NetRemoteService::TryGetAccessPointController(std::shared_ptr<IAccessPoint> accessPoint, std::shared_ptr<IAccessPointController>& accessPointController)
NetRemoteService::TryGetAccessPointController(const std::shared_ptr<IAccessPoint>& accessPoint, std::shared_ptr<IAccessPointController>& accessPointController)
{
AccessPointOperationStatus operationStatus{ accessPoint->GetInterfaceName() };

Expand Down Expand Up @@ -544,7 +519,7 @@ NetRemoteService::WifiAccessPointSetPhyTypeImpl(std::string_view accessPointId,
auto operationStatus = TryGetAccessPointController(accessPointId, accessPointController);
if (!operationStatus.Succeeded() || accessPointController == nullptr) {
wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code));
wifiOperationStatus.set_message(std::format("Failed to create access point controller - {}", operationStatus.ToString()));
wifiOperationStatus.set_message(std::format("Failed to create access point controller for access point {} - {}", accessPointId, operationStatus.ToString()));
return wifiOperationStatus;
}

Expand Down Expand Up @@ -579,3 +554,64 @@ NetRemoteService::WifiAccessPointSetPhyTypeImpl(std::string_view accessPointId,

return wifiOperationStatus;
}

WifiAccessPointOperationStatus
NetRemoteService::WifiAccessPointSetFrequencyBandsImpl(std::string_view accessPointId, std::vector<Dot11FrequencyBand>& dot11FrequencyBands)
{
WifiAccessPointOperationStatus wifiOperationStatus{};

// Validate basic parameters in the request.
if (std::empty(dot11FrequencyBands)) {
wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter);
wifiOperationStatus.set_message("No frequency band provided");
return wifiOperationStatus;
}
if (std::ranges::find(dot11FrequencyBands, Dot11FrequencyBand::Dot11FrequencyBandUnknown) != std::cend(dot11FrequencyBands)) {
wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter);
wifiOperationStatus.set_message("Invalid frequency band provided");
return wifiOperationStatus;
}

// Create an AP controller for the requested AP.
std::shared_ptr<IAccessPointController> accessPointController{};
auto operationStatus = TryGetAccessPointController(accessPointId, accessPointController);
if (!operationStatus.Succeeded() || accessPointController == nullptr) {
wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code));
wifiOperationStatus.set_message(std::format("Failed to create access point controller for access point - {}", accessPointId, operationStatus.ToString()));
return wifiOperationStatus;
}

// Convert dot11 bands to ieee80211 bands.
std::vector<Ieee80211FrequencyBand> ieee80211FrequencyBands(static_cast<std::size_t>(std::size(dot11FrequencyBands)));
std::ranges::transform(dot11FrequencyBands, std::begin(ieee80211FrequencyBands), FromDot11FrequencyBand);

// Obtain capabilities of the access point.
Ieee80211AccessPointCapabilities accessPointCapabilities{};
operationStatus = accessPointController->GetCapabilities(accessPointCapabilities);
if (!operationStatus.Succeeded()) {
wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code));
wifiOperationStatus.set_message(std::format("Failed to get capabilities for access point {} - {}", accessPointId, operationStatus.ToString()));
return wifiOperationStatus;
}

// Check if requested bands are supported by the AP.
for (const auto& requestedFrequencyBand : ieee80211FrequencyBands) {
if (std::ranges::find(accessPointCapabilities.FrequencyBands, requestedFrequencyBand) == std::cend(accessPointCapabilities.FrequencyBands)) {
wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeOperationNotSupported);
wifiOperationStatus.set_message(std::format("Frequency band '{}' not supported by access point {}", magic_enum::enum_name(requestedFrequencyBand), accessPointId));
return wifiOperationStatus;
}
}

// Attempt to set the frequency bands.
operationStatus = accessPointController->SetFrequencyBands(std::move(ieee80211FrequencyBands));
if (!operationStatus.Succeeded()) {
wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code));
wifiOperationStatus.set_message(std::format("Failed to set frequency bands for access point {} - {}", accessPointId, operationStatus.ToString()));
return wifiOperationStatus;
}

wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded);

return wifiOperationStatus;
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ protected:
* @param accessPointController
* @return Microsoft::Net::Wifi::AccessPointOperationStatus
*/
Microsoft::Net::Wifi::AccessPointOperationStatus
TryGetAccessPointController(std::shared_ptr<Microsoft::Net::Wifi::IAccessPoint> accessPoint, std::shared_ptr<Microsoft::Net::Wifi::IAccessPointController>& accessPointController);
static Microsoft::Net::Wifi::AccessPointOperationStatus
TryGetAccessPointController(const std::shared_ptr<Microsoft::Net::Wifi::IAccessPoint>& accessPoint, std::shared_ptr<Microsoft::Net::Wifi::IAccessPointController>& accessPointController);

/**
* @brief Set the active PHY type or protocol of the access point. The access point must be enabled. This will cause
Expand All @@ -138,28 +138,28 @@ protected:
Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus
WifiAccessPointSetPhyTypeImpl(std::string_view accessPointId, Microsoft::Net::Wifi::Dot11PhyType dot11PhyType);

protected:
/**
* @brief Validate the basic input parameters for the WifiAccessPointEnable request.
* @brief Set the active frequency bands of the access point. The access point must be enabled. This will cause the
* access point to temporarily go offline while the change is being applied.
*
* @param request The request to validate.
* @param status The status to populate with failure information.
* @return true
* @return false
* @param accessPointId The access point identifier.
* @param dot11FrequencyBands The new frequency bands to set.
* @return Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus
*/
static bool
ValidateWifiAccessPointEnableRequest(const Microsoft::Net::Remote::Wifi::WifiAccessPointEnableRequest* request, Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus& status);
Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus
WifiAccessPointSetFrequencyBandsImpl(std::string_view accessPointId, std::vector<Microsoft::Net::Wifi::Dot11FrequencyBand>& dot11FrequencyBands);

protected:
/**
* @brief Validate the basic input parameters for the WifiAccessPointSetFrequencyBands reqest.
* @brief Validate the basic input parameters for the WifiAccessPointEnable request.
*
* @param request The request to validate.
* @param result The result to populate with failure information.
* @param status The status to populate with failure information.
* @return true
* @return false
*/
static bool
ValidateWifiSetFrequencyBandsRequest(const Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest* request, Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsResult* result);
ValidateWifiAccessPointEnableRequest(const Microsoft::Net::Remote::Wifi::WifiAccessPointEnableRequest* request, Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus& status);

private:
std::shared_ptr<Microsoft::Net::Wifi::AccessPointManager> m_accessPointManager;
Expand Down
17 changes: 17 additions & 0 deletions src/common/wifi/dot11/adapter/Ieee80211Dot11Adapters.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,23 @@ FromDot11FrequencyBand(const Dot11FrequencyBand dot11FrequencyBand) noexcept

using Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest;

std::vector<Dot11FrequencyBand>
ToDot11FrequencyBands(const WifiAccessPointSetFrequencyBandsRequest& request) noexcept
{
// protobuf encodes enums in repeated fields as 'int' instead of the enum type itself. So, the below is a simple
// function to convert the repeated field of int to the enum type.
constexpr auto toDot11FrequencyBand = [](const auto& frequencyBand) {
return static_cast<Dot11FrequencyBand>(frequencyBand);
};

std::vector<Dot11FrequencyBand> dot11FrequencyBands(static_cast<std::size_t>(std::size(request.frequencybands())));
std::ranges::transform(request.frequencybands(), std::begin(dot11FrequencyBands), toDot11FrequencyBand);

return dot11FrequencyBands;
// TODO: for some reason, std::ranges::to is not being found for clang. Once resolved, update to the following:
// return request.frequencybands() | std::views::transform(toDot11FrequencyBand) | std::ranges::to<std::vector>();
}

std::vector<Ieee80211FrequencyBand>
FromDot11SetFrequencyBandsRequest(const WifiAccessPointSetFrequencyBandsRequest& request)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ FromDot11PhyType(Microsoft::Net::Wifi::Dot11PhyType dot11PhyType) noexcept;
Microsoft::Net::Wifi::Dot11FrequencyBand
ToDot11FrequencyBand(Microsoft::Net::Wifi::Ieee80211FrequencyBand ieee80211FrequencyBand) noexcept;

/**
* @brief Obtain a vector of Dot11FrequencyBands from the specified WifiAccessPointSetFrequencyBandsRequest.
*
* @param request The request to extract the Dot11FrequencyBands from.
* @return std::vector<Microsoft::Net::Wifi::Dot11FrequencyBand>
*/
std::vector<Microsoft::Net::Wifi::Dot11FrequencyBand>
ToDot11FrequencyBands(const Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest& request) noexcept;

/**
* @brief Convert the specified Dot11FrequencyBand to the equivalent IEEE 802.11 frequency band.
*
Expand All @@ -76,6 +85,9 @@ FromDot11FrequencyBand(Microsoft::Net::Wifi::Dot11FrequencyBand dot11FrequencyBa
std::vector<Microsoft::Net::Wifi::Ieee80211FrequencyBand>
FromDot11SetFrequencyBandsRequest(const Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest& request);

std::vector<Microsoft::Net::Wifi::Ieee80211FrequencyBand>
FromDot11SetFrequencyBandsRequest(const Microsoft::Net::Remote::Wifi::WifiAccessPointSetFrequencyBandsRequest& request);

/**
* @brief Convert the specified IEEE 802.11 authentication algorithm to the equivalent Dot11AuthenticationAlgorithm.
*
Expand Down
Loading