1#pragma once
2
3#include <ATen/core/qualified_name.h>
4#include <string>
5#include <utility>
6#include <vector>
7
8#include <ATen/Utils.h>
9#include <ATen/core/ivalue.h>
10#include <ATen/core/jit_type.h>
11#include <c10/util/ArrayRef.h>
12#include <torch/csrc/Export.h>
13
14namespace torch {
15namespace jit {
16
17// See Python's pickletools.py for a detailed description of each of these codes
18enum class PickleOpCode : char {
19 MARK = '(',
20 STOP = '.',
21 POP = '0',
22 POP_MARK = '1',
23 DUP = '2',
24 FLOAT = 'F',
25 INT = 'I',
26 BININT = 'J',
27 BININT1 = 'K',
28 LONG = 'L',
29 BININT2 = 'M',
30 NONE = 'N',
31 PERSID = 'P',
32 BINPERSID = 'Q',
33 REDUCE = 'R',
34 STRING = 'S',
35 BINSTRING = 'T',
36 SHORT_BINSTRING = 'U',
37 // NB: Avoid using UNICODE as it is a macro in the Windows API
38 UNICODE_ = 'V',
39 BINUNICODE = 'X',
40 APPEND = 'a',
41 BUILD = 'b',
42 GLOBAL = 'c',
43 DICT = 'd',
44 EMPTY_DICT = '}',
45 APPENDS = 'e',
46 GET = 'g',
47 BINGET = 'h',
48 INST = 'i',
49 LONG_BINGET = 'j',
50 LIST = 'l',
51 EMPTY_LIST = ']',
52 OBJ = 'o',
53 PUT = 'p',
54 BINPUT = 'q',
55 LONG_BINPUT = 'r',
56 SETITEM = 's',
57 TUPLE = 't',
58 EMPTY_TUPLE = ')',
59 SETITEMS = 'u',
60 BINFLOAT = 'G',
61
62 // Protocol 2
63 PROTO = char('\x80'),
64 NEWOBJ = '\x81',
65 EXT1 = '\x82',
66 EXT2 = '\x83',
67 EXT4 = '\x84',
68 TUPLE1 = '\x85',
69 TUPLE2 = '\x86',
70 TUPLE3 = '\x87',
71 NEWTRUE = '\x88',
72 NEWFALSE = '\x89',
73 LONG1 = '\x8a',
74 LONG4 = '\x8b',
75
76 // Protocol 3 (Python 3.x)
77 BINBYTES = 'B',
78 SHORT_BINBYTES = 'C',
79
80 // Protocol 4
81 SHORT_BINUNICODE = char('\x8c'),
82 BINUNICODE8 = '\x8d',
83 BINBYTES8 = '\x8e',
84 EMPTY_SET = '\x8f',
85 ADDITEMS = '\x90',
86 FROZENSET = '\x91',
87 NEWOBJ_EX = '\x92',
88 STACK_GLOBAL = '\x93',
89 MEMOIZE = '\x94',
90 FRAME = '\x95'
91};
92
93using ::c10::IValue;
94
95// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
96struct WriteableTensorData {
97 const char* data() const {
98 return static_cast<const char*>(tensor_.storage().data());
99 }
100 size_t sizeInBytes() const {
101 return size_;
102 }
103 size_t nbytes() const {
104 return tensor_.storage().nbytes();
105 }
106 bool storageHasDeleter() const {
107 return tensor_.storage().data_ptr().get_context() != nullptr;
108 }
109
110 private:
111 friend TORCH_API WriteableTensorData
112 getWriteableTensorData(const at::Tensor& tensor, bool to_cpu);
113 at::Tensor tensor_;
114 uint64_t size_;
115};
116
117void setTypeTags(bool state);
118bool getTypeTags();
119
120class TORCH_API Pickler {
121 AT_DISALLOW_COPY_AND_ASSIGN(Pickler);
122
123 public:
124 Pickler(std::function<void(const char*, size_t)> writer)
125 : Pickler(std::move(writer), nullptr, nullptr, nullptr) {}
126
127 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
128 Pickler(
129 std::function<void(const char*, size_t)> writer,
130 std::vector<at::Tensor>* tensor_table,
131 std::function<c10::QualifiedName(const c10::ClassTypePtr&)> type_renamer,
132 std::vector<c10::ClassTypePtr>* memoized_class_types,
133 std::function<std::string(const at::Tensor&)> get_tensor_id = nullptr,
134 bool tag_aggregates = true)
135 : writer_(std::move(writer)),
136 tensor_table_(tensor_table),
137 type_renamer_(std::move(type_renamer)),
138 memoized_class_types_(memoized_class_types),
139 get_tensor_id_(std::move(get_tensor_id)),
140 tag_aggregates_(tag_aggregates) {}
141 // NOLINTNEXTLINE(bugprone-exception-escape)
142 ~Pickler();
143
144 // Push protocol onto the stack
145 void protocol();
146
147 // Push STOP PickleOpCode onto the stack
148 void stop();
149
150 void pushIValue(const IValue& ivalue);
151
152 void startTuple();
153 void endTuple();
154
155 const std::vector<at::Tensor>& tensorData() {
156 return tensor_data_;
157 }
158
159 void pushEmptyDict();
160 void pushDict(const IValue& ivalue);
161 void pushInt(int64_t value);
162 void pushLong(const std::string& data);
163
164 private:
165 void pushIValueImpl(const IValue& ivalue);
166 void startTypeTag();
167 void endTypeTag(const IValue& value);
168 void pushBool(bool value);
169 void pushDouble(double value);
170 void pushComplexDouble(const IValue& value);
171 void pushGenericList(const IValue& ivalue);
172 void pushIntList(const IValue& ivalue);
173 void pushList(const IValue& ivalue);
174 void pushTensor(const IValue& ivalue);
175 void pushTensorReference(const IValue& ivalue);
176 void pushLiteralTensor(const IValue& ivalue);
177 void pushLiteralSparseTensor(const at::Tensor& tensor);
178 void pushTuple(const IValue& ivalue);
179 void pushString(const std::string& string);
180 void pushDevice(const IValue& ivalue);
181#ifdef USE_DISTRIBUTED
182 void pushRRef(const IValue& ivalue);
183#endif
184 // unmemoized version
185 void pushStringImpl(const std::string& string);
186 void pushStorageOfTensor(const at::Tensor& tensor);
187
188 void pushBinGet(uint32_t memo_id);
189 void pushSpecializedList(
190 const IValue& ivalue,
191 const char* list_name,
192 const std::function<void(const IValue&)>& item_pusher);
193 void pushGlobal(
194 const std::string& module_name,
195 const std::string& class_name);
196 // raw string data is appended directly to the byte stream
197 void pushBytes(const std::string& string);
198 void pushTensorData(const at::Tensor& tensor);
199
200 // Add a BINPUT op and return the memoization id used
201 size_t pushNextBinPut();
202
203 const void* getPointer(const IValue& ivalue);
204
205 // Caller checks that bufferPos_ > 0
206 void flushNonEmpty() {
207 writer_(buffer_.data(), bufferPos_);
208 bufferPos_ = 0;
209 }
210
211 void flush() {
212 if (bufferPos_ != 0) {
213 flushNonEmpty();
214 }
215 }
216
217 // These convert values to bytes and add them to the stack (NB: since T is to
218 // the left of a '::', its type cannot be deduced by the compiler so one must
219 // explicitly instantiate the template, i.e. push<int>(int) works, push(int)
220 // does not)
221 static CONSTEXPR_EXCEPT_WIN_CUDA size_t kBufferSize = 256;
222 template <typename T>
223 void push(typename std::common_type<T>::type value) {
224 const char* begin = reinterpret_cast<const char*>(&value);
225 if (bufferPos_ + sizeof(T) > buffer_.size()) {
226 flushNonEmpty();
227 }
228 static_assert(sizeof(T) <= kBufferSize, "Buffer size assumption");
229 memcpy(buffer_.data() + bufferPos_, begin, sizeof(T));
230 bufferPos_ += sizeof(T);
231 }
232
233 // Stream to write binary data to
234 // Code shouldn't call writer_ directly without first flush()ing.
235 std::function<void(const char*, size_t)> writer_;
236
237 // Buffer to avoid calling a writer_ on a per-byte basis.
238 std::array<char, kBufferSize> buffer_;
239 size_t bufferPos_{0};
240
241 // Stack of opcodes/data
242 std::vector<char> stack_;
243
244 // External table of tensors to serialize. If this is missing, then tensors
245 // are serialized directly into the pickle
246 std::vector<at::Tensor>* tensor_table_;
247
248 // TODO: only use this if necessary (add a pass to find all shared ivalues,
249 // and only memoize those)
250 uint32_t memo_id_ = 0;
251
252 // Memoization of IValues that have been written (index in table is used for
253 // BINPUT opcodes) to enable shared references
254 std::unordered_map<const void*, uint32_t> memoized_ivalue_map_;
255
256 // because we de-dup ivalues based on their raw pointer address in the above
257 // map we need to keep all the memoized values alive during the pickle.
258 // Otherwise, it is possible that a raw address gets reused for another
259 // object, and we will alias it to the old object at that address.
260 std::vector<IValue> memoized_ivalues_;
261
262 std::function<c10::QualifiedName(const c10::ClassTypePtr&)> type_renamer_;
263
264 // List of all the types that it wrote, inspect from the IValues it wrote.
265 std::vector<c10::ClassTypePtr>* memoized_class_types_;
266
267 // Function to grab next id_name for tensor storage, function is responsible
268 // for returning unique ids
269 std::function<std::string(const at::Tensor&)> get_tensor_id_;
270
271 // List of tensor storages to serialize in the same binary as the pickle data
272 // similar to ivalues, they are memoized using BINPUT
273 std::vector<at::Tensor> tensor_data_;
274 std::unordered_map<const void*, uint32_t> memoized_storage_map_;
275
276 std::unordered_map<std::string, uint32_t> memoized_globals_map_;
277 std::unordered_map<std::string, uint32_t> memoized_strings_map_;
278 std::unordered_map<std::string, uint32_t> memoized_devices_map_;
279 // when true, List and Dict objects will be wrapped in a
280 // torch.jit._pickle.restore_type_tag call to correctly set the dynamic
281 // TorchScript type for the object. When true the thing unpickling must have
282 // torch installed.
283 bool tag_aggregates_;
284};
285
286// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
287// if it was CUDA and to_cpu is True.
288TORCH_API WriteableTensorData
289getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true);
290
291// return the value of the tensor's storage pointer
292uint64_t getStorageKey(const at::Tensor& tensor);
293
294// if the cls has __getstate__/__setstate__
295// assert they have the right schema and return true,
296// otherwise return false
297bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);
298
299// Return a map of Tensor Metadata for serialization.
300// For now, it only takes care of `conj` and `neg` bit.
301inline std::unordered_map<std::string, bool> getTensorMetadata(
302 const at::Tensor& t) {
303 // We don't support serializing `ZeroTensor` as it is not public
304 // facing yet.
305 TORCH_CHECK(
306 !t._is_zerotensor(),
307 "ZeroTensor is not serializable,",
308 " please file an issue if required.");
309 std::unordered_map<std::string, bool> metadata{};
310
311 // Only add meta-data if the value is not default.
312 if (t.is_conj()) {
313 metadata["conj"] = true;
314 }
315 if (t.is_neg()) {
316 metadata["neg"] = true;
317 }
318 return metadata;
319}
320
321// set Tensor Metadata based on the map.
322// Refer: getTensorMathdata
323inline void setTensorMetadata(
324 const at::Tensor& t,
325 std::unordered_map<std::string, bool> metadata) {
326 for (auto& key_value_pair : metadata) {
327 if (key_value_pair.first == "conj") {
328 t._set_conj(true);
329 } else if (key_value_pair.first == "neg") {
330 t._set_neg(true);
331 } else {
332 TORCH_CHECK(
333 false,
334 "Unexpected key `",
335 key_value_pair.first,
336 "` passed to setTensorMetadata.");
337 }
338 }
339}
340
341// set Tensor metadata based on the map.
342// NOTE: This overload is required by unpickler.cpp
343inline void setTensorMetadata(
344 const at::Tensor& t,
345 c10::Dict<c10::IValue, c10::IValue> metadata_idict) {
346 std::unordered_map<std::string, bool> metadata;
347 for (auto& pair : metadata_idict) {
348 auto key = *pair.key().toString();
349 metadata[key] = pair.value().toBool();
350 }
351 setTensorMetadata(t, std::move(metadata));
352}
353
354} // namespace jit
355} // namespace torch
356