1// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#include "minpybind.h"
8#include <frameobject.h>
9#include <opcode.h>
10#include <utility>
11#include <new>
12#include <iostream>
13#include <vector>
14//#include <torch/csrc/autograd/python_variable.h>
15#include <torch/csrc/utils/python_compat.h>
16#include <torch/csrc/Export.h>
17#include <ATen/functorch/BatchedTensorImpl.h>
18#include <ATen/functorch/DynamicLayer.h>
19#include <ATen/ATen.h>
20#include <memory>
21#include "arena.h"
22#include "python_variable_simple.h"
23
24#if IS_PYTHON_3_11_PLUS
25#define Py_BUILD_CORE
26#include "internal/pycore_opcode.h"
27#undef Py_BUILD_CORE
28#endif
29
30// C++ API functions for objects to
31// * construct the object, returning a ref-counted handle
32// * The actual API, with methods that take/return C-typed values
33
34// extend minpybind.h to include
35// * typed handles so that -> can get to their raw API
36// * object/handle distinction for the typed handles
37
38// class Dim: ---------------
39py::handle torch_Tensor___mul__;
40py::handle _Tensor;
41py::handle _Tensor_sum;
42py::handle NamedTuple;
43py::dict_view pointwise;
44py::handle torch_Tensor_expand;
45binaryfunc THPVariable_getitem;
46objobjargproc THPVariable_setitem;
47py::handle no_slice;
48PyTypeObject* torch_Tensor;
49py::handle torch_Tensor_copy_;
50py::handle torch_Tensor_split;
51bool pointwise_optimize = true;
52PyTypeObject* DimType = nullptr;
53
54static void maybeInitializeGlobals() {
55 // globals that depend on the python dim library,
56 // which we can't lookup until we finish initializing the _C module
57 if (_Tensor.ptr()) {
58 return;
59 }
60 auto dim = py::import("functorch.dim");
61 _Tensor = dim.attr("_Tensor");
62 pointwise = dim.attr("pointwise");
63 _Tensor_sum = _Tensor.attr("sum");
64 DimType = (PyTypeObject*) py::import("functorch.dim").attr("Dim").ptr();
65}
66
67PyObject* Tensor_getitem(PyObject* self, PyObject* index);
68int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value);
69
70void replaceMappingIfMatches(py::handle tp) {
71 auto T = (PyTypeObject*) tp.ptr();
72 bool recurse = false;
73 if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) {
74 T->tp_as_mapping->mp_subscript = Tensor_getitem;
75 recurse = true;
76 }
77 if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) {
78 T->tp_as_mapping->mp_ass_subscript = Tensor_setitem;
79 recurse = true;
80 }
81 if (recurse) {
82 auto result = tp.attr("__subclasses__").call();
83 py::list_view lv(result);
84 for (auto i : lv.enumerate()) {
85 replaceMappingIfMatches(lv[i]);
86 }
87 }
88}
89
90static void initializeGlobals(Arena & A) {
91 auto torch = py::import("torch");
92 torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr();
93 torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__");
94
95 torch_Tensor_expand = torch.attr("_C").attr("_TensorBase").attr("expand");
96 torch_Tensor_split = torch.attr("_C").attr("_TensorBase").attr("split");
97 torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_");
98 auto py_TensorBase = torch.attr("_C").attr("_TensorBase");
99 auto TensorBase = (PyTypeObject*) py_TensorBase.ptr();
100 THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript;
101 THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript;
102 NamedTuple = py::import("typing").attr("NamedTuple");
103 no_slice = PySlice_New(NULL, NULL, NULL);
104
105}
106
107py::handle DimensionBindError_;
108static py::handle DimensionBindError() {
109 if(!DimensionBindError_.ptr()) {
110 DimensionBindError_ = py::import("functorch.dim").attr("DimensionBindError");
111 }
112 return DimensionBindError_;
113}
114
115static int64_t n_dims_created = 65;
116
117struct Dim : public py::base<Dim> {
118 int64_t level_; // for stable comparisons in prototype
119 py::object name_;
120 Dim()
121 : level_(n_dims_created++) {}
122 void init(py::object name, int64_t s = -1) {
123 name_ = std::move(name);
124 size_ = s;
125 }
126
127 static bool check_exact(py::handle v) {
128 return Py_TYPE(v.ptr()) == DimType;
129 }
130
131 int64_t size() const {
132 if (size_ == -1) {
133 py::raise_error(PyExc_ValueError, "dimension %S is unbound", name_.ptr());
134 }
135 return size_;
136 }
137 void set_size(int64_t v) {
138 if (size_ == -1) {
139 size_ = v;
140 } else if(size_ != v) {
141 py::raise_error(DimensionBindError(), "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", this, this->size_, v);
142 }
143 }
144 bool is_bound() const {
145 return size_ != -1;
146 }
147 static py::obj<Dim> create(py::object name, int64_t s = -1) {
148 if (!DimType) {
149 maybeInitializeGlobals();
150 }
151 auto r = Dim::alloc(DimType);
152 r->init(std::move(name), s);
153 return r;
154 }
155 static PyTypeObject Type;
156 const at::Tensor& range() {
157 if (!range_.defined()) {
158 range_ = at::arange(size());
159 }
160 return range_;
161 }
162 const at::Tensor& batchtensor() {
163 if (!batchtensor_.defined()) {
164 batchtensor_ = at::functorch::addBatchDim(range(), 0, level_);
165 }
166 return batchtensor_;
167 }
168private:
169 int64_t size_{-1};
170 at::Tensor range_;
171 at::Tensor batchtensor_;
172};
173
174struct DimEntry {
175 // union of either a negative number indicating which dimension this is from the rhs,
176 // or a pointer to a first-class dimension.
177 // pointers do not have their highest bit set, so checking the number is negative tells us
178 // that it is not a dim.
179 bool is_positional() const {
180 return data_ < 0;
181 }
182 bool is_none() const {
183 return data_ == 0;
184 }
185 int64_t position() const {
186 return data_;
187 }
188 py::hdl<Dim> dim() const {
189 Dim* result;
190 std::memcpy(&result, &data_, sizeof(Dim*));
191 return py::hdl<Dim>(result);
192 }
193
194 DimEntry()
195 : data_(0) {}
196
197 DimEntry(int64_t pos)
198 : data_(pos) {
199 AT_ASSERT(pos < 0);
200 }
201 DimEntry(py::hdl<Dim> d) {
202 std::memcpy(&data_, &d, sizeof(int64_t));
203 }
204 bool operator==(const DimEntry& rhs) const {
205 return data_ == rhs.data_;
206 }
207private:
208 int64_t data_;
209};
210
211std::ostream& operator<<(std::ostream& ss, DimEntry entry) {
212 if (entry.is_none()) {
213 ss << "None";
214 } else if (entry.is_positional()) {
215 ss << entry.position();
216 } else {
217 ss << entry.dim();
218 }
219 return ss;
220}
221
222// Dim wrapper methods
223
224static int Dim_init(py::hdl<Dim> self, PyObject *args, PyObject *kwds) {
225 PY_BEGIN
226 static char* kwlist[] = {"name", "size", nullptr};
227 py::handle name;
228 py::handle size = nullptr;
229 if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &name, &size)) {
230 return -1;
231 }
232 self->init(py::object::borrow(name), (size.ptr() && !py::is_none(size)) ? py::to_int(size) : -1);
233 return 0;
234 PY_END(-1)
235}
236
237static PyObject* Dim_repr(Dim* self) {
238 PY_BEGIN
239 py::object name = (self->name_.ptr()) ? self->name_ : py::unicode_from_string("<uninitialized dim>");
240 return name.release();
241 PY_END(nullptr)
242}
243
244
245static PyObject* Dim_getsize(Dim* self, void*) {
246 PY_BEGIN
247 return py::from_int(self->size()).release();
248 PY_END(nullptr)
249}
250
251int Dim_setsize(Dim* self, PyObject* size, void*) {
252 PY_BEGIN
253 self->set_size(py::to_int(size));
254 return 0;
255 PY_END(-1)
256}
257
258static PyObject* Dim_getis_bound(Dim* self, void*) {
259 return PyBool_FromLong(self->is_bound());
260}
261
262static PyObject* Dim_getlevel(Dim* self, void*) {
263 return PyLong_FromLong(self->level_);
264}
265
266static PyObject* Dim_get_levels(Dim* self, void*) {
267 py::tuple t(1);
268 t.set(0, py::object::borrow(self->ptr()));
269 return t.release();
270}
271
272static PyObject* Dim_get_has_device(Dim* self, void*) {
273 Py_RETURN_FALSE;
274}
275
276static PyObject* Dim_get_tensor(Dim* self, void*) {
277 return THPVariable_Wrap(self->range());
278}
279
280static PyObject* Dim_get_batchtensor(Dim* self, void*) {
281 return THPVariable_Wrap(self->batchtensor());
282}
283
284
285static PyGetSetDef Dim_getsetters[] = {
286 {"size", (getter) Dim_getsize, (setter) Dim_setsize,
287 "Dimension size", NULL},
288 {"is_bound", (getter) Dim_getis_bound, NULL, "is_bound", NULL},
289 {"_level", (getter) Dim_getlevel, NULL, "_level", NULL},
290 {"_levels", (getter) Dim_get_levels, NULL, "_levels", NULL},
291 {"_has_device", (getter) Dim_get_has_device, NULL, "_has_device", NULL},
292 {"_tensor", (getter) Dim_get_tensor, NULL, "_tensor", NULL},
293 {"_batchtensor", (getter) Dim_get_batchtensor, NULL, "_batchtensor", NULL},
294 {"ndim", (getter) [](PyObject* self, void*) -> PyObject* { return py::from_int(1).release(); }, NULL, "ndim", NULL},
295 {NULL} /* Sentinel */
296};
297
298PyTypeObject Dim::Type = {
299 PyVarObject_HEAD_INIT(NULL, 0)
300 "_C.Dim", /* tp_name */
301 sizeof(Dim), /* tp_basicsize */
302 0, /* tp_itemsize */
303 Dim::dealloc_stub, /* tp_dealloc */
304 0, /* tp_vectorcall_offset */
305 0, /* tp_getattr */
306 0, /* tp_setattr */
307 0, /* tp_as_async */
308 (reprfunc)Dim_repr, /* tp_repr */
309 0, /* tp_as_number */
310 0, /* tp_as_sequence */
311 0, /* tp_as_mapping */
312 0, /* tp_hash */
313 0, /* tp_call */
314 0, /* tp_str */
315 0, /* tp_getattro */
316 0, /* tp_setattro */
317 0, /* tp_as_buffer */
318 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
319 "Dim Object", /* tp_doc */
320 0, /* tp_traverse */
321 0, /* tp_clear */
322 0, /* tp_richcompare */
323 0, /* tp_weaklistoffset */
324 0, /* tp_iter */
325 0, /* tp_iternext */
326 0, /* tp_methods */
327 0, /* tp_members */
328 Dim_getsetters, /* tp_getset */
329 0, /* tp_base */
330 0, /* tp_dict */
331 0, /* tp_descr_get */
332 0, /* tp_descr_set */
333 0, /* tp_dictoffset */
334 (initproc)(void*) Dim_init, /* tp_init */
335 0, /* tp_alloc */
336 Dim::new_stub, /* tp_new */
337};
338
339// class DimList ------------
340
341struct DimList : public py::base<DimList> {
342 py::object name_;
343 std::vector<py::obj<Dim>> dims_;
344 static PyTypeObject Type;
345 void init(py::object name) {
346 name_ = std::move(name);
347 }
348 void set_dims(std::vector<py::obj<Dim>> dims) {
349 bound_ = true;
350 dims_ = std::move(dims);
351 }
352 bool is_bound() {
353 return bound_;
354 }
355 void bind_len(int64_t size) {
356 if (bound_) {
357 int64_t b_size = dims_.size();
358 if (b_size != size) {
359 py::raise_error(DimensionBindError(), "Dimlist has size %lld but it is being bound to size %d", b_size, size);
360 }
361 } else {
362 bound_ = true;
363 dims_.resize(size);
364 for (Py_ssize_t i = 0; i < size; ++i) {
365 dims_[i] = Dim::create(py::unicode_from_format("%S%i", name_.ptr(), (int)i));
366 }
367 }
368 }
369 int64_t size() const {
370 if (!bound_) {
371 py::raise_error(DimensionBindError(), "DimList not bound");
372 }
373 return dims_.size();
374 }
375 void set_bound(bool b) {
376 bound_ = b;
377 }
378private:
379 bool bound_ = false;
380};
381
382
383static int DimList_init(DimList *self, PyObject *args, PyObject *kwds);
384
385static PyObject* DimList_repr(DimList* self) {
386 PY_BEGIN
387 if (self->is_bound()) {
388 size_t size = self->dims_.size();
389 py::tuple t(size);
390 for(size_t i = 0; i < size; ++i) {
391 t.set(i, self->dims_[i]);
392 }
393 return py::repr(t).release();
394 } else if(!py::is_none(self->name_)) {
395 return py::unicode_from_format("*%S", self->name_.ptr()).release();
396 } else {
397 return py::unicode_from_string("<unbound_dimlist>").release();
398 }
399 PY_END(nullptr)
400}
401
402static PyObject* DimList_bind(DimList *self,
403 PyObject *const *args,
404 Py_ssize_t nargs,
405 PyObject *kwnames) {
406 PY_BEGIN
407 py::handle sizes;
408 static const char * const _keywords[] = {"sizes", nullptr};
409 static _PyArg_Parser parser = {"O", _keywords, 0};
410 if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) {
411 return nullptr;
412 }
413 if (!py::is_sequence(sizes)) {
414 py::raise_error(PyExc_ValueError, "expected a sequence");
415 }
416 py::sequence_view seq = sizes;
417 auto size = seq.size();
418 self->bind_len(size);
419 for (Py_ssize_t i = 0; i < size; ++i) {
420 self->dims_[i]->set_size(py::to_int(seq[i]));
421 }
422 Py_RETURN_NONE;
423 PY_END(nullptr)
424}
425
426static PyObject* DimList_bind_len(DimList *self,
427 PyObject *const *args,
428 Py_ssize_t nargs,
429 PyObject *kwnames) {
430 PY_BEGIN
431 int size;
432 static const char * const _keywords[] = {"N", nullptr};
433 static _PyArg_Parser parser = {"i", _keywords, 0};
434 if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) {
435 return nullptr;
436 }
437 self->bind_len(size);
438 Py_RETURN_NONE;
439 PY_END(nullptr)
440}
441
442static PyMethodDef DimList_methods[] = {
443 {"bind", (PyCFunction)(void*) DimList_bind, METH_FASTCALL | METH_KEYWORDS},
444 {"bind_len", (PyCFunction)(void*) DimList_bind_len, METH_FASTCALL | METH_KEYWORDS},
445 {NULL, NULL, 0, NULL} /* Sentinel */
446};
447
448
449static Py_ssize_t DimList_len(DimList* self) {
450 PY_BEGIN
451 return self->size();
452 PY_END(-1)
453}
454
455PyObject * DimList_item(DimList* self, Py_ssize_t idx) {
456 PY_BEGIN
457 if (!self->is_bound()) {
458 py::raise_error(DimensionBindError(), "DimList not bound");
459 }
460 if (idx < 0 || (size_t) idx >= self->dims_.size()) {
461 py::raise_error(PyExc_IndexError, "index out of bounds");
462 }
463 py::object r = self->dims_[idx];
464 return r.release();
465 PY_END(nullptr)
466}
467
468PySequenceMethods DimList_seq {
469 (lenfunc) DimList_len, //lenfunc sq_length;
470 0, //binaryfunc sq_concat;
471 0, //ssizeargfunc sq_repeat;
472 (ssizeargfunc) DimList_item, //ssizeargfunc sq_item;
473 0, //void *was_sq_slice;
474 0, //ssizeobjargproc sq_ass_item;
475 0, //void *was_sq_ass_slice;
476 0, //objobjproc sq_contains;
477
478 0, //binaryfunc sq_inplace_concat;
479 0, //ssizeargfunc sq_inplace_repeat;
480};
481
482static PyObject* DimList_getis_bound(DimList* self, void*) {
483 return PyBool_FromLong(self->is_bound());
484}
485
486static PyGetSetDef DimList_getsetters[] = {
487 {"is_bound", (getter) DimList_getis_bound, NULL, "is_bound", NULL},
488 {NULL} /* Sentinel */
489};
490
491
492static PyObject* DimList_subscript(DimList* self, py::handle idx) {
493 PY_BEGIN
494 if (py::is_int(idx)) {
495 return DimList_item(self, py::to_int(idx));
496 } else if (py::is_slice(idx)) {
497 if (!self->is_bound()) {
498 py::raise_error(DimensionBindError(), "DimList not bound");
499 }
500 py::slice_view s(idx, self->dims_.size());
501 py::tuple r(s.slicelength);
502 for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) {
503 r.set(j++, self->dims_[i]);
504 }
505 return r.release();
506 } else {
507 py::raise_error(PyExc_ValueError, "expected an int or a slice");
508 return nullptr;
509 }
510 PY_END(nullptr)
511}
512
513PyMappingMethods DimList_mapping = {
514 0, //lenfunc mp_length;
515 (binaryfunc)(void*) DimList_subscript, //binaryfunc mp_subscript;
516 0, //objobjargproc mp_ass_subscript;
517};
518
519
520
521PyTypeObject DimList::Type = {
522 PyVarObject_HEAD_INIT(NULL, 0)
523 "_C.DimList", /* tp_name */
524 sizeof(DimList), /* tp_basicsize */
525 0, /* tp_itemsize */
526 DimList::dealloc_stub, /* tp_dealloc */
527 0, /* tp_vectorcall_offset */
528 0, /* tp_getattr */
529 0, /* tp_setattr */
530 0, /* tp_as_async */
531 (reprfunc)DimList_repr, /* tp_repr */
532 0, /* tp_as_number */
533 &DimList_seq, /* tp_as_sequence */
534 &DimList_mapping, /* tp_as_mapping */
535 0, /* tp_hash */
536 0, /* tp_call */
537 0, /* tp_str */
538 0, /* tp_getattro */
539 0, /* tp_setattro */
540 0, /* tp_as_buffer */
541 0, /* tp_flags */
542 "DimList Object", /* tp_doc */
543 0, /* tp_traverse */
544 0, /* tp_clear */
545 0, /* tp_richcompare */
546 0, /* tp_weaklistoffset */
547 0, /* tp_iter */
548 0, /* tp_iternext */
549 DimList_methods, /* tp_methods */
550 0, /* tp_members */
551 DimList_getsetters, /* tp_getset */
552 0, /* tp_base */
553 0, /* tp_dict */
554 0, /* tp_descr_get */
555 0, /* tp_descr_set */
556 0, /* tp_dictoffset */
557 (initproc) DimList_init, /* tp_init */
558 0, /* tp_alloc */
559 DimList::new_stub, /* tp_new */
560};
561
562static int DimList_init(DimList *self, PyObject *args, PyObject *kwds) {
563 PY_BEGIN
564 static char* kwlist[] = {"len_or_dims", "name", nullptr};
565 py::handle len_or_dims = nullptr;
566 PyObject* name = nullptr;
567 if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", kwlist, &len_or_dims, &name)) {
568 return -1;
569 }
570 self->init(py::object::borrow(name ? name : Py_None));
571 if (len_or_dims.ptr()) {
572 if(py::is_int(len_or_dims)) {
573 self->bind_len(py::to_int(len_or_dims));
574 } else if (py::is_sequence(len_or_dims)) {
575 py::sequence_view s(len_or_dims);
576 std::vector<py::obj<Dim>> dims;
577 size_t size = s.size();
578 dims.reserve(size);
579 for (size_t i = 0; i < size; ++i) {
580 auto r = s[i];
581 if (py::is_int(r)) {
582 dims.emplace_back(Dim::create(py::unicode_from_format("%S%i", self->name_.ptr(), (int)i), py::to_int(r)));
583 } else {
584 dims.emplace_back(Dim::wrap(r));
585 }
586 }
587 self->set_dims(std::move(dims));
588 } else {
589 PyErr_Format(PyExc_ValueError, "expected a length or a sequence of dimensions");
590 return -1;
591 }
592 return 0;
593 }
594 return 0;
595 PY_END(-1);
596}
597
598// Tensor -----------------------------
599
600PyTypeObject* TensorType = nullptr; // the python wrapper type.
601at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice<DimEntry> levels_);
602static py::object run_torch_function(Arena &A, py::handle orig, py::vector_args args, bool is_pointwise);
603void free_levels_dims(Slice<DimEntry> levels);
604
605struct Tensor;
606
607struct DelayedOperator {
608 DelayedOperator(py::object o, py::vector_args a)
609 : orig(std::move(o)), args(a) {
610 auto all = a.size();
611 // this will outlive the call so
612 // take ownership of temporaries
613 // in vector args
614 auto buf = new py::handle[all];
615 memcpy(buf, args.args, sizeof(py::handle)*all);
616 args.args = buf;
617 for (auto i : args.enumerate_all()) {
618 Py_INCREF(args.args[i].ptr());
619 }
620 Py_XINCREF(args.kwnames.ptr());
621 }
622 ~DelayedOperator() {
623 for (auto i : args.enumerate_all()) {
624 Py_DECREF(args[i].ptr());
625 }
626 if (args.has_keywords()) {
627 Py_XDECREF(args.kwnames.ptr());
628 }
629 delete [] args.args;
630 }
631 py::object orig;
632 py::vector_args args;
633};
634
635struct Tensor : public py::base<Tensor> {
636private:
637 at::Tensor tensor_;
638 at::Tensor batchtensor_;
639 OwnedSlice<DimEntry> levels_;
640 bool has_device_;
641 std::unique_ptr<DelayedOperator> delayed_;
642public:
643
644 at::Tensor& tensor(Arena& A) {
645 if (C10_UNLIKELY(!tensor_.defined())) {
646 AT_ASSERT(delayed_);
647 auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true));
648 tensor_ = t->tensor(A);
649 delayed_.reset();
650 // don't force creation of batch tensor if it wasn't alreay provided.
651 batchtensor_ = t->batchtensor_;
652 AT_ASSERT(levels() == t->levels());
653 }
654 return tensor_;
655 }
656 at::Tensor& batchtensor(Arena& A) {
657 if (C10_UNLIKELY(!batchtensor_.defined())) {
658 batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice());
659 }
660 return batchtensor_;
661 }
662 Slice<DimEntry> levels() {
663 return levels_.slice();
664 }
665 bool has_device() {
666 return has_device_;
667 }
668 DelayedOperator* delayed() {
669 return delayed_.get();
670 }
671 static PyTypeObject Type;
672
673 static bool check_exact(py::handle v) {
674 return Py_TYPE(v.ptr()) == TensorType;
675 }
676
677
678 static py::obj<Tensor> create() {
679 if (!TensorType) {
680 TensorType = (PyTypeObject*) py::import("functorch.dim").attr("Tensor").ptr();
681 }
682 return Tensor::alloc(TensorType);
683 }
684 void capture_levels(Slice<DimEntry> levels) {
685 // grab ownership of the dims inside levels
686 for (auto l : levels) {
687 if (!l.is_positional()) {
688 py::object::borrow(l.dim()).release();
689 }
690 }
691 levels_.set(levels, free_levels_dims);
692 }
693 static py::object from_positional(Arena & A, at::Tensor tensor, Slice<DimEntry> levels, bool has_device);
694 static py::obj<Tensor> create_delayed(py::object op, py::vector_args args, Slice<DimEntry> levels, bool has_device);
695 friend struct EnableAllLayers;
696};
697
698at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice<DimEntry> levels_) {
699 auto levels = Slice<DimEntry>();
700 levels.extend(A, levels_);
701 while (true) {
702 int64_t min_real_index = -1;
703 int64_t min_index = -1;
704 int64_t min_value = INT_MAX;
705 int64_t i = 0;
706 int64_t r = 0;
707 for (auto l : levels) {
708 if (!l.is_none()) {
709 if (!l.is_positional() && l.dim()->level_ < min_value) {
710 min_value = l.dim()->level_;
711 min_index = i;
712 min_real_index = r;
713 }
714 ++i;
715 }
716 ++r;
717 }
718 if (min_index == -1) {
719 return t;
720 }
721 auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value);
722 t = std::move(t2);
723 levels[min_real_index] = DimEntry();
724 }
725}
726
727void free_levels_dims(Slice<DimEntry> levels) {
728 for(auto e : levels) {
729 if (!e.is_positional()) {
730 py::object::steal(e.dim());
731 }
732 }
733}
734
735// version in header does a unnecessary refcount +/-
736inline at::functorch::BatchedTensorImpl* maybeGetBatchedImpl(const at::Tensor& tensor) {
737 if (at::functorch::isBatchedTensor(tensor)) {
738 return static_cast<at::functorch::BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
739 }
740 return nullptr;
741}
742
743inline TensorRef unchecked_tensor_from(py::handle p) {
744 auto v = (THPVariable*) p.ptr();
745 return TensorRef(*v->cdata);
746}
747
748int64_t ndim_of_levels(Slice<DimEntry> levels) {
749 int64_t r = 0;
750 for (auto l : levels) {
751 if (l.is_positional()) {
752 ++r;
753 }
754 }
755 return r;
756}
757
758struct TensorInfo {
759 TensorRef tensor;
760 Slice<DimEntry> levels;
761 bool has_device;
762 TensorRef batchedtensor;
763 int64_t ndim() const {
764 return ndim_of_levels(levels);
765 }
766 operator bool() const {
767 return tensor;
768 }
769
770 static TensorInfo create(Arena& A, py::handle h, bool ensure_batched=true, bool ensure_present=true) {
771 if (Tensor::check_exact(h)) {
772 auto t = Tensor::unchecked_wrap(h);
773 return TensorInfo {t->tensor(A), t->levels(), t->has_device(), ensure_batched ? t->batchtensor(A) : TensorRef()};
774 } else if (Dim::check_exact(h)) {
775 auto d = Dim::unchecked_wrap(h);
776 return TensorInfo {d->range(), Slice<DimEntry>(A, DimEntry(d)), false, ensure_batched ? d->batchtensor() : TensorRef()};
777 } else if (THPVariable_Check(h.ptr())) {
778 TensorRef t = unchecked_tensor_from(h);
779 Slice<DimEntry> levels;
780 for (auto i : irange(-t->dim(), 0)) {
781 levels.append(A, i);
782 }
783 return TensorInfo {t, levels, true, t};
784 } else {
785 if (ensure_present) {
786 py::raise_error(PyExc_ValueError, "expected a tensor object");
787 }
788 return TensorInfo {};
789 }
790 }
791
792
793};
794
795py::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice<DimEntry> levels, bool has_device) {
796 size_t seen_dims = 0;
797 int last = 0;
798 //auto sz = tensor.sizes();
799 for (auto i : levels.enumerate()) {
800 auto l = levels[i];
801 if (l.is_positional()) {
802 AT_ASSERT(last == 0 || last + 1 == l.position());
803 last = l.position();
804 } else {
805 py::object::borrow(l.dim()).release();
806 //AT_ASSERT(sz[i] == l.dim()->size());
807 ++seen_dims;
808 }
809 }
810 AT_ASSERT(last == 0 || last == -1);
811 if (!seen_dims) {
812 return py::object::steal(THPVariable_Wrap(std::move(tensor)));
813 }
814
815 py::obj<Tensor> self = Tensor::create();
816 self->tensor_ = std::move(tensor);
817 AT_ASSERT(self->tensor_.dim() == levels.size());
818 self->levels_.set(levels, free_levels_dims);
819 self->has_device_ = has_device;
820 py::object r = std::move(self);
821 return r;
822}
823
824
825static PyObject* py_Tensor_from_positional(PyObject *self,
826 PyObject *const *args,
827 Py_ssize_t nargs,
828 PyObject *kwnames) {
829 Arena A;
830 PY_BEGIN
831 #define ARGS(_) _(py::handle, tensor) _(py::handle, py_levels) _(int, has_device)
832 MPY_PARSE_ARGS_KWNAMES("OOp", ARGS)
833 #undef ARGS
834
835 if (!THPVariable_Check(tensor.ptr())) {
836 py::raise_error(PyExc_ValueError, "_tensor is not a Tensor?");
837 }
838
839 Slice<DimEntry> levels;
840 py::sequence_view sq(py_levels);
841 for (auto i : sq.enumerate()) {
842 py::object v = sq[i];
843 if (py::is_int(v)) {
844 auto vi = py::to_int(v);
845 levels.append(A, vi);
846 } else {
847 auto dim = Dim::wrap(std::move(v));
848 py::hdl<Dim> hdim = dim;
849 levels.append(A, hdim);
850 }
851 }
852 return Tensor::from_positional(A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0).release();
853 PY_END(nullptr)
854}
855
856py::obj<Tensor> Tensor::create_delayed(py::object op, py::vector_args args, Slice<DimEntry> levels, bool has_device) {
857 py::obj<Tensor> self = Tensor::create();
858 self->capture_levels(levels);
859 self->has_device_ = has_device;
860 self->delayed_ = std::make_unique<DelayedOperator>(op, args);
861 return self;
862}
863
864py::list slice_to_list(Slice<py::handle> h) {
865 py::list lst(h.size());
866 for (auto i : h.enumerate()) {
867 lst.set(i, py::object::borrow(h[i]));
868 }
869 return lst;
870}
871
872py::tuple slice_to_tuple(Slice<py::handle> h) {
873 py::tuple lst(h.size());
874 for (auto i : h.enumerate()) {
875 lst.set(i, py::object::borrow(h[i]));
876 }
877 return lst;
878}
879
880enum UType {
881 U_ELEM,
882 U_TUPLE_LIKE,
883 U_DICT,
884};
885
886struct Unflatten {
887 py::object operator()(Slice<py::handle>& elements) {
888 py::object r;
889 switch (type) {
890 case U_ELEM: {
891 r = py::object::borrow(elements[0]);
892 elements = elements.slice(1);
893 } break;
894 case U_TUPLE_LIKE: {
895 py::tuple tup(children.size());
896 for (auto i : children.enumerate()) {
897 tup.set(i, children[i](elements));
898 }
899 r = obj.call(tup);
900 } break;
901 case U_DICT: {
902 r = py::object::checked_steal(PyDict_New());
903 py::dict_view rv(r);
904 py::dict_view d(obj);
905 Py_ssize_t pos = 0;
906 py::handle k, v;
907 for (int i = 0; d.next(&pos, &k, &v); ++i) {
908 rv.set(k, children[i](elements));
909 }
910 } break;
911 }
912 return r;
913 }
914 UType type;
915 py::handle obj;
916 Slice<Unflatten> children;
917};
918
919Unflatten tree_flatten(Arena& A, py::handle agg, Slice<py::handle>& flat_elements) {
920 Slice<Unflatten> c;
921 UType utype;
922 py::handle obj;
923 if (py::list_view::check(agg)) {
924 obj = agg.type();
925 utype = U_TUPLE_LIKE;
926 py::list_view l(agg);
927 for (auto i : l.enumerate()) {
928 c.append(A, tree_flatten(A, l[i], flat_elements));
929 }
930 } else if (py::tuple_view::check(agg)) {
931 obj = agg.type();
932 utype = U_TUPLE_LIKE;
933 // includes named tuples
934 py::tuple_view l(agg);
935 for (auto i : l.enumerate()) {
936 c.append(A, tree_flatten(A, l[i], flat_elements));
937 }
938 } else if (py::dict_view::check(agg)) {
939 utype = U_DICT;
940 py::dict_view d(agg);
941 obj = agg;
942 Py_ssize_t pos = 0;
943 py::handle k, v;
944 while (d.next(&pos, &k, &v)) {
945 c.append(A, tree_flatten(A, v, flat_elements));
946 }
947 } else {
948 utype = U_ELEM;
949 flat_elements.append(A, agg);
950 }
951 return Unflatten {utype, obj, c};
952}
953
954struct UnflattenVectorArgs {
955 py::vector_args operator()(Arena& A, Slice<py::handle>& elements) {
956 if (!had_nested) {
957 auto args = elements.begin();
958 elements = Slice<py::handle>();
959 return py::vector_args(args, nargs, kwnames);
960 }
961 Slice<py::handle> args;
962 for (auto u : children) {
963 args.append(A, A.autorelease(u(elements)));
964 }
965 return py::vector_args(args.begin(), nargs, kwnames);
966 }
967 Slice<Unflatten> children;
968 Py_ssize_t nargs;
969 py::handle kwnames;
970 bool had_nested;
971};
972
973UnflattenVectorArgs tree_flatten(Arena& A, py::vector_args args, Slice<py::handle>& flat_elements) {
974 UnflattenVectorArgs r;
975 r.kwnames = args.kwnames;
976 r.nargs = args.nargs;
977 r.had_nested = false;
978 auto N = args.size();
979 for(auto i : irange(N)) {
980 auto typ = Py_TYPE(args[i].ptr());
981 // fast checks that this thing isn't something that is nested.
982 bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || typ == TensorType || typ == DimType;
983 if (!is_element) {
984 flat_elements.extend(A, args.args, args.args + i);
985 for (auto j : irange(i)) {
986 (void)j;
987 r.children.append(A, Unflatten {U_ELEM});
988 }
989 for (auto j : irange(i, N)) {
990 r.children.append(A, tree_flatten(A, args[j], flat_elements));
991 if (r.children.back().type != U_ELEM) {
992 r.had_nested = true;
993 }
994 }
995 return r;
996 }
997 }
998 flat_elements.extend(A, args.args, args.args + N);
999 return r;
1000}
1001
1002
1003struct UnflattenArena {
1004 Arena A;
1005 Unflatten unflatten;
1006};
1007
1008static PyObject* py_unflatten(PyObject *self,
1009 PyObject *const *args,
1010 Py_ssize_t nargs,
1011 PyObject *kwnames) {
1012 PY_BEGIN
1013 #define ARGS(_) _(py::handle, ns)
1014 MPY_PARSE_ARGS_KWNAMES("O", ARGS)
1015 #undef ARGS
1016 py::sequence_view sv(ns);
1017 // because we do not have a autorelase pool yet...
1018 Arena A;
1019 Slice<py::handle> slice;
1020 py::handle Tuple = (PyObject*) &PyTuple_Type;
1021 auto inputs = Tuple.call(ns);
1022 py::tuple_view tv(inputs);
1023 for (auto i : tv.enumerate()) {
1024 slice.append(A, tv[i]);
1025 }
1026 auto AA = (UnflattenArena*) PyCapsule_GetPointer(self, "arena");
1027 auto r = AA->unflatten(slice).release();
1028 AT_ASSERT(r != nullptr);
1029 return r;
1030 PY_END(nullptr)
1031}
1032
1033PyMethodDef py_unflatten_def = {"unflatten", (PyCFunction)(void*) py_unflatten, METH_FASTCALL | METH_KEYWORDS};
1034
1035void free_unflatten_arena(PyObject * pc) {
1036 delete (UnflattenArena*) PyCapsule_GetPointer(pc, "arena");
1037}
1038
1039static PyObject* py_tree_flatten(PyObject *self,
1040 PyObject *const *args,
1041 Py_ssize_t nargs,
1042 PyObject *kwnames) {
1043 PY_BEGIN
1044 #define ARGS(_) _(py::handle, tree)
1045 MPY_PARSE_ARGS_KWNAMES("O", ARGS)
1046 #undef ARGS
1047 auto A = new UnflattenArena;
1048 Slice<py::handle> elements;
1049 A->unflatten = tree_flatten(A->A, tree, elements);
1050 auto cap = py::object::checked_steal(PyCapsule_New(A, "arena", free_unflatten_arena));
1051 auto unflatten = py::object::checked_steal(PyCFunction_New(&py_unflatten_def, cap.release()));
1052 py::tuple r(2);
1053 r.set(0, slice_to_list(elements));
1054 r.set(1, std::move(unflatten));
1055 return r.release();
1056 PY_END(nullptr)
1057}
1058
1059
1060
1061py::object tree_map(Arena& A, std::function<py::handle(py::handle)> fn, py::handle agg) {
1062 Slice<py::handle> elements;
1063 auto unflatten = tree_flatten(A, agg, elements);
1064 for (auto i : elements.enumerate()) {
1065 elements[i] = fn(elements[i]);
1066 }
1067 return unflatten(elements);
1068}
1069
1070// prereq: isinstance(h, _Tensor)
1071inline int64_t _Tensor_ndim(py::handle h) {
1072 if (Tensor::check(h)) {
1073 int64_t r = 0;
1074 for (auto l : Tensor::unchecked_wrap(h)->levels()) {
1075 if (l.is_positional()) {
1076 ++r;
1077 }
1078 }
1079 return r;
1080 }
1081 // Dim or DelayedMulTensor
1082 return 0;
1083}
1084
1085inline py::handle handle_from_tensor(Arena& A, TensorRef t) {
1086 // fast case: tensor is live in python
1087 c10::optional<PyObject*> mb_obj =
1088 t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter());
1089 if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
1090 return *mb_obj;
1091 }
1092 return A.autorelease(py::object::checked_steal(THPVariable_Wrap(*t)));
1093}
1094
1095struct EnableAllLayers {
1096 EnableAllLayers(Arena& A, Slice<DimEntry> levels) {
1097 std::vector<std::pair<int64_t, int64_t>> layers;
1098 layers.reserve(levels.size());
1099 for (auto l : levels) {
1100 if (!l.is_positional()) {
1101 auto d = l.dim();
1102 levels_to_dim_.append(A, d);
1103 }
1104 }
1105 std::sort(levels_to_dim_.begin(), levels_to_dim_.end(), [](py::hdl<Dim> lhs, py::hdl<Dim> rhs) { return lhs->level_ < rhs->level_;});
1106
1107 for (auto i : levels_to_dim_.enumerate()) {
1108 auto batch_size = levels_to_dim_[i]->size();
1109 auto level = at::functorch::initAndPushDynamicLayer(at::functorch::TransformType::Vmap, batch_size, at::functorch::RandomnessType::Different);
1110 if (i == 0) {
1111 levels_start_ = level;
1112 }
1113 }
1114 }
1115
1116 ~EnableAllLayers() {
1117 auto to_remove = levels_start_ + levels_to_dim_.size() - 1;
1118 for (auto i : levels_to_dim_.enumerate()) {
1119 AT_ASSERT(at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == to_remove - i);
1120 }
1121 }
1122
1123 py::obj<Tensor> from_batched(Arena& A, at::Tensor batchedtensor, bool has_device) {
1124 Slice<DimEntry> levels;
1125 for (auto i : irange(-batchedtensor.dim(), 0)) {
1126 levels.append(A, i);
1127 }
1128 TensorRef tensor;
1129 at::functorch::BatchedTensorImpl * impl = maybeGetBatchedImpl(batchedtensor);
1130 while(true) {
1131 auto level = impl->level();
1132 AT_ASSERT(level >= levels_start_ && level < levels_start_ + levels_to_dim_.size());
1133 py::hdl<Dim> dim = levels_to_dim_[level - levels_start_].ptr();
1134 levels.insert(A, impl->bdim(), dim);
1135 at::functorch::BatchedTensorImpl * nimpl = maybeGetBatchedImpl(impl->value());
1136 if (!nimpl) {
1137 tensor = impl->value();
1138 break;
1139 }
1140 impl = nimpl;
1141 }
1142
1143 py::obj<Tensor> self = Tensor::create();
1144 // grab ownership of the tensors
1145 self->tensor_ = *tensor;
1146 self->batchtensor_ = std::move(batchedtensor);
1147 self->has_device_ = has_device;
1148 self->capture_levels(levels);
1149 return self;
1150 }
1151 void inplace_update_layers(TensorRef batchtensor, Slice<DimEntry> levels) {
1152 // XXX - requires a patch to functorch to att set_level
1153 auto impl = maybeGetBatchedImpl(*batchtensor);
1154 for (auto i : levels_to_dim_.reversed_enumerate()) {
1155 if (!impl) {
1156 break;
1157 }
1158 if (levels.contains(levels_to_dim_[i])) {
1159 impl->_unsafe_set_level(levels_start_ + i);
1160 impl = maybeGetBatchedImpl(impl->value());
1161
1162 }
1163 }
1164 }
1165private:
1166 int64_t levels_start_{};
1167 Slice<py::hdl<Dim>> levels_to_dim_;
1168};
1169
1170TensorRef _match_levels(Arena& A, TensorRef v, Slice<DimEntry> from_levels, Slice<DimEntry> to_levels, bool drop_levels=false) {
1171 if (from_levels == to_levels) {
1172 return v;
1173 }
1174 // drop_levels -> if a dim appears in from_levels but not to_levels, it is assumed it has stride 0.
1175 at::IntArrayRef sz = v->sizes();
1176 at::IntArrayRef sd = v->strides();
1177 AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size());
1178 Slice<int64_t> nsz;
1179 Slice<int64_t> nsd;
1180 for (auto l : to_levels) {
1181 auto oidx = from_levels.index(l);
1182 if (!oidx) {
1183 nsz.append(A, l.is_positional() ? 1 : l.dim()->size());
1184 nsd.append(A, 0);
1185 } else {
1186 auto idx = *oidx;
1187 nsz.append(A, sz[idx]);
1188 nsd.append(A, sd[idx]);
1189 }
1190 }
1191 return A.autorelease(v->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()), at::IntArrayRef(nsd.begin(), nsd.end()), v->storage_offset()));
1192}
1193
1194static py::object run_torch_function(Arena &A, py::handle orig, py::vector_args args, bool is_pointwise) {
1195 if (!pointwise_optimize) {
1196 is_pointwise = false;
1197 }
1198 // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : "functorch") << " " << orig << "\n";
1199
1200 Slice<py::hdl<Dim>> all_dims;
1201 Slice<py::handle> flat_args;
1202 auto unflatten_args = tree_flatten(A, args, flat_args);
1203 TensorRef device_holding_tensor;
1204
1205 Slice<TensorInfo> infos;
1206 Slice<DimEntry> result_levels;
1207 for (auto f : flat_args) {
1208 infos.append(A, TensorInfo::create(A, f, !is_pointwise, false));
1209 if (infos.back()) {
1210 TensorInfo& info = infos.back();
1211 AT_ASSERT(is_pointwise || info.batchedtensor);
1212 if (!device_holding_tensor && info.has_device) {
1213 device_holding_tensor = infos.back().tensor;
1214 }
1215 for (auto l : info.levels) {
1216 if (!result_levels.contains(l)) {
1217 result_levels.append(A, l);
1218 }
1219 }
1220 }
1221 }
1222
1223 if (is_pointwise) {
1224 for (auto i : flat_args.enumerate()) {
1225 if (infos[i]) {
1226 TensorRef tensor = infos[i].tensor;
1227 if (device_holding_tensor && !infos[i].has_device) {
1228 tensor = A.autorelease(tensor->to(device_holding_tensor->device()));
1229 }
1230 auto ml = _match_levels(A, tensor, infos[i].levels, result_levels);
1231 flat_args[i] = handle_from_tensor(A, std::move(ml));
1232 }
1233 }
1234
1235 Slice<py::handle> flat_it = flat_args;
1236 py::vector_args uargs = unflatten_args(A, flat_it);
1237
1238 py::object result = orig.call_vector(uargs);
1239
1240 // fast wrap for normal case where operator just returns a tensor.
1241 if (THPVariable_Check(result.ptr())) {
1242 return Tensor::from_positional(A, THPVariable_Unpack(result.ptr()), result_levels, device_holding_tensor);
1243 }
1244 auto wrap = [&](py::handle h) {
1245 if (THPVariable_Check(h.ptr())){
1246 return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), result_levels, device_holding_tensor));
1247 }
1248 return h;
1249 };
1250 return tree_map(A, wrap, result);
1251 } else {
1252 // std::cout << orig << " calling functorch...\n";
1253 // std::cout << "rl: " << result_levels << "\n";
1254 EnableAllLayers guard(A, result_levels);
1255 for (auto i : flat_args.enumerate()) {
1256 if (infos[i]) {
1257 TensorRef batched = infos[i].batchedtensor;
1258 if (device_holding_tensor && !infos[i].has_device) {
1259 batched = A.autorelease(batched->to(device_holding_tensor->device()));
1260 }
1261 guard.inplace_update_layers(batched, infos[i].levels);
1262 flat_args[i] = handle_from_tensor(A, batched);
1263 }
1264 }
1265 Slice<py::handle> flat_it = flat_args;
1266 py::vector_args uargs = unflatten_args(A, flat_it);
1267 AT_ASSERT(flat_it.size() == 0);
1268 py::object result = orig.call_vector(uargs);
1269 auto wrap = [&](py::handle h) {
1270 if (THPVariable_Check(h.ptr())) {
1271 return A.autorelease(guard.from_batched(A, THPVariable_Unpack(h.ptr()), device_holding_tensor));
1272 }
1273 return h;
1274 };
1275 if (THPVariable_Check(result.ptr())) {
1276 return guard.from_batched(A, THPVariable_Unpack(result.ptr()), device_holding_tensor);
1277 }
1278 return tree_map(A, wrap, result);
1279 }
1280}
1281
1282
1283static py::object __torch_function__(Arena &A, py::handle orig, py::vector_args args, bool is_pointwise) {
1284 if (orig == torch_Tensor___mul__) {
1285 AT_ASSERT(args.nargs == 2 && !args.has_keywords());
1286 auto lhs = args[0];
1287 auto rhs = args[1];
1288 if (py::isinstance(lhs, _Tensor) && py::isinstance(rhs, _Tensor) && _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) {
1289 bool has_device = false;
1290 Slice<DimEntry> levels;
1291 for (auto i : args.enumerate_positional()) {
1292 auto t = TensorInfo::create(A, args[i], false);
1293 // something like a mask * rhs, which matrix multiplies don't correctly promote
1294 if (!t.tensor->is_floating_point()) {
1295 return run_torch_function(A, orig, args, is_pointwise);
1296 }
1297 has_device = has_device || t.has_device;
1298 for (auto l : t.levels) {
1299 if (!levels.contains(l)) {
1300 levels.append(A, l);
1301 }
1302 }
1303 }
1304 // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n";
1305 return Tensor::create_delayed(py::object::borrow(orig), args, levels, has_device);
1306 }
1307 }
1308 return run_torch_function(A, orig, args, is_pointwise);
1309}
1310
1311py::vector_args as_vector_args(Arena& A, py::handle args, py::handle kwargs) {
1312 auto pos_args = (py::handle*) &PyTuple_GET_ITEM(args.ptr(), 0);
1313 auto pos_n = PyTuple_GET_SIZE(args.ptr());
1314 if (!kwargs.ptr()) {
1315 return py::vector_args(pos_args, pos_n, nullptr);
1316 }
1317 Slice<py::handle> all_args;
1318 Slice<py::handle> kwnames;
1319 all_args.extend(A, pos_args, pos_args + pos_n);
1320 py::dict_view dv(kwargs);
1321 Py_ssize_t pos = 0;
1322 py::handle key, value;
1323 while (dv.next(&pos, &key, &value)) {
1324 all_args.append(A, value);
1325 kwnames.append(A, key);
1326 }
1327 return py::vector_args(all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames)));
1328}
1329
1330static PyObject* py___torch_function__(PyObject *self,
1331 PyObject *const *args,
1332 Py_ssize_t nargs,
1333 PyObject *kwnames) {
1334 Arena A;
1335 PY_BEGIN
1336 maybeInitializeGlobals();
1337 AT_ASSERT(nargs == 4 || nargs == 5);
1338 auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr);
1339 bool is_pointwise = pointwise.contains(args[1]);
1340 return __torch_function__(A, args[1], std::move(va), is_pointwise).release();
1341 PY_END(nullptr)
1342}
1343
1344py::object levels_to_tuple(Slice<DimEntry> slice) {
1345 py::tuple t(slice.size());
1346 for (auto i : slice.enumerate()) {
1347 t.set(i, slice[i].is_positional() ? py::from_int(slice[i].position()) : py::object::borrow(slice[i].dim()));
1348 }
1349 py::object r = std::move(t);
1350 return r;
1351}
1352
1353PyObject* Tensor_ndim(Tensor* self, void*) {
1354 Py_ssize_t i = 0;
1355 for (auto l : self->levels()) {
1356 if (l.is_positional()) {
1357 ++i;
1358 }
1359 }
1360 return py::from_int(i).release();
1361}
1362
1363static PyGetSetDef Tensor_getsetters[] = {
1364 {"_has_device", (getter) [](PyObject* self, void*) -> PyObject* { return py::from_bool(((Tensor*)self)->has_device()).release(); }, NULL},
1365 {"_tensor", (getter) [](PyObject* self, void*) -> PyObject* {
1366 Arena A;
1367 return THPVariable_Wrap(((Tensor*)self)->tensor(A)); }, NULL},
1368 {"_batchtensor", (getter) [](PyObject* self, void*) -> PyObject* {
1369 Arena A;
1370 return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); }, NULL},
1371 {"_levels", (getter) [](PyObject* self, void*) -> PyObject* {
1372 PY_BEGIN
1373 return levels_to_tuple(((Tensor*)self)->levels()).release();
1374 PY_END(nullptr)
1375 }},
1376 {"ndim", (getter) Tensor_ndim, NULL, "ndim", NULL},
1377 {NULL} /* Sentinel */
1378};
1379
1380static PyMethodDef Tensor_methods[] = {
1381 {NULL, NULL, 0, NULL} /* Sentinel */
1382};
1383
1384
1385PyTypeObject Tensor::Type = {
1386 PyVarObject_HEAD_INIT(NULL, 0)
1387 "_C.Tensor", /* tp_name */
1388 sizeof(Tensor), /* tp_basicsize */
1389 0, /* tp_itemsize */
1390 Tensor::dealloc_stub, /* tp_dealloc */
1391 0, /* tp_vectorcall_offset */
1392 0, /* tp_getattr */
1393 0, /* tp_setattr */
1394 0, /* tp_as_async */
1395 0, /* tp_repr */
1396 0, /* tp_as_number */
1397 0, /* tp_as_sequence */
1398 0, /* tp_as_mapping */
1399 0, /* tp_hash */
1400 0, /* tp_call */
1401 0, /* tp_str */
1402 0, /* tp_getattro */
1403 0, /* tp_setattro */
1404 0, /* tp_as_buffer */
1405 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE , /* tp_flags */
1406 "Tensor Object", /* tp_doc */
1407 0, /* tp_traverse */
1408 0, /* tp_clear */
1409 0, /* tp_richcompare */
1410 0, /* tp_weaklistoffset */
1411 0, /* tp_iter */
1412 0, /* tp_iternext */
1413 Tensor_methods, /* tp_methods */
1414 0, /* tp_members */
1415 Tensor_getsetters, /* tp_getset */
1416 0, /* tp_base */
1417 0, /* tp_dict */
1418 0, /* tp_descr_get */
1419 0, /* tp_descr_set */
1420 0, /* tp_dictoffset */
1421 0, /* tp_init */
1422 0, /* tp_alloc */
1423 Tensor::new_stub, /* tp_new */
1424};
1425
1426
1427// dim() --------------------
1428
1429bool relevant_op(_Py_CODEUNIT c) {
1430 switch(_Py_OPCODE(c)) {
1431 case STORE_NAME:
1432 case STORE_GLOBAL:
1433 case STORE_FAST:
1434 case STORE_DEREF:
1435 return true;
1436 default:
1437 return false;
1438 }
1439}
1440
1441py::object create_dim(py::object name, py::handle size) {
1442 auto d = Dim::create(std::move(name));
1443 if (!py::is_none(size)) {
1444 d->set_size(py::to_int(size));
1445 }
1446 return std::move(d);
1447}
1448
1449py::object create_dimlist(py::object name, py::handle size) {
1450 auto d = DimList::create(std::move(name));
1451 if (!py::is_none(size)) {
1452 if (py::is_int(size)) {
1453 d->bind_len(py::to_int(size));
1454 } else {
1455 py::sequence_view s(size);
1456 d->bind_len(s.size());
1457 for (auto i : irange(d->size())) {
1458 d->dims_[i]->set_size(py::to_int(s[i]));
1459 }
1460 }
1461 }
1462 return std::move(d);
1463}
1464
1465
1466
1467// Python wrappers that make new reflection primitives available for older runtimes
1468#if !(IS_PYTHON_3_11_PLUS)
1469#define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code))
1470#endif
1471
1472struct PyInstDecoder {
1473 PyInstDecoder(PyCodeObject* code_object, int lasti)
1474 : code_object_(code_object), code_(_PyCode_CODE(code_object)), offset_(lasti / sizeof(_Py_CODEUNIT)) {}
1475 // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols
1476 // See https://github.com/pytorch/pytorch/issues/93854
1477 void next() {
1478 #if IS_PYTHON_3_11_PLUS && !defined(_WIN32)
1479 offset_ += _PyOpcode_Caches[opcode()];
1480 #endif
1481 offset_ += 1;
1482 }
1483 int opcode() {
1484 auto r = _Py_OPCODE(code_[offset_]);
1485 #if IS_PYTHON_3_11_PLUS && !defined(_WIN32)
1486 r = _PyOpcode_Deopt[r];
1487 #endif
1488 return r;
1489 }
1490 int oparg() {
1491 return _Py_OPARG(code_[offset_]);
1492 }
1493
1494 py::object name() {
1495 py::object names;
1496 switch(opcode()) {
1497 case STORE_NAME:
1498 case STORE_GLOBAL:
1499 names = py::object::borrow(code_object_->co_names);
1500 break;
1501 case STORE_FAST:
1502 names = py::object::steal(PyCode_GetVarnames(code_object_));
1503 break;
1504 case STORE_DEREF:
1505 names = py::object::steal(PyCode_GetCellvars(code_object_));
1506 break;
1507 default:
1508 return py::object();
1509 }
1510 return py::object::steal(PySequence_GetItem(names.ptr(), oparg()));
1511 }
1512private:
1513 PyCodeObject* code_object_;
1514 _Py_CODEUNIT* code_;
1515 int offset_;
1516};
1517
1518template<py::object (*create_object)(py::object, py::handle)>
1519static PyObject* _dims(PyObject *self,
1520 PyObject *const *args,
1521 Py_ssize_t nargs,
1522 PyObject *kwnames) {
1523 PY_BEGIN
1524 Py_ssize_t specified_ndims = -1;
1525 Py_ssize_t found_ndims = 0;
1526 Py_ssize_t sizes = -1;
1527 py::handle n = Py_None;
1528 py::handle py_sizes = Py_None;
1529
1530 if (nargs || kwnames) {
1531 py::vector_args va(args, nargs, kwnames);
1532 va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0);
1533 if (!py::is_none(py_sizes)) {
1534 sizes = py::sequence_view(py_sizes).size();
1535 specified_ndims = sizes;
1536 }
1537 if (!py::is_none(n)) {
1538 specified_ndims = py::to_int(n);
1539 }
1540 }
1541
1542 PyThreadState* state = PyThreadState_GET();
1543 auto f = py::obj<PyFrameObject>::steal(PyThreadState_GetFrame(state));
1544 auto c = py::obj<PyCodeObject>::steal(PyFrame_GetCode(f.ptr()));
1545 auto lasti = PyFrame_GetLasti(f.ptr());
1546 auto decoder = PyInstDecoder(c.ptr(), lasti);
1547 #if IS_PYTHON_3_11_PLUS
1548 // When py3.11 adapts bytecode lasti points to the precall
1549 // rather than the call instruction after it
1550 if (decoder.opcode() == PRECALL) {
1551 decoder.next();
1552 }
1553 #endif
1554 decoder.next();
1555
1556 if (relevant_op(decoder.opcode())) {
1557 found_ndims = 1;
1558 } else if (decoder.opcode() == UNPACK_SEQUENCE) {
1559 found_ndims = decoder.oparg();
1560 decoder.next();
1561 }
1562
1563 if (specified_ndims == -1) {
1564 if (found_ndims == 0) {
1565 py::raise_error(PyExc_SyntaxError, "dims() must be assigned to a sequence of variable names or have argument n specified");
1566 }
1567 specified_ndims = found_ndims;
1568 }
1569 if (found_ndims != specified_ndims) {
1570 found_ndims = 0; // avoid taking the wrong names for dimensions
1571 }
1572
1573 auto genobject = [&](int i) -> py::object {
1574 py::object name;
1575 if (i < found_ndims) {
1576 name = decoder.name();
1577 }
1578 if (!name.ptr()) {
1579 name = py::unicode_from_format("d%d", i);
1580 found_ndims = 0; // once we fail at finding a name, we can find any more
1581 } else {
1582 decoder.next();
1583 }
1584 return create_object(std::move(name), sizes != -1 ? py::sequence_view(py_sizes)[i] : py::handle(Py_None));
1585 };
1586 if (sizes != -1 && sizes != specified_ndims) {
1587 py::raise_error(PyExc_ValueError, "expected %d sizes but found %d", int(specified_ndims), int(sizes));
1588 }
1589 if (specified_ndims == 1) {
1590 return genobject(0).release();
1591 }
1592 py::tuple result(specified_ndims);
1593 for (int i = 0; i < specified_ndims; ++i) {
1594 result.set(i, genobject(i));
1595 }
1596 return result.release();
1597 PY_END(nullptr)
1598}
1599
1600int64_t dim_index(const std::vector<py::obj<Dim>>& dims, py::hdl<Dim> dim) {
1601 for (int64_t i = 0, N = dims.size(); i < N; ++i) {
1602 if (dims[i].ptr() == dim.ptr()) {
1603 return i;
1604 }
1605 }
1606 return -1;
1607}
1608
1609
1610struct DotPart {
1611 Slice<DimEntry> dims;
1612 size_t total_size = 1;
1613 void append(Arena& A, py::hdl<Dim> d) {
1614 total_size *= d->size();
1615 dims.append(A, d);
1616 }
1617};
1618
1619template<typename T>
1620static at::ArrayRef<T> as_array_ref(Slice<T> t) {
1621 return at::ArrayRef<T>(t.begin(), t.end());
1622}
1623
1624TensorRef dot_prepare(Arena& A, std::initializer_list<DotPart> parts, const TensorInfo& t) {
1625 Slice<DimEntry> new_levels;
1626 bool needs_reshape = false;
1627 for (auto p : parts) {
1628 if (p.dims.size() != 1) {
1629 needs_reshape = true;
1630 }
1631 new_levels.extend(A, p.dims);
1632 }
1633 auto r = _match_levels(A, t.tensor, t.levels, new_levels, true);
1634 if (!needs_reshape) {
1635 return r;
1636 }
1637 Slice<int64_t> view;
1638 for (auto p : parts) {
1639 view.append(A, p.total_size);
1640 }
1641 return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end())));
1642}
1643
1644py::object dot_finish(Arena& A, std::initializer_list<DotPart> parts, at::Tensor r) {
1645 Slice<DimEntry> result_levels;
1646 bool needs_reshape = false;
1647 for (auto p : parts) {
1648 if (p.dims.size() != 1) {
1649 needs_reshape = true;
1650 }
1651 result_levels.extend(A, p.dims);
1652 }
1653 if (needs_reshape) {
1654 Slice<int64_t> new_size;
1655 for (auto l : result_levels) {
1656 new_size.append(A, l.dim()->size());
1657 }
1658 r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end()));
1659 }
1660 return Tensor::from_positional(A, std::move(r), result_levels, true);
1661}
1662
1663
1664
1665py::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice<DimEntry> sum) {
1666 auto lhs_strides = lhs.tensor->strides();
1667 auto rhs_strides = rhs.tensor->strides();
1668
1669 DotPart lro_dims;
1670 DotPart lo_dims;
1671 DotPart ro_dims;
1672 DotPart lr_dims;
1673
1674 auto insert_dim = [&] (py::hdl<Dim> d, at::optional<int> lhs_idx, at::optional<int> rhs_idx) {
1675 bool reduced = sum.contains(d);
1676 int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0;
1677 int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0;
1678 if (reduced) {
1679 // lr
1680 lr_dims.append(A, d);
1681 } else {
1682 if ((lhs_stride == 0) == (rhs_stride == 0)) {
1683 // lro
1684 lro_dims.append(A, d);
1685 } else if (lhs_stride != 0) {
1686 // lo
1687 lo_dims.append(A, d);
1688 } else {
1689 AT_ASSERT(rhs_stride != 0);
1690 ro_dims.append(A, d);
1691 }
1692 }
1693 };
1694
1695
1696 auto rhs_seen = A.allocate<bool>(rhs.levels.size());
1697 std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false);
1698
1699 for (auto i : lhs.levels.enumerate()) {
1700 auto d = lhs.levels[i];
1701 auto rhs_idx = rhs.levels.index(d);
1702 if (rhs_idx) {
1703 rhs_seen[*rhs_idx] = true;
1704 }
1705 insert_dim(d.dim(), i, rhs_idx);
1706 }
1707
1708 for (auto i : rhs.levels.enumerate()) {
1709 if (rhs_seen[i]) {
1710 continue;
1711 }
1712 auto d = rhs.levels[i];
1713 insert_dim(d.dim(), at::nullopt, i);
1714 }
1715
1716 if (lr_dims.dims.size() != sum.size()) {
1717 for (auto & d : sum) {
1718 if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) {
1719 py::raise_error(DimensionBindError(), "summing over non-existant dimension %S", d.dim().ptr());
1720 }
1721 }
1722 }
1723
1724 // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n";
1725 // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << " " << lr_dims.dims << "\n";
1726
1727 // no batch, just call mm
1728 if (lro_dims.dims.size() != 0) {
1729 auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs);
1730 auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs);
1731 return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_));
1732 } else {
1733 auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs);
1734 auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs);
1735 return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_));
1736 }
1737
1738}
1739
1740static PyObject* test_c(PyObject *self,
1741 PyObject *const *args,
1742 Py_ssize_t nargs,
1743 PyObject *kwnames) {
1744 PY_BEGIN
1745
1746 Arena A;
1747 Slice<int> s(A, 3, 4, 5);
1748 AT_ASSERT(s.size() == 3 && s.capacity() == 8);
1749 AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5);
1750 s.append(A, 6);
1751 AT_ASSERT(s[3] == 6);
1752 for(int i : irange(10)) {
1753 s.append(A, i);
1754 }
1755 AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16);
1756
1757 Slice<int> s2(A, -1, -2, -3);
1758 AT_ASSERT(s2[1] == -2 && s[0] == 3);
1759
1760 auto ss = s.slice(1,2);
1761 AT_ASSERT(ss.size() == 1);
1762 AT_ASSERT(ss[0] == 4);
1763 AT_ASSERT(ss.capacity() == 1);
1764 ss.append(A, -4);
1765 AT_ASSERT(ss.size() == 2 && ss[1] == -4);
1766 ss[0] = 3;
1767 AT_ASSERT(s[1] == 4);
1768
1769 s.insert(A, s.slice(1, 4), ss);
1770 AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0);
1771
1772 auto sz = s.size();
1773 s.insert(A, s.slice(1, 1), 4);
1774 AT_ASSERT(s[1] == 4 && sz + 1 == s.size());
1775
1776
1777 Slice<int> d(A, 0, 1, 2, 3, 4);
1778
1779 Slice<int> b(A, 0, 1, 2, 3, 4);
1780 b.insert(A, b.slice(1,1), d);
1781 AT_ASSERT(b.size() == 10);
1782 AT_ASSERT(b[1] == 0);
1783 AT_ASSERT(b[5] == 4);
1784 AT_ASSERT(b.back() == 4);
1785
1786 Py_RETURN_NONE;
1787
1788 PY_END(nullptr);
1789}
1790
1791static DimEntry _wrap_dim(py::handle d, size_t N, bool keepdim);
1792
1793static PyObject* order(PyObject *_,
1794 PyObject *const *args,
1795 Py_ssize_t nargs,
1796 PyObject *kwnames) {
1797 Arena A;
1798 PY_BEGIN
1799 if (kwnames) {
1800 py::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames);
1801 }
1802 AT_ASSERT(nargs-- > 0);
1803 Slice<DimEntry> orig_levels;
1804 Slice<DimEntry> levels;
1805 TensorRef data;
1806 py::handle self = args++[0];
1807 bool has_device;
1808 if (Tensor::check_exact(self)) {
1809 auto t = Tensor::unchecked_wrap(self);
1810 orig_levels = t->levels();
1811 data = t->tensor(A);
1812 has_device = t->has_device();
1813 } else {
1814 auto d = Dim::unchecked_wrap(self);
1815 orig_levels.append(A, d);
1816 data = d->range();
1817 has_device = false;
1818 }
1819
1820 Slice<DimEntry> flat_positional_dims;
1821 Slice<std::pair<int, int>> to_flatten;
1822 levels.extend(A, orig_levels);
1823
1824 int orig_ndim = ndim_of_levels(levels);
1825 auto append = [&](DimEntry d) {
1826 auto midx = levels.index(d);
1827 if (!midx) {
1828 if (d.is_positional()) {
1829 py::raise_error(PyExc_ValueError, "tensor has %d positional dimensions, but %d specified, or it was specified twice", int(orig_ndim), int(d.position() + orig_ndim));
1830 } else {
1831 py::raise_error(PyExc_ValueError, "tensor of dimensions %R does not contain dim %R or it was specified twice", levels_to_tuple(orig_levels).ptr(), d.dim().ptr());
1832 }
1833 }
1834 levels[*midx] = DimEntry();
1835 flat_positional_dims.append(A, d);
1836 };
1837
1838 int n_new_positional = 0;
1839 for (auto i :irange(nargs)) {
1840 py::handle arg = args[i];
1841 DimEntry entry = _wrap_dim(arg, orig_ndim, false);
1842 if (!entry.is_none()) {
1843 append(entry);
1844 ++n_new_positional;
1845 } else if (DimList::check(arg)) {
1846 auto dl = DimList::unchecked_wrap(arg);
1847 for (py::obj<Dim> & d : dl->dims_) {
1848 append(py::hdl<Dim>(d));
1849 ++n_new_positional;
1850 }
1851 } else {
1852 ++n_new_positional;
1853 if (!py::is_sequence(arg)) {
1854 py::raise_error(PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]");
1855 }
1856 py::sequence_view sq(arg);
1857 auto N = sq.size();
1858 to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N));
1859 for (auto j : irange(N)) {
1860 DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false);
1861 if (e.is_none()) {
1862 py::raise_error(PyExc_ValueError, "expected a Dim, or int");
1863 }
1864 append(e);
1865 }
1866 }
1867 }
1868
1869 int ndim = 0;
1870 int insert_point = -1;
1871 Slice<DimEntry> new_levels;
1872 for (auto l : levels) {
1873 if (l.is_none()) {
1874 continue;
1875 }
1876 if (l.is_positional()) {
1877 ndim++;
1878 if (insert_point == -1) {
1879 insert_point = new_levels.size();
1880 new_levels.extend(A, flat_positional_dims);
1881 }
1882 }
1883 new_levels.append(A, l);
1884 }
1885 if (insert_point == -1) {
1886 insert_point = new_levels.size();
1887 new_levels.extend(A, flat_positional_dims);
1888 }
1889
1890 at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels);
1891
1892 if (to_flatten.size()) {
1893 Slice<int64_t> view;
1894 auto sz = ndata.sizes();
1895 // before the new positional dims
1896 for (auto i : irange(0, insert_point)) {
1897 view.append(A, sz[i]);
1898 }
1899 int i = 0;
1900 for (auto to_flat : to_flatten) {
1901 for (;i < to_flat.first; ++i) {
1902 view.append(A, sz[insert_point + i]);
1903 }
1904 int64_t new_size = 1;
1905 int last = i + to_flat.second;
1906 for (; i < last; ++i) {
1907 new_size *= sz[insert_point + i];
1908 }
1909 view.append(A, new_size);
1910 }
1911 for (; i < flat_positional_dims.size(); ++i) {
1912 view.append(A, sz[insert_point + i]);
1913 }
1914 // after the new positional dims
1915 for (auto i : irange(insert_point + flat_positional_dims.size(), levels.size())) {
1916 view.append(A, sz[i]);
1917 }
1918 // we shorted the number of dimension, so remove them from new levels
1919 // we will renumber them later
1920 auto n_to_remove = flat_positional_dims.size() - n_new_positional;
1921 new_levels.insert(A, new_levels.slice(insert_point, insert_point + n_to_remove), Slice<DimEntry>());
1922 ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end()));
1923 }
1924
1925 // renumber the positional dimension
1926 int seen = 0;
1927 for (auto i : new_levels.reversed_enumerate()) {
1928 if (new_levels[i].is_positional() || (i >= insert_point && i < insert_point + n_new_positional)) {
1929 new_levels[i] = --seen;
1930 }
1931 }
1932 return Tensor::from_positional(A, std::move(ndata), new_levels, has_device).release();
1933
1934 PY_END(nullptr)
1935}
1936
1937static PyObject* expand(PyObject *_,
1938 PyObject *const *args,
1939 Py_ssize_t nargs,
1940 PyObject *kwnames) {
1941 Arena A;
1942 PY_BEGIN
1943 AT_ASSERT(nargs-- > 0);
1944 auto info = TensorInfo::create(A, args++[0], false);
1945 for (auto i : irange(nargs)) {
1946 if (!Dim::check(args[i])) {
1947 maybeInitializeGlobals();
1948 py::vector_args vargs(args - 1, nargs + 1, kwnames);
1949 if (THPVariable_Check(args[-1])) {
1950 return torch_Tensor_expand.call_vector(vargs).release();
1951 } else {
1952 return __torch_function__(A, torch_Tensor_expand, vargs, false).release();
1953 }
1954 }
1955 }
1956 const at::Tensor& data = *info.tensor;
1957 auto levels = info.levels;
1958 Slice<DimEntry> new_levels;
1959 Slice<int64_t> sz;
1960 Slice<int64_t> sd;
1961 for (auto i : irange(nargs)) {
1962 auto d = Dim::unchecked_wrap(args[i]);
1963 if (levels.contains(d) || new_levels.contains(d)) {
1964 py::raise_error(DimensionBindError(), "expanding dimension %R already exists in tensor with dims", d.ptr());
1965 }
1966 new_levels.append(A, d);
1967 sz.append(A, d->size());
1968 sd.append(A, 0);
1969 }
1970 new_levels.extend(A, levels);
1971 at::IntArrayRef osz = data.sizes();
1972 at::IntArrayRef osd = data.strides();
1973 sz.extend(A, osz.begin(), osz.end());
1974 sd.extend(A, osd.begin(), osd.end());
1975 at::Tensor ndata = data.as_strided(at::IntArrayRef(sz.begin(), sz.end()), at::IntArrayRef(sd.begin(), sd.end()), data.storage_offset());
1976 return Tensor::from_positional(A, std::move(ndata), new_levels, info.has_device).release();
1977 PY_END(nullptr)
1978}
1979
1980
1981void _bind_dims_to_size(Arena & A, int64_t sz, int64_t sd,
1982 Slice<py::hdl<Dim>> dims, Slice<int64_t>& nsz, Slice<int64_t>& nsd) {
1983 int64_t rhs_prod = 1;
1984 for (auto i : dims.enumerate()) {
1985 if (!dims[i]->is_bound()) {
1986 for (auto j : irange(i + 1, dims.size())) {
1987 if (!dims[j]->is_bound()) {
1988 py::raise_error(DimensionBindError(), "cannot infer the sizes of two dimensions at once %R and %R", dims[i].ptr(), dims[j].ptr());
1989 }
1990 rhs_prod *= dims[j]->size();
1991 }
1992 if (sz % rhs_prod != 0) {
1993 py::tuple tup(dims.size());
1994 for (auto j : dims.enumerate()) {
1995 tup.set(j, dims[j]->is_bound() ? py::from_int(dims[j]->size()) : py::unicode_from_string("?"));
1996 }
1997 py::raise_error(DimensionBindError(), "inferred dimension does not evenly fit into larger dimension: %d vs %R", (int) sz, tup.ptr());
1998 }
1999 int64_t inferred_size = sz / rhs_prod;
2000 dims[i]->set_size(inferred_size);
2001 rhs_prod = sz;
2002 break;
2003 }
2004 rhs_prod *= dims[i]->size();
2005 }
2006 if (rhs_prod != sz) {
2007 py::tuple tup(dims.size());
2008 for (auto j : dims.enumerate()) {
2009 tup.set(j, py::object::borrow(dims[j]));
2010 }
2011 py::raise_error(DimensionBindError(), "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", (int) sz, (int) rhs_prod, tup.ptr());
2012 }
2013 auto new_strides = A.allocate<int64_t>(dims.size());
2014 auto prev_stride = sd;
2015 for (auto i : dims.reversed_enumerate()) {
2016 new_strides[i] = prev_stride;
2017 prev_stride = dims[i]->size()*prev_stride;
2018 }
2019 for (auto i : dims.enumerate()) {
2020 nsd.append(A, new_strides[i]);
2021 nsz.append(A, dims[i]->size());
2022 }
2023}
2024
2025inline bool has_dims(py::handle d) {
2026 return Dim::check_exact(d) || Tensor::check_exact(d);
2027}
2028
2029struct IndexingInfo {
2030 bool can_call_original; // if true, then it is safe to just call getitem or setitem, these objects do not need special handling
2031 bool advanced_indexing; // requires actual lookup
2032 TensorRef self;
2033 Slice<py::handle> flat_inputs;
2034 Slice<DimEntry> result_levels;
2035 bool has_device;
2036};
2037
2038static Slice<py::handle> as_slice(py::tuple_view tv) {
2039 PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(),0);
2040 return Slice<py::handle>((py::handle*)begin, (py::handle*) (begin + tv.size()));
2041}
2042
2043static Slice<py::handle> as_slice(py::list_view tv) {
2044 PyObject** begin = &PyList_GET_ITEM(tv.ptr(),0);
2045 return Slice<py::handle>((py::handle*)begin, (py::handle*) (begin + tv.size()));
2046}
2047
2048
2049bool maybe_dimpack(Slice<py::handle>& elements, py::handle s, bool check_first=true) {
2050 // can we avoid rechecking?
2051 if (py::list_view::check(s)) {
2052 py::list_view tv(s);
2053 if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) {
2054 elements = as_slice(tv);
2055 return true;
2056 }
2057 }
2058 // can we avoid rechecking?
2059 if (py::tuple_view::check(s)) {
2060 py::tuple_view tv(s);
2061 if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) {
2062 elements = as_slice(tv);
2063 return true;
2064 }
2065 }
2066 return false;
2067};
2068
2069bool is_dimpack(py::handle s) {
2070 Slice<py::handle> e;
2071 return maybe_dimpack(e, s);
2072}
2073
2074IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice<py::handle> input, Slice<DimEntry> keys, Slice<py::handle> values, bool has_dimpacks_or_none);
2075static py::object invoke_getitem(Arena& A, const IndexingInfo& iinfo);
2076
2077static py::object index(Arena& A, py::handle self, py::handle dims, py::handle indices) {
2078 maybeInitializeGlobals();
2079 Slice<py::handle> dims_list;
2080 Slice<py::handle> indices_list;
2081 // we allow for matching single dims to multiple dims,
2082 // so we first have to normalize everything into the case where there is a list on lhs and the rhs
2083 bool lhs_list = py::tuple_view::check(dims) || py::list_view::check(dims);
2084 bool rhs_list = py::tuple_view::check(indices) || py::list_view::check(indices);
2085 if (lhs_list && rhs_list) {
2086 py::sequence_view dv(dims);
2087 py::sequence_view ind(indices);
2088 Py_ssize_t N = dv.size();
2089 if (N != ind.size()) {
2090 py::raise_error(PyExc_TypeError, "dims (%d) and indices (%d) must have the same length", int(N), int(ind.size()));
2091 }
2092 for (auto i : irange(N)) {
2093 dims_list.append(A, A.autorelease(dv[i]));
2094 indices_list.append(A, A.autorelease(ind[i]));
2095 }
2096 } else {
2097 dims_list.append(A, dims);
2098 indices_list.append(A, indices);
2099 }
2100
2101 // dims being indexed can be grouped together into a single index space, and we have to
2102 // flatten them int a single dimension before we can index them...
2103 auto self_info = TensorInfo::create(A, self, false);
2104 auto ndim = self_info.ndim();
2105 Slice<DimEntry> new_levels;
2106 Slice<DimEntry> to_flatten;
2107 Slice<DimEntry> dims_list_flat;
2108 auto parse_dim_entry = [&](py::handle s) -> DimEntry {
2109 auto d = _wrap_dim(s, ndim, false);
2110 if (d.is_none()) {
2111 py::raise_error(PyExc_TypeError, "expected a dimension specifyer but found %R", s.ptr());
2112 }
2113 return d;
2114 };
2115 auto dim_not_present = [&](DimEntry d) {
2116 if (d.is_positional()) {
2117 py::raise_error(PyExc_TypeError, "dimension %d not in tensor of %d dimensions", d.position() + ndim , ndim);
2118 } else {
2119 py::raise_error(PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr());
2120 }
2121 };
2122
2123 for (auto i : dims_list.enumerate()) {
2124 Slice<py::handle> m;
2125 if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) {
2126 if (m.size() == 0) {
2127 // plausible semantics work for this to have 0 elements (e.g. the index will always be 0)
2128 dims_list_flat.append(A, DimEntry()); // value is just dropped
2129 }
2130 auto first = parse_dim_entry(m[0]);
2131 dims_list_flat.append(A, first);
2132 if (m.size() == 1) {
2133 continue;
2134 }
2135 if (to_flatten.size() == 0) {
2136 new_levels.extend(A, self_info.levels);
2137 }
2138 Slice<DimEntry> rest;
2139 for (auto i : irange(1, m.size())) {
2140 auto d = parse_dim_entry(m[i]);
2141 if (!new_levels.remove(A, d)) {
2142 dim_not_present(d);
2143 }
2144 rest.append(A, d);
2145 }
2146
2147 auto first_idx = new_levels.index(first);
2148 if (!first_idx) {
2149 dim_not_present(first);
2150 }
2151 new_levels.insert(A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest);
2152 to_flatten.extend(A, rest);
2153 } else {
2154 dims_list_flat.append(A, parse_dim_entry(dims_list[i]));
2155 }
2156 }
2157 if (to_flatten.size() > 0) {
2158 TensorRef rearranged = _match_levels(A, self_info.tensor, self_info.levels, new_levels);
2159 at::IntArrayRef sizes = rearranged->sizes();
2160 Slice<int64_t> new_sizes;
2161 Slice<DimEntry> reshape_levels;
2162 for (auto i : new_levels.enumerate()) {
2163 if (to_flatten.contains(new_levels[i])) {
2164 new_sizes.back() *= sizes[i];
2165 } else {
2166 new_sizes.append(A, sizes[i]);
2167 reshape_levels.append(A, new_levels[i]);
2168 }
2169 }
2170 self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end())));
2171
2172 self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op
2173 // we need to be careful not to rely the dimensions size because it doesnt match the size of the whole group
2174 }
2175 bool has_dimpacks = false;
2176 for (auto idx : indices_list) {
2177 if (py::tuple_view::check(idx) || py::list_view::check(idx)) {
2178 has_dimpacks = true;
2179 break;
2180 }
2181 }
2182 IndexingInfo info = getsetitem_flat(A, self_info, Slice<py::handle>(), dims_list_flat, indices_list, has_dimpacks);
2183 return invoke_getitem(A, info);
2184}
2185
2186// true -- the indices were flattend out of a tuple, list or sequence...
2187
2188Slice<py::handle> slice_from_sequence(Arena& A, py::handle value) {
2189 if (py::tuple_view::check(value)) {
2190 return as_slice(py::tuple_view(value));
2191 } else if (py::list_view::check(value)) {
2192 return as_slice(py::list_view(value));
2193 } else {
2194 py::sequence_view sv(value);
2195 Slice<py::handle> r;
2196 for (auto i : sv.enumerate()) {
2197 r.append(A, A.autorelease(sv[i]));
2198 }
2199 return r;
2200 }
2201}
2202
2203bool extractIndices(Arena& A, py::handle index, Slice<py::handle>& indices) {
2204 if (py::tuple_view::check(index)) {
2205 indices.extend(A, as_slice(py::tuple_view(index)));
2206 return true;
2207 } else if (THPVariable_Check(index.ptr())) {
2208 indices.append(A, index);
2209 return false;
2210 } else if (!py::is_sequence(index)) {
2211 indices.append(A, index);
2212 return false;
2213 }
2214 // a copy of treatSequenceAsTuple modified to add Dim and our wrapped tensors..
2215 py::sequence_view sv(index);
2216 if (sv.size() >= 32) {
2217 indices.extend(A, slice_from_sequence(A, index));
2218 return true;
2219 }
2220 for (auto i : sv.enumerate()) {
2221 py::handle item;
2222 try {
2223 item = sv[i];
2224 } catch (py::exception_set & e) {
2225 PyErr_Clear();
2226 indices.append(A, index);
2227 return false;
2228 }
2229 if (THPVariable_Check(item.ptr()) || py::is_sequence(item) || PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || py::is_none(item) || has_dims(item)) {
2230 indices.extend(A, slice_from_sequence(A, index));
2231 return true;
2232 }
2233 }
2234 indices.append(A, index);
2235 return false;
2236}
2237
2238static IndexingInfo getsetitem(Arena & A, py::handle self, py::handle index, bool tensors_have_dims) {
2239 bool can_call_original_getitem = !tensors_have_dims;
2240
2241 Slice<py::handle> input;
2242 if (has_dims(index)) {
2243 input.append(A, index);
2244 } else {
2245 bool is_sequence = extractIndices(A, index, input);
2246 // nothing about first class dims here, fallback to getitem
2247 if (can_call_original_getitem && !is_sequence) {
2248 return { true };
2249 }
2250 }
2251
2252 int64_t dims_indexed = 0;
2253 int64_t expanding_object = -1;
2254 DimList* unbound_dim_list = nullptr;
2255 auto check_expanding = [&](int64_t i) {
2256 if (expanding_object != -1) {
2257 py::raise_error(DimensionBindError(), "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", (int) expanding_object, (int) i);
2258 }
2259 expanding_object = i;
2260 };
2261 Slice<int64_t> dimlists;
2262
2263 // calculate how many dimensioned have been indexed in order to compute the size of ...
2264 // or expand a potentially unbound dimension list.
2265
2266 bool has_dimpacks_or_none = false;
2267 for (auto i : input.enumerate()) {
2268 py::handle s = input[i];
2269 if (Dim::check_exact(s) || Tensor::check_exact(s)) {
2270 can_call_original_getitem = false;
2271 ++dims_indexed;
2272 } else if (s.ptr() == Py_Ellipsis) {
2273 check_expanding(i);
2274 } else if (DimList::check(s)) {
2275 can_call_original_getitem = false;
2276 auto dl = DimList::unchecked_wrap(s);
2277 if (!dl->is_bound()) {
2278 check_expanding(i);
2279 unbound_dim_list = dl.ptr();
2280 } else {
2281 dims_indexed += dl->dims_.size();
2282 }
2283 dimlists.append(A, i);
2284 } else if (py::is_none(s)) {
2285 has_dimpacks_or_none = true;
2286 } else if (is_dimpack(s)) {
2287 can_call_original_getitem = false;
2288 has_dimpacks_or_none = true;
2289 ++dims_indexed;
2290 } else {
2291 ++dims_indexed;
2292 }
2293 }
2294
2295 // at this point if we haven't seen any Dim objects, we also can fallback to the original getitem.
2296 if (can_call_original_getitem) {
2297 return {true};
2298 }
2299
2300 // std::cout << "__getitem__ " << self << " " << index << "\n";
2301
2302 TensorInfo self_info = TensorInfo::create(A, self, false, true);
2303 auto ndim = self_info.ndim();
2304 if (dims_indexed > ndim) {
2305 py::raise_error(PyExc_ValueError, "at least %d indices were supplied but the tensor only has %d dimensions", (int) dims_indexed, (int) ndim);
2306 }
2307 // expand any unbound dimension list, or expand ... into individual : slices.
2308 auto expanding_dims = ndim - dims_indexed;
2309 if (expanding_object != -1) {
2310 if (unbound_dim_list) {
2311 unbound_dim_list->bind_len(expanding_dims);
2312 } else {
2313 // ...
2314 Slice<py::handle> no_slices;
2315 for (auto i : irange(expanding_dims)) {
2316 (void) i;
2317 no_slices.append(A, no_slice);
2318 }
2319 input.insert(A, input.slice(expanding_object, expanding_object + 1), no_slices);
2320 }
2321 }
2322
2323 // flatten out any dimensions stored in dimlist elements directly into the inputs
2324 // std::cout << dimlists << " <- dim lists!\n";
2325 for (int64_t i = dimlists.size() - 1; i >=0; --i) {
2326 auto idx = dimlists[i];
2327 // we added more elements to input because of ...
2328 // so we need to also adjust the index to get back to where the
2329 // dimlist existed
2330 if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) {
2331 idx += expanding_dims;
2332 }
2333 auto dl = DimList::unchecked_wrap(input[idx]);
2334 // XXX would be better if we used an OwnedSlice in DimList
2335 Slice<py::handle> more_dims((py::handle*) &*dl->dims_.begin(), (py::handle*) &*dl->dims_.end());
2336 input.insert(A, input.slice(idx, idx + 1), more_dims);
2337 }
2338
2339 return getsetitem_flat(A, self_info, input, Slice<DimEntry>(), Slice<py::handle>(), has_dimpacks_or_none);
2340}
2341
2342IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice<py::handle> input, Slice<DimEntry> keys, Slice<py::handle> values, bool has_dimpacks_or_none) {
2343 // At this point:
2344 // ..., DimList have been eliminated
2345 // Dim, Tensor, Tuple[Dim,...], int, slice still remain
2346
2347
2348 // we have to count how many times we see a dimension.
2349 // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires advanced indexing.
2350 Slice<py::hdl<Dim>> seen_dims;
2351 Slice<int64_t> seen_dims_nuses;
2352 auto add_dim = [&](py::hdl<Dim> entry) {
2353 auto midx = seen_dims.index(entry);
2354 if (!midx) {
2355 seen_dims.append(A, entry);
2356 seen_dims_nuses.append(A, 1);
2357 } else {
2358 ++seen_dims_nuses[*midx];
2359 }
2360 };
2361
2362 Slice<py::handle> input_it = input;
2363
2364 Slice<py::handle> flat_inputs;
2365 // flat inputs will start with an empty py::handle if the
2366 // actual value is in the tensor-like object in the tensor info
2367 Slice<TensorInfo> tensor_inputs;
2368
2369 auto append_flat_handle = [&](py::handle h) {
2370 flat_inputs.append(A, h);
2371 tensor_inputs.append(A, TensorInfo());
2372 };
2373 TensorRef device_holding_tensor;
2374 auto append_tensor_input = [&](TensorInfo ti) {
2375 flat_inputs.append(A, py::handle());
2376 tensor_inputs.append(A, ti);
2377 if (ti.has_device && !device_holding_tensor) {
2378 device_holding_tensor = ti.tensor;
2379 }
2380 };
2381
2382 Slice<int64_t> nsz;
2383 Slice<int64_t> nsd;
2384 at::IntArrayRef sz = self_info.tensor->sizes();
2385 at::IntArrayRef sd = self_info.tensor->strides();
2386
2387 auto append_size = [&](int i) {
2388 if (has_dimpacks_or_none) {
2389 nsz.append(A, sz[i]);
2390 nsd.append(A, sd[i]);
2391 }
2392 };
2393 // std::cout << "self levels: " << self_info.levels << "\n";
2394
2395 auto parse_nones = [&]() {
2396 while (input_it.size() && py::is_none(input_it[0])) {
2397 append_flat_handle(no_slice);
2398 nsz.append(A, 1);
2399 nsd.append(A, 0);
2400 input_it = input_it.slice(1);
2401 }
2402 };
2403
2404
2405 auto append_item = [&](int i, py::handle arg) {
2406 if (Dim::check_exact(arg)) {
2407 auto d = Dim::unchecked_wrap(arg);
2408 d->set_size(sz[i]);
2409 add_dim(d);
2410 append_size(i);
2411 append_flat_handle(arg);
2412 return;
2413 }
2414 auto info = TensorInfo::create(A, arg, false, false);
2415 if (info) {
2416 append_size(i);
2417 append_tensor_input(info);
2418 for (auto il : info.levels) {
2419 if (!il.is_positional()) {
2420 add_dim(il.dim());
2421 }
2422 }
2423 return;
2424 }
2425
2426 if (has_dimpacks_or_none) {
2427 Slice<py::handle> mp;
2428 if (maybe_dimpack(mp, arg)) {
2429 // dim pack
2430 Slice<py::hdl<Dim>> dim_pack;
2431 for (auto d : mp) {
2432 dim_pack.append(A, Dim::wrap(d));
2433 add_dim(dim_pack.back());
2434 append_flat_handle(dim_pack.back());
2435 }
2436 _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd);
2437 return;
2438 }
2439 }
2440
2441 append_size(i);
2442 append_flat_handle(arg);
2443 };
2444
2445 // pair up the indexing expressions with dimension of self it indexes
2446 // self may have first-class dims, which do not participate the indexing.
2447 for (auto i : self_info.levels.enumerate()) {
2448 auto l = self_info.levels[i];
2449 auto idx = keys.index(l);
2450 if (idx) {
2451 append_item(i, values[*idx]);
2452 } else if (l.is_positional()) {
2453 // grab and index from the positional list
2454 parse_nones();
2455 if (!input_it.size()) {
2456 // we might have fewer indices than tensor dimensions,
2457 // which implicitly indexes the remaining dimensions with :
2458 append_flat_handle(no_slice);
2459 append_size(i);
2460 } else {
2461 py::handle arg = input_it[0];
2462 input_it = input_it.slice(1);
2463 append_item(i, arg);
2464 }
2465 } else {
2466 add_dim(l.dim());
2467 append_flat_handle(l.dim());
2468 append_size(i);
2469 }
2470 }
2471 // any training Nones may have no existing dimension associated with them in self.
2472 parse_nones();
2473
2474 // we have to restride the tensor to collapse dimension packs and introduce our none dimensions.
2475 if (has_dimpacks_or_none) {
2476 self_info.tensor = A.autorelease(self_info.tensor->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()),at::IntArrayRef(nsd.begin(), nsd.end()), self_info.tensor->storage_offset()));
2477 }
2478
2479
2480 // figure out what the shape of the indexing tensors will be
2481 // and what the shape of the resulting tensor will be
2482 Slice<DimEntry> result_levels;
2483 Slice<DimEntry> index_levels;
2484 int64_t tensor_insert_point = -1;
2485 bool requires_getindex = false;
2486 auto mark_tensor_index = [&] {
2487 if (tensor_insert_point == -1) {
2488 tensor_insert_point = result_levels.size();
2489 } else if (tensor_insert_point != result_levels.size()) {
2490 tensor_insert_point = 0;
2491 }
2492 };
2493 for (auto i : flat_inputs.enumerate()) {
2494 auto inp = flat_inputs[i];
2495 if(tensor_inputs[i]) {
2496 requires_getindex = true;
2497 mark_tensor_index();
2498 for (auto l : tensor_inputs[i].levels) {
2499 // std::cout << "Consider to add " << l << "\n";
2500 if (!index_levels.contains(l)) {
2501 index_levels.append(A, l);
2502 }
2503 }
2504 } else if (Dim::check_exact(inp)) {
2505 auto d = Dim::unchecked_wrap(inp);
2506 // dimesions used once are just binding operations
2507 if (1 == seen_dims_nuses[*seen_dims.index(d)]) {
2508 flat_inputs[i] = no_slice;
2509 result_levels.append(A, d);
2510 } else {
2511 requires_getindex = true;
2512 flat_inputs[i] = py::handle();
2513 tensor_inputs[i] = TensorInfo {d->range(), Slice<DimEntry>(A, DimEntry(d)), false, TensorRef()};
2514 if (!index_levels.contains(d)) {
2515 index_levels.append(A, d);
2516 }
2517 mark_tensor_index();
2518 }
2519 } else {
2520 if (inp.ptr() != no_slice.ptr()) {
2521 requires_getindex = true;
2522 }
2523 if (!py::is_int(inp)) {
2524 // note: actual positional indexes are accurately computed later
2525 result_levels.append(A, -1);
2526 }
2527 }
2528 }
2529
2530 // indexing dimensions appear in the tensor at the _first use of a tensor_ in the indexing. So insert
2531 // the indexing leveles into the result klevels at this spot
2532 if (tensor_insert_point != -1) {
2533 result_levels.insert(A, result_levels.slice(tensor_insert_point, tensor_insert_point), index_levels);
2534 }
2535
2536 // std::cout << "flat inputs: " << flat_inputs << "\n";
2537 // std::cout << "result_levels: " << result_levels << "\n";
2538 // std::cout << "index_levels: " << index_levels << "\n";
2539
2540 // get all the tensors to be the right shape for indexing
2541 if (requires_getindex) {
2542 for (auto i : flat_inputs.enumerate()) {
2543 if (tensor_inputs[i]) {
2544 AT_ASSERT(!flat_inputs[i].ptr());
2545 // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << "\n";
2546 TensorRef t = tensor_inputs[i].tensor;
2547 if (!tensor_inputs[i].has_device && device_holding_tensor) {
2548 t = A.autorelease(t->to(device_holding_tensor->device()));
2549 }
2550 flat_inputs[i] = handle_from_tensor(A, _match_levels(A, t, tensor_inputs[i].levels, index_levels));
2551 }
2552 }
2553 }
2554
2555 // previously we didn't know how many positional dimensions there would be so we couldn't number them right
2556 // so fill it in now.
2557 auto seen_positionals = 0;
2558 for (auto i : result_levels.reversed_enumerate()) {
2559 if (result_levels[i].is_positional()) {
2560 result_levels[i] = -(++seen_positionals);
2561 }
2562 }
2563
2564 return IndexingInfo {false, requires_getindex, self_info.tensor, flat_inputs, result_levels, self_info.has_device};
2565}
2566
2567static py::object invoke_getitem(Arena& A, const IndexingInfo& iinfo) {
2568 at::Tensor rtensor;
2569 if (iinfo.advanced_indexing) {
2570 auto self_hdl = handle_from_tensor(A, iinfo.self);
2571 auto tup = slice_to_tuple(iinfo.flat_inputs);
2572 // std::cout << "calling original getindex " << self_hdl << " " << tup << "\n";
2573 auto pytensor = py::object::checked_steal(THPVariable_getitem(self_hdl.ptr(), tup.ptr()));
2574 rtensor = THPVariable_Unpack(pytensor.ptr());
2575 } else {
2576 // std::cout << "skipping original getindex\n";
2577 rtensor = *iinfo.self;
2578 }
2579 // std::cout << "returning (from_positional)\n";
2580 return Tensor::from_positional(A, std::move(rtensor), iinfo.result_levels, iinfo.has_device);
2581}
2582
2583static py::object __getitem__(Arena & A, py::handle self, py::handle index) {
2584 maybeInitializeGlobals();
2585 auto iinfo = getsetitem(A, self, index, has_dims(self));
2586 if (iinfo.can_call_original) {
2587 return py::object::checked_steal(THPVariable_getitem(self.ptr(), index.ptr()));
2588 }
2589
2590 return invoke_getitem(A, iinfo);
2591}
2592
2593
2594PyObject* Tensor_getitem(PyObject* self, PyObject* index) {
2595 Arena A;
2596 PY_BEGIN
2597 return __getitem__(A, self, index).release();
2598 PY_END(nullptr);
2599}
2600
2601static void __setitem__(Arena & A, py::handle self, py::handle index, py::handle rhs) {
2602 maybeInitializeGlobals();
2603 auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs));
2604 if (iinfo.can_call_original) {
2605 if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) {
2606 throw py::exception_set();
2607 }
2608 return;
2609 }
2610
2611 auto rhs_info = TensorInfo::create(A, rhs, false, false);
2612 if (rhs_info) { // otherwise rhs can be a scalar...
2613 for (auto l : rhs_info.levels) {
2614 if (!iinfo.result_levels.contains(l)) {
2615 if (l.is_positional()) {
2616 py::raise_error(DimensionBindError(), "rhs contains too many dimensions (%d) compared to indexed value (%d)", ndim_of_levels(iinfo.result_levels), rhs_info.ndim());
2617 } else {
2618 auto tup = levels_to_tuple(iinfo.result_levels);
2619 py::raise_error(DimensionBindError(), "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", l.dim().ptr(), tup.ptr());
2620 }
2621 }
2622 }
2623 auto rhs_matched = _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels);
2624 rhs = handle_from_tensor(A, rhs_matched);
2625 }
2626 self = handle_from_tensor(A, iinfo.self);
2627
2628 if (iinfo.advanced_indexing) {
2629 auto tup = slice_to_tuple(iinfo.flat_inputs);
2630 if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) {
2631 throw py::exception_set();
2632 }
2633 } else {
2634 torch_Tensor_copy_.call(self, rhs);
2635 }
2636}
2637
2638
2639int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value) {
2640 Arena A;
2641 PY_BEGIN
2642 __setitem__(A, self, index, value);
2643 return 0;
2644 PY_END(-1);
2645}
2646
2647static PyObject* py___getitem__(PyObject *_,
2648 PyObject *const *args,
2649 Py_ssize_t nargs,
2650 PyObject *kwnames) {
2651 Arena A;
2652 PY_BEGIN
2653 AT_ASSERT(nargs == 2);
2654 return __getitem__(A, args[0], args[1]).release();
2655 PY_END(nullptr)
2656}
2657
2658static PyObject* py___setitem__(PyObject *_,
2659 PyObject *const *args,
2660 Py_ssize_t nargs,
2661 PyObject *kwnames) {
2662 Arena A;
2663 PY_BEGIN
2664 AT_ASSERT(nargs == 3);
2665 __setitem__(A, args[0], args[1], args[2]);
2666 Py_RETURN_NONE;
2667 PY_END(nullptr)
2668}
2669
2670
2671static PyObject* py_index(PyObject *_,
2672 PyObject *const *args,
2673 Py_ssize_t nargs,
2674 PyObject *kwnames) {
2675 Arena A;
2676 PY_BEGIN
2677 py::vector_args va(args, nargs, kwnames);
2678 py::handle self, dims, indices;
2679 va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3);
2680 return index(A, self, dims, indices).release();
2681 PY_END(nullptr)
2682}
2683
2684
2685static PyObject* py_stack(PyObject *_,
2686 PyObject *const *args,
2687 Py_ssize_t nargs,
2688 PyObject *kwnames) {
2689 Arena A;
2690 PY_BEGIN
2691 py::vector_args va(args, nargs, kwnames);
2692 py::handle tensors, new_dim, dim;
2693 va.parse("stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2);
2694
2695 Slice<DimEntry> result_levels;
2696 Slice<TensorInfo> infos;
2697 py::sequence_view sv(tensors);
2698 auto new_dim_d = Dim::wrap(new_dim);
2699 for (auto i : sv.enumerate()) {
2700 infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false));
2701 for (auto l : infos.back().levels) {
2702 if (!result_levels.contains(l)) {
2703 result_levels.append(A, l);
2704 }
2705 }
2706 }
2707 new_dim_d->set_size(infos.size());
2708 std::vector<at::Tensor> inputs;
2709 inputs.reserve(infos.size());
2710 for (auto in : infos) {
2711 inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels));
2712 }
2713 auto ndim = ndim_of_levels(result_levels);
2714 int64_t rawdim = 0;
2715 if (dim.ptr()) {
2716 auto d = _wrap_dim(dim, ndim, false);
2717 auto idx = result_levels.index(d);
2718 if (!idx) {
2719 py::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr());
2720 }
2721 rawdim = *idx;
2722 }
2723 auto result = at::stack(inputs, rawdim);
2724 result_levels.insert(A, rawdim, new_dim_d);
2725 return Tensor::from_positional(A, std::move(result), result_levels, true).release();
2726 PY_END(nullptr)
2727}
2728
2729static PyObject* py_split(PyObject *_,
2730 PyObject *const *args,
2731 Py_ssize_t nargs,
2732 PyObject *kwnames) {
2733 Arena A;
2734 PY_BEGIN
2735 maybeInitializeGlobals();
2736 py::vector_args va(args, nargs, kwnames);
2737 py::handle self, split_size_or_sections, dim;
2738 va.parse("split", {"self", "split_size_or_sections", "dim"}, {&self, &split_size_or_sections, &dim}, 2);
2739 bool dim_is_object = dim.ptr() && Dim::check_exact(dim);
2740 Slice<py::handle> sizes;
2741
2742 bool all_dims = true;
2743 bool all_ints = true;
2744
2745 if (!py::is_int(split_size_or_sections)) {
2746 py::sequence_view sv(split_size_or_sections);
2747 for (auto i : sv.enumerate()) {
2748 sizes.append(A, A.autorelease(sv[i]));
2749 if (Dim::check_exact(sizes.back())) {
2750 all_ints = false;
2751 } else {
2752 all_dims = false;
2753 }
2754 }
2755 }
2756 if (all_ints) {
2757 if (dim_is_object) {
2758 py::raise_error(PyExc_TypeError, "when dim is specified as a Dim object, split sizes must also be dimensions.");
2759 }
2760 // call original split (if self has dimensions this will use torch function to do the split)
2761 return torch_Tensor_split.call_vector(py::vector_args(args, nargs, kwnames)).release();
2762 }
2763 if (!all_dims) {
2764 py::raise_error(PyExc_TypeError, "split list must be ints or dims but got a mix");
2765 }
2766
2767 auto self_info = TensorInfo::create(A, self, false);
2768 auto ndim = self_info.ndim();
2769 if (!dim_is_object&& ndim == 0) {
2770 py::raise_error(PyExc_TypeError, "split expects at least a 1-dimension tensor");
2771 }
2772 DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim;
2773
2774 auto idx = self_info.levels.index(dim_l);
2775 if (!idx) {
2776 if (!dim.ptr()) {
2777 dim = A.autorelease(py::from_int(0));
2778 }
2779 py::raise_error(PyExc_TypeError, "tensor does not comtain dimension %R", dim.ptr());
2780 }
2781 Slice<int64_t> indices;
2782
2783 int64_t total_size = 0;
2784 Slice<int64_t> unbound;
2785 for (auto i : sizes.enumerate()) {
2786 auto d = Dim::unchecked_wrap(sizes[i]);
2787 if (d->is_bound()) {
2788 indices.append(A, d->size());
2789 total_size += indices.back();
2790 } else {
2791 indices.append(A, 0);
2792 unbound.append(A, i);
2793 }
2794 }
2795 auto tensor_size = self_info.tensor->sizes()[*idx];
2796
2797 if (unbound.size()) {
2798 if (total_size > tensor_size) {
2799 py::raise_error(PyExc_TypeError, "sizes of target dimensions add up to more (%d) than source dim (%d)", int(total_size), int(tensor_size));
2800 }
2801 auto remaining_size = tensor_size - total_size;
2802 auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size();
2803 for (auto u : unbound) {
2804 auto sz = std::min(chunk_size, remaining_size);
2805 Dim::unchecked_wrap(sizes[u])->set_size(sz);
2806 indices[u] = sz;
2807 remaining_size -= sz;
2808 }
2809 } else if (tensor_size != total_size) {
2810 py::raise_error(PyExc_TypeError, "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", int(total_size), int(tensor_size));
2811 }
2812
2813 auto result_tensors = self_info.tensor->split_with_sizes(at::IntArrayRef(indices.begin(), indices.end()), *idx);
2814 py::tuple result(result_tensors.size());
2815 Slice<DimEntry> new_levels;
2816 new_levels.extend(A, self_info.levels);
2817 for (auto i : sizes.enumerate()) {
2818 new_levels[*idx] = Dim::unchecked_wrap(sizes[i]);
2819 result.set(i, Tensor::from_positional(A, std::move(result_tensors[i]), new_levels, true));
2820 }
2821
2822 return result.release();
2823
2824 PY_END(nullptr)
2825}
2826
2827
2828static DimEntry _wrap_dim(py::handle d, size_t N, bool keepdim) {
2829 if (Dim::check(d)) {
2830 if (keepdim) {
2831 py::raise_error(PyExc_ValueError, "cannot preserve first-class dimensions with keepdim=True");
2832 }
2833 return Dim::unchecked_wrap(d);
2834 } else if (py::is_int(d)) {
2835 auto i = py::to_int(d);
2836 while (i >= 0) {
2837 i -= N;
2838 }
2839 return i;
2840 } else {
2841 return DimEntry();
2842 }
2843}
2844
2845static Slice<DimEntry> _wrap_dims(Arena& A, py::handle d, size_t N, bool keepdim) {
2846 auto de = _wrap_dim(d, N, keepdim);
2847 Slice<DimEntry> r;
2848 if (!de.is_none()) {
2849 r.append(A, de);
2850 } else {
2851 py::sequence_view sq(d);
2852 for (auto i : sq.enumerate()) {
2853 r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim));
2854 }
2855 }
2856 return r;
2857}
2858
2859struct WrappedOperator : public py::base<WrappedOperator> {
2860 py::object orig;
2861 PyMethodDef method_def;
2862 py::object name, doc;
2863
2864 bool is_pointwise = false;
2865 int64_t dim_offset = 0;
2866 int64_t keepdim_offset = 1;
2867 std::string dim_name;
2868 bool single_dim = false;
2869 bool reduce = true;
2870
2871 static PyTypeObject Type;
2872
2873 void init(py::object orig_, PyCFunction wrapper_implementation, std::string dim_name_="") {
2874 orig = std::move(orig_);
2875 method_def.ml_meth = wrapper_implementation;
2876 name = orig.attr("__name__");
2877 doc = orig.attr("__doc__");
2878 dim_name = std::move(dim_name_);
2879 if (!py::is_none(doc) && !dim_name.empty()) {
2880 doc = py::unicode_from_format("%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", doc.ptr(), dim_name.c_str());
2881 }
2882 method_def.ml_name = py::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr());
2883 method_def.ml_doc = py::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr());
2884 method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS;
2885 }
2886
2887 py::object function() {
2888 return py::object::checked_steal(PyCFunction_New(&method_def, ptr()));
2889 }
2890
2891};
2892
2893PyTypeObject WrappedOperator::Type = {
2894 PyVarObject_HEAD_INIT(NULL, 0)
2895 "_C.WrappedOperator", /* tp_name */
2896 sizeof(WrappedOperator), /* tp_basicsize */
2897 0, /* tp_itemsize */
2898 WrappedOperator::dealloc_stub, /* tp_dealloc */
2899 0, /* tp_vectorcall_offset */
2900 0, /* tp_getattr */
2901 0, /* tp_setattr */
2902 0, /* tp_as_async */
2903 0, /* tp_repr */
2904 0, /* tp_as_number */
2905 0, /* tp_as_sequence */
2906 0, /* tp_as_mapping */
2907 0, /* tp_hash */
2908 0, /* tp_call */
2909 0, /* tp_str */
2910 0, /* tp_getattro */
2911 0, /* tp_setattro */
2912 0, /* tp_as_buffer */
2913 Py_TPFLAGS_DEFAULT, /* tp_flags */
2914 "Wrapped Object Holder", /* tp_doc */
2915 0, /* tp_traverse */
2916 0, /* tp_clear */
2917 0, /* tp_richcompare */
2918 0, /* tp_weaklistoffset */
2919 0, /* tp_iter */
2920 0, /* tp_iternext */
2921 0, /* tp_methods */
2922 0, /* tp_members */
2923 0, /* tp_getset */
2924 0, /* tp_base */
2925 0, /* tp_dict */
2926 0, /* tp_descr_get */
2927 0, /* tp_descr_set */
2928 0, /* tp_dictoffset */
2929 0, /* tp_init */
2930 0, /* tp_alloc */
2931 WrappedOperator::new_stub, /* tp_new */
2932};
2933
2934static PyObject* patched_dim_method(PyObject * self_,
2935 PyObject *const *args,
2936 Py_ssize_t nargs,
2937 PyObject *kwnames) {
2938 Arena A;
2939 auto self = WrappedOperator::unchecked_wrap(self_);
2940 PY_BEGIN
2941
2942 py::vector_args va(args, nargs, kwnames);
2943
2944 auto _getarg = [&](const char* name, int64_t offset_) -> py::handle {
2945 auto offset = offset_ + 1; // do not include self
2946 auto idx = va.index(name, offset);
2947 return idx == -1 ? py::handle() : va[idx];
2948 };
2949 Slice<py::handle> patched_args;
2950 patched_args.extend(A, va.begin(), va.end());
2951 auto _patcharg = [&](const char* name, int64_t offset_, py::handle value) {
2952 auto offset = offset_ + 1; // do not include self
2953 auto idx = va.index(name, offset);
2954 if (idx == -1) {
2955 py::raise_error(PyExc_ValueError, "Missing argument %s", name);
2956 }
2957 patched_args[idx] = value;
2958 };
2959
2960 auto dim = _getarg(self->dim_name.c_str(), self->dim_offset);
2961 if (!dim.ptr()) {
2962 auto info = TensorInfo::create(A, args[0], true);
2963 EnableAllLayers l(A, info.levels);
2964 l.inplace_update_layers(info.batchedtensor, info.levels);
2965 patched_args[0] = handle_from_tensor(A, info.batchedtensor);
2966 auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames);
2967 return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device).release();
2968 }
2969
2970 auto info = TensorInfo::create(A, args[0]);
2971 auto keepdim = false;
2972 if (self->reduce) {
2973 auto py_keepdim = _getarg("keepdim", self->keepdim_offset);
2974 if (py_keepdim.ptr()) {
2975 keepdim = py::to_bool(py_keepdim);
2976 }
2977 }
2978
2979 auto ndim = info.ndim();
2980 auto dims = _wrap_dims(A, dim, ndim, keepdim);
2981 Slice<int64_t> dim_indices;
2982 auto seen = A.allocate<bool>(info.levels.size());
2983 std::fill(seen, seen + info.levels.size(), false);
2984
2985 for (auto d : dims) {
2986 auto midx = info.levels.index(d);
2987 if (!midx) {
2988 auto tup = levels_to_tuple(info.levels);
2989 py::raise_error(PyExc_ValueError, "Tensor with dimensions %R does not contain one of %R\n", tup.ptr(), dim.ptr());
2990 }
2991 seen[*midx] = true;
2992 dim_indices.append(A, *midx);
2993 }
2994 Slice<DimEntry> new_levels;
2995 if (self->reduce && !keepdim) {
2996 for (auto i : info.levels.enumerate()) {
2997 if (!seen[i]) {
2998 new_levels.append(A, info.levels[i]);
2999 }
3000 }
3001 } else {
3002 new_levels = info.levels;
3003 }
3004 py::object py_indices;
3005 if (dim_indices.size() == 1) {
3006 py_indices = py::from_int(dim_indices[0]);
3007 } else {
3008 py::tuple tup(dim_indices.size());
3009 for (auto i : dim_indices.enumerate()) {
3010 tup.set(i, py::from_int(dim_indices[i]));
3011 }
3012 py_indices = std::move(tup);
3013 }
3014 _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices);
3015 patched_args[0] = handle_from_tensor(A, info.tensor);
3016 auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames);
3017 auto wrap = [&](py::handle h) {
3018 if (THPVariable_Check(h.ptr())) {
3019 return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device));
3020 }
3021 return h;
3022 };
3023 return tree_map(A, wrap, r).release();
3024 PY_END(nullptr)
3025}
3026
3027static PyObject* _wrap(PyObject * self_,
3028 PyObject *const *args,
3029 Py_ssize_t nargs,
3030 PyObject *kwnames) {
3031 Arena A;
3032 PY_BEGIN
3033
3034 #define ARGS(_) _(py::handle, orig) _(py::handle, dim_offset) _(py::handle, keepdim_offset) \
3035 _(py::handle, dim_name) _(py::handle, single_dim) _(py::handle, reduce)
3036 MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS)
3037
3038 std::string dim_name_str;
3039 if (dim_name.ptr()) {
3040 dim_name_str = PyUnicode_AsUTF8(dim_name.ptr());
3041 } else {
3042 dim_name_str = "dim";
3043 }
3044 auto info = WrappedOperator::create(py::object::borrow(orig), (PyCFunction)(void*) patched_dim_method, std::move(dim_name_str));
3045 if (dim_offset.ptr()) {
3046 info->dim_offset = py::to_int(dim_offset);
3047 }
3048 if (keepdim_offset.ptr()) {
3049 info->keepdim_offset = py::to_int(keepdim_offset);
3050 }
3051
3052 if (single_dim.ptr()) {
3053 info->single_dim = py::to_bool(single_dim);
3054 }
3055 if (reduce.ptr()) {
3056 info->reduce = py::to_bool(reduce);
3057 }
3058 return info->function().release();
3059 #undef ARGS
3060
3061 PY_END(nullptr)
3062}
3063
3064static PyObject* call_torch_function(PyObject *self,
3065 PyObject *const *args,
3066 Py_ssize_t nargs,
3067 PyObject *kwnames) {
3068 PY_BEGIN
3069 Arena A;
3070 maybeInitializeGlobals();
3071 auto info = WrappedOperator::unchecked_wrap(self);
3072 return __torch_function__(A, info->orig, py::vector_args(args, nargs, kwnames), info->is_pointwise).release();
3073 PY_END(nullptr)
3074}
3075
3076static PyObject* _wrap_method(PyObject *self,
3077 PyObject *const *args,
3078 Py_ssize_t nargs,
3079 PyObject *kwnames) {
3080 PY_BEGIN
3081 AT_ASSERT(nargs == 2);
3082 // XXX - ignore python function wrapped, we will call torch function directly
3083 py::handle orig = args[0];
3084 if (!pointwise.ptr()) {
3085 auto dim = py::import("functorch.dim");
3086 pointwise = dim.attr("pointwise");
3087 }
3088 auto info = WrappedOperator::create(py::object::borrow(orig), (PyCFunction)(void*) call_torch_function);
3089 info->is_pointwise = pointwise.contains(orig);
3090 return PyInstanceMethod_New(info->function().release());
3091 PY_END(nullptr);
3092}
3093
3094
3095static PyObject* Tensor_sum(PyObject * self_,
3096 PyObject *const *args,
3097 Py_ssize_t nargs,
3098 PyObject *kwnames) {
3099 Arena A;
3100 PY_BEGIN
3101 maybeInitializeGlobals();
3102 py::vector_args va(args, nargs, kwnames);
3103 auto self_ = Tensor::unchecked_wrap(args[0]);
3104 auto d = self_->delayed();
3105 if (!d) {
3106 return _Tensor_sum.call_vector(va).release();
3107 }
3108 py::handle self, dim, keepdim, dtype;
3109 va.parse("sum", {"self", "dim", "keepdim", "dtype"}, {&self, &dim, &keepdim, &dtype}, 1, 1);
3110
3111 if (dtype.ptr() || (keepdim.ptr() && py::to_bool(keepdim))) {
3112 // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n";
3113 return _Tensor_sum.call_vector(va).release();
3114 }
3115 auto levels = self_->levels();
3116
3117 auto N = ndim_of_levels(levels);
3118 auto reduced_dims = _wrap_dims(A, dim, N, false);
3119
3120 return dot(A, TensorInfo::create(A, d->args[0], false), TensorInfo::create(A, d->args[1], false), reduced_dims).release();
3121 PY_END(nullptr)
3122}
3123
3124static PyObject* _parse_test(PyObject * self_,
3125 PyObject *const *args,
3126 Py_ssize_t nargs,
3127 PyObject *kwnames) {
3128 PY_BEGIN
3129 maybeInitializeGlobals();
3130
3131 int required = py::to_int(args[0]);
3132 int kwonly = py::to_int(args[1]);
3133
3134 py::vector_args va(args + 2, nargs - 2, kwnames);
3135
3136
3137 py::handle a, b, c, d;
3138 va.parse("_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly);
3139 py::tuple r(4);
3140 r.set(0, py::object::borrow(a.ptr() ? a : Py_None));
3141 r.set(1, py::object::borrow(b.ptr() ? b : Py_None));
3142 r.set(2, py::object::borrow(c.ptr() ? c : Py_None));
3143 r.set(3, py::object::borrow(d.ptr() ? d : Py_None));
3144 return r.release();
3145
3146 PY_END(nullptr)
3147}
3148
3149static PyObject* _set_pointwise_optimize(PyObject * self_,
3150 PyObject *const *args,
3151 Py_ssize_t nargs,
3152 PyObject *kwnames) {
3153 PY_BEGIN
3154 py::handle value;
3155 py::vector_args va(args, nargs, kwnames);
3156 va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1);
3157 pointwise_optimize = py::to_bool(value);
3158 Py_RETURN_NONE;
3159 PY_END(nullptr)
3160}
3161
3162static PyObject* _patch_tensor_class(PyObject * self_,
3163 PyObject *const *args,
3164 Py_ssize_t nargs,
3165 PyObject *kwnames) {
3166 PY_BEGIN
3167
3168 auto torch = py::import("torch");
3169 auto py_TensorBase = torch.attr("_C").attr("_TensorBase");
3170 replaceMappingIfMatches(py_TensorBase);
3171
3172 Py_RETURN_NONE;
3173 PY_END(nullptr)
3174}
3175
3176
3177const char* dims_doc = R"""(
3178dims(n=None, sizes=None) -> torchdim.Dim or Tuple[torchdim.Dim, ...]
3179
3180Creates and returns one or more Dim objects.
3181
3182Arg:
3183 n (int, optional): The number of dimensions to create. Can be omitted if sizes is specified.
3184 sizes (List[Optional[int]], optional): A list the same size as the number of dimensions to be
3185 created, specifying each dimensions size, or None to leave the size unset.
3186
3187Example::
3188 >>> batch, channel, width, height = dims(4)
3189 >>> batch, channel, width, height = dims(sizes=[None, 3, 224, 224])
3190)""";
3191
3192static PyMethodDef methods[] = {
3193 {"dims", (PyCFunction)(void*) _dims<create_dim>, METH_FASTCALL | METH_KEYWORDS, dims_doc},
3194 {"dimlists", (PyCFunction)(void*) _dims<create_dimlist>, METH_FASTCALL | METH_KEYWORDS},
3195 {"_test_c", (PyCFunction)(void*) test_c, METH_FASTCALL | METH_KEYWORDS},
3196 {"_wrap_method", (PyCFunction)(void*) _wrap_method, METH_FASTCALL | METH_KEYWORDS},
3197 {"Tensor_from_positional", (PyCFunction)(void*) py_Tensor_from_positional, METH_FASTCALL | METH_KEYWORDS},
3198 {"__torch_function__", (PyCFunction)(void*) py___torch_function__, METH_FASTCALL | METH_KEYWORDS},
3199 {"tree_flatten", (PyCFunction)(void*) py_tree_flatten, METH_FASTCALL | METH_KEYWORDS},
3200 {"order", (PyCFunction)(void*) order, METH_FASTCALL | METH_KEYWORDS},
3201 {"index", (PyCFunction)(void*) py_index, METH_FASTCALL | METH_KEYWORDS},
3202 {"stack", (PyCFunction)(void*) py_stack, METH_FASTCALL | METH_KEYWORDS},
3203 {"split", (PyCFunction)(void*) py_split, METH_FASTCALL | METH_KEYWORDS},
3204 {"expand", (PyCFunction)(void*) expand, METH_FASTCALL | METH_KEYWORDS},
3205 {"__getitem__", (PyCFunction)(void*) py___getitem__, METH_FASTCALL | METH_KEYWORDS},
3206 {"__setitem__", (PyCFunction)(void*) py___setitem__, METH_FASTCALL | METH_KEYWORDS},
3207 {"_wrap", (PyCFunction)(void*) _wrap, METH_FASTCALL | METH_KEYWORDS},
3208 {"Tensor_sum", (PyCFunction)(void*) Tensor_sum, METH_FASTCALL | METH_KEYWORDS},
3209 {"_parse_test", (PyCFunction)(void*) _parse_test, METH_FASTCALL | METH_KEYWORDS},
3210 {"_set_pointwise_optimize", (PyCFunction)(void*) _set_pointwise_optimize, METH_FASTCALL | METH_KEYWORDS},
3211 {"_patch_tensor_class", (PyCFunction)(void*) _patch_tensor_class, METH_FASTCALL | METH_KEYWORDS},
3212 {NULL, NULL, 0, NULL} /* Sentinel */
3213};
3214
3215static struct PyModuleDef module_def = {
3216 PyModuleDef_HEAD_INIT,
3217 "_C", /* name of module */
3218 NULL, /* module documentation, may be NULL */
3219 -1, /* size of per-interpreter state of the module,
3220 or -1 if the module keeps state in global variables. */
3221 methods
3222};
3223
3224PyObject* Dim_init(void) {
3225 Arena A;
3226 try {
3227 py::object mod = py::object::checked_steal(PyModule_Create(&module_def));
3228 Dim::ready(mod, "Dim");
3229 DimList::ready(mod, "DimList");
3230 Tensor::ready(mod, "Tensor");
3231 WrappedOperator::ready(mod, "_WrappedOperator");
3232 Py_INCREF(&PyInstanceMethod_Type);
3233 PyModule_AddObject(mod.ptr(), "_instancemethod", (PyObject *)&PyInstanceMethod_Type);
3234
3235 initializeGlobals(A);
3236 return mod.release();
3237 } catch(py::exception_set& err) {
3238 return nullptr;
3239 }
3240}
3241