1 | #include <ATen/TracerMode.h> |
2 | #include <ATen/core/op_registration/op_registration.h> |
3 | #include <c10/core/ScalarType.h> |
4 | #include <c10/util/Optional.h> |
5 | #include <c10/util/irange.h> |
6 | #include <torch/csrc/jit/frontend/tracer.h> |
7 | #include <torch/csrc/jit/ir/ir.h> |
8 | #include <torch/csrc/utils/memory.h> |
9 | #include <torch/library.h> |
10 | |
11 | using namespace at; |
12 | |
13 | namespace torch { |
14 | namespace TraceType { |
15 | |
16 | namespace { |
17 | |
18 | Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) { |
19 | jit::Value* output = nullptr; |
20 | if (torch::jit::tracer::isTracing()) { |
21 | const jit::tracer::TracingState& state = *jit::tracer::getTracingState(); |
22 | auto& graph = state.graph; |
23 | if (state.force_outplace && self.storage().use_count() <= 1) { |
24 | // if you have no views of self, then an in place copy is equivalent to |
25 | // making sure we expand src to the same size as self |
26 | jit::Node* node = graph->create(jit::aten::expand_as, /*num_outputs=*/1); |
27 | jit::tracer::addInputs(node, "src" , src); |
28 | jit::tracer::addInputs(node, "self" , self); |
29 | graph->insertNode(node); |
30 | output = node->output(); |
31 | } else { |
32 | output = graph->insert( |
33 | jit::aten::copy_, |
34 | {jit::tracer::getValueTrace(self), jit::tracer::getValueTrace(src)}); |
35 | jit::tracer::recordSourceLocation(output->node()); |
36 | } |
37 | jit::tracer::ensureUniqueIfOutOfPlaced( |
38 | "copy_ (possibly due to an assignment)" , self); |
39 | } |
40 | |
41 | { |
42 | at::tracer::impl::NoTracerDispatchMode tracer_guard; |
43 | self.copy_(src, non_blocking); |
44 | } |
45 | |
46 | if (torch::jit::tracer::isTracing()) { |
47 | jit::tracer::setOutput(output, self); |
48 | } |
49 | return self; |
50 | } |
51 | |
52 | const Tensor& resize_( |
53 | const Tensor& self, |
54 | IntArrayRef size, |
55 | c10::optional<MemoryFormat> optional_memory_format) { |
56 | if (torch::jit::tracer::isTracing()) { |
57 | if (jit::tracer::ArgumentStash::hasIntArrayRef("size" )) { |
58 | jit::tracer::ArgumentStash::popIntArrayRef("size" ); |
59 | } |
60 | jit::tracer::warn("resize_" , jit::tracer::WARN_RESIZE); |
61 | jit::tracer::delValueTrace(self); |
62 | } |
63 | |
64 | { |
65 | at::tracer::impl::NoTracerDispatchMode tracer_guard; |
66 | self.resize_(size, std::move(optional_memory_format)); |
67 | } |
68 | return self; |
69 | } |
70 | |
71 | const Tensor& resize_as_( |
72 | const Tensor& self, |
73 | const Tensor& the_template, |
74 | c10::optional<MemoryFormat> optional_memory_format) { |
75 | if (torch::jit::tracer::isTracing()) { |
76 | jit::tracer::warn("resize_as_" , jit::tracer::WARN_RESIZE); |
77 | jit::tracer::delValueTrace(self); |
78 | } |
79 | |
80 | { |
81 | at::tracer::impl::NoTracerDispatchMode tracer_guard; |
82 | self.resize_as_(the_template, std::move(optional_memory_format)); |
83 | } |
84 | return self; |
85 | } |
86 | |
87 | Tensor detach(const Tensor& self) { |
88 | torch::jit::Node* node = nullptr; |
89 | if (jit::tracer::isTracing()) { |
90 | auto& graph = jit::tracer::getTracingState()->graph; |
91 | node = graph->create(jit::aten::detach, /*num_outputs=*/0); |
92 | jit::tracer::recordSourceLocation(node); |
93 | jit::tracer::addInputs(node, "self" , self); |
94 | graph->insertNode(node); |
95 | } |
96 | |
97 | auto result = [&]() { |
98 | at::tracer::impl::NoTracerDispatchMode tracer_guard; |
99 | return self.detach(); |
100 | }(); |
101 | |
102 | if (jit::tracer::isTracing()) { |
103 | jit::tracer::addOutput(node, result); |
104 | } |
105 | return result; |
106 | } |
107 | |
108 | Tensor& detach_(Tensor& self) { |
109 | torch::jit::Node* node = nullptr; |
110 | if (jit::tracer::isTracing()) { |
111 | auto& graph = jit::tracer::getTracingState()->graph; |
112 | node = graph->create(jit::aten::detach, /*num_outputs=*/0); |
113 | jit::tracer::recordSourceLocation(node); |
114 | jit::tracer::addInputs(node, "self" , self); |
115 | graph->insertNode(node); |
116 | jit::tracer::ensureUniqueIfOutOfPlaced("detach_" , self); |
117 | } |
118 | |
119 | { |
120 | at::tracer::impl::NoTracerDispatchMode tracer_guard; |
121 | self.detach_(); |
122 | } |
123 | |
124 | if (jit::tracer::isTracing()) { |
125 | jit::tracer::addOutput(node, self); |
126 | } |
127 | return self; |
128 | } |
129 | |
130 | // Invariant: |
131 | // - Ops registered to DispatchKey::Tracer below must be included in |
132 | // `MANUAL_TRACER` in tools/autograd/gen_variable_type.py |
133 | TORCH_LIBRARY_IMPL(aten, Tracer, m) { |
134 | m.impl("resize_" , resize_); |
135 | m.impl("resize_as_" , resize_as_); |
136 | m.impl("detach" , TORCH_FN(detach)); |
137 | m.impl("detach_" , detach_); |
138 | m.impl("copy_" , copy_); |
139 | |
140 | // Skip tracing for the following ops by registering fallthrough kernel |
141 | // explicitly. |
142 | m.impl("_backward" , CppFunction::makeFallthrough()); |
143 | m.impl("set_data" , CppFunction::makeFallthrough()); |
144 | m.impl("data" , CppFunction::makeFallthrough()); |
145 | m.impl("is_leaf" , CppFunction::makeFallthrough()); |
146 | m.impl("output_nr" , CppFunction::makeFallthrough()); |
147 | m.impl("_version" , CppFunction::makeFallthrough()); |
148 | m.impl("requires_grad_" , CppFunction::makeFallthrough()); |
149 | m.impl("retain_grad" , CppFunction::makeFallthrough()); |
150 | m.impl("_fw_primal" , CppFunction::makeFallthrough()); |
151 | m.impl("_make_dual" , CppFunction::makeFallthrough()); |
152 | } |
153 | |
154 | } // namespace |
155 | |
156 | } // namespace TraceType |
157 | } // namespace torch |
158 | |
159 | namespace torch { |
160 | namespace jit { |
161 | void general_trace_function(const c10::OperatorHandle& op, Stack* stack) { |
162 | const auto input_size = op.schema().arguments().size(); |
163 | const auto output_size = op.schema().returns().size(); |
164 | |
165 | Node* node = nullptr; |
166 | std::shared_ptr<tracer::TracingState> tracer_state; |
167 | |
168 | // trace the input before unwrapping, otherwise we may lose |
169 | // the input information |
170 | if (tracer::isTracing()) { |
171 | tracer_state = tracer::getTracingState(); |
172 | auto symbol = Symbol::fromQualString(op.schema().name()); |
173 | const auto& graph = tracer::getTracingState()->graph; |
174 | node = graph->create(symbol, 0); |
175 | tracer::recordSourceLocation(node); |
176 | const auto& args = op.schema().arguments(); |
177 | int i = 0; |
178 | for (auto iter = stack->end() - input_size; iter != stack->end(); |
179 | ++iter, ++i) { |
180 | // TODO we need to refactor graph APIs (e.g., addInputs) |
181 | // appropriately; after that, we can get rid of the giant if-else |
182 | // block we will clean this tech debt together in the following PRs |
183 | auto type = args[i].type(); |
184 | if (type->kind() == TypeKind::OptionalType) { |
185 | if (iter->isNone()) { |
186 | Value* none = graph->insertNode(graph->createNone())->output(); |
187 | node->addInput(none); |
188 | continue; |
189 | } else { |
190 | type = type->expectRef<OptionalType>().getElementType(); |
191 | } |
192 | } |
193 | if (type->isSubtypeOf(*TensorType::get())) { |
194 | AT_ASSERT(iter->isTensor()); |
195 | tracer::addInputs(node, args[i].name().c_str(), iter->toTensor()); |
196 | } else if (type->kind() == TypeKind::FloatType) { |
197 | AT_ASSERT(iter->isDouble()); |
198 | tracer::addInputs(node, args[i].name().c_str(), iter->toDouble()); |
199 | } else if (type->kind() == TypeKind::IntType) { |
200 | AT_ASSERT(iter->isInt()); |
201 | tracer::addInputs(node, args[i].name().c_str(), iter->toInt()); |
202 | } else if (type->kind() == TypeKind::BoolType) { |
203 | AT_ASSERT(iter->isBool()); |
204 | tracer::addInputs(node, args[i].name().c_str(), iter->toBool()); |
205 | } else if (type->kind() == TypeKind::StringType) { |
206 | AT_ASSERT(iter->isString()); |
207 | tracer::addInputs(node, args[i].name().c_str(), iter->toStringView()); |
208 | } else if (type->kind() == TypeKind::NumberType) { |
209 | tracer::addInputs(node, args[i].name().c_str(), iter->toScalar()); |
210 | } else if (type->kind() == TypeKind::ListType) { |
211 | const auto& elem_type = type->expectRef<ListType>().getElementType(); |
212 | if (elem_type->isSubtypeOf(*TensorType::get())) { |
213 | AT_ASSERT(iter->isTensorList()); |
214 | auto list = iter->toTensorVector(); |
215 | tracer::addInputs(node, args[i].name().c_str(), list); |
216 | } else if (auto class_type = elem_type->cast<ClassType>()) { |
217 | AT_ASSERT(iter->isList()); |
218 | auto list = iter->toList(); |
219 | std::vector<c10::intrusive_ptr<c10::ivalue::Object>> objects; |
220 | for (IValue iv : list) { |
221 | objects.emplace_back(std::move(iv).toObject()); |
222 | } |
223 | tracer::addInputs(node, args[i].name().c_str(), objects, class_type); |
224 | } else if (elem_type->kind() == TypeKind::FloatType) { |
225 | AT_ASSERT(iter->isDoubleList()); |
226 | // NB: now, tracer doesn't support tracing double list. We add |
227 | // special handling here, since in our case, we assume that all the |
228 | // doubles in the list are constants |
229 | auto value = iter->toDoubleVector(); |
230 | std::vector<Value*> info(value.size()); |
231 | for (const auto value_index : c10::irange(value.size())) { |
232 | info[value_index] = graph->insertConstant(value[value_index]); |
233 | tracer::recordSourceLocation(info[value_index]->node()); |
234 | } |
235 | node->addInput( |
236 | graph->insertNode(graph->createList(FloatType::get(), info)) |
237 | ->output()); |
238 | } else if (elem_type->kind() == TypeKind::IntType) { |
239 | AT_ASSERT(iter->isIntList()); |
240 | tracer::addInputs( |
241 | node, |
242 | args[i].name().c_str(), |
243 | c10::IntArrayRef(iter->toIntVector())); |
244 | } else if (elem_type->kind() == TypeKind::BoolType) { |
245 | AT_ASSERT(iter->isBoolList()); |
246 | tracer::addInputs( |
247 | node, args[i].name().c_str(), iter->toBoolList().vec()); |
248 | } else { |
249 | throw std::runtime_error( |
250 | "unsupported input list type: " + elem_type->str()); |
251 | } |
252 | } else if (iter->isObject()) { |
253 | tracer::addInputs(node, args[i].name().c_str(), iter->toObject()); |
254 | } else { |
255 | throw std::runtime_error("unsupported input type: " + type->str()); |
256 | } |
257 | } |
258 | graph->insertNode(node); |
259 | |
260 | tracer::setTracingState(nullptr); |
261 | } |
262 | |
263 | op.callBoxed(stack); |
264 | |
265 | if (tracer_state) { |
266 | tracer::setTracingState(std::move(tracer_state)); |
267 | int i = 0; |
268 | for (auto iter = stack->end() - output_size; iter != stack->end(); |
269 | ++iter, ++i) { |
270 | const auto& type = op.schema().returns()[i].type(); |
271 | if (type->isSubtypeOf(*TensorType::get())) { |
272 | AT_ASSERT(iter->isTensor()); |
273 | tracer::addOutput(node, iter->toTensor()); |
274 | } else if (type->kind() == TypeKind::ListType) { |
275 | const auto& elem_type = type->expectRef<ListType>().getElementType(); |
276 | if (elem_type->isSubtypeOf(*TensorType::get())) { |
277 | AT_ASSERT(iter->isTensorList()); |
278 | tracer::addOutput(node, iter->toTensorList()); |
279 | } else { |
280 | throw std::runtime_error( |
281 | "unsupported ouptut list type: " + elem_type->str()); |
282 | } |
283 | } else if (type->kind() == TypeKind::ClassType) { |
284 | AT_ASSERT(iter->isObject()); |
285 | tracer::addOutput(node, iter->toObject()); |
286 | } else { |
287 | throw std::runtime_error( |
288 | "unsupported output type: " + type->str() + |
289 | ", from operator: " + toString(op.operator_name())); |
290 | } |
291 | } |
292 | } |
293 | } |
294 | TORCH_LIBRARY_IMPL(_, Tracer, m) { |
295 | m.fallback(CppFunction::makeFromBoxedFunction<&general_trace_function>()); |
296 | } |
297 | |
298 | } // namespace jit |
299 | } // namespace torch |
300 | |