1#include <torch/csrc/jit/backends/backend.h>
2#include <torch/csrc/jit/backends/backend_debug_handler.h>
3#include <torch/csrc/jit/backends/backend_preprocess.h>
4
5namespace torch {
6namespace jit {
7// This test JIT backend is intended to do the minimal amount of work
8// necessary to test that the JIT backend registration endpoints and
9// code generation are working correctly. It is not intended to
10// produce numerically correct results.
11template <bool isAvailable>
12class TestBackend : public PyTorchBackendInterface {
13 public:
14 // Constructor.
15 // NOLINTNEXTLINE(modernize-use-equals-default)
16 explicit TestBackend() {}
17 // NOLINTNEXTLINE(modernize-use-override)
18 virtual ~TestBackend() = default;
19
20 bool is_available() override {
21 return isAvailable;
22 }
23
24 c10::impl::GenericDict compile(
25 c10::IValue processed,
26 c10::impl::GenericDict method_compile_spec) override {
27 auto spec =
28 c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
29
30 // Return the same string as a value for every key in method_compile_spec.
31 auto handles = c10::Dict<std::string, std::string>();
32 for (const auto& it : spec) {
33 handles.insert(it.key(), it.key());
34 }
35 return c10::impl::toGenericDict(handles);
36 }
37 c10::impl::GenericList execute(
38 c10::IValue handle,
39 c10::impl::GenericList inputs) override {
40 TORCH_INTERNAL_ASSERT(handle.isString());
41 TORCH_INTERNAL_ASSERT(inputs.size() > 0);
42
43 c10::List<at::Tensor> output_list;
44
45 // Implement simple accumulator and negative accumulator (?) ops. Return one
46 // or both of them depending on the handle to make sure multiple outputs are
47 // handled.
48 c10::IValue value = inputs[0];
49 at::Tensor accum = value.toTensor();
50 accum = accum.clone();
51 at::Tensor sub_accum = value.toTensor();
52 sub_accum = sub_accum.clone();
53
54 for (size_t i = 1, e = inputs.size(); i < e; ++i) {
55 value = inputs[i];
56 accum.add_(value.toTensor(), 1.0);
57 sub_accum.sub_(value.toTensor(), 1.0);
58 }
59
60 if (handle.toStringRef() == "accum") {
61 output_list.emplace_back(accum);
62 } else if (handle.toStringRef() == "sub_accum") {
63 output_list.emplace_back(sub_accum);
64 } else if (handle.toStringRef() == "forward") {
65 output_list.emplace_back(accum);
66 output_list.emplace_back(sub_accum);
67 }
68
69 return c10::impl::toList(output_list);
70 }
71};
72
73namespace {
74c10::IValue preprocess(
75 const Module& mod,
76 const c10::Dict<IValue, IValue>& method_compile_spec,
77 const BackendDebugHandleGenerator& generate_debug_handles) {
78 return mod._ivalue();
79}
80
81constexpr auto backend_name = "test_backend";
82static auto cls_available =
83 torch::jit::backend<TestBackend<true>>(backend_name);
84static auto pre_reg = backend_preprocess_register(backend_name, preprocess);
85
86constexpr auto backend_unavailable_name = "test_backend_unavailable";
87static auto cls_unavailable =
88 torch::jit::backend<TestBackend<false>>(backend_unavailable_name);
89static auto pre_reg_unavailable =
90 backend_preprocess_register(backend_unavailable_name, preprocess);
91
92} // namespace
93} // namespace jit
94} // namespace torch
95