1 | #pragma once |
2 | |
3 | #include <ATen/core/function.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/runtime/graph_executor.h> |
6 | #include <torch/csrc/utils/memory.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | struct TORCH_API GraphFunction : public Function { |
12 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
13 | GraphFunction( |
14 | c10::QualifiedName name, |
15 | std::shared_ptr<Graph> graph, |
16 | std::function<void(GraphFunction&)> function_creator, |
17 | c10::optional<ExecutorExecutionMode> executor_execution_mode = |
18 | c10::nullopt) |
19 | : name_(std::move(name)), |
20 | graph_(std::move(graph)), |
21 | executor_execution_mode_(executor_execution_mode), |
22 | function_creator_(std::move(function_creator)) {} |
23 | |
24 | bool isGraphFunction() const override { |
25 | return true; |
26 | } |
27 | |
28 | void run(Stack& stack) override; |
29 | |
30 | std::function<void(GraphFunction&)> function_creator() const { |
31 | return function_creator_; |
32 | } |
33 | |
34 | c10::intrusive_ptr<c10::ivalue::Future> runAsync( |
35 | Stack& stack, |
36 | TaskLauncher taskLauncher = at::launch) override; |
37 | |
38 | std::shared_ptr<Graph> graph() const { |
39 | return graph_; |
40 | } |
41 | |
42 | std::shared_ptr<Graph> optimized_graph() const { |
43 | std::lock_guard<std::recursive_mutex> lock(compile_mutex); |
44 | auto& optimized_graph = optimized_graphs_[currentSpecialization()]; |
45 | if (optimized_graph) { |
46 | return *optimized_graph; |
47 | } |
48 | optimized_graph = graph_->copy(); |
49 | if (getGraphExecutorOptimize()) { |
50 | preoptimizeGraph(*optimized_graph, force_no_amp_); |
51 | } |
52 | return *optimized_graph; |
53 | } |
54 | |
55 | const c10::QualifiedName& qualname() const override { |
56 | return name_; |
57 | } |
58 | |
59 | // private/unstable api. sets the initial execution mode |
60 | // will not affect executor if there is an existing executor |
61 | // created for this function |
62 | void _set_initial_executor_execution_mode(ExecutorExecutionMode mode) { |
63 | executor_execution_mode_ = mode; |
64 | } |
65 | // private/unstable api. sets flag of whether or not to ignore amp. |
66 | // will not affect executor if there is an existing executor |
67 | // created for this function |
68 | void _set_ignore_amp(bool ignore_amp) { |
69 | force_no_amp_ = ignore_amp; |
70 | } |
71 | |
72 | // if this isn't yet defined, run its method_creator function |
73 | void ensure_defined() override; |
74 | |
75 | size_t num_inputs() const override { |
76 | return graph()->inputs().size(); |
77 | } |
78 | |
79 | Function& setSchema(FunctionSchema schema) override { |
80 | schema_ = make_unique<FunctionSchema>(std::move(schema)); |
81 | return *this; |
82 | } |
83 | |
84 | const FunctionSchema& getSchema() const override; |
85 | |
86 | GraphExecutorState getDebugState() { |
87 | return get_executor().getDebugState(); |
88 | } |
89 | |
90 | bool is_optimized() const { |
91 | TORCH_WARN( |
92 | "GraphFunction::is_optimized() is deprecated and always returns true. " |
93 | "Please use getGraphExecutorOptimize()" ); |
94 | return true; |
95 | } |
96 | |
97 | void check_single_output() { |
98 | TORCH_CHECK( |
99 | graph()->outputs().size() == 1, |
100 | "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs" ); |
101 | } |
102 | |
103 | GraphExecutor& get_executor() { |
104 | ensure_defined(); |
105 | std::lock_guard<std::recursive_mutex> lock(compile_mutex); |
106 | auto& executor = executors_[currentSpecialization()]; |
107 | if (executor) { |
108 | return *executor; |
109 | } |
110 | check_single_output(); |
111 | const std::string& name = name_.name(); |
112 | std::shared_ptr<Graph> opt_graph = optimized_graph(); |
113 | if (!executor_execution_mode_) { |
114 | executor = GraphExecutor(opt_graph, name); |
115 | } else { |
116 | executor = GraphExecutor(opt_graph, name, *executor_execution_mode_); |
117 | } |
118 | return *executor; |
119 | } |
120 | |
121 | using Function::call; |
122 | bool call( |
123 | Stack& stack, |
124 | c10::optional<size_t> bailOut, |
125 | c10::function_ref<void(const Code&)> f) override { |
126 | f(get_executor().getPlanFor(stack, bailOut).code); |
127 | return true; |
128 | } |
129 | |
130 | void clear_optimized_graphs() { |
131 | optimized_graphs_.fill(c10::nullopt); |
132 | } |
133 | |
134 | private: |
135 | enum SpecializationKey { |
136 | AutocastOff, |
137 | CpuAutocastOn, |
138 | GpuAutocastOn, |
139 | CpuGpuAutocastOn, |
140 | |
141 | // This provides the number of specializations |
142 | // (Must be last entry) |
143 | TotalCount |
144 | }; |
145 | |
146 | SpecializationKey currentSpecialization() const; |
147 | |
148 | private: |
149 | c10::QualifiedName name_; |
150 | // The original, non-optimized graph |
151 | std::shared_ptr<Graph> graph_; // for debugging and for inlining |
152 | |
153 | // allows users to specify Simple/Profiling Executor for function |
154 | // TODO: add more executors |
155 | mutable c10::optional<ExecutorExecutionMode> executor_execution_mode_; |
156 | |
157 | // if invoked on a graph that has already traced through amp |
158 | // don't invoke amp pass |
159 | mutable bool force_no_amp_ = false; |
160 | // Optimized graph, computed lazily. Used for inlining. |
161 | mutable std::array< |
162 | c10::optional<std::shared_ptr<Graph>>, |
163 | SpecializationKey::TotalCount> |
164 | optimized_graphs_; |
165 | |
166 | // GraphFunctions are invokable from multiple threads, so this lock needs to |
167 | // be held when we're initializing graph executor for the first time or |
168 | // computing the optimized graph. We're using reentrant mutex so that we don't |
169 | // need to worry about causing a deadlock by calling one method from another |
170 | // (e.g. optimized_graph() from get_executor()). |
171 | mutable std::recursive_mutex compile_mutex; |
172 | |
173 | // executor_[0] - autocast off |
174 | // executor_[1] - autocast cpu on |
175 | // executor_[2] - autocast gpu on |
176 | // executor_[3] - autocast cpu & gpu on |
177 | std::array<c10::optional<GraphExecutor>, SpecializationKey::TotalCount> |
178 | executors_; |
179 | |
180 | // an optional function that actually creates the method when |
181 | // ensure_defined() is called. This is used by the compiler so |
182 | // that it can construct methods out of order |
183 | std::function<void(GraphFunction&)> function_creator_; |
184 | |
185 | // if absent, then we generate a default schema based on the graph |
186 | // mutable because getSchema caches the default schema if one is requested |
187 | // before a call to setSchema |
188 | mutable std::unique_ptr<FunctionSchema> schema_; |
189 | }; |
190 | |
191 | // Short hands for dynamic_cast<GraphFunction*>. |
192 | TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept; |
193 | TORCH_API GraphFunction& toGraphFunction(Function&); |
194 | TORCH_API const GraphFunction& toGraphFunction(const Function&); |
195 | |
196 | } // namespace jit |
197 | } // namespace torch |
198 | |