1 | #include <torch/csrc/jit/backends/backend.h> |
2 | #include <torch/csrc/jit/backends/backend_preprocess.h> |
3 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
4 | #include <torch/csrc/jit/passes/inliner.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | namespace { |
9 | // For this backend, the actual compilation happens in preprocess function AOT. |
10 | // Put here for demonstration of backend |
11 | // as a whole piece. It's used when compilation is required. A dummy function |
12 | // can be passed when there's no usage of compilation in runtime backend lib. |
13 | c10::IValue preprocess( |
14 | const Module& mod, |
15 | const c10::Dict<IValue, IValue>& method_compile_spec, |
16 | const BackendDebugHandleGenerator& generate_debug_handles) { |
17 | // The output of this process would produce a dictionary |
18 | // Key: method name. |
19 | // Val: compiled blob (represented by a string). |
20 | c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get()); |
21 | |
22 | for (const auto& method : mod.get_methods()) { |
23 | auto graph = toGraphFunction(method.function()).graph()->copy(); |
24 | // Must inline the graph for debug info map. |
25 | Inline(*graph); |
26 | // This is here because to test module hierarchy we will have |
27 | // getattr nodes which after inlining dont serve any purpose. |
28 | // Without removing them we will run into compilation errors. |
29 | // So eliminate deadcode just remove those getattr nodes. |
30 | EliminateDeadCode(graph); |
31 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
32 | auto key = method.name(); |
33 | auto node_debug_handles = generate_debug_handles(graph); |
34 | std::stringstream ss; |
35 | for (const auto& node : graph->nodes()) { |
36 | switch (node->kind()) { |
37 | case prim::Constant: |
38 | ss << node->kind().toDisplayString() << "#" |
39 | << toIValue(node->output()).value(); |
40 | ss << "<debug_handle>" << node_debug_handles[node]; |
41 | break; |
42 | // NOLINTNEXTLINE(bugprone-branch-clone) |
43 | case aten::add: |
44 | ss << node->kind().toQualString(); |
45 | ss << "<debug_handle>" << node_debug_handles[node]; |
46 | break; |
47 | case aten::sub: |
48 | ss << node->kind().toQualString(); |
49 | ss << "<debug_handle>" << node_debug_handles[node]; |
50 | break; |
51 | default: |
52 | TORCH_CHECK( |
53 | false, |
54 | "The node of " , |
55 | node->kind().toQualString(), |
56 | " is not supported in this compiler. Source code: " , |
57 | node->sourceRange().str()); |
58 | break; |
59 | } |
60 | ss << "," ; |
61 | } |
62 | std::string blob = ss.str(); |
63 | if (!blob.empty()) { |
64 | blob.pop_back(); |
65 | } |
66 | compiled.insert(method.name(), blob); |
67 | } |
68 | return compiled; |
69 | } |
70 | |
71 | constexpr auto backend_name = "backend_with_compiler_demo" ; |
72 | static auto pre_reg = backend_preprocess_register(backend_name, preprocess); |
73 | } // namespace |
74 | |
75 | } // namespace jit |
76 | } // namespace torch |
77 | |