Skip to content

Commit

Permalink
Merge pull request #162 from microsoft/strongexcepttype
Browse files Browse the repository at this point in the history
Signal netlink errors with exceptions
  • Loading branch information
abeltrano authored Feb 22, 2024
2 parents 40fcb9c + 9d6937f commit 3c2e603
Show file tree
Hide file tree
Showing 16 changed files with 264 additions and 109 deletions.
1 change: 0 additions & 1 deletion src/common/shared/notstd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ target_sources(notstd
${NOTSTD_PUBLIC_INCLUDE_PREFIX}/Exceptions.hxx
${NOTSTD_PUBLIC_INCLUDE_PREFIX}/Memory.hxx
${NOTSTD_PUBLIC_INCLUDE_PREFIX}/Scope.hxx
${NOTSTD_PUBLIC_INCLUDE_PREFIX}/Utility.hxx
)

install(
Expand Down
32 changes: 0 additions & 32 deletions src/common/shared/notstd/include/notstd/Utility.hxx

This file was deleted.

8 changes: 6 additions & 2 deletions src/linux/libnl-helpers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@ target_sources(libnl-helpers
PRIVATE
Netlink80211.cxx
Netlink80211Interface.cxx
NetlinkException.cxx
Netlink80211ProtocolState.cxx
Netlink80211Wiphy.cxx
NetlinkMessage.cxx
NetlinkSocket.cxx
Netlink80211WiphyBand.cxx
Netlink80211WiphyBandFrequency.cxx
NetlinkErrorCategory.cxx
NetlinkMessage.cxx
NetlinkSocket.cxx
PUBLIC
FILE_SET HEADERS
BASE_DIRS ${LIBNL_HELPERS_PUBLIC_INCLUDE}
FILES
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/NetlinkErrorCategory.hxx
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/NetlinkMessage.hxx
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/NetlinkSocket.hxx
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/NetlinkException.hxx
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/nl80211/Netlink80211.hxx
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/nl80211/Netlink80211Interface.hxx
${LIBNL_HELPERS_PUBLIC_INCLUDE_PREFIX}/nl80211/Netlink80211ProtocolState.hxx
Expand Down
21 changes: 9 additions & 12 deletions src/linux/libnl-helpers/Netlink80211.cxx
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@

#include <cstdint>
#include <format>
#include <optional>
#include <string_view>
#include <utility>
#include <system_error>

#include <linux/nl80211.h>
#include <microsoft/net/netlink/NetlinkErrorCategory.hxx>
#include <microsoft/net/netlink/NetlinkSocket.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211.hxx>
#include <netlink/errno.h>
#include <netlink/genl/genl.h>
#include <plog/Log.h>

Expand Down Expand Up @@ -506,24 +505,22 @@ Nl80211CipherSuiteToString(uint32_t cipherSuite) noexcept

using Microsoft::Net::Netlink::NetlinkSocket;

std::optional<NetlinkSocket>
NetlinkSocket
CreateNl80211Socket()
{
// Allocate a new netlink socket.
auto netlinkSocket{ NetlinkSocket::Allocate() };
if (netlinkSocket == nullptr) {
LOGE << "Failed to allocate new netlink socket for nl control";
return std::nullopt;
}

// Connect the socket to the generic netlink family.
int ret = genl_connect(netlinkSocket);
const int ret = genl_connect(netlinkSocket);
if (ret < 0) {
LOGE << std::format("Failed to connect netlink socket for nl control with error {} ({})", ret, nl_geterror(ret));
return std::nullopt;
const auto errorCode = MakeNetlinkErrorCode(-ret);
const auto message = std::format("Failed to connect netlink socket for nl control with error {}", errorCode.value());
LOGE << message;
throw std::system_error(errorCode, message);
}

return std::move(netlinkSocket);
return netlinkSocket;
}

} // namespace Microsoft::Net::Netlink::Nl80211
12 changes: 2 additions & 10 deletions src/linux/libnl-helpers/Netlink80211Interface.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <linux/nl80211.h>
#include <magic_enum.hpp>
#include <microsoft/net/netlink/NetlinkMessage.hxx>
#include <microsoft/net/netlink/NetlinkSocket.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211Interface.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211ProtocolState.hxx>
Expand Down Expand Up @@ -108,7 +107,7 @@ HandleNl80211InterfaceDumpResponse(struct nl_msg *nl80211Message, void *context)
LOGW << "Failed to parse nl80211 interface dump response";
return NL_SKIP;
}

