1#include <ATen/Utils.h>
2#include <c10/core/TensorImpl.h>
3#include <torch/csrc/jit/backends/backend.h>
4#include <torch/csrc/jit/backends/backend_exception.h>
5
6#ifndef NO_PROFILING
7#include <torch/csrc/jit/mobile/profiler_edge.h>
8#endif
9
10namespace torch {
11namespace jit {
12
13// Implementation of a PyTorch Backend that can process, compile and execute
14// TorchScript Modules composed of 'add' and 'sub' operators. It just supports
15// for modules that implement a sum or subtraction of 2 inputs (i.e. in1 + in2
16// or in1 - in2). Hence the methods of the models expect exactly 2 inputs of
17// type Tensor. This backend is used to demonstrate the flow of compilation and
18// execution with minimum amount of work. It's not intended to a practical
19// backend that can be used for actual inference.
20
21// Implementation details:
22//
23// Compilation
24// 1. A backend with minimum compilation features, "backend_with_compiler_demo"
25// is added.
26// 2. The compilation happens AOT in the preprocess function registered to this
27// backend.
28// 3. Compiled results are stored in a string blob for each method. They are
29// serialized to the lowered module with __getstate__ function.
30// 4. Error message with model source code is thrown, for features not handled
31// by the backend compiler.
32//
33// Runtime
34// 1. The compiled blob is loaded in __setstate__ method.
35// 2. The compile function of the backend: parse the preprocessed blob to the
36// format (a list of tokens) that the backend can understand.
37// 3. The execute function of the backend executes the specified method
38// (handle).
39
40namespace {
41std::vector<std::tuple<std::string, int64_t>> parseMethodHandle(
42 const std::string& blob) {
43 std::vector<std::tuple<std::string, int64_t>> result;
44 std::stringstream s_stream(blob);
45 constexpr char debug_handle_token[] = "<debug_handle>";
46 while (s_stream.good()) {
47 std::string substr;
48 getline(s_stream, substr, ',');
49 auto debug_handle_pos = substr.find(debug_handle_token);
50 int64_t debug_handle{-1};
51 auto instruction = substr.substr(0);
52 if (debug_handle_pos != std::string::npos) {
53 instruction = substr.substr(0, debug_handle_pos);
54 debug_handle = stoi(substr.substr(debug_handle_pos + 14));
55 }
56 result.push_back(std::make_tuple(instruction, debug_handle));
57 }
58 return result;
59}
60
61float* float_data_ptr(const at::Tensor& t) {
62 return t.unsafeGetTensorImpl()->data_ptr_impl<float>();
63}
64} // namespace
65
66class BackendWithCompiler : public PyTorchBackendInterface {
67 public:
68 // Constructor.
69 // NOLINTNEXTLINE(modernize-use-equals-default)
70 explicit BackendWithCompiler() {}
71 // NOLINTNEXTLINE(modernize-use-override)
72 virtual ~BackendWithCompiler() = default;
73
74 bool is_available() override {
75 return true;
76 }
77
78 // Since the actual compilation is done AOT for this backend, compile just
79 // forwards everything along. In a non toy setup this could grab information
80 // from that runtime that might be relevant to execute, such as build flags
81 // the resolution of the devices camera, or basically any runtime specific
82 // information that wouldnt be available server side where preprocess is
83 // called.
84 c10::impl::GenericDict compile(
85 c10::IValue processed,
86 c10::impl::GenericDict method_compile_spec) override {
87 auto dict = processed.toGenericDict();
88 auto handles =
89 c10::Dict<std::string, std::vector<std::tuple<std::string, int64_t>>>();
90 for (const auto& kv : dict) {
91 auto tokens = parseMethodHandle(kv.value().toStringRef());
92 handles.insert(kv.key().toStringRef(), tokens);
93 }
94 return c10::impl::toGenericDict(handles);
95 }
96
97 // Function that actually executes the model in the backend. Here there is
98 // nothing to dispatch to, so the backend is implemented locally within
99 // execute and it only supports add, subtract, and constant. In a non toy
100 // backend you can imagine how this function could be used to actually
101 // dispatch the inputs to the relevant backend/device.
102 c10::impl::GenericList execute(
103 c10::IValue
104 handle, // example: [('prim::Constant#1', 14), ('aten::add', 15)]
105 c10::impl::GenericList inputs) override {
106 TORCH_INTERNAL_ASSERT(inputs.size() == 2);
107 c10::IValue val0 = inputs[0];
108 at::Tensor x = val0.toTensor();
109 c10::IValue val1 = inputs[1];
110 at::Tensor h = val1.toTensor();
111 std::vector<std::tuple<int64_t, int64_t, std::string>> op_runtimes_us;
112 op_runtimes_us.reserve(handle.toList().size());
113
114 c10::List<at::Tensor> output_list;
115#ifndef NO_PROFILING
116 auto start_us = torch::profiler::impl::getTime() / 1000;
117#endif
118 for (const auto& token : handle.toList()) {
119 IValue val = token;
120 auto instruction = val.toTupleRef().elements()[0].toStringRef();
121 auto debug_handle = val.toTupleRef().elements()[1].toInt();
122#ifndef NO_PROFILING
123 auto start_time_us = torch::profiler::impl::getTime() / 1000;
124#endif
125 try {
126 if (instruction.rfind("prim::Constant", 0) == 0) {
127 // 15 is the length of 'prim::Constant#' the constant val comes after
128 TORCH_CHECK(
129 instruction.size() > 15,
130 "Constant value is expected in ",
131 instruction);
132 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
133 auto sub = instruction.substr(15);
134 } else if (instruction == "aten::add" || instruction == "aten::sub") {
135 TORCH_CHECK(x.sizes() == h.sizes());
136 if (x.dim() > 1 || (x.dim() == 1 && x.size(0) > 1)) {
137 TORCH_WARN(
138 "Only the first elements of the tensors are added or subbed.");
139 }
140 TORCH_CHECK(
141 (x.scalar_type() == c10::ScalarType::Float &&
142 h.scalar_type() == c10::ScalarType::Float),
143 "Only float tensors are compatible for add and sub.");
144 at::Tensor y = at::detail::empty_cpu(x.sizes(), at::kFloat);
145 auto x_ptr = float_data_ptr(x);
146 auto h_ptr = float_data_ptr(h);
147 auto y_ptr = float_data_ptr(y);
148#ifndef NO_PROFILING
149 RECORD_BACKEND_MEMORY_EVENT_TO_EDGE_PROFILER(
150 x_ptr,
151 x.numel() * sizeof(float),
152 x.numel() * sizeof(float),
153 x.numel() * sizeof(float) + y.numel() * sizeof(float) +
154 h.numel() * sizeof(float),
155 c10::Device(c10::kCPU));
156#endif
157 if (instruction == "aten::add") {
158 y_ptr[0] = x_ptr[0] + h_ptr[0];
159 } else {
160 y_ptr[0] = x_ptr[0] - h_ptr[0];
161 }
162 output_list.emplace_back(y);
163 } else {
164 TORCH_CHECK(
165 false,
166 "Instruction, ",
167 instruction,
168 " is not supported. ",
169 "Contact the backend POC for details. ");
170 }
171 } catch (c10::Error& e) {
172 TORCH_DELEGATED_BACKEND_THROW(false, e.what(), debug_handle);
173 }
174#ifndef NO_PROFILING
175 auto end_time_us = torch::profiler::impl::getTime() / 1000;
176 auto duration = end_time_us - start_time_us;
177 op_runtimes_us.emplace_back(duration, debug_handle, instruction);
178#endif
179 }
180#ifndef NO_PROFILING
181 for (const auto& tup : op_runtimes_us) {
182 RECORD_BACKEND_EVENT_TO_EDGE_PROFILER(
183 start_us,
184 start_us + std::get<0>(tup),
185 std::get<1>(tup),
186 std::get<2>(tup),
187 "test_backend");
188 start_us = start_us + std::get<0>(tup);
189 }
190#endif
191 return c10::impl::toList(output_list);
192 }
193};
194
195namespace {
196constexpr auto backend_name = "backend_with_compiler_demo";
197static auto cls = torch::jit::backend<BackendWithCompiler>(backend_name);
198} // namespace
199
200} // namespace jit
201} // namespace torch
202