Skip to content

Commit

Permalink
Address my own minor comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hugary1995 committed Feb 18, 2025
1 parent 9366d59 commit a8dd7df
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
4 changes: 2 additions & 2 deletions include/neml2/dispatchers/StaticHybridScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ class StaticHybridScheduler : public WorkScheduler
* one of a certain type. The device index is optional, and in its defaulted state represents
* (abstractly) "the current device". Further, there are two constraints on the value of the
* device index, if one is explicitly stored:
* 0. A negative index represents the current device, a non-negative index
* 1. A negative index represents the current device, a non-negative index
* represents a specific, concrete device,
* 1. When the device type is CPU, the device index must be zero.
* 2. When the device type is CPU, the device index must be zero.
*/
void setup() override;

Expand Down
16 changes: 7 additions & 9 deletions python/neml2/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,28 @@
#include "neml2/misc/types.h"
#include "neml2/base/LabeledAxisAccessor.h"

#define NEML2_TENSOR_OPTIONS_VARGS \
const torch::Dtype &dtype, const Device &device, bool requires_grad
#define NEML2_TENSOR_OPTIONS_VARGS const Dtype &dtype, const Device &device, bool requires_grad

#define NEML2_TENSOR_OPTIONS \
torch::TensorOptions().dtype(dtype).device(device).requires_grad(requires_grad)

#define PY_ARG_TENSOR_OPTIONS \
pybind11::arg("dtype") = torch::Dtype(torch::kFloat64), \
pybind11::arg("device") = Device(torch::kCPU), pybind11::arg("requires_grad") = false
pybind11::arg("dtype") = Dtype(torch::kFloat64), pybind11::arg("device") = Device(torch::kCPU), \
pybind11::arg("requires_grad") = false

namespace pybind11
{
namespace detail
{
#if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4)
/**
* @brief This specialization enables type conversion between Python object <--> torch::Dtype
* @brief This specialization enables type conversion between Python object <--> Dtype
*/
template <>
struct type_caster<torch::Dtype>
struct type_caster<Dtype>
{
public:
PYBIND11_TYPE_CASTER(torch::Dtype, _("torch.dtype"));
PYBIND11_TYPE_CASTER(Dtype, _("torch.dtype"));

/**
* PYBIND11_TYPE_CASTER defines a member field called value. Since at::Dtype cannot be
Expand All @@ -77,8 +76,7 @@ struct type_caster<torch::Dtype>
return false;
}

static handle
cast(const torch::Dtype & src, return_value_policy /* policy */, handle /* parent */)
static handle cast(const Dtype & src, return_value_policy /* policy */, handle /* parent */)
{
return handle(reinterpret_cast<PyObject *>(torch::getTHPDtype(src)));
}
Expand Down
6 changes: 3 additions & 3 deletions src/neml2/dispatchers/SimpleScheduler.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ OptionSet
SimpleScheduler::expected_options()
{
OptionSet options = WorkScheduler::expected_options();
options.doc() = "Dispatch work to a single device in given batch std::size_ts.";
options.doc() = "Dispatch work to a single device in given batch sizes.";

options.set<std::string>("device");
options.set("device").doc() = "Torch device to run on";
Expand Down Expand Up @@ -68,13 +68,13 @@ SimpleScheduler::schedule_work(Device & device, std::size_t & batch_size) const
}

void
SimpleScheduler::dispatched_work(Device, size_t n)
SimpleScheduler::dispatched_work(Device, std::size_t n)
{
_load += n;
}

void
SimpleScheduler::completed_work(Device, size_t n)
SimpleScheduler::completed_work(Device, std::size_t n)
{
neml_assert(_load >= n, "Load underflow");
_load -= n;
Expand Down

0 comments on commit a8dd7df

Please sign in to comment.