Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ProbeAsync for AccessPointDiscoveryAgentOperationsNetlink #107

Merged
merged 8 commits into from
Jan 12, 2024
Merged
2 changes: 2 additions & 0 deletions src/linux/libnl-helpers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set(LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX ${LIBNL_HELPERS_PUBLIC_INCLUDE}/${LIBNL_
target_sources(libnl-helpers
PRIVATE
Netlink80211.cxx
Netlink80211Interface.cxx
Netlink80211ProtocolState.cxx
NetlinkMessage.cxx
NetlinkSocket.cxx
Expand All @@ -18,6 +19,7 @@ target_sources(libnl-helpers
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/NetlinkMessage.hxx
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/NetlinkSocket.hxx
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/nl80211/Netlink80211.hxx
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/nl80211/Netlink80211Interface.hxx
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/nl80211/Netlink80211ProtocolState.hxx
)

Expand Down
26 changes: 26 additions & 0 deletions src/linux/libnl-helpers/Netlink80211.cxx
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@

#include <format>

#include <microsoft/net/netlink/nl80211/Netlink80211.hxx>
#include <netlink/genl/genl.h>
#include <plog/Log.h>

namespace Microsoft::Net::Netlink::Nl80211
{
Expand Down Expand Up @@ -447,4 +451,26 @@ Nl80211InterfaceTypeToString(nl80211_iftype interfaceType) noexcept
}
}

using Microsoft::Net::Netlink::NetlinkSocket;

std::optional<NetlinkSocket>
CreateNl80211Socket()
{
// Allocate a new netlink socket.
auto netlinkSocket{ NetlinkSocket::Allocate() };
if (netlinkSocket == nullptr) {
LOGE << "Failed to allocate new netlink socket for nl control";
return std::nullopt;
}

// Connect the socket to the generic netlink family.
int ret = genl_connect(netlinkSocket);
if (ret < 0) {
LOGE << std::format("Failed to connect netlink socket for nl control with error {} ({})", ret, nl_geterror(ret));
return std::nullopt;
}

return std::move(netlinkSocket);
}

} // namespace Microsoft::Net::Netlink::Nl80211
154 changes: 154 additions & 0 deletions src/linux/libnl-helpers/Netlink80211Interface.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@

#include <array>
#include <format>

#include <linux/genetlink.h>
#include <magic_enum.hpp>
#include <microsoft/net/netlink/NetlinkMessage.hxx>
#include <microsoft/net/netlink/NetlinkSocket.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211Interface.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211ProtocolState.hxx>
#include <netlink/genl/genl.h>
#include <plog/Log.h>

using namespace Microsoft::Net::Netlink::Nl80211;

using Microsoft::Net::Netlink::NetlinkMessage;
using Microsoft::Net::Netlink::NetlinkSocket;

Nl80211Interface::Nl80211Interface(std::string_view name, nl80211_iftype type, uint32_t index) noexcept :
Name(name),
Type(type),
Index(index)
{
}

std::string
Nl80211Interface::ToString() const
{
return std::format("[{}] {} {}", Index, Name, magic_enum::enum_name(Type));
}

/* static */
std::optional<Nl80211Interface>
Nl80211Interface::Parse(struct nl_msg* nl80211Message) noexcept
{
// Ensure the message is valid.
if (nl80211Message == nullptr) {
LOGE << "Received null nl80211 message";
return std::nullopt;
}

// Ensure the message has a valid genl header.
auto *nl80211MessageHeader{ static_cast<struct nlmsghdr *>(nlmsg_hdr(nl80211Message)) };
if (genlmsg_valid_hdr(nl80211MessageHeader, 1) < 0) {
LOGE << "Received invalid nl80211 message header";
return std::nullopt;
}

// Extract the nl80211 (genl) message header.
const auto *genl80211MessageHeader{ static_cast<struct genlmsghdr *>(nlmsg_data(nl80211MessageHeader)) };

// Parse the message.
std::array<struct nlattr *, NL80211_ATTR_MAX + 1> newInterfaceMessageAttributes{};
int ret = nla_parse(std::data(newInterfaceMessageAttributes), std::size(newInterfaceMessageAttributes), genlmsg_attrdata(genl80211MessageHeader, 0), genlmsg_attrlen(genl80211MessageHeader, 0), nullptr);
if (ret < 0) {
LOG_ERROR << std::format("Failed to parse netlink message attributes with error {} ({})", ret, strerror(-ret));
return std::nullopt;
}

// Tease out parameters to populate the Nl80211Interface instance.
auto *interfaceName = static_cast<const char *>(nla_data(newInterfaceMessageAttributes[NL80211_ATTR_IFNAME]));
auto interfaceType = static_cast<nl80211_iftype>(nla_get_u32(newInterfaceMessageAttributes[NL80211_ATTR_IFTYPE]));
auto interfaceIndex = static_cast<uint32_t>(nla_get_u32(newInterfaceMessageAttributes[NL80211_ATTR_IFINDEX]));

return Nl80211Interface(interfaceName, interfaceType, interfaceIndex);
}

namespace detail
{
/**
* @brief Handler function for NL80211_CMD_GET_INTERFACE responses.
*
* @param nl80211Message The response message to a NL80211_CMD_GET_INTERFACE dump request.
* @param context The context pointer provided to nl_socket_modify_cb. This must be a std::vector<Nl80211Interface>*.
* @return int
*/
int
HandleNl80211InterfaceDumpResponse(struct nl_msg *nl80211Message, void *context)
{
if (context == nullptr) {
LOGE << "Received nl80211 interface dump response with null context";
return NL_SKIP;
}

// Extract vector to populate with interfaces.
auto &nl80211Interfaces = *static_cast<std::vector<Nl80211Interface> *>(context);

// Attempt to parse the message.
auto nl80211Interface = Nl80211Interface::Parse(nl80211Message);
if (!nl80211Interface.has_value()) {
LOGW << "Failed to parse nl80211 interface dump response";
return NL_SKIP;
} else {
LOGD << std::format("Parsed nl80211 interface dump response: {}", nl80211Interface->ToString());
}

nl80211Interfaces.push_back(std::move(nl80211Interface.value()));

return NL_OK;
}
} // namespace detail

/* static */
std::vector<Nl80211Interface>
Nl80211Interface::Enumerate()
{
// Allocate a new netlink socket.
auto nl80211SocketOpt{ CreateNl80211Socket() };
if (!nl80211SocketOpt.has_value()) {
LOGE << "Failed to create nl80211 socket";
return {};
}

// Allocate a new nl80211 message for sending the dump request for all interfaces.
auto nl80211Socket{ std::move(nl80211SocketOpt.value()) };
auto nl80211MessageGetInterfaces{ NetlinkMessage::Allocate() };
if (nl80211MessageGetInterfaces == nullptr) {
LOGE << "Failed to allocate nl80211 message for interface dump request";
return {};
}

// Populate the genl message for the interface dump request.
const int nl80211DriverId = Nl80211ProtocolState::Instance().DriverId;
const auto *genlMessageGetInterfaces = genlmsg_put(nl80211MessageGetInterfaces, NL_AUTO_PID, NL_AUTO_SEQ, nl80211DriverId, 0, NLM_F_DUMP, NL80211_CMD_GET_INTERFACE, 0);
if (genlMessageGetInterfaces == nullptr) {
LOGE << "Failed to populate genl message for interface dump request";
return {};
}

// Modify the socket callback to handle the response, providing a pointer to the vector to populate with interfaces.
std::vector<Nl80211Interface> nl80211Interfaces{};
int ret = nl_socket_modify_cb(nl80211Socket, NL_CB_VALID, NL_CB_CUSTOM, detail::HandleNl80211InterfaceDumpResponse, &nl80211Interfaces);
if (ret < 0) {
LOGE << std::format("Failed to modify socket callback with error {} ({})", ret, nl_geterror(ret));
return {};
}

// Send the request.
ret = nl_send_auto(nl80211Socket, nl80211MessageGetInterfaces);
if (ret < 0) {
LOGE << std::format("Failed to send interface dump request with error {} ({})", ret, nl_geterror(ret));
return {};
}

// Receive the response, which will invoke the configured callback for each message.
ret = nl_recvmsgs_default(nl80211Socket);
if (ret < 0) {
LOGE << std::format("Failed to receive interface dump response with error {} ({})", ret, nl_geterror(ret));
return {};
}

return nl80211Interfaces;
}
8 changes: 8 additions & 0 deletions src/linux/libnl-helpers/NetlinkMessage.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@

using namespace Microsoft::Net::Netlink;

/* static */
NetlinkMessage
NetlinkMessage::Allocate()
{
auto message = nlmsg_alloc();
return NetlinkMessage{ message };
}

NetlinkMessage::NetlinkMessage(struct nl_msg* message) :
Message(message)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ struct NetlinkMessage
*/
struct nl_msg* Message{ nullptr };

/**
* @brief Allocate a new struct nl_msg, and wrap it in a NetlinkMessage.
*
* @return NetlinkMessage
*/
static NetlinkMessage
Allocate();

/**
* @brief Construct a new NetlinkMessage object that does not own a netlink
* message object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ struct NetlinkSocket
static NetlinkSocket
Allocate();

static NetlinkSocket
Create();

/**
* @brief Construct a default NetlinkSocket object that does not own a
* netlink socket object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#ifndef NETLINK_82011_HXX
#define NETLINK_82011_HXX

#include <optional>
#include <string_view>
#include <unordered_map>

#include <linux/nl80211.h>
#include <microsoft/net/netlink/NetlinkSocket.hxx>

namespace Microsoft::Net::Netlink::Nl80211
{
Expand Down Expand Up @@ -46,6 +48,16 @@ Nl80211CommandToString(nl80211_commands command) noexcept;
std::string_view
Nl80211InterfaceTypeToString(nl80211_iftype type) noexcept;

/**
* @brief Create a netlink socket for use with Nl80211.
*
* This creates a netlink socket and connects it to the nl80211 generic netlink family.
*
* @return std::optional<Microsoft::Net::Netlink::NetlinkSocket>
*/
std::optional<Microsoft::Net::Netlink::NetlinkSocket>
CreateNl80211Socket();

} // namespace Microsoft::Net::Netlink::Nl80211

