Skip to content

Commit

Permalink
Merge pull request #240 from microsoft/apenablesetakms
Browse files Browse the repository at this point in the history
Allow settings AKMs from WifiAccessPointEnable API
  • Loading branch information
abeltrano authored Mar 25, 2024
2 parents 875c774 + 20679c5 commit 5e322c9
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,20 @@

namespace Microsoft::Net::Remote::Protocol
{
static constexpr uint32_t NetRemotePortDefault = 5047;
static constexpr std::string_view NetRemoteAddressHttpDefault = "localhost:5047";

/**
* @brief Static NetRemote protocol information.
*/
struct NetRemoteProtocol
{
#define IP_DEFAULT "localhost"
#define PORT_DEFAULT 5047
#define PORT_SEPARATOR ""
#define PORT_SEPARATOR ":"
#define xstr(s) str(s)
#define str(s) #s

static constexpr uint32_t PortDefault{ 5047 };
static constexpr std::string_view PortSeparator{ ":" };
static constexpr std::string_view IpDefault{ "localhost" };
static constexpr uint32_t PortDefault{ PORT_DEFAULT };
static constexpr std::string_view PortSeparator{ PORT_SEPARATOR };
static constexpr std::string_view IpDefault{ IP_DEFAULT };
static constexpr std::string_view AddressDefault{ IP_DEFAULT PORT_SEPARATOR xstr(PORT_DEFAULT) };

#undef IP_DEFAULT
Expand Down
6 changes: 4 additions & 2 deletions protocol/protos/WifiCore.proto
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ enum Dot11AkmSuite
Dot11AkmSuiteOwe = 19;
Dot11AkmSuiteFtPskSha384 = 20;
Dot11AkmSuitePskSha384 = 21;
Dot11AkmSuitePasn = 22;
}

// 802.11 Cipher suites.
Expand Down Expand Up @@ -163,8 +164,9 @@ message Dot11AccessPointConfiguration
Dot11MacAddress Bssid = 2;
Dot11PhyType PhyType = 3;
repeated Dot11CipherSuiteConfiguration PairwiseCipherSuites = 4;
repeated Dot11AuthenticationAlgorithm AuthenticationAlgorithms = 5;
repeated Dot11FrequencyBand FrequencyBands = 6;
repeated Dot11AkmSuite AkmSuites = 5;
repeated Dot11AuthenticationAlgorithm AuthenticationAlgorithms = 6;
repeated Dot11FrequencyBand FrequencyBands = 7;
}

message Dot11AccessPointCapabilities
Expand Down
59 changes: 59 additions & 0 deletions src/common/service/NetRemoteService.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,14 @@ NetRemoteService::WifiAccessPointEnableImpl(std::string_view accessPointId, cons
}
}

if (dot11AccessPointConfiguration->akmsuites_size() > 0) {
auto dot11AkmSuites = ToDot11AkmSuites(*dot11AccessPointConfiguration);
wifiOperationStatus = WifiAccessPointSetAkmSuitesImpl(accessPointId, dot11AkmSuites, accessPointController);
if (wifiOperationStatus.code() != WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded) {
return wifiOperationStatus;
}
}

if (dot11AccessPointConfiguration->pairwiseciphersuites_size() > 0) {
auto dot11PairwiseCipherSuites = ToDot11CipherSuiteConfigurations(dot11AccessPointConfiguration->pairwiseciphersuites());
wifiOperationStatus = WifiAccessPointSetPairwiseCipherSuitesImpl(accessPointId, dot11PairwiseCipherSuites, accessPointController);
Expand Down Expand Up @@ -704,6 +712,57 @@ NetRemoteService::WifiAccessPointSetAuthenticationAlgorithsmImpl(std::string_vie
return wifiOperationStatus;
}

WifiAccessPointOperationStatus
NetRemoteService::WifiAccessPointSetAkmSuitesImpl(std::string_view accessPointId, std::vector<Dot11AkmSuite>& dot11AkmSuites, std::shared_ptr<IAccessPointController> accessPointController)
{
WifiAccessPointOperationStatus wifiOperationStatus{};

// Validate basic parameters in the request.
if (std::empty(dot11AkmSuites)) {
wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter);
wifiOperationStatus.set_message("No akm suites provided");
return wifiOperationStatus;
}
if (std::ranges::contains(dot11AkmSuites, Dot11AkmSuite::Dot11AkmSuiteUnknown)) {
wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter);
wifiOperationStatus.set_message("Invalid akm suite provided");
return wifiOperationStatus;
}

