1 | #include <torch/csrc/MemoryFormat.h> |
2 | |
3 | #include <torch/csrc/Exceptions.h> |
4 | #include <torch/csrc/utils/object_ptr.h> |
5 | #include <torch/csrc/utils/python_strings.h> |
6 | |
7 | #include <c10/core/MemoryFormat.h> |
8 | |
9 | #include <structmember.h> |
10 | #include <cstring> |
11 | #include <string> |
12 | |
13 | PyObject* THPMemoryFormat_New( |
14 | at::MemoryFormat memory_format, |
15 | const std::string& name) { |
16 | auto type = (PyTypeObject*)&THPMemoryFormatType; |
17 | auto self = THPObjectPtr{type->tp_alloc(type, 0)}; |
18 | if (!self) |
19 | throw python_error(); |
20 | auto self_ = reinterpret_cast<THPMemoryFormat*>(self.get()); |
21 | self_->memory_format = memory_format; |
22 | std::strncpy(self_->name, name.c_str(), MEMORY_FORMAT_NAME_LEN); |
23 | self_->name[MEMORY_FORMAT_NAME_LEN] = '\0'; |
24 | return self.release(); |
25 | } |
26 | |
27 | PyObject* THPMemoryFormat_repr(THPMemoryFormat* self) { |
28 | return THPUtils_packString(self->name); |
29 | } |
30 | |
31 | PyObject* THPMemoryFormat_reduce(PyObject* _self, PyObject* noargs) { |
32 | auto* self = (THPMemoryFormat*)_self; |
33 | return THPUtils_packString(self->name); |
34 | } |
35 | |
36 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
37 | static PyMethodDef THPMemoryFormat_methods[] = { |
38 | {"__reduce__" , THPMemoryFormat_reduce, METH_NOARGS, nullptr}, |
39 | {nullptr} /* Sentinel */ |
40 | }; |
41 | |
42 | PyTypeObject THPMemoryFormatType = { |
43 | PyVarObject_HEAD_INIT(nullptr, 0) "torch.memory_format" , /* tp_name */ |
44 | sizeof(THPMemoryFormat), /* tp_basicsize */ |
45 | 0, /* tp_itemsize */ |
46 | nullptr, /* tp_dealloc */ |
47 | 0, /* tp_vectorcall_offset */ |
48 | nullptr, /* tp_getattr */ |
49 | nullptr, /* tp_setattr */ |
50 | nullptr, /* tp_reserved */ |
51 | (reprfunc)THPMemoryFormat_repr, /* tp_repr */ |
52 | nullptr, /* tp_as_number */ |
53 | nullptr, /* tp_as_sequence */ |
54 | nullptr, /* tp_as_mapping */ |
55 | nullptr, /* tp_hash */ |
56 | nullptr, /* tp_call */ |
57 | nullptr, /* tp_str */ |
58 | nullptr, /* tp_getattro */ |
59 | nullptr, /* tp_setattro */ |
60 | nullptr, /* tp_as_buffer */ |
61 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
62 | nullptr, /* tp_doc */ |
63 | nullptr, /* tp_traverse */ |
64 | nullptr, /* tp_clear */ |
65 | nullptr, /* tp_richcompare */ |
66 | 0, /* tp_weaklistoffset */ |
67 | nullptr, /* tp_iter */ |
68 | nullptr, /* tp_iternext */ |
69 | THPMemoryFormat_methods, /* tp_methods */ |
70 | nullptr, /* tp_members */ |
71 | nullptr, /* tp_getset */ |
72 | nullptr, /* tp_base */ |
73 | nullptr, /* tp_dict */ |
74 | nullptr, /* tp_descr_get */ |
75 | nullptr, /* tp_descr_set */ |
76 | 0, /* tp_dictoffset */ |
77 | nullptr, /* tp_init */ |
78 | nullptr, /* tp_alloc */ |
79 | nullptr, /* tp_new */ |
80 | }; |
81 | |
82 | void THPMemoryFormat_init(PyObject* module) { |
83 | if (PyType_Ready(&THPMemoryFormatType) < 0) { |
84 | throw python_error(); |
85 | } |
86 | Py_INCREF(&THPMemoryFormatType); |
87 | if (PyModule_AddObject( |
88 | module, "memory_format" , (PyObject*)&THPMemoryFormatType) != 0) { |
89 | throw python_error(); |
90 | } |
91 | } |
92 | |