1 | #pragma once |
2 | |
3 | #include <ATen/core/ivalue_inl.h> |
4 | #include <ATen/core/qualified_name.h> |
5 | #include <c10/util/Optional.h> |
6 | #include <torch/csrc/jit/api/module.h> |
7 | #include <torch/csrc/jit/frontend/parser.h> |
8 | #include <torch/csrc/jit/frontend/resolver.h> |
9 | #include <torch/csrc/jit/frontend/script_type_parser.h> |
10 | #include <torch/csrc/jit/frontend/source_range.h> |
11 | #include <torch/csrc/jit/ir/ir.h> |
12 | #include <torch/csrc/jit/serialization/export.h> |
13 | #include <torch/custom_class.h> |
14 | #include <functional> |
15 | #include <memory> |
16 | #include <regex> |
17 | #include <string> |
18 | #include <vector> |
19 | |
20 | namespace torch { |
21 | namespace jit { |
22 | |
23 | using SourceLoader = std::function<std::shared_ptr<Source>(const std::string&)>; |
24 | |
25 | struct SourceImporterImpl : public Resolver, |
26 | std::enable_shared_from_this<SourceImporterImpl> { |
27 | SourceImporterImpl( |
28 | std::shared_ptr<CompilationUnit> cu, |
29 | const std::vector<at::IValue>* constant_table, |
30 | SourceLoader source_loader, |
31 | size_t version); |
32 | TypePtr findNamedType(const QualifiedName& name); |
33 | Function* findFunction(const QualifiedName& name); |
34 | void parseSourceIfNeeded(const std::string& qualifier); |
35 | void LEGACY_import_methods( |
36 | const Module& mod, |
37 | const std::shared_ptr<Source>& src); |
38 | |
39 | std::shared_ptr<SugaredValue> resolveValue( |
40 | const std::string& name, |
41 | GraphFunction& m, |
42 | const SourceRange& loc) override; |
43 | TypePtr resolveType(const std::string& name, const SourceRange& loc) override; |
44 | |
45 | private: |
46 | void importFunction(const std::string& qualifier, const Def& def); |
47 | void importNamedType(const std::string& qualifier, const ClassDef& class_def); |
48 | c10::optional<Assign> attributeAssignmentSpecialHandlingHack( |
49 | const QualifiedName& qualified_classname, |
50 | const Assign& assign); |
51 | void importClass( |
52 | const QualifiedName& qualified_classname, |
53 | const ClassDef& class_def, |
54 | bool is_module); |
55 | void importEnum( |
56 | const QualifiedName& qualified_name, |
57 | const ClassDef& enum_def); |
58 | void importNamedTuple( |
59 | const QualifiedName& qualified_name, |
60 | const ClassDef& named_tuple_def); |
61 | |
62 | void parsePossibleVersionNumber(Lexer& L); |
63 | |
64 | void parseImports(Lexer& L); |
65 | |
66 | std::shared_ptr<CompilationUnit> cu_; |
67 | std::unordered_map<std::string, std::shared_ptr<SugaredValue>> env_; |
68 | SourceLoader source_loader_; |
69 | c10::optional<size_t> version_ = c10::nullopt; |
70 | std::unordered_set<std::string> loaded_sources_; |
71 | // named types and functions loaded from a file but not yet defined because |
72 | // their type has not been requested yet. |
73 | std::unordered_map<QualifiedName, TreeRef> to_be_defined_; |
74 | }; |
75 | |
76 | // Given a directory of serialized TorchScript sources, |
77 | // This class allows the loading of individual named types in source. |
78 | // Resolves the dependencies between source files and parses |
79 | // the source files as necessary. |
80 | |
81 | struct TORCH_API SourceImporter { |
82 | SourceImporter( |
83 | // The compilation unit that will own the imported source |
84 | std::shared_ptr<CompilationUnit> cu, |
85 | const std::vector<at::IValue>* constant_table, |
86 | SourceLoader loader, |
87 | size_t version); |
88 | |
89 | TypePtr loadType(const QualifiedName& name) const; |
90 | |
91 | // Add the methods defined in `src` to the module `mod`, using SourceImporter |
92 | // to resolve any classes via loadType |
93 | void LEGACY_import_methods( |
94 | const Module& mod, |
95 | const std::shared_ptr<Source>& src); |
96 | ~SourceImporter(); |
97 | |
98 | private: |
99 | std::shared_ptr<SourceImporterImpl> pImpl; |
100 | }; |
101 | |
102 | } // namespace jit |
103 | } // namespace torch |
104 | |