1 | #define PY_SSIZE_T_CLEAN |
2 | #include <torch/csrc/dynamo/guards.h> |
3 | #include <torch/csrc/utils/python_numbers.h> |
4 | #include <torch/extension.h> |
5 | #include <sstream> |
6 | |
7 | namespace { |
8 | |
9 | struct LocalState { |
10 | // TLS state that changes operators |
11 | c10::impl::LocalDispatchKeySet dispatch_modifier; |
12 | bool grad_mode_enabled; |
13 | |
14 | at::DispatchKeySet apply(at::DispatchKeySet ks) const { |
15 | return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; |
16 | } |
17 | |
18 | LocalState() |
19 | : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()), |
20 | grad_mode_enabled(at::GradMode::is_enabled()) {} |
21 | }; |
22 | |
23 | class TensorCheck { |
24 | public: |
25 | TensorCheck( |
26 | const LocalState& state, |
27 | PyTypeObject* pt, |
28 | const at::Tensor& v, |
29 | bool dynamic_shapes) |
30 | : pytype(pt), |
31 | dispatch_key_(state.apply(v.key_set()).raw_repr()), |
32 | dtype_(v.dtype().toScalarType()), |
33 | device_index_(v.device().index()), |
34 | requires_grad_(state.grad_mode_enabled && v.requires_grad()), |
35 | dynamic_shapes_(dynamic_shapes) { |
36 | auto ndim = v.ndimension(); |
37 | const auto& sizes = v.sizes(); |
38 | const auto& strides = v.strides(); |
39 | sizes_.reserve(ndim); |
40 | strides_.reserve(ndim); |
41 | for (auto i : c10::irange(ndim)) { |
42 | sizes_.emplace_back(sizes[i]); |
43 | strides_.emplace_back(strides[i]); |
44 | } |
45 | } |
46 | |
47 | bool check(const LocalState& state, const at::Tensor& v) { |
48 | if (dispatch_key_ != state.apply(v.key_set()).raw_repr() || |
49 | dtype_ != v.dtype().toScalarType() || |
50 | device_index_ != v.device().index() || |
51 | requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) { |
52 | return false; |
53 | } |
54 | auto ndim = static_cast<size_t>(v.ndimension()); |
55 | if (ndim != sizes_.size()) { |
56 | return false; |
57 | } |
58 | if (!dynamic_shapes_) { |
59 | const auto& sizes = v.sizes(); |
60 | const auto& strides = v.strides(); |
61 | for (auto i : c10::irange(ndim)) { |
62 | if (sizes_[i] != sizes[i] || strides_[i] != strides[i]) { |
63 | return false; |
64 | } |
65 | } |
66 | } |
67 | return true; |
68 | } |
69 | |
70 | std::string check_verbose( |
71 | const LocalState& state, |
72 | const at::Tensor& v, |
73 | std::string tensor_name) { |
74 | std::stringstream fail_reason; |
75 | fail_reason << "tensor '" << tensor_name << "' " ; |
76 | if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) { |
77 | // return fmt::format("tensor dispatch key mismatch. expected {}, actual |
78 | // {}", dispatch_key_, state.apply(v.key_set()).raw_repr()); |
79 | fail_reason << "dispatch key set mismatch. expected " |
80 | << c10::DispatchKeySet( |
81 | c10::DispatchKeySet::RAW, dispatch_key_) |
82 | << ", actual " << state.apply(v.key_set()); |
83 | return fail_reason.str(); |
84 | } else if (dtype_ != v.dtype().toScalarType()) { |
85 | // return fmt::format("tensor dtype mismatch. expected {}, actual {}", |
86 | // dtype_, v.dtype().toScalarType()); |
87 | fail_reason << "dtype mismatch. expected " << dtype_ << ", actual " |
88 | << v.dtype().toScalarType(); |
89 | return fail_reason.str(); |
90 | } else if (device_index_ != v.device().index()) { |
91 | fail_reason |
92 | << "Tensor device index mismatch. Expected device index to be " |
93 | << device_index_ << ", actual " << v.device().index(); |
94 | return fail_reason.str(); |
95 | } else if ( |
96 | requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) { |
97 | // return fmt::format("tensor requires_grad mismatch. expected {}", |
98 | // requires_grad_); |
99 | fail_reason << "requires_grad mismatch. expected requires_grad=" |
100 | << requires_grad_; |
101 | return fail_reason.str(); |
102 | } |
103 | size_t ndim = static_cast<size_t>(v.ndimension()); |
104 | if (ndim != sizes_.size()) { |
105 | // return fmt::format("tensor rank mismatch. expected {}, actual {}", |
106 | // sizes_.size(), ndim); |
107 | fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual " |
108 | << ndim; |
109 | return fail_reason.str(); |
110 | } |
111 | if (!dynamic_shapes_) { |
112 | const auto& sizes = v.sizes(); |
113 | const auto& strides = v.strides(); |
114 | for (auto i : c10::irange(ndim)) { |
115 | if (sizes_[i] != sizes[i]) { |
116 | // return fmt::format("tensor size mismatch at index {}. expected {}, |
117 | // actual {}", i, sizes_[i], sizes[i]); |
118 | fail_reason << "size mismatch at index " << i << ". expected " |
119 | << sizes_[i] << ", actual " << sizes[i]; |
120 | return fail_reason.str(); |
121 | } else if (strides_[i] != strides[i]) { |
122 | // return fmt::format("tensor strides mismatch at index {}. expected |
123 | // {}, actual {}", i, strides_[i]); |
124 | fail_reason << "strides mismatch at index " << i << ". expected " |
125 | << strides_[i] << ", actual " << strides[i]; |
126 | return fail_reason.str(); |
127 | } |
128 | } |
129 | } |
130 | return "" ; |
131 | } |
132 | |
133 | PyTypeObject* pytype; |
134 | |
135 | private: |
136 | uint64_t dispatch_key_; // DispatchKeySet includes device/layout |
137 | at::ScalarType dtype_; |
138 | // Note(voz): While dispatch_key_ is sufficiently representative of a device |
139 | // In that keys are more granular AND device specific - they do not |
140 | // necessarily capture device indices correctly. |
141 | at::DeviceIndex device_index_; |
142 | bool requires_grad_; |
143 | bool dynamic_shapes_; |
144 | std::vector<int64_t> sizes_; |
145 | std::vector<int64_t> strides_; |
146 | }; |
147 | |
148 | typedef std::vector<TensorCheck> ChecksList; |
149 | |
150 | typedef struct { |
151 | PyObject_HEAD; |
152 | ChecksList* checks; |
153 | } TensorGuards; |
154 | |
155 | static void TensorGuards_dealloc(TensorGuards* self) { |
156 | if (self->checks != NULL) { |
157 | delete self->checks; |
158 | self->checks = NULL; |
159 | } |
160 | Py_TYPE(self)->tp_free((PyObject*)self); |
161 | } |
162 | |
163 | static PyObject* TensorGuards_new( |
164 | PyTypeObject* type, |
165 | PyObject* args, |
166 | PyObject* kwds) { |
167 | TensorGuards* self = (TensorGuards*)type->tp_alloc(type, 0); |
168 | if (self != NULL) { |
169 | self->checks = new ChecksList(); |
170 | } |
171 | return (PyObject*)self; |
172 | } |
173 | |
174 | static int TensorGuards_init( |
175 | TensorGuards* self, |
176 | PyObject* args, |
177 | PyObject* kwds) { |
178 | if (!PyTuple_CheckExact(args)) { |
179 | PyErr_SetString(PyExc_TypeError, "expected tuple()" ); |
180 | return -1; |
181 | } |
182 | PyObject* dynamic_shapes_py = PyDict_GetItemString(kwds, "dynamic_shapes" ); |
183 | if (dynamic_shapes_py == NULL) { |
184 | PyErr_SetString(PyExc_TypeError, "missing dynamic_shapes=..." ); |
185 | return -1; |
186 | } |
187 | bool dynamic_shapes = PyObject_IsTrue(dynamic_shapes_py); |
188 | |
189 | auto& checks = *self->checks; |
190 | auto len = PyTuple_GET_SIZE(args); |
191 | checks.reserve(len); |
192 | LocalState state; |
193 | for (auto i : c10::irange(len)) { |
194 | PyObject* item = PyTuple_GET_ITEM(args, i); |
195 | if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) { |
196 | PyErr_SetString(PyExc_TypeError, "expected Tensor()" ); |
197 | return -1; |
198 | } |
199 | checks.emplace_back( |
200 | state, Py_TYPE(item), THPVariable_Unpack(item), dynamic_shapes); |
201 | } |
202 | return 0; |
203 | } |
204 | |
205 | PyObject* TensorGuards_check(TensorGuards* self, PyObject* args) { |
206 | if (!PyTuple_CheckExact(args)) { |
207 | PyErr_SetString(PyExc_TypeError, "expected tuple()" ); |
208 | return NULL; |
209 | } |
210 | auto& checks = *self->checks; |
211 | auto len = PyTuple_GET_SIZE(args); |
212 | |
213 | if (static_cast<decltype(len)>(checks.size()) != len) { |
214 | PyErr_SetString(PyExc_TypeError, "wrong length" ); |
215 | return NULL; |
216 | } |
217 | |
218 | LocalState state; |
219 | |
220 | for (auto i : c10::irange(len)) { |
221 | PyObject* item = PyTuple_GET_ITEM(args, i); |
222 | if (Py_TYPE(item) != checks[i].pytype) { |
223 | Py_RETURN_FALSE; |
224 | } |
225 | if (!checks[i].check(state, THPVariable_Unpack(item))) { |
226 | Py_RETURN_FALSE; |
227 | } |
228 | } |
229 | |
230 | Py_RETURN_TRUE; |
231 | } |
232 | |
233 | PyObject* TensorGuards_check_verbose( |
234 | TensorGuards* self, |
235 | PyObject* args, |
236 | PyObject* kwargs) { |
237 | if (!PyTuple_CheckExact(args)) { |
238 | PyErr_SetString(PyExc_TypeError, "expected tuple()" ); |
239 | return NULL; |
240 | } |
241 | auto& checks = *self->checks; |
242 | auto len = PyTuple_GET_SIZE(args); |
243 | |
244 | if (static_cast<decltype(len)>(checks.size()) != len) { |
245 | PyErr_SetString(PyExc_TypeError, "wrong length" ); |
246 | return NULL; |
247 | } |
248 | |
249 | PyObject* tensor_check_names_py = |
250 | PyDict_GetItemString(kwargs, "tensor_check_names" ); |
251 | if (tensor_check_names_py == NULL) { |
252 | PyErr_SetString(PyExc_TypeError, "missing tensor_check_names kwarg" ); |
253 | return NULL; |
254 | } |
255 | |
256 | if (!PyList_Check(tensor_check_names_py)) { |
257 | PyErr_SetString(PyExc_TypeError, "tensor_check_names kwarg must be a list" ); |
258 | return NULL; |
259 | } |
260 | |
261 | auto names_size = PyList_Size(tensor_check_names_py); |
262 | if (names_size != static_cast<decltype(names_size)>(checks.size())) { |
263 | PyErr_SetString( |
264 | PyExc_TypeError, |
265 | "tensor_check_names should be the same size as # tensors" ); |
266 | return NULL; |
267 | } |
268 | |
269 | std::vector<std::string> tensor_check_names; |
270 | tensor_check_names.reserve(names_size); |
271 | for (auto i : c10::irange(names_size)) { |
272 | PyObject* value = PyList_GetItem(tensor_check_names_py, i); |
273 | if (!PyUnicode_Check(value)) { |
274 | PyErr_SetString( |
275 | PyExc_TypeError, "tensor_check_names must only contain strings" ); |
276 | return NULL; |
277 | } |
278 | tensor_check_names.emplace_back(PyUnicode_AsUTF8(value)); |
279 | } |
280 | |
281 | LocalState state; |
282 | for (auto i : c10::irange(len)) { |
283 | PyObject* item = PyTuple_GET_ITEM(args, i); |
284 | if (Py_TYPE(item) != checks[i].pytype) { |
285 | std::stringstream fail_reason; |
286 | PyObject* type_str = PyObject_Str(PyObject_Type(item)); |
287 | fail_reason << "expected type of '" << tensor_check_names[i] |
288 | << "' to be a tensor type, " ; |
289 | if (!type_str) { |
290 | fail_reason << "but found a different type" ; |
291 | } else { |
292 | fail_reason << "' but found " << PyUnicode_AsUTF8(type_str); |
293 | } |
294 | return Py_BuildValue("s" , fail_reason.str().c_str()); |
295 | } |
296 | std::string fail_reason = checks[i].check_verbose( |
297 | state, THPVariable_Unpack(item), tensor_check_names[i]); |
298 | if (fail_reason.length() > 0) { |
299 | return Py_BuildValue("s" , fail_reason.c_str()); |
300 | } |
301 | } |
302 | |
303 | Py_RETURN_TRUE; |
304 | } |
305 | |
306 | static PyMethodDef TensorGuards_methods[] = { |
307 | {"check" , (PyCFunction)TensorGuards_check, METH_VARARGS, "" }, |
308 | {"check_verbose" , |
309 | (PyCFunction)(void*)TensorGuards_check_verbose, |
310 | METH_VARARGS | METH_KEYWORDS, |
311 | "verbose fail reasons for failed checks" }, |
312 | {NULL} /* Sentinel */ |
313 | }; |
314 | |
315 | static PyTypeObject TensorGuardsType = { |
316 | // NOLINTNEXTLINE |
317 | PyVarObject_HEAD_INIT(NULL, 0)}; |
318 | |
319 | static PyObject* check_type_id(PyObject* dummy, PyObject* args) { |
320 | // faster `lambda obj, expected: id(type(obj)) == expected` |
321 | PyObject* obj; |
322 | unsigned long expected; |
323 | if (!PyArg_ParseTuple(args, "Ok" , &obj, &expected)) { |
324 | return NULL; |
325 | } |
326 | if (Py_TYPE(obj) == (void*)expected) { |
327 | Py_RETURN_TRUE; |
328 | } else { |
329 | Py_RETURN_FALSE; |
330 | } |
331 | } |
332 | |
333 | static PyObject* check_obj_id(PyObject* dummy, PyObject* args) { |
334 | // faster `lambda obj, expected: id(obj) == expected` |
335 | PyObject* obj; |
336 | unsigned long expected; |
337 | if (!PyArg_ParseTuple(args, "Ok" , &obj, &expected)) { |
338 | return NULL; |
339 | } |
340 | if (obj == (void*)expected) { |
341 | Py_RETURN_TRUE; |
342 | } else { |
343 | Py_RETURN_FALSE; |
344 | } |
345 | } |
346 | |
347 | static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) { |
348 | /* |
349 | Assert that a given tensor has a given size/stride, but ignore strides |
350 | of size==1 dimensions. Implemented in C++ as this is on the hot path. |
351 | */ |
352 | PyObject* item; |
353 | PyObject* size; |
354 | PyObject* stride; |
355 | if (!PyArg_ParseTuple(args, "OOO" , &item, &size, &stride)) { |
356 | return NULL; |
357 | } |
358 | if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) { |
359 | PyErr_SetString(PyExc_TypeError, "expected Tensor()" ); |
360 | return NULL; |
361 | } |
362 | if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) { |
363 | PyErr_SetString(PyExc_TypeError, "expected tuple()" ); |
364 | return NULL; |
365 | } |
366 | at::Tensor tensor = THPVariable_Unpack(item); |
367 | int64_t ndim = tensor.ndimension(); |
368 | if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) { |
369 | PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions" ); |
370 | return NULL; |
371 | } |
372 | for (auto i : c10::irange(ndim)) { |
373 | int64_t want_size = THPUtils_unpackLong(PyTuple_GET_ITEM(size, i)); |
374 | int64_t want_stride = THPUtils_unpackLong(PyTuple_GET_ITEM(stride, i)); |
375 | int64_t actual_size = tensor.size(i); |
376 | int64_t actual_stride = tensor.stride(i); |
377 | if (want_size != actual_size || |
378 | // ignore stride differences when size is 1 |
379 | (want_stride != actual_stride && actual_size > 1)) { |
380 | std::stringstream msg; |
381 | msg << "expected size " << actual_size << "==" << want_size << ", stride " |
382 | << actual_stride << "==" << want_stride << " at dim=" << i; |
383 | PyErr_SetString(PyExc_AssertionError, msg.str().c_str()); |
384 | return NULL; |
385 | } |
386 | } |
387 | Py_RETURN_TRUE; |
388 | } |
389 | |
390 | static PyMethodDef _methods[] = { |
391 | {"check_type_id" , check_type_id, METH_VARARGS, NULL}, |
392 | {"check_obj_id" , check_obj_id, METH_VARARGS, NULL}, |
393 | {"assert_size_stride" , assert_size_stride, METH_VARARGS, NULL}, |
394 | {NULL, NULL, 0, NULL}}; |
395 | |
396 | static struct PyModuleDef _module = { |
397 | PyModuleDef_HEAD_INIT, |
398 | "torch._C._dynamo.guards" , |
399 | "Module containing checks on tensors" , |
400 | -1, |
401 | _methods}; |
402 | |
403 | } // namespace |
404 | |
405 | PyObject* torch_c_dynamo_guards_init() { |
406 | // initialize TensorGuardsType |
407 | TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards" ; |
408 | TensorGuardsType.tp_basicsize = sizeof(TensorGuards); |
409 | TensorGuardsType.tp_itemsize = 0; |
410 | TensorGuardsType.tp_dealloc = (destructor)TensorGuards_dealloc; |
411 | TensorGuardsType.tp_flags = Py_TPFLAGS_DEFAULT; |
412 | TensorGuardsType.tp_doc = "Check properties of a torch.Tensor" ; |
413 | TensorGuardsType.tp_methods = TensorGuards_methods; |
414 | TensorGuardsType.tp_init = (initproc)TensorGuards_init; |
415 | TensorGuardsType.tp_new = TensorGuards_new; |
416 | |
417 | PyObject* m; |
418 | if (PyType_Ready(&TensorGuardsType) < 0) |
419 | return NULL; |
420 | |
421 | m = PyModule_Create(&_module); |
422 | if (m == NULL) |
423 | return NULL; |
424 | |
425 | Py_INCREF(&TensorGuardsType); |
426 | if (PyModule_AddObject(m, "TensorGuards" , (PyObject*)&TensorGuardsType) < 0) { |
427 | Py_DECREF(&TensorGuardsType); |
428 | Py_DECREF(m); |
429 | return NULL; |
430 | } |
431 | |
432 | return m; |
433 | } |
434 | |