1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_BASE_TENSOR_SERIALIZATION_H
17#define GLOW_BASE_TENSOR_SERIALIZATION_H
18
19#ifdef WITH_PNG
20#include "glow/Base/Image.h"
21#endif // WITH_PNG
22#include "glow/Base/Tensor.h"
23
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/StringRef.h"
26
27namespace glow {
28
29/// Tensor serialization options.
30struct TensorSerializationOptions {
31 /// Option for loading/dumping a tensor with tensor type.
32 bool withType{true};
33};
34
35/// Dump the content of \p tensor to a binary file \p filename using the options
36/// \p opts. The binary representation of data is guaranteed to preserve data
37/// precision (bit exactness) upon round-trips (dump/load).
38void dumpTensorToBinaryFile(const Tensor &tensor, llvm::StringRef filename,
39 const TensorSerializationOptions &opts);
40
41/// Dump the content of \p tensor to \p fs using the options \p opts. The binary
42/// representation of data is guaranteed to preserve data precision (bit
43/// exactness) upon round-trips (dump/load).
44void dumpTensorToBinaryFile(const Tensor &tensor, std::ofstream &fs,
45 const TensorSerializationOptions &opts);
46
47/// Load the content of \p tensor from a binary file \p filename using the
48/// options \p opts. The binary representation of data is guaranteed to preserve
49/// data precision (bit exactness) upon round-trips (dump/load).
50void loadTensorFromBinaryFile(Tensor &tensor, llvm::StringRef filename,
51 const TensorSerializationOptions &opts);
52
53/// Dump the content of \p tensor to a text file \p filename using the options
54/// \p opts. The data will be listed as a 1D array of values separated by comma
55/// (",") without other formatting. The text representation of data is NOT
56/// guaranteed to preserve data precision (bit exactness) upon round-trips
57/// (dump/load) and is used mainly for human readability.
58void dumpTensorToTextFile(Tensor &tensor, llvm::StringRef filename,
59 const TensorSerializationOptions &opts);
60
61/// Load the content of \p tensor from a text file \p filename using the options
62/// \p opts. The values in the text file are expected to be listed as a 1D array
63/// of values, separated by comma (",") without other formatting. The text
64/// representation of data is NOT guaranteed to preserve data precision (bit
65/// exactness) upon round-trips (dump/load) but is used for human readability.
66void loadTensorFromTextFile(Tensor &tensor, llvm::StringRef filename,
67 const TensorSerializationOptions &opts);
68
69#ifdef WITH_PNG
70
71/// Load network input tensor (same one images are loaded into) from the tensor
72/// blobs files in \p filenames. All the loaded tensors are concatenated along
73/// the batch dimension. The default loader expected the following tensor blob
74/// format:
75/// -- first line: 4 space separated values representing NCHW dimensions (W is
76/// the fastest).
77/// -- second line: Space separated float value that are loaded into \p
78/// tensor.
79/// TODO: Default tensor loader could be extended to support various data types
80/// or tensor layout.
81void loadInputImageFromFileWithType(
82 const llvm::ArrayRef<std::string> &filenames, Tensor *tensor,
83 ImageLayout tensorLayout);
84/// Helper function to aid testing of loadFromFileWithShapeLayout.
85void dumpInputTensorToFileWithType(const llvm::ArrayRef<std::string> &filenames,
86 const Tensor &, ImageLayout);
87
88/// Input tensor loader function. The function needs to set \p tensor type
89/// according to the data provided in the tensor blob. Expected layout is
90/// provided in \p imageLayout.
91using InputTensorFileLoaderFn = std::function<void(
92 Tensor &tensor, llvm::StringRef filename, ImageLayout imageLayout)>;
93
94/// Register input tensor loader function.
95void registerInputTensorFileLoader(InputTensorFileLoaderFn loader);
96
97#endif // WITH_PNG
98} // namespace glow
99
100#endif // GLOW_BASE_TENSOR_SERIALIZATION_H
101