1#pragma once
2
3#include <ATen/quantized/Quantizer.h>
4#include <c10/core/TensorImpl.h>
5#include <c10/util/Exception.h>
6
7namespace at {
8
9/**
10 * QTensorImpl is a TensorImpl for Quantized Tensors, it stores Quantizer which
11 * specifies the quantization scheme and parameters, for more information please
12 * see ATen/quantized/Quantizer.h
13 *
14 * We'll use QTensor in code or documentation to refer to a Tensor with QTensorImpl.
15 */
16struct TORCH_API QTensorImpl : public c10::TensorImpl {
17 public:
18 QTensorImpl(
19 Storage&& storage,
20 DispatchKeySet key_set,
21 const caffe2::TypeMeta data_type,
22 QuantizerPtr quantizer);
23
24 // See Note [Enum ImplType]
25 QTensorImpl(
26 ImplType type,
27 Storage&& storage,
28 DispatchKeySet key_set,
29 const caffe2::TypeMeta data_type,
30 QuantizerPtr quantizer);
31
32
33 // TODO: Expose in PyTorch Frontend
34 QuantizerPtr quantizer() {
35 return quantizer_;
36 }
37
38 void set_quantizer_(QuantizerPtr quantizer) {
39 quantizer_ = quantizer;
40 }
41
42 /**
43 * Return a TensorImpl that is a shallow-copy of this TensorImpl.
44 *
45 * For usage of `version_counter` and `allow_tensor_metadata_change`,
46 * see NOTE [ TensorImpl Shallow-Copying ].
47 */
48 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
49 const c10::VariableVersion& version_counter,
50 bool allow_tensor_metadata_change) const override {
51 auto impl = c10::make_intrusive<QTensorImpl>(
52 Storage(storage()), key_set(), data_type_, quantizer_);
53 copy_tensor_metadata(
54 /*src_impl=*/this,
55 /*dest_impl=*/impl.get(),
56 /*version_counter=*/version_counter,
57 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
58 impl->refresh_numel();
59 impl->refresh_contiguous();
60 return impl;
61 }
62
63 /**
64 * Return a TensorImpl that is a shallow-copy of this TensorImpl.
65 *
66 * For usage of `version_counter` and `allow_tensor_metadata_change`,
67 * see NOTE [ TensorImpl Shallow-Copying ].
68 */
69 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
70 c10::VariableVersion&& version_counter,
71 bool allow_tensor_metadata_change) const override {
72 auto impl = c10::make_intrusive<QTensorImpl>(
73 Storage(storage()), key_set(), data_type_, quantizer_);
74 copy_tensor_metadata(
75 /*src_impl=*/this,
76 /*dest_impl=*/impl.get(),
77 /*version_counter=*/std::move(version_counter),
78 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
79 impl->refresh_numel();
80 impl->refresh_contiguous();
81 return impl;
82 }
83
84 /**
85 * Shallow-copies data from another TensorImpl into this TensorImpl.
86 *
87 * For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`,
88 * see NOTE [ TensorImpl Shallow-Copying ].
89 */
90 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
91 AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
92 auto q_impl = static_cast<const QTensorImpl*>(impl.get());
93 copy_tensor_metadata(
94 /*src_impl=*/q_impl,
95 /*dest_impl=*/this,
96 /*version_counter=*/version_counter(),
97 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
98 refresh_numel();
99 refresh_contiguous();
100 }
101
102 private:
103 QuantizerPtr quantizer_;
104
105 const char* tensorimpl_type_name() const override;
106
107 /**
108 * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset)
109 * from one TensorImpl to another TensorImpl.
110 *
111 * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ].
112 */
113 static void copy_tensor_metadata(
114 const QTensorImpl* src_q_impl,
115 QTensorImpl* dest_q_impl,
116 const c10::VariableVersion& version_counter,
117 bool allow_tensor_metadata_change) {
118 TensorImpl::copy_tensor_metadata(src_q_impl, dest_q_impl, version_counter, allow_tensor_metadata_change);
119
120 // OpaqueTensorImpl-specific fields.
121 dest_q_impl->quantizer_ = src_q_impl->quantizer_;
122 }
123};
124
125} // namespace at
126