Skip to content

Commit

Permalink
Merge pull request #119 from microsoft/apdiscoveryif
Browse files Browse the repository at this point in the history
Change access point discovery output from interface name to IAccessPoint
  • Loading branch information
abeltrano authored Jan 22, 2024
2 parents fe60c2d + b61f71d commit cf28679
Show file tree
Hide file tree
Showing 19 changed files with 153 additions and 126 deletions.
11 changes: 5 additions & 6 deletions src/common/wifi/apmanager/AccessPointDiscoveryAgent.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ AccessPointDiscoveryAgent::RegisterDiscoveryEventCallback(AccessPointPresenceEve
}

void
AccessPointDiscoveryAgent::DevicePresenceChanged(AccessPointPresenceEvent presence, std::string interfaceName) const noexcept
AccessPointDiscoveryAgent::DevicePresenceChanged(AccessPointPresenceEvent presence, std::shared_ptr<IAccessPoint> accessPoint) const noexcept
{
std::shared_lock<std::shared_mutex> onDevicePresenceChangedLock{ m_onDevicePresenceChangedGate };
if (m_onDevicePresenceChanged) {
LOGD << "Access point discovery agent detected a device presence change";
m_onDevicePresenceChanged(presence, std::move(interfaceName));
m_onDevicePresenceChanged(presence, std::move(accessPoint));
}
}

Expand All @@ -57,10 +57,10 @@ AccessPointDiscoveryAgent::Start()
bool expected = false;
if (m_started.compare_exchange_weak(expected, true)) {
LOGD << "Access point discovery agent starting";
m_operations->Start([weakThis = std::weak_ptr<AccessPointDiscoveryAgent>(GetInstance())](auto&& presence, auto&& interfaceName) {
m_operations->Start([weakThis = std::weak_ptr<AccessPointDiscoveryAgent>(GetInstance())](auto&& presence, auto&& accessPoint) {
// Attempt to promote the weak pointer to a shared pointer to ensure this instance is still valid.
if (auto strongThis = weakThis.lock(); strongThis) {
strongThis->DevicePresenceChanged(presence, std::move(interfaceName));
strongThis->DevicePresenceChanged(presence, std::move(accessPoint));
}
});
}
Expand All @@ -76,9 +76,8 @@ AccessPointDiscoveryAgent::Stop()
}
}

std::future<std::vector<std::string>>
std::future<std::vector<std::shared_ptr<IAccessPoint>>>
AccessPointDiscoveryAgent::ProbeAsync()
{
LOGD << "Access point discovery agent probing for devices";
return m_operations->ProbeAsync();
}
26 changes: 11 additions & 15 deletions src/common/wifi/apmanager/AccessPointManager.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,21 @@
#include <chrono>
#include <iterator>

#include <magic_enum.hpp>
#include <microsoft/net/wifi/AccessPoint.hxx>
#include <microsoft/net/wifi/AccessPointDiscoveryAgent.hxx>
#include <microsoft/net/wifi/AccessPointManager.hxx>
#include <microsoft/net/wifi/IAccessPoint.hxx>
#include <notstd/Memory.hxx>
#include <plog/Log.h>

using namespace Microsoft::Net::Wifi;

AccessPointManager::AccessPointManager(std::unique_ptr<IAccessPointFactory> accessPointFactory) :
m_accessPointFactory(std::move(accessPointFactory))
{}

/* static */
std::shared_ptr<AccessPointManager>
AccessPointManager::Create(std::unique_ptr<IAccessPointFactory> accessPointFactory)
AccessPointManager::Create()
{
return std::make_shared<notstd::enable_make_protected<AccessPointManager>>(std::move(accessPointFactory));
return std::make_shared<notstd::enable_make_protected<AccessPointManager>>();
}

