1#pragma once
2
3#include <ATen/core/Dimname.h>
4#include <ATen/core/class_type.h>
5#include <ATen/core/jit_type.h>
6#include <ATen/core/stack.h>
7#include <ATen/core/symbol.h>
8#include <c10/util/Exception.h>
9#include <torch/csrc/Export.h>
10
11#include <torch/csrc/jit/frontend/source_range.h>
12#include <torch/csrc/utils/variadic.h>
13
14#include <cstdint>
15#include <iostream>
16#include <memory>
17#include <mutex>
18#include <unordered_map>
19#include <vector>
20
21namespace torch {
22namespace jit {
23struct Node;
24struct Value;
25struct Graph;
26struct Module;
27
28namespace tracer {
29
30using ::c10::ivalue::Shared;
31
32using ::c10::IValue;
33using ::c10::ivalue::Future;
34
35using ::c10::ArrayRef;
36using ::c10::TupleType;
37using ::c10::TupleTypePtr;
38using ::c10::ivalue::ConstantString;
39
40using torch::autograd::Variable;
41using variable_list = std::vector<Variable>;
42
43TORCH_API std::atomic<bool>& getTracerStateWarnMode();
44
45struct TORCH_API TracingState
46 : public std::enable_shared_from_this<TracingState> {
47 TracingState();
48 ~TracingState();
49
50 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
51 std::shared_ptr<Graph> graph;
52 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
53 bool warn = getTracerStateWarnMode();
54 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
55 bool strict = true;
56 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
57 bool force_outplace = false;
58 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
59 std::function<std::string(const Variable& var)> lookup_var_name_fn =
60 [](const Variable& var) { return ""; };
61
62 void enterFrame() {
63 env_stack.emplace_back();
64 }
65
66 void leaveFrame() {
67 env_stack.pop_back();
68 }
69
70 void setValue(const IValue& v, Value* value);
71 void delValue(const IValue& var);
72 Value* getValue(const IValue& var);
73 Value* getOutput(const IValue& var, size_t i);
74 bool hasValue(const IValue& var) const;
75
76 Node* createNode(c10::Symbol op_name, size_t num_outputs);
77 void insertNode(Node* node);
78
79 private:
80 using WeakIValue = at::WeakIValue;
81
82 struct WeakIValueHasher {
83 size_t operator()(const WeakIValue& t) const {
84 return t.hash();
85 }
86 };
87
88 struct WeakIValueEq {
89 bool operator()(const WeakIValue& t1, const WeakIValue& t2) const {
90 return t1.isSameIdentity(t2);
91 }
92 };
93
94 using Frame =
95 std::unordered_map<WeakIValue, Value*, WeakIValueHasher, WeakIValueEq>;
96 std::vector<Frame> env_stack;
97};
98
99// This is meant to be used as a thread local place, where we can store extra
100// info that gets lost when we call into ATen from Python bindings. One example
101// for when this happens is when we get an IntArrayRef argument with e.g. sizes
102// for view. When tracing, those might be tensors, which let us encode extra
103// data dependencies, but once they get to the ATen call where we actually have
104// the tracing logic, they get converted into a raw IntArrayRef, and we loose
105// all information. To prevent this, we temporarily stash it in here.
106// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
107struct ArgumentStash {
108 struct IntArrayRefTrace : std::vector<Value*> {
109 IntArrayRefTrace(int size) : std::vector<Value*>(size, nullptr) {}
110 };
111
112 static bool empty() {
113 return stash.intlists.empty();
114 }
115
116 TORCH_API static void stashIntArrayRefElem(
117 const std::string& arg_name,
118 size_t size,
119 size_t idx,
120 const Variable& var);
121
122 static bool hasIntArrayRef(const std::string& arg_name) {
123 return stash.intlists.count(arg_name) > 0;
124 }
125
126 static IntArrayRefTrace popIntArrayRef(const std::string& arg_name) {
127 auto info = std::move(stash.intlists.at(arg_name));
128 stash.intlists.erase(arg_name);
129 return info;
130 }
131
132 // Value stashing: Use these methods to stash arguments which correspond
133 // to regular Value*'s in the graph. i.e. they don't require special
134 // handling like in the case of IntArrayRefs
135 TORCH_API static void stashValue(
136 const std::string& arg_name,
137 size_t idx,
138 const Variable& var,
139 const c10::TypePtr& type = nullptr);
140
141 static bool hasValue(const std::string& arg_name) {
142 return stash.values.count(arg_name) > 0;
143 }
144
145 static Value* popValue(const std::string& arg_name) {
146 auto info = stash.values.at(arg_name);
147 stash.values.erase(arg_name);
148 return info;
149 }
150
151 private:
152 static thread_local ArgumentStash stash;
153 std::unordered_map<std::string, IntArrayRefTrace> intlists;
154 std::unordered_map<std::string, Value*> values;
155};
156
157// Retrieve or set the current tracing state. Returns a nullptr if tracing is
158// disabled.
159TORCH_API const std::shared_ptr<TracingState>& getTracingState();
160TORCH_API void setTracingState(std::shared_ptr<TracingState> state);
161
162inline bool isTracing() {
163 return static_cast<bool>(getTracingState());
164}
165
166using warn_fn_type = void (*)(const std::string& msg);
167TORCH_API extern const char* WARN_PYTHON_DATAFLOW;
168TORCH_API extern const char* WARN_CONSTRUCTOR;
169TORCH_API extern const char* WARN_RESIZE;
170TORCH_API extern const char* STRICT_TRACER_MSG;
171TORCH_API void _do_warn(const char* _reason, const char* _kind);
172inline void warn(const char* _reason, const char* _kind = nullptr) {
173 if (const auto& state = getTracingState()) {
174 if (!state->warn)
175 return;
176 _do_warn(_reason, _kind);
177 }
178}
179TORCH_API void setWarn(warn_fn_type fn);
180
181struct TORCH_API NoWarn {
182 NoWarn() : state(getTracingState()) {
183 if (state) {
184 prev = state->warn;
185 state->warn = false;
186 }
187 }
188 ~NoWarn() {
189 if (state) {
190 state->warn = prev;
191 }
192 }
193 std::shared_ptr<TracingState> state;
194 bool prev{false};
195};
196
197struct WithNestedTracingFrame {
198 WithNestedTracingFrame() {
199 getTracingState()->enterFrame();
200 }
201
202 ~WithNestedTracingFrame() {
203 getTracingState()->leaveFrame();
204 }
205};
206TORCH_API void recordSourceLocation(Node* n);
207TORCH_API void setRecordSourceLocation(void (*v)(Node*));
208
209TORCH_API std::vector<StackEntry> pythonCallstack();
210TORCH_API void setPythonCallstack(std::vector<StackEntry> (*v)());
211
212// Having finished adding a new 'node' to the graph IR 'setValueTrace'
213// associates this node with an output variable, so that further operations
214// involving this variable know which node in the IR to reference.
215TORCH_API void setValueTrace(const IValue& v, Value* value);
216
217TORCH_API void delValueTrace(const IValue& var);
218
219TORCH_API std::function<void()> pauseTracing();
220
221TORCH_API Value* getValueTrace(const IValue& var);
222
223TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> trace(
224 Stack inputs,
225 const std::function<Stack(Stack)>& traced_fn,
226 std::function<std::string(const Variable&)> var_name_lookup_fn,
227 bool strict = true,
228 bool force_outplace = false,
229 Module* self = nullptr,
230 const std::vector<std::string>& argument_names = {});
231
232TORCH_API void abandon();
233
234// NB: those serve both as an intermediate steps in addInputs below,
235// as well as the overloads that terminate template recursion
236TORCH_API void addInputs(Node* n, const char* name, int64_t value);
237TORCH_API void addInputs(Node* n, const char* name, c10::SymInt value);
238TORCH_API void addInputs(
239 Node* n,
240 const char* name,
241 c10::optional<int64_t> value);
242TORCH_API void addInputs(Node* n, const char* name, bool value);
243TORCH_API void addInputs(
244 Node* n,
245 const char* name,
246 const c10::optional<bool>& value);
247TORCH_API void addInputs(Node* n, const char* name, double value);
248TORCH_API void addInputs(
249 Node* n,
250 const char* name,
251 const c10::optional<double>& value);
252TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value);
253TORCH_API void addInputs(
254 Node* n,
255 const char* name,
256 const c10::optional<at::Scalar>& value);
257TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value);
258TORCH_API void addInputs(
259 Node* n,
260 const char* name,
261 const c10::optional<at::Tensor>& value);
262TORCH_API void addInputs(Node* n, const char* name, ArrayRef<int64_t> value);
263TORCH_API void addInputs(Node* n, const char* name, c10::SymIntArrayRef value);
264TORCH_API void addInputs(
265 Node* n,
266 const char* name,
267 c10::optional<c10::SymInt> value);
268TORCH_API void addInputs(
269 Node* n,
270 const char* name,
271 const c10::optional<ArrayRef<int64_t>>& value);
272TORCH_API void addInputs(
273 Node* n,
274 const char* name,
275 const at::OptionalIntArrayRef& opt_value);
276TORCH_API void addInputs(
277 Node* n,
278 const char* name,
279 const at::OptionalSymIntArrayRef& opt_value);
280TORCH_API void addInputs(
281 Node* n,
282 const char* name,
283 ArrayRef<at::Tensor> value,
284 bool allow_undefined = false);
285TORCH_API void addInputs(
286 Node* n,
287 const char* name,
288 std::vector<at::Tensor> value,
289 bool allow_undefined = false);
290TORCH_API void addInputs(
291 Node* n,
292 const char* name,
293 at::ITensorListRef value,
294 bool allow_undefined = false);
295TORCH_API void addInputs(
296 Node* n,
297 const char* name,
298 const List<c10::optional<at::Tensor>>& value);
299TORCH_API void addInputs(
300 Node* n,
301 const char* name,
302 ArrayRef<c10::intrusive_ptr<c10::ivalue::Object>> value,
303 const c10::ClassTypePtr& class_type);
304TORCH_API void addInputs(Node* n, const char* name, ArrayRef<double> value);
305TORCH_API void addInputs(
306 Node* n,
307 const char* name,
308 const c10::optional<ArrayRef<double>>& value);
309TORCH_API void addInputs(
310 Node* n,
311 const char* name,
312 const c10::string_view value);
313TORCH_API void addInputs(
314 Node* n,
315 const char* name,
316 const c10::optional<c10::string_view>& value);
317TORCH_API void addInputs(Node* n, const char* name, at::Device value);
318TORCH_API void addInputs(Node* n, const char* name, c10::Stream stream);
319TORCH_API void addInputs(Node* n, const char* name, at::Layout value);
320TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value);
321TORCH_API void addInputs(
322 Node* n,
323 const char* name,
324 const c10::optional<at::ScalarType>& value);
325TORCH_API void addInputs(
326 Node* n,
327 const char* name,
328 const c10::optional<at::Device>& value);
329TORCH_API void addInputs(
330 Node* n,
331 const char* name,
332 const c10::optional<at::Layout>& value);
333TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value);
334TORCH_API void addInputs(
335 Node* n,
336 const char* name,
337 c10::optional<at::DimnameList> value);
338TORCH_API void addInputs(
339 Node* n,
340 const char* name,
341 const c10::optional<at::MemoryFormat>& value);
342TORCH_API void addInputs(
343 Node* n,
344 const char* name,
345 const c10::optional<at::Generator>& value);
346
347inline void addInputs(
348 Node* n,
349 const char* name,
350 const std::vector<bool>& value) {
351 AT_ERROR("Tracing a list of bool type is currently not supported!");
352}
353
354template <typename T>
355void addInputs(Node* n, const char* name, ArrayRef<T> value) {
356 AT_ERROR("Tracing a list of arbitrary type is currently not supported!");
357}
358template <typename K, typename V>
359void addInputs(
360 Node* n,
361 const char* name,
362 const std::unordered_map<K, V>& value) {
363 AT_ERROR("Tracing a dict of arbitrary types is currently not supported!");
364}
365
366template <size_t N>
367void addInputs(Node* n, const char* name, std::array<bool, N> value) {
368 throw std::runtime_error(
369 "Found an unsupported argument type in the JIT tracer. File a bug report.");
370}
371
372TORCH_API void addInputs(
373 Node* n,
374 const char* name,
375 const c10::intrusive_ptr<c10::ivalue::Object>& obj);
376
377TORCH_API void ensureUniqueIfOutOfPlaced(
378 const char* name,
379 const at::Tensor& tensor);
380TORCH_API void ensureUniqueIfOutOfPlaced(
381 const char* name,
382 const c10::optional<at::Tensor>& tensor);
383
384template <
385 typename T,
386 typename = torch::enable_if_t<(
387 !std::is_convertible<torch::decay_t<T>, at::TensorList>::value &&
388 !std::is_convertible<torch::decay_t<T>, c10::List<at::Tensor>>::value &&
389 !std::is_convertible<torch::decay_t<T>, at::Tensor>::value &&
390 !std::is_convertible<
391 torch::decay_t<T>,
392 c10::intrusive_ptr<c10::ivalue::Object>>::value)>>
393void addOutput(Node* node, T&&) {
394 AT_ERROR(
395 "Found an unsupported argument type ",
396 c10::demangle_type<T>(),
397 " in the JIT tracer. File a bug report.");
398}
399TORCH_API void addOutput(Node* node, const at::Tensor& tensor);
400TORCH_API void setOutput(Value* value, const at::Tensor& output);
401TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list);
402TORCH_API void addOutput(Node* node, const c10::List<at::Tensor>& list);
403TORCH_API void addOutput(
404 Node* node,
405 const c10::intrusive_ptr<c10::ivalue::Object>& output);
406
407TORCH_API autograd::Variable getSizeOf(
408 const autograd::Variable& var,
409 int64_t dim);
410
411TORCH_API autograd::Variable getNumelOf(const autograd::Variable& var);
412
413} // namespace tracer
414} // namespace jit
415} // namespace torch
416