1 | #include <ATen/Utils.h> |
2 | #include <c10/core/TensorImpl.h> |
3 | #include <torch/csrc/jit/backends/backend.h> |
4 | #include <torch/csrc/jit/backends/backend_exception.h> |
5 | |
6 | #ifndef NO_PROFILING |
7 | #include <torch/csrc/jit/mobile/profiler_edge.h> |
8 | #endif |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | // Implementation of a PyTorch Backend that can process, compile and execute |
14 | // TorchScript Modules composed of 'add' and 'sub' operators. It just supports |
15 | // for modules that implement a sum or subtraction of 2 inputs (i.e. in1 + in2 |
16 | // or in1 - in2). Hence the methods of the models expect exactly 2 inputs of |
17 | // type Tensor. This backend is used to demonstrate the flow of compilation and |
18 | // execution with minimum amount of work. It's not intended to a practical |
19 | // backend that can be used for actual inference. |
20 | |
21 | // Implementation details: |
22 | // |
23 | // Compilation |
24 | // 1. A backend with minimum compilation features, "backend_with_compiler_demo" |
25 | // is added. |
26 | // 2. The compilation happens AOT in the preprocess function registered to this |
27 | // backend. |
28 | // 3. Compiled results are stored in a string blob for each method. They are |
29 | // serialized to the lowered module with __getstate__ function. |
30 | // 4. Error message with model source code is thrown, for features not handled |
31 | // by the backend compiler. |
32 | // |
33 | // Runtime |
34 | // 1. The compiled blob is loaded in __setstate__ method. |
35 | // 2. The compile function of the backend: parse the preprocessed blob to the |
36 | // format (a list of tokens) that the backend can understand. |
37 | // 3. The execute function of the backend executes the specified method |
38 | // (handle). |
39 | |
40 | namespace { |
41 | std::vector<std::tuple<std::string, int64_t>> parseMethodHandle( |
42 | const std::string& blob) { |
43 | std::vector<std::tuple<std::string, int64_t>> result; |
44 | std::stringstream s_stream(blob); |
45 | constexpr char debug_handle_token[] = "<debug_handle>" ; |
46 | while (s_stream.good()) { |
47 | std::string substr; |
48 | getline(s_stream, substr, ','); |
49 | auto debug_handle_pos = substr.find(debug_handle_token); |
50 | int64_t debug_handle{-1}; |
51 | auto instruction = substr.substr(0); |
52 | if (debug_handle_pos != std::string::npos) { |
53 | instruction = substr.substr(0, debug_handle_pos); |
54 | debug_handle = stoi(substr.substr(debug_handle_pos + 14)); |
55 | } |
56 | result.push_back(std::make_tuple(instruction, debug_handle)); |
57 | } |
58 | return result; |
59 | } |
60 | |
61 | float* float_data_ptr(const at::Tensor& t) { |
62 | return t.unsafeGetTensorImpl()->data_ptr_impl<float>(); |
63 | } |
64 | } // namespace |
65 | |
66 | class BackendWithCompiler : public PyTorchBackendInterface { |
67 | public: |
68 | // Constructor. |
69 | // NOLINTNEXTLINE(modernize-use-equals-default) |
70 | explicit BackendWithCompiler() {} |
71 | // NOLINTNEXTLINE(modernize-use-override) |
72 | virtual ~BackendWithCompiler() = default; |
73 | |
74 | bool is_available() override { |
75 | return true; |
76 | } |
77 | |
78 | // Since the actual compilation is done AOT for this backend, compile just |
79 | // forwards everything along. In a non toy setup this could grab information |
80 | // from that runtime that might be relevant to execute, such as build flags |
81 | // the resolution of the devices camera, or basically any runtime specific |
82 | // information that wouldnt be available server side where preprocess is |
83 | // called. |
84 | c10::impl::GenericDict compile( |
85 | c10::IValue processed, |
86 | c10::impl::GenericDict method_compile_spec) override { |
87 | auto dict = processed.toGenericDict(); |
88 | auto handles = |
89 | c10::Dict<std::string, std::vector<std::tuple<std::string, int64_t>>>(); |
90 | for (const auto& kv : dict) { |
91 | auto tokens = parseMethodHandle(kv.value().toStringRef()); |
92 | handles.insert(kv.key().toStringRef(), tokens); |
93 | } |
94 | return c10::impl::toGenericDict(handles); |
95 | } |
96 | |
97 | // Function that actually executes the model in the backend. Here there is |
98 | // nothing to dispatch to, so the backend is implemented locally within |
99 | // execute and it only supports add, subtract, and constant. In a non toy |
100 | // backend you can imagine how this function could be used to actually |
101 | // dispatch the inputs to the relevant backend/device. |
102 | c10::impl::GenericList execute( |
103 | c10::IValue |
104 | handle, // example: [('prim::Constant#1', 14), ('aten::add', 15)] |
105 | c10::impl::GenericList inputs) override { |
106 | TORCH_INTERNAL_ASSERT(inputs.size() == 2); |
107 | c10::IValue val0 = inputs[0]; |
108 | at::Tensor x = val0.toTensor(); |
109 | c10::IValue val1 = inputs[1]; |
110 | at::Tensor h = val1.toTensor(); |
111 | std::vector<std::tuple<int64_t, int64_t, std::string>> op_runtimes_us; |
112 | op_runtimes_us.reserve(handle.toList().size()); |
113 | |
114 | c10::List<at::Tensor> output_list; |
115 | #ifndef NO_PROFILING |
116 | auto start_us = torch::profiler::impl::getTime() / 1000; |
117 | #endif |
118 | for (const auto& token : handle.toList()) { |
119 | IValue val = token; |
120 | auto instruction = val.toTupleRef().elements()[0].toStringRef(); |
121 | auto debug_handle = val.toTupleRef().elements()[1].toInt(); |
122 | #ifndef NO_PROFILING |
123 | auto start_time_us = torch::profiler::impl::getTime() / 1000; |
124 | #endif |
125 | try { |
126 | if (instruction.rfind("prim::Constant" , 0) == 0) { |
127 | // 15 is the length of 'prim::Constant#' the constant val comes after |
128 | TORCH_CHECK( |
129 | instruction.size() > 15, |
130 | "Constant value is expected in " , |
131 | instruction); |
132 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
133 | auto sub = instruction.substr(15); |
134 | } else if (instruction == "aten::add" || instruction == "aten::sub" ) { |
135 | TORCH_CHECK(x.sizes() == h.sizes()); |
136 | if (x.dim() > 1 || (x.dim() == 1 && x.size(0) > 1)) { |
137 | TORCH_WARN( |
138 | "Only the first elements of the tensors are added or subbed." ); |
139 | } |
140 | TORCH_CHECK( |
141 | (x.scalar_type() == c10::ScalarType::Float && |
142 | h.scalar_type() == c10::ScalarType::Float), |
143 | "Only float tensors are compatible for add and sub." ); |
144 | at::Tensor y = at::detail::empty_cpu(x.sizes(), at::kFloat); |
145 | auto x_ptr = float_data_ptr(x); |
146 | auto h_ptr = float_data_ptr(h); |
147 | auto y_ptr = float_data_ptr(y); |
148 | #ifndef NO_PROFILING |
149 | RECORD_BACKEND_MEMORY_EVENT_TO_EDGE_PROFILER( |
150 | x_ptr, |
151 | x.numel() * sizeof(float), |
152 | x.numel() * sizeof(float), |
153 | x.numel() * sizeof(float) + y.numel() * sizeof(float) + |
154 | h.numel() * sizeof(float), |
155 | c10::Device(c10::kCPU)); |
156 | #endif |
157 | if (instruction == "aten::add" ) { |
158 | y_ptr[0] = x_ptr[0] + h_ptr[0]; |
159 | } else { |
160 | y_ptr[0] = x_ptr[0] - h_ptr[0]; |
161 | } |
162 | output_list.emplace_back(y); |
163 | } else { |
164 | TORCH_CHECK( |
165 | false, |
166 | "Instruction, " , |
167 | instruction, |
168 | " is not supported. " , |
169 | "Contact the backend POC for details. " ); |
170 | } |
171 | } catch (c10::Error& e) { |
172 | TORCH_DELEGATED_BACKEND_THROW(false, e.what(), debug_handle); |
173 | } |
174 | #ifndef NO_PROFILING |
175 | auto end_time_us = torch::profiler::impl::getTime() / 1000; |
176 | auto duration = end_time_us - start_time_us; |
177 | op_runtimes_us.emplace_back(duration, debug_handle, instruction); |
178 | #endif |
179 | } |
180 | #ifndef NO_PROFILING |
181 | for (const auto& tup : op_runtimes_us) { |
182 | RECORD_BACKEND_EVENT_TO_EDGE_PROFILER( |
183 | start_us, |
184 | start_us + std::get<0>(tup), |
185 | std::get<1>(tup), |
186 | std::get<2>(tup), |
187 | "test_backend" ); |
188 | start_us = start_us + std::get<0>(tup); |
189 | } |
190 | #endif |
191 | return c10::impl::toList(output_list); |
192 | } |
193 | }; |
194 | |
195 | namespace { |
196 | constexpr auto backend_name = "backend_with_compiler_demo" ; |
197 | static auto cls = torch::jit::backend<BackendWithCompiler>(backend_name); |
198 | } // namespace |
199 | |
200 | } // namespace jit |
201 | } // namespace torch |
202 | |