std::shared_ptr<AccessPointManager>
Expand Down Expand Up @@ -98,9 +96,8 @@ AccessPointManager::AddDiscoveryAgent(std::shared_ptr<AccessPointDiscoveryAgent>
// be safely destroyed prior to the discovery agent. This allows the
// callback to be registered indefinitely, safely checking whether this
// instance is still valid upon each callback invocation.
discoveryAgent->RegisterDiscoveryEventCallback([discoveryAgentPtr = discoveryAgent.get(), weakThis = std::weak_ptr<AccessPointManager>(GetInstance())](auto&& presence, auto&& interfaceName) {
discoveryAgent->RegisterDiscoveryEventCallback([discoveryAgentPtr = discoveryAgent.get(), weakThis = std::weak_ptr<AccessPointManager>(GetInstance())](auto&& presence, auto&& accessPointChanged) {
if (auto strongThis = weakThis.lock()) {
auto accessPointChanged = strongThis->m_accessPointFactory->Create(interfaceName);
strongThis->OnAccessPointPresenceChanged(discoveryAgentPtr, presence, std::move(accessPointChanged));
}
});
Expand All @@ -111,29 +108,28 @@ AccessPointManager::AddDiscoveryAgent(std::shared_ptr<AccessPointDiscoveryAgent>
}

// Kick off a probe to ensure any access points already present will be added to this manager.
auto existingAccessPointInterfaceNamesProbe = discoveryAgent->ProbeAsync();
auto existingAccessPointsProbe = discoveryAgent->ProbeAsync();

// Add the agent.
{
std::unique_lock<std::shared_mutex> discoveryAgentLock{ m_discoveryAgentsGate };
m_discoveryAgents.push_back(std::move(discoveryAgent));
}

if (existingAccessPointInterfaceNamesProbe.valid()) {
if (existingAccessPointsProbe.valid()) {
static constexpr auto ProbeTimeout = 3s;

// Wait for the operation to complete.
const auto waitResult = existingAccessPointInterfaceNamesProbe.wait_for(ProbeTimeout);
const auto waitResult = existingAccessPointsProbe.wait_for(ProbeTimeout);

// If the operation completed, get the results and add those access points.
if (waitResult == std::future_status::ready) {
auto existingAccessPointInterfaceNames = existingAccessPointInterfaceNamesProbe.get();
for (const auto& existingAccessPointInterfaceName : existingAccessPointInterfaceNames) {
auto existingAccessPoint = m_accessPointFactory->Create(existingAccessPointInterfaceName);
auto existingAccessPoints = existingAccessPointsProbe.get();
for (const auto& existingAccessPoint : existingAccessPoints) {
AddAccessPoint(std::move(existingAccessPoint));
}
} else {
// TODO: log error
LOGE << std::format("Access point discovery agent probe failed ({})", magic_enum::enum_name(waitResult));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <future>
#include <memory>
#include <shared_mutex>
#include <string>

#include <microsoft/net/wifi/IAccessPointDiscoveryAgentOperations.hxx>

Expand Down Expand Up @@ -75,11 +74,11 @@ struct AccessPointDiscoveryAgent :
Stop();

/**
* @brief Probe for all existing devices.
* @brief Perform an asynchronous discovery probe.
*
* @return std::future<std::vector<std::string>>
* @return std::future<std::vector<std::shared_ptr<IAccessPoint>>>
*/
std::future<std::vector<std::string>>
std::future<std::vector<std::shared_ptr<IAccessPoint>>>
ProbeAsync();

protected:
Expand All @@ -94,10 +93,10 @@ protected:
* @brief Wrapper for safely invoking any device presence changed registered callback.
*
* @param presence The presence change that occurred.
* @param interfaceName The name of the network interface of the access point that changed.
* @param accessPoint The access point instance that changed.
*/
void
DevicePresenceChanged(AccessPointPresenceEvent presence, std::string interfaceName) const noexcept;
DevicePresenceChanged(AccessPointPresenceEvent presence, std::shared_ptr<IAccessPoint> accessPoint) const noexcept;

private:
std::unique_ptr<IAccessPointDiscoveryAgentOperations> m_operations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public:
* @return std::shared_ptr<AccessPointManager>
*/
[[nodiscard]] static std::shared_ptr<AccessPointManager>
Create(std::unique_ptr<IAccessPointFactory> accessPointFactory);
Create();

/**
* @brief Get an instance of this access point manager.
Expand Down Expand Up @@ -84,20 +84,11 @@ public:

protected:
/**
* @brief Default constructor.
* @brief Construct a new AccessPointManager object.
*
* It's intentional that this is *declared* here and default-implemented
* in the source file. This is required because IAccessPoint
* and AccessPointDiscoveryAgent are used as incomplete
* types with std::unique_ptr and std::shared_ptr. In case an exception is
* thrown in the constructor, their destructors may be called, and the
* wrapped type must be complete at that time. As such, defining the
* constructor implementation as default here would require the type to be
* complete, which is impossible due to the forward declaration.
* Consequently, the = default implementation is done in the source file
* instead.
* @param accessPointFactory
*/
AccessPointManager(std::unique_ptr<IAccessPointFactory> accessPointFactory);
AccessPointManager() = default;

private:
/**
Expand Down Expand Up @@ -127,7 +118,7 @@ private:
RemoveAccessPoint(std::shared_ptr<IAccessPoint> accessPoint);

private:
std::unique_ptr<IAccessPointFactory> m_accessPointFactory;
std::shared_ptr<IAccessPointFactory> m_accessPointFactory;

mutable std::mutex m_accessPointGate;
std::vector<std::shared_ptr<IAccessPoint>> m_accessPoints{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <functional>
#include <future>
#include <memory>
#include <string>
#include <vector>

namespace Microsoft::Net::Wifi
Expand All @@ -20,7 +19,7 @@ struct IAccessPoint;
/**
* @brief Prototype for the callback invoked when an access point is discovered or removed.
*/
using AccessPointPresenceEventCallback = std::function<void(AccessPointPresenceEvent, std::string interfaceName)>;
using AccessPointPresenceEventCallback = std::function<void(AccessPointPresenceEvent, std::shared_ptr<IAccessPoint> accessPoint)>;

/**
* @brief Operations used to perform discovery of access points, used by
Expand All @@ -47,9 +46,9 @@ struct IAccessPointDiscoveryAgentOperations
/**
* @brief Perform an asynchronous discovery probe.
*
* @return std::future<std::vector<std::string>>
* @return std::future<std::vector<std::shared_ptr<IAccessPoint>>>
*/
virtual std::future<std::vector<std::string>>
virtual std::future<std::vector<std::shared_ptr<IAccessPoint>>>
ProbeAsync() = 0;
};

Expand Down
11 changes: 6 additions & 5 deletions src/linux/server/Main.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ main(int argc, char *argv[])

// Create an access point manager and discovery agent.
{
auto accessPointControllerFactory = std::make_unique<AccessPointControllerHostapdFactory>();
auto accessPointFactory = std::make_unique<AccessPointFactoryLinux>(std::move(accessPointControllerFactory));
configuration.AccessPointManager = AccessPointManager::Create(std::move(accessPointFactory));
configuration.AccessPointManager = AccessPointManager::Create();

auto &accessPointManager = configuration.AccessPointManager;
auto accessPointDiscoveryAgentOperationsNetlink = std::make_unique<AccessPointDiscoveryAgentOperationsNetlink>();
auto accessPointControllerFactory = std::make_unique<AccessPointControllerHostapdFactory>();
auto accessPointFactory = std::make_shared<AccessPointFactoryLinux>(std::move(accessPointControllerFactory));
auto accessPointDiscoveryAgentOperationsNetlink = std::make_unique<AccessPointDiscoveryAgentOperationsNetlink>(accessPointFactory);
auto accessPointDiscoveryAgent = AccessPointDiscoveryAgent::Create(std::move(accessPointDiscoveryAgentOperationsNetlink));
auto &accessPointManager = configuration.AccessPointManager;

accessPointManager->AddDiscoveryAgent(std::move(accessPointDiscoveryAgent));
}

Expand Down
1 change: 1 addition & 0 deletions src/linux/tools/apmonitor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ target_link_libraries(apmonitor-cli-linux
PRIVATE
plog::plog
wifi-apmanager-linux
wifi-core-linux
)

set_target_properties(apmonitor-cli-linux
Expand Down
14 changes: 8 additions & 6 deletions src/linux/tools/apmonitor/Main.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
#include <format>

#include <magic_enum.hpp>
#include <microsoft/net/wifi/AccessPointControllerHostapd.hxx>
#include <microsoft/net/wifi/AccessPointDiscoveryAgent.hxx>
#include <microsoft/net/wifi/AccessPointDiscoveryAgentOperationsNetlink.hxx>
#include <microsoft/net/wifi/AccessPointLinux.hxx>
#include <microsoft/net/wifi/IAccessPoint.hxx>
#include <plog/Appenders/ColorConsoleAppender.h>
#include <plog/Formatters/MessageOnlyFormatter.h>
#include <plog/Init.h>
#include <plog/Log.h>
#include <signal.h>

using Microsoft::Net::Wifi::AccessPointDiscoveryAgent;
using Microsoft::Net::Wifi::AccessPointDiscoveryAgentOperationsNetlink;
using Microsoft::Net::Wifi::IAccessPointDiscoveryAgentOperations;
using namespace Microsoft::Net::Wifi;

int
main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
Expand All @@ -23,10 +23,12 @@ main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
plog::init(plog::verbose, &colorConsoleAppender);

// Configure monitoring with the netlink protocol.
auto accessPointDiscoveryAgentOperationsNetlink{ std::make_unique<AccessPointDiscoveryAgentOperationsNetlink>() };
auto accessPointControllerFactory = std::make_unique<AccessPointControllerHostapdFactory>();
auto accessPointFactory = std::make_shared<AccessPointFactoryLinux>(std::move(accessPointControllerFactory));
auto accessPointDiscoveryAgentOperationsNetlink{ std::make_unique<AccessPointDiscoveryAgentOperationsNetlink>(accessPointFactory) };
auto accessPointDiscoveryAgent{ AccessPointDiscoveryAgent::Create(std::move(accessPointDiscoveryAgentOperationsNetlink)) };
accessPointDiscoveryAgent->RegisterDiscoveryEventCallback([](auto&& presence, auto&& interfaceName) {
PLOG_INFO << std::format("{} -> {}", interfaceName, magic_enum::enum_name(presence));
accessPointDiscoveryAgent->RegisterDiscoveryEventCallback([](auto&& presence, auto&& accessPointChanged) {
LOGI << std::format("{} -> {}", accessPointChanged->GetInterfaceName(), magic_enum::enum_name(presence));
});

LOG_INFO << "starting access point discovery agent";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <linux/if.h>
#include <magic_enum.hpp>
#include <microsoft/net/netlink/nl80211/Netlink80211.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211Interface.hxx>
#include <microsoft/net/wifi/AccessPointDiscoveryAgentOperationsNetlink.hxx>
#include <microsoft/net/wifi/IAccessPoint.hxx>
#include <netlink/genl/genl.h>
Expand All @@ -31,7 +30,8 @@ using Microsoft::Net::Netlink::NetlinkMessage;
using Microsoft::Net::Netlink::NetlinkSocket;
using Microsoft::Net::Netlink::Nl80211::Nl80211ProtocolState;

AccessPointDiscoveryAgentOperationsNetlink::AccessPointDiscoveryAgentOperationsNetlink() :
AccessPointDiscoveryAgentOperationsNetlink::AccessPointDiscoveryAgentOperationsNetlink(std::shared_ptr<IAccessPointFactory> accessPointFactory) :
m_accessPointFactory(std::move(accessPointFactory)),
m_cookie(CookieValid),
m_netlink80211ProtocolState(Nl80211ProtocolState::Instance())
{}
Expand Down Expand Up @@ -123,7 +123,7 @@ namespace detail
/**
* @brief Helper function to determine if an nl80211 interface is an AP. To be used in range expressions.
*
* @param nl80211Interface
* @param nl80211Interface The nl80211 interface to check.
* @return true
* @return false
*/
Expand All @@ -134,28 +134,33 @@ IsNl80211InterfaceTypeAp(const Nl80211Interface &nl80211Interface)
}

/**
* @brief Helper function returning the name of an nl80211 interface. To be used in range expressions.
* @brief Helper function to create an access point instance from an nl80211 interface. To be used in range expressions.
*
* @param nl80211Interface
* @return std::string
* @param accessPointFactory The factory to use for creating the access point instance.
* @param nl80211Interface The nl80211 interface to create the access point instance from.
* @return std::shared_ptr<IAccessPoint>
*/
std::string
Nl80211InterfaceName(const Nl80211Interface &nl80211Interface)
std::shared_ptr<IAccessPoint>
MakeAccessPoint(std::shared_ptr<IAccessPointFactory> accessPointFactory, const Nl80211Interface &nl80211Interface)
{
return nl80211Interface.Name;
return accessPointFactory->Create(nl80211Interface.Name);
}
} // namespace detail

std::future<std::vector<std::string>>
std::future<std::vector<std::shared_ptr<IAccessPoint>>>
AccessPointDiscoveryAgentOperationsNetlink::ProbeAsync()
{
std::promise<std::vector<std::string>> probePromise{};
const auto MakeAccessPoint = [this](const Nl80211Interface &nl80211Interface) {
return detail::MakeAccessPoint(m_accessPointFactory, nl80211Interface);
};

std::promise<std::vector<std::shared_ptr<IAccessPoint>>> probePromise{};
auto probeFuture = probePromise.get_future();

// Enumerate all nl80211 interfaces and filter out those that are not APs.
auto nl80211Interfaces{ Nl80211Interface::Enumerate() };
auto nl80211ApInterfaceNames = nl80211Interfaces | std::views::filter(detail::IsNl80211InterfaceTypeAp) | std::views::transform(detail::Nl80211InterfaceName);
std::vector<std::string> accessPoints(std::make_move_iterator(std::begin(nl80211ApInterfaceNames)), std::make_move_iterator(std::end(nl80211ApInterfaceNames)));
auto accessPointsView = nl80211Interfaces | std::views::filter(detail::IsNl80211InterfaceTypeAp) | std::views::transform(MakeAccessPoint);
std::vector<std::shared_ptr<IAccessPoint>> accessPoints(std::make_move_iterator(std::begin(accessPointsView)), std::make_move_iterator(std::end(accessPointsView)));

// Clear the vector since most of the items were moved out.
nl80211Interfaces.clear();
Expand Down Expand Up @@ -237,7 +242,8 @@ AccessPointDiscoveryAgentOperationsNetlink::ProcessNetlinkMessage(struct nl_msg
// Invoke presence event callback if present.
if (accessPointPresenceEventCallback != nullptr) {
LOGV << std::format("Invoking access point presence event callback with event args 'interface={}, presence={}'", interfaceName, magic_enum::enum_name(accessPointPresenceEvent));
accessPointPresenceEventCallback(accessPointPresenceEvent, interfaceName);
auto accessPoint = m_accessPointFactory->Create(interfaceName);
accessPointPresenceEventCallback(accessPointPresenceEvent, accessPoint);
}

return NL_OK;
Expand Down
Loading

0 comments on commit cf28679

Please sign in to comment.