1 | #pragma once |
2 | |
3 | #include <ATen/core/qualified_name.h> |
4 | #include <string> |
5 | #include <utility> |
6 | #include <vector> |
7 | |
8 | #include <ATen/Utils.h> |
9 | #include <ATen/core/ivalue.h> |
10 | #include <ATen/core/jit_type.h> |
11 | #include <c10/util/ArrayRef.h> |
12 | #include <torch/csrc/Export.h> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | // See Python's pickletools.py for a detailed description of each of these codes |
18 | enum class PickleOpCode : char { |
19 | MARK = '(', |
20 | STOP = '.', |
21 | POP = '0', |
22 | POP_MARK = '1', |
23 | DUP = '2', |
24 | FLOAT = 'F', |
25 | INT = 'I', |
26 | BININT = 'J', |
27 | BININT1 = 'K', |
28 | LONG = 'L', |
29 | BININT2 = 'M', |
30 | NONE = 'N', |
31 | PERSID = 'P', |
32 | BINPERSID = 'Q', |
33 | REDUCE = 'R', |
34 | STRING = 'S', |
35 | BINSTRING = 'T', |
36 | SHORT_BINSTRING = 'U', |
37 | // NB: Avoid using UNICODE as it is a macro in the Windows API |
38 | UNICODE_ = 'V', |
39 | BINUNICODE = 'X', |
40 | APPEND = 'a', |
41 | BUILD = 'b', |
42 | GLOBAL = 'c', |
43 | DICT = 'd', |
44 | EMPTY_DICT = '}', |
45 | APPENDS = 'e', |
46 | GET = 'g', |
47 | BINGET = 'h', |
48 | INST = 'i', |
49 | LONG_BINGET = 'j', |
50 | LIST = 'l', |
51 | EMPTY_LIST = ']', |
52 | OBJ = 'o', |
53 | PUT = 'p', |
54 | BINPUT = 'q', |
55 | LONG_BINPUT = 'r', |
56 | SETITEM = 's', |
57 | TUPLE = 't', |
58 | EMPTY_TUPLE = ')', |
59 | SETITEMS = 'u', |
60 | BINFLOAT = 'G', |
61 | |
62 | // Protocol 2 |
63 | PROTO = char('\x80'), |
64 | NEWOBJ = '\x81', |
65 | EXT1 = '\x82', |
66 | EXT2 = '\x83', |
67 | EXT4 = '\x84', |
68 | TUPLE1 = '\x85', |
69 | TUPLE2 = '\x86', |
70 | TUPLE3 = '\x87', |
71 | NEWTRUE = '\x88', |
72 | NEWFALSE = '\x89', |
73 | LONG1 = '\x8a', |
74 | LONG4 = '\x8b', |
75 | |
76 | // Protocol 3 (Python 3.x) |
77 | BINBYTES = 'B', |
78 | SHORT_BINBYTES = 'C', |
79 | |
80 | // Protocol 4 |
81 | SHORT_BINUNICODE = char('\x8c'), |
82 | BINUNICODE8 = '\x8d', |
83 | BINBYTES8 = '\x8e', |
84 | EMPTY_SET = '\x8f', |
85 | ADDITEMS = '\x90', |
86 | FROZENSET = '\x91', |
87 | NEWOBJ_EX = '\x92', |
88 | STACK_GLOBAL = '\x93', |
89 | MEMOIZE = '\x94', |
90 | FRAME = '\x95' |
91 | }; |
92 | |
93 | using ::c10::IValue; |
94 | |
95 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
96 | struct WriteableTensorData { |
97 | const char* data() const { |
98 | return static_cast<const char*>(tensor_.storage().data()); |
99 | } |
100 | size_t sizeInBytes() const { |
101 | return size_; |
102 | } |
103 | size_t nbytes() const { |
104 | return tensor_.storage().nbytes(); |
105 | } |
106 | bool storageHasDeleter() const { |
107 | return tensor_.storage().data_ptr().get_context() != nullptr; |
108 | } |
109 | |
110 | private: |
111 | friend TORCH_API WriteableTensorData |
112 | getWriteableTensorData(const at::Tensor& tensor, bool to_cpu); |
113 | at::Tensor tensor_; |
114 | uint64_t size_; |
115 | }; |
116 | |
117 | void setTypeTags(bool state); |
118 | bool getTypeTags(); |
119 | |
120 | class TORCH_API Pickler { |
121 | AT_DISALLOW_COPY_AND_ASSIGN(Pickler); |
122 | |
123 | public: |
124 | Pickler(std::function<void(const char*, size_t)> writer) |
125 | : Pickler(std::move(writer), nullptr, nullptr, nullptr) {} |
126 | |
127 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
128 | Pickler( |
129 | std::function<void(const char*, size_t)> writer, |
130 | std::vector<at::Tensor>* tensor_table, |
131 | std::function<c10::QualifiedName(const c10::ClassTypePtr&)> type_renamer, |
132 | std::vector<c10::ClassTypePtr>* memoized_class_types, |
133 | std::function<std::string(const at::Tensor&)> get_tensor_id = nullptr, |
134 | bool tag_aggregates = true) |
135 | : writer_(std::move(writer)), |
136 | tensor_table_(tensor_table), |
137 | type_renamer_(std::move(type_renamer)), |
138 | memoized_class_types_(memoized_class_types), |
139 | get_tensor_id_(std::move(get_tensor_id)), |
140 | tag_aggregates_(tag_aggregates) {} |
141 | // NOLINTNEXTLINE(bugprone-exception-escape) |
142 | ~Pickler(); |
143 | |
144 | // Push protocol onto the stack |
145 | void protocol(); |
146 | |
147 | // Push STOP PickleOpCode onto the stack |
148 | void stop(); |
149 | |
150 | void pushIValue(const IValue& ivalue); |
151 | |
152 | void startTuple(); |
153 | void endTuple(); |
154 | |
155 | const std::vector<at::Tensor>& tensorData() { |
156 | return tensor_data_; |
157 | } |
158 | |
159 | void pushEmptyDict(); |
160 | void pushDict(const IValue& ivalue); |
161 | void pushInt(int64_t value); |
162 | void pushLong(const std::string& data); |
163 | |
164 | private: |
165 | void pushIValueImpl(const IValue& ivalue); |
166 | void startTypeTag(); |
167 | void endTypeTag(const IValue& value); |
168 | void pushBool(bool value); |
169 | void pushDouble(double value); |
170 | void pushComplexDouble(const IValue& value); |
171 | void pushGenericList(const IValue& ivalue); |
172 | void pushIntList(const IValue& ivalue); |
173 | void pushList(const IValue& ivalue); |
174 | void pushTensor(const IValue& ivalue); |
175 | void pushTensorReference(const IValue& ivalue); |
176 | void pushLiteralTensor(const IValue& ivalue); |
177 | void pushLiteralSparseTensor(const at::Tensor& tensor); |
178 | void pushTuple(const IValue& ivalue); |
179 | void pushString(const std::string& string); |
180 | void pushDevice(const IValue& ivalue); |
181 | #ifdef USE_DISTRIBUTED |
182 | void pushRRef(const IValue& ivalue); |
183 | #endif |
184 | // unmemoized version |
185 | void pushStringImpl(const std::string& string); |
186 | void pushStorageOfTensor(const at::Tensor& tensor); |
187 | |
188 | void pushBinGet(uint32_t memo_id); |
189 | void pushSpecializedList( |
190 | const IValue& ivalue, |
191 | const char* list_name, |
192 | const std::function<void(const IValue&)>& item_pusher); |
193 | void pushGlobal( |
194 | const std::string& module_name, |
195 | const std::string& class_name); |
196 | // raw string data is appended directly to the byte stream |
197 | void pushBytes(const std::string& string); |
198 | void pushTensorData(const at::Tensor& tensor); |
199 | |
200 | // Add a BINPUT op and return the memoization id used |
201 | size_t pushNextBinPut(); |
202 | |
203 | const void* getPointer(const IValue& ivalue); |
204 | |
205 | // Caller checks that bufferPos_ > 0 |
206 | void flushNonEmpty() { |
207 | writer_(buffer_.data(), bufferPos_); |
208 | bufferPos_ = 0; |
209 | } |
210 | |
211 | void flush() { |
212 | if (bufferPos_ != 0) { |
213 | flushNonEmpty(); |
214 | } |
215 | } |
216 | |
217 | // These convert values to bytes and add them to the stack (NB: since T is to |
218 | // the left of a '::', its type cannot be deduced by the compiler so one must |
219 | // explicitly instantiate the template, i.e. push<int>(int) works, push(int) |
220 | // does not) |
221 | static CONSTEXPR_EXCEPT_WIN_CUDA size_t kBufferSize = 256; |
222 | template <typename T> |
223 | void push(typename std::common_type<T>::type value) { |
224 | const char* begin = reinterpret_cast<const char*>(&value); |
225 | if (bufferPos_ + sizeof(T) > buffer_.size()) { |
226 | flushNonEmpty(); |
227 | } |
228 | static_assert(sizeof(T) <= kBufferSize, "Buffer size assumption" ); |
229 | memcpy(buffer_.data() + bufferPos_, begin, sizeof(T)); |
230 | bufferPos_ += sizeof(T); |
231 | } |
232 | |
233 | // Stream to write binary data to |
234 | // Code shouldn't call writer_ directly without first flush()ing. |
235 | std::function<void(const char*, size_t)> writer_; |
236 | |
237 | // Buffer to avoid calling a writer_ on a per-byte basis. |
238 | std::array<char, kBufferSize> buffer_; |
239 | size_t bufferPos_{0}; |
240 | |
241 | // Stack of opcodes/data |
242 | std::vector<char> stack_; |
243 | |
244 | // External table of tensors to serialize. If this is missing, then tensors |
245 | // are serialized directly into the pickle |
246 | std::vector<at::Tensor>* tensor_table_; |
247 | |
248 | // TODO: only use this if necessary (add a pass to find all shared ivalues, |
249 | // and only memoize those) |
250 | uint32_t memo_id_ = 0; |
251 | |
252 | // Memoization of IValues that have been written (index in table is used for |
253 | // BINPUT opcodes) to enable shared references |
254 | std::unordered_map<const void*, uint32_t> memoized_ivalue_map_; |
255 | |
256 | // because we de-dup ivalues based on their raw pointer address in the above |
257 | // map we need to keep all the memoized values alive during the pickle. |
258 | // Otherwise, it is possible that a raw address gets reused for another |
259 | // object, and we will alias it to the old object at that address. |
260 | std::vector<IValue> memoized_ivalues_; |
261 | |
262 | std::function<c10::QualifiedName(const c10::ClassTypePtr&)> type_renamer_; |
263 | |
264 | // List of all the types that it wrote, inspect from the IValues it wrote. |
265 | std::vector<c10::ClassTypePtr>* memoized_class_types_; |
266 | |
267 | // Function to grab next id_name for tensor storage, function is responsible |
268 | // for returning unique ids |
269 | std::function<std::string(const at::Tensor&)> get_tensor_id_; |
270 | |
271 | // List of tensor storages to serialize in the same binary as the pickle data |
272 | // similar to ivalues, they are memoized using BINPUT |
273 | std::vector<at::Tensor> tensor_data_; |
274 | std::unordered_map<const void*, uint32_t> memoized_storage_map_; |
275 | |
276 | std::unordered_map<std::string, uint32_t> memoized_globals_map_; |
277 | std::unordered_map<std::string, uint32_t> memoized_strings_map_; |
278 | std::unordered_map<std::string, uint32_t> memoized_devices_map_; |
279 | // when true, List and Dict objects will be wrapped in a |
280 | // torch.jit._pickle.restore_type_tag call to correctly set the dynamic |
281 | // TorchScript type for the object. When true the thing unpickling must have |
282 | // torch installed. |
283 | bool tag_aggregates_; |
284 | }; |
285 | |
286 | // returns a (tensor, record_size) for a tensor, converting it to a CPU tensor |
287 | // if it was CUDA and to_cpu is True. |
288 | TORCH_API WriteableTensorData |
289 | getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true); |
290 | |
291 | // return the value of the tensor's storage pointer |
292 | uint64_t getStorageKey(const at::Tensor& tensor); |
293 | |
294 | // if the cls has __getstate__/__setstate__ |
295 | // assert they have the right schema and return true, |
296 | // otherwise return false |
297 | bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls); |
298 | |
299 | // Return a map of Tensor Metadata for serialization. |
300 | // For now, it only takes care of `conj` and `neg` bit. |
301 | inline std::unordered_map<std::string, bool> getTensorMetadata( |
302 | const at::Tensor& t) { |
303 | // We don't support serializing `ZeroTensor` as it is not public |
304 | // facing yet. |
305 | TORCH_CHECK( |
306 | !t._is_zerotensor(), |
307 | "ZeroTensor is not serializable," , |
308 | " please file an issue if required." ); |
309 | std::unordered_map<std::string, bool> metadata{}; |
310 | |
311 | // Only add meta-data if the value is not default. |
312 | if (t.is_conj()) { |
313 | metadata["conj" ] = true; |
314 | } |
315 | if (t.is_neg()) { |
316 | metadata["neg" ] = true; |
317 | } |
318 | return metadata; |
319 | } |
320 | |
321 | // set Tensor Metadata based on the map. |
322 | // Refer: getTensorMathdata |
323 | inline void setTensorMetadata( |
324 | const at::Tensor& t, |
325 | std::unordered_map<std::string, bool> metadata) { |
326 | for (auto& key_value_pair : metadata) { |
327 | if (key_value_pair.first == "conj" ) { |
328 | t._set_conj(true); |
329 | } else if (key_value_pair.first == "neg" ) { |
330 | t._set_neg(true); |
331 | } else { |
332 | TORCH_CHECK( |
333 | false, |
334 | "Unexpected key `" , |
335 | key_value_pair.first, |
336 | "` passed to setTensorMetadata." ); |
337 | } |
338 | } |
339 | } |
340 | |
341 | // set Tensor metadata based on the map. |
342 | // NOTE: This overload is required by unpickler.cpp |
343 | inline void setTensorMetadata( |
344 | const at::Tensor& t, |
345 | c10::Dict<c10::IValue, c10::IValue> metadata_idict) { |
346 | std::unordered_map<std::string, bool> metadata; |
347 | for (auto& pair : metadata_idict) { |
348 | auto key = *pair.key().toString(); |
349 | metadata[key] = pair.value().toBool(); |
350 | } |
351 | setTensorMetadata(t, std::move(metadata)); |
352 | } |
353 | |
354 | } // namespace jit |
355 | } // namespace torch |
356 | |