AccessPointOperationStatus operationStatus{ accessPointId };

// Create an AP controller for the requested AP if one wasn't specified.
if (accessPointController == nullptr) {
operationStatus = TryGetAccessPointController(accessPointId, accessPointController);
if (!operationStatus.Succeeded() || accessPointController == nullptr) {
wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code));
wifiOperationStatus.set_message(std::format("Failed to create access point controller for access point {} - {}", accessPointId, operationStatus.ToString()));
return wifiOperationStatus;
}
}

// Convert to 802.11 neutral type.
std::vector<Ieee80211AkmSuite> ieee80211AkmSuites(static_cast<std::size_t>(std::size(dot11AkmSuites)));
std::ranges::transform(dot11AkmSuites, std::begin(ieee80211AkmSuites), FromDot11AkmSuite);
if (std::ranges::contains(ieee80211AkmSuites, Ieee80211AkmSuite::Unknown)) {
wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeInvalidParameter);
wifiOperationStatus.set_message("Invalid akm suite provided");
return wifiOperationStatus;
}

// Set the algorithms.
operationStatus = accessPointController->SetAkmSuites(std::move(ieee80211AkmSuites));
if (!operationStatus.Succeeded()) {
wifiOperationStatus.set_code(ToDot11AccessPointOperationStatusCode(operationStatus.Code));
wifiOperationStatus.set_message(std::format("Failed to set akm suites for access point {} - {}", accessPointId, operationStatus.ToString()));
return wifiOperationStatus;
}

wifiOperationStatus.set_code(WifiAccessPointOperationStatusCode::WifiAccessPointOperationStatusCodeSucceeded);

return wifiOperationStatus;
}

WifiAccessPointOperationStatus
NetRemoteService::WifiAccessPointSetPairwiseCipherSuitesImpl(std::string_view accessPointId, std::unordered_map<Dot11SecurityProtocol, std::vector<Dot11CipherSuite>>& dot11PairwiseCipherSuites, std::shared_ptr<IAccessPointController> accessPointController)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ protected:
Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus
WifiAccessPointSetAuthenticationAlgorithsmImpl(std::string_view accessPointId, std::vector<Microsoft::Net::Wifi::Dot11AuthenticationAlgorithm>& dot11AuthenticationAlgorithms, std::shared_ptr<Microsoft::Net::Wifi::IAccessPointController> accessPointController = nullptr);

/**
* @brief Set the active AKM suites of the access point. If the access point is online, this will cause it to
* temporarily go offline while the change is being applied.
*
* @param accessPointId The access point identifier.
* @param dot11AkmSuites The new AKM suites to set.
* @param accessPointController The access point controller for the specified access point (optional).
* @return Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus
*/
Microsoft::Net::Remote::Wifi::WifiAccessPointOperationStatus
WifiAccessPointSetAkmSuitesImpl(std::string_view accessPointId, std::vector<Microsoft::Net::Wifi::Dot11AkmSuite>& dot11AkmSuites, std::shared_ptr<Microsoft::Net::Wifi::IAccessPointController> accessPointController = nullptr);

