1#pragma once
2
3#include <ATen/core/ivalue.h>
4#include <c10/util/ArrayRef.h>
5#include <caffe2/serialize/inline_container.h>
6#include <torch/csrc/Export.h>
7#include <torch/csrc/jit/serialization/pickler.h>
8#include <torch/csrc/jit/serialization/unpickler.h>
9
10namespace torch {
11namespace jit {
12
13/// Pickle an IValue by calling a function to handle writing the data.
14///
15/// `writer` is a function that takes in a pointer to a chunk of memory and its
16/// size and consumes it.
17///
18/// See `jit::pickle` for more details.
19TORCH_API void pickle(
20 std::function<void(const char* data_start, size_t data_len)> writer,
21 const IValue& ivalue,
22 std::vector<at::Tensor>* tensor_table = nullptr);
23
24/// Save a `torch::IValue` in a format compatible with Python's `pickle` module
25///
26/// If present, `tensor_table` is a pointer to a table in which tensors that
27/// are contained within `ivalue` are stored, and the bytes returned by the
28/// pickler will only include references to these tensors in the table. This can
29/// be used to keep the binary blob size small.
30/// If not provided, tensors are stored in the same byte stream as the pickle
31/// data, similar to `torch.save()` in eager Python.
32///
33/// Pickled values can be loaded in Python and C++:
34/// \rst
35/// .. code-block:: cpp
36///
37/// torch::IValue float_value(2.3);
38///
39/// // TODO: when tensors are stored in the pickle, delete this
40/// std::vector<at::Tensor> tensor_table;
41/// auto data = torch::jit::pickle(float_value, &tensor_table);
42///
43/// std::vector<torch::IValue> ivalues =
44/// torch::jit::unpickle(data.data(), data.size());
45///
46/// .. code-block:: python
47///
48/// values = torch.load('data.pkl')
49/// print(values)
50///
51/// \endrst
52TORCH_API std::vector<char> pickle(
53 const IValue& ivalue,
54 std::vector<at::Tensor>* tensor_table = nullptr);
55
56/// Save a `torch::IValue` in a format that can be loaded by both
57/// `torch::pickle_load` in C++ and `torch.load` in Python.
58TORCH_API std::vector<char> pickle_save(const IValue& ivalue);
59
60/// Deserialize a `torch::IValue` from bytes produced by either
61/// `torch::pickle_save` in C++ or `torch.save` in Python
62TORCH_API IValue pickle_load(const std::vector<char>& data);
63
64/// `reader` is a function that takes in a size to read from some pickled
65/// binary. `reader` should remember where it last read, and return
66/// the number of bytes read.
67/// See `torch::pickle` for details.
68/// type_resolver is used to resolve any JIT type based on type str
69TORCH_API IValue unpickle(
70 std::function<size_t(char*, size_t)> reader,
71 TypeResolver type_resolver,
72 c10::ArrayRef<at::Tensor> tensor_table,
73 c10::TypePtr (*type_parser)(const std::string&) =
74 Unpickler::defaultTypeParser);
75
76/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
77///
78/// If any `torch::IValue`s in the pickled data are `Object`s, then a
79/// `class_resolver` function must be provided.
80///
81/// See `torch::pickle` for details.
82TORCH_API IValue unpickle(
83 const char* data,
84 size_t size,
85 TypeResolver type_resolver = nullptr,
86 c10::ArrayRef<at::Tensor> tensor_table = {},
87 c10::TypePtr (*type_parser)(const std::string&) =
88 Unpickler::defaultTypeParser);
89
90} // namespace jit
91} // namespace torch
92