1 | #pragma once |
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | #include <c10/util/ArrayRef.h> |
5 | #include <caffe2/serialize/inline_container.h> |
6 | #include <torch/csrc/Export.h> |
7 | #include <torch/csrc/jit/serialization/pickler.h> |
8 | #include <torch/csrc/jit/serialization/unpickler.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | /// Pickle an IValue by calling a function to handle writing the data. |
14 | /// |
15 | /// `writer` is a function that takes in a pointer to a chunk of memory and its |
16 | /// size and consumes it. |
17 | /// |
18 | /// See `jit::pickle` for more details. |
19 | TORCH_API void pickle( |
20 | std::function<void(const char* data_start, size_t data_len)> writer, |
21 | const IValue& ivalue, |
22 | std::vector<at::Tensor>* tensor_table = nullptr); |
23 | |
24 | /// Save a `torch::IValue` in a format compatible with Python's `pickle` module |
25 | /// |
26 | /// If present, `tensor_table` is a pointer to a table in which tensors that |
27 | /// are contained within `ivalue` are stored, and the bytes returned by the |
28 | /// pickler will only include references to these tensors in the table. This can |
29 | /// be used to keep the binary blob size small. |
30 | /// If not provided, tensors are stored in the same byte stream as the pickle |
31 | /// data, similar to `torch.save()` in eager Python. |
32 | /// |
33 | /// Pickled values can be loaded in Python and C++: |
34 | /// \rst |
35 | /// .. code-block:: cpp |
36 | /// |
37 | /// torch::IValue float_value(2.3); |
38 | /// |
39 | /// // TODO: when tensors are stored in the pickle, delete this |
40 | /// std::vector<at::Tensor> tensor_table; |
41 | /// auto data = torch::jit::pickle(float_value, &tensor_table); |
42 | /// |
43 | /// std::vector<torch::IValue> ivalues = |
44 | /// torch::jit::unpickle(data.data(), data.size()); |
45 | /// |
46 | /// .. code-block:: python |
47 | /// |
48 | /// values = torch.load('data.pkl') |
49 | /// print(values) |
50 | /// |
51 | /// \endrst |
52 | TORCH_API std::vector<char> pickle( |
53 | const IValue& ivalue, |
54 | std::vector<at::Tensor>* tensor_table = nullptr); |
55 | |
56 | /// Save a `torch::IValue` in a format that can be loaded by both |
57 | /// `torch::pickle_load` in C++ and `torch.load` in Python. |
58 | TORCH_API std::vector<char> pickle_save(const IValue& ivalue); |
59 | |
60 | /// Deserialize a `torch::IValue` from bytes produced by either |
61 | /// `torch::pickle_save` in C++ or `torch.save` in Python |
62 | TORCH_API IValue pickle_load(const std::vector<char>& data); |
63 | |
64 | /// `reader` is a function that takes in a size to read from some pickled |
65 | /// binary. `reader` should remember where it last read, and return |
66 | /// the number of bytes read. |
67 | /// See `torch::pickle` for details. |
68 | /// type_resolver is used to resolve any JIT type based on type str |
69 | TORCH_API IValue unpickle( |
70 | std::function<size_t(char*, size_t)> reader, |
71 | TypeResolver type_resolver, |
72 | c10::ArrayRef<at::Tensor> tensor_table, |
73 | c10::TypePtr (*type_parser)(const std::string&) = |
74 | Unpickler::defaultTypeParser); |
75 | |
76 | /// Decode a chunk of memory containing pickled data into its `torch::IValue`s. |
77 | /// |
78 | /// If any `torch::IValue`s in the pickled data are `Object`s, then a |
79 | /// `class_resolver` function must be provided. |
80 | /// |
81 | /// See `torch::pickle` for details. |
82 | TORCH_API IValue unpickle( |
83 | const char* data, |
84 | size_t size, |
85 | TypeResolver type_resolver = nullptr, |
86 | c10::ArrayRef<at::Tensor> tensor_table = {}, |
87 | c10::TypePtr (*type_parser)(const std::string&) = |
88 | Unpickler::defaultTypeParser); |
89 | |
90 | } // namespace jit |
91 | } // namespace torch |
92 | |