1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/Tensor.h> |
4 | #include <c10/core/SymIntArrayRef.h> |
5 | #include <c10/core/TensorImpl.h> |
6 | |
7 | #include <torch/csrc/lazy/core/tensor.h> |
8 | |
9 | namespace torch { |
10 | namespace lazy { |
11 | |
12 | // Tensor implementation class used to be fed to the at::Tensor. |
13 | // Its scope is just to handle an LazyTensor. |
14 | class TORCH_API LTCTensorImpl final : public c10::TensorImpl { |
15 | public: |
16 | explicit LTCTensorImpl(const LazyTensorPtr& tensor); |
17 | explicit LTCTensorImpl(const LazyTensor& tensor); |
18 | explicit LTCTensorImpl(LazyTensor&& tensor); |
19 | |
20 | LazyTensorPtr tensor() { |
21 | return tensor_; |
22 | } |
23 | |
24 | void set_tensor(const LazyTensorPtr& lazy_tensor); |
25 | |
26 | void force_refresh_sizes() { |
27 | generation_ = 0; |
28 | } |
29 | |
30 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
31 | const c10::VariableVersion& version_counter, |
32 | bool allow_tensor_metadata_change) const override; |
33 | |
34 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
35 | c10::VariableVersion&& version_counter, |
36 | bool allow_tensor_metadata_change) const override; |
37 | |
38 | void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override; |
39 | |
40 | at::IntArrayRef sizes_custom() const override; |
41 | at::IntArrayRef strides_custom() const override; |
42 | int64_t numel_custom() const override; |
43 | int64_t storage_offset_custom() const override; |
44 | int64_t dim_custom() const override; |
45 | bool is_contiguous_custom(at::MemoryFormat memory_format) const override; |
46 | bool is_strides_like_custom(at::MemoryFormat memory_format) const override; |
47 | bool is_non_overlapping_and_dense_custom() const override; |
48 | |
49 | c10::SymIntArrayRef sym_sizes_custom() const override; |
50 | c10::SymIntArrayRef sym_strides_custom() const override; |
51 | c10::SymInt sym_numel_custom() const override; |
52 | |
53 | private: |
54 | void setup_size_properties(); |
55 | |
56 | LazyTensorPtr tensor_; |
57 | mutable c10::optional<std::vector<c10::SymInt>> sym_sizes_; |
58 | size_t generation_{0}; |
59 | }; |
60 | |
61 | } // namespace lazy |
62 | } // namespace torch |
63 |