1#define PY_SSIZE_T_CLEAN
2#include <torch/csrc/dynamo/guards.h>
3#include <torch/csrc/utils/python_numbers.h>
4#include <torch/extension.h>
5#include <sstream>
6
7namespace {
8
9struct LocalState {
10 // TLS state that changes operators
11 c10::impl::LocalDispatchKeySet dispatch_modifier;
12 bool grad_mode_enabled;
13
14 at::DispatchKeySet apply(at::DispatchKeySet ks) const {
15 return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_;
16 }
17
18 LocalState()
19 : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()),
20 grad_mode_enabled(at::GradMode::is_enabled()) {}
21};
22
23class TensorCheck {
24 public:
25 TensorCheck(
26 const LocalState& state,
27 PyTypeObject* pt,
28 const at::Tensor& v,
29 bool dynamic_shapes)
30 : pytype(pt),
31 dispatch_key_(state.apply(v.key_set()).raw_repr()),
32 dtype_(v.dtype().toScalarType()),
33 device_index_(v.device().index()),
34 requires_grad_(state.grad_mode_enabled && v.requires_grad()),
35 dynamic_shapes_(dynamic_shapes) {
36 auto ndim = v.ndimension();
37 const auto& sizes = v.sizes();
38 const auto& strides = v.strides();
39 sizes_.reserve(ndim);
40 strides_.reserve(ndim);
41 for (auto i : c10::irange(ndim)) {
42 sizes_.emplace_back(sizes[i]);
43 strides_.emplace_back(strides[i]);
44 }
45 }
46
47 bool check(const LocalState& state, const at::Tensor& v) {
48 if (dispatch_key_ != state.apply(v.key_set()).raw_repr() ||
49 dtype_ != v.dtype().toScalarType() ||
50 device_index_ != v.device().index() ||
51 requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) {
52 return false;
53 }
54 auto ndim = static_cast<size_t>(v.ndimension());
55 if (ndim != sizes_.size()) {
56 return false;
57 }
58 if (!dynamic_shapes_) {
59 const auto& sizes = v.sizes();
60 const auto& strides = v.strides();
61 for (auto i : c10::irange(ndim)) {
62 if (sizes_[i] != sizes[i] || strides_[i] != strides[i]) {
63 return false;
64 }
65 }
66 }
67 return true;
68 }
69
70 std::string check_verbose(
71 const LocalState& state,
72 const at::Tensor& v,
73 std::string tensor_name) {
74 std::stringstream fail_reason;
75 fail_reason << "tensor '" << tensor_name << "' ";
76 if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) {
77 // return fmt::format("tensor dispatch key mismatch. expected {}, actual
78 // {}", dispatch_key_, state.apply(v.key_set()).raw_repr());
79 fail_reason << "dispatch key set mismatch. expected "
80 << c10::DispatchKeySet(
81 c10::DispatchKeySet::RAW, dispatch_key_)
82 << ", actual " << state.apply(v.key_set());
83 return fail_reason.str();
84 } else if (dtype_ != v.dtype().toScalarType()) {
85 // return fmt::format("tensor dtype mismatch. expected {}, actual {}",
86 // dtype_, v.dtype().toScalarType());
87 fail_reason << "dtype mismatch. expected " << dtype_ << ", actual "
88 << v.dtype().toScalarType();
89 return fail_reason.str();
90 } else if (device_index_ != v.device().index()) {
91 fail_reason
92 << "Tensor device index mismatch. Expected device index to be "
93 << device_index_ << ", actual " << v.device().index();
94 return fail_reason.str();
95 } else if (
96 requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) {
97 // return fmt::format("tensor requires_grad mismatch. expected {}",
98 // requires_grad_);
99 fail_reason << "requires_grad mismatch. expected requires_grad="
100 << requires_grad_;
101 return fail_reason.str();
102 }
103 size_t ndim = static_cast<size_t>(v.ndimension());
104 if (ndim != sizes_.size()) {
105 // return fmt::format("tensor rank mismatch. expected {}, actual {}",
106 // sizes_.size(), ndim);
107 fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual "
108 << ndim;
109 return fail_reason.str();
110 }
111 if (!dynamic_shapes_) {
112 const auto& sizes = v.sizes();
113 const auto& strides = v.strides();
114 for (auto i : c10::irange(ndim)) {
115 if (sizes_[i] != sizes[i]) {
116 // return fmt::format("tensor size mismatch at index {}. expected {},
117 // actual {}", i, sizes_[i], sizes[i]);
118 fail_reason << "size mismatch at index " << i << ". expected "
119 << sizes_[i] << ", actual " << sizes[i];
120 return fail_reason.str();
121 } else if (strides_[i] != strides[i]) {
122 // return fmt::format("tensor strides mismatch at index {}. expected
123 // {}, actual {}", i, strides_[i]);
124 fail_reason << "strides mismatch at index " << i << ". expected "
125 << strides_[i] << ", actual " << strides[i];
126 return fail_reason.str();
127 }
128 }
129 }
130 return "";
131 }
132
133 PyTypeObject* pytype;
134
135 private:
136 uint64_t dispatch_key_; // DispatchKeySet includes device/layout
137 at::ScalarType dtype_;
138 // Note(voz): While dispatch_key_ is sufficiently representative of a device
139 // In that keys are more granular AND device specific - they do not
140 // necessarily capture device indices correctly.
141 at::DeviceIndex device_index_;
142 bool requires_grad_;
143 bool dynamic_shapes_;
144 std::vector<int64_t> sizes_;
145 std::vector<int64_t> strides_;
146};
147
148typedef std::vector<TensorCheck> ChecksList;
149
150typedef struct {
151 PyObject_HEAD;
152 ChecksList* checks;
153} TensorGuards;
154
155static void TensorGuards_dealloc(TensorGuards* self) {
156 if (self->checks != NULL) {
157 delete self->checks;
158 self->checks = NULL;
159 }
160 Py_TYPE(self)->tp_free((PyObject*)self);
161}
162
163static PyObject* TensorGuards_new(
164 PyTypeObject* type,
165 PyObject* args,
166 PyObject* kwds) {
167 TensorGuards* self = (TensorGuards*)type->tp_alloc(type, 0);
168 if (self != NULL) {
169 self->checks = new ChecksList();
170 }
171 return (PyObject*)self;
172}
173
174static int TensorGuards_init(
175 TensorGuards* self,
176 PyObject* args,
177 PyObject* kwds) {
178 if (!PyTuple_CheckExact(args)) {
179 PyErr_SetString(PyExc_TypeError, "expected tuple()");
180 return -1;
181 }
182 PyObject* dynamic_shapes_py = PyDict_GetItemString(kwds, "dynamic_shapes");
183 if (dynamic_shapes_py == NULL) {
184 PyErr_SetString(PyExc_TypeError, "missing dynamic_shapes=...");
185 return -1;
186 }
187 bool dynamic_shapes = PyObject_IsTrue(dynamic_shapes_py);
188
189 auto& checks = *self->checks;
190 auto len = PyTuple_GET_SIZE(args);
191 checks.reserve(len);
192 LocalState state;
193 for (auto i : c10::irange(len)) {
194 PyObject* item = PyTuple_GET_ITEM(args, i);
195 if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
196 PyErr_SetString(PyExc_TypeError, "expected Tensor()");
197 return -1;
198 }
199 checks.emplace_back(
200 state, Py_TYPE(item), THPVariable_Unpack(item), dynamic_shapes);
201 }
202 return 0;
203}
204
205PyObject* TensorGuards_check(TensorGuards* self, PyObject* args) {
206 if (!PyTuple_CheckExact(args)) {
207 PyErr_SetString(PyExc_TypeError, "expected tuple()");
208 return NULL;
209 }
210 auto& checks = *self->checks;
211 auto len = PyTuple_GET_SIZE(args);
212
213 if (static_cast<decltype(len)>(checks.size()) != len) {
214 PyErr_SetString(PyExc_TypeError, "wrong length");
215 return NULL;
216 }
217
218 LocalState state;
219
220 for (auto i : c10::irange(len)) {
221 PyObject* item = PyTuple_GET_ITEM(args, i);
222 if (Py_TYPE(item) != checks[i].pytype) {
223 Py_RETURN_FALSE;
224 }
225 if (!checks[i].check(state, THPVariable_Unpack(item))) {
226 Py_RETURN_FALSE;
227 }
228 }
229
230 Py_RETURN_TRUE;
231}
232
233PyObject* TensorGuards_check_verbose(
234 TensorGuards* self,
235 PyObject* args,
236 PyObject* kwargs) {
237 if (!PyTuple_CheckExact(args)) {
238 PyErr_SetString(PyExc_TypeError, "expected tuple()");
239 return NULL;
240 }
241 auto& checks = *self->checks;
242 auto len = PyTuple_GET_SIZE(args);
243
244 if (static_cast<decltype(len)>(checks.size()) != len) {
245 PyErr_SetString(PyExc_TypeError, "wrong length");
246 return NULL;
247 }
248
249 PyObject* tensor_check_names_py =
250 PyDict_GetItemString(kwargs, "tensor_check_names");
251 if (tensor_check_names_py == NULL) {
252 PyErr_SetString(PyExc_TypeError, "missing tensor_check_names kwarg");
253 return NULL;
254 }
255
256 if (!PyList_Check(tensor_check_names_py)) {
257 PyErr_SetString(PyExc_TypeError, "tensor_check_names kwarg must be a list");
258 return NULL;
259 }
260
261 auto names_size = PyList_Size(tensor_check_names_py);
262 if (names_size != static_cast<decltype(names_size)>(checks.size())) {
263 PyErr_SetString(
264 PyExc_TypeError,
265 "tensor_check_names should be the same size as # tensors");
266 return NULL;
267 }
268
269 std::vector<std::string> tensor_check_names;
270 tensor_check_names.reserve(names_size);
271 for (auto i : c10::irange(names_size)) {
272 PyObject* value = PyList_GetItem(tensor_check_names_py, i);
273 if (!PyUnicode_Check(value)) {
274 PyErr_SetString(
275 PyExc_TypeError, "tensor_check_names must only contain strings");
276 return NULL;
277 }
278 tensor_check_names.emplace_back(PyUnicode_AsUTF8(value));
279 }
280
281 LocalState state;
282 for (auto i : c10::irange(len)) {
283 PyObject* item = PyTuple_GET_ITEM(args, i);
284 if (Py_TYPE(item) != checks[i].pytype) {
285 std::stringstream fail_reason;
286 PyObject* type_str = PyObject_Str(PyObject_Type(item));
287 fail_reason << "expected type of '" << tensor_check_names[i]
288 << "' to be a tensor type, ";
289 if (!type_str) {
290 fail_reason << "but found a different type";
291 } else {
292 fail_reason << "' but found " << PyUnicode_AsUTF8(type_str);
293 }
294 return Py_BuildValue("s", fail_reason.str().c_str());
295 }
296 std::string fail_reason = checks[i].check_verbose(
297 state, THPVariable_Unpack(item), tensor_check_names[i]);
298 if (fail_reason.length() > 0) {
299 return Py_BuildValue("s", fail_reason.c_str());
300 }
301 }
302
303 Py_RETURN_TRUE;
304}
305
306static PyMethodDef TensorGuards_methods[] = {
307 {"check", (PyCFunction)TensorGuards_check, METH_VARARGS, ""},
308 {"check_verbose",
309 (PyCFunction)(void*)TensorGuards_check_verbose,
310 METH_VARARGS | METH_KEYWORDS,
311 "verbose fail reasons for failed checks"},
312 {NULL} /* Sentinel */
313};
314
315static PyTypeObject TensorGuardsType = {
316 // NOLINTNEXTLINE
317 PyVarObject_HEAD_INIT(NULL, 0)};
318
319static PyObject* check_type_id(PyObject* dummy, PyObject* args) {
320 // faster `lambda obj, expected: id(type(obj)) == expected`
321 PyObject* obj;
322 unsigned long expected;
323 if (!PyArg_ParseTuple(args, "Ok", &obj, &expected)) {
324 return NULL;
325 }
326 if (Py_TYPE(obj) == (void*)expected) {
327 Py_RETURN_TRUE;
328 } else {
329 Py_RETURN_FALSE;
330 }
331}
332
333static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
334 // faster `lambda obj, expected: id(obj) == expected`
335 PyObject* obj;
336 unsigned long expected;
337 if (!PyArg_ParseTuple(args, "Ok", &obj, &expected)) {
338 return NULL;
339 }
340 if (obj == (void*)expected) {
341 Py_RETURN_TRUE;
342 } else {
343 Py_RETURN_FALSE;
344 }
345}
346
347static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
348 /*
349 Assert that a given tensor has a given size/stride, but ignore strides
350 of size==1 dimensions. Implemented in C++ as this is on the hot path.
351 */
352 PyObject* item;
353 PyObject* size;
354 PyObject* stride;
355 if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
356 return NULL;
357 }
358 if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
359 PyErr_SetString(PyExc_TypeError, "expected Tensor()");
360 return NULL;
361 }
362 if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
363 PyErr_SetString(PyExc_TypeError, "expected tuple()");
364 return NULL;
365 }
366 at::Tensor tensor = THPVariable_Unpack(item);
367 int64_t ndim = tensor.ndimension();
368 if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
369 PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
370 return NULL;
371 }
372 for (auto i : c10::irange(ndim)) {
373 int64_t want_size = THPUtils_unpackLong(PyTuple_GET_ITEM(size, i));
374 int64_t want_stride = THPUtils_unpackLong(PyTuple_GET_ITEM(stride, i));
375 int64_t actual_size = tensor.size(i);
376 int64_t actual_stride = tensor.stride(i);
377 if (want_size != actual_size ||
378 // ignore stride differences when size is 1
379 (want_stride != actual_stride && actual_size > 1)) {
380 std::stringstream msg;
381 msg << "expected size " << actual_size << "==" << want_size << ", stride "
382 << actual_stride << "==" << want_stride << " at dim=" << i;
383 PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
384 return NULL;
385 }
386 }
387 Py_RETURN_TRUE;
388}
389
390static PyMethodDef _methods[] = {
391 {"check_type_id", check_type_id, METH_VARARGS, NULL},
392 {"check_obj_id", check_obj_id, METH_VARARGS, NULL},
393 {"assert_size_stride", assert_size_stride, METH_VARARGS, NULL},
394 {NULL, NULL, 0, NULL}};
395
396static struct PyModuleDef _module = {
397 PyModuleDef_HEAD_INIT,
398 "torch._C._dynamo.guards",
399 "Module containing checks on tensors",
400 -1,
401 _methods};
402
403} // namespace
404
405PyObject* torch_c_dynamo_guards_init() {
406 // initialize TensorGuardsType
407 TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards";
408 TensorGuardsType.tp_basicsize = sizeof(TensorGuards);
409 TensorGuardsType.tp_itemsize = 0;
410 TensorGuardsType.tp_dealloc = (destructor)TensorGuards_dealloc;
411 TensorGuardsType.tp_flags = Py_TPFLAGS_DEFAULT;
412 TensorGuardsType.tp_doc = "Check properties of a torch.Tensor";
413 TensorGuardsType.tp_methods = TensorGuards_methods;
414 TensorGuardsType.tp_init = (initproc)TensorGuards_init;
415 TensorGuardsType.tp_new = TensorGuards_new;
416
417 PyObject* m;
418 if (PyType_Ready(&TensorGuardsType) < 0)
419 return NULL;
420
421 m = PyModule_Create(&_module);
422 if (m == NULL)
423 return NULL;
424
425 Py_INCREF(&TensorGuardsType);
426 if (PyModule_AddObject(m, "TensorGuards", (PyObject*)&TensorGuardsType) < 0) {
427 Py_DECREF(&TensorGuardsType);
428 Py_DECREF(m);
429 return NULL;
430 }
431
432 return m;
433}
434