1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <string>
18#include <unordered_map>
19#include <vector>
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/kernel_def_builder.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_types.h"
26#include "tensorflow/core/framework/types.h"
27#include "tensorflow/core/lib/core/errors.h"
28#include "tensorflow/core/lib/gtl/map_util.h"
29#include "tensorflow/core/platform/logging.h"
30#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
31
32namespace tensorflow {
33
34namespace {
35// Returning a Status instead of using OP_REQUIRES directly since that doesn't
36// seem to work outside the main OpKernel functions.
37Status RemapVectorToMap(
38 const TTypes<const int64_t>::Vec& remapping, std::vector<bool>* id_present,
39 std::unordered_map<int64_t, int64_t>* old_id_to_new_id) {
40 id_present->clear();
41 id_present->resize(remapping.size(), false);
42 for (int i = 0; i < remapping.size(); ++i) {
43 const int64_t old_id = remapping(i);
44 if (old_id < 0) continue;
45 (*id_present)[i] = true;
46 if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) {
47 return errors::Unimplemented(
48 strings::StrCat("Old ID ", old_id, " is mapped to both new ID ",
49 old_id_to_new_id->at(old_id), " and ", i,
50 ", which is not supported."));
51 }
52 }
53 return OkStatus();
54}
55} // anonymous namespace
56
57// This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
58// swaps around the rows/columns according to row_remapping/col_remapping.
59// "Missing" cells are initialized with values from initializing_values.
60class LoadAndRemapMatrixOp : public OpKernel {
61 public:
62 explicit LoadAndRemapMatrixOp(OpKernelConstruction* context)
63 : OpKernel(context) {
64 OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
65 OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
66 OP_REQUIRES_OK(
67 context, context->GetAttr("max_rows_in_memory", &max_rows_in_memory_));
68 }
69
70 void Compute(OpKernelContext* context) override {
71 // Checks what we're remapping and inverts the relevant remapping Tensors to
72 // be maps with key = old ID, value = new ID.
73 std::unordered_map<int64_t, int64_t> old_row_to_new_row_map;
74 std::vector<bool> row_id_present;
75 const Tensor* row_remapping_t;
76 OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
77 OP_REQUIRES(
78 context, row_remapping_t->dims() == 1,
79 errors::InvalidArgument("The `row_remapping` tensor must be 1-D, got "
80 "a tensor of shape ",
81 row_remapping_t->shape().DebugString()));
82 const auto row_remapping = row_remapping_t->vec<int64_t>();
83 OP_REQUIRES(context, row_remapping.size() == num_rows_,
84 errors::InvalidArgument(strings::StrCat(
85 "Size of row_remapping is ", row_remapping.size(),
86 " instead of being equal to num_rows=", num_rows_)));
87 OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present,
88 &old_row_to_new_row_map));
89
90 // Calculates the min/max old row ID that we need to read, to save us from
91 // reading some unnecessary slices of the old tensor.
92 int64_t min_old_row = -1;
93 int64_t max_old_row = -1;
94 for (int i = 0; i < row_remapping.size(); ++i) {
95 if (min_old_row < 0 ||
96 (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) {
97 min_old_row = row_remapping(i);
98 }
99 if (max_old_row < 0 ||
100 (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) {
101 max_old_row = row_remapping(i);
102 }
103 }
104
105 // Processes the remapping for columns.
106 std::unordered_map<int64_t, int64_t> old_col_to_new_col_map;
107 std::vector<bool> col_id_present;
108 const Tensor* col_remapping_t;
109 OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
110 const auto col_remapping = col_remapping_t->vec<int64_t>();
111 // Note that we always "remap rows", even when the row vocabulary does
112 // not change, because partitioning requires a mapping from partitioned
113 // Variables to the full checkpoints we load.
114 const bool remap_cols = col_remapping.size() > 0;
115 if (remap_cols) {
116 OP_REQUIRES(
117 context, col_remapping.size() == num_cols_,
118 errors::InvalidArgument(strings::StrCat(
119 "Provided col_remapping, but its size is ", col_remapping.size(),
120 " instead of being equal to num_cols=", num_cols_)));
121 OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present,
122 &old_col_to_new_col_map));
123 } else {
124 col_id_present.clear();
125 col_id_present.resize(num_cols_, true);
126 }
127
128 // Processes the checkpoint source and the provided Tensor name.
129 const Tensor* ckpt_path_t;
130 OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
131 OP_REQUIRES(
132 context, ckpt_path_t->NumElements() == 1,
133 errors::InvalidArgument("The `ckpt_path` tensor must have exactly one "
134 "element, got tensor of shape ",
135 ckpt_path_t->shape().DebugString()));
136 const string& ckpt_path = ckpt_path_t->scalar<tstring>()();
137 const Tensor* old_tensor_name_t;
138 OP_REQUIRES_OK(context,
139 context->input("old_tensor_name", &old_tensor_name_t));
140 const string& old_tensor_name = old_tensor_name_t->scalar<tstring>()();
141
142 LOG(INFO) << "Processing checkpoint : " << ckpt_path;
143 BundleReader reader(context->env(), ckpt_path);
144 OP_REQUIRES_OK(context, reader.status());
145
146 DataType tensor_type;
147 TensorShape tensor_shape;
148 OP_REQUIRES_OK(context, reader.LookupDtypeAndShape(
149 old_tensor_name, &tensor_type, &tensor_shape));
150 OP_REQUIRES(context, tensor_type == DT_FLOAT,
151 errors::InvalidArgument(strings::StrCat(
152 "Tensor ", old_tensor_name, " has invalid type ",
153 DataTypeString(tensor_type), " instead of expected type ",
154 DataTypeString(DT_FLOAT))));
155 // This op is limited to loading Tensors of rank 2 (matrices).
156 OP_REQUIRES(
157 context, tensor_shape.dims() == 2,
158 errors::InvalidArgument(strings::StrCat(
159 "Tensor ", old_tensor_name, " has shape ",
160 tensor_shape.DebugString(), " of invalid rank ",
161 tensor_shape.dims(), " instead of expected shape of rank 2.")));
162
163 if (!remap_cols) {
164 // TODO(weiho): Consider relaxing this restriction to allow partial column
165 // loading (even when no column remapping is specified) if there turns out
166 // to be a use case for it.
167 OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1),
168 errors::InvalidArgument(strings::StrCat(
169 "Tensor ", old_tensor_name, " has shape ",
170 tensor_shape.DebugString(),
171 ", where the size of its 2nd dimension is ",
172 tensor_shape.dim_size(1),
173 " instead of being equal to num_cols=", num_cols_)));
174 }
175
176 // Uses TensorSlice to potentially load the old tensor in chunks in case
177 // memory usage is a concern.
178 std::vector<TensorSlice> tensor_slices;
179 TensorSlice slice(tensor_shape.dims());
180 if (min_old_row >= 0 && max_old_row >= 0) {
181 int64_t row_start = min_old_row;
182 // TODO(weiho): Given the list of old row IDs of interest (the keys of
183 // old_row_to_new_row_map), we could also try something smarter to
184 // find some minimal set of covering ranges for the list of old row IDs
185 // such that the size of each range is less than max_rows_in_memory_.
186 while (row_start <= max_old_row) {
187 const int64_t slice_length =
188 max_rows_in_memory_ <= 0
189 // If max_rows_in_memory_ <= 0, we just load the entire chunk.
190 ? max_old_row - row_start + 1
191 : std::min(max_rows_in_memory_, max_old_row - row_start + 1);
192 slice.set_start(0, row_start);
193 slice.set_length(0, slice_length);
194 tensor_slices.push_back(slice);
195 row_start += slice_length;
196 }
197 }
198
199 // Allocates the output matrix.
200 Tensor* output_matrix_t = nullptr;
201 OP_REQUIRES_OK(context,
202 context->allocate_output("output_matrix",
203 TensorShape({num_rows_, num_cols_}),
204 &output_matrix_t));
205 auto output_matrix = output_matrix_t->matrix<float>();
206
207 // Iterates through tensor slices and copies over values from the old tensor
208 // to the output matrix.
209 int64_t row_index = min_old_row;
210 int64_t rows_copied = 0;
211 Tensor loaded_tensor_t;
212 for (const TensorSlice& tensor_slice : tensor_slices) {
213 LOG(INFO) << "Loading slice " << tensor_slice.DebugString();
214 TensorShape slice_shape;
215 OP_REQUIRES_OK(context,
216 tensor_slice.SliceTensorShape(tensor_shape, &slice_shape));
217 // Potentially re-allocates the tensor buffer since the last slice may
218 // have fewer rows than the other slices.
219 if (loaded_tensor_t.shape() != slice_shape) {
220 loaded_tensor_t = Tensor(DT_FLOAT, slice_shape);
221 }
222 OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice,
223 &loaded_tensor_t));
224
225 // Iterates through the old loaded tensor slice row-by-row.
226 for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) {
227 if (row_index % 500000 == min_old_row) {
228 LOG(INFO) << "Processing old row " << row_index;
229 }
230
231 // If the old row ID is not found in old_row_to_new_row_map, continue
232 // to the next row; otherwise, copy it to the output matrix.
233 const int64_t* new_row_ptr =
234 gtl::FindOrNull(old_row_to_new_row_map, row_index);
235 if (new_row_ptr == nullptr) {
236 continue;
237 }
238 ++rows_copied;
239 const int64_t new_row = *new_row_ptr;
240
241 // Copies over the row element-by-element, in case remapping is needed
242 // along the column axis.
243 const auto& loaded_tensor = loaded_tensor_t.matrix<float>();
244 for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1);
245 ++old_col) {
246 int64_t new_col = old_col;
247 if (remap_cols) {
248 const int64_t* new_col_ptr =
249 gtl::FindOrNull(old_col_to_new_col_map, old_col);
250 if (new_col_ptr == nullptr) {
251 // Column remapping is specified, but this column is not found in
252 // old_col_to_new_col_map, so we leave it uninitialized, to be
253 // filled in with initializing_values later.
254 continue;
255 }
256 new_col = *new_col_ptr;
257 }
258
259 OP_REQUIRES(context,
260 new_row < num_rows_ && new_col < num_cols_ &&
261 new_row >= 0 && new_col >= 0,
262 errors::Internal(strings::StrCat(
263 "new_row=", new_row, " and new_col=", new_col,
264 " should have been less than num_rows_=", num_rows_,
265 " and num_cols_=", num_cols_,
266 " and non-negative. This should never have happened "
267 "if the code were correct. Please file a bug.")));
268 output_matrix(new_row, new_col) = loaded_tensor(row, old_col);
269 }
270 }
271 }
272 LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with "
273 << tensor_shape.dim_size(0) << " rows) to new matrix (with "
274 << num_rows_ << " rows).";
275
276 // At this point, there are potentially whole rows/columns uninitialized
277 // (corresponding to the indices where row_id_present/col_id_present are
278 // false). We fill this in cell-by-cell using row_id_present and
279 // col_id_present while dequeuing from the initializing_values vector.
280 const Tensor* initializing_values_t;
281 OP_REQUIRES_OK(
282 context, context->input("initializing_values", &initializing_values_t));
283 const auto initializing_values = initializing_values_t->flat<float>();
284 int64_t initializing_values_index = 0;
285 for (int i = 0; i < num_rows_; ++i) {
286 for (int j = 0; j < num_cols_; ++j) {
287 if (row_id_present[i] && col_id_present[j]) continue;
288 OP_REQUIRES(
289 context, initializing_values_index < initializing_values.size(),
290 errors::InvalidArgument(
291 "initializing_values contained ", initializing_values.size(),
292 " elements, but more missing values remain."));
293 output_matrix(i, j) = initializing_values(initializing_values_index);
294 ++initializing_values_index;
295 }
296 }
297
298 // Checks that we used all the given initializing values.
299 OP_REQUIRES(
300 context, initializing_values_index == initializing_values.size(),
301 errors::InvalidArgument(
302 "initializing_values contained ", initializing_values.size(),
303 " elements, but only ", initializing_values_index,
304 " elements were used to fill in missing values."));
305 }
306
307 private:
308 int64_t num_rows_;
309 int64_t num_cols_;
310 int64_t max_rows_in_memory_;
311};
312
313REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
314 LoadAndRemapMatrixOp);
315
316} // namespace tensorflow
317