/**
* @brief Set the active cipher suites of the access point. If the access point is online, this will cause it to
* temporarily go offline while the change is being applied.
Expand Down
53 changes: 52 additions & 1 deletion src/common/tools/cli/NetRemoteCli.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,13 @@ NetRemoteCli::CreateParser() noexcept

app->require_subcommand();

const std::string serverAddressDefault{ Protocol::NetRemoteProtocol::AddressDefault };
auto* optionServer = app->add_option_function<std::string>("-s,--server", [this](const std::string& serverAddress) {
OnServerAddressChanged(serverAddress);
});
optionServer->description("The address of the netremote server to connect to, with format '<hostname>[:port]");
optionServer->description("The address of the netremote server to connect to, with format '<hostname>[:port]'");
optionServer->default_val(serverAddressDefault)->run_callback_for_default()->force_callback();
optionServer->default_str(serverAddressDefault);

m_cliAppServerAddress = optionServer;
m_cliAppWifi = AddSubcommandWifi(app.get());
Expand Down Expand Up @@ -174,8 +177,10 @@ Ieee80211AuthenticationAlgorithmNames()
{
try {
static const std::map<std::string, Ieee80211AuthenticationAlgorithm> ieee80211AuthenticationAlgorithmNames{
{ "o", Ieee80211AuthenticationAlgorithm::OpenSystem },
{ "open", Ieee80211AuthenticationAlgorithm::OpenSystem },
{ "open-system", Ieee80211AuthenticationAlgorithm::OpenSystem },
{ "s", Ieee80211AuthenticationAlgorithm::SharedKey },
{ "shared", Ieee80211AuthenticationAlgorithm::SharedKey },
{ "shared-key", Ieee80211AuthenticationAlgorithm::SharedKey },
{ "skey", Ieee80211AuthenticationAlgorithm::SharedKey },
Expand All @@ -185,6 +190,49 @@ Ieee80211AuthenticationAlgorithmNames()
throw std::runtime_error{ "Failed to create authentication algorithm names" };
}
}

const std::map<std::string, Ieee80211AkmSuite>&
Ieee80211AkmSuiteNames()
{
try {
static const std::map<std::string, Ieee80211AkmSuite> ieee80211AkmSuiteNames{
{ "8021x", Ieee80211AkmSuite::Ieee8021x },
{ "dot1x", Ieee80211AkmSuite::Ieee8021x },
{ "psk", Ieee80211AkmSuite::Psk },
{ "ft8021x", Ieee80211AkmSuite::Ft8021x },
{ "ftdot1x", Ieee80211AkmSuite::Ft8021x },
{ "ftpsk", Ieee80211AkmSuite::FtPsk },
{ "8021xsha256", Ieee80211AkmSuite::Ieee8021xSha256 },
{ "dot1xsha256", Ieee80211AkmSuite::Ieee8021xSha256 },
{ "psksha256", Ieee80211AkmSuite::PskSha256 },
{ "tdls", Ieee80211AkmSuite::Tdls },
{ "sae", Ieee80211AkmSuite::Sae },
{ "ftsae", Ieee80211AkmSuite::FtSae },
{ "appeerkey", Ieee80211AkmSuite::ApPeerKey },
{ "8021xsuiteb", Ieee80211AkmSuite::Ieee8021xSuiteB },
{ "8021xsuiteb192", Ieee80211AkmSuite::Ieee8011xSuiteB192 },
{ "dot1xsuiteb", Ieee80211AkmSuite::Ieee8021xSuiteB },
{ "dot1xsuiteb192", Ieee80211AkmSuite::Ieee8011xSuiteB192 },
{ "8021xb", Ieee80211AkmSuite::Ieee8021xSuiteB },
{ "8021xb192", Ieee80211AkmSuite::Ieee8011xSuiteB192 },
{ "dot11b", Ieee80211AkmSuite::Ieee8021xSuiteB },
{ "dot11b192", Ieee80211AkmSuite::Ieee8011xSuiteB192 },
{ "ft8021xsha384", Ieee80211AkmSuite::Ft8021xSha384 },
{ "ftdot1xsha384", Ieee80211AkmSuite::Ft8021xSha384 },
{ "filssha256", Ieee80211AkmSuite::FilsSha256 },
{ "filssha384", Ieee80211AkmSuite::FilsSha384 },
{ "ftfilssha256", Ieee80211AkmSuite::FtFilsSha256 },
{ "ftfilssha384", Ieee80211AkmSuite::FtFilsSha384 },
{ "owe", Ieee80211AkmSuite::Owe },
{ "ftpsksha384", Ieee80211AkmSuite::FtPskSha384 },
{ "psksha384", Ieee80211AkmSuite::PskSha384 },
{ "pasn", Ieee80211AkmSuite::Pasn },
};
return ieee80211AkmSuiteNames;
} catch (...) {
throw std::runtime_error{ "Failed to create AKM suite names" };
}
}
} // namespace detail

CLI::App*
Expand All @@ -200,6 +248,8 @@ NetRemoteCli::AddSubcommandWifiAccessPointEnable(CLI::App* parent)
->transform(CLI::CheckedTransformer(detail::Ieee80211FrequencyBandNames()));
cliAppWifiAccessPointEnable->add_option("--auth,--auths,--authAlg,--authAlgs,--authentication,--authenticationAlgorithm,--authenticationAlgorithms", m_cliData->WifiAccessPointAuthenticationAlgorithms, "The authentication algorithms of the access point to enable")
->transform(CLI::CheckedTransformer(detail::Ieee80211AuthenticationAlgorithmNames(), CLI::ignore_case));
cliAppWifiAccessPointEnable->add_option("--akm,--akms,--akmSuite,--akmSuites,--keyManagement,--keyManagements", m_cliData->WifiAccessPointAkmSuites, "The AKM suites of the access point to enable")
->transform(CLI::CheckedTransformer(detail::Ieee80211AkmSuiteNames(), CLI::ignore_case));
cliAppWifiAccessPointEnable->callback([this] {
WifiAccessPointEnableCallback();
});
Expand Down Expand Up @@ -232,6 +282,7 @@ NetRemoteCli::OnServerAddressChanged(const std::string& serverAddressArg)
serverAddress += std::format("{}{}", NetRemoteProtocol::PortSeparator, NetRemoteProtocol::PortDefault);
}

LOGI << std::format("Connecting to server {}", serverAddress);
m_cliData->ServerAddress = std::move(serverAddress);

auto connection = NetRemoteServerConnection::TryEstablishConnection(m_cliData->ServerAddress);
Expand Down
9 changes: 9 additions & 0 deletions src/common/tools/cli/NetRemoteCliHandlerOperations.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,15 @@ NetRemoteCliHandlerOperations::WifiAccessPointEnable(std::string_view accessPoin
dot11AccessPointConfiguration.set_phytype(dot11PhyType);
}

// Populate AKM suites if present.
if (!std::empty(ieee80211AccessPointConfiguration->AkmSuites)) {
auto dot11AkmSuites = ToDot11AkmSuites(ieee80211AccessPointConfiguration->AkmSuites);
*dot11AccessPointConfiguration.mutable_akmsuites() = {
std::make_move_iterator(std::begin(dot11AkmSuites)),
std::make_move_iterator(std::end(dot11AkmSuites))
};
}

// Populate pairwise cipher suites if present.
if (!std::empty(ieee80211AccessPointConfiguration->PairwiseCipherSuites)) {
auto dot11PairwiseCipherSuites = ToDot11CipherSuiteConfigurations(ieee80211AccessPointConfiguration->PairwiseCipherSuites);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct NetRemoteCliData
Microsoft::Net::Wifi::Ieee80211PhyType WifiAccessPointPhyType{ Microsoft::Net::Wifi::Ieee80211PhyType::Unknown };
std::vector<Microsoft::Net::Wifi::Ieee80211FrequencyBand> WifiAccessPointFrequencyBands{};
std::vector<Microsoft::Net::Wifi::Ieee80211AuthenticationAlgorithm> WifiAccessPointAuthenticationAlgorithms{};
std::vector<Microsoft::Net::Wifi::Ieee80211AkmSuite> WifiAccessPointAkmSuites{};
};
} // namespace Microsoft::Net::Remote

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct Ieee80211AccessPointConfiguration
std::optional<std::string> Ssid;
std::optional<Ieee80211Bssid> Bssid;
std::optional<Ieee80211PhyType> PhyType;
std::vector<Ieee80211AkmSuite> AkmSuites;
std::unordered_map<Ieee80211SecurityProtocol, std::vector<Ieee80211CipherSuite>> PairwiseCipherSuites;
std::vector<Ieee80211AuthenticationAlgorithm> AuthenticationAlgorithms;
std::vector<Ieee80211FrequencyBand> FrequencyBands;
Expand Down
31 changes: 30 additions & 1 deletion src/common/wifi/dot11/adapter/Ieee80211Dot11Adapters.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ constexpr auto toDot11FrequencyBand = [](const auto& frequencyBand) {
constexpr auto toDot11CipherSuite = [](const auto& cipherSuite) {
return static_cast<Dot11CipherSuite>(cipherSuite);
};

/**
* @brief Convert an int-typed Dot11AkmSuite to its proper enum type.
*/
constexpr auto toDot11AkmSuite = [](const auto& akmSuite) {
return static_cast<Dot11AkmSuite>(akmSuite);
};
} // namespace detail

