Skip to content

Commit

Permalink
Fix build and C++ tests for FreeBSD (#10480)
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Jun 28, 2024
1 parent e8a9625 commit 09d32f1
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 11 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/freebsd.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: FreeBSD

on: [push, pull_request]

permissions:
contents: read # to fetch code (actions/checkout)

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
test:
runs-on: ubuntu-latest
name: A job to run test in FreeBSD
steps:
- uses: actions/checkout@v4
with:
submodules: 'true'
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
with:
usesh: true
prepare: |
pkg install -y cmake git ninja googletest
run: |
mkdir build
cd build
cmake .. -GNinja -DGOOGLE_TEST=ON
ninja -v
./testxgboost
4 changes: 3 additions & 1 deletion include/xgboost/collective/poll_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ namespace utils {

template <typename PollFD>
int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) {
// For Windows and Linux, negative timeout means infinite timeout. For freebsd,
// INFTIM(-1) should be used instead.
#if defined(_WIN32)

#if IS_MINGW()
Expand All @@ -87,7 +89,7 @@ int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
#endif // IS_MINGW()

#else
return poll(pfd, nfds, std::chrono::milliseconds(timeout).count());
return poll(pfd, nfds, timeout.count() < 0 ? -1 : std::chrono::milliseconds(timeout).count());
#endif // IS_MINGW()
}

Expand Down
21 changes: 18 additions & 3 deletions src/c_api/coll_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;

void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
std::chrono::seconds wait_for{collective::HasTimeout(timeout) ? std::min(kDft, timeout.count())
: kDft};
std::int64_t timeout_clipped = kDft;
if (collective::HasTimeout(timeout)) {
timeout_clipped = std::min(kDft, static_cast<std::int64_t>(timeout.count()));
}
std::chrono::seconds wait_for{timeout_clipped};

common::Timer timer;
timer.Start();
Expand Down Expand Up @@ -171,7 +174,19 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
common::Timer timer;
timer.Start();
// Make sure no one else is waiting on the tracker.
while (!ptr->first.unique()) {

// Quote from https://en.cppreference.com/w/cpp/memory/shared_ptr/use_count#Notes:
//
// In multithreaded environment, `use_count() == 1` does not imply that the object is
// safe to modify because accesses to the managed object by former shared owners may not
// have completed, and because new shared owners may be introduced concurrently.
//
// - We don't have the first case since we never access the raw pointer.
//
// - We don't hve the second case for most of the scenarios since tracker is an unique
// object, if the free function is called before another function calls, it's likely
// to be a bug in the user code. The use_count should only decrease in this function.
while (ptr->first.use_count() != 1) {
auto ela = timer.Duration().count();
if (collective::HasTimeout(ptr->first->Timeout()) && ela > ptr->first->Timeout().count()) {
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
Expand Down
4 changes: 3 additions & 1 deletion src/collective/socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ namespace xgboost::collective {
SockAddress MakeSockAddress(StringView host, in_port_t port) {
struct addrinfo hints;
std::memset(&hints, 0, sizeof(hints));
hints.ai_protocol = SOCK_STREAM;
hints.ai_socktype = SOCK_STREAM;
struct addrinfo *res = nullptr;
int sig = getaddrinfo(host.c_str(), nullptr, &hints, &res);
if (sig != 0) {
LOG(FATAL) << "Failed to get addr info for: " << host
<< ", error: " << gai_strerror(sig);
return {};
}
if (res->ai_family == static_cast<std::int32_t>(SockDomain::kV4)) {
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/collective/test_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ inline Json MakeTrackerConfig(std::string host, std::int32_t n_workers,
config["port"] = Integer{0};
config["n_workers"] = Integer{n_workers};
config["sortby"] = Integer{static_cast<std::int32_t>(Tracker::SortBy::kHost)};
config["timeout"] = timeout.count();
config["timeout"] = static_cast<std::int64_t>(timeout.count());
return config;
}

Expand Down
10 changes: 8 additions & 2 deletions tests/cpp/common/test_random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,20 @@ TEST(ColumnSampler, GPUTest) {
// Test if different threads using the same seed produce the same result
TEST(ColumnSampler, ThreadSynchronisation) {
Context ctx;
const int64_t num_threads = 100;
// NOLINTBEGIN(clang-analyzer-deadcode.DeadStores)
#if defined(__linux__)
std::int64_t const n_threads = std::thread::hardware_concurrency() * 128;
#else
std::int64_t const n_threads = std::thread::hardware_concurrency();
#endif
// NOLINTEND(clang-analyzer-deadcode.DeadStores)
int n = 128;
size_t iterations = 10;
size_t levels = 5;
std::vector<bst_feature_t> reference_result;
std::vector<float> feature_weights;
bool success = true; // Cannot use google test asserts in multithreaded region
#pragma omp parallel num_threads(num_threads)
#pragma omp parallel num_threads(n_threads)
{
for (auto j = 0ull; j < iterations; j++) {
ColumnSampler cs(j);
Expand Down
6 changes: 5 additions & 1 deletion tests/cpp/test_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ TEST(DMatrixCache, MultiThread) {
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 3;
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();

auto n = std::thread::hardware_concurrency() * 128u;
#if defined(__linux__)
auto const n = std::thread::hardware_concurrency() * 128;
#else
auto const n = std::thread::hardware_concurrency();
#endif
CHECK_NE(n, 0);
std::vector<std::shared_ptr<CacheForTest>> results(n);

Expand Down
10 changes: 8 additions & 2 deletions tests/cpp/test_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,14 @@ TEST(Learner, MultiThreadedPredict) {
learner->Configure();

std::vector<std::thread> threads;
for (uint32_t thread_id = 0;
thread_id < 2 * std::thread::hardware_concurrency(); ++thread_id) {

#if defined(__linux__)
auto n_threads = std::thread::hardware_concurrency() * 4u;
#else
auto n_threads = std::thread::hardware_concurrency();
#endif

for (decltype(n_threads) thread_id = 0; thread_id < n_threads; ++thread_id) {
threads.emplace_back([learner, p_data] {
size_t constexpr kIters = 10;
auto &entry = learner->GetThreadLocal().prediction_entry;
Expand Down

0 comments on commit 09d32f1

Please sign in to comment.