1 | #pragma once |
2 | |
3 | #include <c10/core/Device.h> |
4 | #include <c10/core/Layout.h> |
5 | #include <c10/core/MemoryFormat.h> |
6 | #include <c10/core/ScalarType.h> |
7 | #include <c10/core/ScalarTypeToTypeMeta.h> |
8 | #include <c10/core/Storage.h> |
9 | #include <c10/core/TensorImpl.h> |
10 | #include <c10/core/TensorOptions.h> |
11 | #include <c10/core/UndefinedTensorImpl.h> |
12 | #include <c10/core/WrapDimMinimal.h> |
13 | #include <c10/util/Exception.h> |
14 | #include <c10/util/ExclusivelyOwnedTensorTraits.h> |
15 | #include <c10/util/MaybeOwned.h> |
16 | #include <c10/util/Optional.h> |
17 | #include <c10/util/intrusive_ptr.h> |
18 | |
19 | #include <ATen/core/NamedTensor.h> |
20 | #include <ATen/core/QuantizerBase.h> |
21 | #include <c10/core/SymIntArrayRef.h> |
22 | #include <ATen/core/TensorAccessor.h> |
23 | |
24 | namespace c10 { |
25 | class Scalar; |
26 | } |
27 | |
28 | namespace torch { namespace autograd { |
29 | |
30 | struct Node; |
31 | |
32 | }} // namespace torch::autograd |
33 | |
34 | namespace at { |
35 | |
36 | class Tensor; |
37 | class TensorBase; |
38 | |
39 | // Convert Tensor to TensorBase without any need to include Tensor.h |
40 | TORCH_API const TensorBase& get_tensor_base(const Tensor& t); |
41 | |
42 | namespace impl { |
43 | inline bool variable_excluded_from_dispatch() { |
44 | #ifdef C10_MOBILE |
45 | // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change. |
46 | return true; |
47 | #else |
48 | return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset); |
49 | #endif |
50 | } |
51 | |
52 | } |
53 | |
54 | // NOTE: [Tensor vs. TensorBase] |
55 | // |
56 | // Tensor, being the central data structure in PyTorch, gets used and |
57 | // it's header included almost everywhere. Unfortunately this means |
58 | // every time an operator signature is updated or changed in |
59 | // native_functions.yaml, you (and every other PyTorch developer) need |
60 | // to recompile all of ATen and it's dependencies. |
61 | // |
62 | // TensorBase aims to break up these header dependencies, and improve |
63 | // incremental build times for all PyTorch developers. TensorBase |
64 | // represents a reference counted handle to TensorImpl, exactly the |
65 | // same as Tensor. However, TensorBase doesn't have code generated |
66 | // methods in it's API and thus no dependence on native_functions.yaml. |
67 | // |
68 | // Usage tips |
69 | // ---------- |
70 | // - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp |
71 | // or .cu file to ensure it has no header dependencies on |
72 | // native_functions.yaml (direct or indirect). |
73 | // - Tensor inherits from TensorBase, so functions taking |
74 | // `const TensorBase &` are callable with Tensor as well. |
75 | // - TensorBase can be converted to tensor with `Tensor(tensor_base)`, |
76 | // but this requires a reference-count bump. OptionalTensorRef on |
77 | // the other hand can materialize a `const Tensor &` without |
78 | // touching the reference-count. |
79 | class TORCH_API TensorBase { |
80 | public: |
81 | struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; }; |
82 | |
83 | protected: |
84 | // Create a Tensor with a +0 reference count. Special care must be |
85 | // taken to avoid decrementing this reference count at destruction |
86 | // time. Intended to support MaybeOwnedTraits<Tensor>. |
87 | explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs) |
88 | : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>::reclaim(rhs.impl_.get())) {} |
89 | friend MaybeOwnedTraits<TensorBase>; |
90 | |
91 | public: |
92 | TensorBase() = default; |
93 | // This constructor should not be used by end users and is an implementation |
94 | // detail invoked by autogenerated code. |
95 | explicit TensorBase( |
96 | c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) |
97 | : impl_(std::move(tensor_impl)) { |
98 | if (impl_.get() == nullptr) { |
99 | throw std::runtime_error("TensorImpl with nullptr is not supported" ); |
100 | } |
101 | } |
102 | TensorBase(const TensorBase&) = default; |
103 | TensorBase(TensorBase&&) = default; |
104 | |
105 | public: |
106 | // Creates a new wrapper from TensorImpl. Intentionally a free method because |
107 | // it should be used with care. Checks necessary invariants |
108 | static TensorBase wrap_tensor_impl( |
109 | c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) { |
110 | TensorBase r(std::move(tensor_impl)); |
111 | r.enforce_invariants(); |
112 | return r; |
113 | } |
114 | |
115 | int64_t dim() const { |
116 | return impl_->dim(); |
117 | } |
118 | int64_t storage_offset() const { |
119 | return impl_->storage_offset(); |
120 | } |
121 | |
122 | TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { |
123 | if (is_contiguous(memory_format)) { |
124 | return *this; |
125 | } else { |
126 | return __dispatch_contiguous(memory_format); |
127 | } |
128 | } |
129 | |
130 | /// Should be used if *this can reasonably be expected to be contiguous and |
131 | /// performance is important. |
132 | /// Compared to contiguous, it saves a reference count |
133 | /// increment/decrement if *this is already contiguous, at the cost |
134 | /// in all cases of an extra pointer of stack usage, an extra branch |
135 | /// to access, and an extra branch at destruction time. |
136 | c10::MaybeOwned<TensorBase> expect_contiguous( |
137 | MemoryFormat memory_format=MemoryFormat::Contiguous) const &; |
138 | |
139 | // Use .contiguous() instead. Trying to borrow from a prvalue |
140 | // will only lead to trouble and dangling references. |
141 | c10::MaybeOwned<TensorBase> expect_contiguous( |
142 | MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete; |
143 | |
144 | const TensorBase& fill_(const c10::Scalar& scalar) const; |
145 | const TensorBase& zero_() const; |
146 | |
147 | TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, c10::optional<at::MemoryFormat> memory_format=c10::nullopt) const; |
148 | |
149 | bool is_complex() const { |
150 | return at::isComplexType(this->scalar_type()); |
151 | } |
152 | |
153 | bool is_floating_point() const { |
154 | return at::isFloatingType(this->scalar_type()); |
155 | } |
156 | |
157 | bool is_signed() const { |
158 | return at::isSignedType(this->scalar_type()); |
159 | } |
160 | |
161 | c10::SymInt sym_size(int64_t dim) const { |
162 | return impl_->sym_size(dim); |
163 | } |
164 | |
165 | c10::SymInt sym_stride(int64_t dim) const { |
166 | const auto sizes = this->sym_strides(); |
167 | const auto ndim = static_cast<int64_t>(sizes.size()); |
168 | // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) |
169 | return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)]; |
170 | |
171 | } |
172 | |
173 | int64_t size(int64_t dim) const { |
174 | return impl_->size(dim); |
175 | } |
176 | |
177 | int64_t stride(int64_t dim) const { |
178 | const auto strides = this->strides(); |
179 | const auto ndim = static_cast<int64_t>(strides.size()); |
180 | // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) |
181 | return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)]; |
182 | } |
183 | |
184 | TensorImpl * unsafeGetTensorImpl() const { |
185 | return impl_.get(); |
186 | } |
187 | TensorImpl * unsafeReleaseTensorImpl() { |
188 | return impl_.release(); |
189 | } |
190 | const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const { |
191 | return impl_; |
192 | } |
193 | |
194 | c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr() { |
195 | return std::move(impl_); |
196 | } |
197 | |
198 | bool defined() const { |
199 | return impl_; |
200 | } |
201 | |
202 | void reset() { |
203 | impl_.reset(); |
204 | } |
205 | |
206 | TensorBase& operator=(const TensorBase& x) & { |
207 | impl_ = x.impl_; |
208 | return *this; |
209 | }; |
210 | TensorBase& operator=(TensorBase&& x) & noexcept { |
211 | impl_ = std::move(x.impl_); |
212 | return *this; |
213 | } |
214 | |
215 | // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here |
216 | TensorBase& operator=(const TensorBase&) && = delete; |
217 | TensorBase& operator=(TensorBase&&) && noexcept = delete; |
218 | |
219 | bool is_same(const TensorBase& other) const noexcept { |
220 | return impl_ == other.impl_; |
221 | } |
222 | size_t use_count() const noexcept { |
223 | return impl_.use_count(); |
224 | } |
225 | size_t weak_use_count() const noexcept { |
226 | return impl_.weak_use_count(); |
227 | } |
228 | |
229 | std::string toString() const; |
230 | |
231 | IntArrayRef sizes() const { |
232 | return impl_->sizes(); |
233 | } |
234 | c10::SymIntArrayRef sym_sizes() const { |
235 | return impl_->sym_sizes(); |
236 | } |
237 | c10::SymIntArrayRef sym_strides() const { |
238 | return impl_->sym_strides(); |
239 | } |
240 | IntArrayRef strides() const { |
241 | return impl_->strides(); |
242 | } |
243 | // See impl::get_opt_names in ATen/NamedTensor.h for docs. |
244 | c10::optional<DimnameList> opt_names() const { |
245 | return impl::get_opt_names(unsafeGetTensorImpl()); |
246 | } |
247 | // See impl::get_names in ATen/NamedTensor.h for docs. |
248 | DimnameList names() const { |
249 | return impl::get_names(unsafeGetTensorImpl()); |
250 | } |
251 | int64_t ndimension() const { |
252 | return dim(); |
253 | } |
254 | |
255 | bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { |
256 | return impl_->is_contiguous(memory_format); |
257 | } |
258 | |
259 | bool is_non_overlapping_and_dense() const { |
260 | return impl_->is_non_overlapping_and_dense(); |
261 | } |
262 | |
263 | at::MemoryFormat suggest_memory_format( |
264 | bool channels_last_strides_exact_match = false) const { |
265 | // Setting channels_last_strides_exact_match to true forces function to |
266 | // check 0,1 - sized dimension strides. |
267 | if (layout() == at::kStrided) { |
268 | if (impl_->is_strides_like_channels_last()) { |
269 | if (!channels_last_strides_exact_match || |
270 | get_channels_last_strides_2d(sizes()) == strides()) { |
271 | return at::MemoryFormat::ChannelsLast; |
272 | } |
273 | } |
274 | else if (impl_->is_strides_like_channels_last_3d()) { |
275 | if (!channels_last_strides_exact_match || |
276 | get_channels_last_strides_3d(sizes()) == strides()) { |
277 | return at::MemoryFormat::ChannelsLast3d; |
278 | } |
279 | } |
280 | } |
281 | return at::MemoryFormat::Contiguous; |
282 | } |
283 | |
284 | // Total bytes consumed by the "view" of elements of the array. Does not |
285 | // include size of metadata. The number reported here does not necessarily |
286 | // correspond to the true physical memory consumed by a tensor; instead, |
287 | // it reports the memory the tensor would take *if* it were contiguous. |
288 | // Defined to be numel() * itemsize() |
289 | size_t nbytes() const { |
290 | TORCH_CHECK(layout () != at::kSparse, |
291 | "nbytes is not defined for sparse tensors. If you want the size of the constituent " \ |
292 | "tensors, add the nbytes of the indices and values. If you want the size of the " \ |
293 | "equivalent dense tensor, multiply numel() by element_size()" ); |
294 | return impl_->numel() * impl_->itemsize(); |
295 | } |
296 | |
297 | c10::SymInt sym_nbytes() const { |
298 | TORCH_CHECK(layout () != at::kSparse, |
299 | "nbytes is not defined for sparse tensors. If you want the size of the constituent " \ |
300 | "tensors, add the nbytes of the indices and values. If you want the size of the " \ |
301 | "equivalent dense tensor, multiply numel() by element_size()" ); |
302 | return impl_->sym_numel() * impl_->itemsize(); |
303 | } |
304 | |
305 | int64_t numel() const { |
306 | return impl_->numel(); |
307 | } |
308 | |
309 | c10::SymInt sym_numel() const { |
310 | return impl_->sym_numel(); |
311 | } |
312 | |
313 | c10::SymInt sym_storage_offset() const { |
314 | return impl_->sym_storage_offset(); |
315 | } |
316 | |
317 | // Length of one array element in bytes. This is the traditional |
318 | // Numpy naming. |
319 | size_t itemsize() const { |
320 | return impl_->itemsize(); |
321 | } |
322 | |
323 | // Same as itemsize(). This is the PyTorch naming. |
324 | int64_t element_size() const { |
325 | return static_cast<int64_t>(impl_->itemsize()); |
326 | } |
327 | |
328 | DispatchKeySet key_set() const { |
329 | return impl_->key_set(); |
330 | } |
331 | ScalarType scalar_type() const { |
332 | return typeMetaToScalarType(impl_->dtype()); |
333 | } |
334 | bool has_storage() const { |
335 | return defined() && impl_->has_storage(); |
336 | } |
337 | const Storage& storage() const { |
338 | return impl_->storage(); |
339 | } |
340 | bool is_alias_of(const at::TensorBase& other) const{ |
341 | return impl_->storage().is_alias_of(other.storage()); |
342 | } |
343 | |
344 | inline bool _is_zerotensor() const { |
345 | return impl_->_is_zerotensor(); |
346 | } |
347 | |
348 | inline void _set_zero(bool zero) const { |
349 | impl_->_set_zero(zero); |
350 | } |
351 | |
352 | inline bool is_conj() const { |
353 | return impl_->is_conj(); |
354 | } |
355 | |
356 | // sets the conjugate bit of a tensor. |
357 | // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure |
358 | // that's what you want. Changing this might lead to incorrect behavior since conjugation is |
359 | // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized. |
360 | inline void _set_conj(bool conjugate) const { |
361 | impl_->_set_conj(conjugate); |
362 | } |
363 | |
364 | inline bool is_neg() const { |
365 | return impl_->is_neg(); |
366 | } |
367 | |
368 | // sets the negative bit of a tensor. |
369 | // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure |
370 | // that's what you want. Changing this might lead to incorrect behavior since we rely on this |
371 | // bit to determine if a negation needs to be materialized. |
372 | inline void _set_neg(bool negative) const { |
373 | impl_->_set_neg(negative); |
374 | } |
375 | |
376 | /// Returns a `Tensor`'s layout. |
377 | Layout layout() const { |
378 | return impl_->layout(); |
379 | } |
380 | |
381 | /// Returns a `Tensor`'s dtype (`TypeMeta`). |
382 | caffe2::TypeMeta dtype() const { |
383 | return impl_->dtype(); |
384 | } |
385 | |
386 | /// Returns a `Tensor`'s device. |
387 | inline Device device() const { |
388 | return impl_->device(); |
389 | } |
390 | |
391 | /// Returns a `Tensor`'s device index. |
392 | int64_t get_device() const { |
393 | // NB: this is not a native function to avoid dispatching overhead. |
394 | return impl_->get_device(); |
395 | } |
396 | |
397 | /// Returns if a `Tensor` has CPU backend. |
398 | bool is_cpu() const { |
399 | // NB: this is not a native function to avoid dispatching overhead. |
400 | return impl_->is_cpu(); |
401 | } |
402 | |
403 | /// Returns if a `Tensor` has CUDA backend. |
404 | bool is_cuda() const { |
405 | // NB: this is not a native function to avoid dispatching overhead. |
406 | return impl_->is_cuda(); |
407 | } |
408 | |
409 | /// Returns if a `Tensor` has IPU backend. |
410 | bool is_ipu() const { |
411 | // NB: this is not a native function to avoid dispatching overhead. |
412 | return impl_->is_ipu(); |
413 | } |
414 | |
415 | /// Returns if a `Tensor` has XPU backend. |
416 | bool is_xpu() const { |
417 | // NB: this is not a native function to avoid dispatching overhead. |
418 | return impl_->is_xpu(); |
419 | } |
420 | |
421 | /// Returns if a `Tensor` has XLA backend. |
422 | bool is_xla() const { |
423 | return impl_->is_xla(); |
424 | } |
425 | |
426 | /// Returns if a `Tensor` has HPU backend. |
427 | bool is_hpu() const { |
428 | return impl_->is_hpu(); |
429 | } |
430 | |
431 | /// Returns if a `Tensor` has Lazy backend. |
432 | bool is_lazy() const { |
433 | return impl_->is_lazy(); |
434 | } |
435 | |
436 | /// Returns if a `Tensor` has HIP backend. |
437 | bool is_hip() const { |
438 | // NB: this is not a native function to avoid dispatching overhead. |
439 | return impl_->is_hip(); |
440 | } |
441 | |
442 | /// Returns if a `Tensor` has VE backend. |
443 | bool is_ve() const { |
444 | // NB: this is not a native function to avoid dispatching overhead. |
445 | return impl_->is_ve(); |
446 | } |
447 | |
448 | /// Returns if a `Tensor` has sparse backend. |
449 | bool is_sparse() const { |
450 | // NB: this is not a native function to avoid dispatching overhead. |
451 | return impl_->is_sparse(); |
452 | } |
453 | |
454 | /// Returns is a `Tensor` has a sparse CSR backend. |
455 | bool is_sparse_csr() const { |
456 | // NB: this is not a native function to avoid dispatching overhead. |
457 | return impl_->is_sparse_csr(); |
458 | } |
459 | |
460 | /// Returns if a `Tensor` is mkldnn tensor. |
461 | bool is_mkldnn() const { |
462 | // NB: this is not a native function to avoid dispatching overhead. |
463 | return impl_->is_mkldnn(); |
464 | } |
465 | |
466 | /// Returns if a `Tensor` is mps tensor. |
467 | bool is_mps() const { |
468 | // NB: this is not a native function to avoid dispatching overhead. |
469 | return impl_->is_mps(); |
470 | } |
471 | |
472 | /// Returns if a `Tensor` is ort tensor. |
473 | bool is_ort() const { |
474 | // NB: this is not a native function to avoid dispatching overhead. |
475 | return impl_->is_ort(); |
476 | } |
477 | |
478 | /// Returns if a `Tensor` is vulkan tensor. |
479 | bool is_vulkan() const { |
480 | // NB: this is not a native function to avoid dispatching overhead. |
481 | return impl_->is_vulkan(); |
482 | } |
483 | |
484 | /// Returns if a `Tensor` is metal tensor. |
485 | bool is_metal() const { |
486 | // NB: this is not a native function to avoid dispatching overhead. |
487 | return impl_->is_metal(); |
488 | } |
489 | |
490 | /// Returns if a `Tensor` has quantized backend. |
491 | bool is_quantized() const { |
492 | // NB: this is not a native function to avoid dispatching overhead. |
493 | return impl_->is_quantized(); |
494 | } |
495 | |
496 | /// Returns if a `Tensor` is a meta tensor. Meta tensors can |
497 | /// also have other designations. |
498 | bool is_meta() const { |
499 | return impl_->is_meta(); |
500 | } |
501 | |
502 | /// Returns if a `Tensor` is an inference tensor. |
503 | bool is_inference() const { |
504 | return impl_->is_inference(); |
505 | } |
506 | |
507 | // Returns if a `Tensor` is a NestedTensor. |
508 | bool is_nested() const { |
509 | return impl_->is_nested(); |
510 | } |
511 | |
512 | /// If a tensor is a quantized tensor, returns its quantizer |
513 | /// TODO: it's not in native_functions.yaml yet as it's not exposed to python |
514 | QuantizerPtr quantizer() const; |
515 | |
516 | /// Returns if a `Tensor` has any dimension names |
517 | bool has_names() const { |
518 | // If a user is using unnamed tensors, then we can short-circuit right here. |
519 | // Otherwise, impl::has_names attempts to retrieve names. |
520 | if (!impl_->has_named_tensor_meta()) { |
521 | return false; |
522 | } |
523 | return impl::has_names(unsafeGetTensorImpl()); |
524 | } |
525 | |
526 | /// Returns a `Tensor`'s dimension names data structure |
527 | const NamedTensorMeta* get_named_tensor_meta() const { |
528 | return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta()); |
529 | } |
530 | |
531 | NamedTensorMeta* get_named_tensor_meta() { |
532 | return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta()); |
533 | } |
534 | |
535 | /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in |
536 | /// TensorOptions.h. |
537 | TensorOptions options() const { |
538 | return TensorOptions().dtype(dtype()) |
539 | .device(device()) |
540 | .layout(layout()); |
541 | } |
542 | |
543 | void* data_ptr() const { |
544 | return this->unsafeGetTensorImpl()->data(); |
545 | } |
546 | |
547 | template <typename T> |
548 | T * data_ptr() const; |
549 | |
550 | // Purposely not defined here to avoid inlining |
551 | void print() const; |
552 | |
553 | // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and |
554 | // dimension. |
555 | template<typename T, size_t N> |
556 | TensorAccessor<T,N> accessor() const& { |
557 | static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()" ); |
558 | TORCH_CHECK(dim() == N, "TensorAccessor expected " , N, " dims but tensor has " , dim()); |
559 | return TensorAccessor<T,N>(data_ptr<T>(),sizes().data(),strides().data()); |
560 | } |
561 | template<typename T, size_t N> |
562 | TensorAccessor<T,N> accessor() && = delete; |
563 | |
564 | // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and |
565 | // dimension. You can optionally specify RestrictPtrTraits as a template parameter to |
566 | // cast the data pointer to a __restrict__ pointer. |
567 | // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor |
568 | // as an argument. |
569 | template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> |
570 | GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& { |
571 | static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()" ); |
572 | TORCH_CHECK(dim() == N, "TensorAccessor expected " , N, " dims but tensor has " , dim()); |
573 | return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(data_ptr<T>()),sizes().data(),strides().data()); |
574 | } |
575 | template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> |
576 | GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete; |
577 | |
578 | template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits> |
579 | PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& { |
580 | TORCH_CHECK( |
581 | impl_->numel() <= |
582 | static_cast<int64_t>(std::numeric_limits<int32_t>::max()), |
583 | "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64" ); |
584 | return generic_packed_accessor<T,N,PtrTraits,int32_t>(); |
585 | } |
586 | template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits> |
587 | PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete; |
588 | |
589 | template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits> |
590 | PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& { |
591 | return generic_packed_accessor<T,N,PtrTraits,int64_t>(); |
592 | } |
593 | template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits> |
594 | PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete; |
595 | |
596 | // ~~~~~ Autograd API ~~~~~ |
597 | |
598 | /// \fn bool is_leaf() const; |
599 | /// |
600 | /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention. |
601 | /// |
602 | /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were |
603 | /// created by the user. This means that they are not the result of an operation and so |
604 | /// `grad_fn()` is `nullptr`. |
605 | /// |
606 | /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`. |
607 | /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`. |
608 | /// |
609 | /// Example: |
610 | /// @code |
611 | /// auto a = torch::rand(10, torch::requires_grad()); |
612 | /// std::cout << a.is_leaf() << std::endl; // prints `true` |
613 | /// |
614 | /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA); |
615 | /// std::cout << b.is_leaf() << std::endl; // prints `false` |
616 | /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor |
617 | /// |
618 | /// auto c = torch::rand(10, torch::requires_grad()) + 2; |
619 | /// std::cout << c.is_leaf() << std::endl; // prints `false` |
620 | /// // c was created by the addition operation |
621 | /// |
622 | /// auto d = torch::rand(10).cuda(); |
623 | /// std::cout << d.is_leaf() << std::endl; // prints `true` |
624 | /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) |
625 | /// |
626 | /// auto e = torch::rand(10).cuda().requires_grad_(); |
627 | /// std::cout << e.is_leaf() << std::endl; // prints `true` |
628 | /// // e requires gradients and has no operations creating it |
629 | /// |
630 | /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true)); |
631 | /// std::cout << f.is_leaf() << std::endl; // prints `true` |
632 | /// // f requires grad, has no operation creating it |
633 | /// @endcode |
634 | |
635 | /// \fn void backward(const Tensor & gradient={}, c10::optional<bool> retain_graph=c10::nullopt, bool create_graph=false, c10::optional<TensorList> inputs=c10::nullopt) const; |
636 | /// |
637 | /// Computes the gradient of current tensor with respect to graph leaves. |
638 | /// |
639 | /// The graph is differentiated using the chain rule. If the tensor is |
640 | /// non-scalar (i.e. its data has more than one element) and requires |
641 | /// gradient, the function additionally requires specifying ``gradient``. |
642 | /// It should be a tensor of matching type and location, that contains |
643 | /// the gradient of the differentiated function w.r.t. this Tensor. |
644 | /// |
645 | /// This function accumulates gradients in the leaves - you might need to |
646 | /// zero them before calling it. |
647 | /// |
648 | /// \param gradient Gradient w.r.t. the |
649 | /// tensor. If it is a tensor, it will be automatically converted |
650 | /// to a Tensor that does not require grad unless ``create_graph`` is True. |
651 | /// None values can be specified for scalar Tensors or ones that |
652 | /// don't require grad. If a None value would be acceptable then |
653 | /// this argument is optional. |
654 | /// \param retain_graph If ``false``, the graph used to compute |
655 | /// the grads will be freed. Note that in nearly all cases setting |
656 | /// this option to True is not needed and often can be worked around |
657 | /// in a much more efficient way. Defaults to the value of |
658 | /// ``create_graph``. |
659 | /// \param create_graph If ``true``, graph of the derivative will |
660 | /// be constructed, allowing to compute higher order derivative |
661 | /// products. Defaults to ``false``. |
662 | /// \param inputs Inputs w.r.t. which the gradient will be accumulated into |
663 | /// ``at::Tensor::grad``. All other Tensors will be ignored. If not |
664 | /// provided, the gradient is accumulated into all the leaf Tensors |
665 | /// that were used to compute the current tensor. |
666 | /// When inputs are provided and a given input is not a leaf, |
667 | /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). |
668 | /// It is an implementation detail on which the user should not rely. |
669 | /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. |
670 | |
671 | /// \fn Tensor detach() const; |
672 | /// |
673 | /// Returns a new Tensor, detached from the current graph. |
674 | /// The result will never require gradient. |
675 | |
676 | /// \fn Tensor & detach_() const; |
677 | /// |
678 | /// Detaches the Tensor from the graph that created it, making it a leaf. |
679 | /// Views cannot be detached in-place. |
680 | |
681 | /// \fn void retain_grad() const; |
682 | /// |
683 | /// Enables this Tensor to have their :attr:`grad` populated during |
684 | /// :func:`backward`. This is a no-op for leaf tensors. |
685 | |
686 | /// \fn bool retains_grad() const; |
687 | /// |
688 | /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be |
689 | /// populated during :func:`backward`, ``false`` otherwise. |
690 | |
691 | const TensorBase& set_requires_grad(bool requires_grad) const { |
692 | impl_->set_requires_grad(requires_grad); |
693 | return *this; |
694 | } |
695 | bool requires_grad() const { |
696 | return impl_->requires_grad(); |
697 | } |
698 | |
699 | // The Forward AD API functions below are low level and are not to be used by end |
700 | // users who should use the API provided in torch/csrc/autograd.h |
701 | |
702 | /// This function returns the forward gradient for this Tensor at the given level. |
703 | const Tensor& _fw_grad(uint64_t level) const { |
704 | return impl_->_fw_grad(level, *this); |
705 | } |
706 | |
707 | /// This function can be used to set the value of the forward grad. |
708 | /// Note that the given new_grad might not be used directly if it has different |
709 | /// metadata (size/stride/storage offset) compared to this Tensor. In that case, |
710 | /// new_grad content will be copied into a new Tensor |
711 | void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const { |
712 | impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op); |
713 | } |
714 | |
715 | /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended |
716 | /// to be used from functions that need to access the `Variable`'s equivalent `Tensor` |
717 | /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`). |
718 | /// |
719 | /// One notable difference with the legacy `.data()` function is that changes to the |
720 | /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset) |
721 | /// will not update the original `Variable`, due to the fact that this function |
722 | /// shallow-copies the `Variable`'s underlying TensorImpl. |
723 | at::TensorBase tensor_data() const; |
724 | |
725 | /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data` |
726 | /// in Python, which create a new `Variable` that shares the same storage and |
727 | /// tensor metadata with the original `Variable`, but with a completely new |
728 | /// autograd history. |
729 | /// |
730 | /// NOTE: If we change the tensor metadata (e.g. sizes / strides / |
731 | /// storage / storage_offset) of a variable created from `var.variable_data()`, those |
732 | /// changes will not update the original variable `var`. In `.variable_data()`, we set |
733 | /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal, |
734 | /// in order to prevent users from changing metadata of `var.variable_data()` |
735 | /// and expecting the original variable `var` to also be updated. |
736 | at::TensorBase variable_data() const; |
737 | |
738 | // Gradient Node and Edges |
739 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
740 | |
741 | /// Gets the gradient function of the `Variable`. If this is a leaf variable, |
742 | /// the pointer returned will be null. |
743 | /// |
744 | /// For View Variables: |
745 | /// Gets the up-to-date grad_fn. If the shared data or base was modified, we |
746 | /// re-create the grad_fn to express the up-to-date view relationship between |
747 | /// this and the base Variable. |
748 | const std::shared_ptr<torch::autograd::Node>& grad_fn() const; |
749 | |
750 | // Hooks |
751 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
752 | |
753 | template <typename T> |
754 | using hook_return_void_t = std::enable_if_t<std::is_void<typename c10::invoke_result_t<T&, TensorBase>>::value, unsigned>; |
755 | template <typename T> |
756 | using hook_return_var_t = std::enable_if_t<std::is_same<typename c10::invoke_result_t<T&, TensorBase>, TensorBase>::value, unsigned>; |
757 | |
758 | /// Registers a backward hook. |
759 | /// |
760 | /// The hook will be called every time a gradient with respect to the Tensor is computed. |
761 | /// The hook should have one of the following signature: |
762 | /// ``` |
763 | /// hook(TensorBase grad) -> TensorBase |
764 | /// ``` |
765 | /// ``` |
766 | /// hook(TensorBase grad) -> void |
767 | /// ``` |
768 | /// The hook should not modify its argument, but it can optionally return a new gradient |
769 | /// which will be used in place of `grad`. |
770 | /// |
771 | /// This function returns the index of the hook in the list which can be used to remove hook. |
772 | /// |
773 | /// Example: |
774 | /// @code |
775 | /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad()); |
776 | /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient |
777 | /// v.backward(torch::tensor({1., 2., 3.})); |
778 | /// // This prints: |
779 | /// // ``` |
780 | /// // 2 |
781 | /// // 4 |
782 | /// // 6 |
783 | /// // [ CPUFloatType{3} ] |
784 | /// // ``` |
785 | /// std::cout << v.grad() << std::endl; |
786 | /// v.remove_hook(h); // removes the hook |
787 | /// @endcode |
788 | template <typename T> |
789 | hook_return_void_t<T> register_hook(T&& hook) const; |
790 | template <typename T> |
791 | hook_return_var_t<T> register_hook(T&& hook) const; |
792 | |
793 | protected: |
794 | unsigned _register_hook(std::function<TensorBase(const TensorBase&)> hook) const; |
795 | |
796 | public: |
797 | |
798 | /// Remove hook at given position |
799 | void remove_hook(unsigned pos) const; |
800 | |
801 | // Variable methods |
802 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
803 | |
804 | bool is_leaf() const; |
805 | |
806 | int64_t output_nr() const; |
807 | |
808 | void set_data(const TensorBase & new_data) const; |
809 | |
810 | TensorBase data() const; |
811 | |
812 | int64_t _version() const; |
813 | |
814 | void retain_grad() const; |
815 | |
816 | bool retains_grad() const; |
817 | |
818 | const TensorBase& requires_grad_(bool _requires_grad=true) const; |
819 | |
820 | // View Variables |
821 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
822 | |
823 | /// Returns true if this `Variable` is a view of another `Variable`. |
824 | bool is_view() const; |
825 | |
826 | /// Returns the `Variable` that this `Variable` is a view of. If this |
827 | /// `Variable` is not a view, throw a `std::runtime_error`. |
828 | const TensorBase& _base() const; |
829 | |
830 | // Miscellaneous |
831 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
832 | |
833 | const std::string& name() const; |
834 | |
835 | protected: |
836 | void enforce_invariants(); |
837 | c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_; |
838 | |
839 | private: |
840 | TensorBase __dispatch_contiguous(c10::MemoryFormat) const; |
841 | }; |
842 | |
843 | inline int64_t get_device(const TensorBase& self) { |
844 | return self.get_device(); |
845 | } |
846 | |
847 | template <typename T> |
848 | auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t<T> { |
849 | // Return the grad argument in case of a hook with void return type to have an |
850 | // std::function with Tensor return type |
851 | static_assert(std::is_same<decltype(hook(TensorBase())), void>::value, |
852 | "Expected hook to return void" ); |
853 | return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad) { |
854 | fn(grad); |
855 | return TensorBase(); |
856 | }); |
857 | } |
858 | |
859 | template <typename T> |
860 | auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t<T> { |
861 | return _register_hook(std::forward<T>(hook)); |
862 | } |
863 | |
864 | namespace detail { |
865 | // Helper creator for Tensor class which doesn't requires the users to pass |
866 | // in an intrusive_ptr instead it just converts the argument passed to |
867 | // requested intrusive_ptr type. |
868 | template <typename T, typename... Args> |
869 | TensorBase make_tensor_base(Args&&... args) { |
870 | return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...)); |
871 | } |
872 | |
873 | } // namespace detail |
874 | |
875 | static inline DispatchKey (const TensorBase& t) { |
876 | return legacyExtractDispatchKey(t.key_set()); |
877 | } |
878 | |
879 | } // namespace at |
880 | |
881 | namespace c10 { |
882 | template <> |
883 | struct MaybeOwnedTraits<at::TensorBase> { |
884 | using owned_type = at::TensorBase; |
885 | using borrow_type = at::TensorBase; |
886 | |
887 | static borrow_type createBorrow(const owned_type& from) { |
888 | // NOTE: this can be implemented without the special |
889 | // unsafe_borrow_t Tensor constructor as |
890 | // |
891 | // return borrow_type(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(from.unsafeGetTensorImpl())); |
892 | // |
893 | // but that hurts inlining due to the nullptr check in the |
894 | // Tensor(c10::intrusive_ptr<...>) constructor. We already know |
895 | // that from.impl_ isn't null because from is a valid Tensor, so |
896 | // we needn't do the check again. (using __builtin_assume can |
897 | // avoid this, but wouldn't be portable to MSVC.) |
898 | return borrow_type(borrow_type::unsafe_borrow_t{}, from); |
899 | } |
900 | |
901 | static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { |
902 | lhs.unsafeReleaseTensorImpl(); |
903 | // See above note: this can be implemented with public API |
904 | // similarly to createBorrow(), but that would hurt inlining. |
905 | lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); |
906 | } |
907 | |
908 | static void destroyBorrow(borrow_type& toDestroy) { |
909 | toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0. |
910 | } |
911 | |
912 | static const owned_type& referenceFromBorrow(const borrow_type& borrow) { |
913 | return borrow; |
914 | } |
915 | |
916 | static const owned_type* pointerFromBorrow(const borrow_type& borrow) { |
917 | return &borrow; |
918 | } |
919 | |
920 | static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { |
921 | return true; |
922 | } |
923 | }; |
924 | |
925 | template <> |
926 | struct ExclusivelyOwnedTraits<at::TensorBase> : public c10::ExclusivelyOwnedTensorTraits<at::TensorBase> {}; |
927 | } // namespace c10 |
928 | |
929 | namespace at { |
930 | |
931 | inline c10::MaybeOwned<TensorBase> borrow_from_optional_tensor( |
932 | const c10::optional<TensorBase>& opt) { |
933 | return opt.has_value() |
934 | ? c10::MaybeOwned<TensorBase>::borrowed(*opt) |
935 | : c10::MaybeOwned<TensorBase>::owned(c10::in_place); |
936 | } |
937 | |
938 | inline c10::MaybeOwned<TensorBase> TensorBase::expect_contiguous(MemoryFormat memory_format) const & { |
939 | if (is_contiguous(memory_format)) { |
940 | return c10::MaybeOwned<TensorBase>::borrowed(*this); |
941 | } else { |
942 | return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format)); |
943 | } |
944 | } |
945 | |
946 | namespace symint { |
947 | |
948 | template <typename T> |
949 | using enable_if_symint = std::enable_if_t<std::is_same<T, c10::SymInt>::value>; |
950 | template <typename T> |
951 | using enable_if_int = std::enable_if_t<std::is_same<T, int64_t>::value>; |
952 | |
953 | template <typename T, typename = enable_if_symint<T>> |
954 | c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); } |
955 | template <typename T, typename = enable_if_int<T>> |
956 | IntArrayRef sizes(const TensorBase& t) { return t.sizes(); } |
957 | |
958 | template <typename T, typename = enable_if_symint<T>> |
959 | c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); } |
960 | template <typename T, typename = enable_if_int<T>> |
961 | int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); } |
962 | |
963 | template <typename T, typename = enable_if_symint<T>> |
964 | c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); } |
965 | template <typename T, typename = enable_if_int<T>> |
966 | IntArrayRef strides(const TensorBase& t) { return t.strides(); } |
967 | |
968 | template <typename T, typename = enable_if_symint<T>> |
969 | c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); } |
970 | template <typename T, typename = enable_if_int<T>> |
971 | int64_t numel(const TensorBase& t) { return t.numel(); } |
972 | |
973 | } // namespace symint |
974 | |
975 | } // namespace at |
976 | |