Skip to content

Commit

Permalink
Tidy random
Browse files Browse the repository at this point in the history
  • Loading branch information
mlund committed Jun 16, 2024
1 parent 6c2dca1 commit 1501171
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
37 changes: 18 additions & 19 deletions src/random.cpp
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
#include <doctest/doctest.h>
#include "random.h"
#include <nlohmann/json.hpp>
#include <stdexcept>
#include <iostream>
#include <string>
#include <sstream>

namespace Faunus {

void from_json(const nlohmann::json &j, Random &random) {
void from_json(const nlohmann::json &j, Random & rng) {
if (j.is_object()) {
auto seed = j.value("seed", std::string());
try {
if (seed == "default" or seed == "fixed") { // use default seed, i.e. do nothing
return;
} else if (seed == "hardware") { // use hardware seed
random.engine = decltype(random.engine)(std::random_device()());
rng.engine = decltype(rng.engine)(std::random_device()());
} else if (!seed.empty()) { // read engine state
std::stringstream stream(seed);
stream.exceptions(std::ios::badbit | std::ios::failbit);
stream >> random.engine;
stream >> rng.engine;
}
} catch (std::exception &e) {
std::cerr << "could not initialize random engine - falling back to fixed seed." << std::endl;
std::cerr << "could not initialize rng engine - falling back to fixed seed." << std::endl;
}
}
}
Expand Down Expand Up @@ -52,7 +51,7 @@ TEST_CASE("[Faunus] Random") {
using namespace Faunus;
Random slump, slump2; // local instances

CHECK(slump() == slump2()); // deterministic initialization by default; the global random variable cannot
CHECK_EQ(slump(), slump2()); // deterministic initialization by default; the global random variable cannot
// be used for comparison as its state is not reset at the beginning of
// each test case

Expand All @@ -66,51 +65,51 @@ TEST_CASE("[Faunus] Random") {
max = j;
x += j;
}
CHECK(min == 0);
CHECK(max == 9);
CHECK(std::fabs(x / N) == doctest::Approx(4.5).epsilon(0.01));
CHECK_EQ(min, 0);
CHECK_EQ(max, 9);
CHECK_EQ(std::fabs(x / N), doctest::Approx(4.5).epsilon(0.01));

Random r1 = R"( {"seed" : "hardware"} )"_json; // non-deterministic seed
Random r2; // default is a deterministic seed
CHECK(r1() != r2());
CHECK((r1() != r2()));
Random r3 = nlohmann::json(r1); // r1 --> json --> r3
CHECK(r1() == r3());
CHECK_EQ(r1(), r3());

// check if random_device works
Random a, b;
CHECK(a() == b());
CHECK_EQ(a(), b());
a.seed();
b.seed();
CHECK(a() != b());
CHECK((a() != b()));
}

TEST_CASE("[Faunus] WeightedDistribution") {
using namespace Faunus;
WeightedDistribution<double> v;

v.push_back(0.5);
CHECK(v.getLastIndex() == 0);
CHECK(v.size() == 1);
CHECK_EQ(v.getLastIndex(), 0);
CHECK_EQ(v.size(), 1);

v.push_back(0.1, 4);
CHECK(v.getLastIndex() == 1);
CHECK(v.size() == 2);
CHECK_EQ(v.getLastIndex(), 1);
CHECK_EQ(v.size(), 2);
CHECK(not v.empty());

int N = 1e4;
double sum = 0;
for (int i = 0; i < N; i++) {
sum += v.sample(Faunus::random.engine);
}
CHECK(sum / N == doctest::Approx((0.5 * 1 + 0.1 * 4) / (1 + 4)).epsilon(0.05));
CHECK_EQ(sum / N, doctest::Approx((0.5 * 1 + 0.1 * 4) / (1 + 4)).epsilon(0.05));

std::vector<int> weights = {2, 1};
v.setWeight(weights.begin(), weights.end());
sum = 0;
for (int i = 0; i < N; i++) {
sum += v.sample(Faunus::random.engine);
}
CHECK(sum / N == doctest::Approx((0.5 * 2 + 0.1 * 1) / (2 + 1)).epsilon(0.05));
CHECK_EQ(sum / N, doctest::Approx((0.5 * 2 + 0.1 * 1) / (2 + 1)).epsilon(0.05));

weights = {2, 1, 1};
CHECK_THROWS(v.setWeight(weights.begin(), weights.end()));
Expand Down
4 changes: 2 additions & 2 deletions src/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ template <typename T> class WeightedDistribution {
public:
std::vector<T> data; //!< raw vector of T
auto size() const { return data.size(); } //!< Number of data points
bool empty() const { return data.empty(); } //!< True if no data points
size_t getLastIndex() const {
[[nodiscard]] bool empty() const { return data.empty(); } //!< True if no data points
[[nodiscard]] size_t getLastIndex() const {
assert(!data.empty());
return latest_index;
} //!< index of last `get()` or `addGroup()` element
Expand Down

0 comments on commit 1501171

Please sign in to comment.