1 | #pragma once |
2 | |
3 | #include <ATen/core/function_schema.h> |
4 | #include <ATen/core/ivalue.h> |
5 | #include <ATen/core/qualified_name.h> |
6 | #include <c10/util/Exception.h> |
7 | #include <c10/util/FunctionRef.h> |
8 | |
9 | namespace c10 { |
10 | struct FunctionSchema; |
11 | }; |
12 | |
13 | namespace at { |
14 | TORCH_API void launch(std::function<void()> func); |
15 | } |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | |
20 | struct Graph; |
21 | struct Code; |
22 | |
23 | namespace mobile { |
24 | struct Code; |
25 | } |
26 | |
27 | using Stack = std::vector<at::IValue>; |
28 | using Kwargs = std::unordered_map<std::string, at::IValue>; |
29 | struct RecursiveMethodCallError : public std::exception {}; |
30 | using TaskLauncher = std::function<void(std::function<void()>)>; |
31 | |
32 | TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph, bool disable_autocast=false); |
33 | |
34 | // A Function is a pure Graph with no implicit `self` object bound. |
35 | // It contains schema information and the executor that manages the |
36 | // execution of the function. Method is a wrapper around an |
37 | // underlying Function that also provides a `self` object. |
38 | struct TORCH_API Function { |
39 | virtual c10::string_view doc_string() const { |
40 | static constexpr c10::string_view no_doc_string = "" ; |
41 | return no_doc_string; |
42 | } |
43 | |
44 | virtual bool isGraphFunction() const { |
45 | return false; |
46 | } |
47 | |
48 | virtual void run(Stack& stack) = 0; |
49 | |
50 | virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync( |
51 | Stack& /*stack*/, |
52 | TaskLauncher taskLauncher = at::launch) { |
53 | (void)taskLauncher; // Suppress unused variable warning |
54 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); |
55 | return {}; |
56 | } |
57 | |
58 | at::IValue operator()( |
59 | Stack stack, |
60 | const Kwargs& kwargs = Kwargs()) { |
61 | getSchema().checkAndNormalizeInputs(stack, kwargs); |
62 | run(stack); |
63 | return stack.front(); |
64 | } |
65 | |
66 | virtual const c10::QualifiedName& qualname() const = 0; |
67 | |
68 | const std::string& name() const { |
69 | return qualname().name(); |
70 | } |
71 | |
72 | // if this isn't yet defined, run its method_creator function |
73 | virtual void ensure_defined() = 0; |
74 | |
75 | virtual const c10::FunctionSchema& getSchema() const = 0; |
76 | |
77 | virtual size_t num_inputs() const = 0; |
78 | |
79 | virtual Function& setSchema(c10::FunctionSchema schema) = 0; |
80 | |
81 | // call() defines how different interpreter implementations interacts with |
82 | // Function objects. Basically interpreters need to provide a callback to |
83 | // communicate to Functions what to do if provided a Code object. |
84 | // Alternatively we could design the signature to return an optional Code |
85 | // object, but that requires special handling the null case in interpreter |
86 | // and the fallback behavior is not well defined by interpreter but rather |
87 | // Function themselves, so a callback approach is more reasonable than |
88 | // returning values. |
89 | // If call() returns true, then callback completes successfully, otherwise |
90 | // call() returns false. |
91 | |
92 | // Overload for server interpreter, a bailout size is needed for graph executor. |
93 | virtual bool call(Stack&, c10::optional<size_t>, c10::function_ref<void(const Code&)>) { |
94 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); |
95 | return false; |
96 | } |
97 | |
98 | // Overload for mobile interpreter. |
99 | virtual bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) { |
100 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); |
101 | return false; |
102 | } |
103 | |
104 | virtual ~Function() = default; |
105 | }; |
106 | } // namespace jit |
107 | } // namespace torch |
108 | |