1 | #include <torch/csrc/jit/backends/backend.h> |
2 | #include <torch/csrc/jit/backends/backend_debug_handler.h> |
3 | #include <torch/csrc/jit/backends/backend_preprocess.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | // This test JIT backend is intended to do the minimal amount of work |
8 | // necessary to test that the JIT backend registration endpoints and |
9 | // code generation are working correctly. It is not intended to |
10 | // produce numerically correct results. |
11 | template <bool isAvailable> |
12 | class TestBackend : public PyTorchBackendInterface { |
13 | public: |
14 | // Constructor. |
15 | // NOLINTNEXTLINE(modernize-use-equals-default) |
16 | explicit TestBackend() {} |
17 | // NOLINTNEXTLINE(modernize-use-override) |
18 | virtual ~TestBackend() = default; |
19 | |
20 | bool is_available() override { |
21 | return isAvailable; |
22 | } |
23 | |
24 | c10::impl::GenericDict compile( |
25 | c10::IValue processed, |
26 | c10::impl::GenericDict method_compile_spec) override { |
27 | auto spec = |
28 | c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec); |
29 | |
30 | // Return the same string as a value for every key in method_compile_spec. |
31 | auto handles = c10::Dict<std::string, std::string>(); |
32 | for (const auto& it : spec) { |
33 | handles.insert(it.key(), it.key()); |
34 | } |
35 | return c10::impl::toGenericDict(handles); |
36 | } |
37 | c10::impl::GenericList execute( |
38 | c10::IValue handle, |
39 | c10::impl::GenericList inputs) override { |
40 | TORCH_INTERNAL_ASSERT(handle.isString()); |
41 | TORCH_INTERNAL_ASSERT(inputs.size() > 0); |
42 | |
43 | c10::List<at::Tensor> output_list; |
44 | |
45 | // Implement simple accumulator and negative accumulator (?) ops. Return one |
46 | // or both of them depending on the handle to make sure multiple outputs are |
47 | // handled. |
48 | c10::IValue value = inputs[0]; |
49 | at::Tensor accum = value.toTensor(); |
50 | accum = accum.clone(); |
51 | at::Tensor sub_accum = value.toTensor(); |
52 | sub_accum = sub_accum.clone(); |
53 | |
54 | for (size_t i = 1, e = inputs.size(); i < e; ++i) { |
55 | value = inputs[i]; |
56 | accum.add_(value.toTensor(), 1.0); |
57 | sub_accum.sub_(value.toTensor(), 1.0); |
58 | } |
59 | |
60 | if (handle.toStringRef() == "accum" ) { |
61 | output_list.emplace_back(accum); |
62 | } else if (handle.toStringRef() == "sub_accum" ) { |
63 | output_list.emplace_back(sub_accum); |
64 | } else if (handle.toStringRef() == "forward" ) { |
65 | output_list.emplace_back(accum); |
66 | output_list.emplace_back(sub_accum); |
67 | } |
68 | |
69 | return c10::impl::toList(output_list); |
70 | } |
71 | }; |
72 | |
73 | namespace { |
74 | c10::IValue preprocess( |
75 | const Module& mod, |
76 | const c10::Dict<IValue, IValue>& method_compile_spec, |
77 | const BackendDebugHandleGenerator& generate_debug_handles) { |
78 | return mod._ivalue(); |
79 | } |
80 | |
81 | constexpr auto backend_name = "test_backend" ; |
82 | static auto cls_available = |
83 | torch::jit::backend<TestBackend<true>>(backend_name); |
84 | static auto pre_reg = backend_preprocess_register(backend_name, preprocess); |
85 | |
86 | constexpr auto backend_unavailable_name = "test_backend_unavailable" ; |
87 | static auto cls_unavailable = |
88 | torch::jit::backend<TestBackend<false>>(backend_unavailable_name); |
89 | static auto pre_reg_unavailable = |
90 | backend_preprocess_register(backend_unavailable_name, preprocess); |
91 | |
92 | } // namespace |
93 | } // namespace jit |
94 | } // namespace torch |
95 | |