1#pragma once
2
3#include <ATen/core/ivalue.h>
4#include <ATen/core/jit_type.h>
5#include <ATen/core/qualified_name.h>
6#include <ATen/core/stack.h>
7#include <pybind11/complex.h>
8#include <pybind11/pybind11.h>
9#include <pybind11/pytypes.h>
10#include <torch/csrc/Device.h>
11#include <torch/csrc/Dtype.h>
12#include <torch/csrc/Export.h>
13#include <torch/csrc/Layout.h>
14#include <torch/csrc/QScheme.h>
15#include <torch/csrc/Stream.h>
16#include <torch/csrc/jit/api/module.h>
17#include <torch/csrc/jit/frontend/schema_matching.h>
18#include <torch/csrc/jit/frontend/tracer.h>
19#include <torch/csrc/jit/python/module_python.h>
20#include <torch/csrc/jit/python/python_custom_class.h>
21#include <torch/csrc/jit/python/python_tracer.h>
22#include <torch/csrc/jit/resource_guard.h>
23#include <torch/csrc/jit/runtime/operator.h>
24#include <torch/csrc/utils/auto_gil.h>
25#include <torch/csrc/utils/pybind.h>
26#include <torch/csrc/utils/python_arg_parser.h>
27#include <torch/csrc/utils/six.h>
28#ifdef USE_DISTRIBUTED
29#include <torch/csrc/distributed/rpc/py_rref.h>
30#include <torch/csrc/distributed/rpc/rref_impl.h>
31#endif
32
33#include <ATen/core/function_schema.h>
34#include <c10/core/Stream.h>
35#ifdef USE_C10D_NCCL
36#include <c10/cuda/CUDACachingAllocator.h>
37#include <c10/cuda/CUDAStream.h>
38#endif
39#include <c10/util/Exception.h>
40#include <c10/util/Optional.h>
41#include <c10/util/irange.h>
42
43#include <algorithm>
44#include <cstddef>
45#include <string>
46#include <utility>
47#include <vector>
48
49// The visibility attribute is to avoid a warning about storing a field in the
50// struct that has a different visibility (from pybind) than the struct.
51#ifdef _WIN32
52#define VISIBILITY_HIDDEN
53#else
54#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
55#endif
56
57namespace torch {
58namespace jit {
59
60void clear_registered_instances(void* ptr);
61
62TORCH_API IValue toIValue(
63 py::handle obj,
64 const TypePtr& type,
65 c10::optional<int32_t> N = c10::nullopt);
66
67TORCH_API py::object toPyObject(IValue ivalue);
68
69// Hack to overload the behavior of toIValue to accept Python
70// numbers in places where a Tensor is expected
71// See also torch::should_allow_numbers_as_tensors
72class ToIValueAllowNumbersAsTensors {
73 bool old_;
74
75 public:
76 ToIValueAllowNumbersAsTensors(bool enable);
77 ~ToIValueAllowNumbersAsTensors();
78};
79
80// Wrap Python function to guard deref
81// NB: Need VISIBILITY_HIDDEN for silencing compiler error,
82// 'torch::jit::PythonFunctionGuard' declared with greater visibility than the
83// type of its field 'torch::jit::PythonFunctionGuard::func_'
84struct VISIBILITY_HIDDEN PythonFunctionGuard {
85 explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {}
86
87 ~PythonFunctionGuard() {
88 pybind11::gil_scoped_acquire ag;
89 func_.dec_ref();
90 // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
91 // decref on the PyObject again.
92 // See Note [Destructing py::object] in python_ivalue.h
93 func_.ptr() = nullptr;
94 }
95
96 py::function func_;
97};
98
99// The PythonFutureWrapper for ivalue::Future
100//
101// NB: VISIBILITY_HIDDEN is for silencing compiling error,
102// "error: 'torch::jit::PythonFutureWrapper' declared with greater visibility
103// than the type of its field 'torch::jit::PythonFutureWrapper::unwrap_func'
104// [-Werror=attributes]"
105//
106// NB: inherit from enable_shared_from_this because then(py::function) needs to
107// get a shared_ptr from this pointer.
108struct VISIBILITY_HIDDEN PythonFutureWrapper
109 : std::enable_shared_from_this<PythonFutureWrapper> {
110 using UnwrapFunc = std::function<void(py::object)>;
111
112 explicit PythonFutureWrapper(
113 c10::intrusive_ptr<c10::ivalue::Future> fut,
114 c10::optional<UnwrapFunc> unwrap_func = c10::nullopt)
115 : fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {}
116
117 explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
118 PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;
119
120 bool done() {
121 return fut->completed();
122 }
123
124 py::object value() {
125 // acquiring GIL as toPyObject creates new py::object
126 // without grabbing the GIL.
127 py::gil_scoped_acquire acquire;
128 py::object py_obj = toPyObject(fut->value());
129 // unwrap_func is a general compositional function that takes in a
130 // py::object and executes some python function. It is currently mostly used
131 // to throw python exceptions.
132 if (unwrap_func) {
133 (*unwrap_func)(py_obj);
134 }
135 return py_obj;
136 }
137
138 py::object wait() {
139 fut->wait();
140 if (jit::tracer::isTracing()) {
141 auto graph = jit::tracer::getTracingState()->graph;
142
143 Value* fut_val = jit::tracer::getValueTrace(fut);
144 auto output = graph->insert(aten::wait, {fut_val});
145 jit::tracer::setValueTrace(fut->value(), output);
146 }
147 return value();
148 }
149
150 // The py::function cb arg must take a std::shared_ptr<PythonFutureWrapper>
151 // (i.e., torch._C.Future) as the only argument. If the type mismatches, an
152 // error will be thrown when waiting for the value of this returned Future.
153 std::shared_ptr<PythonFutureWrapper> then(py::function cb) {
154 // We need this an additional layer of wrapper here to guard the
155 // destruction of the py::function object. Because, the
156 // Future owns a reference to the py::function in its callback
157 // vector, but Future does not acquire GIL on destruction.
158 auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));
159
160 return std::make_shared<jit::PythonFutureWrapper>(fut->then(
161 // Capture a copy of the ivalue::Future instead of the `this` pointer
162 // because the PythonFutureWrapper object could have been deleted
163 // when the callbacks are fired. For example, RPC only captures the
164 // ivalue::Future instead of PythonFutureWrapper in JitFuture's
165 // callback functions. Hence, if user code does not hold a reference to
166 // this PythonFutureWrapper object, there is no guarantee that the
167 // PythonFutureWrapper is still valid when running the callback.
168 [pyFut(this->getPtr()),
169 pf(std::move(pf))](c10::ivalue::Future& /* unused */) -> IValue {
170 try {
171 pybind11::gil_scoped_acquire ag;
172 return toIValue(pf->func_(pyFut), PyObjectType::get());
173 } catch (py::error_already_set& e) {
174 auto err = std::runtime_error(c10::str(
175 "Got the following error when running the callback: ",
176 e.what()));
177 {
178 pybind11::gil_scoped_acquire ag;
179 // Release ownership on py::objects and also restore Python
180 // Error Indicator.
181 e.restore();
182 // Clear the Python Error Indicator as we has recorded the
183 // exception in the response message.
184 PyErr_Clear();
185 }
186
187 throw err;
188 }
189 },
190 PyObjectType::get()));
191 }
192
193 void add_done_callback(py::function cb) {
194 auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));
195 // NOLINTNEXTLINE(modernize-avoid-bind)
196 fut->addCallback(std::bind(
197 [pyFut(this->getPtr())](std::shared_ptr<PythonFunctionGuard> pf) {
198 try {
199 pybind11::gil_scoped_acquire ag;
200 pf->func_(pyFut);
201 } catch (py::error_already_set& e) {
202 {
203 pybind11::gil_scoped_acquire ag;
204 // Release ownership on py::objects and also restore Python
205 // Error Indicator.
206 e.restore();
207 // Clear the Python Error Indicator as we has recorded the
208 // exception in the response message.
209 PyErr_Clear();
210 }
211 // Log and ignore exceptions raised through the callback
212 LOG(ERROR) << "Got the following error when running the callback: "
213 << e.what();
214
215 } catch (const std::exception& e) {
216 // Log and ignore exceptions raised through the callback
217 LOG(ERROR) << "Got the following error when running the callback: "
218 << e.what();
219 }
220 },
221 std::move(pf)));
222 }
223
224 void markCompleted(const py::object& pyValue) {
225 DCHECK(PyGILState_Check());
226 IValue value = toIValue(pyValue, PyObjectType::get());
227
228 py::gil_scoped_release release;
229 fut->markCompleted(std::move(value));
230 }
231
232 c10::intrusive_ptr<c10::ivalue::Future> fut;
233 // unwrap_func works like a callback for the value returned by
234 // PythonFutureWrapper::wait().
235 c10::optional<UnwrapFunc> unwrap_func;
236
237 private:
238 std::shared_ptr<PythonFutureWrapper> getPtr() {
239 return shared_from_this();
240 }
241};
242
243// The PythonAwaitWrapper for ivalue::Await
244//
245// Expresses delayed function execution with Lazy semantic.
246// i.e. Await[W] in eager mode can be used as W.
247// When the attribute of W type is requested, Await[W] will return the
248// attribute of W, transparently calling wait() beforehand.
249// No Lazy semantic for script, explicit wait(Await[W]) -> W must be called to
250// convert to type W.
251//
252// The Await object takes shared ownership of specified function and the
253// arguments. After first call for wait() it owns the result. Deliberately no
254// type inference for eager mode.
255struct VISIBILITY_HIDDEN PythonAwaitWrapper
256 : std::enable_shared_from_this<PythonAwaitWrapper> {
257 explicit PythonAwaitWrapper(c10::intrusive_ptr<c10::ivalue::Await> aw)
258 : aw_(std::move(aw)) {}
259 explicit PythonAwaitWrapper(py::handle input) {
260 args_ = py::tuple(1u);
261 args_[0] = input;
262 auto type = PyObjectType::get();
263 aw_ = c10::make_intrusive<c10::ivalue::Await>(type);
264 aw_->markCompleted(toIValue(input, type));
265 }
266
267 explicit PythonAwaitWrapper(py::function pf, py::tuple args) {
268 pyfg_ = std::make_shared<torch::jit::PythonFunctionGuard>(std::move(pf));
269 args_ = std::move(args);
270 std::function<IValue()> f = [fg(pyfg_), &args(args_)]() {
271 pybind11::gil_scoped_acquire ag;
272 return toIValue(fg->func_(*args), PyObjectType::get());
273 };
274 aw_ = c10::make_intrusive<c10::ivalue::Await>(
275 PyObjectType::get(), std::move(f));
276 }
277
278 explicit PythonAwaitWrapper(const PythonAwaitWrapper&) = delete;
279 PythonAwaitWrapper& operator=(const PythonAwaitWrapper&) = delete;
280
281 py::object wait() {
282 py::gil_scoped_acquire acquire;
283 return toPyObject(aw_->wait());
284 }
285
286 // Nowait semantic means trivial case when Await is constructed from the
287 // result
288 bool is_nowait() {
289 return pyfg_ == nullptr;
290 }
291
292 const py::function fn() {
293 TORCH_CHECK(
294 pyfg_, "Await constructed as awaitable_nowait does not have fn");
295 return pyfg_->func_;
296 }
297
298 const py::tuple args() {
299 return args_;
300 }
301
302 TypePtr type() {
303 return aw_->type();
304 }
305
306 c10::intrusive_ptr<c10::ivalue::Await> aw_;
307 std::shared_ptr<torch::jit::PythonFunctionGuard> pyfg_;
308 py::tuple args_;
309
310 private:
311 std::shared_ptr<PythonAwaitWrapper> getPtr() {
312 return shared_from_this();
313 }
314};
315
316// error reporting: when reporting user-caused errors, these functions should
317// not use AT_ERROR macros, since these macros add stack trace information
318// that is confusing to display to the end user since it always reports
319// locations in libtorch code rather than user code.
320
321inline std::shared_ptr<CompilationUnit> get_python_cu() {
322 return py::module::import("torch.jit._state")
323 .attr("_python_cu")
324 .cast<std::shared_ptr<CompilationUnit>>();
325}
326
327struct TypedIValue : public std::pair<IValue, TypePtr> {
328 using pair::pair;
329
330 IValue& ivalue() {
331 return this->first;
332 }
333 TypePtr& type() {
334 return this->second;
335 }
336};
337
338inline TypedIValue toDictKeyIValue(py::handle key) {
339 if (py::isinstance<py::str>(key)) {
340 return TypedIValue(
341 ConstantString::create(py::cast<std::string>(key)), StringType::get());
342 } else if (py::isinstance<py::int_>(key)) {
343 return TypedIValue(py::cast<int64_t>(key), IntType::get());
344 } else if (py::isinstance<py::float_>(key)) {
345 return TypedIValue(py::cast<double>(key), FloatType::get());
346 } else {
347 AT_ERROR("Dictionary inputs may only have string, int, or float keys");
348 }
349}
350
351inline c10::optional<TypePtr> unifyOrInitializeType(
352 const TypePtr& accum,
353 const TypePtr& unify) {
354 if (!accum) {
355 return unify;
356 }
357 return unifyTypes(accum, unify);
358}
359
360using InferredType = c10::InferredType;
361
362InferredType tryToInferContainerType(py::handle input);
363
364// Try to infer the type of a Python object
365// The type cannot be inferred if:
366// input is an empty container (list, dict)
367// input is an list with element types that cannot be unified
368// input is an dict with key or value types that cannot be unified
369inline InferredType tryToInferType(py::handle input) {
370 // Try tensor types
371 if (THPVariable_Check(input.ptr())) {
372 return InferredType(TensorType::get());
373 }
374
375 if (input.is_none()) {
376 return InferredType(NoneType::get());
377 }
378
379 if (py::isinstance<StrongFunctionPtr>(input)) {
380 auto fn = py::cast<StrongFunctionPtr>(input).function_;
381 return InferredType(FunctionType::create(fn));
382 }
383
384 // Try basic types first
385 if (py::isinstance<py::bool_>(input)) {
386 return InferredType(BoolType::get());
387 // NOLINTNEXTLINE(bugprone-branch-clone)
388 } else if (py::isinstance<py::int_>(input)) {
389 return InferredType(IntType::get());
390 } else if (py::isinstance<py::float_>(input)) {
391 return InferredType(FloatType::get());
392 } else if (PyComplex_CheckExact(input.ptr())) {
393 return InferredType(ComplexType::get());
394 } else if (py::isinstance<py::str>(input)) {
395 return InferredType(StringType::get());
396 } else if (THPLayout_Check(input.ptr())) {
397 return InferredType(IntType::get());
398 } else if (THPDevice_Check(input.ptr())) {
399 return InferredType(DeviceObjType::get());
400 } else if (THPStream_Check(input.ptr())) {
401 return InferredType(StreamObjType::get());
402 } else if (THPDtype_Check(input.ptr())) {
403 return InferredType(IntType::get());
404 } else if (THPQScheme_Check(input.ptr())) {
405 return InferredType(IntType::get());
406 } else if (THPLayout_Check(input.ptr())) {
407 return InferredType(IntType::get());
408 }
409
410 auto enum_type = py::module::import("enum").attr("Enum");
411 py::bool_ isEnumValue = py::isinstance(input, enum_type);
412 if (py::cast<bool>(isEnumValue)) {
413 auto enum_class = input.attr("__class__");
414 auto enum_type = py::cast<TypePtr>(
415 py::module::import("torch.jit.annotations")
416 .attr("try_ann_to_type")(enum_class, SourceRange()));
417 return InferredType(std::move(enum_type));
418 }
419
420 py::bool_ isClass =
421 py::module::import("inspect").attr("isclass")(input.get_type());
422 if (py::cast<bool>(isClass)) {
423 // Assume that the class is compiled already or will compile. Invalidate
424 // this later if needed.
425 bool class_compiled = true;
426
427 // Check if the type is already compiled.
428 py::object existing_ty = py::module::import("torch.jit._state")
429 .attr("_get_script_class")(input.get_type());
430
431 if (existing_ty.is_none()) {
432 // If not, try to compile it.
433 py::bool_ can_compile = py::module::import("torch._jit_internal")
434 .attr("can_compile_class")(input.get_type());
435
436 if (py::cast<bool>(can_compile)) {
437 // Try to compile the class. This is wrapped in a try-catch because
438 // compilation of class types can raise an Exception and in that case,
439 // we want to defer to other attempts at type inference below rather
440 // than fail compilation altogether.
441 try {
442 py::module::import("torch.jit._script")
443 .attr("_recursive_compile_class")(
444 input.get_type(), SourceRange());
445 } catch (...) {
446 // Invalidate the assumption that the class compiled so that we don't
447 // look up and return its JIT type as the type for the input.
448 class_compiled = false;
449 }
450 }
451 }
452
453 // If the class compiled successfully, look up the existing JIT type by
454 // qualified name and return it.
455 if (class_compiled) {
456 auto script_class = py::module::import("torch.jit._state")
457 .attr("_get_script_class")(input.get_type());
458
459 if (!script_class.is_none()) {
460 auto class_type = py::cast<ClassTypePtr>(script_class);
461
462 if (class_type && !class_type->is_module()) {
463 return InferredType(std::move(class_type));
464 }
465 }
466 }
467 }
468
469 if (py::isinstance<Object>(input)) {
470 auto object = py::cast<Object>(input);
471 return InferredType(object.type());
472#ifdef USE_RPC
473 } else if (py::isinstance<torch::distributed::rpc::PyRRef>(input)) {
474 auto rref_ivalue = input.cast<torch::distributed::rpc::PyRRef>().toIValue();
475 return InferredType(rref_ivalue.type());
476#endif
477 }
478
479 auto await_type = py::module::import("torch._awaits").attr("_Await");
480 py::bool_ is_await = py::isinstance(input, await_type);
481 if (py::cast<bool>(is_await)) {
482 auto awptr = input.cast<std::shared_ptr<PythonAwaitWrapper>>();
483 return InferredType(AwaitType::create(awptr->aw_->elementType()));
484 }
485
486 if (as_module(py::cast<py::object>(input))) {
487 return InferredType("Cannot infer type of ScriptModule");
488 }
489
490 auto module_type = py::module::import("torch.nn").attr("Module");
491 py::bool_ is_module = py::isinstance(input, module_type);
492 if (py::cast<bool>(is_module)) {
493 return InferredType("Cannot infer concrete type of torch.nn.Module");
494 }
495
496 // Try container types
497 return tryToInferContainerType(input);
498}
499
500inline InferredType tryToInferContainerType(py::handle input) {
501 if (six::isTuple(input)) {
502 py::tuple tuple = py::cast<py::tuple>(input);
503 std::vector<TypePtr> element_types;
504 element_types.reserve(tuple.size());
505
506 for (py::handle elem : tuple) {
507 auto type_match = tryToInferType(elem);
508 if (type_match.success()) {
509 element_types.push_back(type_match.type());
510 } else {
511 // Forward error message along
512 return type_match.reason();
513 }
514 }
515 return InferredType(TupleType::create(std::move(element_types)));
516 } else if (PyDict_Check(input.ptr())) {
517 // Check to make sure we can generate useful input/output types
518 auto dict = py::cast<py::dict>(input);
519 size_t len = py::len(dict);
520 if (!len) {
521 return InferredType("Dictionary inputs must have entries");
522 }
523
524 TypePtr key_type = nullptr;
525 TypePtr value_type = nullptr;
526
527 for (auto entry : dict) {
528 // Try to infer the key type and unify it with the existing one
529 auto entry_key_type_match = tryToInferType(entry.first);
530 if (!entry_key_type_match.success()) {
531 return entry_key_type_match.reason();
532 }
533 auto unified_key =
534 unifyOrInitializeType(key_type, entry_key_type_match.type());
535 if (!unified_key) {
536 return InferredType(c10::str(
537 "Dictionary inputs to traced functions must have consistent type. Found ",
538 key_type->repr_str(),
539 " and ",
540 (entry_key_type_match.type())->repr_str()));
541 }
542
543 // Try to infer the value type and unify it with the existing one
544 auto entry_value_type_match = tryToInferType(entry.second);
545 if (!entry_value_type_match.success()) {
546 return entry_value_type_match.reason();
547 }
548 auto unified_value =
549 unifyOrInitializeType(value_type, entry_value_type_match.type());
550 if (!unified_value) {
551 return InferredType(c10::str(
552 "Dictionary inputs to traced functions must have consistent type. Found ",
553 value_type->repr_str(),
554 " and ",
555 (entry_value_type_match.type())->repr_str()));
556 }
557
558 key_type = *unified_key;
559 value_type = *unified_value;
560 }
561 return InferredType(
562 DictType::create(std::move(key_type), std::move(value_type)));
563 } else if (PyList_Check(input.ptr())) {
564 auto list = py::cast<py::list>(input);
565 size_t len = py::len(list);
566 if (!len) {
567 return InferredType("List trace inputs must have elements");
568 }
569
570 TypePtr element_type = nullptr;
571 for (auto elem : list) {
572 auto element_type_match = tryToInferType(elem);
573 if (!element_type_match.success()) {
574 return InferredType(c10::str(
575 "Could not infer type of list element: ",
576 element_type_match.reason()));
577 }
578 auto unified_type =
579 unifyOrInitializeType(element_type, element_type_match.type());
580 if (!unified_type) {
581 return InferredType(c10::str(
582 "List inputs to traced functions must have consistent element type. Found ",
583 element_type->repr_str(),
584 " and ",
585 (element_type_match.type())->repr_str()));
586 }
587 element_type = *unified_type;
588 }
589 return InferredType(ListType::create(element_type));
590 } else {
591 // TODO: this message is not correct anymore, since this InferredType is
592 // used from a bunch of circumstances unrelated to tracing. We can re-use
593 // this instead of the attribute_failure stuff in concreteType
594 return InferredType(c10::str(
595 "Only tensors and (possibly nested) tuples of tensors, lists, or dicts",
596 "are supported ",
597 "as inputs or outputs of traced functions",
598 ", but instead got value of type ",
599 py::str(input.get_type().attr("__name__")),
600 "."));
601 }
602}
603
604inline bool isTraceableType(const TypePtr& type) {
605 if (type->isSubtypeOf(*TensorType::get())) {
606 return true;
607 }
608
609 if (auto list_type = type->cast<ListType>()) {
610 return isTraceableType(list_type->getElementType());
611 }
612
613 if (auto tuple_type = type->cast<TupleType>()) {
614 return std::all_of(
615 tuple_type->elements().begin(),
616 tuple_type->elements().end(),
617 [](const TypePtr& element_type) {
618 return isTraceableType(element_type);
619 });
620 }
621
622 if (auto dict_type = type->cast<DictType>()) {
623 return isTraceableType(dict_type->getValueType());
624 }
625
626 return false;
627}
628
629inline IValue toTypeInferredIValue(py::handle input) {
630 auto match = tryToInferType(input);
631 if (!match.success()) {
632 auto object = py::cast<py::object>(input);
633 if (auto mod = as_module(object)) {
634 // if obj is already a ScriptModule, just return its ivalue
635 auto ptr = mod.value()._ivalue();
636 // explict copy semantics for strong ownership of the resource.
637 return c10::intrusive_ptr<c10::ivalue::Object>::reclaim_copy(
638 ptr.release());
639 }
640
641 // Check if the obj is a ScriptObject.
642 if (auto script_obj = as_object(object)) {
643 auto ptr = script_obj.value()._ivalue();
644 return c10::intrusive_ptr<c10::ivalue::Object>::reclaim_copy(
645 ptr.release());
646 }
647 AT_ERROR(
648 "Tracer cannot infer type of ", py::str(input), "\n:", match.reason());
649 }
650 return toIValue(input, match.type());
651}
652
653inline Stack toTraceableStack(const py::tuple& inputs) {
654 auto info = toTypeInferredIValue(inputs);
655 TORCH_CHECK(
656 isTraceableType(info.type()),
657 "Type '",
658 info.type()->repr_str(),
659 "' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and"
660 " Tuples of Tensors can be traced");
661 return info.toTupleRef().elements().vec();
662}
663
664// Serialize the python dictionary into a traceable stack.
665inline Stack toTraceableStack(const py::dict& inputs) {
666 Stack res;
667 for (auto it = inputs.begin(); it != inputs.end(); it++) {
668 if (THPVariable_Check(it->second.ptr())) {
669 res.push_back(toIValue(it->second, tryToInferType(it->second).type()));
670 }
671 }
672 return res;
673}
674
675inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
676 auto elems = c10::impl::GenericList(elem_type);
677 for (auto elem : obj) {
678 elems.push_back(toIValue(elem, elem_type));
679 }
680 return IValue(elems);
681}
682
683inline IValue createGenericDict(
684 const py::dict& obj,
685 const TypePtr& key_type,
686 const TypePtr& value_type) {
687 c10::impl::GenericDict elems(key_type, value_type);
688 elems.reserve(py::len(obj));
689 for (auto& entry : obj) {
690 elems.insert(
691 toIValue(entry.first, key_type), toIValue(entry.second, value_type));
692 }
693 return IValue(elems);
694}
695
696template <class T>
697inline void guardAgainstNamedTensor(const T& var) {
698 TORCH_CHECK(
699 !var.has_names(),
700 "NYI: Named tensors are currently unsupported in TorchScript. As a "
701 "workaround please drop names via `tensor = tensor.rename(None)`.");
702}
703
704// Defined in pybind_utils.cpp to break a circular dependency with
705// python_ivalue.h
706IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N);
707
708// Extract custom class registered with torchbind
709template <typename T>
710c10::intrusive_ptr<T> toCustomClass(py::handle obj) {
711 static_assert(
712 std::is_base_of<CustomClassHolder, T>::value, "T is not a CustomClass");
713 const auto& type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
714 c10::IValue ivalue = toIValue(obj, type);
715 return std::move(ivalue).toCustomClass<T>();
716}
717
718// Small wrapper around getting the type name string from Python to make
719// types easier to interpret, e.g. give the structural type for a NamedTuple
720inline std::string friendlyTypeName(py::handle obj) {
721 if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
722 auto field_names =
723 py::cast<std::vector<std::string>>(py::getattr(obj, "_fields"));
724 std::stringstream ss;
725 ss << py::str(obj.get_type().attr("__name__"));
726 ss << " (aka NamedTuple(";
727 bool first = true;
728 for (auto& field_name : field_names) {
729 if (!first) {
730 ss << ", ";
731 }
732 ss << field_name;
733 first = false;
734 }
735 ss << "))";
736 return ss.str();
737 } else {
738 return py::str(obj.get_type().attr("__name__"));
739 }
740}
741
742// Thrown when trying to create a schema for a list of python
743// arguments that cannot be converted.
744// Can be caught by the caller to attempt to use other schema
745// when there is an overloaded operator.
746struct schema_match_error : public std::runtime_error {
747 using std::runtime_error::runtime_error;
748};
749
750inline IValue argumentToIValue(
751 const FunctionSchema& schema,
752 size_t argumentPosition,
753 py::handle object) {
754 const auto& argument = schema.arguments().at(argumentPosition);
755 try {
756 return toIValue(object, argument.real_type(), argument.N());
757 } catch (const py::cast_error& error) {
758 throw schema_match_error(c10::str(
759 schema.formatTypeMismatchMsg(
760 argument,
761 friendlyTypeName(object),
762 argumentPosition,
763 py::repr(object)),
764 "\nCast error details: ",
765 error.what()));
766 } catch (const py::error_already_set& error) {
767 throw schema_match_error(c10::str(
768 schema.formatTypeMismatchMsg(
769 argument,
770 friendlyTypeName(object),
771 argumentPosition,
772 py::repr(object)),
773 "\n Python error details: ",
774 error.what()));
775 }
776}
777
778inline IValue returnToIValue(const TypePtr& type, py::handle object) {
779 try {
780 return toIValue(object, type);
781 } catch (const py::cast_error& error) {
782 throw std::runtime_error(c10::str(
783 " expected value of type ",
784 type->str(),
785 " for return value but instead got value of type ",
786 py::str(object.get_type().attr("__name__")),
787 ".",
788 "\nValue: ",
789 py::repr(object),
790 "\nCast error details: ",
791 error.what()));
792 }
793}
794
795inline py::object getScriptedClassOrError(const c10::NamedTypePtr& classType) {
796 auto py_class =
797 py::module::import("torch.jit._state")
798 .attr("_get_python_class")(classType->name()->qualifiedName());
799 if (py_class.is_none()) {
800 std::stringstream err;
801 err << "Unknown reference to ScriptClass ";
802 err << classType->name()->qualifiedName();
803 err << ". (Did you forget to import it?)";
804 throw std::runtime_error(err.str());
805 }
806 return py_class;
807}
808
809struct VISIBILITY_HIDDEN tuple_slice {
810 /*implicit*/ tuple_slice(py::tuple tup_)
811 : tup(std::move(tup_)), b(0), e(tup.size()) {}
812 tuple_slice(py::tuple tup_, int64_t b_)
813 : tup(std::move(tup_)), b(b_), e(tup.size()) {}
814 tuple_slice(py::tuple tup_, int64_t b_, int64_t e_)
815 : tup(std::move(tup_)), b(b_), e(e_) {}
816 py::detail::tuple_iterator begin() const {
817 return {tup, static_cast<pybind11::ssize_t>(b)};
818 }
819 py::detail::tuple_iterator end() const {
820 return {tup, static_cast<pybind11::ssize_t>(e)};
821 }
822 size_t size() const {
823 return e - b;
824 }
825 py::detail::tuple_accessor operator[](size_t index) const {
826 return {tup, static_cast<size_t>(b + index)};
827 }
828
829 private:
830 py::tuple tup;
831 int64_t b;
832 int64_t e;
833};
834
835inline Stack createStackForSchema(
836 const FunctionSchema& schema,
837 const tuple_slice& args,
838 const py::kwargs& kwargs,
839 c10::optional<IValue> self) {
840 size_t all_arguments = (self ? 1 : 0) + args.size() + kwargs.size();
841 if (all_arguments > schema.arguments().size()) {
842 throw schema_match_error(c10::str(
843 schema.name(),
844 "() expected at most ",
845 schema.arguments().size(),
846 " argument(s) but received ",
847 all_arguments,
848 " argument(s). Declaration: ",
849 schema));
850 }
851 Stack stack;
852 stack.reserve(schema.arguments().size());
853
854 int64_t arg_idx = 0;
855 if (self) {
856 push(stack, std::move(*self));
857 arg_idx++;
858 }
859 // First push all positional args.
860 for (const auto& arg : args) {
861 // ...but refuse to do it if the schema says that this was supposed
862 // to be keyword only
863 if (schema.arguments()[arg_idx].kwarg_only()) {
864 throw schema_match_error(c10::str(
865 schema.name(),
866 "() takes ",
867 arg_idx,
868 " positional argument(s) but ",
869 self ? 1 + args.size() : args.size(),
870 " was/were given. Declaration: ",
871 schema));
872 }
873 // Use the type information from the schema to convert the PyObject.
874 push(stack, argumentToIValue(schema, stack.size(), arg));
875 arg_idx++;
876 }
877
878 // Now for every remaining non-positional argument in the schema, look for it
879 // in the kwargs dict and push it if found, or use its default value if it
880 // has one.
881 size_t consumed_kwargs = 0;
882 for (size_t i = stack.size(); i < schema.arguments().size(); ++i) {
883 const auto& arg = schema.arguments()[i];
884 if (kwargs.contains(arg.name().c_str())) {
885 push(stack, argumentToIValue(schema, i, kwargs[arg.name().c_str()]));
886 consumed_kwargs += 1;
887 } else if (arg.default_value()) {
888 push(stack, *arg.default_value());
889 } else {
890 throw schema_match_error(c10::str(
891 schema.name(),
892 "() is missing value for argument '",
893 arg.name(),
894 "'. Declaration: ",
895 schema));
896 }
897 }
898
899 if (consumed_kwargs != kwargs.size()) {
900 std::vector<std::string> names;
901 for (const auto& kwarg : kwargs) {
902 names.emplace_back(py::cast<std::string>(kwarg.first));
903 }
904 throw schema_match_error(schema.findErrorInKwargs(names));
905 }
906
907 return stack;
908}
909
910inline py::object createPyObjectForStack(Stack&& stack) {
911 if (stack.empty()) {
912 return py::none();
913 }
914
915 // Return a simple value and not a single-element tuple if there is only one
916 // return value.
917 if (stack.size() == 1) {
918 return toPyObject(std::move(stack[0]));
919 }
920
921 // If there is more than one return value, pop them into a py::tuple.
922 py::tuple return_values(stack.size());
923 for (const auto ret : c10::irange(return_values.size())) {
924 return_values[ret] = toPyObject(std::move(stack[ret]));
925 }
926
927 return std::move(return_values);
928}
929
930// TODO: Remove once we clean up the GraphExecutor usage.
931inline Stack evilDeprecatedBadCreateStackDoNotUse(
932 const py::tuple& tuple,
933 at::ArrayRef<Value*> inputs,
934 size_t reserve_extra_space = 0) {
935 if (tuple.size() != inputs.size()) {
936 AT_ERROR(
937 "expected " + std::to_string(inputs.size()) + " inputs, but got " +
938 std::to_string(tuple.size()));
939 }
940 Stack result;
941 result.reserve(tuple.size() + reserve_extra_space);
942 for (const auto i : c10::irange(inputs.size())) {
943 result.push_back(toIValue(std::move(tuple[i]), inputs[i]->type()));
944 }
945 return result;
946}
947
948// Run `callee`, potentially inserting a CallFunction/CallMethod node into the
949// tracing graph.
950inline py::object runAndInsertCall(
951 Function& callee,
952 const tuple_slice& args,
953 const py::kwargs& kwargs,
954 c10::optional<IValue> self,
955 // Lambda that tells this function how to insert `callee` into the graph if
956 // we're tracing.
957 const std::function<Value*(Graph&, const MatchedSchema& match)>&
958 callInserter) {
959 auto stack =
960 createStackForSchema(callee.getSchema(), args, kwargs, std::move(self));
961 const auto& tracing_state = tracer::getTracingState();
962 if (!tracing_state) {
963 pybind11::gil_scoped_release no_gil_guard;
964 // If we're not tracing, just run the callee as normal.
965 callee.run(stack);
966 } else {
967 // If we are tracing, insert the appropriate CallFunction or CallMethod node
968 // and then run the callee with tracing disabled.
969
970 // Get the graph `Value`s that represent the input IValues
971 auto inputs = last(stack, callee.num_inputs());
972 auto input_values =
973 fmap(inputs, [](const IValue& v) { return tracer::getValueTrace(v); });
974 TORCH_INTERNAL_ASSERT(callee.getSchema().returns().size() == 1)
975 auto return_type = callee.getSchema().returns().at(0).type();
976 auto graph = tracing_state->graph;
977 std::vector<NamedValue> named_values;
978 named_values.reserve(input_values.size());
979 for (Value* v : input_values) {
980 named_values.emplace_back(v);
981 }
982
983 // Add a call node.
984 MatchedSchema match = matchSchema(
985 callee.getSchema(),
986 tracer::getPythonInterpreterSourceRange(),
987 *graph,
988 named_values,
989 {});
990 auto output_value = callInserter(*graph, match);
991
992 // Actually run the callee. Pause the tracer so that we don't double-add the
993 // callee nodes.
994 {
995 pybind11::gil_scoped_release no_gil_guard;
996 ResourceGuard guard(tracer::pauseTracing());
997 callee.run(stack);
998 }
999
1000 // Associate the output IValues with the output `Value`s in the graph
1001 tracer::setValueTrace(stack.back(), output_value);
1002 }
1003
1004 TORCH_CHECK(
1005 !stack.empty(),
1006 "Expected values in the stack after execution but found none");
1007 return toPyObject(std::move(stack.back()));
1008}
1009
1010inline c10::optional<py::object> maybeTorchFunctionDispatch(
1011 const py::object& callee,
1012 const tuple_slice& args_no_self,
1013 const py::kwargs& kwargs,
1014 const c10::QualifiedName qualname) {
1015 std::vector<py::handle> args_vec;
1016 for (const auto& arg : args_no_self) {
1017 args_vec.push_back(arg);
1018 }
1019 py::tuple args = py::cast(args_vec);
1020
1021 // Handle __torch_function__ dispatch
1022 std::vector<py::handle> overloaded_args;
1023 size_t total_arg_num = args.size() + kwargs.size();
1024 for (const auto& arg : args) {
1025 is_tensor_and_append_overloaded(arg.ptr(), &overloaded_args);
1026 is_tensor_list_and_append_overloaded(
1027 arg.ptr(),
1028 &overloaded_args,
1029 static_cast<int>(total_arg_num),
1030 false /* throw_error */);
1031 }
1032 // NB: for kwargs, we cannot guarantee the order of appending
1033 // is the same as the argument order in operator's schema.
1034 // This is suboptimal, but should be fine. Later when we have
1035 // better schema matching and argument parsing, we could
1036 // match the operator in `operations` first, then the order will
1037 // be guaranteed.
1038 for (auto item : kwargs) {
1039 is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
1040 is_tensor_list_and_append_overloaded(
1041 item.second.ptr(),
1042 &overloaded_args,
1043 total_arg_num,
1044 false /* throw_error */);
1045 }
1046 if (!overloaded_args.empty()) {
1047 return pybind11::reinterpret_steal<py::object>(
1048 handle_torch_function_no_python_arg_parser(
1049 /*overloaded_args=*/overloaded_args,
1050 /*args=*/args.ptr(),
1051 /*kwargs=*/kwargs.ptr(),
1052 /*func_name=*/qualname.name().c_str(),
1053 /*torch_api_function=*/callee.ptr(),
1054 /*module_name=*/qualname.prefix().c_str()));
1055 }
1056
1057 return c10::nullopt;
1058}
1059
1060inline py::object invokeScriptFunctionFromPython(
1061 Function& callee,
1062 const tuple_slice& args,
1063 const py::kwargs& kwargs) {
1064 // TODO: we could add __torch_function__ dispatch here but I don't know
1065 // the implications of doing so
1066
1067 return runAndInsertCall(
1068 callee,
1069 args,
1070 kwargs,
1071 /*self=*/c10::nullopt,
1072 [&](Graph& graph, const MatchedSchema& match) {
1073 return graph.insertFunctionCall(&callee, match);
1074 });
1075}
1076
1077inline py::object invokeScriptMethodFromPython(
1078 Method& callee,
1079 const tuple_slice& args,
1080 const py::kwargs& kwargs) {
1081 auto self = callee.owner()._ivalue();
1082
1083 if (auto torch_fn_result = maybeTorchFunctionDispatch(
1084 py::cast(callee), args, kwargs, callee.name())) {
1085 return *torch_fn_result;
1086 }
1087
1088 return runAndInsertCall(
1089 callee.function(),
1090 args,
1091 kwargs,
1092 self,
1093 [&](Graph& graph, const MatchedSchema& match) {
1094 return graph.insertMethodCall(callee.name(), match);
1095 });
1096}
1097
1098TORCH_API std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
1099 const std::vector<std::shared_ptr<Operator>>& operations,
1100 py::args args,
1101 const py::kwargs& kwargs);
1102
1103TORCH_API py::object invokeOperatorFromPython(
1104 const std::vector<std::shared_ptr<Operator>>& operations,
1105 py::args args,
1106 const py::kwargs& kwargs,
1107 c10::optional<c10::DispatchKey> dk = c10::nullopt);
1108
1109TORCH_API py::object _get_operation_for_overload_or_packet(
1110 const std::vector<std::shared_ptr<Operator>>& operations,
1111 Symbol symbol,
1112 py::args args,
1113 const py::kwargs& kwargs,
1114 bool is_overload,
1115 c10::optional<c10::DispatchKey> dk = c10::nullopt);
1116
1117} // namespace jit
1118} // namespace torch
1119