1#pragma once
2
3#include <ATen/core/function.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/runtime/graph_executor.h>
6#include <torch/csrc/utils/memory.h>
7
8namespace torch {
9namespace jit {
10
11struct TORCH_API GraphFunction : public Function {
12 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
13 GraphFunction(
14 c10::QualifiedName name,
15 std::shared_ptr<Graph> graph,
16 std::function<void(GraphFunction&)> function_creator,
17 c10::optional<ExecutorExecutionMode> executor_execution_mode =
18 c10::nullopt)
19 : name_(std::move(name)),
20 graph_(std::move(graph)),
21 executor_execution_mode_(executor_execution_mode),
22 function_creator_(std::move(function_creator)) {}
23
24 bool isGraphFunction() const override {
25 return true;
26 }
27
28 void run(Stack& stack) override;
29
30 std::function<void(GraphFunction&)> function_creator() const {
31 return function_creator_;
32 }
33
34 c10::intrusive_ptr<c10::ivalue::Future> runAsync(
35 Stack& stack,
36 TaskLauncher taskLauncher = at::launch) override;
37
38 std::shared_ptr<Graph> graph() const {
39 return graph_;
40 }
41
42 std::shared_ptr<Graph> optimized_graph() const {
43 std::lock_guard<std::recursive_mutex> lock(compile_mutex);
44 auto& optimized_graph = optimized_graphs_[currentSpecialization()];
45 if (optimized_graph) {
46 return *optimized_graph;
47 }
48 optimized_graph = graph_->copy();
49 if (getGraphExecutorOptimize()) {
50 preoptimizeGraph(*optimized_graph, force_no_amp_);
51 }
52 return *optimized_graph;
53 }
54
55 const c10::QualifiedName& qualname() const override {
56 return name_;
57 }
58
59 // private/unstable api. sets the initial execution mode
60 // will not affect executor if there is an existing executor
61 // created for this function
62 void _set_initial_executor_execution_mode(ExecutorExecutionMode mode) {
63 executor_execution_mode_ = mode;
64 }
65 // private/unstable api. sets flag of whether or not to ignore amp.
66 // will not affect executor if there is an existing executor
67 // created for this function
68 void _set_ignore_amp(bool ignore_amp) {
69 force_no_amp_ = ignore_amp;
70 }
71
72 // if this isn't yet defined, run its method_creator function
73 void ensure_defined() override;
74
75 size_t num_inputs() const override {
76 return graph()->inputs().size();
77 }
78
79 Function& setSchema(FunctionSchema schema) override {
80 schema_ = make_unique<FunctionSchema>(std::move(schema));
81 return *this;
82 }
83
84 const FunctionSchema& getSchema() const override;
85
86 GraphExecutorState getDebugState() {
87 return get_executor().getDebugState();
88 }
89
90 bool is_optimized() const {
91 TORCH_WARN(
92 "GraphFunction::is_optimized() is deprecated and always returns true. "
93 "Please use getGraphExecutorOptimize()");
94 return true;
95 }
96
97 void check_single_output() {
98 TORCH_CHECK(
99 graph()->outputs().size() == 1,
100 "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
101 }
102
103 GraphExecutor& get_executor() {
104 ensure_defined();
105 std::lock_guard<std::recursive_mutex> lock(compile_mutex);
106 auto& executor = executors_[currentSpecialization()];
107 if (executor) {
108 return *executor;
109 }
110 check_single_output();
111 const std::string& name = name_.name();
112 std::shared_ptr<Graph> opt_graph = optimized_graph();
113 if (!executor_execution_mode_) {
114 executor = GraphExecutor(opt_graph, name);
115 } else {
116 executor = GraphExecutor(opt_graph, name, *executor_execution_mode_);
117 }
118 return *executor;
119 }
120
121 using Function::call;
122 bool call(
123 Stack& stack,
124 c10::optional<size_t> bailOut,
125 c10::function_ref<void(const Code&)> f) override {
126 f(get_executor().getPlanFor(stack, bailOut).code);
127 return true;
128 }
129
130 void clear_optimized_graphs() {
131 optimized_graphs_.fill(c10::nullopt);
132 }
133
134 private:
135 enum SpecializationKey {
136 AutocastOff,
137 CpuAutocastOn,
138 GpuAutocastOn,
139 CpuGpuAutocastOn,
140
141 // This provides the number of specializations
142 // (Must be last entry)
143 TotalCount
144 };
145
146 SpecializationKey currentSpecialization() const;
147
148 private:
149 c10::QualifiedName name_;
150 // The original, non-optimized graph
151 std::shared_ptr<Graph> graph_; // for debugging and for inlining
152
153 // allows users to specify Simple/Profiling Executor for function
154 // TODO: add more executors
155 mutable c10::optional<ExecutorExecutionMode> executor_execution_mode_;
156
157 // if invoked on a graph that has already traced through amp
158 // don't invoke amp pass
159 mutable bool force_no_amp_ = false;
160 // Optimized graph, computed lazily. Used for inlining.
161 mutable std::array<
162 c10::optional<std::shared_ptr<Graph>>,
163 SpecializationKey::TotalCount>
164 optimized_graphs_;
165
166 // GraphFunctions are invokable from multiple threads, so this lock needs to
167 // be held when we're initializing graph executor for the first time or
168 // computing the optimized graph. We're using reentrant mutex so that we don't
169 // need to worry about causing a deadlock by calling one method from another
170 // (e.g. optimized_graph() from get_executor()).
171 mutable std::recursive_mutex compile_mutex;
172
173 // executor_[0] - autocast off
174 // executor_[1] - autocast cpu on
175 // executor_[2] - autocast gpu on
176 // executor_[3] - autocast cpu & gpu on
177 std::array<c10::optional<GraphExecutor>, SpecializationKey::TotalCount>
178 executors_;
179
180 // an optional function that actually creates the method when
181 // ensure_defined() is called. This is used by the compiler so
182 // that it can construct methods out of order
183 std::function<void(GraphFunction&)> function_creator_;
184
185 // if absent, then we generate a default schema based on the graph
186 // mutable because getSchema caches the default schema if one is requested
187 // before a call to setSchema
188 mutable std::unique_ptr<FunctionSchema> schema_;
189};
190
191// Short hands for dynamic_cast<GraphFunction*>.
192TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept;
193TORCH_API GraphFunction& toGraphFunction(Function&);
194TORCH_API const GraphFunction& toGraphFunction(const Function&);
195
196} // namespace jit
197} // namespace torch
198