1 | // Copyright (c) Facebook, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // This source code is licensed under the BSD-style license found in the |
5 | // LICENSE file in the root directory of this source tree. |
6 | |
7 | #include "minpybind.h" |
8 | #include <frameobject.h> |
9 | #include <opcode.h> |
10 | #include <utility> |
11 | #include <new> |
12 | #include <iostream> |
13 | #include <vector> |
14 | //#include <torch/csrc/autograd/python_variable.h> |
15 | #include <torch/csrc/utils/python_compat.h> |
16 | #include <torch/csrc/Export.h> |
17 | #include <ATen/functorch/BatchedTensorImpl.h> |
18 | #include <ATen/functorch/DynamicLayer.h> |
19 | #include <ATen/ATen.h> |
20 | #include <memory> |
21 | #include "arena.h" |
22 | #include "python_variable_simple.h" |
23 | |
24 | #if IS_PYTHON_3_11_PLUS |
25 | #define Py_BUILD_CORE |
26 | #include "internal/pycore_opcode.h" |
27 | #undef Py_BUILD_CORE |
28 | #endif |
29 | |
30 | // C++ API functions for objects to |
31 | // * construct the object, returning a ref-counted handle |
32 | // * The actual API, with methods that take/return C-typed values |
33 | |
34 | // extend minpybind.h to include |
35 | // * typed handles so that -> can get to their raw API |
36 | // * object/handle distinction for the typed handles |
37 | |
38 | // class Dim: --------------- |
39 | py::handle torch_Tensor___mul__; |
40 | py::handle _Tensor; |
41 | py::handle _Tensor_sum; |
42 | py::handle NamedTuple; |
43 | py::dict_view pointwise; |
44 | py::handle torch_Tensor_expand; |
45 | binaryfunc THPVariable_getitem; |
46 | objobjargproc THPVariable_setitem; |
47 | py::handle no_slice; |
48 | PyTypeObject* torch_Tensor; |
49 | py::handle torch_Tensor_copy_; |
50 | py::handle torch_Tensor_split; |
51 | bool pointwise_optimize = true; |
52 | PyTypeObject* DimType = nullptr; |
53 | |
54 | static void maybeInitializeGlobals() { |
55 | // globals that depend on the python dim library, |
56 | // which we can't lookup until we finish initializing the _C module |
57 | if (_Tensor.ptr()) { |
58 | return; |
59 | } |
60 | auto dim = py::import("functorch.dim" ); |
61 | _Tensor = dim.attr("_Tensor" ); |
62 | pointwise = dim.attr("pointwise" ); |
63 | _Tensor_sum = _Tensor.attr("sum" ); |
64 | DimType = (PyTypeObject*) py::import("functorch.dim" ).attr("Dim" ).ptr(); |
65 | } |
66 | |
67 | PyObject* Tensor_getitem(PyObject* self, PyObject* index); |
68 | int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value); |
69 | |
70 | void replaceMappingIfMatches(py::handle tp) { |
71 | auto T = (PyTypeObject*) tp.ptr(); |
72 | bool recurse = false; |
73 | if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) { |
74 | T->tp_as_mapping->mp_subscript = Tensor_getitem; |
75 | recurse = true; |
76 | } |
77 | if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) { |
78 | T->tp_as_mapping->mp_ass_subscript = Tensor_setitem; |
79 | recurse = true; |
80 | } |
81 | if (recurse) { |
82 | auto result = tp.attr("__subclasses__" ).call(); |
83 | py::list_view lv(result); |
84 | for (auto i : lv.enumerate()) { |
85 | replaceMappingIfMatches(lv[i]); |
86 | } |
87 | } |
88 | } |
89 | |
90 | static void initializeGlobals(Arena & A) { |
91 | auto torch = py::import("torch" ); |
92 | torch_Tensor = (PyTypeObject*) torch.attr("Tensor" ).ptr(); |
93 | torch_Tensor___mul__ = torch.attr("Tensor" ).attr("__mul__" ); |
94 | |
95 | torch_Tensor_expand = torch.attr("_C" ).attr("_TensorBase" ).attr("expand" ); |
96 | torch_Tensor_split = torch.attr("_C" ).attr("_TensorBase" ).attr("split" ); |
97 | torch_Tensor_copy_ = torch.attr("Tensor" ).attr("copy_" ); |
98 | auto py_TensorBase = torch.attr("_C" ).attr("_TensorBase" ); |
99 | auto TensorBase = (PyTypeObject*) py_TensorBase.ptr(); |
100 | THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript; |
101 | THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; |
102 | NamedTuple = py::import("typing" ).attr("NamedTuple" ); |
103 | no_slice = PySlice_New(NULL, NULL, NULL); |
104 | |
105 | } |
106 | |
107 | py::handle DimensionBindError_; |
108 | static py::handle DimensionBindError() { |
109 | if(!DimensionBindError_.ptr()) { |
110 | DimensionBindError_ = py::import("functorch.dim" ).attr("DimensionBindError" ); |
111 | } |
112 | return DimensionBindError_; |
113 | } |
114 | |
115 | static int64_t n_dims_created = 65; |
116 | |
117 | struct Dim : public py::base<Dim> { |
118 | int64_t level_; // for stable comparisons in prototype |
119 | py::object name_; |
120 | Dim() |
121 | : level_(n_dims_created++) {} |
122 | void init(py::object name, int64_t s = -1) { |
123 | name_ = std::move(name); |
124 | size_ = s; |
125 | } |
126 | |
127 | static bool check_exact(py::handle v) { |
128 | return Py_TYPE(v.ptr()) == DimType; |
129 | } |
130 | |
131 | int64_t size() const { |
132 | if (size_ == -1) { |
133 | py::raise_error(PyExc_ValueError, "dimension %S is unbound" , name_.ptr()); |
134 | } |
135 | return size_; |
136 | } |
137 | void set_size(int64_t v) { |
138 | if (size_ == -1) { |
139 | size_ = v; |
140 | } else if(size_ != v) { |
141 | py::raise_error(DimensionBindError(), "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld" , this, this->size_, v); |
142 | } |
143 | } |
144 | bool is_bound() const { |
145 | return size_ != -1; |
146 | } |
147 | static py::obj<Dim> create(py::object name, int64_t s = -1) { |
148 | if (!DimType) { |
149 | maybeInitializeGlobals(); |
150 | } |
151 | auto r = Dim::alloc(DimType); |
152 | r->init(std::move(name), s); |
153 | return r; |
154 | } |
155 | static PyTypeObject Type; |
156 | const at::Tensor& range() { |
157 | if (!range_.defined()) { |
158 | range_ = at::arange(size()); |
159 | } |
160 | return range_; |
161 | } |
162 | const at::Tensor& batchtensor() { |
163 | if (!batchtensor_.defined()) { |
164 | batchtensor_ = at::functorch::addBatchDim(range(), 0, level_); |
165 | } |
166 | return batchtensor_; |
167 | } |
168 | private: |
169 | int64_t size_{-1}; |
170 | at::Tensor range_; |
171 | at::Tensor batchtensor_; |
172 | }; |
173 | |
174 | struct DimEntry { |
175 | // union of either a negative number indicating which dimension this is from the rhs, |
176 | // or a pointer to a first-class dimension. |
177 | // pointers do not have their highest bit set, so checking the number is negative tells us |
178 | // that it is not a dim. |
179 | bool is_positional() const { |
180 | return data_ < 0; |
181 | } |
182 | bool is_none() const { |
183 | return data_ == 0; |
184 | } |
185 | int64_t position() const { |
186 | return data_; |
187 | } |
188 | py::hdl<Dim> dim() const { |
189 | Dim* result; |
190 | std::memcpy(&result, &data_, sizeof(Dim*)); |
191 | return py::hdl<Dim>(result); |
192 | } |
193 | |
194 | DimEntry() |
195 | : data_(0) {} |
196 | |
197 | DimEntry(int64_t pos) |
198 | : data_(pos) { |
199 | AT_ASSERT(pos < 0); |
200 | } |
201 | DimEntry(py::hdl<Dim> d) { |
202 | std::memcpy(&data_, &d, sizeof(int64_t)); |
203 | } |
204 | bool operator==(const DimEntry& rhs) const { |
205 | return data_ == rhs.data_; |
206 | } |
207 | private: |
208 | int64_t data_; |
209 | }; |
210 | |
211 | std::ostream& operator<<(std::ostream& ss, DimEntry entry) { |
212 | if (entry.is_none()) { |
213 | ss << "None" ; |
214 | } else if (entry.is_positional()) { |
215 | ss << entry.position(); |
216 | } else { |
217 | ss << entry.dim(); |
218 | } |
219 | return ss; |
220 | } |
221 | |
222 | // Dim wrapper methods |
223 | |
224 | static int Dim_init(py::hdl<Dim> self, PyObject *args, PyObject *kwds) { |
225 | PY_BEGIN |
226 | static char* kwlist[] = {"name" , "size" , nullptr}; |
227 | py::handle name; |
228 | py::handle size = nullptr; |
229 | if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O" , kwlist, &name, &size)) { |
230 | return -1; |
231 | } |
232 | self->init(py::object::borrow(name), (size.ptr() && !py::is_none(size)) ? py::to_int(size) : -1); |
233 | return 0; |
234 | PY_END(-1) |
235 | } |
236 | |
237 | static PyObject* Dim_repr(Dim* self) { |
238 | PY_BEGIN |
239 | py::object name = (self->name_.ptr()) ? self->name_ : py::unicode_from_string("<uninitialized dim>" ); |
240 | return name.release(); |
241 | PY_END(nullptr) |
242 | } |
243 | |
244 | |
245 | static PyObject* Dim_getsize(Dim* self, void*) { |
246 | PY_BEGIN |
247 | return py::from_int(self->size()).release(); |
248 | PY_END(nullptr) |
249 | } |
250 | |
251 | int Dim_setsize(Dim* self, PyObject* size, void*) { |
252 | PY_BEGIN |
253 | self->set_size(py::to_int(size)); |
254 | return 0; |
255 | PY_END(-1) |
256 | } |
257 | |
258 | static PyObject* Dim_getis_bound(Dim* self, void*) { |
259 | return PyBool_FromLong(self->is_bound()); |
260 | } |
261 | |
262 | static PyObject* Dim_getlevel(Dim* self, void*) { |
263 | return PyLong_FromLong(self->level_); |
264 | } |
265 | |
266 | static PyObject* Dim_get_levels(Dim* self, void*) { |
267 | py::tuple t(1); |
268 | t.set(0, py::object::borrow(self->ptr())); |
269 | return t.release(); |
270 | } |
271 | |
272 | static PyObject* Dim_get_has_device(Dim* self, void*) { |
273 | Py_RETURN_FALSE; |
274 | } |
275 | |
276 | static PyObject* Dim_get_tensor(Dim* self, void*) { |
277 | return THPVariable_Wrap(self->range()); |
278 | } |
279 | |
280 | static PyObject* Dim_get_batchtensor(Dim* self, void*) { |
281 | return THPVariable_Wrap(self->batchtensor()); |
282 | } |
283 | |
284 | |
285 | static PyGetSetDef Dim_getsetters[] = { |
286 | {"size" , (getter) Dim_getsize, (setter) Dim_setsize, |
287 | "Dimension size" , NULL}, |
288 | {"is_bound" , (getter) Dim_getis_bound, NULL, "is_bound" , NULL}, |
289 | {"_level" , (getter) Dim_getlevel, NULL, "_level" , NULL}, |
290 | {"_levels" , (getter) Dim_get_levels, NULL, "_levels" , NULL}, |
291 | {"_has_device" , (getter) Dim_get_has_device, NULL, "_has_device" , NULL}, |
292 | {"_tensor" , (getter) Dim_get_tensor, NULL, "_tensor" , NULL}, |
293 | {"_batchtensor" , (getter) Dim_get_batchtensor, NULL, "_batchtensor" , NULL}, |
294 | {"ndim" , (getter) [](PyObject* self, void*) -> PyObject* { return py::from_int(1).release(); }, NULL, "ndim" , NULL}, |
295 | {NULL} /* Sentinel */ |
296 | }; |
297 | |
298 | PyTypeObject Dim::Type = { |
299 | PyVarObject_HEAD_INIT(NULL, 0) |
300 | "_C.Dim" , /* tp_name */ |
301 | sizeof(Dim), /* tp_basicsize */ |
302 | 0, /* tp_itemsize */ |
303 | Dim::dealloc_stub, /* tp_dealloc */ |
304 | 0, /* tp_vectorcall_offset */ |
305 | 0, /* tp_getattr */ |
306 | 0, /* tp_setattr */ |
307 | 0, /* tp_as_async */ |
308 | (reprfunc)Dim_repr, /* tp_repr */ |
309 | 0, /* tp_as_number */ |
310 | 0, /* tp_as_sequence */ |
311 | 0, /* tp_as_mapping */ |
312 | 0, /* tp_hash */ |
313 | 0, /* tp_call */ |
314 | 0, /* tp_str */ |
315 | 0, /* tp_getattro */ |
316 | 0, /* tp_setattro */ |
317 | 0, /* tp_as_buffer */ |
318 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
319 | "Dim Object" , /* tp_doc */ |
320 | 0, /* tp_traverse */ |
321 | 0, /* tp_clear */ |
322 | 0, /* tp_richcompare */ |
323 | 0, /* tp_weaklistoffset */ |
324 | 0, /* tp_iter */ |
325 | 0, /* tp_iternext */ |
326 | 0, /* tp_methods */ |
327 | 0, /* tp_members */ |
328 | Dim_getsetters, /* tp_getset */ |
329 | 0, /* tp_base */ |
330 | 0, /* tp_dict */ |
331 | 0, /* tp_descr_get */ |
332 | 0, /* tp_descr_set */ |
333 | 0, /* tp_dictoffset */ |
334 | (initproc)(void*) Dim_init, /* tp_init */ |
335 | 0, /* tp_alloc */ |
336 | Dim::new_stub, /* tp_new */ |
337 | }; |
338 | |
339 | // class DimList ------------ |
340 | |
341 | struct DimList : public py::base<DimList> { |
342 | py::object name_; |
343 | std::vector<py::obj<Dim>> dims_; |
344 | static PyTypeObject Type; |
345 | void init(py::object name) { |
346 | name_ = std::move(name); |
347 | } |
348 | void set_dims(std::vector<py::obj<Dim>> dims) { |
349 | bound_ = true; |
350 | dims_ = std::move(dims); |
351 | } |
352 | bool is_bound() { |
353 | return bound_; |
354 | } |
355 | void bind_len(int64_t size) { |
356 | if (bound_) { |
357 | int64_t b_size = dims_.size(); |
358 | if (b_size != size) { |
359 | py::raise_error(DimensionBindError(), "Dimlist has size %lld but it is being bound to size %d" , b_size, size); |
360 | } |
361 | } else { |
362 | bound_ = true; |
363 | dims_.resize(size); |
364 | for (Py_ssize_t i = 0; i < size; ++i) { |
365 | dims_[i] = Dim::create(py::unicode_from_format("%S%i" , name_.ptr(), (int)i)); |
366 | } |
367 | } |
368 | } |
369 | int64_t size() const { |
370 | if (!bound_) { |
371 | py::raise_error(DimensionBindError(), "DimList not bound" ); |
372 | } |
373 | return dims_.size(); |
374 | } |
375 | void set_bound(bool b) { |
376 | bound_ = b; |
377 | } |
378 | private: |
379 | bool bound_ = false; |
380 | }; |
381 | |
382 | |
383 | static int DimList_init(DimList *self, PyObject *args, PyObject *kwds); |
384 | |
385 | static PyObject* DimList_repr(DimList* self) { |
386 | PY_BEGIN |
387 | if (self->is_bound()) { |
388 | size_t size = self->dims_.size(); |
389 | py::tuple t(size); |
390 | for(size_t i = 0; i < size; ++i) { |
391 | t.set(i, self->dims_[i]); |
392 | } |
393 | return py::repr(t).release(); |
394 | } else if(!py::is_none(self->name_)) { |
395 | return py::unicode_from_format("*%S" , self->name_.ptr()).release(); |
396 | } else { |
397 | return py::unicode_from_string("<unbound_dimlist>" ).release(); |
398 | } |
399 | PY_END(nullptr) |
400 | } |
401 | |
402 | static PyObject* DimList_bind(DimList *self, |
403 | PyObject *const *args, |
404 | Py_ssize_t nargs, |
405 | PyObject *kwnames) { |
406 | PY_BEGIN |
407 | py::handle sizes; |
408 | static const char * const _keywords[] = {"sizes" , nullptr}; |
409 | static _PyArg_Parser parser = {"O" , _keywords, 0}; |
410 | if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) { |
411 | return nullptr; |
412 | } |
413 | if (!py::is_sequence(sizes)) { |
414 | py::raise_error(PyExc_ValueError, "expected a sequence" ); |
415 | } |
416 | py::sequence_view seq = sizes; |
417 | auto size = seq.size(); |
418 | self->bind_len(size); |
419 | for (Py_ssize_t i = 0; i < size; ++i) { |
420 | self->dims_[i]->set_size(py::to_int(seq[i])); |
421 | } |
422 | Py_RETURN_NONE; |
423 | PY_END(nullptr) |
424 | } |
425 | |
426 | static PyObject* DimList_bind_len(DimList *self, |
427 | PyObject *const *args, |
428 | Py_ssize_t nargs, |
429 | PyObject *kwnames) { |
430 | PY_BEGIN |
431 | int size; |
432 | static const char * const _keywords[] = {"N" , nullptr}; |
433 | static _PyArg_Parser parser = {"i" , _keywords, 0}; |
434 | if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) { |
435 | return nullptr; |
436 | } |
437 | self->bind_len(size); |
438 | Py_RETURN_NONE; |
439 | PY_END(nullptr) |
440 | } |
441 | |
442 | static PyMethodDef DimList_methods[] = { |
443 | {"bind" , (PyCFunction)(void*) DimList_bind, METH_FASTCALL | METH_KEYWORDS}, |
444 | {"bind_len" , (PyCFunction)(void*) DimList_bind_len, METH_FASTCALL | METH_KEYWORDS}, |
445 | {NULL, NULL, 0, NULL} /* Sentinel */ |
446 | }; |
447 | |
448 | |
449 | static Py_ssize_t DimList_len(DimList* self) { |
450 | PY_BEGIN |
451 | return self->size(); |
452 | PY_END(-1) |
453 | } |
454 | |
455 | PyObject * DimList_item(DimList* self, Py_ssize_t idx) { |
456 | PY_BEGIN |
457 | if (!self->is_bound()) { |
458 | py::raise_error(DimensionBindError(), "DimList not bound" ); |
459 | } |
460 | if (idx < 0 || (size_t) idx >= self->dims_.size()) { |
461 | py::raise_error(PyExc_IndexError, "index out of bounds" ); |
462 | } |
463 | py::object r = self->dims_[idx]; |
464 | return r.release(); |
465 | PY_END(nullptr) |
466 | } |
467 | |
468 | PySequenceMethods DimList_seq { |
469 | (lenfunc) DimList_len, //lenfunc sq_length; |
470 | 0, //binaryfunc sq_concat; |
471 | 0, //ssizeargfunc sq_repeat; |
472 | (ssizeargfunc) DimList_item, //ssizeargfunc sq_item; |
473 | 0, //void *was_sq_slice; |
474 | 0, //ssizeobjargproc sq_ass_item; |
475 | 0, //void *was_sq_ass_slice; |
476 | 0, //objobjproc sq_contains; |
477 | |
478 | 0, //binaryfunc sq_inplace_concat; |
479 | 0, //ssizeargfunc sq_inplace_repeat; |
480 | }; |
481 | |
482 | static PyObject* DimList_getis_bound(DimList* self, void*) { |
483 | return PyBool_FromLong(self->is_bound()); |
484 | } |
485 | |
486 | static PyGetSetDef DimList_getsetters[] = { |
487 | {"is_bound" , (getter) DimList_getis_bound, NULL, "is_bound" , NULL}, |
488 | {NULL} /* Sentinel */ |
489 | }; |
490 | |
491 | |
492 | static PyObject* DimList_subscript(DimList* self, py::handle idx) { |
493 | PY_BEGIN |
494 | if (py::is_int(idx)) { |
495 | return DimList_item(self, py::to_int(idx)); |
496 | } else if (py::is_slice(idx)) { |
497 | if (!self->is_bound()) { |
498 | py::raise_error(DimensionBindError(), "DimList not bound" ); |
499 | } |
500 | py::slice_view s(idx, self->dims_.size()); |
501 | py::tuple r(s.slicelength); |
502 | for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) { |
503 | r.set(j++, self->dims_[i]); |
504 | } |
505 | return r.release(); |
506 | } else { |
507 | py::raise_error(PyExc_ValueError, "expected an int or a slice" ); |
508 | return nullptr; |
509 | } |
510 | PY_END(nullptr) |
511 | } |
512 | |
513 | PyMappingMethods DimList_mapping = { |
514 | 0, //lenfunc mp_length; |
515 | (binaryfunc)(void*) DimList_subscript, //binaryfunc mp_subscript; |
516 | 0, //objobjargproc mp_ass_subscript; |
517 | }; |
518 | |
519 | |
520 | |
521 | PyTypeObject DimList::Type = { |
522 | PyVarObject_HEAD_INIT(NULL, 0) |
523 | "_C.DimList" , /* tp_name */ |
524 | sizeof(DimList), /* tp_basicsize */ |
525 | 0, /* tp_itemsize */ |
526 | DimList::dealloc_stub, /* tp_dealloc */ |
527 | 0, /* tp_vectorcall_offset */ |
528 | 0, /* tp_getattr */ |
529 | 0, /* tp_setattr */ |
530 | 0, /* tp_as_async */ |
531 | (reprfunc)DimList_repr, /* tp_repr */ |
532 | 0, /* tp_as_number */ |
533 | &DimList_seq, /* tp_as_sequence */ |
534 | &DimList_mapping, /* tp_as_mapping */ |
535 | 0, /* tp_hash */ |
536 | 0, /* tp_call */ |
537 | 0, /* tp_str */ |
538 | 0, /* tp_getattro */ |
539 | 0, /* tp_setattro */ |
540 | 0, /* tp_as_buffer */ |
541 | 0, /* tp_flags */ |
542 | "DimList Object" , /* tp_doc */ |
543 | 0, /* tp_traverse */ |
544 | 0, /* tp_clear */ |
545 | 0, /* tp_richcompare */ |
546 | 0, /* tp_weaklistoffset */ |
547 | 0, /* tp_iter */ |
548 | 0, /* tp_iternext */ |
549 | DimList_methods, /* tp_methods */ |
550 | 0, /* tp_members */ |
551 | DimList_getsetters, /* tp_getset */ |
552 | 0, /* tp_base */ |
553 | 0, /* tp_dict */ |
554 | 0, /* tp_descr_get */ |
555 | 0, /* tp_descr_set */ |
556 | 0, /* tp_dictoffset */ |
557 | (initproc) DimList_init, /* tp_init */ |
558 | 0, /* tp_alloc */ |
559 | DimList::new_stub, /* tp_new */ |
560 | }; |
561 | |
562 | static int DimList_init(DimList *self, PyObject *args, PyObject *kwds) { |
563 | PY_BEGIN |
564 | static char* kwlist[] = {"len_or_dims" , "name" , nullptr}; |
565 | py::handle len_or_dims = nullptr; |
566 | PyObject* name = nullptr; |
567 | if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO" , kwlist, &len_or_dims, &name)) { |
568 | return -1; |
569 | } |
570 | self->init(py::object::borrow(name ? name : Py_None)); |
571 | if (len_or_dims.ptr()) { |
572 | if(py::is_int(len_or_dims)) { |
573 | self->bind_len(py::to_int(len_or_dims)); |
574 | } else if (py::is_sequence(len_or_dims)) { |
575 | py::sequence_view s(len_or_dims); |
576 | std::vector<py::obj<Dim>> dims; |
577 | size_t size = s.size(); |
578 | dims.reserve(size); |
579 | for (size_t i = 0; i < size; ++i) { |
580 | auto r = s[i]; |
581 | if (py::is_int(r)) { |
582 | dims.emplace_back(Dim::create(py::unicode_from_format("%S%i" , self->name_.ptr(), (int)i), py::to_int(r))); |
583 | } else { |
584 | dims.emplace_back(Dim::wrap(r)); |
585 | } |
586 | } |
587 | self->set_dims(std::move(dims)); |
588 | } else { |
589 | PyErr_Format(PyExc_ValueError, "expected a length or a sequence of dimensions" ); |
590 | return -1; |
591 | } |
592 | return 0; |
593 | } |
594 | return 0; |
595 | PY_END(-1); |
596 | } |
597 | |
598 | // Tensor ----------------------------- |
599 | |
600 | PyTypeObject* TensorType = nullptr; // the python wrapper type. |
601 | at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice<DimEntry> levels_); |
602 | static py::object run_torch_function(Arena &A, py::handle orig, py::vector_args args, bool is_pointwise); |
603 | void free_levels_dims(Slice<DimEntry> levels); |
604 | |
605 | struct Tensor; |
606 | |
607 | struct DelayedOperator { |
608 | DelayedOperator(py::object o, py::vector_args a) |
609 | : orig(std::move(o)), args(a) { |
610 | auto all = a.size(); |
611 | // this will outlive the call so |
612 | // take ownership of temporaries |
613 | // in vector args |
614 | auto buf = new py::handle[all]; |
615 | memcpy(buf, args.args, sizeof(py::handle)*all); |
616 | args.args = buf; |
617 | for (auto i : args.enumerate_all()) { |
618 | Py_INCREF(args.args[i].ptr()); |
619 | } |
620 | Py_XINCREF(args.kwnames.ptr()); |
621 | } |
622 | ~DelayedOperator() { |
623 | for (auto i : args.enumerate_all()) { |
624 | Py_DECREF(args[i].ptr()); |
625 | } |
626 | if (args.has_keywords()) { |
627 | Py_XDECREF(args.kwnames.ptr()); |
628 | } |
629 | delete [] args.args; |
630 | } |
631 | py::object orig; |
632 | py::vector_args args; |
633 | }; |
634 | |
635 | struct Tensor : public py::base<Tensor> { |
636 | private: |
637 | at::Tensor tensor_; |
638 | at::Tensor batchtensor_; |
639 | OwnedSlice<DimEntry> levels_; |
640 | bool has_device_; |
641 | std::unique_ptr<DelayedOperator> delayed_; |
642 | public: |
643 | |
644 | at::Tensor& tensor(Arena& A) { |
645 | if (C10_UNLIKELY(!tensor_.defined())) { |
646 | AT_ASSERT(delayed_); |
647 | auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true)); |
648 | tensor_ = t->tensor(A); |
649 | delayed_.reset(); |
650 | // don't force creation of batch tensor if it wasn't alreay provided. |
651 | batchtensor_ = t->batchtensor_; |
652 | AT_ASSERT(levels() == t->levels()); |
653 | } |
654 | return tensor_; |
655 | } |
656 | at::Tensor& batchtensor(Arena& A) { |
657 | if (C10_UNLIKELY(!batchtensor_.defined())) { |
658 | batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice()); |
659 | } |
660 | return batchtensor_; |
661 | } |
662 | Slice<DimEntry> levels() { |
663 | return levels_.slice(); |
664 | } |
665 | bool has_device() { |
666 | return has_device_; |
667 | } |
668 | DelayedOperator* delayed() { |
669 | return delayed_.get(); |
670 | } |
671 | static PyTypeObject Type; |
672 | |
673 | static bool check_exact(py::handle v) { |
674 | return Py_TYPE(v.ptr()) == TensorType; |
675 | } |
676 | |
677 | |
678 | static py::obj<Tensor> create() { |
679 | if (!TensorType) { |
680 | TensorType = (PyTypeObject*) py::import("functorch.dim" ).attr("Tensor" ).ptr(); |
681 | } |
682 | return Tensor::alloc(TensorType); |
683 | } |
684 | void capture_levels(Slice<DimEntry> levels) { |
685 | // grab ownership of the dims inside levels |
686 | for (auto l : levels) { |
687 | if (!l.is_positional()) { |
688 | py::object::borrow(l.dim()).release(); |
689 | } |
690 | } |
691 | levels_.set(levels, free_levels_dims); |
692 | } |
693 | static py::object from_positional(Arena & A, at::Tensor tensor, Slice<DimEntry> levels, bool has_device); |
694 | static py::obj<Tensor> create_delayed(py::object op, py::vector_args args, Slice<DimEntry> levels, bool has_device); |
695 | friend struct EnableAllLayers; |
696 | }; |
697 | |
698 | at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice<DimEntry> levels_) { |
699 | auto levels = Slice<DimEntry>(); |
700 | levels.extend(A, levels_); |
701 | while (true) { |
702 | int64_t min_real_index = -1; |
703 | int64_t min_index = -1; |
704 | int64_t min_value = INT_MAX; |
705 | int64_t i = 0; |
706 | int64_t r = 0; |
707 | for (auto l : levels) { |
708 | if (!l.is_none()) { |
709 | if (!l.is_positional() && l.dim()->level_ < min_value) { |
710 | min_value = l.dim()->level_; |
711 | min_index = i; |
712 | min_real_index = r; |
713 | } |
714 | ++i; |
715 | } |
716 | ++r; |
717 | } |
718 | if (min_index == -1) { |
719 | return t; |
720 | } |
721 | auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value); |
722 | t = std::move(t2); |
723 | levels[min_real_index] = DimEntry(); |
724 | } |
725 | } |
726 | |
727 | void free_levels_dims(Slice<DimEntry> levels) { |
728 | for(auto e : levels) { |
729 | if (!e.is_positional()) { |
730 | py::object::steal(e.dim()); |
731 | } |
732 | } |
733 | } |
734 | |
735 | // version in header does a unnecessary refcount +/- |
736 | inline at::functorch::BatchedTensorImpl* maybeGetBatchedImpl(const at::Tensor& tensor) { |
737 | if (at::functorch::isBatchedTensor(tensor)) { |
738 | return static_cast<at::functorch::BatchedTensorImpl*>(tensor.unsafeGetTensorImpl()); |
739 | } |
740 | return nullptr; |
741 | } |
742 | |
743 | inline TensorRef unchecked_tensor_from(py::handle p) { |
744 | auto v = (THPVariable*) p.ptr(); |
745 | return TensorRef(*v->cdata); |
746 | } |
747 | |
748 | int64_t ndim_of_levels(Slice<DimEntry> levels) { |
749 | int64_t r = 0; |
750 | for (auto l : levels) { |
751 | if (l.is_positional()) { |
752 | ++r; |
753 | } |
754 | } |
755 | return r; |
756 | } |
757 | |
758 | struct TensorInfo { |
759 | TensorRef tensor; |
760 | Slice<DimEntry> levels; |
761 | bool has_device; |
762 | TensorRef batchedtensor; |
763 | int64_t ndim() const { |
764 | return ndim_of_levels(levels); |
765 | } |
766 | operator bool() const { |
767 | return tensor; |
768 | } |
769 | |
770 | static TensorInfo create(Arena& A, py::handle h, bool ensure_batched=true, bool ensure_present=true) { |
771 | if (Tensor::check_exact(h)) { |
772 | auto t = Tensor::unchecked_wrap(h); |
773 | return TensorInfo {t->tensor(A), t->levels(), t->has_device(), ensure_batched ? t->batchtensor(A) : TensorRef()}; |
774 | } else if (Dim::check_exact(h)) { |
775 | auto d = Dim::unchecked_wrap(h); |
776 | return TensorInfo {d->range(), Slice<DimEntry>(A, DimEntry(d)), false, ensure_batched ? d->batchtensor() : TensorRef()}; |
777 | } else if (THPVariable_Check(h.ptr())) { |
778 | TensorRef t = unchecked_tensor_from(h); |
779 | Slice<DimEntry> levels; |
780 | for (auto i : irange(-t->dim(), 0)) { |
781 | levels.append(A, i); |
782 | } |
783 | return TensorInfo {t, levels, true, t}; |
784 | } else { |
785 | if (ensure_present) { |
786 | py::raise_error(PyExc_ValueError, "expected a tensor object" ); |
787 | } |
788 | return TensorInfo {}; |
789 | } |
790 | } |
791 | |
792 | |
793 | }; |
794 | |
795 | py::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice<DimEntry> levels, bool has_device) { |
796 | size_t seen_dims = 0; |
797 | int last = 0; |
798 | //auto sz = tensor.sizes(); |
799 | for (auto i : levels.enumerate()) { |
800 | auto l = levels[i]; |
801 | if (l.is_positional()) { |
802 | AT_ASSERT(last == 0 || last + 1 == l.position()); |
803 | last = l.position(); |
804 | } else { |
805 | py::object::borrow(l.dim()).release(); |
806 | //AT_ASSERT(sz[i] == l.dim()->size()); |
807 | ++seen_dims; |
808 | } |
809 | } |
810 | AT_ASSERT(last == 0 || last == -1); |
811 | if (!seen_dims) { |
812 | return py::object::steal(THPVariable_Wrap(std::move(tensor))); |
813 | } |
814 | |
815 | py::obj<Tensor> self = Tensor::create(); |
816 | self->tensor_ = std::move(tensor); |
817 | AT_ASSERT(self->tensor_.dim() == levels.size()); |
818 | self->levels_.set(levels, free_levels_dims); |
819 | self->has_device_ = has_device; |
820 | py::object r = std::move(self); |
821 | return r; |
822 | } |
823 | |
824 | |
825 | static PyObject* py_Tensor_from_positional(PyObject *self, |
826 | PyObject *const *args, |
827 | Py_ssize_t nargs, |
828 | PyObject *kwnames) { |
829 | Arena A; |
830 | PY_BEGIN |
831 | #define ARGS(_) _(py::handle, tensor) _(py::handle, py_levels) _(int, has_device) |
832 | MPY_PARSE_ARGS_KWNAMES("OOp" , ARGS) |
833 | #undef ARGS |
834 | |
835 | if (!THPVariable_Check(tensor.ptr())) { |
836 | py::raise_error(PyExc_ValueError, "_tensor is not a Tensor?" ); |
837 | } |
838 | |
839 | Slice<DimEntry> levels; |
840 | py::sequence_view sq(py_levels); |
841 | for (auto i : sq.enumerate()) { |
842 | py::object v = sq[i]; |
843 | if (py::is_int(v)) { |
844 | auto vi = py::to_int(v); |
845 | levels.append(A, vi); |
846 | } else { |
847 | auto dim = Dim::wrap(std::move(v)); |
848 | py::hdl<Dim> hdim = dim; |
849 | levels.append(A, hdim); |
850 | } |
851 | } |
852 | return Tensor::from_positional(A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0).release(); |
853 | PY_END(nullptr) |
854 | } |
855 | |
856 | py::obj<Tensor> Tensor::create_delayed(py::object op, py::vector_args args, Slice<DimEntry> levels, bool has_device) { |
857 | py::obj<Tensor> self = Tensor::create(); |
858 | self->capture_levels(levels); |
859 | self->has_device_ = has_device; |
860 | self->delayed_ = std::make_unique<DelayedOperator>(op, args); |
861 | return self; |
862 | } |
863 | |
864 | py::list slice_to_list(Slice<py::handle> h) { |
865 | py::list lst(h.size()); |
866 | for (auto i : h.enumerate()) { |
867 | lst.set(i, py::object::borrow(h[i])); |
868 | } |
869 | return lst; |
870 | } |
871 | |
872 | py::tuple slice_to_tuple(Slice<py::handle> h) { |
873 | py::tuple lst(h.size()); |
874 | for (auto i : h.enumerate()) { |
875 | lst.set(i, py::object::borrow(h[i])); |
876 | } |
877 | return lst; |
878 | } |
879 | |
880 | enum UType { |
881 | U_ELEM, |
882 | U_TUPLE_LIKE, |
883 | U_DICT, |
884 | }; |
885 | |
886 | struct Unflatten { |
887 | py::object operator()(Slice<py::handle>& elements) { |
888 | py::object r; |
889 | switch (type) { |
890 | case U_ELEM: { |
891 | r = py::object::borrow(elements[0]); |
892 | elements = elements.slice(1); |
893 | } break; |
894 | case U_TUPLE_LIKE: { |
895 | py::tuple tup(children.size()); |
896 | for (auto i : children.enumerate()) { |
897 | tup.set(i, children[i](elements)); |
898 | } |
899 | r = obj.call(tup); |
900 | } break; |
901 | case U_DICT: { |
902 | r = py::object::checked_steal(PyDict_New()); |
903 | py::dict_view rv(r); |
904 | py::dict_view d(obj); |
905 | Py_ssize_t pos = 0; |
906 | py::handle k, v; |
907 | for (int i = 0; d.next(&pos, &k, &v); ++i) { |
908 | rv.set(k, children[i](elements)); |
909 | } |
910 | } break; |
911 | } |
912 | return r; |
913 | } |
914 | UType type; |
915 | py::handle obj; |
916 | Slice<Unflatten> children; |
917 | }; |
918 | |
919 | Unflatten tree_flatten(Arena& A, py::handle agg, Slice<py::handle>& flat_elements) { |
920 | Slice<Unflatten> c; |
921 | UType utype; |
922 | py::handle obj; |
923 | if (py::list_view::check(agg)) { |
924 | obj = agg.type(); |
925 | utype = U_TUPLE_LIKE; |
926 | py::list_view l(agg); |
927 | for (auto i : l.enumerate()) { |
928 | c.append(A, tree_flatten(A, l[i], flat_elements)); |
929 | } |
930 | } else if (py::tuple_view::check(agg)) { |
931 | obj = agg.type(); |
932 | utype = U_TUPLE_LIKE; |
933 | // includes named tuples |
934 | py::tuple_view l(agg); |
935 | for (auto i : l.enumerate()) { |
936 | c.append(A, tree_flatten(A, l[i], flat_elements)); |
937 | } |
938 | } else if (py::dict_view::check(agg)) { |
939 | utype = U_DICT; |
940 | py::dict_view d(agg); |
941 | obj = agg; |
942 | Py_ssize_t pos = 0; |
943 | py::handle k, v; |
944 | while (d.next(&pos, &k, &v)) { |
945 | c.append(A, tree_flatten(A, v, flat_elements)); |
946 | } |
947 | } else { |
948 | utype = U_ELEM; |
949 | flat_elements.append(A, agg); |
950 | } |
951 | return Unflatten {utype, obj, c}; |
952 | } |
953 | |
954 | struct UnflattenVectorArgs { |
955 | py::vector_args operator()(Arena& A, Slice<py::handle>& elements) { |
956 | if (!had_nested) { |
957 | auto args = elements.begin(); |
958 | elements = Slice<py::handle>(); |
959 | return py::vector_args(args, nargs, kwnames); |
960 | } |
961 | Slice<py::handle> args; |
962 | for (auto u : children) { |
963 | args.append(A, A.autorelease(u(elements))); |
964 | } |
965 | return py::vector_args(args.begin(), nargs, kwnames); |
966 | } |
967 | Slice<Unflatten> children; |
968 | Py_ssize_t nargs; |
969 | py::handle kwnames; |
970 | bool had_nested; |
971 | }; |
972 | |
973 | UnflattenVectorArgs tree_flatten(Arena& A, py::vector_args args, Slice<py::handle>& flat_elements) { |
974 | UnflattenVectorArgs r; |
975 | r.kwnames = args.kwnames; |
976 | r.nargs = args.nargs; |
977 | r.had_nested = false; |
978 | auto N = args.size(); |
979 | for(auto i : irange(N)) { |
980 | auto typ = Py_TYPE(args[i].ptr()); |
981 | // fast checks that this thing isn't something that is nested. |
982 | bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || typ == TensorType || typ == DimType; |
983 | if (!is_element) { |
984 | flat_elements.extend(A, args.args, args.args + i); |
985 | for (auto j : irange(i)) { |
986 | (void)j; |
987 | r.children.append(A, Unflatten {U_ELEM}); |
988 | } |
989 | for (auto j : irange(i, N)) { |
990 | r.children.append(A, tree_flatten(A, args[j], flat_elements)); |
991 | if (r.children.back().type != U_ELEM) { |
992 | r.had_nested = true; |
993 | } |
994 | } |
995 | return r; |
996 | } |
997 | } |
998 | flat_elements.extend(A, args.args, args.args + N); |
999 | return r; |
1000 | } |
1001 | |
1002 | |
1003 | struct UnflattenArena { |
1004 | Arena A; |
1005 | Unflatten unflatten; |
1006 | }; |
1007 | |
1008 | static PyObject* py_unflatten(PyObject *self, |
1009 | PyObject *const *args, |
1010 | Py_ssize_t nargs, |
1011 | PyObject *kwnames) { |
1012 | PY_BEGIN |
1013 | #define ARGS(_) _(py::handle, ns) |
1014 | MPY_PARSE_ARGS_KWNAMES("O" , ARGS) |
1015 | #undef ARGS |
1016 | py::sequence_view sv(ns); |
1017 | // because we do not have a autorelase pool yet... |
1018 | Arena A; |
1019 | Slice<py::handle> slice; |
1020 | py::handle Tuple = (PyObject*) &PyTuple_Type; |
1021 | auto inputs = Tuple.call(ns); |
1022 | py::tuple_view tv(inputs); |
1023 | for (auto i : tv.enumerate()) { |
1024 | slice.append(A, tv[i]); |
1025 | } |
1026 | auto AA = (UnflattenArena*) PyCapsule_GetPointer(self, "arena" ); |
1027 | auto r = AA->unflatten(slice).release(); |
1028 | AT_ASSERT(r != nullptr); |
1029 | return r; |
1030 | PY_END(nullptr) |
1031 | } |
1032 | |
1033 | PyMethodDef py_unflatten_def = {"unflatten" , (PyCFunction)(void*) py_unflatten, METH_FASTCALL | METH_KEYWORDS}; |
1034 | |
1035 | void free_unflatten_arena(PyObject * pc) { |
1036 | delete (UnflattenArena*) PyCapsule_GetPointer(pc, "arena" ); |
1037 | } |
1038 | |
1039 | static PyObject* py_tree_flatten(PyObject *self, |
1040 | PyObject *const *args, |
1041 | Py_ssize_t nargs, |
1042 | PyObject *kwnames) { |
1043 | PY_BEGIN |
1044 | #define ARGS(_) _(py::handle, tree) |
1045 | MPY_PARSE_ARGS_KWNAMES("O" , ARGS) |
1046 | #undef ARGS |
1047 | auto A = new UnflattenArena; |
1048 | Slice<py::handle> elements; |
1049 | A->unflatten = tree_flatten(A->A, tree, elements); |
1050 | auto cap = py::object::checked_steal(PyCapsule_New(A, "arena" , free_unflatten_arena)); |
1051 | auto unflatten = py::object::checked_steal(PyCFunction_New(&py_unflatten_def, cap.release())); |
1052 | py::tuple r(2); |
1053 | r.set(0, slice_to_list(elements)); |
1054 | r.set(1, std::move(unflatten)); |
1055 | return r.release(); |
1056 | PY_END(nullptr) |
1057 | } |
1058 | |
1059 | |
1060 | |
1061 | py::object tree_map(Arena& A, std::function<py::handle(py::handle)> fn, py::handle agg) { |
1062 | Slice<py::handle> elements; |
1063 | auto unflatten = tree_flatten(A, agg, elements); |
1064 | for (auto i : elements.enumerate()) { |
1065 | elements[i] = fn(elements[i]); |
1066 | } |
1067 | return unflatten(elements); |
1068 | } |
1069 | |
1070 | // prereq: isinstance(h, _Tensor) |
1071 | inline int64_t _Tensor_ndim(py::handle h) { |
1072 | if (Tensor::check(h)) { |
1073 | int64_t r = 0; |
1074 | for (auto l : Tensor::unchecked_wrap(h)->levels()) { |
1075 | if (l.is_positional()) { |
1076 | ++r; |
1077 | } |
1078 | } |
1079 | return r; |
1080 | } |
1081 | // Dim or DelayedMulTensor |
1082 | return 0; |
1083 | } |
1084 | |
1085 | inline py::handle handle_from_tensor(Arena& A, TensorRef t) { |
1086 | // fast case: tensor is live in python |
1087 | c10::optional<PyObject*> mb_obj = |
1088 | t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter()); |
1089 | if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { |
1090 | return *mb_obj; |
1091 | } |
1092 | return A.autorelease(py::object::checked_steal(THPVariable_Wrap(*t))); |
1093 | } |
1094 | |
1095 | struct EnableAllLayers { |
1096 | EnableAllLayers(Arena& A, Slice<DimEntry> levels) { |
1097 | std::vector<std::pair<int64_t, int64_t>> layers; |
1098 | layers.reserve(levels.size()); |
1099 | for (auto l : levels) { |
1100 | if (!l.is_positional()) { |
1101 | auto d = l.dim(); |
1102 | levels_to_dim_.append(A, d); |
1103 | } |
1104 | } |
1105 | std::sort(levels_to_dim_.begin(), levels_to_dim_.end(), [](py::hdl<Dim> lhs, py::hdl<Dim> rhs) { return lhs->level_ < rhs->level_;}); |
1106 | |
1107 | for (auto i : levels_to_dim_.enumerate()) { |
1108 | auto batch_size = levels_to_dim_[i]->size(); |
1109 | auto level = at::functorch::initAndPushDynamicLayer(at::functorch::TransformType::Vmap, batch_size, at::functorch::RandomnessType::Different); |
1110 | if (i == 0) { |
1111 | levels_start_ = level; |
1112 | } |
1113 | } |
1114 | } |
1115 | |
1116 | ~EnableAllLayers() { |
1117 | auto to_remove = levels_start_ + levels_to_dim_.size() - 1; |
1118 | for (auto i : levels_to_dim_.enumerate()) { |
1119 | AT_ASSERT(at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == to_remove - i); |
1120 | } |
1121 | } |
1122 | |
1123 | py::obj<Tensor> from_batched(Arena& A, at::Tensor batchedtensor, bool has_device) { |
1124 | Slice<DimEntry> levels; |
1125 | for (auto i : irange(-batchedtensor.dim(), 0)) { |
1126 | levels.append(A, i); |
1127 | } |
1128 | TensorRef tensor; |
1129 | at::functorch::BatchedTensorImpl * impl = maybeGetBatchedImpl(batchedtensor); |
1130 | while(true) { |
1131 | auto level = impl->level(); |
1132 | AT_ASSERT(level >= levels_start_ && level < levels_start_ + levels_to_dim_.size()); |
1133 | py::hdl<Dim> dim = levels_to_dim_[level - levels_start_].ptr(); |
1134 | levels.insert(A, impl->bdim(), dim); |
1135 | at::functorch::BatchedTensorImpl * nimpl = maybeGetBatchedImpl(impl->value()); |
1136 | if (!nimpl) { |
1137 | tensor = impl->value(); |
1138 | break; |
1139 | } |
1140 | impl = nimpl; |
1141 | } |
1142 | |
1143 | py::obj<Tensor> self = Tensor::create(); |
1144 | // grab ownership of the tensors |
1145 | self->tensor_ = *tensor; |
1146 | self->batchtensor_ = std::move(batchedtensor); |
1147 | self->has_device_ = has_device; |
1148 | self->capture_levels(levels); |
1149 | return self; |
1150 | } |
1151 | void inplace_update_layers(TensorRef batchtensor, Slice<DimEntry> levels) { |
1152 | // XXX - requires a patch to functorch to att set_level |
1153 | auto impl = maybeGetBatchedImpl(*batchtensor); |
1154 | for (auto i : levels_to_dim_.reversed_enumerate()) { |
1155 | if (!impl) { |
1156 | break; |
1157 | } |
1158 | if (levels.contains(levels_to_dim_[i])) { |
1159 | impl->_unsafe_set_level(levels_start_ + i); |
1160 | impl = maybeGetBatchedImpl(impl->value()); |
1161 | |
1162 | } |
1163 | } |
1164 | } |
1165 | private: |
1166 | int64_t levels_start_{}; |
1167 | Slice<py::hdl<Dim>> levels_to_dim_; |
1168 | }; |
1169 | |
1170 | TensorRef _match_levels(Arena& A, TensorRef v, Slice<DimEntry> from_levels, Slice<DimEntry> to_levels, bool drop_levels=false) { |
1171 | if (from_levels == to_levels) { |
1172 | return v; |
1173 | } |
1174 | // drop_levels -> if a dim appears in from_levels but not to_levels, it is assumed it has stride 0. |
1175 | at::IntArrayRef sz = v->sizes(); |
1176 | at::IntArrayRef sd = v->strides(); |
1177 | AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size()); |
1178 | Slice<int64_t> nsz; |
1179 | Slice<int64_t> nsd; |
1180 | for (auto l : to_levels) { |
1181 | auto oidx = from_levels.index(l); |
1182 | if (!oidx) { |
1183 | nsz.append(A, l.is_positional() ? 1 : l.dim()->size()); |
1184 | nsd.append(A, 0); |
1185 | } else { |
1186 | auto idx = *oidx; |
1187 | nsz.append(A, sz[idx]); |
1188 | nsd.append(A, sd[idx]); |
1189 | } |
1190 | } |
1191 | return A.autorelease(v->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()), at::IntArrayRef(nsd.begin(), nsd.end()), v->storage_offset())); |
1192 | } |
1193 | |
1194 | static py::object run_torch_function(Arena &A, py::handle orig, py::vector_args args, bool is_pointwise) { |
1195 | if (!pointwise_optimize) { |
1196 | is_pointwise = false; |
1197 | } |
1198 | // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : "functorch") << " " << orig << "\n"; |
1199 | |
1200 | Slice<py::hdl<Dim>> all_dims; |
1201 | Slice<py::handle> flat_args; |
1202 | auto unflatten_args = tree_flatten(A, args, flat_args); |
1203 | TensorRef device_holding_tensor; |
1204 | |
1205 | Slice<TensorInfo> infos; |
1206 | Slice<DimEntry> result_levels; |
1207 | for (auto f : flat_args) { |
1208 | infos.append(A, TensorInfo::create(A, f, !is_pointwise, false)); |
1209 | if (infos.back()) { |
1210 | TensorInfo& info = infos.back(); |
1211 | AT_ASSERT(is_pointwise || info.batchedtensor); |
1212 | if (!device_holding_tensor && info.has_device) { |
1213 | device_holding_tensor = infos.back().tensor; |
1214 | } |
1215 | for (auto l : info.levels) { |
1216 | if (!result_levels.contains(l)) { |
1217 | result_levels.append(A, l); |
1218 | } |
1219 | } |
1220 | } |
1221 | } |
1222 | |
1223 | if (is_pointwise) { |
1224 | for (auto i : flat_args.enumerate()) { |
1225 | if (infos[i]) { |
1226 | TensorRef tensor = infos[i].tensor; |
1227 | if (device_holding_tensor && !infos[i].has_device) { |
1228 | tensor = A.autorelease(tensor->to(device_holding_tensor->device())); |
1229 | } |
1230 | auto ml = _match_levels(A, tensor, infos[i].levels, result_levels); |
1231 | flat_args[i] = handle_from_tensor(A, std::move(ml)); |
1232 | } |
1233 | } |
1234 | |
1235 | Slice<py::handle> flat_it = flat_args; |
1236 | py::vector_args uargs = unflatten_args(A, flat_it); |
1237 | |
1238 | py::object result = orig.call_vector(uargs); |
1239 | |
1240 | // fast wrap for normal case where operator just returns a tensor. |
1241 | if (THPVariable_Check(result.ptr())) { |
1242 | return Tensor::from_positional(A, THPVariable_Unpack(result.ptr()), result_levels, device_holding_tensor); |
1243 | } |
1244 | auto wrap = [&](py::handle h) { |
1245 | if (THPVariable_Check(h.ptr())){ |
1246 | return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), result_levels, device_holding_tensor)); |
1247 | } |
1248 | return h; |
1249 | }; |
1250 | return tree_map(A, wrap, result); |
1251 | } else { |
1252 | // std::cout << orig << " calling functorch...\n"; |
1253 | // std::cout << "rl: " << result_levels << "\n"; |
1254 | EnableAllLayers guard(A, result_levels); |
1255 | for (auto i : flat_args.enumerate()) { |
1256 | if (infos[i]) { |
1257 | TensorRef batched = infos[i].batchedtensor; |
1258 | if (device_holding_tensor && !infos[i].has_device) { |
1259 | batched = A.autorelease(batched->to(device_holding_tensor->device())); |
1260 | } |
1261 | guard.inplace_update_layers(batched, infos[i].levels); |
1262 | flat_args[i] = handle_from_tensor(A, batched); |
1263 | } |
1264 | } |
1265 | Slice<py::handle> flat_it = flat_args; |
1266 | py::vector_args uargs = unflatten_args(A, flat_it); |
1267 | AT_ASSERT(flat_it.size() == 0); |
1268 | py::object result = orig.call_vector(uargs); |
1269 | auto wrap = [&](py::handle h) { |
1270 | if (THPVariable_Check(h.ptr())) { |
1271 | return A.autorelease(guard.from_batched(A, THPVariable_Unpack(h.ptr()), device_holding_tensor)); |
1272 | } |
1273 | return h; |
1274 | }; |
1275 | if (THPVariable_Check(result.ptr())) { |
1276 | return guard.from_batched(A, THPVariable_Unpack(result.ptr()), device_holding_tensor); |
1277 | } |
1278 | return tree_map(A, wrap, result); |
1279 | } |
1280 | } |
1281 | |
1282 | |
1283 | static py::object __torch_function__(Arena &A, py::handle orig, py::vector_args args, bool is_pointwise) { |
1284 | if (orig == torch_Tensor___mul__) { |
1285 | AT_ASSERT(args.nargs == 2 && !args.has_keywords()); |
1286 | auto lhs = args[0]; |
1287 | auto rhs = args[1]; |
1288 | if (py::isinstance(lhs, _Tensor) && py::isinstance(rhs, _Tensor) && _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) { |
1289 | bool has_device = false; |
1290 | Slice<DimEntry> levels; |
1291 | for (auto i : args.enumerate_positional()) { |
1292 | auto t = TensorInfo::create(A, args[i], false); |
1293 | // something like a mask * rhs, which matrix multiplies don't correctly promote |
1294 | if (!t.tensor->is_floating_point()) { |
1295 | return run_torch_function(A, orig, args, is_pointwise); |
1296 | } |
1297 | has_device = has_device || t.has_device; |
1298 | for (auto l : t.levels) { |
1299 | if (!levels.contains(l)) { |
1300 | levels.append(A, l); |
1301 | } |
1302 | } |
1303 | } |
1304 | // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n"; |
1305 | return Tensor::create_delayed(py::object::borrow(orig), args, levels, has_device); |
1306 | } |
1307 | } |
1308 | return run_torch_function(A, orig, args, is_pointwise); |
1309 | } |
1310 | |
1311 | py::vector_args as_vector_args(Arena& A, py::handle args, py::handle kwargs) { |
1312 | auto pos_args = (py::handle*) &PyTuple_GET_ITEM(args.ptr(), 0); |
1313 | auto pos_n = PyTuple_GET_SIZE(args.ptr()); |
1314 | if (!kwargs.ptr()) { |
1315 | return py::vector_args(pos_args, pos_n, nullptr); |
1316 | } |
1317 | Slice<py::handle> all_args; |
1318 | Slice<py::handle> kwnames; |
1319 | all_args.extend(A, pos_args, pos_args + pos_n); |
1320 | py::dict_view dv(kwargs); |
1321 | Py_ssize_t pos = 0; |
1322 | py::handle key, value; |
1323 | while (dv.next(&pos, &key, &value)) { |
1324 | all_args.append(A, value); |
1325 | kwnames.append(A, key); |
1326 | } |
1327 | return py::vector_args(all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames))); |
1328 | } |
1329 | |
1330 | static PyObject* py___torch_function__(PyObject *self, |
1331 | PyObject *const *args, |
1332 | Py_ssize_t nargs, |
1333 | PyObject *kwnames) { |
1334 | Arena A; |
1335 | PY_BEGIN |
1336 | maybeInitializeGlobals(); |
1337 | AT_ASSERT(nargs == 4 || nargs == 5); |
1338 | auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr); |
1339 | bool is_pointwise = pointwise.contains(args[1]); |
1340 | return __torch_function__(A, args[1], std::move(va), is_pointwise).release(); |
1341 | PY_END(nullptr) |
1342 | } |
1343 | |
1344 | py::object levels_to_tuple(Slice<DimEntry> slice) { |
1345 | py::tuple t(slice.size()); |
1346 | for (auto i : slice.enumerate()) { |
1347 | t.set(i, slice[i].is_positional() ? py::from_int(slice[i].position()) : py::object::borrow(slice[i].dim())); |
1348 | } |
1349 | py::object r = std::move(t); |
1350 | return r; |
1351 | } |
1352 | |
1353 | PyObject* Tensor_ndim(Tensor* self, void*) { |
1354 | Py_ssize_t i = 0; |
1355 | for (auto l : self->levels()) { |
1356 | if (l.is_positional()) { |
1357 | ++i; |
1358 | } |
1359 | } |
1360 | return py::from_int(i).release(); |
1361 | } |
1362 | |
1363 | static PyGetSetDef Tensor_getsetters[] = { |
1364 | {"_has_device" , (getter) [](PyObject* self, void*) -> PyObject* { return py::from_bool(((Tensor*)self)->has_device()).release(); }, NULL}, |
1365 | {"_tensor" , (getter) [](PyObject* self, void*) -> PyObject* { |
1366 | Arena A; |
1367 | return THPVariable_Wrap(((Tensor*)self)->tensor(A)); }, NULL}, |
1368 | {"_batchtensor" , (getter) [](PyObject* self, void*) -> PyObject* { |
1369 | Arena A; |
1370 | return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); }, NULL}, |
1371 | {"_levels" , (getter) [](PyObject* self, void*) -> PyObject* { |
1372 | PY_BEGIN |
1373 | return levels_to_tuple(((Tensor*)self)->levels()).release(); |
1374 | PY_END(nullptr) |
1375 | }}, |
1376 | {"ndim" , (getter) Tensor_ndim, NULL, "ndim" , NULL}, |
1377 | {NULL} /* Sentinel */ |
1378 | }; |
1379 | |
1380 | static PyMethodDef Tensor_methods[] = { |
1381 | {NULL, NULL, 0, NULL} /* Sentinel */ |
1382 | }; |
1383 | |
1384 | |
1385 | PyTypeObject Tensor::Type = { |
1386 | PyVarObject_HEAD_INIT(NULL, 0) |
1387 | "_C.Tensor" , /* tp_name */ |
1388 | sizeof(Tensor), /* tp_basicsize */ |
1389 | 0, /* tp_itemsize */ |
1390 | Tensor::dealloc_stub, /* tp_dealloc */ |
1391 | 0, /* tp_vectorcall_offset */ |
1392 | 0, /* tp_getattr */ |
1393 | 0, /* tp_setattr */ |
1394 | 0, /* tp_as_async */ |
1395 | 0, /* tp_repr */ |
1396 | 0, /* tp_as_number */ |
1397 | 0, /* tp_as_sequence */ |
1398 | 0, /* tp_as_mapping */ |
1399 | 0, /* tp_hash */ |
1400 | 0, /* tp_call */ |
1401 | 0, /* tp_str */ |
1402 | 0, /* tp_getattro */ |
1403 | 0, /* tp_setattro */ |
1404 | 0, /* tp_as_buffer */ |
1405 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE , /* tp_flags */ |
1406 | "Tensor Object" , /* tp_doc */ |
1407 | 0, /* tp_traverse */ |
1408 | 0, /* tp_clear */ |
1409 | 0, /* tp_richcompare */ |
1410 | 0, /* tp_weaklistoffset */ |
1411 | 0, /* tp_iter */ |
1412 | 0, /* tp_iternext */ |
1413 | Tensor_methods, /* tp_methods */ |
1414 | 0, /* tp_members */ |
1415 | Tensor_getsetters, /* tp_getset */ |
1416 | 0, /* tp_base */ |
1417 | 0, /* tp_dict */ |
1418 | 0, /* tp_descr_get */ |
1419 | 0, /* tp_descr_set */ |
1420 | 0, /* tp_dictoffset */ |
1421 | 0, /* tp_init */ |
1422 | 0, /* tp_alloc */ |
1423 | Tensor::new_stub, /* tp_new */ |
1424 | }; |
1425 | |
1426 | |
1427 | // dim() -------------------- |
1428 | |
1429 | bool relevant_op(_Py_CODEUNIT c) { |
1430 | switch(_Py_OPCODE(c)) { |
1431 | case STORE_NAME: |
1432 | case STORE_GLOBAL: |
1433 | case STORE_FAST: |
1434 | case STORE_DEREF: |
1435 | return true; |
1436 | default: |
1437 | return false; |
1438 | } |
1439 | } |
1440 | |
1441 | py::object create_dim(py::object name, py::handle size) { |
1442 | auto d = Dim::create(std::move(name)); |
1443 | if (!py::is_none(size)) { |
1444 | d->set_size(py::to_int(size)); |
1445 | } |
1446 | return std::move(d); |
1447 | } |
1448 | |
1449 | py::object create_dimlist(py::object name, py::handle size) { |
1450 | auto d = DimList::create(std::move(name)); |
1451 | if (!py::is_none(size)) { |
1452 | if (py::is_int(size)) { |
1453 | d->bind_len(py::to_int(size)); |
1454 | } else { |
1455 | py::sequence_view s(size); |
1456 | d->bind_len(s.size()); |
1457 | for (auto i : irange(d->size())) { |
1458 | d->dims_[i]->set_size(py::to_int(s[i])); |
1459 | } |
1460 | } |
1461 | } |
1462 | return std::move(d); |
1463 | } |
1464 | |
1465 | |
1466 | |
1467 | // Python wrappers that make new reflection primitives available for older runtimes |
1468 | #if !(IS_PYTHON_3_11_PLUS) |
1469 | #define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code)) |
1470 | #endif |
1471 | |
1472 | struct PyInstDecoder { |
1473 | PyInstDecoder(PyCodeObject* code_object, int lasti) |
1474 | : code_object_(code_object), code_(_PyCode_CODE(code_object)), offset_(lasti / sizeof(_Py_CODEUNIT)) {} |
1475 | // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols |
1476 | // See https://github.com/pytorch/pytorch/issues/93854 |
1477 | void next() { |
1478 | #if IS_PYTHON_3_11_PLUS && !defined(_WIN32) |
1479 | offset_ += _PyOpcode_Caches[opcode()]; |
1480 | #endif |
1481 | offset_ += 1; |
1482 | } |
1483 | int opcode() { |
1484 | auto r = _Py_OPCODE(code_[offset_]); |
1485 | #if IS_PYTHON_3_11_PLUS && !defined(_WIN32) |
1486 | r = _PyOpcode_Deopt[r]; |
1487 | #endif |
1488 | return r; |
1489 | } |
1490 | int oparg() { |
1491 | return _Py_OPARG(code_[offset_]); |
1492 | } |
1493 | |
1494 | py::object name() { |
1495 | py::object names; |
1496 | switch(opcode()) { |
1497 | case STORE_NAME: |
1498 | case STORE_GLOBAL: |
1499 | names = py::object::borrow(code_object_->co_names); |
1500 | break; |
1501 | case STORE_FAST: |
1502 | names = py::object::steal(PyCode_GetVarnames(code_object_)); |
1503 | break; |
1504 | case STORE_DEREF: |
1505 | names = py::object::steal(PyCode_GetCellvars(code_object_)); |
1506 | break; |
1507 | default: |
1508 | return py::object(); |
1509 | } |
1510 | return py::object::steal(PySequence_GetItem(names.ptr(), oparg())); |
1511 | } |
1512 | private: |
1513 | PyCodeObject* code_object_; |
1514 | _Py_CODEUNIT* code_; |
1515 | int offset_; |
1516 | }; |
1517 | |
1518 | template<py::object (*create_object)(py::object, py::handle)> |
1519 | static PyObject* _dims(PyObject *self, |
1520 | PyObject *const *args, |
1521 | Py_ssize_t nargs, |
1522 | PyObject *kwnames) { |
1523 | PY_BEGIN |
1524 | Py_ssize_t specified_ndims = -1; |
1525 | Py_ssize_t found_ndims = 0; |
1526 | Py_ssize_t sizes = -1; |
1527 | py::handle n = Py_None; |
1528 | py::handle py_sizes = Py_None; |
1529 | |
1530 | if (nargs || kwnames) { |
1531 | py::vector_args va(args, nargs, kwnames); |
1532 | va.parse("dims" , {"n" , "sizes" }, {&n, &py_sizes}, 0); |
1533 | if (!py::is_none(py_sizes)) { |
1534 | sizes = py::sequence_view(py_sizes).size(); |
1535 | specified_ndims = sizes; |
1536 | } |
1537 | if (!py::is_none(n)) { |
1538 | specified_ndims = py::to_int(n); |
1539 | } |
1540 | } |
1541 | |
1542 | PyThreadState* state = PyThreadState_GET(); |
1543 | auto f = py::obj<PyFrameObject>::steal(PyThreadState_GetFrame(state)); |
1544 | auto c = py::obj<PyCodeObject>::steal(PyFrame_GetCode(f.ptr())); |
1545 | auto lasti = PyFrame_GetLasti(f.ptr()); |
1546 | auto decoder = PyInstDecoder(c.ptr(), lasti); |
1547 | #if IS_PYTHON_3_11_PLUS |
1548 | // When py3.11 adapts bytecode lasti points to the precall |
1549 | // rather than the call instruction after it |
1550 | if (decoder.opcode() == PRECALL) { |
1551 | decoder.next(); |
1552 | } |
1553 | #endif |
1554 | decoder.next(); |
1555 | |
1556 | if (relevant_op(decoder.opcode())) { |
1557 | found_ndims = 1; |
1558 | } else if (decoder.opcode() == UNPACK_SEQUENCE) { |
1559 | found_ndims = decoder.oparg(); |
1560 | decoder.next(); |
1561 | } |
1562 | |
1563 | if (specified_ndims == -1) { |
1564 | if (found_ndims == 0) { |
1565 | py::raise_error(PyExc_SyntaxError, "dims() must be assigned to a sequence of variable names or have argument n specified" ); |
1566 | } |
1567 | specified_ndims = found_ndims; |
1568 | } |
1569 | if (found_ndims != specified_ndims) { |
1570 | found_ndims = 0; // avoid taking the wrong names for dimensions |
1571 | } |
1572 | |
1573 | auto genobject = [&](int i) -> py::object { |
1574 | py::object name; |
1575 | if (i < found_ndims) { |
1576 | name = decoder.name(); |
1577 | } |
1578 | if (!name.ptr()) { |
1579 | name = py::unicode_from_format("d%d" , i); |
1580 | found_ndims = 0; // once we fail at finding a name, we can find any more |
1581 | } else { |
1582 | decoder.next(); |
1583 | } |
1584 | return create_object(std::move(name), sizes != -1 ? py::sequence_view(py_sizes)[i] : py::handle(Py_None)); |
1585 | }; |
1586 | if (sizes != -1 && sizes != specified_ndims) { |
1587 | py::raise_error(PyExc_ValueError, "expected %d sizes but found %d" , int(specified_ndims), int(sizes)); |
1588 | } |
1589 | if (specified_ndims == 1) { |
1590 | return genobject(0).release(); |
1591 | } |
1592 | py::tuple result(specified_ndims); |
1593 | for (int i = 0; i < specified_ndims; ++i) { |
1594 | result.set(i, genobject(i)); |
1595 | } |
1596 | return result.release(); |
1597 | PY_END(nullptr) |
1598 | } |
1599 | |
1600 | int64_t dim_index(const std::vector<py::obj<Dim>>& dims, py::hdl<Dim> dim) { |
1601 | for (int64_t i = 0, N = dims.size(); i < N; ++i) { |
1602 | if (dims[i].ptr() == dim.ptr()) { |
1603 | return i; |
1604 | } |
1605 | } |
1606 | return -1; |
1607 | } |
1608 | |
1609 | |
1610 | struct DotPart { |
1611 | Slice<DimEntry> dims; |
1612 | size_t total_size = 1; |
1613 | void append(Arena& A, py::hdl<Dim> d) { |
1614 | total_size *= d->size(); |
1615 | dims.append(A, d); |
1616 | } |
1617 | }; |
1618 | |
1619 | template<typename T> |
1620 | static at::ArrayRef<T> as_array_ref(Slice<T> t) { |
1621 | return at::ArrayRef<T>(t.begin(), t.end()); |
1622 | } |
1623 | |
1624 | TensorRef dot_prepare(Arena& A, std::initializer_list<DotPart> parts, const TensorInfo& t) { |
1625 | Slice<DimEntry> new_levels; |
1626 | bool needs_reshape = false; |
1627 | for (auto p : parts) { |
1628 | if (p.dims.size() != 1) { |
1629 | needs_reshape = true; |
1630 | } |
1631 | new_levels.extend(A, p.dims); |
1632 | } |
1633 | auto r = _match_levels(A, t.tensor, t.levels, new_levels, true); |
1634 | if (!needs_reshape) { |
1635 | return r; |
1636 | } |
1637 | Slice<int64_t> view; |
1638 | for (auto p : parts) { |
1639 | view.append(A, p.total_size); |
1640 | } |
1641 | return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end()))); |
1642 | } |
1643 | |
1644 | py::object dot_finish(Arena& A, std::initializer_list<DotPart> parts, at::Tensor r) { |
1645 | Slice<DimEntry> result_levels; |
1646 | bool needs_reshape = false; |
1647 | for (auto p : parts) { |
1648 | if (p.dims.size() != 1) { |
1649 | needs_reshape = true; |
1650 | } |
1651 | result_levels.extend(A, p.dims); |
1652 | } |
1653 | if (needs_reshape) { |
1654 | Slice<int64_t> new_size; |
1655 | for (auto l : result_levels) { |
1656 | new_size.append(A, l.dim()->size()); |
1657 | } |
1658 | r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end())); |
1659 | } |
1660 | return Tensor::from_positional(A, std::move(r), result_levels, true); |
1661 | } |
1662 | |
1663 | |
1664 | |
1665 | py::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice<DimEntry> sum) { |
1666 | auto lhs_strides = lhs.tensor->strides(); |
1667 | auto rhs_strides = rhs.tensor->strides(); |
1668 | |
1669 | DotPart lro_dims; |
1670 | DotPart lo_dims; |
1671 | DotPart ro_dims; |
1672 | DotPart lr_dims; |
1673 | |
1674 | auto insert_dim = [&] (py::hdl<Dim> d, at::optional<int> lhs_idx, at::optional<int> rhs_idx) { |
1675 | bool reduced = sum.contains(d); |
1676 | int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; |
1677 | int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; |
1678 | if (reduced) { |
1679 | // lr |
1680 | lr_dims.append(A, d); |
1681 | } else { |
1682 | if ((lhs_stride == 0) == (rhs_stride == 0)) { |
1683 | // lro |
1684 | lro_dims.append(A, d); |
1685 | } else if (lhs_stride != 0) { |
1686 | // lo |
1687 | lo_dims.append(A, d); |
1688 | } else { |
1689 | AT_ASSERT(rhs_stride != 0); |
1690 | ro_dims.append(A, d); |
1691 | } |
1692 | } |
1693 | }; |
1694 | |
1695 | |
1696 | auto rhs_seen = A.allocate<bool>(rhs.levels.size()); |
1697 | std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false); |
1698 | |
1699 | for (auto i : lhs.levels.enumerate()) { |
1700 | auto d = lhs.levels[i]; |
1701 | auto rhs_idx = rhs.levels.index(d); |
1702 | if (rhs_idx) { |
1703 | rhs_seen[*rhs_idx] = true; |
1704 | } |
1705 | insert_dim(d.dim(), i, rhs_idx); |
1706 | } |
1707 | |
1708 | for (auto i : rhs.levels.enumerate()) { |
1709 | if (rhs_seen[i]) { |
1710 | continue; |
1711 | } |
1712 | auto d = rhs.levels[i]; |
1713 | insert_dim(d.dim(), at::nullopt, i); |
1714 | } |
1715 | |
1716 | if (lr_dims.dims.size() != sum.size()) { |
1717 | for (auto & d : sum) { |
1718 | if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { |
1719 | py::raise_error(DimensionBindError(), "summing over non-existant dimension %S" , d.dim().ptr()); |
1720 | } |
1721 | } |
1722 | } |
1723 | |
1724 | // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n"; |
1725 | // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << " " << lr_dims.dims << "\n"; |
1726 | |
1727 | // no batch, just call mm |
1728 | if (lro_dims.dims.size() != 0) { |
1729 | auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs); |
1730 | auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs); |
1731 | return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_)); |
1732 | } else { |
1733 | auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs); |
1734 | auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs); |
1735 | return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_)); |
1736 | } |
1737 | |
1738 | } |
1739 | |
1740 | static PyObject* test_c(PyObject *self, |
1741 | PyObject *const *args, |
1742 | Py_ssize_t nargs, |
1743 | PyObject *kwnames) { |
1744 | PY_BEGIN |
1745 | |
1746 | Arena A; |
1747 | Slice<int> s(A, 3, 4, 5); |
1748 | AT_ASSERT(s.size() == 3 && s.capacity() == 8); |
1749 | AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5); |
1750 | s.append(A, 6); |
1751 | AT_ASSERT(s[3] == 6); |
1752 | for(int i : irange(10)) { |
1753 | s.append(A, i); |
1754 | } |
1755 | AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16); |
1756 | |
1757 | Slice<int> s2(A, -1, -2, -3); |
1758 | AT_ASSERT(s2[1] == -2 && s[0] == 3); |
1759 | |
1760 | auto ss = s.slice(1,2); |
1761 | AT_ASSERT(ss.size() == 1); |
1762 | AT_ASSERT(ss[0] == 4); |
1763 | AT_ASSERT(ss.capacity() == 1); |
1764 | ss.append(A, -4); |
1765 | AT_ASSERT(ss.size() == 2 && ss[1] == -4); |
1766 | ss[0] = 3; |
1767 | AT_ASSERT(s[1] == 4); |
1768 | |
1769 | s.insert(A, s.slice(1, 4), ss); |
1770 | AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0); |
1771 | |
1772 | auto sz = s.size(); |
1773 | s.insert(A, s.slice(1, 1), 4); |
1774 | AT_ASSERT(s[1] == 4 && sz + 1 == s.size()); |
1775 | |
1776 | |
1777 | Slice<int> d(A, 0, 1, 2, 3, 4); |
1778 | |
1779 | Slice<int> b(A, 0, 1, 2, 3, 4); |
1780 | b.insert(A, b.slice(1,1), d); |
1781 | AT_ASSERT(b.size() == 10); |
1782 | AT_ASSERT(b[1] == 0); |
1783 | AT_ASSERT(b[5] == 4); |
1784 | AT_ASSERT(b.back() == 4); |
1785 | |
1786 | Py_RETURN_NONE; |
1787 | |
1788 | PY_END(nullptr); |
1789 | } |
1790 | |
1791 | static DimEntry _wrap_dim(py::handle d, size_t N, bool keepdim); |
1792 | |
1793 | static PyObject* order(PyObject *_, |
1794 | PyObject *const *args, |
1795 | Py_ssize_t nargs, |
1796 | PyObject *kwnames) { |
1797 | Arena A; |
1798 | PY_BEGIN |
1799 | if (kwnames) { |
1800 | py::raise_error(PyExc_TypeError, "unexpected keyword arguments %S" , kwnames); |
1801 | } |
1802 | AT_ASSERT(nargs-- > 0); |
1803 | Slice<DimEntry> orig_levels; |
1804 | Slice<DimEntry> levels; |
1805 | TensorRef data; |
1806 | py::handle self = args++[0]; |
1807 | bool has_device; |
1808 | if (Tensor::check_exact(self)) { |
1809 | auto t = Tensor::unchecked_wrap(self); |
1810 | orig_levels = t->levels(); |
1811 | data = t->tensor(A); |
1812 | has_device = t->has_device(); |
1813 | } else { |
1814 | auto d = Dim::unchecked_wrap(self); |
1815 | orig_levels.append(A, d); |
1816 | data = d->range(); |
1817 | has_device = false; |
1818 | } |
1819 | |
1820 | Slice<DimEntry> flat_positional_dims; |
1821 | Slice<std::pair<int, int>> to_flatten; |
1822 | levels.extend(A, orig_levels); |
1823 | |
1824 | int orig_ndim = ndim_of_levels(levels); |
1825 | auto append = [&](DimEntry d) { |
1826 | auto midx = levels.index(d); |
1827 | if (!midx) { |
1828 | if (d.is_positional()) { |
1829 | py::raise_error(PyExc_ValueError, "tensor has %d positional dimensions, but %d specified, or it was specified twice" , int(orig_ndim), int(d.position() + orig_ndim)); |
1830 | } else { |
1831 | py::raise_error(PyExc_ValueError, "tensor of dimensions %R does not contain dim %R or it was specified twice" , levels_to_tuple(orig_levels).ptr(), d.dim().ptr()); |
1832 | } |
1833 | } |
1834 | levels[*midx] = DimEntry(); |
1835 | flat_positional_dims.append(A, d); |
1836 | }; |
1837 | |
1838 | int n_new_positional = 0; |
1839 | for (auto i :irange(nargs)) { |
1840 | py::handle arg = args[i]; |
1841 | DimEntry entry = _wrap_dim(arg, orig_ndim, false); |
1842 | if (!entry.is_none()) { |
1843 | append(entry); |
1844 | ++n_new_positional; |
1845 | } else if (DimList::check(arg)) { |
1846 | auto dl = DimList::unchecked_wrap(arg); |
1847 | for (py::obj<Dim> & d : dl->dims_) { |
1848 | append(py::hdl<Dim>(d)); |
1849 | ++n_new_positional; |
1850 | } |
1851 | } else { |
1852 | ++n_new_positional; |
1853 | if (!py::is_sequence(arg)) { |
1854 | py::raise_error(PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]" ); |
1855 | } |
1856 | py::sequence_view sq(arg); |
1857 | auto N = sq.size(); |
1858 | to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N)); |
1859 | for (auto j : irange(N)) { |
1860 | DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false); |
1861 | if (e.is_none()) { |
1862 | py::raise_error(PyExc_ValueError, "expected a Dim, or int" ); |
1863 | } |
1864 | append(e); |
1865 | } |
1866 | } |
1867 | } |
1868 | |
1869 | int ndim = 0; |
1870 | int insert_point = -1; |
1871 | Slice<DimEntry> new_levels; |
1872 | for (auto l : levels) { |
1873 | if (l.is_none()) { |
1874 | continue; |
1875 | } |
1876 | if (l.is_positional()) { |
1877 | ndim++; |
1878 | if (insert_point == -1) { |
1879 | insert_point = new_levels.size(); |
1880 | new_levels.extend(A, flat_positional_dims); |
1881 | } |
1882 | } |
1883 | new_levels.append(A, l); |
1884 | } |
1885 | if (insert_point == -1) { |
1886 | insert_point = new_levels.size(); |
1887 | new_levels.extend(A, flat_positional_dims); |
1888 | } |
1889 | |
1890 | at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels); |
1891 | |
1892 | if (to_flatten.size()) { |
1893 | Slice<int64_t> view; |
1894 | auto sz = ndata.sizes(); |
1895 | // before the new positional dims |
1896 | for (auto i : irange(0, insert_point)) { |
1897 | view.append(A, sz[i]); |
1898 | } |
1899 | int i = 0; |
1900 | for (auto to_flat : to_flatten) { |
1901 | for (;i < to_flat.first; ++i) { |
1902 | view.append(A, sz[insert_point + i]); |
1903 | } |
1904 | int64_t new_size = 1; |
1905 | int last = i + to_flat.second; |
1906 | for (; i < last; ++i) { |
1907 | new_size *= sz[insert_point + i]; |
1908 | } |
1909 | view.append(A, new_size); |
1910 | } |
1911 | for (; i < flat_positional_dims.size(); ++i) { |
1912 | view.append(A, sz[insert_point + i]); |
1913 | } |
1914 | // after the new positional dims |
1915 | for (auto i : irange(insert_point + flat_positional_dims.size(), levels.size())) { |
1916 | view.append(A, sz[i]); |
1917 | } |
1918 | // we shorted the number of dimension, so remove them from new levels |
1919 | // we will renumber them later |
1920 | auto n_to_remove = flat_positional_dims.size() - n_new_positional; |
1921 | new_levels.insert(A, new_levels.slice(insert_point, insert_point + n_to_remove), Slice<DimEntry>()); |
1922 | ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end())); |
1923 | } |
1924 | |
1925 | // renumber the positional dimension |
1926 | int seen = 0; |
1927 | for (auto i : new_levels.reversed_enumerate()) { |
1928 | if (new_levels[i].is_positional() || (i >= insert_point && i < insert_point + n_new_positional)) { |
1929 | new_levels[i] = --seen; |
1930 | } |
1931 | } |
1932 | return Tensor::from_positional(A, std::move(ndata), new_levels, has_device).release(); |
1933 | |
1934 | PY_END(nullptr) |
1935 | } |
1936 | |
1937 | static PyObject* expand(PyObject *_, |
1938 | PyObject *const *args, |
1939 | Py_ssize_t nargs, |
1940 | PyObject *kwnames) { |
1941 | Arena A; |
1942 | PY_BEGIN |
1943 | AT_ASSERT(nargs-- > 0); |
1944 | auto info = TensorInfo::create(A, args++[0], false); |
1945 | for (auto i : irange(nargs)) { |
1946 | if (!Dim::check(args[i])) { |
1947 | maybeInitializeGlobals(); |
1948 | py::vector_args vargs(args - 1, nargs + 1, kwnames); |
1949 | if (THPVariable_Check(args[-1])) { |
1950 | return torch_Tensor_expand.call_vector(vargs).release(); |
1951 | } else { |
1952 | return __torch_function__(A, torch_Tensor_expand, vargs, false).release(); |
1953 | } |
1954 | } |
1955 | } |
1956 | const at::Tensor& data = *info.tensor; |
1957 | auto levels = info.levels; |
1958 | Slice<DimEntry> new_levels; |
1959 | Slice<int64_t> sz; |
1960 | Slice<int64_t> sd; |
1961 | for (auto i : irange(nargs)) { |
1962 | auto d = Dim::unchecked_wrap(args[i]); |
1963 | if (levels.contains(d) || new_levels.contains(d)) { |
1964 | py::raise_error(DimensionBindError(), "expanding dimension %R already exists in tensor with dims" , d.ptr()); |
1965 | } |
1966 | new_levels.append(A, d); |
1967 | sz.append(A, d->size()); |
1968 | sd.append(A, 0); |
1969 | } |
1970 | new_levels.extend(A, levels); |
1971 | at::IntArrayRef osz = data.sizes(); |
1972 | at::IntArrayRef osd = data.strides(); |
1973 | sz.extend(A, osz.begin(), osz.end()); |
1974 | sd.extend(A, osd.begin(), osd.end()); |
1975 | at::Tensor ndata = data.as_strided(at::IntArrayRef(sz.begin(), sz.end()), at::IntArrayRef(sd.begin(), sd.end()), data.storage_offset()); |
1976 | return Tensor::from_positional(A, std::move(ndata), new_levels, info.has_device).release(); |
1977 | PY_END(nullptr) |
1978 | } |
1979 | |
1980 | |
1981 | void _bind_dims_to_size(Arena & A, int64_t sz, int64_t sd, |
1982 | Slice<py::hdl<Dim>> dims, Slice<int64_t>& nsz, Slice<int64_t>& nsd) { |
1983 | int64_t rhs_prod = 1; |
1984 | for (auto i : dims.enumerate()) { |
1985 | if (!dims[i]->is_bound()) { |
1986 | for (auto j : irange(i + 1, dims.size())) { |
1987 | if (!dims[j]->is_bound()) { |
1988 | py::raise_error(DimensionBindError(), "cannot infer the sizes of two dimensions at once %R and %R" , dims[i].ptr(), dims[j].ptr()); |
1989 | } |
1990 | rhs_prod *= dims[j]->size(); |
1991 | } |
1992 | if (sz % rhs_prod != 0) { |
1993 | py::tuple tup(dims.size()); |
1994 | for (auto j : dims.enumerate()) { |
1995 | tup.set(j, dims[j]->is_bound() ? py::from_int(dims[j]->size()) : py::unicode_from_string("?" )); |
1996 | } |
1997 | py::raise_error(DimensionBindError(), "inferred dimension does not evenly fit into larger dimension: %d vs %R" , (int) sz, tup.ptr()); |
1998 | } |
1999 | int64_t inferred_size = sz / rhs_prod; |
2000 | dims[i]->set_size(inferred_size); |
2001 | rhs_prod = sz; |
2002 | break; |
2003 | } |
2004 | rhs_prod *= dims[i]->size(); |
2005 | } |
2006 | if (rhs_prod != sz) { |
2007 | py::tuple tup(dims.size()); |
2008 | for (auto j : dims.enumerate()) { |
2009 | tup.set(j, py::object::borrow(dims[j])); |
2010 | } |
2011 | py::raise_error(DimensionBindError(), "Dimension sizes to do not match (%d != %d) when matching dimension pack %R" , (int) sz, (int) rhs_prod, tup.ptr()); |
2012 | } |
2013 | auto new_strides = A.allocate<int64_t>(dims.size()); |
2014 | auto prev_stride = sd; |
2015 | for (auto i : dims.reversed_enumerate()) { |
2016 | new_strides[i] = prev_stride; |
2017 | prev_stride = dims[i]->size()*prev_stride; |
2018 | } |
2019 | for (auto i : dims.enumerate()) { |
2020 | nsd.append(A, new_strides[i]); |
2021 | nsz.append(A, dims[i]->size()); |
2022 | } |
2023 | } |
2024 | |
2025 | inline bool has_dims(py::handle d) { |
2026 | return Dim::check_exact(d) || Tensor::check_exact(d); |
2027 | } |
2028 | |
2029 | struct IndexingInfo { |
2030 | bool can_call_original; // if true, then it is safe to just call getitem or setitem, these objects do not need special handling |
2031 | bool advanced_indexing; // requires actual lookup |
2032 | TensorRef self; |
2033 | Slice<py::handle> flat_inputs; |
2034 | Slice<DimEntry> result_levels; |
2035 | bool has_device; |
2036 | }; |
2037 | |
2038 | static Slice<py::handle> as_slice(py::tuple_view tv) { |
2039 | PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(),0); |
2040 | return Slice<py::handle>((py::handle*)begin, (py::handle*) (begin + tv.size())); |
2041 | } |
2042 | |
2043 | static Slice<py::handle> as_slice(py::list_view tv) { |
2044 | PyObject** begin = &PyList_GET_ITEM(tv.ptr(),0); |
2045 | return Slice<py::handle>((py::handle*)begin, (py::handle*) (begin + tv.size())); |
2046 | } |
2047 | |
2048 | |
2049 | bool maybe_dimpack(Slice<py::handle>& elements, py::handle s, bool check_first=true) { |
2050 | // can we avoid rechecking? |
2051 | if (py::list_view::check(s)) { |
2052 | py::list_view tv(s); |
2053 | if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { |
2054 | elements = as_slice(tv); |
2055 | return true; |
2056 | } |
2057 | } |
2058 | // can we avoid rechecking? |
2059 | if (py::tuple_view::check(s)) { |
2060 | py::tuple_view tv(s); |
2061 | if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { |
2062 | elements = as_slice(tv); |
2063 | return true; |
2064 | } |
2065 | } |
2066 | return false; |
2067 | }; |
2068 | |
2069 | bool is_dimpack(py::handle s) { |
2070 | Slice<py::handle> e; |
2071 | return maybe_dimpack(e, s); |
2072 | } |
2073 | |
2074 | IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice<py::handle> input, Slice<DimEntry> keys, Slice<py::handle> values, bool has_dimpacks_or_none); |
2075 | static py::object invoke_getitem(Arena& A, const IndexingInfo& iinfo); |
2076 | |
2077 | static py::object index(Arena& A, py::handle self, py::handle dims, py::handle indices) { |
2078 | maybeInitializeGlobals(); |
2079 | Slice<py::handle> dims_list; |
2080 | Slice<py::handle> indices_list; |
2081 | // we allow for matching single dims to multiple dims, |
2082 | // so we first have to normalize everything into the case where there is a list on lhs and the rhs |
2083 | bool lhs_list = py::tuple_view::check(dims) || py::list_view::check(dims); |
2084 | bool rhs_list = py::tuple_view::check(indices) || py::list_view::check(indices); |
2085 | if (lhs_list && rhs_list) { |
2086 | py::sequence_view dv(dims); |
2087 | py::sequence_view ind(indices); |
2088 | Py_ssize_t N = dv.size(); |
2089 | if (N != ind.size()) { |
2090 | py::raise_error(PyExc_TypeError, "dims (%d) and indices (%d) must have the same length" , int(N), int(ind.size())); |
2091 | } |
2092 | for (auto i : irange(N)) { |
2093 | dims_list.append(A, A.autorelease(dv[i])); |
2094 | indices_list.append(A, A.autorelease(ind[i])); |
2095 | } |
2096 | } else { |
2097 | dims_list.append(A, dims); |
2098 | indices_list.append(A, indices); |
2099 | } |
2100 | |
2101 | // dims being indexed can be grouped together into a single index space, and we have to |
2102 | // flatten them int a single dimension before we can index them... |
2103 | auto self_info = TensorInfo::create(A, self, false); |
2104 | auto ndim = self_info.ndim(); |
2105 | Slice<DimEntry> new_levels; |
2106 | Slice<DimEntry> to_flatten; |
2107 | Slice<DimEntry> dims_list_flat; |
2108 | auto parse_dim_entry = [&](py::handle s) -> DimEntry { |
2109 | auto d = _wrap_dim(s, ndim, false); |
2110 | if (d.is_none()) { |
2111 | py::raise_error(PyExc_TypeError, "expected a dimension specifyer but found %R" , s.ptr()); |
2112 | } |
2113 | return d; |
2114 | }; |
2115 | auto dim_not_present = [&](DimEntry d) { |
2116 | if (d.is_positional()) { |
2117 | py::raise_error(PyExc_TypeError, "dimension %d not in tensor of %d dimensions" , d.position() + ndim , ndim); |
2118 | } else { |
2119 | py::raise_error(PyExc_TypeError, "dimension %R not in tensor" , d.dim()->ptr()); |
2120 | } |
2121 | }; |
2122 | |
2123 | for (auto i : dims_list.enumerate()) { |
2124 | Slice<py::handle> m; |
2125 | if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) { |
2126 | if (m.size() == 0) { |
2127 | // plausible semantics work for this to have 0 elements (e.g. the index will always be 0) |
2128 | dims_list_flat.append(A, DimEntry()); // value is just dropped |
2129 | } |
2130 | auto first = parse_dim_entry(m[0]); |
2131 | dims_list_flat.append(A, first); |
2132 | if (m.size() == 1) { |
2133 | continue; |
2134 | } |
2135 | if (to_flatten.size() == 0) { |
2136 | new_levels.extend(A, self_info.levels); |
2137 | } |
2138 | Slice<DimEntry> rest; |
2139 | for (auto i : irange(1, m.size())) { |
2140 | auto d = parse_dim_entry(m[i]); |
2141 | if (!new_levels.remove(A, d)) { |
2142 | dim_not_present(d); |
2143 | } |
2144 | rest.append(A, d); |
2145 | } |
2146 | |
2147 | auto first_idx = new_levels.index(first); |
2148 | if (!first_idx) { |
2149 | dim_not_present(first); |
2150 | } |
2151 | new_levels.insert(A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest); |
2152 | to_flatten.extend(A, rest); |
2153 | } else { |
2154 | dims_list_flat.append(A, parse_dim_entry(dims_list[i])); |
2155 | } |
2156 | } |
2157 | if (to_flatten.size() > 0) { |
2158 | TensorRef rearranged = _match_levels(A, self_info.tensor, self_info.levels, new_levels); |
2159 | at::IntArrayRef sizes = rearranged->sizes(); |
2160 | Slice<int64_t> new_sizes; |
2161 | Slice<DimEntry> reshape_levels; |
2162 | for (auto i : new_levels.enumerate()) { |
2163 | if (to_flatten.contains(new_levels[i])) { |
2164 | new_sizes.back() *= sizes[i]; |
2165 | } else { |
2166 | new_sizes.append(A, sizes[i]); |
2167 | reshape_levels.append(A, new_levels[i]); |
2168 | } |
2169 | } |
2170 | self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); |
2171 | |
2172 | self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op |
2173 | // we need to be careful not to rely the dimensions size because it doesnt match the size of the whole group |
2174 | } |
2175 | bool has_dimpacks = false; |
2176 | for (auto idx : indices_list) { |
2177 | if (py::tuple_view::check(idx) || py::list_view::check(idx)) { |
2178 | has_dimpacks = true; |
2179 | break; |
2180 | } |
2181 | } |
2182 | IndexingInfo info = getsetitem_flat(A, self_info, Slice<py::handle>(), dims_list_flat, indices_list, has_dimpacks); |
2183 | return invoke_getitem(A, info); |
2184 | } |
2185 | |
2186 | // true -- the indices were flattend out of a tuple, list or sequence... |
2187 | |
2188 | Slice<py::handle> slice_from_sequence(Arena& A, py::handle value) { |
2189 | if (py::tuple_view::check(value)) { |
2190 | return as_slice(py::tuple_view(value)); |
2191 | } else if (py::list_view::check(value)) { |
2192 | return as_slice(py::list_view(value)); |
2193 | } else { |
2194 | py::sequence_view sv(value); |
2195 | Slice<py::handle> r; |
2196 | for (auto i : sv.enumerate()) { |
2197 | r.append(A, A.autorelease(sv[i])); |
2198 | } |
2199 | return r; |
2200 | } |
2201 | } |
2202 | |
2203 | bool extractIndices(Arena& A, py::handle index, Slice<py::handle>& indices) { |
2204 | if (py::tuple_view::check(index)) { |
2205 | indices.extend(A, as_slice(py::tuple_view(index))); |
2206 | return true; |
2207 | } else if (THPVariable_Check(index.ptr())) { |
2208 | indices.append(A, index); |
2209 | return false; |
2210 | } else if (!py::is_sequence(index)) { |
2211 | indices.append(A, index); |
2212 | return false; |
2213 | } |
2214 | // a copy of treatSequenceAsTuple modified to add Dim and our wrapped tensors.. |
2215 | py::sequence_view sv(index); |
2216 | if (sv.size() >= 32) { |
2217 | indices.extend(A, slice_from_sequence(A, index)); |
2218 | return true; |
2219 | } |
2220 | for (auto i : sv.enumerate()) { |
2221 | py::handle item; |
2222 | try { |
2223 | item = sv[i]; |
2224 | } catch (py::exception_set & e) { |
2225 | PyErr_Clear(); |
2226 | indices.append(A, index); |
2227 | return false; |
2228 | } |
2229 | if (THPVariable_Check(item.ptr()) || py::is_sequence(item) || PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || py::is_none(item) || has_dims(item)) { |
2230 | indices.extend(A, slice_from_sequence(A, index)); |
2231 | return true; |
2232 | } |
2233 | } |
2234 | indices.append(A, index); |
2235 | return false; |
2236 | } |
2237 | |
2238 | static IndexingInfo getsetitem(Arena & A, py::handle self, py::handle index, bool tensors_have_dims) { |
2239 | bool can_call_original_getitem = !tensors_have_dims; |
2240 | |
2241 | Slice<py::handle> input; |
2242 | if (has_dims(index)) { |
2243 | input.append(A, index); |
2244 | } else { |
2245 | bool is_sequence = extractIndices(A, index, input); |
2246 | // nothing about first class dims here, fallback to getitem |
2247 | if (can_call_original_getitem && !is_sequence) { |
2248 | return { true }; |
2249 | } |
2250 | } |
2251 | |
2252 | int64_t dims_indexed = 0; |
2253 | int64_t expanding_object = -1; |
2254 | DimList* unbound_dim_list = nullptr; |
2255 | auto check_expanding = [&](int64_t i) { |
2256 | if (expanding_object != -1) { |
2257 | py::raise_error(DimensionBindError(), "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d" , (int) expanding_object, (int) i); |
2258 | } |
2259 | expanding_object = i; |
2260 | }; |
2261 | Slice<int64_t> dimlists; |
2262 | |
2263 | // calculate how many dimensioned have been indexed in order to compute the size of ... |
2264 | // or expand a potentially unbound dimension list. |
2265 | |
2266 | bool has_dimpacks_or_none = false; |
2267 | for (auto i : input.enumerate()) { |
2268 | py::handle s = input[i]; |
2269 | if (Dim::check_exact(s) || Tensor::check_exact(s)) { |
2270 | can_call_original_getitem = false; |
2271 | ++dims_indexed; |
2272 | } else if (s.ptr() == Py_Ellipsis) { |
2273 | check_expanding(i); |
2274 | } else if (DimList::check(s)) { |
2275 | can_call_original_getitem = false; |
2276 | auto dl = DimList::unchecked_wrap(s); |
2277 | if (!dl->is_bound()) { |
2278 | check_expanding(i); |
2279 | unbound_dim_list = dl.ptr(); |
2280 | } else { |
2281 | dims_indexed += dl->dims_.size(); |
2282 | } |
2283 | dimlists.append(A, i); |
2284 | } else if (py::is_none(s)) { |
2285 | has_dimpacks_or_none = true; |
2286 | } else if (is_dimpack(s)) { |
2287 | can_call_original_getitem = false; |
2288 | has_dimpacks_or_none = true; |
2289 | ++dims_indexed; |
2290 | } else { |
2291 | ++dims_indexed; |
2292 | } |
2293 | } |
2294 | |
2295 | // at this point if we haven't seen any Dim objects, we also can fallback to the original getitem. |
2296 | if (can_call_original_getitem) { |
2297 | return {true}; |
2298 | } |
2299 | |
2300 | // std::cout << "__getitem__ " << self << " " << index << "\n"; |
2301 | |
2302 | TensorInfo self_info = TensorInfo::create(A, self, false, true); |
2303 | auto ndim = self_info.ndim(); |
2304 | if (dims_indexed > ndim) { |
2305 | py::raise_error(PyExc_ValueError, "at least %d indices were supplied but the tensor only has %d dimensions" , (int) dims_indexed, (int) ndim); |
2306 | } |
2307 | // expand any unbound dimension list, or expand ... into individual : slices. |
2308 | auto expanding_dims = ndim - dims_indexed; |
2309 | if (expanding_object != -1) { |
2310 | if (unbound_dim_list) { |
2311 | unbound_dim_list->bind_len(expanding_dims); |
2312 | } else { |
2313 | // ... |
2314 | Slice<py::handle> no_slices; |
2315 | for (auto i : irange(expanding_dims)) { |
2316 | (void) i; |
2317 | no_slices.append(A, no_slice); |
2318 | } |
2319 | input.insert(A, input.slice(expanding_object, expanding_object + 1), no_slices); |
2320 | } |
2321 | } |
2322 | |
2323 | // flatten out any dimensions stored in dimlist elements directly into the inputs |
2324 | // std::cout << dimlists << " <- dim lists!\n"; |
2325 | for (int64_t i = dimlists.size() - 1; i >=0; --i) { |
2326 | auto idx = dimlists[i]; |
2327 | // we added more elements to input because of ... |
2328 | // so we need to also adjust the index to get back to where the |
2329 | // dimlist existed |
2330 | if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) { |
2331 | idx += expanding_dims; |
2332 | } |
2333 | auto dl = DimList::unchecked_wrap(input[idx]); |
2334 | // XXX would be better if we used an OwnedSlice in DimList |
2335 | Slice<py::handle> more_dims((py::handle*) &*dl->dims_.begin(), (py::handle*) &*dl->dims_.end()); |
2336 | input.insert(A, input.slice(idx, idx + 1), more_dims); |
2337 | } |
2338 | |
2339 | return getsetitem_flat(A, self_info, input, Slice<DimEntry>(), Slice<py::handle>(), has_dimpacks_or_none); |
2340 | } |
2341 | |
2342 | IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice<py::handle> input, Slice<DimEntry> keys, Slice<py::handle> values, bool has_dimpacks_or_none) { |
2343 | // At this point: |
2344 | // ..., DimList have been eliminated |
2345 | // Dim, Tensor, Tuple[Dim,...], int, slice still remain |
2346 | |
2347 | |
2348 | // we have to count how many times we see a dimension. |
2349 | // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires advanced indexing. |
2350 | Slice<py::hdl<Dim>> seen_dims; |
2351 | Slice<int64_t> seen_dims_nuses; |
2352 | auto add_dim = [&](py::hdl<Dim> entry) { |
2353 | auto midx = seen_dims.index(entry); |
2354 | if (!midx) { |
2355 | seen_dims.append(A, entry); |
2356 | seen_dims_nuses.append(A, 1); |
2357 | } else { |
2358 | ++seen_dims_nuses[*midx]; |
2359 | } |
2360 | }; |
2361 | |
2362 | Slice<py::handle> input_it = input; |
2363 | |
2364 | Slice<py::handle> flat_inputs; |
2365 | // flat inputs will start with an empty py::handle if the |
2366 | // actual value is in the tensor-like object in the tensor info |
2367 | Slice<TensorInfo> tensor_inputs; |
2368 | |
2369 | auto append_flat_handle = [&](py::handle h) { |
2370 | flat_inputs.append(A, h); |
2371 | tensor_inputs.append(A, TensorInfo()); |
2372 | }; |
2373 | TensorRef device_holding_tensor; |
2374 | auto append_tensor_input = [&](TensorInfo ti) { |
2375 | flat_inputs.append(A, py::handle()); |
2376 | tensor_inputs.append(A, ti); |
2377 | if (ti.has_device && !device_holding_tensor) { |
2378 | device_holding_tensor = ti.tensor; |
2379 | } |
2380 | }; |
2381 | |
2382 | Slice<int64_t> nsz; |
2383 | Slice<int64_t> nsd; |
2384 | at::IntArrayRef sz = self_info.tensor->sizes(); |
2385 | at::IntArrayRef sd = self_info.tensor->strides(); |
2386 | |
2387 | auto append_size = [&](int i) { |
2388 | if (has_dimpacks_or_none) { |
2389 | nsz.append(A, sz[i]); |
2390 | nsd.append(A, sd[i]); |
2391 | } |
2392 | }; |
2393 | // std::cout << "self levels: " << self_info.levels << "\n"; |
2394 | |
2395 | auto parse_nones = [&]() { |
2396 | while (input_it.size() && py::is_none(input_it[0])) { |
2397 | append_flat_handle(no_slice); |
2398 | nsz.append(A, 1); |
2399 | nsd.append(A, 0); |
2400 | input_it = input_it.slice(1); |
2401 | } |
2402 | }; |
2403 | |
2404 | |
2405 | auto append_item = [&](int i, py::handle arg) { |
2406 | if (Dim::check_exact(arg)) { |
2407 | auto d = Dim::unchecked_wrap(arg); |
2408 | d->set_size(sz[i]); |
2409 | add_dim(d); |
2410 | append_size(i); |
2411 | append_flat_handle(arg); |
2412 | return; |
2413 | } |
2414 | auto info = TensorInfo::create(A, arg, false, false); |
2415 | if (info) { |
2416 | append_size(i); |
2417 | append_tensor_input(info); |
2418 | for (auto il : info.levels) { |
2419 | if (!il.is_positional()) { |
2420 | add_dim(il.dim()); |
2421 | } |
2422 | } |
2423 | return; |
2424 | } |
2425 | |
2426 | if (has_dimpacks_or_none) { |
2427 | Slice<py::handle> mp; |
2428 | if (maybe_dimpack(mp, arg)) { |
2429 | // dim pack |
2430 | Slice<py::hdl<Dim>> dim_pack; |
2431 | for (auto d : mp) { |
2432 | dim_pack.append(A, Dim::wrap(d)); |
2433 | add_dim(dim_pack.back()); |
2434 | append_flat_handle(dim_pack.back()); |
2435 | } |
2436 | _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd); |
2437 | return; |
2438 | } |
2439 | } |
2440 | |
2441 | append_size(i); |
2442 | append_flat_handle(arg); |
2443 | }; |
2444 | |
2445 | // pair up the indexing expressions with dimension of self it indexes |
2446 | // self may have first-class dims, which do not participate the indexing. |
2447 | for (auto i : self_info.levels.enumerate()) { |
2448 | auto l = self_info.levels[i]; |
2449 | auto idx = keys.index(l); |
2450 | if (idx) { |
2451 | append_item(i, values[*idx]); |
2452 | } else if (l.is_positional()) { |
2453 | // grab and index from the positional list |
2454 | parse_nones(); |
2455 | if (!input_it.size()) { |
2456 | // we might have fewer indices than tensor dimensions, |
2457 | // which implicitly indexes the remaining dimensions with : |
2458 | append_flat_handle(no_slice); |
2459 | append_size(i); |
2460 | } else { |
2461 | py::handle arg = input_it[0]; |
2462 | input_it = input_it.slice(1); |
2463 | append_item(i, arg); |
2464 | } |
2465 | } else { |
2466 | add_dim(l.dim()); |
2467 | append_flat_handle(l.dim()); |
2468 | append_size(i); |
2469 | } |
2470 | } |
2471 | // any training Nones may have no existing dimension associated with them in self. |
2472 | parse_nones(); |
2473 | |
2474 | // we have to restride the tensor to collapse dimension packs and introduce our none dimensions. |
2475 | if (has_dimpacks_or_none) { |
2476 | self_info.tensor = A.autorelease(self_info.tensor->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()),at::IntArrayRef(nsd.begin(), nsd.end()), self_info.tensor->storage_offset())); |
2477 | } |
2478 | |
2479 | |
2480 | // figure out what the shape of the indexing tensors will be |
2481 | // and what the shape of the resulting tensor will be |
2482 | Slice<DimEntry> result_levels; |
2483 | Slice<DimEntry> index_levels; |
2484 | int64_t tensor_insert_point = -1; |
2485 | bool requires_getindex = false; |
2486 | auto mark_tensor_index = [&] { |
2487 | if (tensor_insert_point == -1) { |
2488 | tensor_insert_point = result_levels.size(); |
2489 | } else if (tensor_insert_point != result_levels.size()) { |
2490 | tensor_insert_point = 0; |
2491 | } |
2492 | }; |
2493 | for (auto i : flat_inputs.enumerate()) { |
2494 | auto inp = flat_inputs[i]; |
2495 | if(tensor_inputs[i]) { |
2496 | requires_getindex = true; |
2497 | mark_tensor_index(); |
2498 | for (auto l : tensor_inputs[i].levels) { |
2499 | // std::cout << "Consider to add " << l << "\n"; |
2500 | if (!index_levels.contains(l)) { |
2501 | index_levels.append(A, l); |
2502 | } |
2503 | } |
2504 | } else if (Dim::check_exact(inp)) { |
2505 | auto d = Dim::unchecked_wrap(inp); |
2506 | // dimesions used once are just binding operations |
2507 | if (1 == seen_dims_nuses[*seen_dims.index(d)]) { |
2508 | flat_inputs[i] = no_slice; |
2509 | result_levels.append(A, d); |
2510 | } else { |
2511 | requires_getindex = true; |
2512 | flat_inputs[i] = py::handle(); |
2513 | tensor_inputs[i] = TensorInfo {d->range(), Slice<DimEntry>(A, DimEntry(d)), false, TensorRef()}; |
2514 | if (!index_levels.contains(d)) { |
2515 | index_levels.append(A, d); |
2516 | } |
2517 | mark_tensor_index(); |
2518 | } |
2519 | } else { |
2520 | if (inp.ptr() != no_slice.ptr()) { |
2521 | requires_getindex = true; |
2522 | } |
2523 | if (!py::is_int(inp)) { |
2524 | // note: actual positional indexes are accurately computed later |
2525 | result_levels.append(A, -1); |
2526 | } |
2527 | } |
2528 | } |
2529 | |
2530 | // indexing dimensions appear in the tensor at the _first use of a tensor_ in the indexing. So insert |
2531 | // the indexing leveles into the result klevels at this spot |
2532 | if (tensor_insert_point != -1) { |
2533 | result_levels.insert(A, result_levels.slice(tensor_insert_point, tensor_insert_point), index_levels); |
2534 | } |
2535 | |
2536 | // std::cout << "flat inputs: " << flat_inputs << "\n"; |
2537 | // std::cout << "result_levels: " << result_levels << "\n"; |
2538 | // std::cout << "index_levels: " << index_levels << "\n"; |
2539 | |
2540 | // get all the tensors to be the right shape for indexing |
2541 | if (requires_getindex) { |
2542 | for (auto i : flat_inputs.enumerate()) { |
2543 | if (tensor_inputs[i]) { |
2544 | AT_ASSERT(!flat_inputs[i].ptr()); |
2545 | // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << "\n"; |
2546 | TensorRef t = tensor_inputs[i].tensor; |
2547 | if (!tensor_inputs[i].has_device && device_holding_tensor) { |
2548 | t = A.autorelease(t->to(device_holding_tensor->device())); |
2549 | } |
2550 | flat_inputs[i] = handle_from_tensor(A, _match_levels(A, t, tensor_inputs[i].levels, index_levels)); |
2551 | } |
2552 | } |
2553 | } |
2554 | |
2555 | // previously we didn't know how many positional dimensions there would be so we couldn't number them right |
2556 | // so fill it in now. |
2557 | auto seen_positionals = 0; |
2558 | for (auto i : result_levels.reversed_enumerate()) { |
2559 | if (result_levels[i].is_positional()) { |
2560 | result_levels[i] = -(++seen_positionals); |
2561 | } |
2562 | } |
2563 | |
2564 | return IndexingInfo {false, requires_getindex, self_info.tensor, flat_inputs, result_levels, self_info.has_device}; |
2565 | } |
2566 | |
2567 | static py::object invoke_getitem(Arena& A, const IndexingInfo& iinfo) { |
2568 | at::Tensor rtensor; |
2569 | if (iinfo.advanced_indexing) { |
2570 | auto self_hdl = handle_from_tensor(A, iinfo.self); |
2571 | auto tup = slice_to_tuple(iinfo.flat_inputs); |
2572 | // std::cout << "calling original getindex " << self_hdl << " " << tup << "\n"; |
2573 | auto pytensor = py::object::checked_steal(THPVariable_getitem(self_hdl.ptr(), tup.ptr())); |
2574 | rtensor = THPVariable_Unpack(pytensor.ptr()); |
2575 | } else { |
2576 | // std::cout << "skipping original getindex\n"; |
2577 | rtensor = *iinfo.self; |
2578 | } |
2579 | // std::cout << "returning (from_positional)\n"; |
2580 | return Tensor::from_positional(A, std::move(rtensor), iinfo.result_levels, iinfo.has_device); |
2581 | } |
2582 | |
2583 | static py::object __getitem__(Arena & A, py::handle self, py::handle index) { |
2584 | maybeInitializeGlobals(); |
2585 | auto iinfo = getsetitem(A, self, index, has_dims(self)); |
2586 | if (iinfo.can_call_original) { |
2587 | return py::object::checked_steal(THPVariable_getitem(self.ptr(), index.ptr())); |
2588 | } |
2589 | |
2590 | return invoke_getitem(A, iinfo); |
2591 | } |
2592 | |
2593 | |
2594 | PyObject* Tensor_getitem(PyObject* self, PyObject* index) { |
2595 | Arena A; |
2596 | PY_BEGIN |
2597 | return __getitem__(A, self, index).release(); |
2598 | PY_END(nullptr); |
2599 | } |
2600 | |
2601 | static void __setitem__(Arena & A, py::handle self, py::handle index, py::handle rhs) { |
2602 | maybeInitializeGlobals(); |
2603 | auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs)); |
2604 | if (iinfo.can_call_original) { |
2605 | if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) { |
2606 | throw py::exception_set(); |
2607 | } |
2608 | return; |
2609 | } |
2610 | |
2611 | auto rhs_info = TensorInfo::create(A, rhs, false, false); |
2612 | if (rhs_info) { // otherwise rhs can be a scalar... |
2613 | for (auto l : rhs_info.levels) { |
2614 | if (!iinfo.result_levels.contains(l)) { |
2615 | if (l.is_positional()) { |
2616 | py::raise_error(DimensionBindError(), "rhs contains too many dimensions (%d) compared to indexed value (%d)" , ndim_of_levels(iinfo.result_levels), rhs_info.ndim()); |
2617 | } else { |
2618 | auto tup = levels_to_tuple(iinfo.result_levels); |
2619 | py::raise_error(DimensionBindError(), "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)" , l.dim().ptr(), tup.ptr()); |
2620 | } |
2621 | } |
2622 | } |
2623 | auto rhs_matched = _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels); |
2624 | rhs = handle_from_tensor(A, rhs_matched); |
2625 | } |
2626 | self = handle_from_tensor(A, iinfo.self); |
2627 | |
2628 | if (iinfo.advanced_indexing) { |
2629 | auto tup = slice_to_tuple(iinfo.flat_inputs); |
2630 | if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) { |
2631 | throw py::exception_set(); |
2632 | } |
2633 | } else { |
2634 | torch_Tensor_copy_.call(self, rhs); |
2635 | } |
2636 | } |
2637 | |
2638 | |
2639 | int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value) { |
2640 | Arena A; |
2641 | PY_BEGIN |
2642 | __setitem__(A, self, index, value); |
2643 | return 0; |
2644 | PY_END(-1); |
2645 | } |
2646 | |
2647 | static PyObject* py___getitem__(PyObject *_, |
2648 | PyObject *const *args, |
2649 | Py_ssize_t nargs, |
2650 | PyObject *kwnames) { |
2651 | Arena A; |
2652 | PY_BEGIN |
2653 | AT_ASSERT(nargs == 2); |
2654 | return __getitem__(A, args[0], args[1]).release(); |
2655 | PY_END(nullptr) |
2656 | } |
2657 | |
2658 | static PyObject* py___setitem__(PyObject *_, |
2659 | PyObject *const *args, |
2660 | Py_ssize_t nargs, |
2661 | PyObject *kwnames) { |
2662 | Arena A; |
2663 | PY_BEGIN |
2664 | AT_ASSERT(nargs == 3); |
2665 | __setitem__(A, args[0], args[1], args[2]); |
2666 | Py_RETURN_NONE; |
2667 | PY_END(nullptr) |
2668 | } |
2669 | |
2670 | |
2671 | static PyObject* py_index(PyObject *_, |
2672 | PyObject *const *args, |
2673 | Py_ssize_t nargs, |
2674 | PyObject *kwnames) { |
2675 | Arena A; |
2676 | PY_BEGIN |
2677 | py::vector_args va(args, nargs, kwnames); |
2678 | py::handle self, dims, indices; |
2679 | va.parse("index" , {"self" , "dims" , "indices" }, {&self, &dims, &indices}, 3); |
2680 | return index(A, self, dims, indices).release(); |
2681 | PY_END(nullptr) |
2682 | } |
2683 | |
2684 | |
2685 | static PyObject* py_stack(PyObject *_, |
2686 | PyObject *const *args, |
2687 | Py_ssize_t nargs, |
2688 | PyObject *kwnames) { |
2689 | Arena A; |
2690 | PY_BEGIN |
2691 | py::vector_args va(args, nargs, kwnames); |
2692 | py::handle tensors, new_dim, dim; |
2693 | va.parse("stack" , {"tensors" , "new_dim" , "dim" }, {&tensors, &new_dim, &dim}, 2); |
2694 | |
2695 | Slice<DimEntry> result_levels; |
2696 | Slice<TensorInfo> infos; |
2697 | py::sequence_view sv(tensors); |
2698 | auto new_dim_d = Dim::wrap(new_dim); |
2699 | for (auto i : sv.enumerate()) { |
2700 | infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false)); |
2701 | for (auto l : infos.back().levels) { |
2702 | if (!result_levels.contains(l)) { |
2703 | result_levels.append(A, l); |
2704 | } |
2705 | } |
2706 | } |
2707 | new_dim_d->set_size(infos.size()); |
2708 | std::vector<at::Tensor> inputs; |
2709 | inputs.reserve(infos.size()); |
2710 | for (auto in : infos) { |
2711 | inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels)); |
2712 | } |
2713 | auto ndim = ndim_of_levels(result_levels); |
2714 | int64_t rawdim = 0; |
2715 | if (dim.ptr()) { |
2716 | auto d = _wrap_dim(dim, ndim, false); |
2717 | auto idx = result_levels.index(d); |
2718 | if (!idx) { |
2719 | py::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs" , dim.ptr()); |
2720 | } |
2721 | rawdim = *idx; |
2722 | } |
2723 | auto result = at::stack(inputs, rawdim); |
2724 | result_levels.insert(A, rawdim, new_dim_d); |
2725 | return Tensor::from_positional(A, std::move(result), result_levels, true).release(); |
2726 | PY_END(nullptr) |
2727 | } |
2728 | |
2729 | static PyObject* py_split(PyObject *_, |
2730 | PyObject *const *args, |
2731 | Py_ssize_t nargs, |
2732 | PyObject *kwnames) { |
2733 | Arena A; |
2734 | PY_BEGIN |
2735 | maybeInitializeGlobals(); |
2736 | py::vector_args va(args, nargs, kwnames); |
2737 | py::handle self, split_size_or_sections, dim; |
2738 | va.parse("split" , {"self" , "split_size_or_sections" , "dim" }, {&self, &split_size_or_sections, &dim}, 2); |
2739 | bool dim_is_object = dim.ptr() && Dim::check_exact(dim); |
2740 | Slice<py::handle> sizes; |
2741 | |
2742 | bool all_dims = true; |
2743 | bool all_ints = true; |
2744 | |
2745 | if (!py::is_int(split_size_or_sections)) { |
2746 | py::sequence_view sv(split_size_or_sections); |
2747 | for (auto i : sv.enumerate()) { |
2748 | sizes.append(A, A.autorelease(sv[i])); |
2749 | if (Dim::check_exact(sizes.back())) { |
2750 | all_ints = false; |
2751 | } else { |
2752 | all_dims = false; |
2753 | } |
2754 | } |
2755 | } |
2756 | if (all_ints) { |
2757 | if (dim_is_object) { |
2758 | py::raise_error(PyExc_TypeError, "when dim is specified as a Dim object, split sizes must also be dimensions." ); |
2759 | } |
2760 | // call original split (if self has dimensions this will use torch function to do the split) |
2761 | return torch_Tensor_split.call_vector(py::vector_args(args, nargs, kwnames)).release(); |
2762 | } |
2763 | if (!all_dims) { |
2764 | py::raise_error(PyExc_TypeError, "split list must be ints or dims but got a mix" ); |
2765 | } |
2766 | |
2767 | auto self_info = TensorInfo::create(A, self, false); |
2768 | auto ndim = self_info.ndim(); |
2769 | if (!dim_is_object&& ndim == 0) { |
2770 | py::raise_error(PyExc_TypeError, "split expects at least a 1-dimension tensor" ); |
2771 | } |
2772 | DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim; |
2773 | |
2774 | auto idx = self_info.levels.index(dim_l); |
2775 | if (!idx) { |
2776 | if (!dim.ptr()) { |
2777 | dim = A.autorelease(py::from_int(0)); |
2778 | } |
2779 | py::raise_error(PyExc_TypeError, "tensor does not comtain dimension %R" , dim.ptr()); |
2780 | } |
2781 | Slice<int64_t> indices; |
2782 | |
2783 | int64_t total_size = 0; |
2784 | Slice<int64_t> unbound; |
2785 | for (auto i : sizes.enumerate()) { |
2786 | auto d = Dim::unchecked_wrap(sizes[i]); |
2787 | if (d->is_bound()) { |
2788 | indices.append(A, d->size()); |
2789 | total_size += indices.back(); |
2790 | } else { |
2791 | indices.append(A, 0); |
2792 | unbound.append(A, i); |
2793 | } |
2794 | } |
2795 | auto tensor_size = self_info.tensor->sizes()[*idx]; |
2796 | |
2797 | if (unbound.size()) { |
2798 | if (total_size > tensor_size) { |
2799 | py::raise_error(PyExc_TypeError, "sizes of target dimensions add up to more (%d) than source dim (%d)" , int(total_size), int(tensor_size)); |
2800 | } |
2801 | auto remaining_size = tensor_size - total_size; |
2802 | auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size(); |
2803 | for (auto u : unbound) { |
2804 | auto sz = std::min(chunk_size, remaining_size); |
2805 | Dim::unchecked_wrap(sizes[u])->set_size(sz); |
2806 | indices[u] = sz; |
2807 | remaining_size -= sz; |
2808 | } |
2809 | } else if (tensor_size != total_size) { |
2810 | py::raise_error(PyExc_TypeError, "sum of sizes of target dimensions (%d) do not match the than source dim (%d)" , int(total_size), int(tensor_size)); |
2811 | } |
2812 | |
2813 | auto result_tensors = self_info.tensor->split_with_sizes(at::IntArrayRef(indices.begin(), indices.end()), *idx); |
2814 | py::tuple result(result_tensors.size()); |
2815 | Slice<DimEntry> new_levels; |
2816 | new_levels.extend(A, self_info.levels); |
2817 | for (auto i : sizes.enumerate()) { |
2818 | new_levels[*idx] = Dim::unchecked_wrap(sizes[i]); |
2819 | result.set(i, Tensor::from_positional(A, std::move(result_tensors[i]), new_levels, true)); |
2820 | } |
2821 | |
2822 | return result.release(); |
2823 | |
2824 | PY_END(nullptr) |
2825 | } |
2826 | |
2827 | |
2828 | static DimEntry _wrap_dim(py::handle d, size_t N, bool keepdim) { |
2829 | if (Dim::check(d)) { |
2830 | if (keepdim) { |
2831 | py::raise_error(PyExc_ValueError, "cannot preserve first-class dimensions with keepdim=True" ); |
2832 | } |
2833 | return Dim::unchecked_wrap(d); |
2834 | } else if (py::is_int(d)) { |
2835 | auto i = py::to_int(d); |
2836 | while (i >= 0) { |
2837 | i -= N; |
2838 | } |
2839 | return i; |
2840 | } else { |
2841 | return DimEntry(); |
2842 | } |
2843 | } |
2844 | |
2845 | static Slice<DimEntry> _wrap_dims(Arena& A, py::handle d, size_t N, bool keepdim) { |
2846 | auto de = _wrap_dim(d, N, keepdim); |
2847 | Slice<DimEntry> r; |
2848 | if (!de.is_none()) { |
2849 | r.append(A, de); |
2850 | } else { |
2851 | py::sequence_view sq(d); |
2852 | for (auto i : sq.enumerate()) { |
2853 | r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim)); |
2854 | } |
2855 | } |
2856 | return r; |
2857 | } |
2858 | |
2859 | struct WrappedOperator : public py::base<WrappedOperator> { |
2860 | py::object orig; |
2861 | PyMethodDef method_def; |
2862 | py::object name, doc; |
2863 | |
2864 | bool is_pointwise = false; |
2865 | int64_t dim_offset = 0; |
2866 | int64_t keepdim_offset = 1; |
2867 | std::string dim_name; |
2868 | bool single_dim = false; |
2869 | bool reduce = true; |
2870 | |
2871 | static PyTypeObject Type; |
2872 | |
2873 | void init(py::object orig_, PyCFunction wrapper_implementation, std::string dim_name_="" ) { |
2874 | orig = std::move(orig_); |
2875 | method_def.ml_meth = wrapper_implementation; |
2876 | name = orig.attr("__name__" ); |
2877 | doc = orig.attr("__doc__" ); |
2878 | dim_name = std::move(dim_name_); |
2879 | if (!py::is_none(doc) && !dim_name.empty()) { |
2880 | doc = py::unicode_from_format("%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n" , doc.ptr(), dim_name.c_str()); |
2881 | } |
2882 | method_def.ml_name = py::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr()); |
2883 | method_def.ml_doc = py::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr()); |
2884 | method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS; |
2885 | } |
2886 | |
2887 | py::object function() { |
2888 | return py::object::checked_steal(PyCFunction_New(&method_def, ptr())); |
2889 | } |
2890 | |
2891 | }; |
2892 | |
2893 | PyTypeObject WrappedOperator::Type = { |
2894 | PyVarObject_HEAD_INIT(NULL, 0) |
2895 | "_C.WrappedOperator" , /* tp_name */ |
2896 | sizeof(WrappedOperator), /* tp_basicsize */ |
2897 | 0, /* tp_itemsize */ |
2898 | WrappedOperator::dealloc_stub, /* tp_dealloc */ |
2899 | 0, /* tp_vectorcall_offset */ |
2900 | 0, /* tp_getattr */ |
2901 | 0, /* tp_setattr */ |
2902 | 0, /* tp_as_async */ |
2903 | 0, /* tp_repr */ |
2904 | 0, /* tp_as_number */ |
2905 | 0, /* tp_as_sequence */ |
2906 | 0, /* tp_as_mapping */ |
2907 | 0, /* tp_hash */ |
2908 | 0, /* tp_call */ |
2909 | 0, /* tp_str */ |
2910 | 0, /* tp_getattro */ |
2911 | 0, /* tp_setattro */ |
2912 | 0, /* tp_as_buffer */ |
2913 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
2914 | "Wrapped Object Holder" , /* tp_doc */ |
2915 | 0, /* tp_traverse */ |
2916 | 0, /* tp_clear */ |
2917 | 0, /* tp_richcompare */ |
2918 | 0, /* tp_weaklistoffset */ |
2919 | 0, /* tp_iter */ |
2920 | 0, /* tp_iternext */ |
2921 | 0, /* tp_methods */ |
2922 | 0, /* tp_members */ |
2923 | 0, /* tp_getset */ |
2924 | 0, /* tp_base */ |
2925 | 0, /* tp_dict */ |
2926 | 0, /* tp_descr_get */ |
2927 | 0, /* tp_descr_set */ |
2928 | 0, /* tp_dictoffset */ |
2929 | 0, /* tp_init */ |
2930 | 0, /* tp_alloc */ |
2931 | WrappedOperator::new_stub, /* tp_new */ |
2932 | }; |
2933 | |
2934 | static PyObject* patched_dim_method(PyObject * self_, |
2935 | PyObject *const *args, |
2936 | Py_ssize_t nargs, |
2937 | PyObject *kwnames) { |
2938 | Arena A; |
2939 | auto self = WrappedOperator::unchecked_wrap(self_); |
2940 | PY_BEGIN |
2941 | |
2942 | py::vector_args va(args, nargs, kwnames); |
2943 | |
2944 | auto _getarg = [&](const char* name, int64_t offset_) -> py::handle { |
2945 | auto offset = offset_ + 1; // do not include self |
2946 | auto idx = va.index(name, offset); |
2947 | return idx == -1 ? py::handle() : va[idx]; |
2948 | }; |
2949 | Slice<py::handle> patched_args; |
2950 | patched_args.extend(A, va.begin(), va.end()); |
2951 | auto _patcharg = [&](const char* name, int64_t offset_, py::handle value) { |
2952 | auto offset = offset_ + 1; // do not include self |
2953 | auto idx = va.index(name, offset); |
2954 | if (idx == -1) { |
2955 | py::raise_error(PyExc_ValueError, "Missing argument %s" , name); |
2956 | } |
2957 | patched_args[idx] = value; |
2958 | }; |
2959 | |
2960 | auto dim = _getarg(self->dim_name.c_str(), self->dim_offset); |
2961 | if (!dim.ptr()) { |
2962 | auto info = TensorInfo::create(A, args[0], true); |
2963 | EnableAllLayers l(A, info.levels); |
2964 | l.inplace_update_layers(info.batchedtensor, info.levels); |
2965 | patched_args[0] = handle_from_tensor(A, info.batchedtensor); |
2966 | auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); |
2967 | return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device).release(); |
2968 | } |
2969 | |
2970 | auto info = TensorInfo::create(A, args[0]); |
2971 | auto keepdim = false; |
2972 | if (self->reduce) { |
2973 | auto py_keepdim = _getarg("keepdim" , self->keepdim_offset); |
2974 | if (py_keepdim.ptr()) { |
2975 | keepdim = py::to_bool(py_keepdim); |
2976 | } |
2977 | } |
2978 | |
2979 | auto ndim = info.ndim(); |
2980 | auto dims = _wrap_dims(A, dim, ndim, keepdim); |
2981 | Slice<int64_t> dim_indices; |
2982 | auto seen = A.allocate<bool>(info.levels.size()); |
2983 | std::fill(seen, seen + info.levels.size(), false); |
2984 | |
2985 | for (auto d : dims) { |
2986 | auto midx = info.levels.index(d); |
2987 | if (!midx) { |
2988 | auto tup = levels_to_tuple(info.levels); |
2989 | py::raise_error(PyExc_ValueError, "Tensor with dimensions %R does not contain one of %R\n" , tup.ptr(), dim.ptr()); |
2990 | } |
2991 | seen[*midx] = true; |
2992 | dim_indices.append(A, *midx); |
2993 | } |
2994 | Slice<DimEntry> new_levels; |
2995 | if (self->reduce && !keepdim) { |
2996 | for (auto i : info.levels.enumerate()) { |
2997 | if (!seen[i]) { |
2998 | new_levels.append(A, info.levels[i]); |
2999 | } |
3000 | } |
3001 | } else { |
3002 | new_levels = info.levels; |
3003 | } |
3004 | py::object py_indices; |
3005 | if (dim_indices.size() == 1) { |
3006 | py_indices = py::from_int(dim_indices[0]); |
3007 | } else { |
3008 | py::tuple tup(dim_indices.size()); |
3009 | for (auto i : dim_indices.enumerate()) { |
3010 | tup.set(i, py::from_int(dim_indices[i])); |
3011 | } |
3012 | py_indices = std::move(tup); |
3013 | } |
3014 | _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices); |
3015 | patched_args[0] = handle_from_tensor(A, info.tensor); |
3016 | auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); |
3017 | auto wrap = [&](py::handle h) { |
3018 | if (THPVariable_Check(h.ptr())) { |
3019 | return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device)); |
3020 | } |
3021 | return h; |
3022 | }; |
3023 | return tree_map(A, wrap, r).release(); |
3024 | PY_END(nullptr) |
3025 | } |
3026 | |
3027 | static PyObject* _wrap(PyObject * self_, |
3028 | PyObject *const *args, |
3029 | Py_ssize_t nargs, |
3030 | PyObject *kwnames) { |
3031 | Arena A; |
3032 | PY_BEGIN |
3033 | |
3034 | #define ARGS(_) _(py::handle, orig) _(py::handle, dim_offset) _(py::handle, keepdim_offset) \ |
3035 | _(py::handle, dim_name) _(py::handle, single_dim) _(py::handle, reduce) |
3036 | MPY_PARSE_ARGS_KWNAMES("O|OOOOO" , ARGS) |
3037 | |
3038 | std::string dim_name_str; |
3039 | if (dim_name.ptr()) { |
3040 | dim_name_str = PyUnicode_AsUTF8(dim_name.ptr()); |
3041 | } else { |
3042 | dim_name_str = "dim" ; |
3043 | } |
3044 | auto info = WrappedOperator::create(py::object::borrow(orig), (PyCFunction)(void*) patched_dim_method, std::move(dim_name_str)); |
3045 | if (dim_offset.ptr()) { |
3046 | info->dim_offset = py::to_int(dim_offset); |
3047 | } |
3048 | if (keepdim_offset.ptr()) { |
3049 | info->keepdim_offset = py::to_int(keepdim_offset); |
3050 | } |
3051 | |
3052 | if (single_dim.ptr()) { |
3053 | info->single_dim = py::to_bool(single_dim); |
3054 | } |
3055 | if (reduce.ptr()) { |
3056 | info->reduce = py::to_bool(reduce); |
3057 | } |
3058 | return info->function().release(); |
3059 | #undef ARGS |
3060 | |
3061 | PY_END(nullptr) |
3062 | } |
3063 | |
3064 | static PyObject* call_torch_function(PyObject *self, |
3065 | PyObject *const *args, |
3066 | Py_ssize_t nargs, |
3067 | PyObject *kwnames) { |
3068 | PY_BEGIN |
3069 | Arena A; |
3070 | maybeInitializeGlobals(); |
3071 | auto info = WrappedOperator::unchecked_wrap(self); |
3072 | return __torch_function__(A, info->orig, py::vector_args(args, nargs, kwnames), info->is_pointwise).release(); |
3073 | PY_END(nullptr) |
3074 | } |
3075 | |
3076 | static PyObject* _wrap_method(PyObject *self, |
3077 | PyObject *const *args, |
3078 | Py_ssize_t nargs, |
3079 | PyObject *kwnames) { |
3080 | PY_BEGIN |
3081 | AT_ASSERT(nargs == 2); |
3082 | // XXX - ignore python function wrapped, we will call torch function directly |
3083 | py::handle orig = args[0]; |
3084 | if (!pointwise.ptr()) { |
3085 | auto dim = py::import("functorch.dim" ); |
3086 | pointwise = dim.attr("pointwise" ); |
3087 | } |
3088 | auto info = WrappedOperator::create(py::object::borrow(orig), (PyCFunction)(void*) call_torch_function); |
3089 | info->is_pointwise = pointwise.contains(orig); |
3090 | return PyInstanceMethod_New(info->function().release()); |
3091 | PY_END(nullptr); |
3092 | } |
3093 | |
3094 | |
3095 | static PyObject* Tensor_sum(PyObject * self_, |
3096 | PyObject *const *args, |
3097 | Py_ssize_t nargs, |
3098 | PyObject *kwnames) { |
3099 | Arena A; |
3100 | PY_BEGIN |
3101 | maybeInitializeGlobals(); |
3102 | py::vector_args va(args, nargs, kwnames); |
3103 | auto self_ = Tensor::unchecked_wrap(args[0]); |
3104 | auto d = self_->delayed(); |
3105 | if (!d) { |
3106 | return _Tensor_sum.call_vector(va).release(); |
3107 | } |
3108 | py::handle self, dim, keepdim, dtype; |
3109 | va.parse("sum" , {"self" , "dim" , "keepdim" , "dtype" }, {&self, &dim, &keepdim, &dtype}, 1, 1); |
3110 | |
3111 | if (dtype.ptr() || (keepdim.ptr() && py::to_bool(keepdim))) { |
3112 | // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n"; |
3113 | return _Tensor_sum.call_vector(va).release(); |
3114 | } |
3115 | auto levels = self_->levels(); |
3116 | |
3117 | auto N = ndim_of_levels(levels); |
3118 | auto reduced_dims = _wrap_dims(A, dim, N, false); |
3119 | |
3120 | return dot(A, TensorInfo::create(A, d->args[0], false), TensorInfo::create(A, d->args[1], false), reduced_dims).release(); |
3121 | PY_END(nullptr) |
3122 | } |
3123 | |
3124 | static PyObject* _parse_test(PyObject * self_, |
3125 | PyObject *const *args, |
3126 | Py_ssize_t nargs, |
3127 | PyObject *kwnames) { |
3128 | PY_BEGIN |
3129 | maybeInitializeGlobals(); |
3130 | |
3131 | int required = py::to_int(args[0]); |
3132 | int kwonly = py::to_int(args[1]); |
3133 | |
3134 | py::vector_args va(args + 2, nargs - 2, kwnames); |
3135 | |
3136 | |
3137 | py::handle a, b, c, d; |
3138 | va.parse("_parse_test" , {"a" , "b" , "c" , "d" }, {&a, &b, &c, &d}, required, kwonly); |
3139 | py::tuple r(4); |
3140 | r.set(0, py::object::borrow(a.ptr() ? a : Py_None)); |
3141 | r.set(1, py::object::borrow(b.ptr() ? b : Py_None)); |
3142 | r.set(2, py::object::borrow(c.ptr() ? c : Py_None)); |
3143 | r.set(3, py::object::borrow(d.ptr() ? d : Py_None)); |
3144 | return r.release(); |
3145 | |
3146 | PY_END(nullptr) |
3147 | } |
3148 | |
3149 | static PyObject* _set_pointwise_optimize(PyObject * self_, |
3150 | PyObject *const *args, |
3151 | Py_ssize_t nargs, |
3152 | PyObject *kwnames) { |
3153 | PY_BEGIN |
3154 | py::handle value; |
3155 | py::vector_args va(args, nargs, kwnames); |
3156 | va.parse("_set_pointwise_optimization" , {"value" }, {&value}, 1); |
3157 | pointwise_optimize = py::to_bool(value); |
3158 | Py_RETURN_NONE; |
3159 | PY_END(nullptr) |
3160 | } |
3161 | |
3162 | static PyObject* _patch_tensor_class(PyObject * self_, |
3163 | PyObject *const *args, |
3164 | Py_ssize_t nargs, |
3165 | PyObject *kwnames) { |
3166 | PY_BEGIN |
3167 | |
3168 | auto torch = py::import("torch" ); |
3169 | auto py_TensorBase = torch.attr("_C" ).attr("_TensorBase" ); |
3170 | replaceMappingIfMatches(py_TensorBase); |
3171 | |
3172 | Py_RETURN_NONE; |
3173 | PY_END(nullptr) |
3174 | } |
3175 | |
3176 | |
3177 | const char* dims_doc = R"""( |
3178 | dims(n=None, sizes=None) -> torchdim.Dim or Tuple[torchdim.Dim, ...] |
3179 | |
3180 | Creates and returns one or more Dim objects. |
3181 | |
3182 | Arg: |
3183 | n (int, optional): The number of dimensions to create. Can be omitted if sizes is specified. |
3184 | sizes (List[Optional[int]], optional): A list the same size as the number of dimensions to be |
3185 | created, specifying each dimensions size, or None to leave the size unset. |
3186 | |
3187 | Example:: |
3188 | >>> batch, channel, width, height = dims(4) |
3189 | >>> batch, channel, width, height = dims(sizes=[None, 3, 224, 224]) |
3190 | )""" ; |
3191 | |
3192 | static PyMethodDef methods[] = { |
3193 | {"dims" , (PyCFunction)(void*) _dims<create_dim>, METH_FASTCALL | METH_KEYWORDS, dims_doc}, |
3194 | {"dimlists" , (PyCFunction)(void*) _dims<create_dimlist>, METH_FASTCALL | METH_KEYWORDS}, |
3195 | {"_test_c" , (PyCFunction)(void*) test_c, METH_FASTCALL | METH_KEYWORDS}, |
3196 | {"_wrap_method" , (PyCFunction)(void*) _wrap_method, METH_FASTCALL | METH_KEYWORDS}, |
3197 | {"Tensor_from_positional" , (PyCFunction)(void*) py_Tensor_from_positional, METH_FASTCALL | METH_KEYWORDS}, |
3198 | {"__torch_function__" , (PyCFunction)(void*) py___torch_function__, METH_FASTCALL | METH_KEYWORDS}, |
3199 | {"tree_flatten" , (PyCFunction)(void*) py_tree_flatten, METH_FASTCALL | METH_KEYWORDS}, |
3200 | {"order" , (PyCFunction)(void*) order, METH_FASTCALL | METH_KEYWORDS}, |
3201 | {"index" , (PyCFunction)(void*) py_index, METH_FASTCALL | METH_KEYWORDS}, |
3202 | {"stack" , (PyCFunction)(void*) py_stack, METH_FASTCALL | METH_KEYWORDS}, |
3203 | {"split" , (PyCFunction)(void*) py_split, METH_FASTCALL | METH_KEYWORDS}, |
3204 | {"expand" , (PyCFunction)(void*) expand, METH_FASTCALL | METH_KEYWORDS}, |
3205 | {"__getitem__" , (PyCFunction)(void*) py___getitem__, METH_FASTCALL | METH_KEYWORDS}, |
3206 | {"__setitem__" , (PyCFunction)(void*) py___setitem__, METH_FASTCALL | METH_KEYWORDS}, |
3207 | {"_wrap" , (PyCFunction)(void*) _wrap, METH_FASTCALL | METH_KEYWORDS}, |
3208 | {"Tensor_sum" , (PyCFunction)(void*) Tensor_sum, METH_FASTCALL | METH_KEYWORDS}, |
3209 | {"_parse_test" , (PyCFunction)(void*) _parse_test, METH_FASTCALL | METH_KEYWORDS}, |
3210 | {"_set_pointwise_optimize" , (PyCFunction)(void*) _set_pointwise_optimize, METH_FASTCALL | METH_KEYWORDS}, |
3211 | {"_patch_tensor_class" , (PyCFunction)(void*) _patch_tensor_class, METH_FASTCALL | METH_KEYWORDS}, |
3212 | {NULL, NULL, 0, NULL} /* Sentinel */ |
3213 | }; |
3214 | |
3215 | static struct PyModuleDef module_def = { |
3216 | PyModuleDef_HEAD_INIT, |
3217 | "_C" , /* name of module */ |
3218 | NULL, /* module documentation, may be NULL */ |
3219 | -1, /* size of per-interpreter state of the module, |
3220 | or -1 if the module keeps state in global variables. */ |
3221 | methods |
3222 | }; |
3223 | |
3224 | PyObject* Dim_init(void) { |
3225 | Arena A; |
3226 | try { |
3227 | py::object mod = py::object::checked_steal(PyModule_Create(&module_def)); |
3228 | Dim::ready(mod, "Dim" ); |
3229 | DimList::ready(mod, "DimList" ); |
3230 | Tensor::ready(mod, "Tensor" ); |
3231 | WrappedOperator::ready(mod, "_WrappedOperator" ); |
3232 | Py_INCREF(&PyInstanceMethod_Type); |
3233 | PyModule_AddObject(mod.ptr(), "_instancemethod" , (PyObject *)&PyInstanceMethod_Type); |
3234 | |
3235 | initializeGlobals(A); |
3236 | return mod.release(); |
3237 | } catch(py::exception_set& err) { |
3238 | return nullptr; |
3239 | } |
3240 | } |
3241 | |