1 | #pragma once |
2 | |
3 | #include <ATen/quantized/Quantizer.h> |
4 | #include <c10/core/TensorImpl.h> |
5 | #include <c10/util/Exception.h> |
6 | |
7 | namespace at { |
8 | |
9 | /** |
10 | * QTensorImpl is a TensorImpl for Quantized Tensors, it stores Quantizer which |
11 | * specifies the quantization scheme and parameters, for more information please |
12 | * see ATen/quantized/Quantizer.h |
13 | * |
14 | * We'll use QTensor in code or documentation to refer to a Tensor with QTensorImpl. |
15 | */ |
16 | struct TORCH_API QTensorImpl : public c10::TensorImpl { |
17 | public: |
18 | QTensorImpl( |
19 | Storage&& storage, |
20 | DispatchKeySet key_set, |
21 | const caffe2::TypeMeta data_type, |
22 | QuantizerPtr quantizer); |
23 | |
24 | // See Note [Enum ImplType] |
25 | QTensorImpl( |
26 | ImplType type, |
27 | Storage&& storage, |
28 | DispatchKeySet key_set, |
29 | const caffe2::TypeMeta data_type, |
30 | QuantizerPtr quantizer); |
31 | |
32 | |
33 | // TODO: Expose in PyTorch Frontend |
34 | QuantizerPtr quantizer() { |
35 | return quantizer_; |
36 | } |
37 | |
38 | void set_quantizer_(QuantizerPtr quantizer) { |
39 | quantizer_ = quantizer; |
40 | } |
41 | |
42 | /** |
43 | * Return a TensorImpl that is a shallow-copy of this TensorImpl. |
44 | * |
45 | * For usage of `version_counter` and `allow_tensor_metadata_change`, |
46 | * see NOTE [ TensorImpl Shallow-Copying ]. |
47 | */ |
48 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
49 | const c10::VariableVersion& version_counter, |
50 | bool allow_tensor_metadata_change) const override { |
51 | auto impl = c10::make_intrusive<QTensorImpl>( |
52 | Storage(storage()), key_set(), data_type_, quantizer_); |
53 | copy_tensor_metadata( |
54 | /*src_impl=*/this, |
55 | /*dest_impl=*/impl.get(), |
56 | /*version_counter=*/version_counter, |
57 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
58 | impl->refresh_numel(); |
59 | impl->refresh_contiguous(); |
60 | return impl; |
61 | } |
62 | |
63 | /** |
64 | * Return a TensorImpl that is a shallow-copy of this TensorImpl. |
65 | * |
66 | * For usage of `version_counter` and `allow_tensor_metadata_change`, |
67 | * see NOTE [ TensorImpl Shallow-Copying ]. |
68 | */ |
69 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
70 | c10::VariableVersion&& version_counter, |
71 | bool allow_tensor_metadata_change) const override { |
72 | auto impl = c10::make_intrusive<QTensorImpl>( |
73 | Storage(storage()), key_set(), data_type_, quantizer_); |
74 | copy_tensor_metadata( |
75 | /*src_impl=*/this, |
76 | /*dest_impl=*/impl.get(), |
77 | /*version_counter=*/std::move(version_counter), |
78 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
79 | impl->refresh_numel(); |
80 | impl->refresh_contiguous(); |
81 | return impl; |
82 | } |
83 | |
84 | /** |
85 | * Shallow-copies data from another TensorImpl into this TensorImpl. |
86 | * |
87 | * For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`, |
88 | * see NOTE [ TensorImpl Shallow-Copying ]. |
89 | */ |
90 | void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override { |
91 | AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); |
92 | auto q_impl = static_cast<const QTensorImpl*>(impl.get()); |
93 | copy_tensor_metadata( |
94 | /*src_impl=*/q_impl, |
95 | /*dest_impl=*/this, |
96 | /*version_counter=*/version_counter(), |
97 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); |
98 | refresh_numel(); |
99 | refresh_contiguous(); |
100 | } |
101 | |
102 | private: |
103 | QuantizerPtr quantizer_; |
104 | |
105 | const char* tensorimpl_type_name() const override; |
106 | |
107 | /** |
108 | * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) |
109 | * from one TensorImpl to another TensorImpl. |
110 | * |
111 | * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ]. |
112 | */ |
113 | static void copy_tensor_metadata( |
114 | const QTensorImpl* src_q_impl, |
115 | QTensorImpl* dest_q_impl, |
116 | const c10::VariableVersion& version_counter, |
117 | bool allow_tensor_metadata_change) { |
118 | TensorImpl::copy_tensor_metadata(src_q_impl, dest_q_impl, version_counter, allow_tensor_metadata_change); |
119 | |
120 | // OpaqueTensorImpl-specific fields. |
121 | dest_q_impl->quantizer_ = src_q_impl->quantizer_; |
122 | } |
123 | }; |
124 | |
125 | } // namespace at |
126 | |