1#pragma once
2
3#include <atomic>
4#include <memory>
5
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/python/update_graph_executor_opt.h>
8#include <torch/csrc/jit/runtime/argument_spec.h>
9#include <torch/csrc/jit/runtime/interpreter.h>
10#include <torch/csrc/jit/runtime/variable_tensor_list.h>
11
12C10_DECLARE_bool(torch_jit_enable_new_executor);
13
14namespace torch {
15namespace jit {
16struct GraphExecutorState;
17struct Code;
18
19enum ExecutorExecutionMode {
20 SIMPLE,
21 PROFILING,
22};
23
24struct ExecutionPlan {
25 ExecutionPlan() = default;
26 ExecutionPlan(std::shared_ptr<Graph> graph, std::string function_name)
27 : code(graph, std::move(function_name)), graph(std::move(graph)) {}
28
29 operator bool() const {
30 return static_cast<bool>(graph);
31 }
32
33 Code code;
34 std::shared_ptr<Graph> graph;
35};
36
37// Notice that those structs don't manage lifetime of their members.
38// They are only valid only right after you call getDebugState() and should
39// never be used again once another GraphExecutor function is called.
40
41// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
42struct GraphExecutorState {
43 const Graph* graph = nullptr;
44 ExecutionPlan fallback; // XXX: members of this field are optional
45 std::unordered_map<ArgumentSpec, ExecutionPlan> execution_plans;
46};
47
48struct TORCH_API EnableProfilingGuard {
49 EnableProfilingGuard();
50 ~EnableProfilingGuard();
51
52 private:
53 bool old_executor_mode = false;
54 bool old_get_optimize = false;
55};
56
57struct GraphExecutorImplBase;
58struct TORCH_API GraphExecutor {
59 GraphExecutor() = default;
60 GraphExecutor(const std::shared_ptr<Graph>& graph, std::string function_name);
61
62 GraphExecutor(
63 const std::shared_ptr<Graph>& graph,
64 std::string function_name,
65 ExecutorExecutionMode executor_mode);
66
67 void run(Stack& inputs);
68 c10::intrusive_ptr<Future> runAsync(
69 Stack& stack,
70 TaskLauncher taskLauncher = at::launch);
71
72 // `remaining_bailout_depth` stands for the maximum number of profiled and
73 // specialized recompilations allowed for the current `GraphExecutor`. if
74 // remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any
75 // profiling and specialization. This is also equivalent to the
76 // SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0,
77 // `GraphExecutor` will profile and specialize its input graph based on the
78 // profiled information whenever a bailout check is failed/triggered, a new
79 // `GraphExecutor` will be created. This new `GraphExecutor`'s
80 // remaining_bailout_depth will be reduced by 1.
81 // If no bailout depth is passed, the depth will be initialized from the
82 // current global fusion strategy settings.
83 const ExecutionPlan& getPlanFor(
84 Stack& inputs,
85 c10::optional<size_t> remaining_bailout_depth = c10::nullopt);
86 GraphExecutorState getDebugState();
87
88 void debugFlushCompilationCache();
89
90 bool isOptimized() const;
91
92 private:
93 std::shared_ptr<GraphExecutorImplBase> pImpl;
94};
95
96TORCH_API Node* replaceBlockWithFallbackGraph(
97 Block* b,
98 ArrayRef<Value*> inputs);
99
100// These passes need to run before it is valid to pass to the interpreter
101// regardless of whether sizes have been specialized or not.
102TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
103
104TORCH_API void debugSetFusionGroupInlining(bool state);
105TORCH_API bool getFusionGroupInlining();
106
107TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
108TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
109
110TORCH_API std::atomic<bool>& getProfilingMode();
111TORCH_API std::atomic<bool>& getExecutorMode();
112TORCH_API std::atomic<size_t>& getNumProfiledRuns();
113TORCH_API size_t getBailoutDepth();
114TORCH_API bool IsNewExecutorEnabled();
115
116struct TORCH_API GraphOptimizerEnabledGuard {
117 GraphOptimizerEnabledGuard(bool state)
118 : old_state_(getGraphExecutorOptimize()) {
119 setGraphExecutorOptimize(state);
120 }
121
122 ~GraphOptimizerEnabledGuard() {
123 setGraphExecutorOptimize(old_state_);
124 }
125
126 bool old_state_;
127};
128
129namespace detail {
130
131GraphExecutor* getGradExecutor(Operation& op);
132
133GraphExecutor* getDifferentiableGraphOpExecutor(Operation& op);
134
135// for debugging information we expose a way to get the last actually
136// run graph. Previous approaches allowed querying the GraphExecutor
137// for what graph it would run in certain circumstances (graphFor), but
138// this is fragile because we sometimes change how these decisions are made.
139// This interface still allows our tests to look at optimized graphs, but
140// with less plumbing.
141} // namespace detail
142
143} // namespace jit
144} // namespace torch
145