1#pragma once
2
3#include <c10/core/MemoryFormat.h>
4#include <c10/core/SymIntArrayRef.h>
5#include <c10/core/TensorImpl.h>
6#include <c10/util/Exception.h>
7
8namespace at {
9
10// An "Opaque" TensorImpl -- there are no strides and (for now)
11// even data() is not supported (thus no pointer arithmetic).
12
13// NOTE: We could allow data() in the future, but would have to ensure pointer
14// arithmetic code is properly guarded.
15//
16// NOTE: This does not support resize_ (and other metadata-changing ops) because
17// of `shallow_copy_and_detach`. We would need to define an interface to
18// "shallow copy" in order to add support.
19
20template <typename OpaqueHandle>
21struct TORCH_API OpaqueTensorImpl : public TensorImpl {
22 // public constructor for now...
23 OpaqueTensorImpl(
24 at::DispatchKeySet key_set,
25 const caffe2::TypeMeta data_type,
26 c10::Device device,
27 OpaqueHandle opaque_handle,
28 c10::IntArrayRef sizes,
29 bool is_non_overlapping_and_dense = true)
30 : TensorImpl(key_set, data_type, device),
31 opaque_handle_(std::move(opaque_handle)) {
32 set_storage_access_should_throw();
33 set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
34 sizes_and_strides_.set_sizes(sizes);
35 refresh_numel();
36 is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
37 }
38
39 // Destructor doesn't call release_resources because it's
40 // unnecessary; don't forget to change that if needed!
41 void release_resources() override {
42 TensorImpl::release_resources();
43 opaque_handle_ = {};
44 }
45
46 void set_size(int64_t dim, int64_t new_size) override {
47 AT_ERROR("opaque tensors do not have set_size");
48 }
49
50 void set_stride(int64_t dim, int64_t new_stride) override {
51 AT_ERROR("opaque tensors do not have set_stride");
52 }
53
54 void set_storage_offset(int64_t storage_offset) override {
55 AT_ERROR("opaque tensors do not have set_storage_offset");
56 }
57
58#ifdef DEBUG
59 bool has_storage() const override {
60 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
61 !storage_, "OpaqueTensorImpl assumes that storage_ is never set");
62 return false;
63 }
64#endif
65
66 /**
67 * Return a TensorImpl that is a shallow-copy of this TensorImpl.
68 *
69 * For usage of `version_counter` and `allow_tensor_metadata_change`,
70 * see NOTE [ TensorImpl Shallow-Copying ].
71 */
72 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
73 const c10::VariableVersion& version_counter,
74 bool allow_tensor_metadata_change) const override {
75 auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
76 key_set(),
77 dtype(),
78 device(),
79 opaque_handle_,
80 sizes_and_strides_.sizes_arrayref());
81 copy_tensor_metadata(
82 /*src_opaque_impl=*/this,
83 /*dest_opaque_impl=*/impl.get(),
84 /*version_counter=*/version_counter,
85 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
86 impl->refresh_numel();
87 return impl;
88 }
89
90 /**
91 * Return a TensorImpl that is a shallow-copy of this TensorImpl.
92 *
93 * For usage of `version_counter` and `allow_tensor_metadata_change`,
94 * see NOTE [ TensorImpl Shallow-Copying ].
95 */
96 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
97 c10::VariableVersion&& version_counter,
98 bool allow_tensor_metadata_change) const override {
99 auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
100 key_set(),
101 dtype(),
102 device(),
103 opaque_handle_,
104 sizes_and_strides_.sizes_arrayref());
105 copy_tensor_metadata(
106 /*src_opaque_impl=*/this,
107 /*dest_opaque_impl=*/impl.get(),
108 /*version_counter=*/std::move(version_counter),
109 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
110 impl->refresh_numel();
111 return impl;
112 }
113
114 /**
115 * Shallow-copies data from another TensorImpl into this TensorImpl.
116 *
117 * For why this function doesn't check this TensorImpl's
118 * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
119 */
120 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
121 AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
122 auto opaque_impl =
123 static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
124 copy_tensor_metadata(
125 /*src_impl=*/opaque_impl,
126 /*dest_impl=*/this,
127 /*version_counter=*/version_counter(),
128 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
129 refresh_numel();
130 }
131
132 const OpaqueHandle& opaque_handle() const {
133 return opaque_handle_;
134 }
135
136 OpaqueHandle& unsafe_opaque_handle() {
137 return opaque_handle_;
138 }
139
140 protected:
141 /**
142 * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
143 * storage_offset) from one TensorImpl to another TensorImpl.
144 *
145 * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
146 * [ TensorImpl Shallow-Copying ].
147 */
148 static void copy_tensor_metadata(
149 const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
150 OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
151 const c10::VariableVersion& version_counter,
152 bool allow_tensor_metadata_change) {
153 TensorImpl::copy_tensor_metadata(
154 src_opaque_impl,
155 dest_opaque_impl,
156 version_counter,
157 allow_tensor_metadata_change);
158
159 // OpaqueTensorImpl-specific fields.
160 dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
161 }
162
163 static void copy_tensor_metadata(
164 const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
165 OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
166 c10::VariableVersion&& version_counter,
167 bool allow_tensor_metadata_change) {
168 TensorImpl::copy_tensor_metadata(
169 src_opaque_impl,
170 dest_opaque_impl,
171 std::move(version_counter),
172 allow_tensor_metadata_change);
173
174 // OpaqueTensorImpl-specific fields.
175 dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
176 }
177
178 private:
179 const char* tensorimpl_type_name() const override {
180 return "OpaqueTensorImpl";
181 }
182
183 OpaqueHandle opaque_handle_;
184};
185
186} // namespace at
187