1// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#pragma once
8#define PY_SSIZE_T_CLEAN
9#include <Python.h>
10#include <utility>
11#include <iostream>
12#include <memory>
13
14#define PY_BEGIN try {
15#define PY_END(v) } catch(py::exception_set & err) { return (v); }
16
17#if PY_VERSION_HEX < 0x03080000
18 #define PY_VECTORCALL _PyObject_FastCallKeywords
19#else
20 #define PY_VECTORCALL _PyObject_Vectorcall
21#endif
22
23struct irange {
24 public:
25 irange(int64_t end)
26 : irange(0, end, 1) {}
27 irange(int64_t begin, int64_t end, int64_t step = 1)
28 : begin_(begin), end_(end), step_(step) {}
29 int64_t operator*() const {
30 return begin_;
31 }
32 irange& operator++() {
33 begin_ += step_;
34 return *this;
35 }
36 bool operator!=(const irange& other) {
37 return begin_ != other.begin_;
38 }
39 irange begin() {
40 return *this;
41 }
42 irange end() {
43 return irange {end_, end_, step_};
44 }
45 private:
46 int64_t begin_;
47 int64_t end_;
48 int64_t step_;
49};
50
51namespace py {
52
53struct exception_set {
54};
55
56struct object;
57struct vector_args;
58
59struct handle {
60 handle(PyObject* ptr)
61 : ptr_(ptr) {}
62 handle() = default;
63
64
65 PyObject* ptr() const {
66 return ptr_;
67 }
68 object attr(const char* key);
69 bool hasattr(const char* key);
70 handle type() const {
71 return (PyObject*) Py_TYPE(ptr());
72 }
73
74 template<typename... Args>
75 object call(Args&&... args);
76 object call_object(py::handle args);
77 object call_object(py::handle args, py::handle kwargs);
78 object call_vector(py::handle* begin, Py_ssize_t nargs, py::handle kwnames);
79 object call_vector(vector_args args);
80 bool operator==(handle rhs) {
81 return ptr_ == rhs.ptr_;
82 }
83
84 static handle checked(PyObject* ptr) {
85 if (!ptr) {
86 throw exception_set();
87 }
88 return ptr;
89 }
90
91protected:
92 PyObject* ptr_ = nullptr;
93};
94
95
96template<typename T>
97struct obj;
98
99template<typename T>
100struct hdl : public handle {
101 T* ptr() {
102 return (T*) handle::ptr();
103 }
104 T* operator->() {
105 return ptr();
106 }
107 hdl(T* ptr)
108 : hdl((PyObject*) ptr) {}
109 hdl(const obj<T>& o)
110 : hdl(o.ptr()) {}
111private:
112 hdl(handle h) : handle(h) {}
113};
114
115struct object : public handle {
116 object() = default;
117 object(const object& other)
118 : handle(other.ptr_) {
119 Py_XINCREF(ptr_);
120 }
121 object(object&& other) noexcept
122 : handle(other.ptr_) {
123 other.ptr_ = nullptr;
124 }
125 object& operator=(const object& other) {
126 return *this = object(other);
127 }
128 object& operator=(object&& other) noexcept {
129 PyObject* tmp = ptr_;
130 ptr_ = other.ptr_;
131 other.ptr_ = tmp;
132 return *this;
133 }
134 ~object() {
135 Py_XDECREF(ptr_);
136 }
137 static object steal(handle o) {
138 return object(o.ptr());
139 }
140 static object checked_steal(handle o) {
141 if (!o.ptr()) {
142 throw exception_set();
143 }
144 return steal(o);
145 }
146 static object borrow(handle o) {
147 Py_XINCREF(o.ptr());
148 return steal(o);
149 }
150 PyObject* release() {
151 auto tmp = ptr_;
152 ptr_ = nullptr;
153 return tmp;
154 }
155protected:
156 explicit object(PyObject* ptr)
157 : handle(ptr) {}
158};
159
160template<typename T>
161struct obj : public object {
162 obj() = default;
163 obj(const obj& other)
164 : object(other.ptr_) {
165 Py_XINCREF(ptr_);
166 }
167 obj(obj&& other) noexcept
168 : object(other.ptr_) {
169 other.ptr_ = nullptr;
170 }
171 obj& operator=(const obj& other) {
172 return *this = obj(other);
173 }
174 obj& operator=(obj&& other) noexcept {
175 PyObject* tmp = ptr_;
176 ptr_ = other.ptr_;
177 other.ptr_ = tmp;
178 return *this;
179 }
180 static obj steal(hdl<T> o) {
181 return obj(o.ptr());
182 }
183 static obj checked_steal(hdl<T> o) {
184 if (!o.ptr()) {
185 throw exception_set();
186 }
187 return steal(o);
188 }
189 static obj borrow(hdl<T> o) {
190 Py_XINCREF(o.ptr());
191 return steal(o);
192 }
193 T* ptr() const {
194 return (T*) object::ptr();
195 }
196 T* operator->() {
197 return ptr();
198 }
199protected:
200 explicit obj(T* ptr)
201 : object((PyObject*)ptr) {}
202};
203
204
205bool isinstance(handle h, handle c) {
206 return PyObject_IsInstance(h.ptr(), c.ptr());
207}
208
209[[ noreturn ]] void raise_error(handle exception, const char *format, ...) {
210 va_list args;
211 va_start(args, format);
212 PyErr_FormatV(exception.ptr(), format, args);
213 va_end(args);
214 throw exception_set();
215}
216
217template<typename T>
218struct base {
219 PyObject_HEAD
220 PyObject* ptr() const {
221 return (PyObject*) this;
222 }
223 static obj<T> alloc(PyTypeObject* type = nullptr) {
224 if (!type) {
225 type = &T::Type;
226 }
227 auto self = (T*) type->tp_alloc(type, 0);
228 if (!self) {
229 throw py::exception_set();
230 }
231 new (self) T;
232 return obj<T>::steal(self);
233 }
234 template<typename ... Args>
235 static obj<T> create(Args ... args) {
236 auto self = alloc();
237 self->init(std::forward<Args>(args)...);
238 return self;
239 }
240 static bool check(handle v) {
241 return isinstance(v, (PyObject*)&T::Type);
242 }
243
244 static hdl<T> unchecked_wrap(handle self_) {
245 return hdl<T>((T*)self_.ptr());
246 }
247 static hdl<T> wrap(handle self_) {
248 if (!check(self_)) {
249 raise_error(PyExc_ValueError, "not an instance of %S", &T::Type);
250 }
251 return unchecked_wrap(self_);
252 }
253
254 static obj<T> unchecked_wrap(object self_) {
255 return obj<T>::steal(unchecked_wrap(self_.release()));
256 }
257 static obj<T> wrap(object self_) {
258 return obj<T>::steal(wrap(self_.release()));
259 }
260
261 static PyObject* new_stub(PyTypeObject *type, PyObject *args, PyObject *kwds) {
262 PY_BEGIN
263 return (PyObject*) alloc(type).release();
264 PY_END(nullptr)
265 }
266 static void dealloc_stub(PyObject *self) {
267 ((T*)self)->~T();
268 Py_TYPE(self)->tp_free(self);
269 }
270 static void ready(py::handle mod, const char* name) {
271 if (PyType_Ready(&T::Type)) {
272 throw exception_set();
273 }
274 if(PyModule_AddObject(mod.ptr(), name, (PyObject*) &T::Type) < 0) {
275 throw exception_set();
276 }
277 }
278};
279
280inline object handle::attr(const char* key) {
281 return object::checked_steal(PyObject_GetAttrString(ptr(), key));
282}
283
284inline bool handle::hasattr(const char* key) {
285 return PyObject_HasAttrString(ptr(), key);
286}
287
288inline object import(const char* module) {
289 return object::checked_steal(PyImport_ImportModule(module));
290}
291
292template<typename... Args>
293inline object handle::call(Args&&... args) {
294 return object::checked_steal(PyObject_CallFunctionObjArgs(ptr_, args.ptr()..., nullptr));
295}
296
297inline object handle::call_object(py::handle args) {
298 return object::checked_steal(PyObject_CallObject(ptr(), args.ptr()));
299}
300
301
302inline object handle::call_object(py::handle args, py::handle kwargs) {
303 return object::checked_steal(PyObject_Call(ptr(), args.ptr(), kwargs.ptr()));
304}
305
306inline object handle::call_vector(py::handle* begin, Py_ssize_t nargs, py::handle kwnames) {
307 return object::checked_steal(PY_VECTORCALL(ptr(), (PyObject*const*) begin, nargs, kwnames.ptr()));
308}
309
310struct tuple : public object {
311 void set(int i, object v) {
312 PyTuple_SET_ITEM(ptr_, i, v.release());
313 }
314 tuple(int size)
315 : object(checked_steal(PyTuple_New(size))) {}
316};
317
318struct list : public object {
319 void set(int i, object v) {
320 PyList_SET_ITEM(ptr_, i, v.release());
321 }
322 list(int size)
323 : object(checked_steal(PyList_New(size))) {}
324};
325
326py::object unicode_from_format(const char* format, ...) {
327 va_list args;
328 va_start(args, format);
329 auto r = PyUnicode_FromFormatV(format, args);
330 va_end(args);
331 return py::object::checked_steal(r);
332}
333py::object unicode_from_string(const char * str) {
334 return py::object::checked_steal(PyUnicode_FromString(str));
335}
336
337py::object from_int(Py_ssize_t s) {
338 return py::object::checked_steal(PyLong_FromSsize_t(s));
339}
340py::object from_bool(bool b) {
341 return py::object::borrow(b ? Py_True : Py_False);
342}
343
344bool is_sequence(handle h) {
345 return PySequence_Check(h.ptr());
346}
347
348
349struct sequence_view : public handle {
350 sequence_view(handle h)
351 : handle(h) {}
352 Py_ssize_t size() const {
353 auto r = PySequence_Size(ptr());
354 if (r == -1 && PyErr_Occurred()) {
355 throw py::exception_set();
356 }
357 return r;
358 }
359 irange enumerate() const {
360 return irange(size());
361 }
362 static sequence_view wrap(handle h) {
363 if (!is_sequence(h)) {
364 raise_error(PyExc_ValueError, "expected a sequence");
365 }
366 return sequence_view(h);
367 }
368 py::object operator[](Py_ssize_t i) const {
369 return py::object::checked_steal(PySequence_GetItem(ptr(), i));
370 }
371};
372
373
374py::object repr(handle h) {
375 return py::object::checked_steal(PyObject_Repr(h.ptr()));
376}
377
378py::object str(handle h) {
379 return py::object::checked_steal(PyObject_Str(h.ptr()));
380}
381
382
383bool is_int(handle h) {
384 return PyLong_Check(h.ptr());
385}
386
387bool is_float(handle h) {
388 return PyFloat_Check(h.ptr());
389}
390
391bool is_none(handle h) {
392 return h.ptr() == Py_None;
393}
394
395bool is_bool(handle h) {
396 return PyBool_Check(h.ptr());
397}
398
399Py_ssize_t to_int(handle h) {
400 Py_ssize_t r = PyLong_AsSsize_t(h.ptr());
401 if (r == -1 && PyErr_Occurred()) {
402 throw py::exception_set();
403 }
404 return r;
405}
406
407double to_float(handle h) {
408 double r = PyFloat_AsDouble(h.ptr());
409 if (PyErr_Occurred()) {
410 throw py::exception_set();
411 }
412 return r;
413}
414
415bool to_bool_unsafe(handle h) {
416 return h.ptr() == Py_True;
417}
418
419bool to_bool(handle h) {
420 return PyObject_IsTrue(h.ptr()) != 0;
421}
422
423struct slice_view {
424 slice_view(handle h, Py_ssize_t size) {
425 if(PySlice_Unpack(h.ptr(), &start, &stop, &step) == -1) {
426 throw py::exception_set();
427 }
428 slicelength = PySlice_AdjustIndices(size, &start, &stop, step);
429 }
430 Py_ssize_t start, stop, step, slicelength;
431};
432
433bool is_slice(handle h) {
434 return PySlice_Check(h.ptr());
435}
436
437inline std::ostream& operator<<(std::ostream& ss, handle h) {
438 ss << PyUnicode_AsUTF8(str(h).ptr());
439 return ss;
440}
441
442struct tuple_view : public handle {
443 tuple_view() = default;
444 tuple_view(handle h) : handle(h) {}
445
446 Py_ssize_t size() const {
447 return PyTuple_GET_SIZE(ptr());
448 }
449
450 irange enumerate() const {
451 return irange(size());
452 }
453
454 handle operator[](Py_ssize_t i) {
455 return PyTuple_GET_ITEM(ptr(), i);
456 }
457
458 static bool check(handle h) {
459 return PyTuple_Check(h.ptr());
460 }
461};
462
463struct list_view : public handle {
464 list_view() = default;
465 list_view(handle h) : handle(h) {}
466 Py_ssize_t size() const {
467 return PyList_GET_SIZE(ptr());
468 }
469
470 irange enumerate() const {
471 return irange(size());
472 }
473
474 handle operator[](Py_ssize_t i) {
475 return PyList_GET_ITEM(ptr(), i);
476 }
477
478 static bool check(handle h) {
479 return PyList_Check(h.ptr());
480 }
481};
482
483struct dict_view : public handle {
484 dict_view() = default;
485 dict_view(handle h) : handle(h) {}
486 object keys() const {
487 return py::object::checked_steal(PyDict_Keys(ptr()));
488 }
489 object values() const {
490 return py::object::checked_steal(PyDict_Values(ptr()));
491 }
492 object items() const {
493 return py::object::checked_steal(PyDict_Items(ptr()));
494 }
495 bool contains(handle k) const {
496 return PyDict_Contains(ptr(), k.ptr());
497 }
498 handle operator[](handle k) {
499 return py::handle::checked(PyDict_GetItem(ptr(), k.ptr()));
500 }
501 static bool check(handle h) {
502 return PyDict_Check(h.ptr());
503 }
504 bool next(Py_ssize_t* pos, py::handle* key, py::handle* value) {
505 PyObject *k = nullptr, *v = nullptr;
506 auto r = PyDict_Next(ptr(), pos, &k, &v);
507 *key = k;
508 *value = v;
509 return r;
510 }
511 void set(handle k, handle v) {
512 if (-1 == PyDict_SetItem(ptr(), k.ptr(), v.ptr())) {
513 throw exception_set();
514 }
515 }
516};
517
518
519struct kwnames_view : public handle {
520 kwnames_view() = default;
521 kwnames_view(handle h) : handle(h) {}
522
523 Py_ssize_t size() const {
524 return PyTuple_GET_SIZE(ptr());
525 }
526
527 irange enumerate() const {
528 return irange(size());
529 }
530
531 const char* operator[](Py_ssize_t i) const {
532 PyObject* obj = PyTuple_GET_ITEM(ptr(), i);
533 return PyUnicode_AsUTF8(obj);
534 }
535
536 static bool check(handle h) {
537 return PyTuple_Check(h.ptr());
538 }
539};
540
541inline py::object funcname(py::handle func) {
542 if (func.hasattr("__name__")) {
543 return func.attr("__name__");
544 } else {
545 return py::str(func);
546 }
547}
548
549struct vector_args {
550 vector_args(PyObject *const *a,
551 Py_ssize_t n,
552 PyObject *k)
553 : vector_args((py::handle*)a, n, k) {}
554 vector_args(py::handle* a,
555 Py_ssize_t n,
556 py::handle k)
557 : args((py::handle*)a), nargs(n), kwnames(k) {}
558 py::handle* args;
559 Py_ssize_t nargs;
560 kwnames_view kwnames;
561
562 py::handle* begin() {
563 return args;
564 }
565 py::handle* end() {
566 return args + size();
567 }
568
569 py::handle operator[](int64_t i) const {
570 return args[i];
571 }
572 bool has_keywords() const {
573 return kwnames.ptr();
574 }
575 irange enumerate_positional() {
576 return irange(nargs);
577 }
578 irange enumerate_all() {
579 return irange(size());
580 }
581 int64_t size() const {
582 return nargs + (has_keywords() ? kwnames.size() : 0);
583 }
584
585 // bind a test function so this can be tested, first two args for required/kwonly, then return what was parsed...
586
587 // provide write kwarg
588 // don't provide a required arg
589 // don't provide an optional arg
590 // provide a kwarg that is the name of already provided positional
591 // provide a kwonly argument positionally
592 // provide keyword arguments in the wrong order
593 // provide only keyword arguments
594 void parse(const char * fname_cstr, std::initializer_list<const char*> names, std::initializer_list<py::handle*> values, int required, int kwonly=0) {
595 auto error = [&]() {
596 // rather than try to match the slower infrastructure with error messages exactly, once we have detected an error, just use that
597 // infrastructure to format it and throw it
598
599 // have to leak this, because python expects these to last
600 const char** names_buf = new const char*[names.size() + 1];
601 std::copy(names.begin(), names.end(), &names_buf[0]);
602 names_buf[names.size()] = nullptr;
603
604#if PY_VERSION_HEX < 0x03080000
605 char* format_str = new char[names.size() + 3];
606 int i = 0;
607 char* format_it = format_str;
608 for (auto it = names.begin(); it != names.end(); ++it, ++i) {
609 if (i == required) {
610 *format_it++ = '|';
611 }
612 if (i == (int)names.size() - kwonly) {
613 *format_it++ = '$';
614 }
615 *format_it++ = 'O';
616 }
617 *format_it++ = '\0';
618 _PyArg_Parser* _parser = new _PyArg_Parser{format_str, &names_buf[0], fname_cstr, 0};
619 PyObject *dummy = NULL;
620 _PyArg_ParseStackAndKeywords((PyObject*const*)args, nargs, kwnames.ptr(), _parser, &dummy, &dummy, &dummy, &dummy, &dummy);
621#else
622 _PyArg_Parser* _parser = new _PyArg_Parser{NULL, &names_buf[0], fname_cstr, 0};
623 std::unique_ptr<PyObject*[]> buf(new PyObject*[names.size()]);
624 _PyArg_UnpackKeywords((PyObject*const*)args, nargs, NULL, kwnames.ptr(), _parser, required, (Py_ssize_t)values.size() - kwonly, 0, &buf[0]);
625#endif
626 throw exception_set();
627 };
628
629 auto values_it = values.begin();
630 auto names_it = names.begin();
631 auto npositional = values.size() - kwonly;
632
633 if (nargs > (Py_ssize_t)npositional) {
634 // TOO MANY ARGUMENTS
635 error();
636 }
637 for (auto i : irange(nargs)) {
638 *(*values_it++) = args[i];
639 ++names_it;
640 }
641
642 if (!kwnames.ptr()) {
643 if (nargs < required) {
644 // not enough positional arguments
645 error();
646 }
647 } else {
648 int consumed = 0;
649 for (auto i : irange(nargs, values.size())) {
650 bool success = i >= required;
651 const char* target_name = *(names_it++);
652 for (auto j : kwnames.enumerate()) {
653 if (!strcmp(target_name,kwnames[j])) {
654 *(*values_it) = args[nargs + j];
655 ++consumed;
656 success = true;
657 break;
658 }
659 }
660 ++values_it;
661 if (!success) {
662 // REQUIRED ARGUMENT NOT SPECIFIED
663 error();
664 }
665 }
666 if (consumed != kwnames.size()) {
667 // NOT ALL KWNAMES ARGUMENTS WERE USED
668 error();
669 }
670 }
671 }
672 int index(const char* name, int pos) {
673 if (pos < nargs) {
674 return pos;
675 }
676 if (kwnames.ptr()) {
677 for (auto j : kwnames.enumerate()) {
678 if (!strcmp(name, kwnames[j])) {
679 return nargs + j;
680 }
681 }
682 }
683 return -1;
684 }
685};
686
687inline object handle::call_vector(vector_args args) {
688 return object::checked_steal(PY_VECTORCALL(ptr(), (PyObject*const*) args.args, args.nargs, args.kwnames.ptr()));
689}
690
691
692}
693
694#define MPY_ARGS_NAME(typ, name) #name ,
695#define MPY_ARGS_DECLARE(typ, name) typ name;
696#define MPY_ARGS_POINTER(typ, name) &name ,
697#define MPY_PARSE_ARGS_KWARGS(fmt, FORALL_ARGS) \
698 static char* kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \
699 FORALL_ARGS(MPY_ARGS_DECLARE) \
700 if (!PyArg_ParseTupleAndKeywords(args, kwargs, fmt, kwlist, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \
701 throw py::exception_set(); \
702 }
703
704#define MPY_PARSE_ARGS_KWNAMES(fmt, FORALL_ARGS) \
705 static const char * const kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \
706 FORALL_ARGS(MPY_ARGS_DECLARE) \
707 static _PyArg_Parser parser = {fmt, kwlist, 0}; \
708 if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \
709 throw py::exception_set(); \
710 }
711