1 | #pragma once |
2 | |
3 | #include <ATen/core/Dimname.h> |
4 | #include <ATen/core/class_type.h> |
5 | #include <ATen/core/jit_type.h> |
6 | #include <ATen/core/stack.h> |
7 | #include <ATen/core/symbol.h> |
8 | #include <c10/util/Exception.h> |
9 | #include <torch/csrc/Export.h> |
10 | |
11 | #include <torch/csrc/jit/frontend/source_range.h> |
12 | #include <torch/csrc/utils/variadic.h> |
13 | |
14 | #include <cstdint> |
15 | #include <iostream> |
16 | #include <memory> |
17 | #include <mutex> |
18 | #include <unordered_map> |
19 | #include <vector> |
20 | |
21 | namespace torch { |
22 | namespace jit { |
23 | struct Node; |
24 | struct Value; |
25 | struct Graph; |
26 | struct Module; |
27 | |
28 | namespace tracer { |
29 | |
30 | using ::c10::ivalue::Shared; |
31 | |
32 | using ::c10::IValue; |
33 | using ::c10::ivalue::Future; |
34 | |
35 | using ::c10::ArrayRef; |
36 | using ::c10::TupleType; |
37 | using ::c10::TupleTypePtr; |
38 | using ::c10::ivalue::ConstantString; |
39 | |
40 | using torch::autograd::Variable; |
41 | using variable_list = std::vector<Variable>; |
42 | |
43 | TORCH_API std::atomic<bool>& getTracerStateWarnMode(); |
44 | |
45 | struct TORCH_API TracingState |
46 | : public std::enable_shared_from_this<TracingState> { |
47 | TracingState(); |
48 | ~TracingState(); |
49 | |
50 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
51 | std::shared_ptr<Graph> graph; |
52 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
53 | bool warn = getTracerStateWarnMode(); |
54 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
55 | bool strict = true; |
56 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
57 | bool force_outplace = false; |
58 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
59 | std::function<std::string(const Variable& var)> lookup_var_name_fn = |
60 | [](const Variable& var) { return "" ; }; |
61 | |
62 | void enterFrame() { |
63 | env_stack.emplace_back(); |
64 | } |
65 | |
66 | void leaveFrame() { |
67 | env_stack.pop_back(); |
68 | } |
69 | |
70 | void setValue(const IValue& v, Value* value); |
71 | void delValue(const IValue& var); |
72 | Value* getValue(const IValue& var); |
73 | Value* getOutput(const IValue& var, size_t i); |
74 | bool hasValue(const IValue& var) const; |
75 | |
76 | Node* createNode(c10::Symbol op_name, size_t num_outputs); |
77 | void insertNode(Node* node); |
78 | |
79 | private: |
80 | using WeakIValue = at::WeakIValue; |
81 | |
82 | struct WeakIValueHasher { |
83 | size_t operator()(const WeakIValue& t) const { |
84 | return t.hash(); |
85 | } |
86 | }; |
87 | |
88 | struct WeakIValueEq { |
89 | bool operator()(const WeakIValue& t1, const WeakIValue& t2) const { |
90 | return t1.isSameIdentity(t2); |
91 | } |
92 | }; |
93 | |
94 | using Frame = |
95 | std::unordered_map<WeakIValue, Value*, WeakIValueHasher, WeakIValueEq>; |
96 | std::vector<Frame> env_stack; |
97 | }; |
98 | |
99 | // This is meant to be used as a thread local place, where we can store extra |
100 | // info that gets lost when we call into ATen from Python bindings. One example |
101 | // for when this happens is when we get an IntArrayRef argument with e.g. sizes |
102 | // for view. When tracing, those might be tensors, which let us encode extra |
103 | // data dependencies, but once they get to the ATen call where we actually have |
104 | // the tracing logic, they get converted into a raw IntArrayRef, and we loose |
105 | // all information. To prevent this, we temporarily stash it in here. |
106 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
107 | struct ArgumentStash { |
108 | struct IntArrayRefTrace : std::vector<Value*> { |
109 | IntArrayRefTrace(int size) : std::vector<Value*>(size, nullptr) {} |
110 | }; |
111 | |
112 | static bool empty() { |
113 | return stash.intlists.empty(); |
114 | } |
115 | |
116 | TORCH_API static void stashIntArrayRefElem( |
117 | const std::string& arg_name, |
118 | size_t size, |
119 | size_t idx, |
120 | const Variable& var); |
121 | |
122 | static bool hasIntArrayRef(const std::string& arg_name) { |
123 | return stash.intlists.count(arg_name) > 0; |
124 | } |
125 | |
126 | static IntArrayRefTrace popIntArrayRef(const std::string& arg_name) { |
127 | auto info = std::move(stash.intlists.at(arg_name)); |
128 | stash.intlists.erase(arg_name); |
129 | return info; |
130 | } |
131 | |
132 | // Value stashing: Use these methods to stash arguments which correspond |
133 | // to regular Value*'s in the graph. i.e. they don't require special |
134 | // handling like in the case of IntArrayRefs |
135 | TORCH_API static void stashValue( |
136 | const std::string& arg_name, |
137 | size_t idx, |
138 | const Variable& var, |
139 | const c10::TypePtr& type = nullptr); |
140 | |
141 | static bool hasValue(const std::string& arg_name) { |
142 | return stash.values.count(arg_name) > 0; |
143 | } |
144 | |
145 | static Value* popValue(const std::string& arg_name) { |
146 | auto info = stash.values.at(arg_name); |
147 | stash.values.erase(arg_name); |
148 | return info; |
149 | } |
150 | |
151 | private: |
152 | static thread_local ArgumentStash stash; |
153 | std::unordered_map<std::string, IntArrayRefTrace> intlists; |
154 | std::unordered_map<std::string, Value*> values; |
155 | }; |
156 | |
157 | // Retrieve or set the current tracing state. Returns a nullptr if tracing is |
158 | // disabled. |
159 | TORCH_API const std::shared_ptr<TracingState>& getTracingState(); |
160 | TORCH_API void setTracingState(std::shared_ptr<TracingState> state); |
161 | |
162 | inline bool isTracing() { |
163 | return static_cast<bool>(getTracingState()); |
164 | } |
165 | |
166 | using warn_fn_type = void (*)(const std::string& msg); |
167 | TORCH_API extern const char* WARN_PYTHON_DATAFLOW; |
168 | TORCH_API extern const char* WARN_CONSTRUCTOR; |
169 | TORCH_API extern const char* WARN_RESIZE; |
170 | TORCH_API extern const char* STRICT_TRACER_MSG; |
171 | TORCH_API void _do_warn(const char* _reason, const char* _kind); |
172 | inline void warn(const char* _reason, const char* _kind = nullptr) { |
173 | if (const auto& state = getTracingState()) { |
174 | if (!state->warn) |
175 | return; |
176 | _do_warn(_reason, _kind); |
177 | } |
178 | } |
179 | TORCH_API void setWarn(warn_fn_type fn); |
180 | |
181 | struct TORCH_API NoWarn { |
182 | NoWarn() : state(getTracingState()) { |
183 | if (state) { |
184 | prev = state->warn; |
185 | state->warn = false; |
186 | } |
187 | } |
188 | ~NoWarn() { |
189 | if (state) { |
190 | state->warn = prev; |
191 | } |
192 | } |
193 | std::shared_ptr<TracingState> state; |
194 | bool prev{false}; |
195 | }; |
196 | |
197 | struct WithNestedTracingFrame { |
198 | WithNestedTracingFrame() { |
199 | getTracingState()->enterFrame(); |
200 | } |
201 | |
202 | ~WithNestedTracingFrame() { |
203 | getTracingState()->leaveFrame(); |
204 | } |
205 | }; |
206 | TORCH_API void recordSourceLocation(Node* n); |
207 | TORCH_API void setRecordSourceLocation(void (*v)(Node*)); |
208 | |
209 | TORCH_API std::vector<StackEntry> pythonCallstack(); |
210 | TORCH_API void setPythonCallstack(std::vector<StackEntry> (*v)()); |
211 | |
212 | // Having finished adding a new 'node' to the graph IR 'setValueTrace' |
213 | // associates this node with an output variable, so that further operations |
214 | // involving this variable know which node in the IR to reference. |
215 | TORCH_API void setValueTrace(const IValue& v, Value* value); |
216 | |
217 | TORCH_API void delValueTrace(const IValue& var); |
218 | |
219 | TORCH_API std::function<void()> pauseTracing(); |
220 | |
221 | TORCH_API Value* getValueTrace(const IValue& var); |
222 | |
223 | TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> trace( |
224 | Stack inputs, |
225 | const std::function<Stack(Stack)>& traced_fn, |
226 | std::function<std::string(const Variable&)> var_name_lookup_fn, |
227 | bool strict = true, |
228 | bool force_outplace = false, |
229 | Module* self = nullptr, |
230 | const std::vector<std::string>& argument_names = {}); |
231 | |
232 | TORCH_API void abandon(); |
233 | |
234 | // NB: those serve both as an intermediate steps in addInputs below, |
235 | // as well as the overloads that terminate template recursion |
236 | TORCH_API void addInputs(Node* n, const char* name, int64_t value); |
237 | TORCH_API void addInputs(Node* n, const char* name, c10::SymInt value); |
238 | TORCH_API void addInputs( |
239 | Node* n, |
240 | const char* name, |
241 | c10::optional<int64_t> value); |
242 | TORCH_API void addInputs(Node* n, const char* name, bool value); |
243 | TORCH_API void addInputs( |
244 | Node* n, |
245 | const char* name, |
246 | const c10::optional<bool>& value); |
247 | TORCH_API void addInputs(Node* n, const char* name, double value); |
248 | TORCH_API void addInputs( |
249 | Node* n, |
250 | const char* name, |
251 | const c10::optional<double>& value); |
252 | TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value); |
253 | TORCH_API void addInputs( |
254 | Node* n, |
255 | const char* name, |
256 | const c10::optional<at::Scalar>& value); |
257 | TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value); |
258 | TORCH_API void addInputs( |
259 | Node* n, |
260 | const char* name, |
261 | const c10::optional<at::Tensor>& value); |
262 | TORCH_API void addInputs(Node* n, const char* name, ArrayRef<int64_t> value); |
263 | TORCH_API void addInputs(Node* n, const char* name, c10::SymIntArrayRef value); |
264 | TORCH_API void addInputs( |
265 | Node* n, |
266 | const char* name, |
267 | c10::optional<c10::SymInt> value); |
268 | TORCH_API void addInputs( |
269 | Node* n, |
270 | const char* name, |
271 | const c10::optional<ArrayRef<int64_t>>& value); |
272 | TORCH_API void addInputs( |
273 | Node* n, |
274 | const char* name, |
275 | const at::OptionalIntArrayRef& opt_value); |
276 | TORCH_API void addInputs( |
277 | Node* n, |
278 | const char* name, |
279 | const at::OptionalSymIntArrayRef& opt_value); |
280 | TORCH_API void addInputs( |
281 | Node* n, |
282 | const char* name, |
283 | ArrayRef<at::Tensor> value, |
284 | bool allow_undefined = false); |
285 | TORCH_API void addInputs( |
286 | Node* n, |
287 | const char* name, |
288 | std::vector<at::Tensor> value, |
289 | bool allow_undefined = false); |
290 | TORCH_API void addInputs( |
291 | Node* n, |
292 | const char* name, |
293 | at::ITensorListRef value, |
294 | bool allow_undefined = false); |
295 | TORCH_API void addInputs( |
296 | Node* n, |
297 | const char* name, |
298 | const List<c10::optional<at::Tensor>>& value); |
299 | TORCH_API void addInputs( |
300 | Node* n, |
301 | const char* name, |
302 | ArrayRef<c10::intrusive_ptr<c10::ivalue::Object>> value, |
303 | const c10::ClassTypePtr& class_type); |
304 | TORCH_API void addInputs(Node* n, const char* name, ArrayRef<double> value); |
305 | TORCH_API void addInputs( |
306 | Node* n, |
307 | const char* name, |
308 | const c10::optional<ArrayRef<double>>& value); |
309 | TORCH_API void addInputs( |
310 | Node* n, |
311 | const char* name, |
312 | const c10::string_view value); |
313 | TORCH_API void addInputs( |
314 | Node* n, |
315 | const char* name, |
316 | const c10::optional<c10::string_view>& value); |
317 | TORCH_API void addInputs(Node* n, const char* name, at::Device value); |
318 | TORCH_API void addInputs(Node* n, const char* name, c10::Stream stream); |
319 | TORCH_API void addInputs(Node* n, const char* name, at::Layout value); |
320 | TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value); |
321 | TORCH_API void addInputs( |
322 | Node* n, |
323 | const char* name, |
324 | const c10::optional<at::ScalarType>& value); |
325 | TORCH_API void addInputs( |
326 | Node* n, |
327 | const char* name, |
328 | const c10::optional<at::Device>& value); |
329 | TORCH_API void addInputs( |
330 | Node* n, |
331 | const char* name, |
332 | const c10::optional<at::Layout>& value); |
333 | TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value); |
334 | TORCH_API void addInputs( |
335 | Node* n, |
336 | const char* name, |
337 | c10::optional<at::DimnameList> value); |
338 | TORCH_API void addInputs( |
339 | Node* n, |
340 | const char* name, |
341 | const c10::optional<at::MemoryFormat>& value); |
342 | TORCH_API void addInputs( |
343 | Node* n, |
344 | const char* name, |
345 | const c10::optional<at::Generator>& value); |
346 | |
347 | inline void addInputs( |
348 | Node* n, |
349 | const char* name, |
350 | const std::vector<bool>& value) { |
351 | AT_ERROR("Tracing a list of bool type is currently not supported!" ); |
352 | } |
353 | |
354 | template <typename T> |
355 | void addInputs(Node* n, const char* name, ArrayRef<T> value) { |
356 | AT_ERROR("Tracing a list of arbitrary type is currently not supported!" ); |
357 | } |
358 | template <typename K, typename V> |
359 | void addInputs( |
360 | Node* n, |
361 | const char* name, |
362 | const std::unordered_map<K, V>& value) { |
363 | AT_ERROR("Tracing a dict of arbitrary types is currently not supported!" ); |
364 | } |
365 | |
366 | template <size_t N> |
367 | void addInputs(Node* n, const char* name, std::array<bool, N> value) { |
368 | throw std::runtime_error( |
369 | "Found an unsupported argument type in the JIT tracer. File a bug report." ); |
370 | } |
371 | |
372 | TORCH_API void addInputs( |
373 | Node* n, |
374 | const char* name, |
375 | const c10::intrusive_ptr<c10::ivalue::Object>& obj); |
376 | |
377 | TORCH_API void ensureUniqueIfOutOfPlaced( |
378 | const char* name, |
379 | const at::Tensor& tensor); |
380 | TORCH_API void ensureUniqueIfOutOfPlaced( |
381 | const char* name, |
382 | const c10::optional<at::Tensor>& tensor); |
383 | |
384 | template < |
385 | typename T, |
386 | typename = torch::enable_if_t<( |
387 | !std::is_convertible<torch::decay_t<T>, at::TensorList>::value && |
388 | !std::is_convertible<torch::decay_t<T>, c10::List<at::Tensor>>::value && |
389 | !std::is_convertible<torch::decay_t<T>, at::Tensor>::value && |
390 | !std::is_convertible< |
391 | torch::decay_t<T>, |
392 | c10::intrusive_ptr<c10::ivalue::Object>>::value)>> |
393 | void addOutput(Node* node, T&&) { |
394 | AT_ERROR( |
395 | "Found an unsupported argument type " , |
396 | c10::demangle_type<T>(), |
397 | " in the JIT tracer. File a bug report." ); |
398 | } |
399 | TORCH_API void addOutput(Node* node, const at::Tensor& tensor); |
400 | TORCH_API void setOutput(Value* value, const at::Tensor& output); |
401 | TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list); |
402 | TORCH_API void addOutput(Node* node, const c10::List<at::Tensor>& list); |
403 | TORCH_API void addOutput( |
404 | Node* node, |
405 | const c10::intrusive_ptr<c10::ivalue::Object>& output); |
406 | |
407 | TORCH_API autograd::Variable getSizeOf( |
408 | const autograd::Variable& var, |
409 | int64_t dim); |
410 | |
411 | TORCH_API autograd::Variable getNumelOf(const autograd::Variable& var); |
412 | |
413 | } // namespace tracer |
414 | } // namespace jit |
415 | } // namespace torch |
416 | |