std::vector<Dot11AuthenticationAlgorithm>
Expand Down Expand Up @@ -363,11 +370,31 @@ ToDot11AkmSuite(const Ieee80211AkmSuite ieee80211AkmSuite) noexcept
return Dot11AkmSuite::Dot11AkmSuiteFtPskSha384;
case Ieee80211AkmSuite::PskSha384:
return Dot11AkmSuite::Dot11AkmSuitePskSha384;
case Ieee80211AkmSuite::Pasn:
return Dot11AkmSuite::Dot11AkmSuitePasn;
default:
return Dot11AkmSuite::Dot11AkmSuiteUnknown;
}
}

std::vector<Dot11AkmSuite>
ToDot11AkmSuites(const Dot11AccessPointConfiguration& dot11AccessPointConfiguration) noexcept
{
std::vector<Dot11AkmSuite> dot11AkmSuites(static_cast<std::size_t>(std::size(dot11AccessPointConfiguration.akmsuites())));
std::ranges::transform(dot11AccessPointConfiguration.akmsuites(), std::begin(dot11AkmSuites), detail::toDot11AkmSuite);

return dot11AkmSuites;
}

std::vector<Dot11AkmSuite>
ToDot11AkmSuites(const std::vector<Ieee80211AkmSuite>& ieee80211AkmSuites) noexcept
{
std::vector<Dot11AkmSuite> dot11AkmSuites(static_cast<std::size_t>(std::size(ieee80211AkmSuites)));
std::ranges::transform(ieee80211AkmSuites, std::begin(dot11AkmSuites), ToDot11AkmSuite);

return dot11AkmSuites;
}

