Skip to content

Commit 6b60fee

Browse files
committed
Clang-format
1 parent 2999c79 commit 6b60fee

File tree

4 files changed

+36
-42
lines changed

4 files changed

+36
-42
lines changed

include/lbann/layers/operator_layer.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class OperatorLayer final : public data_type_layer<InputT, OutputT>
8383

8484
#ifdef LBANN_HAS_ONNX
8585
void fill_onnx_node(onnx::GraphProto& graph) const override;
86-
#endif //LBANN_HAS_ONNX
86+
#endif // LBANN_HAS_ONNX
8787

8888
void fp_compute() final;
8989
void bp_compute() final;

include/lbann/operators/math/binary_with_constant.hpp

+23-24
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ inline onnx::NodeProto get_constant_node(float val)
176176
}
177177

178178
template <typename T, El::Device D>
179-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
180-
AddConstantOperator<T, D> const op)
179+
std::vector<onnx::NodeProto>
180+
get_onnx_nodes_impl(AddConstantOperator<T, D> const op)
181181
{
182182
std::vector<onnx::NodeProto> nodes(2UL);
183183
nodes.front().set_op_type("Add");
@@ -187,8 +187,7 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
187187
}
188188

189189
template <typename T, El::Device D>
190-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
191-
ScaleOperator<T, D> const op)
190+
std::vector<onnx::NodeProto> get_onnx_nodes_impl(ScaleOperator<T, D> const op)
192191
{
193192
std::vector<onnx::NodeProto> nodes(2UL);
194193
nodes.front().set_op_type("Mul");
@@ -198,8 +197,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
198197
}
199198

200199
template <typename T, El::Device D>
201-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
202-
SubtractConstantOperator<T, D> const op)
200+
std::vector<onnx::NodeProto>
201+
get_onnx_nodes_impl(SubtractConstantOperator<T, D> const op)
203202
{
204203
std::vector<onnx::NodeProto> nodes(2UL);
205204
nodes.front().set_op_type("Sub");
@@ -209,8 +208,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
209208
}
210209

211210
template <typename T, El::Device D>
212-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
213-
ConstantSubtractOperator<T, D> const op)
211+
std::vector<onnx::NodeProto>
212+
get_onnx_nodes_impl(ConstantSubtractOperator<T, D> const op)
214213
{
215214
std::vector<onnx::NodeProto> nodes(2UL);
216215
nodes.front().set_op_type("Sub");
@@ -220,8 +219,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
220219
}
221220

222221
template <typename T, El::Device D>
223-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
224-
MaxConstantOperator<T, D> const op)
222+
std::vector<onnx::NodeProto>
223+
get_onnx_nodes_impl(MaxConstantOperator<T, D> const op)
225224
{
226225
std::vector<onnx::NodeProto> nodes(2UL);
227226
nodes.front().set_op_type("Max");
@@ -231,8 +230,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
231230
}
232231

233232
template <typename T, El::Device D>
234-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
235-
MinConstantOperator<T, D> const op)
233+
std::vector<onnx::NodeProto>
234+
get_onnx_nodes_impl(MinConstantOperator<T, D> const op)
236235
{
237236
std::vector<onnx::NodeProto> nodes(2UL);
238237
nodes.front().set_op_type("Min");
@@ -242,8 +241,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
242241
}
243242

244243
template <typename T, El::Device D>
245-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
246-
EqualConstantOperator<T, D> const op)
244+
std::vector<onnx::NodeProto>
245+
get_onnx_nodes_impl(EqualConstantOperator<T, D> const op)
247246
{
248247
std::vector<onnx::NodeProto> nodes(2UL);
249248
nodes.front().set_op_type("Equal");
@@ -253,8 +252,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
253252
}
254253

255254
template <typename T, El::Device D>
256-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
257-
NotEqualConstantOperator<T, D> const op)
255+
std::vector<onnx::NodeProto>
256+
get_onnx_nodes_impl(NotEqualConstantOperator<T, D> const op)
258257
{
259258
std::vector<onnx::NodeProto> nodes(3UL);
260259
nodes.front().set_op_type("Equal");
@@ -265,8 +264,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
265264
}
266265

267266
template <typename T, El::Device D>
268-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
269-
LessConstantOperator<T, D> const op)
267+
std::vector<onnx::NodeProto>
268+
get_onnx_nodes_impl(LessConstantOperator<T, D> const op)
270269
{
271270
std::vector<onnx::NodeProto> nodes(2UL);
272271
nodes.front().set_op_type("Less");
@@ -276,8 +275,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
276275
}
277276

