1#pragma once
2
3#include <ATen/Tensor.h>
4#include <c10/core/SymIntArrayRef.h>
5#include <c10/core/TensorImpl.h>
6
7#include <torch/csrc/lazy/core/tensor.h>
8
9namespace torch {
10namespace lazy {
11
12// Tensor implementation class used to be fed to the at::Tensor.
13// Its scope is just to handle an LazyTensor.
14class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
15 public:
16 explicit LTCTensorImpl(const LazyTensorPtr& tensor);
17 explicit LTCTensorImpl(const LazyTensor& tensor);
18 explicit LTCTensorImpl(LazyTensor&& tensor);
19
20 LazyTensorPtr tensor() {
21 return tensor_;
22 }
23
24 void set_tensor(const LazyTensorPtr& lazy_tensor);
25
26 void force_refresh_sizes() {
27 generation_ = 0;
28 }
29
30 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
31 const c10::VariableVersion& version_counter,
32 bool allow_tensor_metadata_change) const override;
33
34 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
35 c10::VariableVersion&& version_counter,
36 bool allow_tensor_metadata_change) const override;
37
38 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
39
40 at::IntArrayRef sizes_custom() const override;
41 at::IntArrayRef strides_custom() const override;
42 int64_t numel_custom() const override;
43 int64_t storage_offset_custom() const override;
44 int64_t dim_custom() const override;
45 bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
46 bool is_strides_like_custom(at::MemoryFormat memory_format) const override;
47 bool is_non_overlapping_and_dense_custom() const override;
48
49 c10::SymIntArrayRef sym_sizes_custom() const override;
50 c10::SymIntArrayRef sym_strides_custom() const override;
51 c10::SymInt sym_numel_custom() const override;
52
53 private:
54 void setup_size_properties();
55
56 LazyTensorPtr tensor_;
57 mutable c10::optional<std::vector<c10::SymInt>> sym_sizes_;
58 size_t generation_{0};
59};
60
61} // namespace lazy
62} // namespace torch
63