forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimport_method.cpp
131 lines (116 loc) · 5.07 KB
/
import_method.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include "torch/csrc/jit/import_method.h"
#include "torch/csrc/jit/script/parser.h"
namespace torch { namespace jit {
// this is a much simpler accessor that only handles modules, parameters, and
// and methods. It does not depend on python to work.
struct ModuleAccessorValue : public script::SugaredValue {
ModuleAccessorValue(std::shared_ptr<script::Module> module)
: module(std::move(module)) {}
std::string kind() const override {
return "module";
}
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> attr(SourceRange loc, script::Method & m, const std::string& field) override {
if(script::NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleAccessorValue>(v->module);
} else if(script::NamedParameter* v = module->find_parameter(field)) {
return std::make_shared<script::SimpleValue>(m.get_or_add_parameter(v->slot()));
} else if(script::Method* m = module->find_method(field)) {
return std::make_shared<script::MethodValue>(module, *m);
} else {
throw script::ErrorReport(loc) << "unknown attr: " << field;
}
}
private:
std::shared_ptr<script::Module> module;
};
struct OpsValue : public script::SugaredValue {
OpsValue(size_t version)
: version_(version) {}
std::string kind() const override {
return "ops";
}
std::shared_ptr<SugaredValue> attr(SourceRange loc, script::Method & m, const std::string& field) override {
return std::make_shared<script::BuiltinModule>(field, version_);
}
size_t version_;
};
struct ConstantValue : public script::SugaredValue {
ConstantValue(IValue value)
: value_(std::move(value)) {}
IValue value_;
std::string kind() const override { return "constant"; }
Value * asValue(SourceRange loc, script::Method & m) override {
return m.graph()->insertConstant(value_);
}
};
// This value maps attributes CONSTANTS.c0 CONSTANTS.c1 to entries
// in the 'constants' vector. This table is will be stored in a container format
// and given to the import_method when restoring the code.
struct ConstantTableValue : public script::SugaredValue {
ConstantTableValue(ArrayRef<at::Tensor> constants)
: constants_(constants) {}
std::string kind() const override {
return "CONSTANTS";
}
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> attr(SourceRange loc, script::Method & m, const std::string& field) override {
const char* field_s = field.c_str();
char* end;
int64_t offset = std::strtoll(field_s + 1, &end, 10);
if(field.size() < 2 || *end != 0)
throw script::ErrorReport(loc) << "invalid constant specifier: " << field;
if (offset < 0 || size_t(offset) >= constants_.size()) {
throw script::ErrorReport(loc) << "constant index " << offset
<< " is out of bounds (constant table has "
<< constants_.size() << " entries).";
}
Value* value = m.graph()->insertConstant(constants_[offset], loc);
return std::make_shared<script::SimpleValue>(value);
}
private:
ArrayRef<at::Tensor> constants_;
};
static size_t parseVersionNumber(script::Lexer& L) {
auto range = L.cur().range;
auto name = L.expect(script::TK_IDENT).text();
L.expect('=');
std::string version_text = L.expect(script::TK_NUMBER).text();
L.expect(script::TK_NEWLINE);
auto version = script::Const::create(L.cur().range, version_text);
if (name != "op_version_set")
throw script::ErrorReport(range) << "expected an assignment to op_version_set";
if (!version.isIntegral())
throw script::ErrorReport(range) << "expected an integral version but found " << version.text();
return size_t(version.asIntegral());
}
void import_methods(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table) {
script::Parser p(src);
size_t version = parseVersionNumber(p.lexer());
std::unordered_map<std::string, std::shared_ptr<script::SugaredValue>> env = {
{"torch", std::make_shared<script::BuiltinModule>("aten", version)},
{"ops", std::make_shared<OpsValue>(version)},
{"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
{"fork", std::make_shared<script::ForkValue>()},
{"annotate", std::make_shared<script::AnnotateValue>()},
{"inf", std::make_shared<ConstantValue>(std::numeric_limits<double>::infinity())},
{"nan", std::make_shared<ConstantValue>(std::numeric_limits<double>::quiet_NaN())},
};
auto resolver = [&](const std::string& name, script::Method& m, const SourceRange& loc)
-> std::shared_ptr<script::SugaredValue> {
auto it = env.find(name);
if (it == env.end())
return nullptr;
return it->second;
};
std::vector<script::Def> definitions;
std::vector<script::Resolver> resolvers;
while (p.lexer().cur().kind != script::TK_EOF) {
auto def = script::Def(p.parseFunction(/*is_method=*/true));
definitions.emplace_back(def);
resolvers.emplace_back(resolver);
}
auto self = std::make_shared<ModuleAccessorValue>(mod);
script::defineMethodsInModule(mod, definitions, resolvers, self);
}
}}