forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnamed_value.h
81 lines (66 loc) · 2.34 KB
/
named_value.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#pragma once
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/utils/variadic.h>
namespace torch::jit {
struct Value;
/**
* A value with optional extra name and location information. Used during
* schema matching to provide extra error information and resolve kwargs.
*/
struct NamedValue {
NamedValue(const SourceRange& loc, const std::string& name, Value* value)
: loc_(loc), name_(name), value_(value) {}
NamedValue(const SourceRange& loc, Value* value) : loc_(loc), value_(value) {}
/* implicit */ NamedValue(Value* value) : value_(value) {}
NamedValue(const std::string& name, Value* value)
: name_(name), value_(value) {}
/* implicit */ NamedValue(IValue value) : ivalue_(std::move(value)) {}
NamedValue(const std::string& name, IValue value)
: name_(name), ivalue_(std::move(value)) {}
template <
typename T,
typename = std::enable_if_t<
(!std::is_same_v<std::decay_t<T>, NamedValue> &&
!std::is_same_v<std::decay_t<T>, Value*> &&
!std::is_same_v<std::decay_t<T>, IValue>)>>
// NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
NamedValue(T&& t) : NamedValue(IValue(std::forward<T>(t))) {}
template <
typename T,
typename = std::enable_if_t<
(!std::is_same_v<std::decay_t<T>, Value*> &&
!std::is_same_v<std::decay_t<T>, IValue>)>>
NamedValue(const std::string& name, T&& t)
: NamedValue(name, IValue(std::forward<T>(t))) {}
SourceRange locOr(const SourceRange& backup_location) const {
if (!loc_)
return backup_location;
return loc();
}
// note: this will insert a constant node into the graph at the current
// insert point if this NamedValue is actually a constant
Value* value(Graph& g) const {
if (!value_)
return insertConstant(
g, ivalue_); // use insertConstant to remove need to include ir.h here
return value_;
}
const std::string& name() const {
AT_ASSERT(name_);
return *name_;
}
const SourceRange& loc() const {
AT_ASSERT(loc_);
return *loc_;
}
at::TypePtr type() const;
private:
std::optional<SourceRange> loc_;
std::optional<std::string> name_;
Value* value_{nullptr};
// only valid if value_ == nullptr;
IValue ivalue_;
};
} // namespace torch::jit