Ieee80211AkmSuite
FromDot11AkmSuite(const Dot11AkmSuite dot11AkmSuite) noexcept
{
Expand Down Expand Up @@ -414,8 +441,10 @@ FromDot11AkmSuite(const Dot11AkmSuite dot11AkmSuite) noexcept
return Ieee80211AkmSuite::FtPskSha384;
case Dot11AkmSuite::Dot11AkmSuitePskSha384:
return Ieee80211AkmSuite::PskSha384;
case Dot11AkmSuite::Dot11AkmSuitePasn:
return Ieee80211AkmSuite::Pasn;
default:
return Ieee80211AkmSuite::Reserved0; // FIXME: this needs to be an invalid value instead
return Ieee80211AkmSuite::Unknown;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,24 @@ FromDot11AuthenticationAlgorithm(Dot11AuthenticationAlgorithm dot11Authenticatio
Dot11AkmSuite
ToDot11AkmSuite(Ieee80211AkmSuite ieee80211AkmSuite) noexcept;

/**
* @brief Obtain a vector of Dot11AkmSuites from the specified Dot11AccessPointConfiguration.
*
* @param dot11AccessPointConfiguration The Dot11AccessPointConfiguration to extract the Dot11AkmSuites from.
* @return std::vector<Dot11AkmSuite>
*/
std::vector<Dot11AkmSuite>
ToDot11AkmSuites(const Dot11AccessPointConfiguration& dot11AccessPointConfiguration) noexcept;

/**
* @brief Convert the specified IEEE 802.11 AKM suite algorithms to the equivalent Dot11AkmSuites.
*
* @param ieee80211AkmSuites The IEEE 802.11 AKM suite algorithms to convert.
* @return std::vector<Dot11AkmSuite>
*/
std::vector<Dot11AkmSuite>
ToDot11AkmSuites(const std::vector<Ieee80211AkmSuite>& ieee80211AkmSuites) noexcept;

/**
* @brief Convert the specified Dot11AkmSuite to the equivalent IEEE 802.11 AKM suite algorithm.
*
Expand Down

0 comments on commit 5e322c9

Please sign in to comment.