1#include <torch/csrc/jit/backends/backend.h>
2#include <torch/csrc/jit/backends/backend_preprocess.h>
3#include <torch/csrc/jit/passes/dead_code_elimination.h>
4#include <torch/csrc/jit/passes/inliner.h>
5
6namespace torch {
7namespace jit {
8namespace {
9// For this backend, the actual compilation happens in preprocess function AOT.
10// Put here for demonstration of backend
11// as a whole piece. It's used when compilation is required. A dummy function
12// can be passed when there's no usage of compilation in runtime backend lib.
13c10::IValue preprocess(
14 const Module& mod,
15 const c10::Dict<IValue, IValue>& method_compile_spec,
16 const BackendDebugHandleGenerator& generate_debug_handles) {
17 // The output of this process would produce a dictionary
18 // Key: method name.
19 // Val: compiled blob (represented by a string).
20 c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get());
21
22 for (const auto& method : mod.get_methods()) {
23 auto graph = toGraphFunction(method.function()).graph()->copy();
24 // Must inline the graph for debug info map.
25 Inline(*graph);
26 // This is here because to test module hierarchy we will have
27 // getattr nodes which after inlining dont serve any purpose.
28 // Without removing them we will run into compilation errors.
29 // So eliminate deadcode just remove those getattr nodes.
30 EliminateDeadCode(graph);
31 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
32 auto key = method.name();
33 auto node_debug_handles = generate_debug_handles(graph);
34 std::stringstream ss;
35 for (const auto& node : graph->nodes()) {
36 switch (node->kind()) {
37 case prim::Constant:
38 ss << node->kind().toDisplayString() << "#"
39 << toIValue(node->output()).value();
40 ss << "<debug_handle>" << node_debug_handles[node];
41 break;
42 // NOLINTNEXTLINE(bugprone-branch-clone)
43 case aten::add:
44 ss << node->kind().toQualString();
45 ss << "<debug_handle>" << node_debug_handles[node];
46 break;
47 case aten::sub:
48 ss << node->kind().toQualString();
49 ss << "<debug_handle>" << node_debug_handles[node];
50 break;
51 default:
52 TORCH_CHECK(
53 false,
54 "The node of ",
55 node->kind().toQualString(),
56 " is not supported in this compiler. Source code: ",
57 node->sourceRange().str());
58 break;
59 }
60 ss << ",";
61 }
62 std::string blob = ss.str();
63 if (!blob.empty()) {
64 blob.pop_back();
65 }
66 compiled.insert(method.name(), blob);
67 }
68 return compiled;
69}
70
71constexpr auto backend_name = "backend_with_compiler_demo";
72static auto pre_reg = backend_preprocess_register(backend_name, preprocess);
73} // namespace
74
75} // namespace jit
76} // namespace torch
77