1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <string> |
18 | #include <unordered_map> |
19 | #include <vector> |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/kernel_def_builder.h" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/tensor_types.h" |
26 | #include "tensorflow/core/framework/types.h" |
27 | #include "tensorflow/core/lib/core/errors.h" |
28 | #include "tensorflow/core/lib/gtl/map_util.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | namespace { |
35 | // Returning a Status instead of using OP_REQUIRES directly since that doesn't |
36 | // seem to work outside the main OpKernel functions. |
37 | Status RemapVectorToMap( |
38 | const TTypes<const int64_t>::Vec& remapping, std::vector<bool>* id_present, |
39 | std::unordered_map<int64_t, int64_t>* old_id_to_new_id) { |
40 | id_present->clear(); |
41 | id_present->resize(remapping.size(), false); |
42 | for (int i = 0; i < remapping.size(); ++i) { |
43 | const int64_t old_id = remapping(i); |
44 | if (old_id < 0) continue; |
45 | (*id_present)[i] = true; |
46 | if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) { |
47 | return errors::Unimplemented( |
48 | strings::StrCat("Old ID " , old_id, " is mapped to both new ID " , |
49 | old_id_to_new_id->at(old_id), " and " , i, |
50 | ", which is not supported." )); |
51 | } |
52 | } |
53 | return OkStatus(); |
54 | } |
55 | } // anonymous namespace |
56 | |
57 | // This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and |
58 | // swaps around the rows/columns according to row_remapping/col_remapping. |
59 | // "Missing" cells are initialized with values from initializing_values. |
60 | class LoadAndRemapMatrixOp : public OpKernel { |
61 | public: |
62 | explicit LoadAndRemapMatrixOp(OpKernelConstruction* context) |
63 | : OpKernel(context) { |
64 | OP_REQUIRES_OK(context, context->GetAttr("num_rows" , &num_rows_)); |
65 | OP_REQUIRES_OK(context, context->GetAttr("num_cols" , &num_cols_)); |
66 | OP_REQUIRES_OK( |
67 | context, context->GetAttr("max_rows_in_memory" , &max_rows_in_memory_)); |
68 | } |
69 | |
70 | void Compute(OpKernelContext* context) override { |
71 | // Checks what we're remapping and inverts the relevant remapping Tensors to |
72 | // be maps with key = old ID, value = new ID. |
73 | std::unordered_map<int64_t, int64_t> old_row_to_new_row_map; |
74 | std::vector<bool> row_id_present; |
75 | const Tensor* row_remapping_t; |
76 | OP_REQUIRES_OK(context, context->input("row_remapping" , &row_remapping_t)); |
77 | OP_REQUIRES( |
78 | context, row_remapping_t->dims() == 1, |
79 | errors::InvalidArgument("The `row_remapping` tensor must be 1-D, got " |
80 | "a tensor of shape " , |
81 | row_remapping_t->shape().DebugString())); |
82 | const auto row_remapping = row_remapping_t->vec<int64_t>(); |
83 | OP_REQUIRES(context, row_remapping.size() == num_rows_, |
84 | errors::InvalidArgument(strings::StrCat( |
85 | "Size of row_remapping is " , row_remapping.size(), |
86 | " instead of being equal to num_rows=" , num_rows_))); |
87 | OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present, |
88 | &old_row_to_new_row_map)); |
89 | |
90 | // Calculates the min/max old row ID that we need to read, to save us from |
91 | // reading some unnecessary slices of the old tensor. |
92 | int64_t min_old_row = -1; |
93 | int64_t max_old_row = -1; |
94 | for (int i = 0; i < row_remapping.size(); ++i) { |
95 | if (min_old_row < 0 || |
96 | (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) { |
97 | min_old_row = row_remapping(i); |
98 | } |
99 | if (max_old_row < 0 || |
100 | (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) { |
101 | max_old_row = row_remapping(i); |
102 | } |
103 | } |
104 | |
105 | // Processes the remapping for columns. |
106 | std::unordered_map<int64_t, int64_t> old_col_to_new_col_map; |
107 | std::vector<bool> col_id_present; |
108 | const Tensor* col_remapping_t; |
109 | OP_REQUIRES_OK(context, context->input("col_remapping" , &col_remapping_t)); |
110 | const auto col_remapping = col_remapping_t->vec<int64_t>(); |
111 | // Note that we always "remap rows", even when the row vocabulary does |
112 | // not change, because partitioning requires a mapping from partitioned |
113 | // Variables to the full checkpoints we load. |
114 | const bool remap_cols = col_remapping.size() > 0; |
115 | if (remap_cols) { |
116 | OP_REQUIRES( |
117 | context, col_remapping.size() == num_cols_, |
118 | errors::InvalidArgument(strings::StrCat( |
119 | "Provided col_remapping, but its size is " , col_remapping.size(), |
120 | " instead of being equal to num_cols=" , num_cols_))); |
121 | OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present, |
122 | &old_col_to_new_col_map)); |
123 | } else { |
124 | col_id_present.clear(); |
125 | col_id_present.resize(num_cols_, true); |
126 | } |
127 | |
128 | // Processes the checkpoint source and the provided Tensor name. |
129 | const Tensor* ckpt_path_t; |
130 | OP_REQUIRES_OK(context, context->input("ckpt_path" , &ckpt_path_t)); |
131 | OP_REQUIRES( |
132 | context, ckpt_path_t->NumElements() == 1, |
133 | errors::InvalidArgument("The `ckpt_path` tensor must have exactly one " |
134 | "element, got tensor of shape " , |
135 | ckpt_path_t->shape().DebugString())); |
136 | const string& ckpt_path = ckpt_path_t->scalar<tstring>()(); |
137 | const Tensor* old_tensor_name_t; |
138 | OP_REQUIRES_OK(context, |
139 | context->input("old_tensor_name" , &old_tensor_name_t)); |
140 | const string& old_tensor_name = old_tensor_name_t->scalar<tstring>()(); |
141 | |
142 | LOG(INFO) << "Processing checkpoint : " << ckpt_path; |
143 | BundleReader reader(context->env(), ckpt_path); |
144 | OP_REQUIRES_OK(context, reader.status()); |
145 | |
146 | DataType tensor_type; |
147 | TensorShape tensor_shape; |
148 | OP_REQUIRES_OK(context, reader.LookupDtypeAndShape( |
149 | old_tensor_name, &tensor_type, &tensor_shape)); |
150 | OP_REQUIRES(context, tensor_type == DT_FLOAT, |
151 | errors::InvalidArgument(strings::StrCat( |
152 | "Tensor " , old_tensor_name, " has invalid type " , |
153 | DataTypeString(tensor_type), " instead of expected type " , |
154 | DataTypeString(DT_FLOAT)))); |
155 | // This op is limited to loading Tensors of rank 2 (matrices). |
156 | OP_REQUIRES( |
157 | context, tensor_shape.dims() == 2, |
158 | errors::InvalidArgument(strings::StrCat( |
159 | "Tensor " , old_tensor_name, " has shape " , |
160 | tensor_shape.DebugString(), " of invalid rank " , |
161 | tensor_shape.dims(), " instead of expected shape of rank 2." ))); |
162 | |
163 | if (!remap_cols) { |
164 | // TODO(weiho): Consider relaxing this restriction to allow partial column |
165 | // loading (even when no column remapping is specified) if there turns out |
166 | // to be a use case for it. |
167 | OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1), |
168 | errors::InvalidArgument(strings::StrCat( |
169 | "Tensor " , old_tensor_name, " has shape " , |
170 | tensor_shape.DebugString(), |
171 | ", where the size of its 2nd dimension is " , |
172 | tensor_shape.dim_size(1), |
173 | " instead of being equal to num_cols=" , num_cols_))); |
174 | } |
175 | |
176 | // Uses TensorSlice to potentially load the old tensor in chunks in case |
177 | // memory usage is a concern. |
178 | std::vector<TensorSlice> tensor_slices; |
179 | TensorSlice slice(tensor_shape.dims()); |
180 | if (min_old_row >= 0 && max_old_row >= 0) { |
181 | int64_t row_start = min_old_row; |
182 | // TODO(weiho): Given the list of old row IDs of interest (the keys of |
183 | // old_row_to_new_row_map), we could also try something smarter to |
184 | // find some minimal set of covering ranges for the list of old row IDs |
185 | // such that the size of each range is less than max_rows_in_memory_. |
186 | while (row_start <= max_old_row) { |
187 | const int64_t slice_length = |
188 | max_rows_in_memory_ <= 0 |
189 | // If max_rows_in_memory_ <= 0, we just load the entire chunk. |
190 | ? max_old_row - row_start + 1 |
191 | : std::min(max_rows_in_memory_, max_old_row - row_start + 1); |
192 | slice.set_start(0, row_start); |
193 | slice.set_length(0, slice_length); |
194 | tensor_slices.push_back(slice); |
195 | row_start += slice_length; |
196 | } |
197 | } |
198 | |
199 | // Allocates the output matrix. |
200 | Tensor* output_matrix_t = nullptr; |
201 | OP_REQUIRES_OK(context, |
202 | context->allocate_output("output_matrix" , |
203 | TensorShape({num_rows_, num_cols_}), |
204 | &output_matrix_t)); |
205 | auto output_matrix = output_matrix_t->matrix<float>(); |
206 | |
207 | // Iterates through tensor slices and copies over values from the old tensor |
208 | // to the output matrix. |
209 | int64_t row_index = min_old_row; |
210 | int64_t rows_copied = 0; |
211 | Tensor loaded_tensor_t; |
212 | for (const TensorSlice& tensor_slice : tensor_slices) { |
213 | LOG(INFO) << "Loading slice " << tensor_slice.DebugString(); |
214 | TensorShape slice_shape; |
215 | OP_REQUIRES_OK(context, |
216 | tensor_slice.SliceTensorShape(tensor_shape, &slice_shape)); |
217 | // Potentially re-allocates the tensor buffer since the last slice may |
218 | // have fewer rows than the other slices. |
219 | if (loaded_tensor_t.shape() != slice_shape) { |
220 | loaded_tensor_t = Tensor(DT_FLOAT, slice_shape); |
221 | } |
222 | OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice, |
223 | &loaded_tensor_t)); |
224 | |
225 | // Iterates through the old loaded tensor slice row-by-row. |
226 | for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) { |
227 | if (row_index % 500000 == min_old_row) { |
228 | LOG(INFO) << "Processing old row " << row_index; |
229 | } |
230 | |
231 | // If the old row ID is not found in old_row_to_new_row_map, continue |
232 | // to the next row; otherwise, copy it to the output matrix. |
233 | const int64_t* new_row_ptr = |
234 | gtl::FindOrNull(old_row_to_new_row_map, row_index); |
235 | if (new_row_ptr == nullptr) { |
236 | continue; |
237 | } |
238 | ++rows_copied; |
239 | const int64_t new_row = *new_row_ptr; |
240 | |
241 | // Copies over the row element-by-element, in case remapping is needed |
242 | // along the column axis. |
243 | const auto& loaded_tensor = loaded_tensor_t.matrix<float>(); |
244 | for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1); |
245 | ++old_col) { |
246 | int64_t new_col = old_col; |
247 | if (remap_cols) { |
248 | const int64_t* new_col_ptr = |
249 | gtl::FindOrNull(old_col_to_new_col_map, old_col); |
250 | if (new_col_ptr == nullptr) { |
251 | // Column remapping is specified, but this column is not found in |
252 | // old_col_to_new_col_map, so we leave it uninitialized, to be |
253 | // filled in with initializing_values later. |
254 | continue; |
255 | } |
256 | new_col = *new_col_ptr; |
257 | } |
258 | |
259 | OP_REQUIRES(context, |
260 | new_row < num_rows_ && new_col < num_cols_ && |
261 | new_row >= 0 && new_col >= 0, |
262 | errors::Internal(strings::StrCat( |
263 | "new_row=" , new_row, " and new_col=" , new_col, |
264 | " should have been less than num_rows_=" , num_rows_, |
265 | " and num_cols_=" , num_cols_, |
266 | " and non-negative. This should never have happened " |
267 | "if the code were correct. Please file a bug." ))); |
268 | output_matrix(new_row, new_col) = loaded_tensor(row, old_col); |
269 | } |
270 | } |
271 | } |
272 | LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with " |
273 | << tensor_shape.dim_size(0) << " rows) to new matrix (with " |
274 | << num_rows_ << " rows)." ; |
275 | |
276 | // At this point, there are potentially whole rows/columns uninitialized |
277 | // (corresponding to the indices where row_id_present/col_id_present are |
278 | // false). We fill this in cell-by-cell using row_id_present and |
279 | // col_id_present while dequeuing from the initializing_values vector. |
280 | const Tensor* initializing_values_t; |
281 | OP_REQUIRES_OK( |
282 | context, context->input("initializing_values" , &initializing_values_t)); |
283 | const auto initializing_values = initializing_values_t->flat<float>(); |
284 | int64_t initializing_values_index = 0; |
285 | for (int i = 0; i < num_rows_; ++i) { |
286 | for (int j = 0; j < num_cols_; ++j) { |
287 | if (row_id_present[i] && col_id_present[j]) continue; |
288 | OP_REQUIRES( |
289 | context, initializing_values_index < initializing_values.size(), |
290 | errors::InvalidArgument( |
291 | "initializing_values contained " , initializing_values.size(), |
292 | " elements, but more missing values remain." )); |
293 | output_matrix(i, j) = initializing_values(initializing_values_index); |
294 | ++initializing_values_index; |
295 | } |
296 | } |
297 | |
298 | // Checks that we used all the given initializing values. |
299 | OP_REQUIRES( |
300 | context, initializing_values_index == initializing_values.size(), |
301 | errors::InvalidArgument( |
302 | "initializing_values contained " , initializing_values.size(), |
303 | " elements, but only " , initializing_values_index, |
304 | " elements were used to fill in missing values." )); |
305 | } |
306 | |
307 | private: |
308 | int64_t num_rows_; |
309 | int64_t num_cols_; |
310 | int64_t max_rows_in_memory_; |
311 | }; |
312 | |
313 | REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix" ).Device(DEVICE_CPU), |
314 | LoadAndRemapMatrixOp); |
315 | |
316 | } // namespace tensorflow |
317 | |