1#include <ATen/TracerMode.h>
2#include <ATen/core/op_registration/op_registration.h>
3#include <c10/core/ScalarType.h>
4#include <c10/util/Optional.h>
5#include <c10/util/irange.h>
6#include <torch/csrc/jit/frontend/tracer.h>
7#include <torch/csrc/jit/ir/ir.h>
8#include <torch/csrc/utils/memory.h>
9#include <torch/library.h>
10
11using namespace at;
12
13namespace torch {
14namespace TraceType {
15
16namespace {
17
18Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
19 jit::Value* output = nullptr;
20 if (torch::jit::tracer::isTracing()) {
21 const jit::tracer::TracingState& state = *jit::tracer::getTracingState();
22 auto& graph = state.graph;
23 if (state.force_outplace && self.storage().use_count() <= 1) {
24 // if you have no views of self, then an in place copy is equivalent to
25 // making sure we expand src to the same size as self
26 jit::Node* node = graph->create(jit::aten::expand_as, /*num_outputs=*/1);
27 jit::tracer::addInputs(node, "src", src);
28 jit::tracer::addInputs(node, "self", self);
29 graph->insertNode(node);
30 output = node->output();
31 } else {
32 output = graph->insert(
33 jit::aten::copy_,
34 {jit::tracer::getValueTrace(self), jit::tracer::getValueTrace(src)});
35 jit::tracer::recordSourceLocation(output->node());
36 }
37 jit::tracer::ensureUniqueIfOutOfPlaced(
38 "copy_ (possibly due to an assignment)", self);
39 }
40
41 {
42 at::tracer::impl::NoTracerDispatchMode tracer_guard;
43 self.copy_(src, non_blocking);
44 }
45
46 if (torch::jit::tracer::isTracing()) {
47 jit::tracer::setOutput(output, self);
48 }
49 return self;
50}
51
52const Tensor& resize_(
53 const Tensor& self,
54 IntArrayRef size,
55 c10::optional<MemoryFormat> optional_memory_format) {
56 if (torch::jit::tracer::isTracing()) {
57 if (jit::tracer::ArgumentStash::hasIntArrayRef("size")) {
58 jit::tracer::ArgumentStash::popIntArrayRef("size");
59 }
60 jit::tracer::warn("resize_", jit::tracer::WARN_RESIZE);
61 jit::tracer::delValueTrace(self);
62 }
63
64 {
65 at::tracer::impl::NoTracerDispatchMode tracer_guard;
66 self.resize_(size, std::move(optional_memory_format));
67 }
68 return self;
69}
70
71const Tensor& resize_as_(
72 const Tensor& self,
73 const Tensor& the_template,
74 c10::optional<MemoryFormat> optional_memory_format) {
75 if (torch::jit::tracer::isTracing()) {
76 jit::tracer::warn("resize_as_", jit::tracer::WARN_RESIZE);
77 jit::tracer::delValueTrace(self);
78 }
79
80 {
81 at::tracer::impl::NoTracerDispatchMode tracer_guard;
82 self.resize_as_(the_template, std::move(optional_memory_format));
83 }
84 return self;
85}
86
87Tensor detach(const Tensor& self) {
88 torch::jit::Node* node = nullptr;
89 if (jit::tracer::isTracing()) {
90 auto& graph = jit::tracer::getTracingState()->graph;
91 node = graph->create(jit::aten::detach, /*num_outputs=*/0);
92 jit::tracer::recordSourceLocation(node);
93 jit::tracer::addInputs(node, "self", self);
94 graph->insertNode(node);
95 }
96
97 auto result = [&]() {
98 at::tracer::impl::NoTracerDispatchMode tracer_guard;
99 return self.detach();
100 }();
101
102 if (jit::tracer::isTracing()) {
103 jit::tracer::addOutput(node, result);
104 }
105 return result;
106}
107
108Tensor& detach_(Tensor& self) {
109 torch::jit::Node* node = nullptr;
110 if (jit::tracer::isTracing()) {
111 auto& graph = jit::tracer::getTracingState()->graph;
112 node = graph->create(jit::aten::detach, /*num_outputs=*/0);
113 jit::tracer::recordSourceLocation(node);
114 jit::tracer::addInputs(node, "self", self);
115 graph->insertNode(node);
116 jit::tracer::ensureUniqueIfOutOfPlaced("detach_", self);
117 }
118
119 {
120 at::tracer::impl::NoTracerDispatchMode tracer_guard;
121 self.detach_();
122 }
123
124 if (jit::tracer::isTracing()) {
125 jit::tracer::addOutput(node, self);
126 }
127 return self;
128}
129
130// Invariant:
131// - Ops registered to DispatchKey::Tracer below must be included in
132// `MANUAL_TRACER` in tools/autograd/gen_variable_type.py
133TORCH_LIBRARY_IMPL(aten, Tracer, m) {
134 m.impl("resize_", resize_);
135 m.impl("resize_as_", resize_as_);
136 m.impl("detach", TORCH_FN(detach));
137 m.impl("detach_", detach_);
138 m.impl("copy_", copy_);
139
140 // Skip tracing for the following ops by registering fallthrough kernel
141 // explicitly.
142 m.impl("_backward", CppFunction::makeFallthrough());
143 m.impl("set_data", CppFunction::makeFallthrough());
144 m.impl("data", CppFunction::makeFallthrough());
145 m.impl("is_leaf", CppFunction::makeFallthrough());
146 m.impl("output_nr", CppFunction::makeFallthrough());
147 m.impl("_version", CppFunction::makeFallthrough());
148 m.impl("requires_grad_", CppFunction::makeFallthrough());
149 m.impl("retain_grad", CppFunction::makeFallthrough());
150 m.impl("_fw_primal", CppFunction::makeFallthrough());
151 m.impl("_make_dual", CppFunction::makeFallthrough());
152}
153
154} // namespace
155
156} // namespace TraceType
157} // namespace torch
158
159namespace torch {
160namespace jit {
161void general_trace_function(const c10::OperatorHandle& op, Stack* stack) {
162 const auto input_size = op.schema().arguments().size();
163 const auto output_size = op.schema().returns().size();
164
165 Node* node = nullptr;
166 std::shared_ptr<tracer::TracingState> tracer_state;
167
168 // trace the input before unwrapping, otherwise we may lose
169 // the input information
170 if (tracer::isTracing()) {
171 tracer_state = tracer::getTracingState();
172 auto symbol = Symbol::fromQualString(op.schema().name());
173 const auto& graph = tracer::getTracingState()->graph;
174 node = graph->create(symbol, 0);
175 tracer::recordSourceLocation(node);
176 const auto& args = op.schema().arguments();
177 int i = 0;
178 for (auto iter = stack->end() - input_size; iter != stack->end();
179 ++iter, ++i) {
180 // TODO we need to refactor graph APIs (e.g., addInputs)
181 // appropriately; after that, we can get rid of the giant if-else
182 // block we will clean this tech debt together in the following PRs
183 auto type = args[i].type();
184 if (type->kind() == TypeKind::OptionalType) {
185 if (iter->isNone()) {
186 Value* none = graph->insertNode(graph->createNone())->output();
187 node->addInput(none);
188 continue;
189 } else {
190 type = type->expectRef<OptionalType>().getElementType();
191 }
192 }
193 if (type->isSubtypeOf(*TensorType::get())) {
194 AT_ASSERT(iter->isTensor());
195 tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
196 } else if (type->kind() == TypeKind::FloatType) {
197 AT_ASSERT(iter->isDouble());
198 tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
199 } else if (type->kind() == TypeKind::IntType) {
200 AT_ASSERT(iter->isInt());
201 tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
202 } else if (type->kind() == TypeKind::BoolType) {
203 AT_ASSERT(iter->isBool());
204 tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
205 } else if (type->kind() == TypeKind::StringType) {
206 AT_ASSERT(iter->isString());
207 tracer::addInputs(node, args[i].name().c_str(), iter->toStringView());
208 } else if (type->kind() == TypeKind::NumberType) {
209 tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());
210 } else if (type->kind() == TypeKind::ListType) {
211 const auto& elem_type = type->expectRef<ListType>().getElementType();
212 if (elem_type->isSubtypeOf(*TensorType::get())) {
213 AT_ASSERT(iter->isTensorList());
214 auto list = iter->toTensorVector();
215 tracer::addInputs(node, args[i].name().c_str(), list);
216 } else if (auto class_type = elem_type->cast<ClassType>()) {
217 AT_ASSERT(iter->isList());
218 auto list = iter->toList();
219 std::vector<c10::intrusive_ptr<c10::ivalue::Object>> objects;
220 for (IValue iv : list) {
221 objects.emplace_back(std::move(iv).toObject());
222 }
223 tracer::addInputs(node, args[i].name().c_str(), objects, class_type);
224 } else if (elem_type->kind() == TypeKind::FloatType) {
225 AT_ASSERT(iter->isDoubleList());
226 // NB: now, tracer doesn't support tracing double list. We add
227 // special handling here, since in our case, we assume that all the
228 // doubles in the list are constants
229 auto value = iter->toDoubleVector();
230 std::vector<Value*> info(value.size());
231 for (const auto value_index : c10::irange(value.size())) {
232 info[value_index] = graph->insertConstant(value[value_index]);
233 tracer::recordSourceLocation(info[value_index]->node());
234 }
235 node->addInput(
236 graph->insertNode(graph->createList(FloatType::get(), info))
237 ->output());
238 } else if (elem_type->kind() == TypeKind::IntType) {
239 AT_ASSERT(iter->isIntList());
240 tracer::addInputs(
241 node,
242 args[i].name().c_str(),
243 c10::IntArrayRef(iter->toIntVector()));
244 } else if (elem_type->kind() == TypeKind::BoolType) {
245 AT_ASSERT(iter->isBoolList());
246 tracer::addInputs(
247 node, args[i].name().c_str(), iter->toBoolList().vec());
248 } else {
249 throw std::runtime_error(
250 "unsupported input list type: " + elem_type->str());
251 }
252 } else if (iter->isObject()) {
253 tracer::addInputs(node, args[i].name().c_str(), iter->toObject());
254 } else {
255 throw std::runtime_error("unsupported input type: " + type->str());
256 }
257 }
258 graph->insertNode(node);
259
260 tracer::setTracingState(nullptr);
261 }
262
263 op.callBoxed(stack);
264
265 if (tracer_state) {
266 tracer::setTracingState(std::move(tracer_state));
267 int i = 0;
268 for (auto iter = stack->end() - output_size; iter != stack->end();
269 ++iter, ++i) {
270 const auto& type = op.schema().returns()[i].type();
271 if (type->isSubtypeOf(*TensorType::get())) {
272 AT_ASSERT(iter->isTensor());
273 tracer::addOutput(node, iter->toTensor());
274 } else if (type->kind() == TypeKind::ListType) {
275 const auto& elem_type = type->expectRef<ListType>().getElementType();
276 if (elem_type->isSubtypeOf(*TensorType::get())) {
277 AT_ASSERT(iter->isTensorList());
278 tracer::addOutput(node, iter->toTensorList());
279 } else {
280 throw std::runtime_error(
281 "unsupported ouptut list type: " + elem_type->str());
282 }
283 } else if (type->kind() == TypeKind::ClassType) {
284 AT_ASSERT(iter->isObject());
285 tracer::addOutput(node, iter->toObject());
286 } else {
287 throw std::runtime_error(
288 "unsupported output type: " + type->str() +
289 ", from operator: " + toString(op.operator_name()));
290 }
291 }
292 }
293}
294TORCH_LIBRARY_IMPL(_, Tracer, m) {
295 m.fallback(CppFunction::makeFromBoxedFunction<&general_trace_function>());
296}
297
298} // namespace jit
299} // namespace torch
300