1 | #pragma once |
2 | |
3 | #include <atomic> |
4 | #include <memory> |
5 | |
6 | #include <torch/csrc/jit/ir/ir.h> |
7 | #include <torch/csrc/jit/python/update_graph_executor_opt.h> |
8 | #include <torch/csrc/jit/runtime/argument_spec.h> |
9 | #include <torch/csrc/jit/runtime/interpreter.h> |
10 | #include <torch/csrc/jit/runtime/variable_tensor_list.h> |
11 | |
12 | C10_DECLARE_bool(torch_jit_enable_new_executor); |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | struct GraphExecutorState; |
17 | struct Code; |
18 | |
19 | enum ExecutorExecutionMode { |
20 | SIMPLE, |
21 | PROFILING, |
22 | }; |
23 | |
24 | struct ExecutionPlan { |
25 | ExecutionPlan() = default; |
26 | ExecutionPlan(std::shared_ptr<Graph> graph, std::string function_name) |
27 | : code(graph, std::move(function_name)), graph(std::move(graph)) {} |
28 | |
29 | operator bool() const { |
30 | return static_cast<bool>(graph); |
31 | } |
32 | |
33 | Code code; |
34 | std::shared_ptr<Graph> graph; |
35 | }; |
36 | |
37 | // Notice that those structs don't manage lifetime of their members. |
38 | // They are only valid only right after you call getDebugState() and should |
39 | // never be used again once another GraphExecutor function is called. |
40 | |
41 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
42 | struct GraphExecutorState { |
43 | const Graph* graph = nullptr; |
44 | ExecutionPlan fallback; // XXX: members of this field are optional |
45 | std::unordered_map<ArgumentSpec, ExecutionPlan> execution_plans; |
46 | }; |
47 | |
48 | struct TORCH_API EnableProfilingGuard { |
49 | EnableProfilingGuard(); |
50 | ~EnableProfilingGuard(); |
51 | |
52 | private: |
53 | bool old_executor_mode = false; |
54 | bool old_get_optimize = false; |
55 | }; |
56 | |
57 | struct GraphExecutorImplBase; |
58 | struct TORCH_API GraphExecutor { |
59 | GraphExecutor() = default; |
60 | GraphExecutor(const std::shared_ptr<Graph>& graph, std::string function_name); |
61 | |
62 | GraphExecutor( |
63 | const std::shared_ptr<Graph>& graph, |
64 | std::string function_name, |
65 | ExecutorExecutionMode executor_mode); |
66 | |
67 | void run(Stack& inputs); |
68 | c10::intrusive_ptr<Future> runAsync( |
69 | Stack& stack, |
70 | TaskLauncher taskLauncher = at::launch); |
71 | |
72 | // `remaining_bailout_depth` stands for the maximum number of profiled and |
73 | // specialized recompilations allowed for the current `GraphExecutor`. if |
74 | // remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any |
75 | // profiling and specialization. This is also equivalent to the |
76 | // SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0, |
77 | // `GraphExecutor` will profile and specialize its input graph based on the |
78 | // profiled information whenever a bailout check is failed/triggered, a new |
79 | // `GraphExecutor` will be created. This new `GraphExecutor`'s |
80 | // remaining_bailout_depth will be reduced by 1. |
81 | // If no bailout depth is passed, the depth will be initialized from the |
82 | // current global fusion strategy settings. |
83 | const ExecutionPlan& getPlanFor( |
84 | Stack& inputs, |
85 | c10::optional<size_t> remaining_bailout_depth = c10::nullopt); |
86 | GraphExecutorState getDebugState(); |
87 | |
88 | void debugFlushCompilationCache(); |
89 | |
90 | bool isOptimized() const; |
91 | |
92 | private: |
93 | std::shared_ptr<GraphExecutorImplBase> pImpl; |
94 | }; |
95 | |
96 | TORCH_API Node* replaceBlockWithFallbackGraph( |
97 | Block* b, |
98 | ArrayRef<Value*> inputs); |
99 | |
100 | // These passes need to run before it is valid to pass to the interpreter |
101 | // regardless of whether sizes have been specialized or not. |
102 | TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g); |
103 | |
104 | TORCH_API void debugSetFusionGroupInlining(bool state); |
105 | TORCH_API bool getFusionGroupInlining(); |
106 | |
107 | TORCH_API void debugSetAutodiffSubgraphInlining(bool state); |
108 | TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph(); |
109 | |
110 | TORCH_API std::atomic<bool>& getProfilingMode(); |
111 | TORCH_API std::atomic<bool>& getExecutorMode(); |
112 | TORCH_API std::atomic<size_t>& getNumProfiledRuns(); |
113 | TORCH_API size_t getBailoutDepth(); |
114 | TORCH_API bool IsNewExecutorEnabled(); |
115 | |
116 | struct TORCH_API GraphOptimizerEnabledGuard { |
117 | GraphOptimizerEnabledGuard(bool state) |
118 | : old_state_(getGraphExecutorOptimize()) { |
119 | setGraphExecutorOptimize(state); |
120 | } |
121 | |
122 | ~GraphOptimizerEnabledGuard() { |
123 | setGraphExecutorOptimize(old_state_); |
124 | } |
125 | |
126 | bool old_state_; |
127 | }; |
128 | |
129 | namespace detail { |
130 | |
131 | GraphExecutor* getGradExecutor(Operation& op); |
132 | |
133 | GraphExecutor* getDifferentiableGraphOpExecutor(Operation& op); |
134 | |
135 | // for debugging information we expose a way to get the last actually |
136 | // run graph. Previous approaches allowed querying the GraphExecutor |
137 | // for what graph it would run in certain circumstances (graphFor), but |
138 | // this is fragile because we sometimes change how these decisions are made. |
139 | // This interface still allows our tests to look at optimized graphs, but |
140 | // with less plumbing. |
141 | } // namespace detail |
142 | |
143 | } // namespace jit |
144 | } // namespace torch |
145 | |