LOGD << std::format("Parsed nl80211 interface dump response: {}", nl80211Interface->ToString());
nl80211Interfaces.push_back(std::move(nl80211Interface.value()));

Expand All @@ -120,15 +119,7 @@ HandleNl80211InterfaceDumpResponse(struct nl_msg *nl80211Message, void *context)
std::vector<Nl80211Interface>
Nl80211Interface::Enumerate()
{
// Allocate a new netlink socket.
auto nl80211SocketOpt{ CreateNl80211Socket() };
if (!nl80211SocketOpt.has_value()) {
LOGE << "Failed to create nl80211 socket";
return {};
}

// Allocate a new nl80211 message for sending the dump request for all interfaces.
auto nl80211Socket{ std::move(nl80211SocketOpt.value()) };
auto nl80211MessageGetInterfaces{ NetlinkMessage::Allocate() };
if (nl80211MessageGetInterfaces == nullptr) {
LOGE << "Failed to allocate nl80211 message for interface dump request";
Expand All @@ -144,6 +135,7 @@ Nl80211Interface::Enumerate()
}

// Modify the socket callback to handle the response, providing a pointer to the vector to populate with interfaces.
auto nl80211Socket{ CreateNl80211Socket() };
std::vector<Nl80211Interface> nl80211Interfaces{};
int ret = nl_socket_modify_cb(nl80211Socket, NL_CB_VALID, NL_CB_CUSTOM, detail::HandleNl80211InterfaceDumpResponse, &nl80211Interfaces);
if (ret < 0) {
Expand Down
17 changes: 5 additions & 12 deletions src/linux/libnl-helpers/Netlink80211ProtocolState.cxx
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@

#include <cerrno>
#include <cstring>
#include <format>
#include <iterator>

#include <linux/nl80211.h>
#include <microsoft/net/netlink/NetlinkException.hxx>
#include <microsoft/net/netlink/NetlinkSocket.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211ProtocolState.hxx>
#include <netlink/errno.h>
#include <netlink/genl/ctrl.h>
#include <netlink/genl/genl.h>
#include <plog/Log.h>

using namespace Microsoft::Net::Netlink::Nl80211;

Expand All @@ -25,24 +22,20 @@ Nl80211ProtocolState::Nl80211ProtocolState()
// Connect the socket to the generic netlink family.
const int ret = genl_connect(netlinkSocket);
if (ret < 0) {
const auto err = errno;
LOGE << std::format("Failed to connect netlink socket for nl control with error {} ({})", err, strerror(err)); // NOLINT(concurrency-mt-unsafe)
throw err;
throw NetlinkException::CreateLogged(-ret, "Failed to connect netlink socket for nl control");
}

// Look up the nl80211 driver id.
DriverId = genl_ctrl_resolve(netlinkSocket, NL80211_GENL_NAME);
if (DriverId < 0) {
LOGE << std::format("Failed to resolve nl80211 netlink id with error {} ({})", DriverId, nl_geterror(DriverId));
throw DriverId;
throw NetlinkException::CreateLogged(-DriverId, "Failed to resolve nl80211 netlink id");
}

// Lookup the ids for the nl80211 multicast groups.
for (const auto& [multicastGroup, multicastGroupName] : Nl80211MulticastGroupNames) {
int multicastGroupId = genl_ctrl_resolve_grp(netlinkSocket, NL80211_GENL_NAME, std::data(multicastGroupName));
const int multicastGroupId = genl_ctrl_resolve_grp(netlinkSocket, NL80211_GENL_NAME, std::data(multicastGroupName));
if (multicastGroupId < 0) {
LOGE << std::format("Failed to resolve nl80211 {} multicast group id with error {} ({})", multicastGroupName, multicastGroupId, nl_geterror(multicastGroupId));
throw multicastGroupId;
throw NetlinkException::CreateLogged(-multicastGroupId, std::format("Failed to resolve nl80211 {} multicast group id", multicastGroupName));
}
MulticastGroupId[multicastGroup] = multicastGroupId;
}
Expand Down
12 changes: 2 additions & 10 deletions src/linux/libnl-helpers/Netlink80211Wiphy.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <linux/nl80211.h>
#include <magic_enum.hpp>
#include <microsoft/net/netlink/NetlinkMessage.hxx>
#include <microsoft/net/netlink/NetlinkSocket.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211ProtocolState.hxx>
#include <microsoft/net/netlink/nl80211/Netlink80211Wiphy.hxx>
Expand Down Expand Up @@ -76,17 +75,9 @@ HandleNl80211GetWiphyResponse(struct nl_msg *nl80211Message, void *context) noex

/* static */
std::optional<Nl80211Wiphy>
Nl80211Wiphy::FromId(const std::function<void(Microsoft::Net::Netlink::NetlinkMessage &)> &addWiphyIdentifier)
Nl80211Wiphy::FromId(const std::function<void(NetlinkMessage &)> &addWiphyIdentifier)
{
// Allocate a new netlink socket.
auto nl80211SocketOpt{ CreateNl80211Socket() };
if (!nl80211SocketOpt.has_value()) {
LOGE << "Failed to create nl80211 socket";
return std::nullopt;
}

// Allocate a new nl80211 message for sending the dump request for all interfaces.
auto nl80211Socket{ std::move(nl80211SocketOpt.value()) };
auto nl80211MessageGetWiphy{ NetlinkMessage::Allocate() };
if (nl80211MessageGetWiphy == nullptr) {
LOGE << "Failed to allocate nl80211 message for wiphy request";
Expand All @@ -104,6 +95,7 @@ Nl80211Wiphy::FromId(const std::function<void(Microsoft::Net::Netlink::NetlinkMe
// Add the identifier to the message so nl80211 knows what to lookup.
addWiphyIdentifier(nl80211MessageGetWiphy);

auto nl80211Socket{ CreateNl80211Socket() };
std::optional<Nl80211Wiphy> nl80211Wiphy{};
int ret = nl_socket_modify_cb(nl80211Socket, NL_CB_VALID, NL_CB_CUSTOM, detail::HandleNl80211GetWiphyResponse, &nl80211Wiphy);
if (ret < 0) {
Expand Down
17 changes: 9 additions & 8 deletions src/linux/libnl-helpers/Netlink80211WiphyBand.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Nl80211WiphyBand::Nl80211WiphyBand(std::vector<WiphyBandFrequency> frequencies,
Bitrates(std::move(bitRates)),
HtCapabilities(htCapabilities),
VhtCapabilities(VhtCapabilities),
VhtMcsSet(std::move(vhtMcsSetOpt))
VhtMcsSet(vhtMcsSetOpt)
{
}

Expand Down Expand Up @@ -58,29 +58,30 @@ Nl80211WiphyBand::Parse(struct nlattr *wiphyBand) noexcept

std::vector<WiphyBandFrequency> frequencies{};
if (wiphyBandAttributes[NL80211_BAND_ATTR_FREQS] != nullptr) {
struct nlattr *wiphyBandFrequency;
int remainingBandFrequencies;
struct nlattr *wiphyBandFrequency = nullptr;
int remainingBandFrequencies = 0;
nla_for_each_nested(wiphyBandFrequency, wiphyBandAttributes[NL80211_BAND_ATTR_FREQS], remainingBandFrequencies)
{
auto frequency = WiphyBandFrequency::Parse(wiphyBandFrequency);
if (frequency.has_value()) {
frequencies.emplace_back(std::move(frequency.value()));
frequencies.emplace_back(frequency.value());
}
}
}

std::vector<uint32_t> bitRates{};
if (wiphyBandAttributes[NL80211_BAND_ATTR_RATES] != nullptr) {
int remainingBitRates;
struct nlattr *bitRate;
int remainingBitRates = 0;
struct nlattr *bitRate = nullptr;
nla_for_each_nested(bitRate, wiphyBandAttributes[NL80211_BAND_ATTR_RATES], remainingBitRates)
{
std::array<struct nlattr *, NL80211_BITRATE_ATTR_MAX + 1> bitRateAttributes{};
ret = nla_parse(std::data(bitRateAttributes), std::size(bitRateAttributes), static_cast<struct nlattr *>(nla_data(bitRate)), nla_len(bitRate), nullptr);
if (ret < 0) {
LOGW << std::format("Failed to parse wiphy band bit rate attributes with error {} ({})", ret, nl_geterror(ret));
return std::nullopt;
} else if (bitRateAttributes[NL80211_BITRATE_ATTR_RATE] == nullptr) {
}
if (bitRateAttributes[NL80211_BITRATE_ATTR_RATE] == nullptr) {
continue;
}

Expand All @@ -89,7 +90,7 @@ Nl80211WiphyBand::Parse(struct nlattr *wiphyBand) noexcept
}
}

return Nl80211WiphyBand(std::move(frequencies), std::move(bitRates), htCapabilities, vhtCapabilities, std::move(vhtMcsSetOpt));
return Nl80211WiphyBand(std::move(frequencies), std::move(bitRates), htCapabilities, vhtCapabilities, vhtMcsSetOpt);
}

std::string
Expand Down
47 changes: 47 additions & 0 deletions src/linux/libnl-helpers/NetlinkErrorCategory.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

#include <stdexcept>
#include <string>
#include <system_error>

#include <microsoft/net/netlink/NetlinkErrorCategory.hxx>
#include <netlink/errno.h>

namespace Microsoft::Net::Netlink
{

const char*
NetlinkErrorCategory::name() const noexcept
{
static constexpr const char* Name = "Netlink";
return Name;
}

std::string
NetlinkErrorCategory::message(int error) const
{
return nl_geterror(error);
}

std::error_condition
NetlinkErrorCategory::default_error_condition(int error) const noexcept
{
return std::error_condition(error, *this);
}

std::error_code
make_netlink_error_code(int error)
{
return std::error_code(error, NetlinkErrorCategory());
}

std::error_code
MakeNetlinkErrorCode(int error)
{
if (error < 0) {
throw std::runtime_error("Netlink error codes must be non-negative; this is a programming error");
}

return make_netlink_error_code(error);
}

} // namespace Microsoft::Net::Netlink
34 changes: 34 additions & 0 deletions src/linux/libnl-helpers/NetlinkException.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

#include <format>
#include <string>
#include <system_error>

#include <microsoft/net/netlink/NetlinkErrorCategory.hxx>
#include <microsoft/net/netlink/NetlinkException.hxx>
#include <plog/Log.h>

using namespace Microsoft::Net::Netlink;

NetlinkException::NetlinkException(int error, const char *what) :
std::system_error(MakeNetlinkErrorCode(error), what)
{}

NetlinkException::NetlinkException(int error, const std::string &what) :
NetlinkException(error, what.c_str())
{}

/* static */
NetlinkException
NetlinkException::CreateLogged(int error, const char *what)
{
NetlinkException netlinkException(error, what);
LOGE << std::format("Netlink error ({}): {} ({})", netlinkException.code().value(), netlinkException.what(), netlinkException.code().message());
return netlinkException;
}

/* static */
NetlinkException
NetlinkException::CreateLogged(int error, const std::string &what)
{
return CreateLogged(error, what.c_str());
}
8 changes: 8 additions & 0 deletions src/linux/libnl-helpers/NetlinkSocket.cxx
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@

#include <system_error>

#include <microsoft/net/netlink/NetlinkErrorCategory.hxx>
#include <microsoft/net/netlink/NetlinkSocket.hxx>
#include <netlink/errno.h>
#include <netlink/handlers.h>
#include <netlink/socket.h>

Expand All @@ -10,6 +14,10 @@ NetlinkSocket
NetlinkSocket::Allocate()
{
auto* socket = nl_socket_alloc();
if (socket == nullptr) {
throw std::system_error(MakeNetlinkErrorCode(NLE_NOMEM), "Failed to allocate netlink socket");
}

return NetlinkSocket{ socket };
}

Expand Down
Loading

0 comments on commit 3c2e603

Please sign in to comment.