1 | #pragma once |
2 | |
3 | #include <ATen/core/Tensor.h> |
4 | #include <ATen/core/function_schema.h> |
5 | #include <c10/macros/Export.h> |
6 | |
7 | // NOTE: [Jit Decomposition Interface] |
8 | // |
9 | // For some context of why we need this at all, see NOTE: [forward-mode AD |
10 | // decompositions mechanism] |
11 | // |
12 | // Introducing that mechanism from the NOTE is problematic because: |
13 | // - it relies on TorchScript, so now VariableTypeX.cpp depends on TorchScript. |
14 | // - there exist internal builds like lite_trainer, which depend on VariableType |
15 | // but do not depend on TorchScript. |
16 | // |
17 | // For internal builds like lite_trainer builds to pass, and for OSS builds that |
18 | // do depend on TorchScript to still support the forward AD decomp mechanism, we |
19 | // implement a PImpl pattern to avoid a static dependency in favor of a dynamic |
20 | // one |
21 | // - during static initialization time, if the library is built with TorchScript |
22 | // setJitDecompImpl is called in decomposition_registry.cpp setting a global |
23 | // ptr to the impl |
24 | // - when the program is run,if getJitDecompImpl returns a non null ptr, we can |
25 | // carry on normally, otherwise we gracefully error out |
26 | // |
27 | // For extra context, see VariableHooksInterface.h, where a similar technique |
28 | // is used |
29 | |
30 | namespace torch { |
31 | namespace autograd { |
32 | namespace impl { |
33 | |
34 | struct TORCH_API JitDecompInterface { |
35 | virtual ~JitDecompInterface() = default; |
36 | virtual bool has_jit_decomposition( |
37 | const c10::FunctionSchema& schema) const = 0; |
38 | virtual void run_jit_decomposition( |
39 | const c10::OperatorHandle& op, |
40 | jit::Stack* stack) const = 0; |
41 | }; |
42 | |
43 | TORCH_API void setJitDecompImpl(JitDecompInterface* impl); |
44 | TORCH_API JitDecompInterface* getJitDecompImpl(); |
45 | |
46 | struct TORCH_API JitDecompRegisterer { |
47 | explicit JitDecompRegisterer(JitDecompInterface* impl) { |
48 | setJitDecompImpl(impl); |
49 | } |
50 | }; |
51 | |
52 | } // namespace impl |
53 | } // namespace autograd |
54 | } // namespace torch |
55 | |