Skip to content

Commit

Permalink
Update python.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 14, 2024
1 parent 082c7e0 commit 982ff60
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 29 deletions.
157 changes: 140 additions & 17 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ python.Execution = class {
this.register('argparse');
this._enum = this.register('enum');
this.register('collections');
const copy = this.register('copy');
this.register('copy_reg');
const ast = this.register('ast');
this.ast = ast;
Expand Down Expand Up @@ -4330,6 +4331,7 @@ python.Execution = class {
this.registerFunction('collections.defaultdict', (/* default_factory */) => {
return {};
});
this.registerFunction('copy.deepcopy');
this.registerFunction('copy_reg._reconstructor', (cls, base, state) => {
// copyreg._reconstructor in Python 3
if (base === '__builtin__.object' || base === builtins.object) {
Expand Down Expand Up @@ -6980,13 +6982,16 @@ python.Execution = class {
});
this.registerType('torch.ClassType', class extends torch.Type {
constructor(qualified_name, cu, is_module) {
super('ClassType', qualified_name);
super('ClassType', typeof qualified_name === 'string' ? qualified_name : qualified_name.qualifiedName());
this._is_module = is_module;
this._attributes = new Map();
this._methods = new Map();
this._staticmethods = new Map();
this._constants = new Map();
}
static create(qualifiedName, cu, is_module /*, doc_string, unresolved_class_attributes */) {
return new torch.ClassType(qualifiedName, cu, is_module);
}
qualified_name() {
return this.annotation_str;
}
Expand Down Expand Up @@ -7655,7 +7660,7 @@ python.Execution = class {
while (L.eat('.')) {
name = `${name}.${L.expect('id')}`;
}
real_value = new torch.ClassType(name); // getCustomClass
real_value = torch.ClassType.create(name); // getCustomClass
fake_value = real_value;
} else {
real_value = this.parseBaseType();
Expand Down Expand Up @@ -8805,6 +8810,8 @@ python.Execution = class {
let name = null;
if (args.length === 1 && typeof args[0] === 'string') {
[name] = args;
} else if (args.length === 1 && Array.isArray(args[0]) && args[0].every((arg) => typeof arg === 'string')) {
name = args[0].join('.');
} else {
name = `${args[0].qualifiedName()}.${args[1]}`;
}
Expand All @@ -8822,6 +8829,9 @@ python.Execution = class {
name() {
return this._name; // "baz"
}
atoms() {
return this._qualifiedName.split('.');
}
});
this.registerType('torch.jit.SourceImporter', class {
constructor(cu, constant_table, source_loader, version) {
Expand Down Expand Up @@ -8887,7 +8897,7 @@ python.Execution = class {
const pre_hook_def_map = new Map();
const hook_names = new Set();
const hook_def_map = new Map();
const class_type = new torch.ClassType(qualified_classname.qualifiedName(), this._cu, is_module);
const class_type = torch.ClassType.create(qualified_classname.qualifiedName(), this._cu, is_module);
for (const stmt of class_def.body) {
if (stmt instanceof ast.Assign || stmt instanceof ast.AnnAssign) {
let target = null;
Expand Down Expand Up @@ -8940,8 +8950,9 @@ python.Execution = class {
break;
}
}
} else if (target instanceof ast.Subscript) {
// not implemented
} else if (target instanceof ast.Subscript && target.value instanceof ast.Name && target.value.id === '__annotations__') {
const name = target.slice.elts[0].value;
attributes.push({ name, value, annotation: stmt.value });
continue;
} else {
throw new python.Error('Unexpected statement kind in module metadata.');
Expand Down Expand Up @@ -9020,11 +9031,11 @@ python.Execution = class {
this._code_prefix = !pickle_dir_prefix && !tensor_dir_prefix ? 'code/' : '.data/ts_code/code/';
this._pickle_dir_prefix = pickle_dir_prefix || '';
this._tensor_dir_prefix = tensor_dir_prefix || '';
this._constant_table = [];
const SourceLoader = (qualifier) => {
return this.findSourceInArchiveFromQualifier(this._reader, this._code_prefix, qualifier);
};
this._source_importer = new torch.jit.SourceImporter(
this._compilation_unit, this._constants_table, SourceLoader, reader.version());
this._source_importer = new torch.jit.SourceImporter(this._compilation_unit, this._constant_table, SourceLoader, reader.version());
}
deserialize() {
const execution = this._compilation_unit.execution;
Expand Down Expand Up @@ -9061,7 +9072,7 @@ python.Execution = class {
];
for (const known_type of known_types) {
const prefix = new torch.jit.QualifiedName(known_type.name);
const type = new torch.ClassType(known_type.name, this._compilation_unit, false);
const type = torch.ClassType.create(known_type.name, this._compilation_unit, false);
for (const known_method of known_type.methods || []) {
const schema = new torch.FunctionSchema(known_method);
const name = new torch.jit.QualifiedName(prefix, schema.name);
Expand All @@ -9078,7 +9089,8 @@ python.Execution = class {
execution.builtins.CONSTANTS[`c${i}`] = constants[i];
}
const module = this.readArchive('data');
const type = new torch.ClassType(`${module.__class__.__module__}.${module.__class__.__name__}`, null, true);
const name = `${module.__class__.__module__}.${module.__class__.__name__}`;
const type = torch.ClassType.create(name, null, true);
const result = new torch.ScriptModule(type);
result.data = module;
return result;
Expand All @@ -9101,7 +9113,7 @@ python.Execution = class {
['INT32', 'Int'],
['INT64', 'Long']
]);
const constants = (model.tensors || []).map((constant) => {
const tensor_table = (model.tensors || []).map((constant) => {
const key = constant.data.key;
if (!tensorTypeMap.has(constant.dataType)) {
throw new python.Error(`Unsupported tensor data type '${constant.dataType}'.`);
Expand All @@ -9126,8 +9138,8 @@ python.Execution = class {
return tensor;
});
execution.builtins.CONSTANTS = {};
for (let i = 0; i < constants.length; i++) {
execution.builtins.CONSTANTS[`c${i}`] = constants[i];
for (let i = 0; i < tensor_table.length; i++) {
execution.builtins.CONSTANTS[`c${i}`] = tensor_table[i];
}
const attributes = [];
if (this._reader.has_record('attributes.pkl')) {
Expand All @@ -9137,6 +9149,14 @@ python.Execution = class {
const obj = unpickler.load();
attributes.push(...obj);
}

this._LEGACY_moduleStack = ['__torch__'];
// const module_def = model.mainModule;
for (const tensor of tensor_table) {
this._constant_table.push(tensor);
}
// this.LEGACY_convertModule(module_def);

while (queue.length > 0) {
const module = queue.shift();
module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' };
Expand All @@ -9161,7 +9181,7 @@ python.Execution = class {
delete module.arguments;
}
for (const parameter of parameters) {
const tensor = constants[parameter.tensorId];
const tensor = tensor_table[parameter.tensorId];
module[parameter.name] = tensor;
parameter.__class__ = parameter.__class__ || { __module__: 'torch', __name__: 'Tensor' };
}
Expand All @@ -9182,11 +9202,71 @@ python.Execution = class {
data.forward = module.forward;
}
}
const class_type = new torch.ClassType(data.name);
const class_type = torch.ClassType.create(data.name);
const result = new torch.ScriptModule(class_type);
result.data = data;
return result;
}
LEGACY_convertModule(module_def) {
const atoms = new torch.jit.QualifiedName(module_def.name).atoms();
const numPushed = atoms.length;
for (const atom of atoms) {
const sanitized = /^\d+$/.test(atom) ? `_${atom}` : atom;
this._LEGACY_moduleStack.push(sanitized);
}
const module = new torch.ScriptModule(new torch.jit.QualifiedName(this._LEGACY_moduleStack), this._compilation_unit);
for (const sub_def of module_def.submodules || []) {
const submodule = this.LEGACY_convertModule(sub_def);
module.register_module(sub_def.name, submodule);
}
for (const param_def of module_def.parameters || []) {
const tensor = this._constant_table[Number(param_def.tensorId)];
if (param_def.isBuffer) {
module.register_buffer(param_def.name, tensor);
} else {
module.register_parameter(param_def.name, tensor, false);
}
}
// const typeParser = new torch.jit.ScriptTypeParser(this._source_importer);
for (const attr_def of module_def.attributes || []) {
if (module.hasattr(attr_def.name)) {
continue;
}
// IValue ivalue;
// if (attr_def.id() >= 0) {
// ivalue = LEGACY_pickled_ivalues_.at(attr_def.id());
// }
// module.register_attribute(attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
}
/*
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr;
if (module_def.has_torchscript_debug_arena()) {
auto [data, size] = reader_->getRecord(module_def.torchscript_debug_arena().key());
gen_ranges = std::make_shared<ConcreteSourceRangeUnpickler>(std::move(data), size);
}
if (module_def.has_torchscript_arena()) {
auto [data, size] =
reader_->getRecord(module_def.torchscript_arena().key());
std::string data_str(static_cast<const char*>(data.get()), size);
auto src = std::make_shared<Source>(std::string(static_cast<const char*>(data.get()), size), module_def.torchscript_arena().key(), 1, std::move(gen_ranges));
source_importer_.LEGACY_import_methods(module, src);
}
if (module_def.has_get_state_attribute_id()) {
LEGACY_moduleSetState(module, LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id()));
}
const ClassTypePtr& module_type = module._ivalue()->type();
for (size_t i = 0, N = module_type->numAttributes(); i < N; ++i) {
const IValue& v = module._ivalue()->getSlot(i);
if (module_type->getAttribute(i)->kind() != TypeKind::OptionalType) {
TORCH_CHECK(!v.isNone(), "The field '", module_type->getAttributeName(i), "' was left unitialized after __setstate__, but expected a ", "value of type '", v.type()->repr_str(), "'");
}
}
*/
for (let i = 0; i < numPushed; i++) {
this._LEGACY_moduleStack.pop();
}
return module;
}
readArchive(archive_name) {
const type_resolver = null;
const obj_loader = null;
Expand Down Expand Up @@ -9478,6 +9558,14 @@ python.Execution = class {
}
});
this.registerType('torch.ScriptModule', class extends torch.ScriptObject {
constructor(...args) {
if (args[0] instanceof torch.jit.QualifiedName && args[1] instanceof torch.jit.CompilationUnit) {
const [class_name, cu, shouldMangle] = args;
super(...torch.ScriptModule.create_module_object(class_name, cu, shouldMangle));
} else {
super(...args);
}
}
get qualified_name() {
return this._type.qualified_name();
}
Expand Down Expand Up @@ -9579,6 +9667,24 @@ python.Execution = class {
}
return this._graph;
}
static create_module_object(class_name, cu, shouldMangle) {
shouldMangle = shouldMangle || false;
if (!class_name.prefix()) {
class_name = new torch.jit.QualifiedName('__torch__', class_name.name());
}
if (shouldMangle && cu.get_class(class_name)) {
class_name = cu.mangle(class_name);
}
const cls = torch.ClassType.create(class_name, cu, true);
cu.register_type(cls);
return [cls, cu];
}
register_module(/* name, module */) {
}
register_buffer(/* name, buffer */) {
}
register_parameter(/* name, parameter, is_buffer */) {
}
});
this.registerType('torch.ModuleDict', class {
constructor(module) {
Expand Down Expand Up @@ -9659,6 +9765,7 @@ python.Execution = class {
const graph = new torch.Graph();
graph.set_op_version(operator_set_version);
const fn = new torch.jit.GraphFunction(name, graph, creator);
fn.__ast__ = def;
if (shouldMangle && this.find_function(name)) {
// name = mangle(name);
}
Expand Down Expand Up @@ -9845,7 +9952,7 @@ python.Execution = class {
cls = this._cu.get_class(new torch.jit.QualifiedName(name));
if (!cls) {
const torch = this._torch;
cls = new torch.ClassType(name, this._cu, true);
cls = torch.ClassType.create(name, this._cu, true);
this._cu.register_type(cls);
}
} else {
Expand All @@ -9862,7 +9969,20 @@ python.Execution = class {
return cls;
}
});
this.registerType('torch.export.unflatten.UnflattenedModule', class extends torch.nn.modules.module.Module {});
this.registerType('torch.export.UnflattenedModule', class extends torch.nn.modules.module.Module {
constructor(export_module, flat_args_adapter) {
super();
const export_graph = copy.deepcopy(export_module.graph);
self.graph_signature = copy.deepcopy(export_module.graph_signature);
this.graph = torch.fx.Graph();
this.graph.owning_module = this;
this.module_call_graph = copy.deepcopy(export_module.module_call_graph);
this.flat_args_adapter = flat_args_adapter;
this.adapted = false;
// this._run_with_interpreter = RUN_WITH_INTERPRETER
this._inplace_buffer_mutations(export_graph, this.graph_signature);
}
});
this.registerType('torch.export.graph_signature.ExportGraphSignature', class {
constructor(input_specs, output_specs) {
this.input_specs = input_specs;
Expand Down Expand Up @@ -10017,7 +10137,10 @@ python.Execution = class {
});
this.registerType('torch.export.exported_program.ModuleCallEntry', class {});
this.registerType('torch.export.exported_program.ModuleCallSignature', class {});
this.registerFunction('torch.export.unflatten');
this.registerFunction('torch.export.unflatten', (module, flat_args_adapter) => {
module = torch.export._remove_effect_tokens(module);
return new torch.export.UnflattenedModule(module, flat_args_adapter);
});
this.registerFunction('torch._export.exported_program._create_graph_module_for_export', (root, graph) => {
return new torch.fx.graph_module.GraphModule(root, graph);
});
Expand Down
Loading

0 comments on commit 982ff60

Please sign in to comment.