1#pragma once
2
3#include <c10/core/Device.h>
4#include <c10/core/Layout.h>
5#include <c10/core/MemoryFormat.h>
6#include <c10/core/ScalarType.h>
7#include <c10/core/ScalarTypeToTypeMeta.h>
8#include <c10/core/Storage.h>
9#include <c10/core/TensorImpl.h>
10#include <c10/core/TensorOptions.h>
11#include <c10/core/UndefinedTensorImpl.h>
12#include <c10/core/WrapDimMinimal.h>
13#include <c10/util/Exception.h>
14#include <c10/util/ExclusivelyOwnedTensorTraits.h>
15#include <c10/util/MaybeOwned.h>
16#include <c10/util/Optional.h>
17#include <c10/util/intrusive_ptr.h>
18
19#include <ATen/core/NamedTensor.h>
20#include <ATen/core/QuantizerBase.h>
21#include <c10/core/SymIntArrayRef.h>
22#include <ATen/core/TensorAccessor.h>
23
24namespace c10 {
25class Scalar;
26}
27
28namespace torch { namespace autograd {
29
30struct Node;
31
32}} // namespace torch::autograd
33
34namespace at {
35
36class Tensor;
37class TensorBase;
38
39// Convert Tensor to TensorBase without any need to include Tensor.h
40TORCH_API const TensorBase& get_tensor_base(const Tensor& t);
41
42namespace impl {
43inline bool variable_excluded_from_dispatch() {
44#ifdef C10_MOBILE
45 // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
46 return true;
47#else
48 return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
49#endif
50}
51
52}
53
54// NOTE: [Tensor vs. TensorBase]
55//
56// Tensor, being the central data structure in PyTorch, gets used and
57// it's header included almost everywhere. Unfortunately this means
58// every time an operator signature is updated or changed in
59// native_functions.yaml, you (and every other PyTorch developer) need
60// to recompile all of ATen and it's dependencies.
61//
62// TensorBase aims to break up these header dependencies, and improve
63// incremental build times for all PyTorch developers. TensorBase
64// represents a reference counted handle to TensorImpl, exactly the
65// same as Tensor. However, TensorBase doesn't have code generated
66// methods in it's API and thus no dependence on native_functions.yaml.
67//
68// Usage tips
69// ----------
70// - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp
71// or .cu file to ensure it has no header dependencies on
72// native_functions.yaml (direct or indirect).
73// - Tensor inherits from TensorBase, so functions taking
74// `const TensorBase &` are callable with Tensor as well.
75// - TensorBase can be converted to tensor with `Tensor(tensor_base)`,
76// but this requires a reference-count bump. OptionalTensorRef on
77// the other hand can materialize a `const Tensor &` without
78// touching the reference-count.
79class TORCH_API TensorBase {
80 public:
81 struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
82
83 protected:
84 // Create a Tensor with a +0 reference count. Special care must be
85 // taken to avoid decrementing this reference count at destruction
86 // time. Intended to support MaybeOwnedTraits<Tensor>.
87 explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
88 : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>::reclaim(rhs.impl_.get())) {}
89 friend MaybeOwnedTraits<TensorBase>;
90
91 public:
92 TensorBase() = default;
93 // This constructor should not be used by end users and is an implementation
94 // detail invoked by autogenerated code.
95 explicit TensorBase(
96 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
97 : impl_(std::move(tensor_impl)) {
98 if (impl_.get() == nullptr) {
99 throw std::runtime_error("TensorImpl with nullptr is not supported");
100 }
101 }
102 TensorBase(const TensorBase&) = default;
103 TensorBase(TensorBase&&) = default;
104
105 public:
106 // Creates a new wrapper from TensorImpl. Intentionally a free method because
107 // it should be used with care. Checks necessary invariants
108 static TensorBase wrap_tensor_impl(
109 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
110 TensorBase r(std::move(tensor_impl));
111 r.enforce_invariants();
112 return r;
113 }
114
115 int64_t dim() const {
116 return impl_->dim();
117 }
118 int64_t storage_offset() const {
119 return impl_->storage_offset();
120 }
121
122 TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
123 if (is_contiguous(memory_format)) {
124 return *this;
125 } else {
126 return __dispatch_contiguous(memory_format);
127 }
128 }
129
130 /// Should be used if *this can reasonably be expected to be contiguous and
131 /// performance is important.
132 /// Compared to contiguous, it saves a reference count
133 /// increment/decrement if *this is already contiguous, at the cost
134 /// in all cases of an extra pointer of stack usage, an extra branch
135 /// to access, and an extra branch at destruction time.
136 c10::MaybeOwned<TensorBase> expect_contiguous(
137 MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
138
139 // Use .contiguous() instead. Trying to borrow from a prvalue
140 // will only lead to trouble and dangling references.
141 c10::MaybeOwned<TensorBase> expect_contiguous(
142 MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
143
144 const TensorBase& fill_(const c10::Scalar& scalar) const;
145 const TensorBase& zero_() const;
146
147 TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, c10::optional<at::MemoryFormat> memory_format=c10::nullopt) const;
148
149 bool is_complex() const {
150 return at::isComplexType(this->scalar_type());
151 }
152
153 bool is_floating_point() const {
154 return at::isFloatingType(this->scalar_type());
155 }
156
157 bool is_signed() const {
158 return at::isSignedType(this->scalar_type());
159 }
160
161 c10::SymInt sym_size(int64_t dim) const {
162 return impl_->sym_size(dim);
163 }
164
165 c10::SymInt sym_stride(int64_t dim) const {
166 const auto sizes = this->sym_strides();
167 const auto ndim = static_cast<int64_t>(sizes.size());
168 // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
169 return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
170
171 }
172
173 int64_t size(int64_t dim) const {
174 return impl_->size(dim);
175 }
176
177 int64_t stride(int64_t dim) const {
178 const auto strides = this->strides();
179 const auto ndim = static_cast<int64_t>(strides.size());
180 // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
181 return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
182 }
183
184 TensorImpl * unsafeGetTensorImpl() const {
185 return impl_.get();
186 }
187 TensorImpl * unsafeReleaseTensorImpl() {
188 return impl_.release();
189 }
190 const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const {
191 return impl_;
192 }
193
194 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr() {
195 return std::move(impl_);
196 }
197
198 bool defined() const {
199 return impl_;
200 }
201
202 void reset() {
203 impl_.reset();
204 }
205
206 TensorBase& operator=(const TensorBase& x) & {
207 impl_ = x.impl_;
208 return *this;
209 };
210 TensorBase& operator=(TensorBase&& x) & noexcept {
211 impl_ = std::move(x.impl_);
212 return *this;
213 }
214
215 // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here
216 TensorBase& operator=(const TensorBase&) && = delete;
217 TensorBase& operator=(TensorBase&&) && noexcept = delete;
218
219 bool is_same(const TensorBase& other) const noexcept {
220 return impl_ == other.impl_;
221 }
222 size_t use_count() const noexcept {
223 return impl_.use_count();
224 }
225 size_t weak_use_count() const noexcept {
226 return impl_.weak_use_count();
227 }
228
229 std::string toString() const;
230
231 IntArrayRef sizes() const {
232 return impl_->sizes();
233 }
234 c10::SymIntArrayRef sym_sizes() const {
235 return impl_->sym_sizes();
236 }
237 c10::SymIntArrayRef sym_strides() const {
238 return impl_->sym_strides();
239 }
240 IntArrayRef strides() const {
241 return impl_->strides();
242 }
243 // See impl::get_opt_names in ATen/NamedTensor.h for docs.
244 c10::optional<DimnameList> opt_names() const {
245 return impl::get_opt_names(unsafeGetTensorImpl());
246 }
247 // See impl::get_names in ATen/NamedTensor.h for docs.
248 DimnameList names() const {
249 return impl::get_names(unsafeGetTensorImpl());
250 }
251 int64_t ndimension() const {
252 return dim();
253 }
254
255 bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
256 return impl_->is_contiguous(memory_format);
257 }
258
259 bool is_non_overlapping_and_dense() const {
260 return impl_->is_non_overlapping_and_dense();
261 }
262
263 at::MemoryFormat suggest_memory_format(
264 bool channels_last_strides_exact_match = false) const {
265 // Setting channels_last_strides_exact_match to true forces function to
266 // check 0,1 - sized dimension strides.
267 if (layout() == at::kStrided) {
268 if (impl_->is_strides_like_channels_last()) {
269 if (!channels_last_strides_exact_match ||
270 get_channels_last_strides_2d(sizes()) == strides()) {
271 return at::MemoryFormat::ChannelsLast;
272 }
273 }
274 else if (impl_->is_strides_like_channels_last_3d()) {
275 if (!channels_last_strides_exact_match ||
276 get_channels_last_strides_3d(sizes()) == strides()) {
277 return at::MemoryFormat::ChannelsLast3d;
278 }
279 }
280 }
281 return at::MemoryFormat::Contiguous;
282 }
283
284 // Total bytes consumed by the "view" of elements of the array. Does not
285 // include size of metadata. The number reported here does not necessarily
286 // correspond to the true physical memory consumed by a tensor; instead,
287 // it reports the memory the tensor would take *if* it were contiguous.
288 // Defined to be numel() * itemsize()
289 size_t nbytes() const {
290 TORCH_CHECK(layout () != at::kSparse,
291 "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
292 "tensors, add the nbytes of the indices and values. If you want the size of the " \
293 "equivalent dense tensor, multiply numel() by element_size()");
294 return impl_->numel() * impl_->itemsize();
295 }
296
297 c10::SymInt sym_nbytes() const {
298 TORCH_CHECK(layout () != at::kSparse,
299 "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
300 "tensors, add the nbytes of the indices and values. If you want the size of the " \
301 "equivalent dense tensor, multiply numel() by element_size()");
302 return impl_->sym_numel() * impl_->itemsize();
303 }
304
305 int64_t numel() const {
306 return impl_->numel();
307 }
308
309 c10::SymInt sym_numel() const {
310 return impl_->sym_numel();
311 }
312
313 c10::SymInt sym_storage_offset() const {
314 return impl_->sym_storage_offset();
315 }
316
317 // Length of one array element in bytes. This is the traditional
318 // Numpy naming.
319 size_t itemsize() const {
320 return impl_->itemsize();
321 }
322
323 // Same as itemsize(). This is the PyTorch naming.
324 int64_t element_size() const {
325 return static_cast<int64_t>(impl_->itemsize());
326 }
327
328 DispatchKeySet key_set() const {
329 return impl_->key_set();
330 }
331 ScalarType scalar_type() const {
332 return typeMetaToScalarType(impl_->dtype());
333 }
334 bool has_storage() const {
335 return defined() && impl_->has_storage();
336 }
337 const Storage& storage() const {
338 return impl_->storage();
339 }
340 bool is_alias_of(const at::TensorBase& other) const{
341 return impl_->storage().is_alias_of(other.storage());
342 }
343
344 inline bool _is_zerotensor() const {
345 return impl_->_is_zerotensor();
346 }
347
348 inline void _set_zero(bool zero) const {
349 impl_->_set_zero(zero);
350 }
351
352 inline bool is_conj() const {
353 return impl_->is_conj();
354 }
355
356 // sets the conjugate bit of a tensor.
357 // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
358 // that's what you want. Changing this might lead to incorrect behavior since conjugation is
359 // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
360 inline void _set_conj(bool conjugate) const {
361 impl_->_set_conj(conjugate);
362 }
363
364 inline bool is_neg() const {
365 return impl_->is_neg();
366 }
367
368 // sets the negative bit of a tensor.
369 // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
370 // that's what you want. Changing this might lead to incorrect behavior since we rely on this
371 // bit to determine if a negation needs to be materialized.
372 inline void _set_neg(bool negative) const {
373 impl_->_set_neg(negative);
374 }
375
376 /// Returns a `Tensor`'s layout.
377 Layout layout() const {
378 return impl_->layout();
379 }
380
381 /// Returns a `Tensor`'s dtype (`TypeMeta`).
382 caffe2::TypeMeta dtype() const {
383 return impl_->dtype();
384 }
385
386 /// Returns a `Tensor`'s device.
387 inline Device device() const {
388 return impl_->device();
389 }
390
391 /// Returns a `Tensor`'s device index.
392 int64_t get_device() const {
393 // NB: this is not a native function to avoid dispatching overhead.
394 return impl_->get_device();
395 }
396
397 /// Returns if a `Tensor` has CPU backend.
398 bool is_cpu() const {
399 // NB: this is not a native function to avoid dispatching overhead.
400 return impl_->is_cpu();
401 }
402
403 /// Returns if a `Tensor` has CUDA backend.
404 bool is_cuda() const {
405 // NB: this is not a native function to avoid dispatching overhead.
406 return impl_->is_cuda();
407 }
408
409 /// Returns if a `Tensor` has IPU backend.
410 bool is_ipu() const {
411 // NB: this is not a native function to avoid dispatching overhead.
412 return impl_->is_ipu();
413 }
414
415 /// Returns if a `Tensor` has XPU backend.
416 bool is_xpu() const {
417 // NB: this is not a native function to avoid dispatching overhead.
418 return impl_->is_xpu();
419 }
420
421 /// Returns if a `Tensor` has XLA backend.
422 bool is_xla() const {
423 return impl_->is_xla();
424 }
425
426 /// Returns if a `Tensor` has HPU backend.
427 bool is_hpu() const {
428 return impl_->is_hpu();
429 }
430
431 /// Returns if a `Tensor` has Lazy backend.
432 bool is_lazy() const {
433 return impl_->is_lazy();
434 }
435
436 /// Returns if a `Tensor` has HIP backend.
437 bool is_hip() const {
438 // NB: this is not a native function to avoid dispatching overhead.
439 return impl_->is_hip();
440 }
441
442 /// Returns if a `Tensor` has VE backend.
443 bool is_ve() const {
444 // NB: this is not a native function to avoid dispatching overhead.
445 return impl_->is_ve();
446 }
447
448 /// Returns if a `Tensor` has sparse backend.
449 bool is_sparse() const {
450 // NB: this is not a native function to avoid dispatching overhead.
451 return impl_->is_sparse();
452 }
453
454 /// Returns is a `Tensor` has a sparse CSR backend.
455 bool is_sparse_csr() const {
456 // NB: this is not a native function to avoid dispatching overhead.
457 return impl_->is_sparse_csr();
458 }
459
460 /// Returns if a `Tensor` is mkldnn tensor.
461 bool is_mkldnn() const {
462 // NB: this is not a native function to avoid dispatching overhead.
463 return impl_->is_mkldnn();
464 }
465
466 /// Returns if a `Tensor` is mps tensor.
467 bool is_mps() const {
468 // NB: this is not a native function to avoid dispatching overhead.
469 return impl_->is_mps();
470 }
471
472 /// Returns if a `Tensor` is ort tensor.
473 bool is_ort() const {
474 // NB: this is not a native function to avoid dispatching overhead.
475 return impl_->is_ort();
476 }
477
478 /// Returns if a `Tensor` is vulkan tensor.
479 bool is_vulkan() const {
480 // NB: this is not a native function to avoid dispatching overhead.
481 return impl_->is_vulkan();
482 }
483
484 /// Returns if a `Tensor` is metal tensor.
485 bool is_metal() const {
486 // NB: this is not a native function to avoid dispatching overhead.
487 return impl_->is_metal();
488 }
489
490 /// Returns if a `Tensor` has quantized backend.
491 bool is_quantized() const {
492 // NB: this is not a native function to avoid dispatching overhead.
493 return impl_->is_quantized();
494 }
495
496 /// Returns if a `Tensor` is a meta tensor. Meta tensors can
497 /// also have other designations.
498 bool is_meta() const {
499 return impl_->is_meta();
500 }
501
502 /// Returns if a `Tensor` is an inference tensor.
503 bool is_inference() const {
504 return impl_->is_inference();
505 }
506
507 // Returns if a `Tensor` is a NestedTensor.
508 bool is_nested() const {
509 return impl_->is_nested();
510 }
511
512 /// If a tensor is a quantized tensor, returns its quantizer
513 /// TODO: it's not in native_functions.yaml yet as it's not exposed to python
514 QuantizerPtr quantizer() const;
515
516 /// Returns if a `Tensor` has any dimension names
517 bool has_names() const {
518 // If a user is using unnamed tensors, then we can short-circuit right here.
519 // Otherwise, impl::has_names attempts to retrieve names.
520 if (!impl_->has_named_tensor_meta()) {
521 return false;
522 }
523 return impl::has_names(unsafeGetTensorImpl());
524 }
525
526 /// Returns a `Tensor`'s dimension names data structure
527 const NamedTensorMeta* get_named_tensor_meta() const {
528 return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
529 }
530
531 NamedTensorMeta* get_named_tensor_meta() {
532 return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
533 }
534
535 /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
536 /// TensorOptions.h.
537 TensorOptions options() const {
538 return TensorOptions().dtype(dtype())
539 .device(device())
540 .layout(layout());
541 }
542
543 void* data_ptr() const {
544 return this->unsafeGetTensorImpl()->data();
545 }
546
547 template <typename T>
548 T * data_ptr() const;
549
550 // Purposely not defined here to avoid inlining
551 void print() const;
552
553 // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
554 // dimension.
555 template<typename T, size_t N>
556 TensorAccessor<T,N> accessor() const& {
557 static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
558 TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
559 return TensorAccessor<T,N>(data_ptr<T>(),sizes().data(),strides().data());
560 }
561 template<typename T, size_t N>
562 TensorAccessor<T,N> accessor() && = delete;
563
564 // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
565 // dimension. You can optionally specify RestrictPtrTraits as a template parameter to
566 // cast the data pointer to a __restrict__ pointer.
567 // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
568 // as an argument.
569 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
570 GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
571 static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
572 TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
573 return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(data_ptr<T>()),sizes().data(),strides().data());
574 }
575 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
576 GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;
577
578 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
579 PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
580 TORCH_CHECK(
581 impl_->numel() <=
582 static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
583 "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
584 return generic_packed_accessor<T,N,PtrTraits,int32_t>();
585 }
586 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
587 PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;
588
589 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
590 PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
591 return generic_packed_accessor<T,N,PtrTraits,int64_t>();
592 }
593 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
594 PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;
595
596 // ~~~~~ Autograd API ~~~~~
597
598 /// \fn bool is_leaf() const;
599 ///
600 /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
601 ///
602 /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
603 /// created by the user. This means that they are not the result of an operation and so
604 /// `grad_fn()` is `nullptr`.
605 ///
606 /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
607 /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
608 ///
609 /// Example:
610 /// @code
611 /// auto a = torch::rand(10, torch::requires_grad());
612 /// std::cout << a.is_leaf() << std::endl; // prints `true`
613 ///
614 /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
615 /// std::cout << b.is_leaf() << std::endl; // prints `false`
616 /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
617 ///
618 /// auto c = torch::rand(10, torch::requires_grad()) + 2;
619 /// std::cout << c.is_leaf() << std::endl; // prints `false`
620 /// // c was created by the addition operation
621 ///
622 /// auto d = torch::rand(10).cuda();
623 /// std::cout << d.is_leaf() << std::endl; // prints `true`
624 /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
625 ///
626 /// auto e = torch::rand(10).cuda().requires_grad_();
627 /// std::cout << e.is_leaf() << std::endl; // prints `true`
628 /// // e requires gradients and has no operations creating it
629 ///
630 /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
631 /// std::cout << f.is_leaf() << std::endl; // prints `true`
632 /// // f requires grad, has no operation creating it
633 /// @endcode
634
635 /// \fn void backward(const Tensor & gradient={}, c10::optional<bool> retain_graph=c10::nullopt, bool create_graph=false, c10::optional<TensorList> inputs=c10::nullopt) const;
636 ///
637 /// Computes the gradient of current tensor with respect to graph leaves.
638 ///
639 /// The graph is differentiated using the chain rule. If the tensor is
640 /// non-scalar (i.e. its data has more than one element) and requires
641 /// gradient, the function additionally requires specifying ``gradient``.
642 /// It should be a tensor of matching type and location, that contains
643 /// the gradient of the differentiated function w.r.t. this Tensor.
644 ///
645 /// This function accumulates gradients in the leaves - you might need to
646 /// zero them before calling it.
647 ///
648 /// \param gradient Gradient w.r.t. the
649 /// tensor. If it is a tensor, it will be automatically converted
650 /// to a Tensor that does not require grad unless ``create_graph`` is True.
651 /// None values can be specified for scalar Tensors or ones that
652 /// don't require grad. If a None value would be acceptable then
653 /// this argument is optional.
654 /// \param retain_graph If ``false``, the graph used to compute
655 /// the grads will be freed. Note that in nearly all cases setting
656 /// this option to True is not needed and often can be worked around
657 /// in a much more efficient way. Defaults to the value of
658 /// ``create_graph``.
659 /// \param create_graph If ``true``, graph of the derivative will
660 /// be constructed, allowing to compute higher order derivative
661 /// products. Defaults to ``false``.
662 /// \param inputs Inputs w.r.t. which the gradient will be accumulated into
663 /// ``at::Tensor::grad``. All other Tensors will be ignored. If not
664 /// provided, the gradient is accumulated into all the leaf Tensors
665 /// that were used to compute the current tensor.
666 /// When inputs are provided and a given input is not a leaf,
667 /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
668 /// It is an implementation detail on which the user should not rely.
669 /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
670
671 /// \fn Tensor detach() const;
672 ///
673 /// Returns a new Tensor, detached from the current graph.
674 /// The result will never require gradient.
675
676 /// \fn Tensor & detach_() const;
677 ///
678 /// Detaches the Tensor from the graph that created it, making it a leaf.
679 /// Views cannot be detached in-place.
680
681 /// \fn void retain_grad() const;
682 ///
683 /// Enables this Tensor to have their :attr:`grad` populated during
684 /// :func:`backward`. This is a no-op for leaf tensors.
685
686 /// \fn bool retains_grad() const;
687 ///
688 /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
689 /// populated during :func:`backward`, ``false`` otherwise.
690
691 const TensorBase& set_requires_grad(bool requires_grad) const {
692 impl_->set_requires_grad(requires_grad);
693 return *this;
694 }
695 bool requires_grad() const {
696 return impl_->requires_grad();
697 }
698
699 // The Forward AD API functions below are low level and are not to be used by end
700 // users who should use the API provided in torch/csrc/autograd.h
701
702 /// This function returns the forward gradient for this Tensor at the given level.
703 const Tensor& _fw_grad(uint64_t level) const {
704 return impl_->_fw_grad(level, *this);
705 }
706
707 /// This function can be used to set the value of the forward grad.
708 /// Note that the given new_grad might not be used directly if it has different
709 /// metadata (size/stride/storage offset) compared to this Tensor. In that case,
710 /// new_grad content will be copied into a new Tensor
711 void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
712 impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
713 }
714
715 /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
716 /// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
717 /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
718 ///
719 /// One notable difference with the legacy `.data()` function is that changes to the
720 /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
721 /// will not update the original `Variable`, due to the fact that this function
722 /// shallow-copies the `Variable`'s underlying TensorImpl.
723 at::TensorBase tensor_data() const;
724
725 /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
726 /// in Python, which create a new `Variable` that shares the same storage and
727 /// tensor metadata with the original `Variable`, but with a completely new
728 /// autograd history.
729 ///
730 /// NOTE: If we change the tensor metadata (e.g. sizes / strides /
731 /// storage / storage_offset) of a variable created from `var.variable_data()`, those
732 /// changes will not update the original variable `var`. In `.variable_data()`, we set
733 /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
734 /// in order to prevent users from changing metadata of `var.variable_data()`
735 /// and expecting the original variable `var` to also be updated.
736 at::TensorBase variable_data() const;
737
738 // Gradient Node and Edges
739 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
740
741 /// Gets the gradient function of the `Variable`. If this is a leaf variable,
742 /// the pointer returned will be null.
743 ///
744 /// For View Variables:
745 /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
746 /// re-create the grad_fn to express the up-to-date view relationship between
747 /// this and the base Variable.
748 const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
749
750 // Hooks
751 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
752
753 template <typename T>
754 using hook_return_void_t = std::enable_if_t<std::is_void<typename c10::invoke_result_t<T&, TensorBase>>::value, unsigned>;
755 template <typename T>
756 using hook_return_var_t = std::enable_if_t<std::is_same<typename c10::invoke_result_t<T&, TensorBase>, TensorBase>::value, unsigned>;
757
758 /// Registers a backward hook.
759 ///
760 /// The hook will be called every time a gradient with respect to the Tensor is computed.
761 /// The hook should have one of the following signature:
762 /// ```
763 /// hook(TensorBase grad) -> TensorBase
764 /// ```
765 /// ```
766 /// hook(TensorBase grad) -> void
767 /// ```
768 /// The hook should not modify its argument, but it can optionally return a new gradient
769 /// which will be used in place of `grad`.
770 ///
771 /// This function returns the index of the hook in the list which can be used to remove hook.
772 ///
773 /// Example:
774 /// @code
775 /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
776 /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
777 /// v.backward(torch::tensor({1., 2., 3.}));
778 /// // This prints:
779 /// // ```
780 /// // 2
781 /// // 4
782 /// // 6
783 /// // [ CPUFloatType{3} ]
784 /// // ```
785 /// std::cout << v.grad() << std::endl;
786 /// v.remove_hook(h); // removes the hook
787 /// @endcode
788 template <typename T>
789 hook_return_void_t<T> register_hook(T&& hook) const;
790 template <typename T>
791 hook_return_var_t<T> register_hook(T&& hook) const;
792
793protected:
794 unsigned _register_hook(std::function<TensorBase(const TensorBase&)> hook) const;
795
796public:
797
798 /// Remove hook at given position
799 void remove_hook(unsigned pos) const;
800
801 // Variable methods
802 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
803
804 bool is_leaf() const;
805
806 int64_t output_nr() const;
807
808 void set_data(const TensorBase & new_data) const;
809
810 TensorBase data() const;
811
812 int64_t _version() const;
813
814 void retain_grad() const;
815
816 bool retains_grad() const;
817
818 const TensorBase& requires_grad_(bool _requires_grad=true) const;
819
820 // View Variables
821 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
822
823 /// Returns true if this `Variable` is a view of another `Variable`.
824 bool is_view() const;
825
826 /// Returns the `Variable` that this `Variable` is a view of. If this
827 /// `Variable` is not a view, throw a `std::runtime_error`.
828 const TensorBase& _base() const;
829
830 // Miscellaneous
831 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
832
833 const std::string& name() const;
834
835protected:
836 void enforce_invariants();
837 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
838
839private:
840 TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
841};
842
843inline int64_t get_device(const TensorBase& self) {
844 return self.get_device();
845}
846
847template <typename T>
848auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t<T> {
849 // Return the grad argument in case of a hook with void return type to have an
850 // std::function with Tensor return type
851 static_assert(std::is_same<decltype(hook(TensorBase())), void>::value,
852 "Expected hook to return void");
853 return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad) {
854 fn(grad);
855 return TensorBase();
856 });
857}
858
859template <typename T>
860auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t<T> {
861 return _register_hook(std::forward<T>(hook));
862}
863
864namespace detail {
865// Helper creator for Tensor class which doesn't requires the users to pass
866// in an intrusive_ptr instead it just converts the argument passed to
867// requested intrusive_ptr type.
868template <typename T, typename... Args>
869TensorBase make_tensor_base(Args&&... args) {
870 return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...));
871}
872
873} // namespace detail
874
875static inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
876 return legacyExtractDispatchKey(t.key_set());
877}
878
879} // namespace at
880
881namespace c10 {
882template <>
883struct MaybeOwnedTraits<at::TensorBase> {
884 using owned_type = at::TensorBase;
885 using borrow_type = at::TensorBase;
886
887 static borrow_type createBorrow(const owned_type& from) {
888 // NOTE: this can be implemented without the special
889 // unsafe_borrow_t Tensor constructor as
890 //
891 // return borrow_type(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(from.unsafeGetTensorImpl()));
892 //
893 // but that hurts inlining due to the nullptr check in the
894 // Tensor(c10::intrusive_ptr<...>) constructor. We already know
895 // that from.impl_ isn't null because from is a valid Tensor, so
896 // we needn't do the check again. (using __builtin_assume can
897 // avoid this, but wouldn't be portable to MSVC.)
898 return borrow_type(borrow_type::unsafe_borrow_t{}, from);
899 }
900
901 static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
902 lhs.unsafeReleaseTensorImpl();
903 // See above note: this can be implemented with public API
904 // similarly to createBorrow(), but that would hurt inlining.
905 lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
906 }
907
908 static void destroyBorrow(borrow_type& toDestroy) {
909 toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
910 }
911
912 static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
913 return borrow;
914 }
915
916 static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
917 return &borrow;
918 }
919
920 static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
921 return true;
922 }
923};
924
925template <>
926struct ExclusivelyOwnedTraits<at::TensorBase> : public c10::ExclusivelyOwnedTensorTraits<at::TensorBase> {};
927} // namespace c10
928
929namespace at {
930
931inline c10::MaybeOwned<TensorBase> borrow_from_optional_tensor(
932 const c10::optional<TensorBase>& opt) {
933 return opt.has_value()
934 ? c10::MaybeOwned<TensorBase>::borrowed(*opt)
935 : c10::MaybeOwned<TensorBase>::owned(c10::in_place);
936}
937
938inline c10::MaybeOwned<TensorBase> TensorBase::expect_contiguous(MemoryFormat memory_format) const & {
939 if (is_contiguous(memory_format)) {
940 return c10::MaybeOwned<TensorBase>::borrowed(*this);
941 } else {
942 return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format));
943 }
944}
945
946namespace symint {
947
948template <typename T>
949using enable_if_symint = std::enable_if_t<std::is_same<T, c10::SymInt>::value>;
950template <typename T>
951using enable_if_int = std::enable_if_t<std::is_same<T, int64_t>::value>;
952
953template <typename T, typename = enable_if_symint<T>>
954c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); }
955template <typename T, typename = enable_if_int<T>>
956IntArrayRef sizes(const TensorBase& t) { return t.sizes(); }
957
958template <typename T, typename = enable_if_symint<T>>
959c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); }
960template <typename T, typename = enable_if_int<T>>
961int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); }
962
963template <typename T, typename = enable_if_symint<T>>
964c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); }
965template <typename T, typename = enable_if_int<T>>
966IntArrayRef strides(const TensorBase& t) { return t.strides(); }
967
968template <typename T, typename = enable_if_symint<T>>
969c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); }
970template <typename T, typename = enable_if_int<T>>
971int64_t numel(const TensorBase& t) { return t.numel(); }
972
973} // namespace symint
974
975} // namespace at
976