1 | #pragma once |
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | #include <ATen/core/jit_type.h> |
5 | #include <ATen/core/qualified_name.h> |
6 | #include <ATen/core/stack.h> |
7 | #include <pybind11/complex.h> |
8 | #include <pybind11/pybind11.h> |
9 | #include <pybind11/pytypes.h> |
10 | #include <torch/csrc/Device.h> |
11 | #include <torch/csrc/Dtype.h> |
12 | #include <torch/csrc/Export.h> |
13 | #include <torch/csrc/Layout.h> |
14 | #include <torch/csrc/QScheme.h> |
15 | #include <torch/csrc/Stream.h> |
16 | #include <torch/csrc/jit/api/module.h> |
17 | #include <torch/csrc/jit/frontend/schema_matching.h> |
18 | #include <torch/csrc/jit/frontend/tracer.h> |
19 | #include <torch/csrc/jit/python/module_python.h> |
20 | #include <torch/csrc/jit/python/python_custom_class.h> |
21 | #include <torch/csrc/jit/python/python_tracer.h> |
22 | #include <torch/csrc/jit/resource_guard.h> |
23 | #include <torch/csrc/jit/runtime/operator.h> |
24 | #include <torch/csrc/utils/auto_gil.h> |
25 | #include <torch/csrc/utils/pybind.h> |
26 | #include <torch/csrc/utils/python_arg_parser.h> |
27 | #include <torch/csrc/utils/six.h> |
28 | #ifdef USE_DISTRIBUTED |
29 | #include <torch/csrc/distributed/rpc/py_rref.h> |
30 | #include <torch/csrc/distributed/rpc/rref_impl.h> |
31 | #endif |
32 | |
33 | #include <ATen/core/function_schema.h> |
34 | #include <c10/core/Stream.h> |
35 | #ifdef USE_C10D_NCCL |
36 | #include <c10/cuda/CUDACachingAllocator.h> |
37 | #include <c10/cuda/CUDAStream.h> |
38 | #endif |
39 | #include <c10/util/Exception.h> |
40 | #include <c10/util/Optional.h> |
41 | #include <c10/util/irange.h> |
42 | |
43 | #include <algorithm> |
44 | #include <cstddef> |
45 | #include <string> |
46 | #include <utility> |
47 | #include <vector> |
48 | |
49 | // The visibility attribute is to avoid a warning about storing a field in the |
50 | // struct that has a different visibility (from pybind) than the struct. |
51 | #ifdef _WIN32 |
52 | #define VISIBILITY_HIDDEN |
53 | #else |
54 | #define VISIBILITY_HIDDEN __attribute__((visibility("hidden"))) |
55 | #endif |
56 | |
57 | namespace torch { |
58 | namespace jit { |
59 | |
60 | void clear_registered_instances(void* ptr); |
61 | |
62 | TORCH_API IValue toIValue( |
63 | py::handle obj, |
64 | const TypePtr& type, |
65 | c10::optional<int32_t> N = c10::nullopt); |
66 | |
67 | TORCH_API py::object toPyObject(IValue ivalue); |
68 | |
69 | // Hack to overload the behavior of toIValue to accept Python |
70 | // numbers in places where a Tensor is expected |
71 | // See also torch::should_allow_numbers_as_tensors |
72 | class ToIValueAllowNumbersAsTensors { |
73 | bool old_; |
74 | |
75 | public: |
76 | ToIValueAllowNumbersAsTensors(bool enable); |
77 | ~ToIValueAllowNumbersAsTensors(); |
78 | }; |
79 | |
80 | // Wrap Python function to guard deref |
81 | // NB: Need VISIBILITY_HIDDEN for silencing compiler error, |
82 | // 'torch::jit::PythonFunctionGuard' declared with greater visibility than the |
83 | // type of its field 'torch::jit::PythonFunctionGuard::func_' |
84 | struct VISIBILITY_HIDDEN PythonFunctionGuard { |
85 | explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {} |
86 | |
87 | ~PythonFunctionGuard() { |
88 | pybind11::gil_scoped_acquire ag; |
89 | func_.dec_ref(); |
90 | // explicitly setting PyObject* to nullptr to prevent py::object's dtor to |
91 | // decref on the PyObject again. |
92 | // See Note [Destructing py::object] in python_ivalue.h |
93 | func_.ptr() = nullptr; |
94 | } |
95 | |
96 | py::function func_; |
97 | }; |
98 | |
99 | // The PythonFutureWrapper for ivalue::Future |
100 | // |
101 | // NB: VISIBILITY_HIDDEN is for silencing compiling error, |
102 | // "error: 'torch::jit::PythonFutureWrapper' declared with greater visibility |
103 | // than the type of its field 'torch::jit::PythonFutureWrapper::unwrap_func' |
104 | // [-Werror=attributes]" |
105 | // |
106 | // NB: inherit from enable_shared_from_this because then(py::function) needs to |
107 | // get a shared_ptr from this pointer. |
108 | struct VISIBILITY_HIDDEN PythonFutureWrapper |
109 | : std::enable_shared_from_this<PythonFutureWrapper> { |
110 | using UnwrapFunc = std::function<void(py::object)>; |
111 | |
112 | explicit PythonFutureWrapper( |
113 | c10::intrusive_ptr<c10::ivalue::Future> fut, |
114 | c10::optional<UnwrapFunc> unwrap_func = c10::nullopt) |
115 | : fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {} |
116 | |
117 | explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete; |
118 | PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete; |
119 | |
120 | bool done() { |
121 | return fut->completed(); |
122 | } |
123 | |
124 | py::object value() { |
125 | // acquiring GIL as toPyObject creates new py::object |
126 | // without grabbing the GIL. |
127 | py::gil_scoped_acquire acquire; |
128 | py::object py_obj = toPyObject(fut->value()); |
129 | // unwrap_func is a general compositional function that takes in a |
130 | // py::object and executes some python function. It is currently mostly used |
131 | // to throw python exceptions. |
132 | if (unwrap_func) { |
133 | (*unwrap_func)(py_obj); |
134 | } |
135 | return py_obj; |
136 | } |
137 | |
138 | py::object wait() { |
139 | fut->wait(); |
140 | if (jit::tracer::isTracing()) { |
141 | auto graph = jit::tracer::getTracingState()->graph; |
142 | |
143 | Value* fut_val = jit::tracer::getValueTrace(fut); |
144 | auto output = graph->insert(aten::wait, {fut_val}); |
145 | jit::tracer::setValueTrace(fut->value(), output); |
146 | } |
147 | return value(); |
148 | } |
149 | |
150 | // The py::function cb arg must take a std::shared_ptr<PythonFutureWrapper> |
151 | // (i.e., torch._C.Future) as the only argument. If the type mismatches, an |
152 | // error will be thrown when waiting for the value of this returned Future. |
153 | std::shared_ptr<PythonFutureWrapper> then(py::function cb) { |
154 | // We need this an additional layer of wrapper here to guard the |
155 | // destruction of the py::function object. Because, the |
156 | // Future owns a reference to the py::function in its callback |
157 | // vector, but Future does not acquire GIL on destruction. |
158 | auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb)); |
159 | |
160 | return std::make_shared<jit::PythonFutureWrapper>(fut->then( |
161 | // Capture a copy of the ivalue::Future instead of the `this` pointer |
162 | // because the PythonFutureWrapper object could have been deleted |
163 | // when the callbacks are fired. For example, RPC only captures the |
164 | // ivalue::Future instead of PythonFutureWrapper in JitFuture's |
165 | // callback functions. Hence, if user code does not hold a reference to |
166 | // this PythonFutureWrapper object, there is no guarantee that the |
167 | // PythonFutureWrapper is still valid when running the callback. |
168 | [pyFut(this->getPtr()), |
169 | pf(std::move(pf))](c10::ivalue::Future& /* unused */) -> IValue { |
170 | try { |
171 | pybind11::gil_scoped_acquire ag; |
172 | return toIValue(pf->func_(pyFut), PyObjectType::get()); |
173 | } catch (py::error_already_set& e) { |
174 | auto err = std::runtime_error(c10::str( |
175 | "Got the following error when running the callback: " , |
176 | e.what())); |
177 | { |
178 | pybind11::gil_scoped_acquire ag; |
179 | // Release ownership on py::objects and also restore Python |
180 | // Error Indicator. |
181 | e.restore(); |
182 | // Clear the Python Error Indicator as we has recorded the |
183 | // exception in the response message. |
184 | PyErr_Clear(); |
185 | } |
186 | |
187 | throw err; |
188 | } |
189 | }, |
190 | PyObjectType::get())); |
191 | } |
192 | |
193 | void add_done_callback(py::function cb) { |
194 | auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb)); |
195 | // NOLINTNEXTLINE(modernize-avoid-bind) |
196 | fut->addCallback(std::bind( |
197 | [pyFut(this->getPtr())](std::shared_ptr<PythonFunctionGuard> pf) { |
198 | try { |
199 | pybind11::gil_scoped_acquire ag; |
200 | pf->func_(pyFut); |
201 | } catch (py::error_already_set& e) { |
202 | { |
203 | pybind11::gil_scoped_acquire ag; |
204 | // Release ownership on py::objects and also restore Python |
205 | // Error Indicator. |
206 | e.restore(); |
207 | // Clear the Python Error Indicator as we has recorded the |
208 | // exception in the response message. |
209 | PyErr_Clear(); |
210 | } |
211 | // Log and ignore exceptions raised through the callback |
212 | LOG(ERROR) << "Got the following error when running the callback: " |
213 | << e.what(); |
214 | |
215 | } catch (const std::exception& e) { |
216 | // Log and ignore exceptions raised through the callback |
217 | LOG(ERROR) << "Got the following error when running the callback: " |
218 | << e.what(); |
219 | } |
220 | }, |
221 | std::move(pf))); |
222 | } |
223 | |
224 | void markCompleted(const py::object& pyValue) { |
225 | DCHECK(PyGILState_Check()); |
226 | IValue value = toIValue(pyValue, PyObjectType::get()); |
227 | |
228 | py::gil_scoped_release release; |
229 | fut->markCompleted(std::move(value)); |
230 | } |
231 | |
232 | c10::intrusive_ptr<c10::ivalue::Future> fut; |
233 | // unwrap_func works like a callback for the value returned by |
234 | // PythonFutureWrapper::wait(). |
235 | c10::optional<UnwrapFunc> unwrap_func; |
236 | |
237 | private: |
238 | std::shared_ptr<PythonFutureWrapper> getPtr() { |
239 | return shared_from_this(); |
240 | } |
241 | }; |
242 | |
243 | // The PythonAwaitWrapper for ivalue::Await |
244 | // |
245 | // Expresses delayed function execution with Lazy semantic. |
246 | // i.e. Await[W] in eager mode can be used as W. |
247 | // When the attribute of W type is requested, Await[W] will return the |
248 | // attribute of W, transparently calling wait() beforehand. |
249 | // No Lazy semantic for script, explicit wait(Await[W]) -> W must be called to |
250 | // convert to type W. |
251 | // |
252 | // The Await object takes shared ownership of specified function and the |
253 | // arguments. After first call for wait() it owns the result. Deliberately no |
254 | // type inference for eager mode. |
255 | struct VISIBILITY_HIDDEN PythonAwaitWrapper |
256 | : std::enable_shared_from_this<PythonAwaitWrapper> { |
257 | explicit PythonAwaitWrapper(c10::intrusive_ptr<c10::ivalue::Await> aw) |
258 | : aw_(std::move(aw)) {} |
259 | explicit PythonAwaitWrapper(py::handle input) { |
260 | args_ = py::tuple(1u); |
261 | args_[0] = input; |
262 | auto type = PyObjectType::get(); |
263 | aw_ = c10::make_intrusive<c10::ivalue::Await>(type); |
264 | aw_->markCompleted(toIValue(input, type)); |
265 | } |
266 | |
267 | explicit PythonAwaitWrapper(py::function pf, py::tuple args) { |
268 | pyfg_ = std::make_shared<torch::jit::PythonFunctionGuard>(std::move(pf)); |
269 | args_ = std::move(args); |
270 | std::function<IValue()> f = [fg(pyfg_), &args(args_)]() { |
271 | pybind11::gil_scoped_acquire ag; |
272 | return toIValue(fg->func_(*args), PyObjectType::get()); |
273 | }; |
274 | aw_ = c10::make_intrusive<c10::ivalue::Await>( |
275 | PyObjectType::get(), std::move(f)); |
276 | } |
277 | |
278 | explicit PythonAwaitWrapper(const PythonAwaitWrapper&) = delete; |
279 | PythonAwaitWrapper& operator=(const PythonAwaitWrapper&) = delete; |
280 | |
281 | py::object wait() { |
282 | py::gil_scoped_acquire acquire; |
283 | return toPyObject(aw_->wait()); |
284 | } |
285 | |
286 | // Nowait semantic means trivial case when Await is constructed from the |
287 | // result |
288 | bool is_nowait() { |
289 | return pyfg_ == nullptr; |
290 | } |
291 | |
292 | const py::function fn() { |
293 | TORCH_CHECK( |
294 | pyfg_, "Await constructed as awaitable_nowait does not have fn" ); |
295 | return pyfg_->func_; |
296 | } |
297 | |
298 | const py::tuple args() { |
299 | return args_; |
300 | } |
301 | |
302 | TypePtr type() { |
303 | return aw_->type(); |
304 | } |
305 | |
306 | c10::intrusive_ptr<c10::ivalue::Await> aw_; |
307 | std::shared_ptr<torch::jit::PythonFunctionGuard> pyfg_; |
308 | py::tuple args_; |
309 | |
310 | private: |
311 | std::shared_ptr<PythonAwaitWrapper> getPtr() { |
312 | return shared_from_this(); |
313 | } |
314 | }; |
315 | |
316 | // error reporting: when reporting user-caused errors, these functions should |
317 | // not use AT_ERROR macros, since these macros add stack trace information |
318 | // that is confusing to display to the end user since it always reports |
319 | // locations in libtorch code rather than user code. |
320 | |
321 | inline std::shared_ptr<CompilationUnit> get_python_cu() { |
322 | return py::module::import("torch.jit._state" ) |
323 | .attr("_python_cu" ) |
324 | .cast<std::shared_ptr<CompilationUnit>>(); |
325 | } |
326 | |
327 | struct TypedIValue : public std::pair<IValue, TypePtr> { |
328 | using pair::pair; |
329 | |
330 | IValue& ivalue() { |
331 | return this->first; |
332 | } |
333 | TypePtr& type() { |
334 | return this->second; |
335 | } |
336 | }; |
337 | |
338 | inline TypedIValue toDictKeyIValue(py::handle key) { |
339 | if (py::isinstance<py::str>(key)) { |
340 | return TypedIValue( |
341 | ConstantString::create(py::cast<std::string>(key)), StringType::get()); |
342 | } else if (py::isinstance<py::int_>(key)) { |
343 | return TypedIValue(py::cast<int64_t>(key), IntType::get()); |
344 | } else if (py::isinstance<py::float_>(key)) { |
345 | return TypedIValue(py::cast<double>(key), FloatType::get()); |
346 | } else { |
347 | AT_ERROR("Dictionary inputs may only have string, int, or float keys" ); |
348 | } |
349 | } |
350 | |
351 | inline c10::optional<TypePtr> unifyOrInitializeType( |
352 | const TypePtr& accum, |
353 | const TypePtr& unify) { |
354 | if (!accum) { |
355 | return unify; |
356 | } |
357 | return unifyTypes(accum, unify); |
358 | } |
359 | |
360 | using InferredType = c10::InferredType; |
361 | |
362 | InferredType tryToInferContainerType(py::handle input); |
363 | |
364 | // Try to infer the type of a Python object |
365 | // The type cannot be inferred if: |
366 | // input is an empty container (list, dict) |
367 | // input is an list with element types that cannot be unified |
368 | // input is an dict with key or value types that cannot be unified |
369 | inline InferredType tryToInferType(py::handle input) { |
370 | // Try tensor types |
371 | if (THPVariable_Check(input.ptr())) { |
372 | return InferredType(TensorType::get()); |
373 | } |
374 | |
375 | if (input.is_none()) { |
376 | return InferredType(NoneType::get()); |
377 | } |
378 | |
379 | if (py::isinstance<StrongFunctionPtr>(input)) { |
380 | auto fn = py::cast<StrongFunctionPtr>(input).function_; |
381 | return InferredType(FunctionType::create(fn)); |
382 | } |
383 | |
384 | // Try basic types first |
385 | if (py::isinstance<py::bool_>(input)) { |
386 | return InferredType(BoolType::get()); |
387 | // NOLINTNEXTLINE(bugprone-branch-clone) |
388 | } else if (py::isinstance<py::int_>(input)) { |
389 | return InferredType(IntType::get()); |
390 | } else if (py::isinstance<py::float_>(input)) { |
391 | return InferredType(FloatType::get()); |
392 | } else if (PyComplex_CheckExact(input.ptr())) { |
393 | return InferredType(ComplexType::get()); |
394 | } else if (py::isinstance<py::str>(input)) { |
395 | return InferredType(StringType::get()); |
396 | } else if (THPLayout_Check(input.ptr())) { |
397 | return InferredType(IntType::get()); |
398 | } else if (THPDevice_Check(input.ptr())) { |
399 | return InferredType(DeviceObjType::get()); |
400 | } else if (THPStream_Check(input.ptr())) { |
401 | return InferredType(StreamObjType::get()); |
402 | } else if (THPDtype_Check(input.ptr())) { |
403 | return InferredType(IntType::get()); |
404 | } else if (THPQScheme_Check(input.ptr())) { |
405 | return InferredType(IntType::get()); |
406 | } else if (THPLayout_Check(input.ptr())) { |
407 | return InferredType(IntType::get()); |
408 | } |
409 | |
410 | auto enum_type = py::module::import("enum" ).attr("Enum" ); |
411 | py::bool_ isEnumValue = py::isinstance(input, enum_type); |
412 | if (py::cast<bool>(isEnumValue)) { |
413 | auto enum_class = input.attr("__class__" ); |
414 | auto enum_type = py::cast<TypePtr>( |
415 | py::module::import("torch.jit.annotations" ) |
416 | .attr("try_ann_to_type" )(enum_class, SourceRange())); |
417 | return InferredType(std::move(enum_type)); |
418 | } |
419 | |
420 | py::bool_ isClass = |
421 | py::module::import("inspect" ).attr("isclass" )(input.get_type()); |
422 | if (py::cast<bool>(isClass)) { |
423 | // Assume that the class is compiled already or will compile. Invalidate |
424 | // this later if needed. |
425 | bool class_compiled = true; |
426 | |
427 | // Check if the type is already compiled. |
428 | py::object existing_ty = py::module::import("torch.jit._state" ) |
429 | .attr("_get_script_class" )(input.get_type()); |
430 | |
431 | if (existing_ty.is_none()) { |
432 | // If not, try to compile it. |
433 | py::bool_ can_compile = py::module::import("torch._jit_internal" ) |
434 | .attr("can_compile_class" )(input.get_type()); |
435 | |
436 | if (py::cast<bool>(can_compile)) { |
437 | // Try to compile the class. This is wrapped in a try-catch because |
438 | // compilation of class types can raise an Exception and in that case, |
439 | // we want to defer to other attempts at type inference below rather |
440 | // than fail compilation altogether. |
441 | try { |
442 | py::module::import("torch.jit._script" ) |
443 | .attr("_recursive_compile_class" )( |
444 | input.get_type(), SourceRange()); |
445 | } catch (...) { |
446 | // Invalidate the assumption that the class compiled so that we don't |
447 | // look up and return its JIT type as the type for the input. |
448 | class_compiled = false; |
449 | } |
450 | } |
451 | } |
452 | |
453 | // If the class compiled successfully, look up the existing JIT type by |
454 | // qualified name and return it. |
455 | if (class_compiled) { |
456 | auto script_class = py::module::import("torch.jit._state" ) |
457 | .attr("_get_script_class" )(input.get_type()); |
458 | |
459 | if (!script_class.is_none()) { |
460 | auto class_type = py::cast<ClassTypePtr>(script_class); |
461 | |
462 | if (class_type && !class_type->is_module()) { |
463 | return InferredType(std::move(class_type)); |
464 | } |
465 | } |
466 | } |
467 | } |
468 | |
469 | if (py::isinstance<Object>(input)) { |
470 | auto object = py::cast<Object>(input); |
471 | return InferredType(object.type()); |
472 | #ifdef USE_RPC |
473 | } else if (py::isinstance<torch::distributed::rpc::PyRRef>(input)) { |
474 | auto rref_ivalue = input.cast<torch::distributed::rpc::PyRRef>().toIValue(); |
475 | return InferredType(rref_ivalue.type()); |
476 | #endif |
477 | } |
478 | |
479 | auto await_type = py::module::import("torch._awaits" ).attr("_Await" ); |
480 | py::bool_ is_await = py::isinstance(input, await_type); |
481 | if (py::cast<bool>(is_await)) { |
482 | auto awptr = input.cast<std::shared_ptr<PythonAwaitWrapper>>(); |
483 | return InferredType(AwaitType::create(awptr->aw_->elementType())); |
484 | } |
485 | |
486 | if (as_module(py::cast<py::object>(input))) { |
487 | return InferredType("Cannot infer type of ScriptModule" ); |
488 | } |
489 | |
490 | auto module_type = py::module::import("torch.nn" ).attr("Module" ); |
491 | py::bool_ is_module = py::isinstance(input, module_type); |
492 | if (py::cast<bool>(is_module)) { |
493 | return InferredType("Cannot infer concrete type of torch.nn.Module" ); |
494 | } |
495 | |
496 | // Try container types |
497 | return tryToInferContainerType(input); |
498 | } |
499 | |
500 | inline InferredType tryToInferContainerType(py::handle input) { |
501 | if (six::isTuple(input)) { |
502 | py::tuple tuple = py::cast<py::tuple>(input); |
503 | std::vector<TypePtr> element_types; |
504 | element_types.reserve(tuple.size()); |
505 | |
506 | for (py::handle elem : tuple) { |
507 | auto type_match = tryToInferType(elem); |
508 | if (type_match.success()) { |
509 | element_types.push_back(type_match.type()); |
510 | } else { |
511 | // Forward error message along |
512 | return type_match.reason(); |
513 | } |
514 | } |
515 | return InferredType(TupleType::create(std::move(element_types))); |
516 | } else if (PyDict_Check(input.ptr())) { |
517 | // Check to make sure we can generate useful input/output types |
518 | auto dict = py::cast<py::dict>(input); |
519 | size_t len = py::len(dict); |
520 | if (!len) { |
521 | return InferredType("Dictionary inputs must have entries" ); |
522 | } |
523 | |
524 | TypePtr key_type = nullptr; |
525 | TypePtr value_type = nullptr; |
526 | |
527 | for (auto entry : dict) { |
528 | // Try to infer the key type and unify it with the existing one |
529 | auto entry_key_type_match = tryToInferType(entry.first); |
530 | if (!entry_key_type_match.success()) { |
531 | return entry_key_type_match.reason(); |
532 | } |
533 | auto unified_key = |
534 | unifyOrInitializeType(key_type, entry_key_type_match.type()); |
535 | if (!unified_key) { |
536 | return InferredType(c10::str( |
537 | "Dictionary inputs to traced functions must have consistent type. Found " , |
538 | key_type->repr_str(), |
539 | " and " , |
540 | (entry_key_type_match.type())->repr_str())); |
541 | } |
542 | |
543 | // Try to infer the value type and unify it with the existing one |
544 | auto entry_value_type_match = tryToInferType(entry.second); |
545 | if (!entry_value_type_match.success()) { |
546 | return entry_value_type_match.reason(); |
547 | } |
548 | auto unified_value = |
549 | unifyOrInitializeType(value_type, entry_value_type_match.type()); |
550 | if (!unified_value) { |
551 | return InferredType(c10::str( |
552 | "Dictionary inputs to traced functions must have consistent type. Found " , |
553 | value_type->repr_str(), |
554 | " and " , |
555 | (entry_value_type_match.type())->repr_str())); |
556 | } |
557 | |
558 | key_type = *unified_key; |
559 | value_type = *unified_value; |
560 | } |
561 | return InferredType( |
562 | DictType::create(std::move(key_type), std::move(value_type))); |
563 | } else if (PyList_Check(input.ptr())) { |
564 | auto list = py::cast<py::list>(input); |
565 | size_t len = py::len(list); |
566 | if (!len) { |
567 | return InferredType("List trace inputs must have elements" ); |
568 | } |
569 | |
570 | TypePtr element_type = nullptr; |
571 | for (auto elem : list) { |
572 | auto element_type_match = tryToInferType(elem); |
573 | if (!element_type_match.success()) { |
574 | return InferredType(c10::str( |
575 | "Could not infer type of list element: " , |
576 | element_type_match.reason())); |
577 | } |
578 | auto unified_type = |
579 | unifyOrInitializeType(element_type, element_type_match.type()); |
580 | if (!unified_type) { |
581 | return InferredType(c10::str( |
582 | "List inputs to traced functions must have consistent element type. Found " , |
583 | element_type->repr_str(), |
584 | " and " , |
585 | (element_type_match.type())->repr_str())); |
586 | } |
587 | element_type = *unified_type; |
588 | } |
589 | return InferredType(ListType::create(element_type)); |
590 | } else { |
591 | // TODO: this message is not correct anymore, since this InferredType is |
592 | // used from a bunch of circumstances unrelated to tracing. We can re-use |
593 | // this instead of the attribute_failure stuff in concreteType |
594 | return InferredType(c10::str( |
595 | "Only tensors and (possibly nested) tuples of tensors, lists, or dicts" , |
596 | "are supported " , |
597 | "as inputs or outputs of traced functions" , |
598 | ", but instead got value of type " , |
599 | py::str(input.get_type().attr("__name__" )), |
600 | "." )); |
601 | } |
602 | } |
603 | |
604 | inline bool isTraceableType(const TypePtr& type) { |
605 | if (type->isSubtypeOf(*TensorType::get())) { |
606 | return true; |
607 | } |
608 | |
609 | if (auto list_type = type->cast<ListType>()) { |
610 | return isTraceableType(list_type->getElementType()); |
611 | } |
612 | |
613 | if (auto tuple_type = type->cast<TupleType>()) { |
614 | return std::all_of( |
615 | tuple_type->elements().begin(), |
616 | tuple_type->elements().end(), |
617 | [](const TypePtr& element_type) { |
618 | return isTraceableType(element_type); |
619 | }); |
620 | } |
621 | |
622 | if (auto dict_type = type->cast<DictType>()) { |
623 | return isTraceableType(dict_type->getValueType()); |
624 | } |
625 | |
626 | return false; |
627 | } |
628 | |
629 | inline IValue toTypeInferredIValue(py::handle input) { |
630 | auto match = tryToInferType(input); |
631 | if (!match.success()) { |
632 | auto object = py::cast<py::object>(input); |
633 | if (auto mod = as_module(object)) { |
634 | // if obj is already a ScriptModule, just return its ivalue |
635 | auto ptr = mod.value()._ivalue(); |
636 | // explict copy semantics for strong ownership of the resource. |
637 | return c10::intrusive_ptr<c10::ivalue::Object>::reclaim_copy( |
638 | ptr.release()); |
639 | } |
640 | |
641 | // Check if the obj is a ScriptObject. |
642 | if (auto script_obj = as_object(object)) { |
643 | auto ptr = script_obj.value()._ivalue(); |
644 | return c10::intrusive_ptr<c10::ivalue::Object>::reclaim_copy( |
645 | ptr.release()); |
646 | } |
647 | AT_ERROR( |
648 | "Tracer cannot infer type of " , py::str(input), "\n:" , match.reason()); |
649 | } |
650 | return toIValue(input, match.type()); |
651 | } |
652 | |
653 | inline Stack toTraceableStack(const py::tuple& inputs) { |
654 | auto info = toTypeInferredIValue(inputs); |
655 | TORCH_CHECK( |
656 | isTraceableType(info.type()), |
657 | "Type '" , |
658 | info.type()->repr_str(), |
659 | "' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and" |
660 | " Tuples of Tensors can be traced" ); |
661 | return info.toTupleRef().elements().vec(); |
662 | } |
663 | |
664 | // Serialize the python dictionary into a traceable stack. |
665 | inline Stack toTraceableStack(const py::dict& inputs) { |
666 | Stack res; |
667 | for (auto it = inputs.begin(); it != inputs.end(); it++) { |
668 | if (THPVariable_Check(it->second.ptr())) { |
669 | res.push_back(toIValue(it->second, tryToInferType(it->second).type())); |
670 | } |
671 | } |
672 | return res; |
673 | } |
674 | |
675 | inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) { |
676 | auto elems = c10::impl::GenericList(elem_type); |
677 | for (auto elem : obj) { |
678 | elems.push_back(toIValue(elem, elem_type)); |
679 | } |
680 | return IValue(elems); |
681 | } |
682 | |
683 | inline IValue createGenericDict( |
684 | const py::dict& obj, |
685 | const TypePtr& key_type, |
686 | const TypePtr& value_type) { |
687 | c10::impl::GenericDict elems(key_type, value_type); |
688 | elems.reserve(py::len(obj)); |
689 | for (auto& entry : obj) { |
690 | elems.insert( |
691 | toIValue(entry.first, key_type), toIValue(entry.second, value_type)); |
692 | } |
693 | return IValue(elems); |
694 | } |
695 | |
696 | template <class T> |
697 | inline void guardAgainstNamedTensor(const T& var) { |
698 | TORCH_CHECK( |
699 | !var.has_names(), |
700 | "NYI: Named tensors are currently unsupported in TorchScript. As a " |
701 | "workaround please drop names via `tensor = tensor.rename(None)`." ); |
702 | } |
703 | |
704 | // Defined in pybind_utils.cpp to break a circular dependency with |
705 | // python_ivalue.h |
706 | IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N); |
707 | |
708 | // Extract custom class registered with torchbind |
709 | template <typename T> |
710 | c10::intrusive_ptr<T> toCustomClass(py::handle obj) { |
711 | static_assert( |
712 | std::is_base_of<CustomClassHolder, T>::value, "T is not a CustomClass" ); |
713 | const auto& type = c10::getCustomClassType<c10::intrusive_ptr<T>>(); |
714 | c10::IValue ivalue = toIValue(obj, type); |
715 | return std::move(ivalue).toCustomClass<T>(); |
716 | } |
717 | |
718 | // Small wrapper around getting the type name string from Python to make |
719 | // types easier to interpret, e.g. give the structural type for a NamedTuple |
720 | inline std::string friendlyTypeName(py::handle obj) { |
721 | if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields" )) { |
722 | auto field_names = |
723 | py::cast<std::vector<std::string>>(py::getattr(obj, "_fields" )); |
724 | std::stringstream ss; |
725 | ss << py::str(obj.get_type().attr("__name__" )); |
726 | ss << " (aka NamedTuple(" ; |
727 | bool first = true; |
728 | for (auto& field_name : field_names) { |
729 | if (!first) { |
730 | ss << ", " ; |
731 | } |
732 | ss << field_name; |
733 | first = false; |
734 | } |
735 | ss << "))" ; |
736 | return ss.str(); |
737 | } else { |
738 | return py::str(obj.get_type().attr("__name__" )); |
739 | } |
740 | } |
741 | |
742 | // Thrown when trying to create a schema for a list of python |
743 | // arguments that cannot be converted. |
744 | // Can be caught by the caller to attempt to use other schema |
745 | // when there is an overloaded operator. |
746 | struct schema_match_error : public std::runtime_error { |
747 | using std::runtime_error::runtime_error; |
748 | }; |
749 | |
750 | inline IValue argumentToIValue( |
751 | const FunctionSchema& schema, |
752 | size_t argumentPosition, |
753 | py::handle object) { |
754 | const auto& argument = schema.arguments().at(argumentPosition); |
755 | try { |
756 | return toIValue(object, argument.real_type(), argument.N()); |
757 | } catch (const py::cast_error& error) { |
758 | throw schema_match_error(c10::str( |
759 | schema.formatTypeMismatchMsg( |
760 | argument, |
761 | friendlyTypeName(object), |
762 | argumentPosition, |
763 | py::repr(object)), |
764 | "\nCast error details: " , |
765 | error.what())); |
766 | } catch (const py::error_already_set& error) { |
767 | throw schema_match_error(c10::str( |
768 | schema.formatTypeMismatchMsg( |
769 | argument, |
770 | friendlyTypeName(object), |
771 | argumentPosition, |
772 | py::repr(object)), |
773 | "\n Python error details: " , |
774 | error.what())); |
775 | } |
776 | } |
777 | |
778 | inline IValue returnToIValue(const TypePtr& type, py::handle object) { |
779 | try { |
780 | return toIValue(object, type); |
781 | } catch (const py::cast_error& error) { |
782 | throw std::runtime_error(c10::str( |
783 | " expected value of type " , |
784 | type->str(), |
785 | " for return value but instead got value of type " , |
786 | py::str(object.get_type().attr("__name__" )), |
787 | "." , |
788 | "\nValue: " , |
789 | py::repr(object), |
790 | "\nCast error details: " , |
791 | error.what())); |
792 | } |
793 | } |
794 | |
795 | inline py::object getScriptedClassOrError(const c10::NamedTypePtr& classType) { |
796 | auto py_class = |
797 | py::module::import("torch.jit._state" ) |
798 | .attr("_get_python_class" )(classType->name()->qualifiedName()); |
799 | if (py_class.is_none()) { |
800 | std::stringstream err; |
801 | err << "Unknown reference to ScriptClass " ; |
802 | err << classType->name()->qualifiedName(); |
803 | err << ". (Did you forget to import it?)" ; |
804 | throw std::runtime_error(err.str()); |
805 | } |
806 | return py_class; |
807 | } |
808 | |
809 | struct VISIBILITY_HIDDEN tuple_slice { |
810 | /*implicit*/ tuple_slice(py::tuple tup_) |
811 | : tup(std::move(tup_)), b(0), e(tup.size()) {} |
812 | tuple_slice(py::tuple tup_, int64_t b_) |
813 | : tup(std::move(tup_)), b(b_), e(tup.size()) {} |
814 | tuple_slice(py::tuple tup_, int64_t b_, int64_t e_) |
815 | : tup(std::move(tup_)), b(b_), e(e_) {} |
816 | py::detail::tuple_iterator begin() const { |
817 | return {tup, static_cast<pybind11::ssize_t>(b)}; |
818 | } |
819 | py::detail::tuple_iterator end() const { |
820 | return {tup, static_cast<pybind11::ssize_t>(e)}; |
821 | } |
822 | size_t size() const { |
823 | return e - b; |
824 | } |
825 | py::detail::tuple_accessor operator[](size_t index) const { |
826 | return {tup, static_cast<size_t>(b + index)}; |
827 | } |
828 | |
829 | private: |
830 | py::tuple tup; |
831 | int64_t b; |
832 | int64_t e; |
833 | }; |
834 | |
835 | inline Stack createStackForSchema( |
836 | const FunctionSchema& schema, |
837 | const tuple_slice& args, |
838 | const py::kwargs& kwargs, |
839 | c10::optional<IValue> self) { |
840 | size_t all_arguments = (self ? 1 : 0) + args.size() + kwargs.size(); |
841 | if (all_arguments > schema.arguments().size()) { |
842 | throw schema_match_error(c10::str( |
843 | schema.name(), |
844 | "() expected at most " , |
845 | schema.arguments().size(), |
846 | " argument(s) but received " , |
847 | all_arguments, |
848 | " argument(s). Declaration: " , |
849 | schema)); |
850 | } |
851 | Stack stack; |
852 | stack.reserve(schema.arguments().size()); |
853 | |
854 | int64_t arg_idx = 0; |
855 | if (self) { |
856 | push(stack, std::move(*self)); |
857 | arg_idx++; |
858 | } |
859 | // First push all positional args. |
860 | for (const auto& arg : args) { |
861 | // ...but refuse to do it if the schema says that this was supposed |
862 | // to be keyword only |
863 | if (schema.arguments()[arg_idx].kwarg_only()) { |
864 | throw schema_match_error(c10::str( |
865 | schema.name(), |
866 | "() takes " , |
867 | arg_idx, |
868 | " positional argument(s) but " , |
869 | self ? 1 + args.size() : args.size(), |
870 | " was/were given. Declaration: " , |
871 | schema)); |
872 | } |
873 | // Use the type information from the schema to convert the PyObject. |
874 | push(stack, argumentToIValue(schema, stack.size(), arg)); |
875 | arg_idx++; |
876 | } |
877 | |
878 | // Now for every remaining non-positional argument in the schema, look for it |
879 | // in the kwargs dict and push it if found, or use its default value if it |
880 | // has one. |
881 | size_t consumed_kwargs = 0; |
882 | for (size_t i = stack.size(); i < schema.arguments().size(); ++i) { |
883 | const auto& arg = schema.arguments()[i]; |
884 | if (kwargs.contains(arg.name().c_str())) { |
885 | push(stack, argumentToIValue(schema, i, kwargs[arg.name().c_str()])); |
886 | consumed_kwargs += 1; |
887 | } else if (arg.default_value()) { |
888 | push(stack, *arg.default_value()); |
889 | } else { |
890 | throw schema_match_error(c10::str( |
891 | schema.name(), |
892 | "() is missing value for argument '" , |
893 | arg.name(), |
894 | "'. Declaration: " , |
895 | schema)); |
896 | } |
897 | } |
898 | |
899 | if (consumed_kwargs != kwargs.size()) { |
900 | std::vector<std::string> names; |
901 | for (const auto& kwarg : kwargs) { |
902 | names.emplace_back(py::cast<std::string>(kwarg.first)); |
903 | } |
904 | throw schema_match_error(schema.findErrorInKwargs(names)); |
905 | } |
906 | |
907 | return stack; |
908 | } |
909 | |
910 | inline py::object createPyObjectForStack(Stack&& stack) { |
911 | if (stack.empty()) { |
912 | return py::none(); |
913 | } |
914 | |
915 | // Return a simple value and not a single-element tuple if there is only one |
916 | // return value. |
917 | if (stack.size() == 1) { |
918 | return toPyObject(std::move(stack[0])); |
919 | } |
920 | |
921 | // If there is more than one return value, pop them into a py::tuple. |
922 | py::tuple return_values(stack.size()); |
923 | for (const auto ret : c10::irange(return_values.size())) { |
924 | return_values[ret] = toPyObject(std::move(stack[ret])); |
925 | } |
926 | |
927 | return std::move(return_values); |
928 | } |
929 | |
930 | // TODO: Remove once we clean up the GraphExecutor usage. |
931 | inline Stack evilDeprecatedBadCreateStackDoNotUse( |
932 | const py::tuple& tuple, |
933 | at::ArrayRef<Value*> inputs, |
934 | size_t = 0) { |
935 | if (tuple.size() != inputs.size()) { |
936 | AT_ERROR( |
937 | "expected " + std::to_string(inputs.size()) + " inputs, but got " + |
938 | std::to_string(tuple.size())); |
939 | } |
940 | Stack result; |
941 | result.reserve(tuple.size() + reserve_extra_space); |
942 | for (const auto i : c10::irange(inputs.size())) { |
943 | result.push_back(toIValue(std::move(tuple[i]), inputs[i]->type())); |
944 | } |
945 | return result; |
946 | } |
947 | |
948 | // Run `callee`, potentially inserting a CallFunction/CallMethod node into the |
949 | // tracing graph. |
950 | inline py::object runAndInsertCall( |
951 | Function& callee, |
952 | const tuple_slice& args, |
953 | const py::kwargs& kwargs, |
954 | c10::optional<IValue> self, |
955 | // Lambda that tells this function how to insert `callee` into the graph if |
956 | // we're tracing. |
957 | const std::function<Value*(Graph&, const MatchedSchema& match)>& |
958 | callInserter) { |
959 | auto stack = |
960 | createStackForSchema(callee.getSchema(), args, kwargs, std::move(self)); |
961 | const auto& tracing_state = tracer::getTracingState(); |
962 | if (!tracing_state) { |
963 | pybind11::gil_scoped_release no_gil_guard; |
964 | // If we're not tracing, just run the callee as normal. |
965 | callee.run(stack); |
966 | } else { |
967 | // If we are tracing, insert the appropriate CallFunction or CallMethod node |
968 | // and then run the callee with tracing disabled. |
969 | |
970 | // Get the graph `Value`s that represent the input IValues |
971 | auto inputs = last(stack, callee.num_inputs()); |
972 | auto input_values = |
973 | fmap(inputs, [](const IValue& v) { return tracer::getValueTrace(v); }); |
974 | TORCH_INTERNAL_ASSERT(callee.getSchema().returns().size() == 1) |
975 | auto return_type = callee.getSchema().returns().at(0).type(); |
976 | auto graph = tracing_state->graph; |
977 | std::vector<NamedValue> named_values; |
978 | named_values.reserve(input_values.size()); |
979 | for (Value* v : input_values) { |
980 | named_values.emplace_back(v); |
981 | } |
982 | |
983 | // Add a call node. |
984 | MatchedSchema match = matchSchema( |
985 | callee.getSchema(), |
986 | tracer::getPythonInterpreterSourceRange(), |
987 | *graph, |
988 | named_values, |
989 | {}); |
990 | auto output_value = callInserter(*graph, match); |
991 | |
992 | // Actually run the callee. Pause the tracer so that we don't double-add the |
993 | // callee nodes. |
994 | { |
995 | pybind11::gil_scoped_release no_gil_guard; |
996 | ResourceGuard guard(tracer::pauseTracing()); |
997 | callee.run(stack); |
998 | } |
999 | |
1000 | // Associate the output IValues with the output `Value`s in the graph |
1001 | tracer::setValueTrace(stack.back(), output_value); |
1002 | } |
1003 | |
1004 | TORCH_CHECK( |
1005 | !stack.empty(), |
1006 | "Expected values in the stack after execution but found none" ); |
1007 | return toPyObject(std::move(stack.back())); |
1008 | } |
1009 | |
1010 | inline c10::optional<py::object> maybeTorchFunctionDispatch( |
1011 | const py::object& callee, |
1012 | const tuple_slice& args_no_self, |
1013 | const py::kwargs& kwargs, |
1014 | const c10::QualifiedName qualname) { |
1015 | std::vector<py::handle> args_vec; |
1016 | for (const auto& arg : args_no_self) { |
1017 | args_vec.push_back(arg); |
1018 | } |
1019 | py::tuple args = py::cast(args_vec); |
1020 | |
1021 | // Handle __torch_function__ dispatch |
1022 | std::vector<py::handle> overloaded_args; |
1023 | size_t total_arg_num = args.size() + kwargs.size(); |
1024 | for (const auto& arg : args) { |
1025 | is_tensor_and_append_overloaded(arg.ptr(), &overloaded_args); |
1026 | is_tensor_list_and_append_overloaded( |
1027 | arg.ptr(), |
1028 | &overloaded_args, |
1029 | static_cast<int>(total_arg_num), |
1030 | false /* throw_error */); |
1031 | } |
1032 | // NB: for kwargs, we cannot guarantee the order of appending |
1033 | // is the same as the argument order in operator's schema. |
1034 | // This is suboptimal, but should be fine. Later when we have |
1035 | // better schema matching and argument parsing, we could |
1036 | // match the operator in `operations` first, then the order will |
1037 | // be guaranteed. |
1038 | for (auto item : kwargs) { |
1039 | is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args); |
1040 | is_tensor_list_and_append_overloaded( |
1041 | item.second.ptr(), |
1042 | &overloaded_args, |
1043 | total_arg_num, |
1044 | false /* throw_error */); |
1045 | } |
1046 | if (!overloaded_args.empty()) { |
1047 | return pybind11::reinterpret_steal<py::object>( |
1048 | handle_torch_function_no_python_arg_parser( |
1049 | /*overloaded_args=*/overloaded_args, |
1050 | /*args=*/args.ptr(), |
1051 | /*kwargs=*/kwargs.ptr(), |
1052 | /*func_name=*/qualname.name().c_str(), |
1053 | /*torch_api_function=*/callee.ptr(), |
1054 | /*module_name=*/qualname.prefix().c_str())); |
1055 | } |
1056 | |
1057 | return c10::nullopt; |
1058 | } |
1059 | |
1060 | inline py::object invokeScriptFunctionFromPython( |
1061 | Function& callee, |
1062 | const tuple_slice& args, |
1063 | const py::kwargs& kwargs) { |
1064 | // TODO: we could add __torch_function__ dispatch here but I don't know |
1065 | // the implications of doing so |
1066 | |
1067 | return runAndInsertCall( |
1068 | callee, |
1069 | args, |
1070 | kwargs, |
1071 | /*self=*/c10::nullopt, |
1072 | [&](Graph& graph, const MatchedSchema& match) { |
1073 | return graph.insertFunctionCall(&callee, match); |
1074 | }); |
1075 | } |
1076 | |
1077 | inline py::object invokeScriptMethodFromPython( |
1078 | Method& callee, |
1079 | const tuple_slice& args, |
1080 | const py::kwargs& kwargs) { |
1081 | auto self = callee.owner()._ivalue(); |
1082 | |
1083 | if (auto torch_fn_result = maybeTorchFunctionDispatch( |
1084 | py::cast(callee), args, kwargs, callee.name())) { |
1085 | return *torch_fn_result; |
1086 | } |
1087 | |
1088 | return runAndInsertCall( |
1089 | callee.function(), |
1090 | args, |
1091 | kwargs, |
1092 | self, |
1093 | [&](Graph& graph, const MatchedSchema& match) { |
1094 | return graph.insertMethodCall(callee.name(), match); |
1095 | }); |
1096 | } |
1097 | |
1098 | TORCH_API std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack( |
1099 | const std::vector<std::shared_ptr<Operator>>& operations, |
1100 | py::args args, |
1101 | const py::kwargs& kwargs); |
1102 | |
1103 | TORCH_API py::object invokeOperatorFromPython( |
1104 | const std::vector<std::shared_ptr<Operator>>& operations, |
1105 | py::args args, |
1106 | const py::kwargs& kwargs, |
1107 | c10::optional<c10::DispatchKey> dk = c10::nullopt); |
1108 | |
1109 | TORCH_API py::object _get_operation_for_overload_or_packet( |
1110 | const std::vector<std::shared_ptr<Operator>>& operations, |
1111 | Symbol symbol, |
1112 | py::args args, |
1113 | const py::kwargs& kwargs, |
1114 | bool is_overload, |
1115 | c10::optional<c10::DispatchKey> dk = c10::nullopt); |
1116 | |
1117 | } // namespace jit |
1118 | } // namespace torch |
1119 | |