#endif // NETLINK_82011_HXX
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@

#ifndef NETLINK_82011_INTERFACE_HXX
#define NETLINK_82011_INTERFACE_HXX

#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
#include <vector>

#include <linux/netlink.h>
#include <linux/nl80211.h>
#include <netlink/msg.h>

namespace Microsoft::Net::Netlink::Nl80211
{
/**
* @brief Represents a netlink 802.11 interface.
*/
struct Nl80211Interface
{
std::string Name;
nl80211_iftype Type{ nl80211_iftype::NL80211_IFTYPE_UNSPECIFIED};
uint32_t Index;

/**
* @brief Parse a netlink message into an Nl80211Interface. The netlink message must contain a response to the
* NL80211_CMD_GET_INTERFACE command, which is encoded as a NL80211_CMD_NEW_INTERFACE.
*
* @param nl80211Message The message to parse.
* @return std::optional<Nl80211Interface> Will contain a valid Nl80211Interface if the message was parsed
* successfully, otherwise has no value, indicating the message did not contain a valid NL80211_CMD_NEW_INTERFACE
* response
*/
static std::optional<Nl80211Interface>
Parse(struct nl_msg* nl80211Message) noexcept;

/**
* @brief Enumerate all netlink 802.11 interfaces on the system.
*
* @return std::vector<Nl80211Interface>
*/
static std::vector<Nl80211Interface>
Enumerate();

/**
* @brief Convert the interface to a string representation.
*
* @return std::string
*/
std::string
ToString() const;

private:
/**
* @brief Construct a new Nl80211Interface object with the specified attributes.
*
* @param name The name of the interface.
* @param type The nl80211_iftype of the interface.
* @param index The interface index in the kernel.
*/
Nl80211Interface(std::string_view name, nl80211_iftype type, uint32_t index) noexcept;
};

} // namespace Microsoft::Net::Netlink::Nl80211

#endif // NETLINK_82011_INTERFACE_HXX
Loading
Loading