278277
template <typename T, El::Device D>
279-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
280-
LessEqualConstantOperator<T, D> const op)
278+
std::vector<onnx::NodeProto>
279+
get_onnx_nodes_impl(LessEqualConstantOperator<T, D> const op)
281280
{
282281
std::vector<onnx::NodeProto> nodes(2UL);
283282
nodes.front().set_op_type("LessOrEqual");
@@ -287,8 +286,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
287286
}
288287

289288
template <typename T, El::Device D>
290-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
291-
GreaterConstantOperator<T, D> const op)
289+
std::vector<onnx::NodeProto>
290+
get_onnx_nodes_impl(GreaterConstantOperator<T, D> const op)
292291
{
293292
std::vector<onnx::NodeProto> nodes(2UL);
294293
nodes.front().set_op_type("Greater");
@@ -298,8 +297,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
298297
}
299298

300299
template <typename T, El::Device D>
301-
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
302-
GreaterEqualConstantOperator<T, D> const op)
300+
std::vector<onnx::NodeProto>
301+
get_onnx_nodes_impl(GreaterEqualConstantOperator<T, D> const op)
303302
{
304303
std::vector<onnx::NodeProto> nodes(2UL);
305304
nodes.front().set_op_type("GreaterOrEqual");

include/lbann/operators/operator.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ void Operator<InputT, OutputT, D>::serialize(ArchiveT& ar)
217217

218218
#ifdef LBANN_HAS_ONNX
219219
template <typename InputT, typename OutputT, El::Device D>
220-
std::vector<onnx::NodeProto> Operator<InputT, OutputT, D>::get_onnx_nodes() const
220+
std::vector<onnx::NodeProto>
221+
Operator<InputT, OutputT, D>::get_onnx_nodes() const
221222
{
222223
// The default assumption is that we don't know how to represent
223224
// this operator in ONNX terms yet.

src/layers/operator_layer.cpp

+10-16
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ namespace lbann {
4545

4646
#ifdef LBANN_HAS_ONNX
4747
template <typename T, typename O, data_layout L, El::Device D>
48-
void OperatorLayer<T, O, L, D>::fill_onnx_node(
49-
onnx::GraphProto& graph) const
48+
void OperatorLayer<T, O, L, D>::fill_onnx_node(onnx::GraphProto& graph) const
5049
{
5150
const auto& parents = this->get_parent_layers();
5251
auto nodes = m_ops.front()->get_onnx_nodes();
@@ -58,27 +57,22 @@ void OperatorLayer<T, O, L, D>::fill_onnx_node(
5857
op_node->set_domain("");
5958
op_node->set_doc_string(this->get_name());
6059

61-
//binary operators
62-
if(nodes.size() == 1)
63-
{
64-
for(auto* parent : parents)
65-
{
60+
// binary operators
61+
if (nodes.size() == 1) {
62+
for (auto* parent : parents) {
6663
size_t idx = parent->find_child_layer_index(*this);
6764
op_node->add_input(parent->get_name() + "_" + std::to_string(idx));
6865
}
6966
}
7067
// Binary w/ constant operators
71-
else if(nodes.size() == 2 || nodes.size() == 3)
72-
{
68+
else if (nodes.size() == 2 || nodes.size() == 3) {
7369
auto* const_node = graph.add_node();
7470
*const_node = nodes.back();
75-
if(const_node->op_type() == "PostConstant")
76-
{
71+
if (const_node->op_type() == "PostConstant") {
7772
op_node->add_input(parents[0]->get_name() + "_0");
7873
op_node->add_input(const_node->output(0));
7974
}
80-
else if(const_node->op_type() == "PreConstant")
81-
{
75+
else if (const_node->op_type() == "PreConstant") {
8276
op_node->add_input(const_node->output(0));
8377
op_node->add_input(parents[0]->get_name() + "_0");
8478
}
@@ -88,11 +82,11 @@ void OperatorLayer<T, O, L, D>::fill_onnx_node(
8882
const_node->set_op_type("Constant");
8983
}
9084
else
91-
LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ", nodes.size());
85+
LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ",
86+
nodes.size());
9287

9388
// Not equal operator
94-
if(nodes.size() == 3)
95-
{
89+
if (nodes.size() == 3) {
9690
op_node->add_output("EqualOperator");
9791
auto* not_node = graph.add_node();
9892
not_node->add_input(op_node->output(0));

0 commit comments

Comments
 (0)