1 | #include <c10/core/TensorImpl.h> |
2 | |
3 | #include <c10/core/Backend.h> |
4 | #include <c10/core/InferenceMode.h> |
5 | #include <c10/core/SymIntArrayRef.h> |
6 | #include <c10/core/WrapDimMinimal.h> |
7 | #include <c10/core/impl/LocalDispatchKeySet.h> |
8 | #include <c10/core/impl/PyInterpreter.h> |
9 | #include <c10/core/impl/TorchDispatchModeTLS.h> |
10 | #include <c10/util/Optional.h> |
11 | #include <c10/util/irange.h> |
12 | |
13 | #include <utility> |
14 | |
15 | C10_DEFINE_bool( |
16 | caffe2_keep_on_shrink, |
17 | true, |
18 | "If set, keeps memory when a tensor is shrinking its size." ); |
19 | |
20 | C10_DEFINE_int64( |
21 | caffe2_max_keep_on_shrink_memory, |
22 | LLONG_MAX, |
23 | "The maximum memory in bytes to keep on shrink, if the difference between " |
24 | "tensor sizes is bigger than this then tensor will be reset." ); |
25 | |
26 | namespace c10 { |
27 | |
28 | const char* const TensorImpl::err_msg_tensor_metadata_change_not_allowed = |
29 | "is not allowed on a Tensor created from .data or .detach().\n" |
30 | "If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)\n" |
31 | "without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.\n" |
32 | "For example, change:\n" |
33 | " x.data.set_(y)\n" |
34 | "to:\n" |
35 | " with torch.no_grad():\n" |
36 | " x.set_(y)" ; |
37 | |
38 | at::Tensor& TensorImpl::mutable_grad() { |
39 | if (!autograd_meta_) |
40 | autograd_meta_ = impl::GetAutogradMetaFactory()->make(); |
41 | return autograd_meta_->mutable_grad(); |
42 | } |
43 | |
44 | const at::Tensor& TensorImpl::grad() const { |
45 | // Yes, I know this looks really weird. But I don't really have a choice as |
46 | // long as this function returns a const reference to Tensor. I'm not |
47 | // really sure how I would have designed this API differently, but it |
48 | // is not so easy to fix right now because the mutable counterpart of |
49 | // this function must keep working so that "x.grad() = ..." keeps working |
50 | // (part of public API). |
51 | if (!autograd_meta_) |
52 | return impl::GetAutogradMetaFactory()->undefined_tensor(); |
53 | return autograd_meta_->grad(); |
54 | } |
55 | |
56 | const at::Tensor& TensorImpl::_fw_grad( |
57 | uint64_t level, |
58 | const at::TensorBase& self) const { |
59 | // See TensorImpl::grad() above for explanation about the line below |
60 | if (!autograd_meta_) |
61 | return impl::GetAutogradMetaFactory()->undefined_tensor(); |
62 | return autograd_meta_->fw_grad(level, self); |
63 | } |
64 | |
65 | void TensorImpl::_set_fw_grad( |
66 | const at::TensorBase& new_grad, |
67 | const at::TensorBase& self, |
68 | uint64_t level, |
69 | bool is_inplace_op) { |
70 | if (!autograd_meta_) |
71 | autograd_meta_ = impl::GetAutogradMetaFactory()->make(); |
72 | autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op); |
73 | } |
74 | |
75 | TensorImpl::~TensorImpl() { |
76 | pyobj_slot_.destroy_pyobj_if_needed(); |
77 | } |
78 | |
79 | TensorImpl::TensorImpl( |
80 | Storage&& storage, |
81 | DispatchKeySet key_set, |
82 | const caffe2::TypeMeta data_type) |
83 | // Use std::forward to suppress static analyzer false positive. |
84 | : TensorImpl( |
85 | std::forward<Storage>(storage), |
86 | key_set, |
87 | data_type, |
88 | storage.device()) {} |
89 | |
90 | // [Note: Python key removal] |
91 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
92 | // In most constructors for TensorImpl, you will see Python and |
93 | // PythonTLSSnapshot keys are removed from the passed in DispatchKeySet. Why? |
94 | // |
95 | // INVARIANT: Python and PythonTLSSnapshot dispatch keys are set iff PyObject |
96 | // for the Tensor has a nontrivial __torch_dispatch__ implementation. |
97 | // |
98 | // When a fresh TensorImpl is created, there is *no* PyObject (this only gets |
99 | // initialized lazily at the first point in time the Tensor passes into Python). |
100 | // So we would violate the invariant. |
101 | // |
102 | // In practice, what will happen shortly afterwards is that the TensorImpl |
103 | // will get its PyObject initialized by Tensor._make_subclass; at this point |
104 | // the Python and PythonTLSSnapshot dispatch keys will be set and all is well. |
105 | // The point is to delay the dispatch key setting until that point. |
106 | |
107 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
108 | TensorImpl::TensorImpl( |
109 | ImplType type, |
110 | Storage&& storage, |
111 | DispatchKeySet key_set, |
112 | const caffe2::TypeMeta data_type) |
113 | : storage_(std::move(storage)), |
114 | |
115 | numel_(0), |
116 | data_type_(data_type), |
117 | device_opt_(storage_.device()), |
118 | key_set_(key_set - c10::python_ks) { // See [Note: Python key removal] |
119 | init_bitfields(); |
120 | // Inference tensor doesn't have version counter. |
121 | if (!is_inference()) { |
122 | version_counter_ = VariableVersion(/*version=*/0); |
123 | } |
124 | } |
125 | |
126 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
127 | TensorImpl::TensorImpl( |
128 | DispatchKeySet key_set, |
129 | const caffe2::TypeMeta data_type, |
130 | c10::optional<c10::Device> device_opt) |
131 | : TensorImpl({}, key_set, data_type, device_opt) {} |
132 | |
133 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
134 | TensorImpl::TensorImpl( |
135 | Storage&& storage, |
136 | DispatchKeySet key_set, |
137 | const caffe2::TypeMeta data_type, |
138 | c10::optional<c10::Device> device_opt) |
139 | : storage_(std::move(storage)), |
140 | |
141 | numel_(0), |
142 | data_type_(data_type), |
143 | device_opt_(device_opt) { |
144 | init_bitfields(); |
145 | |
146 | if (!key_set.empty()) { |
147 | TORCH_INTERNAL_ASSERT( |
148 | data_type == ScalarType::Undefined || device_opt_.has_value()); |
149 | // UndefinedTensorImpl is a singleton, so we skip logging it |
150 | C10_LOG_API_USAGE_ONCE("tensor.create" ); |
151 | } |
152 | |
153 | // XXX: if updating keyset logic here also update |
154 | // _change_backend_component_keys |
155 | bool inference_mode = c10::InferenceMode::is_enabled(); |
156 | |
157 | // TODO: be more explicit about the full key set at call sites so we |
158 | // don't have to keep recomputing it here |
159 | auto k = key_set.highestBackendKey(); |
160 | |
161 | key_set = key_set | getAutocastRelatedKeySetFromBackend(k); |
162 | |
163 | // See [Note: Python key removal] |
164 | key_set = key_set - c10::python_ks; |
165 | |
166 | // Inference tensor doesn't have autograd related keys. |
167 | if (inference_mode) { |
168 | // See Note [Expected TLS state in InferenceMode] for why we exclude |
169 | // Autograd & ADInplaceOrView keys. Normally key_set only contains backend |
170 | // keys but we do the substraction here to make sure. |
171 | key_set_ = key_set - c10::autograd_dispatch_keyset_with_ADInplaceOrView; |
172 | } else { |
173 | // TODO: Ideally we only add AutogradBackend key when the tensor requires |
174 | // grad. |
175 | // See Note [Dream: skip VariableType kernel when requires_grad=false] |
176 | key_set_ = key_set | getAutogradRelatedKeySetFromBackend(k); |
177 | } |
178 | |
179 | // Inference tensor doesn't have version counter. |
180 | if (!is_inference()) { |
181 | version_counter_ = VariableVersion(/*version=*/0); |
182 | } |
183 | // we would also like to check that non-cpu devices have an index, but some |
184 | // Caffe2 operators create Storages with default devices. |
185 | } |
186 | |
187 | void TensorImpl::_change_backend_component_keys(c10::Device device) { |
188 | BackendComponent new_backend = toBackendComponent(device.type()); |
189 | BackendComponent old_backend = key_set_.highestBackendKey(); |
190 | |
191 | // following logic TensorImpl::TensorImpl, update the BackendComponent related |
192 | // keys to correspond to device |
193 | |
194 | // TODO: Autocoast should be a per-backend functionality key, once that change |
195 | // is made this key swap will not be necessary. |
196 | auto key_set = |
197 | key_set_ - c10::getAutocastRelatedKeySetFromBackend(old_backend); |
198 | key_set = key_set | c10::getAutocastRelatedKeySetFromBackend(new_backend); |
199 | |
200 | // See note [Removing keys from DispatchKeySet Only Affects Functionality |
201 | // Keys] |
202 | key_set = key_set.remove_backend(old_backend); |
203 | key_set_ = key_set | DispatchKeySet(new_backend); |
204 | } |
205 | |
206 | void TensorImpl::HandleResize() { |
207 | // If needed, we will free the data. the next mutable_data() call |
208 | // will create the data storage. |
209 | bool reset_tensor = false; |
210 | if (reserved_) { |
211 | // If tensor is reserved then don't claim its memeory unless nbytes() |
212 | // is smaller than new size |
213 | reset_tensor = |
214 | storage_.nbytes() < (storage_offset_ + numel_) * data_type_.itemsize(); |
215 | } else { |
216 | reset_tensor = storage_.nbytes() < |
217 | (storage_offset_ + numel_) * data_type_.itemsize() || |
218 | !FLAGS_caffe2_keep_on_shrink || |
219 | storage_.nbytes() - (storage_offset_ + numel_) * data_type_.itemsize() > |
220 | static_cast<size_t>(FLAGS_caffe2_max_keep_on_shrink_memory); |
221 | } |
222 | |
223 | if (reset_tensor && storage_initialized()) { |
224 | FreeMemory(); |
225 | } |
226 | } |
227 | |
228 | template <typename T> |
229 | bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) { |
230 | bool is_contiguous = true; |
231 | if (numel == 0) |
232 | return is_contiguous; |
233 | T z = 1; |
234 | // NB: make sure we do signed arithmetic |
235 | for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { |
236 | const auto& size_d = sizes[d]; |
237 | if (size_d != 1) { |
238 | if (strides[d] == z) { |
239 | z *= size_d; |
240 | } else { |
241 | is_contiguous = false; |
242 | break; |
243 | } |
244 | } |
245 | } |
246 | return is_contiguous; |
247 | } |
248 | |
249 | bool TensorImpl::compute_contiguous(identity<bool>) const { |
250 | if (is_sparse()) { |
251 | return false; |
252 | } |
253 | return _compute_contiguous<int64_t>( |
254 | sizes_and_strides_.sizes_arrayref(), |
255 | sizes_and_strides_.strides_arrayref(), |
256 | numel_); |
257 | } |
258 | |
259 | SymBool TensorImpl::compute_contiguous(identity<SymBool>) const { |
260 | if (is_sparse()) { |
261 | return false; |
262 | } |
263 | return _compute_contiguous<c10::SymInt>( |
264 | extra_meta_->sizes_, extra_meta_->strides_, extra_meta_->numel_); |
265 | } |
266 | |
267 | template <typename T> |
268 | bool _compute_channels_last_contiguous_2d( |
269 | ArrayRef<T> sizes, |
270 | ArrayRef<T> strides) { |
271 | // Please don't combine these code, constant array is used here to let |
272 | // compiler fully unroll the loop to get better performance |
273 | switch (sizes.size()) { |
274 | case 4: { |
275 | T expected = 1; |
276 | for (auto& d : {1, 3, 2, 0}) { |
277 | const auto& size_d = sizes[d]; |
278 | if (size_d != 1) { |
279 | if (strides[d] != expected) { |
280 | return false; |
281 | } |
282 | expected *= size_d; |
283 | } |
284 | } |
285 | return true; |
286 | } |
287 | // NOLINTNEXTLINE(bugprone-branch-clone) |
288 | case 3: |
289 | // TODO dim == 3 case will be enabled once it is fully tested |
290 | return false; |
291 | default: |
292 | return false; |
293 | } |
294 | } |
295 | |
296 | bool TensorImpl::compute_channels_last_contiguous_2d(identity<bool>) const { |
297 | if (is_sparse()) { |
298 | return false; |
299 | } |
300 | return _compute_channels_last_contiguous_2d<int64_t>( |
301 | sizes_and_strides_.sizes_arrayref(), |
302 | sizes_and_strides_.strides_arrayref()); |
303 | } |
304 | |
305 | SymBool TensorImpl::compute_channels_last_contiguous_2d( |
306 | identity<SymBool>) const { |
307 | if (is_sparse()) { |
308 | return false; |
309 | } |
310 | return _compute_channels_last_contiguous_2d<c10::SymInt>( |
311 | extra_meta_->sizes_, extra_meta_->strides_); |
312 | } |
313 | |
314 | template <typename T> |
315 | bool _compute_channels_last_contiguous_3d( |
316 | ArrayRef<T> sizes, |
317 | ArrayRef<T> strides) { |
318 | // Please don't combine these code, constant array is used here to let |
319 | // compiler fully unroll the loop to get better performance |
320 | switch (sizes.size()) { |
321 | case 5: { |
322 | T expected = 1; |
323 | for (auto& d : {1, 4, 3, 2, 0}) { |
324 | const auto& size_d = sizes[d]; |
325 | if (size_d != 1) { |
326 | if (strides[d] != expected) { |
327 | return false; |
328 | } |
329 | expected *= size_d; |
330 | } |
331 | } |
332 | return true; |
333 | } |
334 | // NOLINTNEXTLINE(bugprone-branch-clone) |
335 | case 4: |
336 | // TODO dim == 4 case will be enabled once it is fully tested |
337 | return false; |
338 | default: |
339 | return false; |
340 | } |
341 | } |
342 | |
343 | bool TensorImpl::compute_channels_last_contiguous_3d(identity<bool>) const { |
344 | if (is_sparse()) { |
345 | return false; |
346 | } |
347 | return _compute_channels_last_contiguous_3d<int64_t>( |
348 | sizes_and_strides_.sizes_arrayref(), |
349 | sizes_and_strides_.strides_arrayref()); |
350 | } |
351 | |
352 | SymBool TensorImpl::compute_channels_last_contiguous_3d( |
353 | identity<SymBool>) const { |
354 | if (is_sparse()) { |
355 | return false; |
356 | } |
357 | return _compute_channels_last_contiguous_3d<c10::SymInt>( |
358 | extra_meta_->sizes_, extra_meta_->strides_); |
359 | } |
360 | |
361 | bool TensorImpl::compute_strides_like_channels_last_2d(identity<bool>) const { |
362 | if (is_sparse()) { |
363 | return false; |
364 | } |
365 | return is_channels_last_strides_2d<int64_t>( |
366 | sizes_and_strides_.sizes_arrayref(), |
367 | sizes_and_strides_.strides_arrayref()); |
368 | } |
369 | |
370 | SymBool TensorImpl::compute_strides_like_channels_last_2d( |
371 | identity<SymBool>) const { |
372 | if (is_sparse()) { |
373 | return false; |
374 | } |
375 | return is_channels_last_strides_2d<c10::SymInt>( |
376 | extra_meta_->sizes_, extra_meta_->strides_); |
377 | } |
378 | |
379 | bool TensorImpl::compute_strides_like_channels_last_3d(identity<bool>) const { |
380 | if (is_sparse()) { |
381 | return false; |
382 | } |
383 | return is_channels_last_strides_3d<int64_t>( |
384 | sizes_and_strides_.sizes_arrayref(), |
385 | sizes_and_strides_.strides_arrayref()); |
386 | } |
387 | |
388 | SymBool TensorImpl::compute_strides_like_channels_last_3d( |
389 | identity<SymBool>) const { |
390 | if (is_sparse()) { |
391 | return false; |
392 | } |
393 | return is_channels_last_strides_3d<c10::SymInt>( |
394 | extra_meta_->sizes_, extra_meta_->strides_); |
395 | } |
396 | |
397 | template <typename T> |
398 | bool _compute_non_overlapping_and_dense( |
399 | ArrayRef<T> sizes, |
400 | ArrayRef<T> strides) { |
401 | auto dim = sizes.size(); |
402 | if (dim == 1) { |
403 | return sizes[0] < 2 || strides[0] == 1; |
404 | } |
405 | SmallVector<int64_t, 5> perm; |
406 | perm.resize(dim); |
407 | for (const auto i : c10::irange(dim)) { |
408 | perm[i] = i; |
409 | } |
410 | // Sort by strides, leaving 0 and 1 sized dims at the end of the array |
411 | std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { |
412 | if (sizes[a] < 2) { |
413 | return false; |
414 | } else if (sizes[b] < 2) { |
415 | return true; |
416 | } |
417 | return strides[a] < strides[b]; |
418 | }); |
419 | T require_stride = 1; |
420 | for (const auto i : c10::irange(dim)) { |
421 | const auto& size_perm_i = sizes[perm[i]]; |
422 | if (size_perm_i < 2) { |
423 | return true; |
424 | } |
425 | if (strides[perm[i]] != require_stride) { |
426 | return false; |
427 | } |
428 | require_stride *= size_perm_i; |
429 | } |
430 | return true; |
431 | } |
432 | |
433 | bool TensorImpl::compute_non_overlapping_and_dense(identity<bool>) const { |
434 | if (is_sparse()) { |
435 | return false; |
436 | } |
437 | return _compute_non_overlapping_and_dense<int64_t>( |
438 | sizes_and_strides_.sizes_arrayref(), |
439 | sizes_and_strides_.strides_arrayref()); |
440 | } |
441 | |
442 | SymBool TensorImpl::compute_non_overlapping_and_dense(identity<SymBool>) const { |
443 | if (is_sparse()) { |
444 | return false; |
445 | } |
446 | return _compute_non_overlapping_and_dense<c10::SymInt>( |
447 | extra_meta_->sizes_, extra_meta_->strides_); |
448 | } |
449 | |
450 | // Glue compute |
451 | // NB: intentionally not using bitwise operators. Using bitwise operators |
452 | // currently impedes ShapeEnv from getting crucial equalities which cause |
453 | // python test/functorch/test_aotdispatch.py -k |
454 | // test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 to run |
455 | // very slowly. I think probably we just need to be able to reason through |
456 | // And/Or, and then we can switch these to be symbolic. |
457 | |
458 | SymBool TensorImpl::compute_is_non_overlapping_and_dense_dim4( |
459 | identity<SymBool> type_id) { |
460 | return extra_meta_->is_contiguous_.guard_bool(__FILE__, __LINE__) || |
461 | extra_meta_->is_channels_last_contiguous_.guard_bool( |
462 | __FILE__, __LINE__) || |
463 | compute_non_overlapping_and_dense(type_id).guard_bool(__FILE__, __LINE__); |
464 | } |
465 | |
466 | SymBool TensorImpl::compute_channels_last_contiguous_3d_dim5( |
467 | identity<SymBool> type_id) { |
468 | return !extra_meta_->is_channels_last_contiguous_.guard_bool( |
469 | __FILE__, __LINE__) && |
470 | compute_channels_last_contiguous_3d(type_id).guard_bool( |
471 | __FILE__, __LINE__); |
472 | } |
473 | |
474 | SymBool TensorImpl::compute_channels_last_2d_dim5(identity<SymBool> type_id) { |
475 | return !extra_meta_->is_channels_last_3d_contiguous_.guard_bool( |
476 | __FILE__, __LINE__) && |
477 | compute_strides_like_channels_last_2d(type_id).guard_bool( |
478 | __FILE__, __LINE__); |
479 | } |
480 | |
481 | SymBool TensorImpl::compute_channels_last_3d_dim5(identity<SymBool> type_id) { |
482 | return !extra_meta_->is_channels_last_.guard_bool(__FILE__, __LINE__) && |
483 | compute_strides_like_channels_last_3d(type_id).guard_bool( |
484 | __FILE__, __LINE__); |
485 | } |
486 | |
487 | SymBool TensorImpl::compute_is_non_overlapping_and_dense_dim5( |
488 | identity<SymBool> type_id) { |
489 | return extra_meta_->is_contiguous_.guard_bool(__FILE__, __LINE__) || |
490 | extra_meta_->is_channels_last_contiguous_.guard_bool( |
491 | __FILE__, __LINE__) || |
492 | extra_meta_->is_channels_last_3d_contiguous_.guard_bool( |
493 | __FILE__, __LINE__) || |
494 | compute_non_overlapping_and_dense(type_id).guard_bool(__FILE__, __LINE__); |
495 | } |
496 | |
497 | SymBool TensorImpl::compute_is_non_overlapping_and_dense_anydim( |
498 | identity<SymBool> type_id) { |
499 | return extra_meta_->is_contiguous_.guard_bool(__FILE__, __LINE__) || |
500 | compute_non_overlapping_and_dense(type_id).guard_bool(__FILE__, __LINE__); |
501 | } |
502 | |
503 | void TensorImpl::release_resources() { |
504 | autograd_meta_.reset(); |
505 | if (storage_) { |
506 | storage_ = {}; |
507 | } |
508 | pyobj_slot_.destroy_pyobj_if_needed(); |
509 | } |
510 | |
511 | #ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY |
512 | bool TensorImpl::has_storage() const { |
513 | return storage_; |
514 | } |
515 | #endif |
516 | |
517 | void TensorImpl::throw_storage_access_error() const { |
518 | TORCH_CHECK_NOT_IMPLEMENTED( |
519 | false, "Cannot access storage of " , tensorimpl_type_name()); |
520 | } |
521 | |
522 | bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { |
523 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { |
524 | return pyobj_slot_.load_pyobj_interpreter()->is_contiguous( |
525 | this, memory_format); |
526 | } |
527 | return is_contiguous_default(memory_format); |
528 | } |
529 | |
530 | bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const { |
531 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { |
532 | return pyobj_slot_.load_pyobj_interpreter()->is_strides_like( |
533 | this, memory_format); |
534 | } |
535 | return is_strides_like_default(memory_format); |
536 | } |
537 | |
538 | bool TensorImpl::is_non_overlapping_and_dense_custom() const { |
539 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { |
540 | return pyobj_slot_.load_pyobj_interpreter()->is_non_overlapping_and_dense( |
541 | this); |
542 | } |
543 | return is_non_overlapping_and_dense_default(); |
544 | } |
545 | |
546 | IntArrayRef TensorImpl::sizes_custom() const { |
547 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) { |
548 | return pyobj_slot_.load_pyobj_interpreter()->sizes(this); |
549 | } |
550 | return sizes_default(); |
551 | } |
552 | |
553 | c10::SymIntArrayRef TensorImpl::sym_sizes_custom() const { |
554 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) { |
555 | return pyobj_slot_.load_pyobj_interpreter()->sym_sizes(this); |
556 | } |
557 | return sym_sizes_default(); |
558 | } |
559 | |
560 | c10::SymInt TensorImpl::sym_numel_custom() const { |
561 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) { |
562 | return pyobj_slot_.load_pyobj_interpreter()->sym_numel(this); |
563 | } |
564 | return sym_numel_default(); |
565 | } |
566 | |
567 | c10::SymIntArrayRef TensorImpl::sym_strides_custom() const { |
568 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { |
569 | return pyobj_slot_.load_pyobj_interpreter()->sym_strides(this); |
570 | } |
571 | return sym_strides_default(); |
572 | } |
573 | |
574 | c10::Device TensorImpl::device_custom() const { |
575 | if (C10_UNLIKELY(python_custom_device_)) { |
576 | return pyobj_slot_.load_pyobj_interpreter()->device(this); |
577 | } |
578 | return device_default(); |
579 | } |
580 | |
581 | IntArrayRef TensorImpl::strides_custom() const { |
582 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { |
583 | return pyobj_slot_.load_pyobj_interpreter()->strides(this); |
584 | } |
585 | return strides_default(); |
586 | } |
587 | |
588 | int64_t TensorImpl::dim_custom() const { |
589 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) { |
590 | return pyobj_slot_.load_pyobj_interpreter()->dim(this); |
591 | } |
592 | return dim_default(); |
593 | } |
594 | |
595 | int64_t TensorImpl::numel_custom() const { |
596 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) { |
597 | // TODO: fix this |
598 | return pyobj_slot_.load_pyobj_interpreter()->sym_numel(this).expect_int(); |
599 | } |
600 | return numel_default(); |
601 | } |
602 | |
603 | c10::Layout TensorImpl::layout_custom() const { |
604 | if (C10_UNLIKELY(python_custom_layout_)) { |
605 | return pyobj_slot_.load_pyobj_interpreter()->layout(this); |
606 | } |
607 | // TODO: fix this |
608 | TORCH_CHECK( |
609 | 0, "Tensors of type " , tensorimpl_type_name(), " do not have layout" ) |
610 | // return layout_default(); |
611 | } |
612 | |
613 | int64_t TensorImpl::storage_offset_custom() const { |
614 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) { |
615 | // TODO: fix this |
616 | return pyobj_slot_.load_pyobj_interpreter() |
617 | ->sym_storage_offset(this) |
618 | .expect_int(); |
619 | } |
620 | return storage_offset_default(); |
621 | } |
622 | |
623 | c10::SymInt TensorImpl::sym_storage_offset_custom() const { |
624 | if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) { |
625 | return pyobj_slot_.load_pyobj_interpreter()->sym_storage_offset(this); |
626 | } |
627 | return sym_storage_offset_default(); |
628 | } |
629 | |
630 | static void deletePlacementDeleteContext(void* ptr) { |
631 | delete static_cast<PlacementDeleteContext*>(ptr); |
632 | } |
633 | |
634 | at::DataPtr PlacementDeleteContext::makeDataPtr( |
635 | at::DataPtr&& data_ptr, |
636 | PlacementDtor placement_dtor, |
637 | size_t size, |
638 | at::Device device) { |
639 | auto* ptr = data_ptr.get(); |
640 | return { |
641 | ptr, |
642 | new PlacementDeleteContext(std::move(data_ptr), placement_dtor, size), |
643 | &deletePlacementDeleteContext, |
644 | device}; |
645 | } |
646 | |
647 | AutogradMetaInterface::~AutogradMetaInterface() = default; |
648 | |
649 | // Setting requires_grad to true on inference tensor outside InferenceMode |
650 | // is forbidden. Ideally it would also be illegal inside InferenceMode. |
651 | // But there's no way that we can directly allocate a tensor to have |
652 | // requires_grad = true in C++ constructor so set_requires_grad is widely |
653 | // used in C++ frontend. Forbidding it inside InferenceMode will force users |
654 | // to delete these setter code in their code which is not ideal. |
655 | void TensorImpl::set_requires_grad(bool requires_grad) { |
656 | TORCH_CHECK( |
657 | !(requires_grad && is_inference() && !c10::InferenceMode::is_enabled()), |
658 | "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed." ); |
659 | if (!requires_grad && !autograd_meta_) |
660 | return; |
661 | if (!autograd_meta_) |
662 | autograd_meta_ = impl::GetAutogradMetaFactory()->make(); |
663 | // NB: In principle, setting requires_grad to false could result in |
664 | // the AutogradMeta becoming equal to a default constructed state, |
665 | // in which case we could apply the nullptr AutogradMeta optimization |
666 | // (see autograd_meta_ docs). But we don't do this right now. Note |
667 | // that it is unsound to unconditionally set AutogradMeta to false |
668 | // when you set requires_grad to False, as there may be nontrivial |
669 | // information content in the other fields; for example, we may |
670 | // have set the string name for a Variable, or there may be hooks |
671 | // registered for it. |
672 | autograd_meta_->set_requires_grad(requires_grad, this); |
673 | } |
674 | |
675 | bool TensorImpl::requires_grad() const { |
676 | if (!autograd_meta_) |
677 | return false; |
678 | return autograd_meta_->requires_grad(); |
679 | } |
680 | |
681 | void TensorImpl::set_autograd_meta( |
682 | std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) { |
683 | // NB: autograd_meta may be null! That just means it's the default |
684 | // constructor |
685 | autograd_meta_ = std::move(autograd_meta); |
686 | } |
687 | |
688 | c10::AutogradMetaInterface* TensorImpl::autograd_meta() const { |
689 | // NB: Might return null! |
690 | return autograd_meta_.get(); |
691 | } |
692 | |
693 | template <typename VariableVersion> |
694 | c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach_core( |
695 | VariableVersion&& version_counter, |
696 | bool allow_tensor_metadata_change) const { |
697 | c10::intrusive_ptr<TensorImpl> r; |
698 | const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); |
699 | // TODO: do we have to exclude after Python dispatch key set? |
700 | if (mode_stack_len > 0 && |
701 | !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { |
702 | const auto& cur_torch_dispatch_mode_state = |
703 | c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); |
704 | r = cur_torch_dispatch_mode_state->pyinterpreter()->detach(this); |
705 | } else if ( |
706 | key_set_.has(DispatchKey::Python) && |
707 | !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { |
708 | r = (pyobj_slot_.load_pyobj_interpreter())->detach(this); |
709 | } |
710 | if (r) { |
711 | r->set_version_counter(std::forward<VariableVersion>(version_counter)); |
712 | r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); |
713 | return r; |
714 | } |
715 | // otherwise just copy the TensorImpl and not the PyObject. Since |
716 | // the interpreter is dead no one can call us out on it |
717 | auto impl = c10::make_intrusive<TensorImpl>( |
718 | // No need to populate Storage; copy_tensor_metadata will do it for us. |
719 | key_set_, |
720 | data_type_, |
721 | device_opt_); |
722 | copy_tensor_metadata( |
723 | /*src_impl=*/this, |
724 | /*dest_impl=*/impl.get(), |
725 | /*version_counter=*/std::forward<VariableVersion>(version_counter), |
726 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
727 | |
728 | impl->refresh_numel(); |
729 | impl->refresh_contiguous(); |
730 | return impl; |
731 | } |
732 | |
733 | c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach( |
734 | const c10::VariableVersion& version_counter, |
735 | bool allow_tensor_metadata_change) const { |
736 | return shallow_copy_and_detach_core( |
737 | version_counter, allow_tensor_metadata_change); |
738 | } |
739 | |
740 | c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach( |
741 | c10::VariableVersion&& version_counter, |
742 | bool allow_tensor_metadata_change) const { |
743 | return shallow_copy_and_detach_core( |
744 | std::move(version_counter), allow_tensor_metadata_change); |
745 | } |
746 | |
747 | // This function copies all of the metadata from the src tensor except for: |
748 | // - key_set_ |
749 | // - storage_ |
750 | // - storage_access_should_throw_ |
751 | // - sizes_strides_policy_ |
752 | // - version_counter_ |
753 | // - allow_tensor_metadata_change_ |
754 | // The idea is that if we have a "wrapper tensor" (like in functionalization), |
755 | // all of the above are properties that the wrapper will want to customize, |
756 | // while everything else should be mirrored between the wrapper and the inner |
757 | // tensor. |
758 | void TensorImpl::copy_generic_tensor_metadata( |
759 | const TensorImpl* src_impl, |
760 | TensorImpl* dest_impl) { |
761 | dest_impl->sizes_and_strides_ = src_impl->sizes_and_strides_; |
762 | dest_impl->has_symbolic_sizes_strides_ = |
763 | src_impl->has_symbolic_sizes_strides_; |
764 | |
765 | dest_impl->storage_offset_ = src_impl->storage_offset_; |
766 | dest_impl->data_type_ = src_impl->data_type_; |
767 | dest_impl->device_opt_ = src_impl->device_opt_; |
768 | dest_impl->is_contiguous_ = src_impl->is_contiguous_; |
769 | dest_impl->is_channels_last_contiguous_ = |
770 | src_impl->is_channels_last_contiguous_; |
771 | dest_impl->is_channels_last_3d_contiguous_ = |
772 | src_impl->is_channels_last_3d_contiguous_; |
773 | dest_impl->is_channels_last_ = src_impl->is_channels_last_; |
774 | dest_impl->is_channels_last_3d_ = src_impl->is_channels_last_3d_; |
775 | dest_impl->is_non_overlapping_and_dense_ = |
776 | src_impl->is_non_overlapping_and_dense_; |
777 | dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_; |
778 | dest_impl->reserved_ = src_impl->reserved_; |
779 | if (src_impl->extra_meta_ != nullptr) { |
780 | dest_impl->extra_meta_ = src_impl->extra_meta_->clone(); |
781 | } |
782 | |
783 | // NB: symbolic sizes and strides are copied as is custom policy, but python |
784 | // policy is NOT (you have no Python object to dispatch to!) |
785 | // NB: subclass relevant policy doesn't have to be copied; the |
786 | // constructor sets this up |
787 | |
788 | dest_impl->refresh_sizes_strides_policy(); |
789 | dest_impl->refresh_layout_policy(); |
790 | dest_impl->refresh_device_policy(); |
791 | } |
792 | |
793 | void TensorImpl::copy_tensor_metadata_except_version_counter( |
794 | const TensorImpl* src_impl, |
795 | TensorImpl* dest_impl, |
796 | bool allow_tensor_metadata_change) { |
797 | // First call the generic copy function |
798 | copy_generic_tensor_metadata(src_impl, dest_impl); |
799 | // Then copy everything else (see the comment at copy_generic_tensor_metadata |
800 | // for the list of metadata that it does not directly copy). |
801 | dest_impl->storage_ = src_impl->storage_; |
802 | // Copying tensor metadata doesn't change the PyObject (maybe |
803 | // it should), which means that we have to preserve whatever the |
804 | // original Python keyset was (as it's associated with the PyObject |
805 | // being a tensor subclass or not) |
806 | dest_impl->key_set_ = (src_impl->key_set_ - c10::python_ks) | |
807 | (dest_impl->key_set_ & c10::python_ks); |
808 | dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); |
809 | dest_impl->storage_access_should_throw_ = |
810 | src_impl->storage_access_should_throw_; |
811 | } |
812 | |
813 | void TensorImpl::copy_tensor_metadata( |
814 | const TensorImpl* src_impl, |
815 | TensorImpl* dest_impl, |
816 | const c10::VariableVersion& version_counter, |
817 | bool allow_tensor_metadata_change) { |
818 | copy_tensor_metadata_except_version_counter( |
819 | src_impl, dest_impl, allow_tensor_metadata_change); |
820 | // TODO: In the ideal end state, it's okay to set disabled version_counter |
821 | // on inference tensor since it's a no-op. This requires refactor on call |
822 | // sites. |
823 | if (!dest_impl->is_inference()) { |
824 | dest_impl->set_version_counter(version_counter); |
825 | } |
826 | } |
827 | |
828 | void TensorImpl::copy_tensor_metadata( |
829 | const TensorImpl* src_impl, |
830 | TensorImpl* dest_impl, |
831 | c10::VariableVersion&& version_counter, |
832 | bool allow_tensor_metadata_change) { |
833 | copy_tensor_metadata_except_version_counter( |
834 | src_impl, dest_impl, allow_tensor_metadata_change); |
835 | if (!dest_impl->is_inference()) { |
836 | dest_impl->set_version_counter(std::move(version_counter)); |
837 | } |
838 | } |
839 | |
840 | // Legacy Caffe2 operations |
841 | |
842 | void TensorImpl::Extend(int64_t num, float growthPct) { |
843 | TORCH_CHECK(sizes_and_strides_.size() >= 1u); |
844 | TORCH_CHECK(num >= 0, "`num` must be non-negative for Extend" ); |
845 | TORCH_CHECK( |
846 | is_contiguous_, |
847 | "Right now Extend is only supported for contiguous Tensor." ); |
848 | TORCH_CHECK( |
849 | !has_symbolic_sizes_strides_, |
850 | "Extend() called on tensor with symbolic shape" ) |
851 | |
852 | using SizesVector = SmallVector<int64_t, 5>; |
853 | IntArrayRef sizes_and_strides = sizes_and_strides_.sizes_arrayref(); |
854 | SizesVector newDims(sizes_and_strides.begin(), sizes_and_strides.end()); |
855 | newDims[0] += num; |
856 | if (!storage_.data()) { |
857 | Resize(newDims); |
858 | return; |
859 | } |
860 | const auto newNumel = c10::multiply_integers(newDims.begin(), newDims.end()); |
861 | if (newNumel * data_type_.itemsize() <= storage_.nbytes()) { |
862 | sizes_and_strides_.set_sizes(newDims); |
863 | numel_ = newNumel; |
864 | return; |
865 | } |
866 | SizesVector newCapacity(sizes_and_strides.begin(), sizes_and_strides.end()); |
867 | newCapacity[0] = std::max( |
868 | newDims[0], |
869 | static_cast<int64_t>(std::ceil( |
870 | static_cast<float>(sizes_and_strides_.size_at_unchecked(0)) * |
871 | (1 + growthPct / 100)))); |
872 | auto oldData = std::move(storage_.data_ptr()); |
873 | auto oldSize = numel_; |
874 | Resize(std::move(newCapacity)); |
875 | auto* newData = raw_mutable_data(data_type_); |
876 | if (data_type_.copy()) { |
877 | TORCH_CHECK( |
878 | device_type() == DeviceType::CPU, "non-POD types work only on CPU" ); |
879 | data_type_.copy()(oldData.get(), newData, oldSize); |
880 | } else { |
881 | // The following copy uses the current (thread local) stream for copying |
882 | // and also takes the GPU id from the device() field passed in. |
883 | // |
884 | // TODO: Potentially more enforcements are necessary to avoid accidental |
885 | // switch to sync copy if the currently set device is wrong. |
886 | // |
887 | // Specifically, we might need to switch to a different context device |
888 | // here explicitly to avoid relying on user synchronizing things |
889 | // properly. |
890 | CopyBytes( |
891 | oldSize * itemsize(), |
892 | oldData.get(), |
893 | device(), |
894 | newData, |
895 | device(), |
896 | true); // non-blocking |
897 | } |
898 | reserved_ = true; |
899 | sizes_and_strides_.set_sizes(newDims); |
900 | numel_ = newNumel; |
901 | } |
902 | |
903 | void TensorImpl::ReserveSpace(int64_t outer_dim) { |
904 | TORCH_CHECK( |
905 | is_contiguous_, |
906 | "Right now ReserveSpace is only supported for contiguous Tensor." ); |
907 | TORCH_CHECK( |
908 | !has_symbolic_sizes_strides_, |
909 | "ReserveSpace() called on tensor with symbolic shape" ) |
910 | |
911 | TORCH_CHECK(storage_.unique(), "Can't call ReserveSpace on shared storage." ); |
912 | // TODO: eliminate newCapacity. |
913 | IntArrayRef sizes_and_strides = sizes_and_strides_.sizes_arrayref(); |
914 | SmallVector<int64_t, 5> newCapacity( |
915 | sizes_and_strides.begin(), sizes_and_strides.end()); |
916 | newCapacity[0] = outer_dim; |
917 | auto newNumel = c10::multiply_integers(newCapacity); |
918 | if (newNumel * data_type_.itemsize() <= storage_.nbytes()) { |
919 | return; |
920 | } |
921 | // Old data is discarded |
922 | storage_.data_ptr().clear(); |
923 | auto oldSize = numel_; |
924 | SmallVector<int64_t, 5> oldDims( |
925 | sizes_and_strides.begin(), sizes_and_strides.end()); |
926 | Resize(std::move(newCapacity)); |
927 | // Allocate new memory but don't copy over the data |
928 | raw_mutable_data(data_type_); |
929 | sizes_and_strides_.set_sizes(oldDims); |
930 | numel_ = oldSize; |
931 | reserved_ = true; |
932 | } |
933 | |
934 | void TensorImpl::Reshape(const std::vector<int64_t>& dims) { |
935 | TORCH_CHECK( |
936 | is_contiguous_, |
937 | "Right now Reshape is only supported for contiguous Tensor." ); |
938 | TORCH_CHECK( |
939 | !has_symbolic_sizes_strides_, |
940 | "Reshape() called on tensor with symbolic shape" ) |
941 | |
942 | int64_t new_size = 1; |
943 | for (auto d : dims) { |
944 | TORCH_CHECK(d >= 0); |
945 | new_size *= d; |
946 | } |
947 | TORCH_CHECK( |
948 | new_size == numel_, |
949 | "New size and old size are not equal. You cannot use Reshape, " |
950 | "but should use Resize." |
951 | // TODO(jiayq): remove the following warning after pending diffs |
952 | // stabilize. |
953 | " The old caffe2 mixes Reshape and Resize but this behavior has " |
954 | "been changed. If you find this error, most likely you will need " |
955 | "to change corresponding code from Reshape to Resize." ); |
956 | sizes_and_strides_.set_sizes(dims); |
957 | empty_tensor_restride(MemoryFormat::Contiguous); |
958 | } |
959 | |
960 | void TensorImpl::FreeMemory() { |
961 | // We'll detach from the old Storage and create a new one |
962 | if (storage_.use_count() != 1 || !storage_.resizable() || |
963 | !storage_.allocator()) { |
964 | storage_ = Storage::create_legacy(storage_.device()); |
965 | } else { |
966 | storage_.reset_legacy(); |
967 | } |
968 | storage_offset_ = 0; |
969 | } |
970 | |
971 | void TensorImpl::ShareData(const TensorImpl& src) { |
972 | // Right now, we are assuming the device_type are the same, since it is |
973 | // inherently the same in the non-templatized code. We should probably add |
974 | // an assert here which might affect perf a little bit. |
975 | TORCH_CHECK( |
976 | src.numel_ == numel_, |
977 | "Size mismatch - did you call reshape before sharing the data?" ); |
978 | // It is possible that the source tensor hasn't called mutable_data() yet, |
979 | // in which case ShareData() doesn't make much sense since we don't really |
980 | // know what to share yet. |
981 | // TODO: Add the assert after all uninitialized states are eliminated |
982 | // TORCH_CHECK(src.dtype_initialized(), |
983 | // "Source tensor don't have a data type (did you call |
984 | // mutable_data<T> on the tensor?)"); |
985 | if (!src.dtype_initialized()) { |
986 | C10_LOG_EVERY_MS(WARNING, 1000) |
987 | << "Source tensor don't have a data type (did you call mutable_data<T> on the tensor?)" ; |
988 | } |
989 | TORCH_CHECK( |
990 | src.storage_initialized(), |
991 | "Source tensor has no content and has size > 0" ); |
992 | // Finally, do sharing. |
993 | /* Since we create new Storage whenever we need to change data_type/nbytes |
994 | * this still keeps the original semantics |
995 | */ |
996 | storage_ = src.storage(); |
997 | data_type_ = src.dtype(); |
998 | device_opt_ = src.device_opt(); |
999 | storage_offset_ = src.storage_offset(); |
1000 | } |
1001 | |
1002 | void TensorImpl::ShareExternalPointer( |
1003 | DataPtr&& data_ptr, |
1004 | const caffe2::TypeMeta data_type, |
1005 | size_t size_bytes) { |
1006 | TORCH_CHECK( |
1007 | data_type != ScalarType::Undefined, |
1008 | "To share with a raw external pointer you need to pass in an " |
1009 | "initialized data_type(TypeMeta)." ); |
1010 | TORCH_CHECK( |
1011 | !has_symbolic_sizes_strides_, |
1012 | "ShareExternalPointer() called on tensor with symbolic shape" ); |
1013 | if (!size_bytes) { |
1014 | size_bytes = numel_ * data_type.itemsize(); |
1015 | } |
1016 | if (storage_.unique()) { |
1017 | storage_.UniqueStorageShareExternalPointer(std::move(data_ptr), size_bytes); |
1018 | data_type_ = data_type; |
1019 | device_opt_ = storage_.device(); |
1020 | storage_offset_ = 0; |
1021 | } else { |
1022 | // Create a new Storage |
1023 | storage_ = Storage( |
1024 | Storage::use_byte_size_t(), |
1025 | size_bytes, |
1026 | std::move(data_ptr), |
1027 | /*allocator=*/nullptr, |
1028 | /*resizable=*/false); |
1029 | data_type_ = data_type; |
1030 | device_opt_ = storage_.device(); |
1031 | storage_offset_ = 0; |
1032 | } |
1033 | } |
1034 | |
1035 | void clone_symvec(SymIntArrayRef src, SymDimVector& dst) { |
1036 | dst.clear(); |
1037 | dst.reserve(src.size()); |
1038 | for (const auto& i : src) { |
1039 | dst.emplace_back(i.clone()); |
1040 | } |
1041 | } |
1042 | |
1043 | // NB: this doesn't check that the sizes/strides/offset are in bound for the |
1044 | // storage, and furthermore, it CANNOT do so as in some cases we temporarily |
1045 | // violate invariants by first setting sizes/strides, and then updating the |
1046 | // storage |
1047 | void TensorImpl::set_sizes_and_strides( |
1048 | c10::SymIntArrayRef sizes, |
1049 | c10::SymIntArrayRef strides, |
1050 | c10::optional<c10::SymInt> storage_offset) { |
1051 | auto int_sizes = asIntArrayRefSlowOpt(sizes); |
1052 | auto int_strides = asIntArrayRefSlowOpt(strides); |
1053 | if (int_sizes && int_strides && |
1054 | (!storage_offset.has_value() || !storage_offset->is_symbolic()) && |
1055 | !has_symbolic_sizes_strides_) { |
1056 | set_sizes_and_strides(*int_sizes, *int_strides); |
1057 | if (storage_offset.has_value()) |
1058 | set_storage_offset(storage_offset->as_int_unchecked()); |
1059 | return; |
1060 | } |
1061 | TORCH_CHECK( |
1062 | allow_tensor_metadata_change(), |
1063 | "set_sizes_and_strides " , |
1064 | err_msg_tensor_metadata_change_not_allowed); |
1065 | |
1066 | has_symbolic_sizes_strides_ = true; |
1067 | refresh_sizes_strides_policy(); |
1068 | if (!extra_meta_) { |
1069 | extra_meta_ = std::make_unique<ExtraMeta>(); |
1070 | if (!storage_offset.has_value()) { |
1071 | extra_meta_->storage_offset_ = storage_offset_; |
1072 | } |
1073 | } |
1074 | clone_symvec(sizes, extra_meta_->sizes_); |
1075 | clone_symvec(strides, extra_meta_->strides_); |
1076 | if (storage_offset.has_value()) |
1077 | extra_meta_->storage_offset_ = storage_offset->clone(); |
1078 | |
1079 | refresh_numel(); |
1080 | refresh_contiguous(); |
1081 | } |
1082 | |
1083 | void TensorImpl::generic_set_sizes_contiguous(SymIntArrayRef sizes) { |
1084 | auto int_sizes = asIntArrayRefSlowOpt(sizes); |
1085 | if (int_sizes.has_value()) { |
1086 | set_sizes_contiguous(*int_sizes); |
1087 | return; |
1088 | } |
1089 | |
1090 | TORCH_CHECK( |
1091 | allow_tensor_metadata_change(), |
1092 | "generic_set_sizes_contiguous " , |
1093 | err_msg_tensor_metadata_change_not_allowed); |
1094 | |
1095 | has_symbolic_sizes_strides_ = true; |
1096 | refresh_sizes_strides_policy(); |
1097 | if (!extra_meta_) { |
1098 | extra_meta_ = std::make_unique<ExtraMeta>(); |
1099 | extra_meta_->storage_offset_ = storage_offset_; |
1100 | } |
1101 | |
1102 | clone_symvec(sizes, extra_meta_->sizes_); |
1103 | refresh_numel(); |
1104 | empty_tensor_restride_symint( |
1105 | MemoryFormat::Contiguous); // calls refresh_contiguous() |
1106 | } |
1107 | |
1108 | void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) { |
1109 | TORCH_INTERNAL_ASSERT(has_symbolic_sizes_strides_); |
1110 | #ifdef DEBUG |
1111 | TORCH_INTERNAL_ASSERT( |
1112 | compute_numel() == numel_, |
1113 | "If you are seeing this error, that means empty_tensor_restride was " |
1114 | "called before setting correct numel" ); |
1115 | #endif |
1116 | switch (memory_format) { |
1117 | case MemoryFormat::Contiguous: { |
1118 | // dim_ is a virtual call, don't repeat it |
1119 | const auto dim_ = dim(); |
1120 | extra_meta_->strides_.resize(dim_); |
1121 | if (dim_ > 0) { |
1122 | const auto last_idx = dim_ - 1; |
1123 | extra_meta_->strides_[last_idx] = c10::SymInt(1); |
1124 | for (auto i = last_idx - 1; i >= 0; --i) { |
1125 | extra_meta_->strides_[last_idx] = |
1126 | extra_meta_->strides_[i + 1] * extra_meta_->sizes_[i + 1].max(1); |
1127 | } |
1128 | } |
1129 | break; |
1130 | } |
1131 | case MemoryFormat::ChannelsLast: { |
1132 | TORCH_CHECK( |
1133 | dim() == 4, "required rank 4 tensor to use channels_last format" ); |
1134 | set_sizes_and_strides( |
1135 | sym_sizes(), get_channels_last_strides_2d(sym_sizes())); |
1136 | break; |
1137 | } |
1138 | case MemoryFormat::ChannelsLast3d: { |
1139 | TORCH_CHECK( |
1140 | dim() == 5, "required rank 5 tensor to use channels_last_3d format" ); |
1141 | set_sizes_and_strides( |
1142 | sym_sizes(), get_channels_last_strides_3d(sym_sizes())); |
1143 | break; |
1144 | } |
1145 | case MemoryFormat::Preserve: |
1146 | TORCH_CHECK(false, "unsupported memory format " , memory_format); |
1147 | // Cleaning warning messages, no need to break as TORCH_CHECK(false) |
1148 | // terminates flow. |
1149 | // break; |
1150 | case MemoryFormat::NumOptions: |
1151 | TORCH_INTERNAL_ASSERT(false, "invalid memory format " , memory_format); |
1152 | } |
1153 | // recompute contiguous flag, as currently NHWC/NCHW flags are not mutually |
1154 | // exclusive see #24090 |
1155 | refresh_contiguous(); |
1156 | } |
1157 | |
1158 | namespace impl { |
1159 | |
1160 | namespace { |
1161 | AutogradMetaFactory* meta_factory = nullptr; |
1162 | } // namespace |
1163 | |
1164 | void SetAutogradMetaFactory(AutogradMetaFactory* factory) { |
1165 | meta_factory = factory; |
1166 | } |
1167 | AutogradMetaFactory* GetAutogradMetaFactory() { |
1168 | TORCH_CHECK( |
1169 | meta_factory, |
1170 | "Support for autograd has not been loaded; have you linked against libtorch.so?" ) |
1171 | return meta_factory; |
1172 | } |
1173 | |
1174 | } // namespace impl |
1175 | |
1176 | } // namespace c10 |
1177 | |