1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/cc/dtensor_device_util.h" |
17 | |
18 | #include <cstddef> |
19 | #include <string> |
20 | #include <utility> |
21 | |
22 | #include "absl/container/flat_hash_map.h" |
23 | #include "absl/strings/str_cat.h" |
24 | #include "tensorflow/c/eager/c_api_internal.h" |
25 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
26 | #include "tensorflow/c/tf_status.h" |
27 | #include "tensorflow/compiler/xla/status_macros.h" |
28 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
29 | #include "tensorflow/core/common_runtime/shape_refiner.h" |
30 | #include "tensorflow/core/framework/attr_value.pb.h" |
31 | #include "tensorflow/core/framework/function.h" |
32 | #include "tensorflow/core/framework/node_def.pb.h" |
33 | #include "tensorflow/core/framework/node_def_util.h" |
34 | #include "tensorflow/core/framework/tensor.pb.h" |
35 | #include "tensorflow/core/framework/types.pb.h" |
36 | #include "tensorflow/core/graph/graph.h" |
37 | #include "tensorflow/core/lib/strings/proto_serialization.h" |
38 | #include "tensorflow/core/platform/errors.h" |
39 | #include "tensorflow/core/platform/fingerprint.h" |
40 | #include "tensorflow/core/public/version.h" |
41 | #include "tensorflow/dtensor/cc/constants.h" |
42 | #include "tensorflow/dtensor/cc/dstatus.h" |
43 | #include "tensorflow/dtensor/cc/small_constant_optimization.h" |
44 | |
45 | namespace tensorflow { |
46 | namespace dtensor { |
47 | namespace { |
48 | // Represents an input node during graph construction. |
49 | // When executing a Function, `output` is used to align graph inputs |
50 | // with the inputs to the function call. |
51 | struct FunctionArgument { |
52 | Node* node; |
53 | NodeDefBuilder::NodeOut output; |
54 | }; |
55 | |
56 | std::unique_ptr<parallel_device::ParallelTensor> |
57 | BroadcastTensorHandleToParallelTensor(TFE_Context* context, |
58 | TFE_TensorHandle* tensor, |
59 | const MeshWithParallelDevice& mesh, |
60 | TF_Status* status) { |
61 | // Broadcast tensor value to local devices. |
62 | const Mesh& target_mesh = mesh.mesh_config(); |
63 | absl::Span<const std::string> local_devices = target_mesh.local_devices(); |
64 | const int num_local_devices = local_devices.size(); |
65 | |
66 | std::vector<parallel_device::TensorHandlePtr> components; |
67 | components.reserve(num_local_devices); |
68 | for (int i = 0; i < num_local_devices; ++i) { |
69 | // Create tensor copies to each local devices specifie by `target_mesh`. |
70 | components.emplace_back(TFE_TensorHandleCopyToDevice( |
71 | tensor, context, local_devices[i].c_str(), status)); |
72 | if (TF_GetCode(status) != TF_OK) { |
73 | TF_SetStatus( |
74 | status, TF_INTERNAL, |
75 | absl::StrCat( |
76 | "Unable to copy tensor value for broadcast. Original message: " , |
77 | TF_Message(status)) |
78 | .c_str()); |
79 | return nullptr; |
80 | } |
81 | } |
82 | |
83 | std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = |
84 | parallel_device::ParallelTensor::FromTensorHandles( |
85 | mesh.parallel_device(), std::move(components), status); |
86 | if (TF_GetCode(status) != TF_OK) return nullptr; |
87 | return parallel_tensor; |
88 | } |
89 | |
90 | // Broadcast a single non-parallel resource tensor onto `mesh` with a fully |
91 | // replicated sharding spec. Does not take ownership of `tensor`. |
92 | std::unique_ptr<TensorWithLayout> BroadcastResourceTensor( |
93 | TFE_Context* context, TFE_TensorHandle* tensor, |
94 | const MeshWithParallelDevice& mesh, const std::string& dtensor_device_name, |
95 | TF_Status* status) { |
96 | // Only broadcast resource tensors that point to scalars since they are |
97 | // always replicated. We also still want to catch honest user errors so |
98 | // error out on non-scalars. |
99 | // Resolve the Tensor as resource handle and get the shape and dtype |
100 | // of the tensor it points to. |
101 | std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tf_tensor( |
102 | TFE_TensorHandleResolve(tensor, status), TF_DeleteTensor); |
103 | Tensor t; |
104 | Status convert_status = TF_TensorToTensor(tf_tensor.get(), &t); |
105 | if (!convert_status.ok() || t.dtype() != DataType::DT_RESOURCE) { |
106 | TF_SetStatus(status, TF_INTERNAL, |
107 | absl::StrCat("TF_TensorToTensor() Conversion failed:" , |
108 | convert_status.error_message()) |
109 | .c_str()); |
110 | return nullptr; |
111 | } |
112 | // Replicate this resource handle to all devices without changing the |
113 | // associated device of the resource itself. |
114 | ResourceHandle r = t.flat<ResourceHandle>()(0); |
115 | |
116 | // Only broadcast resource tensors onto a CPU mesh. Copying |
117 | // resource tensors to non CPU device is not supported. |
118 | if (!mesh.mesh_config().is_cpu_mesh()) { |
119 | std::string error_message = |
120 | "Using a non-DTensor variable with DTensor is only supported for " |
121 | "copying to a CPU mesh. If you are using a scope " |
122 | "based API, create " |
123 | "variables inside the DTensor scope.\n" ; |
124 | |
125 | // Get the stack_trace and Summaries from the resource tensor. |
126 | absl::StrAppend( |
127 | &error_message, "Offending variable summary: " , r.SummarizeValue(), |
128 | "\nStack trace: " , DefinitionLocationMsg(r.definition_stack_trace())); |
129 | TF_SetStatus(status, TF_INVALID_ARGUMENT, error_message.c_str()); |
130 | return nullptr; |
131 | } |
132 | |
133 | LOG(INFO) << "Broadcasting resource tensor to a dtensor resource tensor." ; |
134 | if (mesh.mesh_config().is_remote()) { |
135 | TF_DataType dtype = TFE_TensorHandleDataType(tensor); |
136 | std::vector<int64_t> shape(TensorShapeAsVector(tensor, status)); |
137 | if (TF_GetCode(status) != TF_OK) return nullptr; |
138 | auto layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), shape.size()); |
139 | |
140 | auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout); |
141 | return ret; |
142 | } |
143 | |
144 | std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = |
145 | BroadcastTensorHandleToParallelTensor(context, tensor, mesh, status); |
146 | if (TF_GetCode(status) != TF_OK) return nullptr; |
147 | |
148 | int rank = r.dtypes_and_shapes().empty() |
149 | ? 0 |
150 | : r.dtypes_and_shapes().begin()->shape.dims(); |
151 | |
152 | StatusOr<std::unique_ptr<TensorWithLayout>> result = TensorWithLayout::Wrap( |
153 | std::move(parallel_tensor), mesh, |
154 | Layout::ReplicatedOnMesh(mesh.mesh_config(), rank)); |
155 | if (!result.ok()) { |
156 | TF_SetStatus( |
157 | status, TF_INTERNAL, |
158 | absl::StrCat("Error creating a TensorWithLayout from a resource tensor " |
159 | "during broadcasting with original error message:" , |
160 | result.status().error_message()) |
161 | .c_str()); |
162 | return nullptr; |
163 | } |
164 | |
165 | if (!r.dtypes_and_shapes().empty()) { |
166 | PartialTensorShape partial_shape = r.dtypes_and_shapes().begin()->shape; |
167 | // Set the shape/type of the tensor that the resource points to |
168 | // so that the graph has correct shape/type information that we can use. |
169 | (*result)->UpdateShapeAndDType( |
170 | partial_shape.AsProto(), r.dtypes_and_shapes().begin()->dtype, status); |
171 | } |
172 | |
173 | if (TF_GetCode(status) != TF_OK) { |
174 | TF_SetStatus(status, TF_INTERNAL, |
175 | "Error updating shape and dtype for resource tensor during " |
176 | "broadcasting." ); |
177 | return nullptr; |
178 | } |
179 | return std::move(*result); |
180 | } |
181 | |
182 | bool LayoutsAreCompatible(absl::optional<Layout> first_layout, |
183 | absl::optional<Layout> second_layout) { |
184 | if (!first_layout.has_value() && !second_layout.has_value()) { |
185 | return true; |
186 | } |
187 | if (!first_layout.has_value() || !second_layout.has_value()) { |
188 | return false; |
189 | } |
190 | return first_layout.value() == second_layout.value(); |
191 | } |
192 | |
193 | // Parse a pair of attribute of (indices, layouts) into a map. |
194 | Status ParseAttrMap(const Node& node, absl::string_view indices_attr, |
195 | absl::string_view layout_attr, |
196 | std::map<int, Layout>* indices_layout_map) { |
197 | std::vector<std::string> layouts; |
198 | if (!TryGetNodeAttr(node.attrs(), layout_attr, &layouts)) { |
199 | return OkStatus(); |
200 | } |
201 | const TensorProto* indices; |
202 | if (!TryGetNodeAttr(node.attrs(), indices_attr, &indices)) { |
203 | return errors::Internal( |
204 | "Arg indices must be set when setting inferred resource layouts." ); |
205 | } |
206 | if (indices->int_val_size() != layouts.size()) { |
207 | return errors::Internal( |
208 | "Arg indices for inferred resource argument must match the " |
209 | "size of inferred resource layout." ); |
210 | } |
211 | for (int i = 0; i < indices->int_val_size(); ++i) { |
212 | const auto arg_index = indices->int_val(i); |
213 | const auto& arg_layout = layouts[i]; |
214 | indices_layout_map->emplace( |
215 | arg_index, tensorflow::dtensor::Layout::FromString(arg_layout).value()); |
216 | } |
217 | return OkStatus(); |
218 | } |
219 | |
220 | Status ParseResourceArgumentLayouts( |
221 | const Node& node, std::map<int, Layout>* inferred_resource_input_layouts) { |
222 | return ParseAttrMap(node, kNewResourceLayoutIndices, kNewResourceArgLayouts, |
223 | inferred_resource_input_layouts); |
224 | } |
225 | |
226 | Status ParseShapeInputLayouts(const Node& node, |
227 | std::map<int, Layout>* shape_output_metadata) { |
228 | return ParseAttrMap(node, kShapeOpInputLayoutIndices, kShapeOpInputLayout, |
229 | shape_output_metadata); |
230 | } |
231 | |
232 | // Gets the layout attached to a specific node at a given index, ignoring any |
233 | // Identity ops. |
234 | StatusOr<Layout> GetLayoutThroughIdentityOps(Node* op, int output_index) { |
235 | while (op->op_def().name() == "Identity" || |
236 | op->op_def().name() == "IdentityN" ) { |
237 | const Edge* edge; |
238 | TF_RETURN_IF_ERROR(op->input_edge(output_index, &edge)); |
239 | op = edge->src(); |
240 | output_index = edge->src_output(); |
241 | } |
242 | const auto serialized_layouts = op->attrs().Find(kLayoutAttr); |
243 | |
244 | if (!serialized_layouts) { |
245 | return errors::InvalidArgument( |
246 | op->op_def().name(), " doesn't contain attribute : " , kLayoutAttr); |
247 | } |
248 | |
249 | // We assume that there is one layout for each output. |
250 | if (serialized_layouts->list().s_size() != op->num_outputs()) { |
251 | return errors::InvalidArgument( |
252 | "Number of outputs to " , op->op_def().name(), |
253 | " does not match number of layouts attached" ); |
254 | } |
255 | |
256 | return Layout::FromString(serialized_layouts->list().s(output_index)); |
257 | } |
258 | |
259 | } // namespace |
260 | |
261 | tensorflow::Fprint128 TensorWithLayout::CacheKey() const { |
262 | tensorflow::Fprint128 f = tensorflow::Fingerprint128(layout_.ToString()); |
263 | // Use exact shape to compute the key. |
264 | for (const int64_t dim : local_shape()) { |
265 | f = FingerprintCat128(f, dim); |
266 | } |
267 | if (const_value_.has_value()) { |
268 | std::string serialized; |
269 | SerializeToStringDeterministic(const_value_.value(), &serialized); |
270 | f = FingerprintCat128(f, tensorflow::Fingerprint128(serialized)); |
271 | } |
272 | return f; |
273 | } |
274 | |
275 | std::unique_ptr<TensorWithLayout> TensorWithLayout::Broadcast( |
276 | TFE_Context* context, TFE_TensorHandle* tensor, |
277 | const MeshWithParallelDevice& mesh, const std::string& dtensor_device_name, |
278 | TF_Status* status) { |
279 | const char* input_device = TFE_TensorHandleDeviceName(tensor, status); |
280 | if (TF_GetCode(status) != TF_OK) return nullptr; |
281 | |
282 | if (dtensor_device_name == input_device) { |
283 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
284 | "Input to Broadcast must be eager tensor." ); |
285 | return nullptr; |
286 | } |
287 | |
288 | // Handle resource tensor broadcasting to the mesh. |
289 | if (TFE_TensorHandleDataType(tensor) == TF_RESOURCE) { |
290 | return BroadcastResourceTensor(context, tensor, mesh, dtensor_device_name, |
291 | status); |
292 | } |
293 | |
294 | if (mesh.mesh_config().is_remote()) { |
295 | TF_DataType dtype = TFE_TensorHandleDataType(tensor); |
296 | std::vector<int64_t> shape(TensorShapeAsVector(tensor, status)); |
297 | if (TF_GetCode(status) != TF_OK) return nullptr; |
298 | auto layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), shape.size()); |
299 | |
300 | auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout); |
301 | absl::optional<NodeDef> const_value = |
302 | ExtractSmallTensorValue(context, tensor, layout, status); |
303 | if (TF_GetCode(status) != TF_OK) return nullptr; |
304 | if (const_value) { |
305 | ret->set_const_value(const_value.value()); |
306 | } |
307 | return ret; |
308 | } |
309 | |
310 | std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = |
311 | BroadcastTensorHandleToParallelTensor(context, tensor, mesh, status); |
312 | if (TF_GetCode(status) != TF_OK) return nullptr; |
313 | |
314 | const std::vector<int64_t>* shape; |
315 | Status s = parallel_tensor->Shape(&shape); |
316 | if (!s.ok()) { |
317 | TF_SetStatus(status, static_cast<TF_Code>(s.code()), |
318 | s.error_message().c_str()); |
319 | return nullptr; |
320 | } |
321 | size_t num_dims = shape->size(); |
322 | const Layout layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), num_dims); |
323 | |
324 | absl::optional<NodeDef> const_value = |
325 | ExtractSmallTensorValue(context, tensor, layout, status); |
326 | if (TF_GetCode(status) != TF_OK) return nullptr; |
327 | |
328 | std::unique_ptr<TensorWithLayout> result(new TensorWithLayout( |
329 | std::move(parallel_tensor), mesh, std::move(layout), *shape, |
330 | /*dtype=*/absl::nullopt, std::move(const_value))); |
331 | return result; |
332 | } |
333 | |
334 | StatusOr<std::unique_ptr<TensorWithLayout>> TensorWithLayout::Wrap( |
335 | std::unique_ptr<parallel_device::ParallelTensor> tensor, |
336 | const MeshWithParallelDevice& mesh, const Layout& layout) { |
337 | const std::vector<int64_t>* shape; |
338 | TF_RETURN_IF_ERROR(tensor->Shape(&shape)); |
339 | |
340 | if (tensor->dtype() != TF_RESOURCE) { |
341 | return std::unique_ptr<TensorWithLayout>( |
342 | new TensorWithLayout(std::move(tensor), mesh, layout, *shape)); |
343 | } else { |
344 | return std::unique_ptr<TensorWithLayout>( |
345 | new ResourceHandleWithLayout(std::move(tensor), mesh, layout, *shape)); |
346 | } |
347 | } |
348 | |
349 | std::unique_ptr<TensorWithLayout> TensorWithLayout::Dummy( |
350 | const std::vector<int64_t>& local_shape, const TF_DataType dtype, |
351 | const MeshWithParallelDevice& mesh, const Layout& layout) { |
352 | if (dtype != TF_RESOURCE) { |
353 | return std::unique_ptr<TensorWithLayout>(new TensorWithLayout( |
354 | /*tensor=*/nullptr, mesh, layout, local_shape, dtype)); |
355 | } else { |
356 | return std::unique_ptr<TensorWithLayout>(new ResourceHandleWithLayout( |
357 | /*tensor=*/nullptr, mesh, layout, local_shape)); |
358 | } |
359 | } |
360 | |
361 | std::string TensorWithLayout::SummarizeValue() const { |
362 | std::string value_summary; |
363 | Status status; |
364 | if (dtype() != TF_RESOURCE && layout().IsFullyReplicated()) { |
365 | status = |
366 | tensorflow::unwrap(tensor()->tensor(0))->SummarizeValue(value_summary); |
367 | } else { |
368 | // Note that this just prints the local values for sharded tensors. We could |
369 | // instead run a collective here to relayout to replicated. |
370 | status = tensor()->SummarizeValue(value_summary); |
371 | } |
372 | if (!status.ok()) { |
373 | value_summary = "<error computing value>" ; |
374 | } |
375 | return absl::StrCat(value_summary, ", layout=\"" , layout().ToString(), "\"" ); |
376 | } |
377 | |
378 | std::string TensorWithLayout::DebugString() const { |
379 | auto dtype = static_cast<DataType>(tensor()->dtype()); |
380 | |
381 | const auto& shape_vector = global_shape(); |
382 | return absl::StrCat("DTensor(" , SummarizeValue(), |
383 | ", shape=" , ShapeToDebugString(shape_vector), |
384 | ", type=" , DataTypeString(dtype), ")" ); |
385 | } |
386 | |
387 | void ResourceHandleWithLayout::EncodeAttributes( |
388 | tensorflow::NodeDefBuilder& builder) const { |
389 | // If set, attach shape and dtype to the given node def. |
390 | if (dereferenced_shape().has_value()) { |
391 | builder.Attr("_handle_shapes" , {*dereferenced_shape()}); |
392 | } |
393 | if (dereferenced_dtype().has_value()) { |
394 | builder.Attr("_handle_dtypes" , {*dereferenced_dtype()}); |
395 | } |
396 | } |
397 | |
398 | tensorflow::Fprint128 ResourceHandleWithLayout::CacheKey() const { |
399 | tensorflow::Fprint128 f = tensorflow::Fingerprint128(layout().ToString()); |
400 | if (dereferenced_shape().has_value()) { |
401 | std::string serialized; |
402 | SerializeToStringDeterministic(dereferenced_shape().value(), &serialized); |
403 | f = FingerprintCat128(f, tensorflow::Fingerprint128(serialized)); |
404 | } |
405 | if (dereferenced_dtype().has_value()) { |
406 | f = FingerprintCat128(f, dereferenced_dtype().value()); |
407 | } |
408 | return f; |
409 | } |
410 | |
411 | void ResourceHandleWithLayout::UpdateLayout(const Layout& new_layout, |
412 | TF_Status* status) { |
413 | // Only set the value for deferenced layout if the incoming layout is not |
414 | // empty. This is still hacky as we use empty layout as placeholder for |
415 | // eagerly placed VarHandleOp. |
416 | if (!dereferenced_layout_.has_value() && new_layout.IsEmpty()) return; |
417 | if (dereferenced_layout_.has_value() && |
418 | !LayoutsAreCompatible(dereferenced_layout_, new_layout)) { |
419 | // TODO(xiejw, allenl): Consider allowing variables to switch layouts. |
420 | RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
421 | "Attempted to overwrite an existing Layout." ); |
422 | } |
423 | dereferenced_layout_.emplace(new_layout); |
424 | } |
425 | |
426 | void ResourceHandleWithLayout::UpdateAttrs(const EmbeddingResourceAttrs& attrs, |
427 | TF_Status* status) { |
428 | if (attrs_.has_value()) { |
429 | RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
430 | "Attepted to overwrite an existing embedding resource " |
431 | "attribute." ); |
432 | } |
433 | attrs_.emplace(attrs); |
434 | } |
435 | |
436 | StatusOr<std::unique_ptr<TensorWithLayout>> SparseTensorWithLayout::Wrap( |
437 | std::unique_ptr<parallel_device::ParallelTensor> indices_tensor, |
438 | std::unique_ptr<parallel_device::ParallelTensor> values_tensor, |
439 | std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor, |
440 | const MeshWithParallelDevice& mesh, const Layout& layout, |
441 | std::vector<int64_t> local_shape) { |
442 | return std::unique_ptr<TensorWithLayout>(new SparseTensorWithLayout( |
443 | std::move(indices_tensor), std::move(values_tensor), |
444 | std::move(shapes_tensor), mesh, layout, local_shape)); |
445 | } |
446 | |
447 | std::string SparseTensorWithLayout::SummarizeValue() const { |
448 | std::string indices_summary; |
449 | std::string values_summary; |
450 | std::string dense_shapes_summary; |
451 | |
452 | Status indices_status; |
453 | Status values_status; |
454 | Status dense_shapes_status; |
455 | |
456 | if (layout().IsFullyReplicated()) { |
457 | indices_status = tensorflow::unwrap(indices_->tensor(0)) |
458 | ->SummarizeValue(indices_summary); |
459 | values_status = |
460 | tensorflow::unwrap(values_->tensor(0))->SummarizeValue(values_summary); |
461 | dense_shapes_status = tensorflow::unwrap(dense_shapes_->tensor(0)) |
462 | ->SummarizeValue(dense_shapes_summary); |
463 | } else { |
464 | indices_status = indices_->SummarizeValue(indices_summary); |
465 | values_status = values_->SummarizeValue(values_summary); |
466 | dense_shapes_status = dense_shapes_->SummarizeValue(dense_shapes_summary); |
467 | } |
468 | |
469 | if (!indices_status.ok()) |
470 | values_summary = "<error computing summary for indices>" ; |
471 | if (!values_status.ok()) |
472 | indices_summary = "<error computing summary for values>" ; |
473 | if (!dense_shapes_status.ok()) |
474 | indices_summary = "<error computing summary for dense_shapes>" ; |
475 | |
476 | return absl::StrCat("indices: " , indices_summary, ", " , |
477 | "values: " , values_summary, ", " , |
478 | "dense_shapes: " , dense_shapes_summary, ", layout=\"" , |
479 | layout().ToString(), "\"" ); |
480 | } |
481 | |
482 | std::string SparseTensorWithLayout::DebugString() const { |
483 | auto dtype = static_cast<DataType>(values_->dtype()); |
484 | |
485 | const auto& shape_vector = global_shape(); |
486 | return absl::StrCat("DTensor(" , SummarizeValue(), |
487 | ", shape=" , ShapeToDebugString(shape_vector), |
488 | ", type=" , DataTypeString(dtype), ")" ); |
489 | } |
490 | |
491 | TF_DataType SparseTensorWithLayout::dtype() const { |
492 | if (dtype_.has_value()) { |
493 | return dtype_.value(); |
494 | } else { |
495 | return values_->dtype(); |
496 | } |
497 | } |
498 | |
499 | TFE_TensorHandle* SparseTensorWithLayout::get_tensor(size_t index) const { |
500 | int num_sparse_tensors = num_tensors() / 3; |
501 | if (index < num_sparse_tensors) { |
502 | return indices()->tensor(index); |
503 | } else if (index < 2 * num_sparse_tensors) { |
504 | return values()->tensor(index % num_sparse_tensors); |
505 | } else { |
506 | return dense_shapes()->tensor(index % num_sparse_tensors); |
507 | } |
508 | } |
509 | |
510 | absl::flat_hash_map<int, NodeDef> GetConstantFoldableTensors( |
511 | const std::vector<TensorWithLayout*>& inputs) { |
512 | absl::flat_hash_map<int, NodeDef> small_tensors; |
513 | for (auto index = 0; index < inputs.size(); ++index) { |
514 | if (inputs[index]->const_value().has_value()) { |
515 | small_tensors.insert({index, inputs[index]->const_value().value()}); |
516 | } |
517 | } |
518 | return small_tensors; |
519 | } |
520 | |
521 | // Thread unsafe method. go/thread-unsafe |
522 | // Cache key computation should consider all features of an op that affects |
523 | // the SPMD lowering. The cache keys of two ops must be different if the |
524 | // translated functions are different. |
525 | // - op name and attr |
526 | // - input shapes and layouts |
527 | // - default layout of outputs. |
528 | // - values of constant foldable inputs. |
529 | tensorflow::Fprint128 FunctionManager::CacheKeyForGraph( |
530 | const DTensorOperation& doperation, const NameAttrList& attributes, |
531 | const std::vector<TensorWithLayout*>& inputs, |
532 | const std::vector<const Layout*>& output_layouts) { |
533 | tensorflow::Fprint128 cache_key = tensorflow::Fingerprint128(doperation.name); |
534 | std::string serialized; |
535 | SerializeToStringDeterministic(attributes, &serialized); |
536 | cache_key = |
537 | FingerprintCat128(cache_key, tensorflow::Fingerprint128(serialized)); |
538 | // Higher level cache based on operation name and input shapes. |
539 | for (auto i = 0; i < inputs.size(); ++i) { |
540 | if (!IsConstantFoldable(doperation, i)) { |
541 | inputs[i]->reset_const_value(); |
542 | } |
543 | cache_key = FingerprintCat128(cache_key, inputs[i]->CacheKey()); |
544 | } |
545 | for (int output_index = 0; output_index < output_layouts.size(); |
546 | ++output_index) { |
547 | if (output_layouts[output_index]) { |
548 | cache_key = FingerprintCat128(cache_key, output_index); |
549 | cache_key = FingerprintCat128( |
550 | cache_key, |
551 | tensorflow::Fingerprint128(output_layouts[output_index]->ToString())); |
552 | } |
553 | } |
554 | return cache_key; |
555 | } |
556 | |
557 | // Thread-unsafe method go/thread-unsafe. |
558 | std::pair<tensorflow::Fprint128, const ExecutionFunctions*> |
559 | FunctionManager::GetCachedFunction( |
560 | const DTensorOperation& doperation, const NameAttrList& attributes, |
561 | const std::vector<TensorWithLayout*>& inputs, |
562 | const std::vector<const Layout*>& output_layouts) { |
563 | tensorflow::Fprint128 cache_key = |
564 | CacheKeyForGraph(doperation, attributes, inputs, output_layouts); |
565 | auto iter = function_cache_.find(cache_key); |
566 | |
567 | // Early return if we have a cache hit. |
568 | if (iter != function_cache_.end()) { |
569 | return std::pair<Fprint128, ExecutionFunctions*>(cache_key, &iter->second); |
570 | } |
571 | |
572 | // For eager ops we early return the cache miss and do not make further |
573 | // optimizations. |
574 | if (!doperation.is_func()) { |
575 | return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr); |
576 | } |
577 | |
578 | const tensorflow::Fprint128 doperation_hash = |
579 | CacheKeyForDTensorOperation(doperation); |
580 | |
581 | // Save the constant folded inputs to this doperation if we have not seen this |
582 | // before. This is needed so that in the next call to this operation, we |
583 | // can compare these inputs to confirm which one is indeed a constant. |
584 | auto doperation_iter = dtensor_op_and_small_inputs_.find(doperation_hash); |
585 | if (doperation_iter == dtensor_op_and_small_inputs_.end()) { |
586 | dtensor_op_and_small_inputs_.insert( |
587 | {doperation_hash, GetConstantFoldableTensors(inputs)}); |
588 | return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr); |
589 | } |
590 | |
591 | // If we are here, then we have ran this function before but constant folded |
592 | // some input(s) when it was not a constant input i.e. one of the small value |
593 | // to this function input changed. So mark those changed values as |
594 | // non-constant. |
595 | absl::flat_hash_map<int, NodeDef>& previous_small_inputs = |
596 | doperation_iter->second; |
597 | std::vector<int> non_constant_indices; |
598 | |
599 | for (auto const& [index, previous_small_input] : previous_small_inputs) { |
600 | if (inputs[index]->const_value().has_value()) { |
601 | if (NodeDefsHaveDifferentTensorProto( |
602 | previous_small_input, inputs[index]->const_value().value())) { |
603 | inputs[index]->reset_const_value(); |
604 | non_constant_indices.push_back(index); |
605 | } |
606 | } |
607 | } |
608 | for (int non_constant_index : non_constant_indices) { |
609 | previous_small_inputs.erase(non_constant_index); |
610 | } |
611 | // Generate a new cache key since we updated small const inputs which change |
612 | // the cache key. |
613 | cache_key = CacheKeyForGraph(doperation, attributes, inputs, output_layouts); |
614 | return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr); |
615 | } |
616 | |
617 | const ExecutionFunctions* FunctionManager::AddCachedFunction( |
618 | const DTensorOperation& op, tensorflow::Fprint128 cache_key, |
619 | ExecutionFunctions function) { |
620 | return &function_cache_.insert({cache_key, std::move(function)}) |
621 | .first->second; |
622 | } |
623 | |
624 | bool FunctionManager::IsConstantFoldable(const DTensorOperation& doperation, |
625 | const int input_index) const { |
626 | // For eager ops, assume the inputs are constant foldable. |
627 | if (!doperation.is_func()) return true; |
628 | const tensorflow::Fprint128 doperation_hash = |
629 | CacheKeyForDTensorOperation(doperation); |
630 | // If we didn't see this doperation before then optimisticly assume this is |
631 | // foldable. The input at `input_index` is foldable only if it is one of the |
632 | // indices we have saved as the small inputs. |
633 | auto doperation_iter = dtensor_op_and_small_inputs_.find(doperation_hash); |
634 | return doperation_iter == dtensor_op_and_small_inputs_.end() || |
635 | doperation_iter->second.contains(input_index); |
636 | } |
637 | |
638 | const tensorflow::Fprint128 FunctionManager::CacheKeyForDTensorOperation( |
639 | const DTensorOperation& doperation) const { |
640 | return tensorflow::Fingerprint128(doperation.name); |
641 | } |
642 | |
643 | std::vector<int64_t> TensorShapeAsVector(TFE_TensorHandle* tensor, |
644 | TF_Status* status) { |
645 | std::vector<int64_t> shape(TFE_TensorHandleNumDims(tensor, status)); |
646 | if (TF_GetCode(status) != TF_OK) return {}; |
647 | for (int i = 0; i < shape.size(); ++i) { |
648 | shape[i] = TFE_TensorHandleDim(tensor, i, status); |
649 | if (TF_GetCode(status) != TF_OK) return {}; |
650 | } |
651 | return shape; |
652 | } |
653 | |
654 | Status PrepareGraphForMlir( |
655 | const FunctionManager& function_manager, |
656 | const std::vector<TensorWithLayout*>& inputs, |
657 | const DTensorOperation& doperation, |
658 | const tensorflow::FunctionLibraryDefinition& flib_def, |
659 | const NameAttrList& attributes, |
660 | const absl::optional<Layout>& default_layout, tensorflow::Graph* graph, |
661 | std::vector<PartialTensorShape>* global_output_shapes, |
662 | std::vector<const Layout*>* output_layouts) { |
663 | // We run shape inference on the graph to find output shapes, which may |
664 | // determine default layouts. |
665 | ShapeRefiner shape_refiner(TF_GRAPH_DEF_VERSION, &flib_def); |
666 | shape_refiner.set_function_library_for_shape_inference(&flib_def); |
667 | tensorflow::Status status; |
668 | { |
669 | // We include an _Arg node for the device ID, but this isn't used by the |
670 | // initial function. It will be provided a value, though, so it's available |
671 | // for use in rewrites. |
672 | tensorflow::NodeDefBuilder builder("device_id" , "_Arg" ); |
673 | tensorflow::PartialTensorShape partial_shape; |
674 | TF_RETURN_IF_ERROR(tensorflow::PartialTensorShape::MakePartialShape( |
675 | static_cast<int*>(nullptr), 0, &partial_shape)); |
676 | tensorflow::NodeDef arg_node_def; |
677 | TF_RETURN_IF_ERROR(builder.Attr("shape" , partial_shape) |
678 | .Attr("T" , tensorflow::DT_INT32) |
679 | .Attr("index" , 0) |
680 | .Finalize(&arg_node_def, /*consume=*/true)); |
681 | tensorflow::Node* arg_node = graph->AddNode(arg_node_def, &status); |
682 | TF_RETURN_IF_ERROR(status); |
683 | graph->AddControlEdge(graph->source_node(), arg_node); |
684 | TF_RETURN_IF_ERROR(shape_refiner.AddNode(arg_node)); |
685 | } |
686 | std::vector<FunctionArgument> graph_op_inputs; |
687 | graph_op_inputs.reserve(inputs.size()); |
688 | for (int i = 0; i < inputs.size(); ++i) { |
689 | const TensorWithLayout* input = inputs[i]; |
690 | // TODO(allenl): This will block until async execution is complete, which |
691 | // will be slow. We should find a non-blocking way of fetching the shape, |
692 | // at least pre-cache. |
693 | // The shape passed into MLIR transformation represents the global shape of |
694 | // the tensor. Ideally, the local shape on each parallel device should not |
695 | // be consulted at all and we should use the shape on our input tensor |
696 | // directly. |
697 | const auto& shape = input->global_shape(); |
698 | std::vector<tensorflow::int64> cast_shape(shape.begin(), shape.end()); |
699 | tensorflow::PartialTensorShape partial_shape; |
700 | // For resource tensors, `shape` attribute should not be specified as shape |
701 | // of resource tensors is specified by resource shape subtype -- not the |
702 | // shape attribute. |
703 | auto* resource = dynamic_cast<const ResourceHandleWithLayout*>(input); |
704 | if (!resource) { |
705 | TF_RETURN_IF_ERROR(tensorflow::PartialTensorShape::MakePartialShape( |
706 | cast_shape.data(), cast_shape.size(), &partial_shape)); |
707 | } |
708 | |
709 | tensorflow::NodeDef arg_node_def; |
710 | auto dtype = static_cast<tensorflow::DataType>(input->dtype()); |
711 | tensorflow::NodeDefBuilder builder(absl::StrCat("op_input_" , i), "_Arg" ); |
712 | |
713 | // Delegate TensorWithLayout to encode attributes if applicable. |
714 | input->EncodeAttributes(builder); |
715 | |
716 | // Here we set each arg node's `index` attribute to the position of |
717 | // the dtensor inputs. This is important for later use when we create |
718 | // a mapping from the graph argument node to the corresponding argument |
719 | // index of the list of dtensor inputs. Thus, even if the argument node |
720 | // orderings change within the graph, we can always correctly |
721 | // find the dtensor input corresponding to that arg node. |
722 | // |
723 | // This assumes that the dtensor inputs stay unchanged in ordering, |
724 | // and if there is an ordering change of dtensor inputs, then special |
725 | // care must be taken. |
726 | TF_RETURN_IF_ERROR( |
727 | builder.Attr("shape" , partial_shape) |
728 | .Attr("T" , dtype) |
729 | .Attr("index" , i + 1) // Indices are offset by 1 for device_id |
730 | .Attr(kLayoutAttr, input->layout().ToString()) |
731 | .Attr(kMeshAttr, input->mesh().mesh_config().ToString()) |
732 | .Finalize(&arg_node_def, /*consume=*/true)); |
733 | Node* arg_node = graph->AddNode(arg_node_def, &status); |
734 | TF_RETURN_IF_ERROR(status); |
735 | TF_RETURN_IF_ERROR(shape_refiner.AddNode(arg_node)); |
736 | |
737 | shape_inference::InferenceContext* inference_context = |
738 | shape_refiner.GetContext(arg_node); |
739 | shape_inference::ShapeHandle shape_handle; |
740 | TF_RETURN_IF_ERROR(inference_context->MakeShapeFromPartialTensorShape( |
741 | partial_shape, &shape_handle)); |
742 | TF_RETURN_IF_ERROR(shape_refiner.SetShape(arg_node, 0, shape_handle)); |
743 | |
744 | // Small constants are converted into constant graph nodes, instead of being |
745 | // passed in as input arguments. This provides more information to the SPMD |
746 | // and layout propagation passes. |
747 | if (!input->const_value().has_value() || |
748 | !function_manager.IsConstantFoldable(doperation, i)) { |
749 | graph_op_inputs.push_back(FunctionArgument{ |
750 | arg_node, NodeDefBuilder::NodeOut{arg_node->name(), i, dtype}}); |
751 | graph->AddControlEdge(graph->source_node(), arg_node); |
752 | } else { |
753 | // TODO(xiejw): Refactor the TensorWithLayout representation to avoid |
754 | // special code here. |
755 | NodeDef const_node = input->const_value().value(); |
756 | const_node.set_name(absl::StrCat("input_" , i, "_const_value" )); |
757 | Node* const_value_n = graph->AddNode(const_node, &status); |
758 | TF_RETURN_IF_ERROR(status); |
759 | TF_RETURN_IF_ERROR(shape_refiner.AddNode(const_value_n)); |
760 | graph_op_inputs.push_back(FunctionArgument{ |
761 | const_value_n, tensorflow::NodeDefBuilder::NodeOut{ |
762 | const_value_n->name(), i, dtype}}); |
763 | } |
764 | } |
765 | |
766 | tensorflow::NodeDef op_node_def; |
767 | const FunctionDef* function_def = doperation.function_def; |
768 | if (function_def) { |
769 | AttrValue func_attr; |
770 | func_attr.mutable_func()->set_name(doperation.name); |
771 | std::vector<tensorflow::NodeDefBuilder::NodeOut> func_inputs; |
772 | std::vector<tensorflow::DataType> inputs_types; |
773 | for (const auto& in : graph_op_inputs) { |
774 | func_inputs.emplace_back(in.output); |
775 | inputs_types.emplace_back(in.output.data_type); |
776 | } |
777 | |
778 | std::vector<tensorflow::DataType> output_types; |
779 | for (const auto& out : function_def->signature().output_arg()) |
780 | output_types.emplace_back(out.type()); |
781 | |
782 | TF_RETURN_IF_ERROR( |
783 | NodeDefBuilder("eager_operation" , "StatefulPartitionedCall" ) |
784 | .Attr("Tin" , inputs_types) |
785 | .Attr("Tout" , output_types) |
786 | .Attr("f" , func_attr) |
787 | .Input(func_inputs) |
788 | .Finalize(&op_node_def, true)); |
789 | } else { |
790 | op_node_def.set_op(doperation.name); |
791 | op_node_def.set_name("eager_operation" ); |
792 | } |
793 | |
794 | op_node_def.mutable_attr()->insert(attributes.attr().begin(), |
795 | attributes.attr().end()); |
796 | |
797 | tensorflow::Node* op_node = graph->AddNode(op_node_def, &status); |
798 | TF_RETURN_IF_ERROR(status); |
799 | |
800 | for (int i = 0; i < graph_op_inputs.size(); ++i) { |
801 | graph->AddEdge(graph_op_inputs[i].node, 0, op_node, i); |
802 | } |
803 | TF_RETURN_IF_ERROR(shape_refiner.AddNode(op_node)); |
804 | |
805 | output_layouts->clear(); |
806 | output_layouts->reserve(op_node->num_outputs()); |
807 | global_output_shapes->reserve(op_node->num_outputs()); |
808 | for (int output_index = 0; output_index < op_node->num_outputs(); |
809 | ++output_index) { |
810 | tensorflow::NodeDefBuilder builder(absl::StrCat("op_output_" , output_index), |
811 | "_Retval" ); |
812 | tensorflow::NodeDef ret_node_def; |
813 | tensorflow::DataType output_type = op_node->output_type(output_index); |
814 | |
815 | TF_RETURN_IF_ERROR(builder.Attr("T" , output_type) |
816 | .Attr("index" , output_index) |
817 | .Input("eager_operation" , output_index, output_type) |
818 | .Finalize(&ret_node_def, /*consume=*/true)); |
819 | tensorflow::Node* ret_node = graph->AddNode(ret_node_def, &status); |
820 | TF_RETURN_IF_ERROR(status); |
821 | graph->AddEdge(op_node, output_index, ret_node, 0); |
822 | graph->AddControlEdge(ret_node, graph->sink_node()); |
823 | |
824 | shape_inference::InferenceContext* inference_context = |
825 | shape_refiner.GetContext(op_node); |
826 | shape_inference::ShapeHandle output_shape_handle = |
827 | inference_context->output(output_index); |
828 | TensorShapeProto output_shape_proto; |
829 | inference_context->ShapeHandleToProto(output_shape_handle, |
830 | &output_shape_proto); |
831 | PartialTensorShape global_output_shape(output_shape_proto); |
832 | VLOG(3) << "Inferred shape for operation '" << doperation.name |
833 | << "':" << global_output_shape.DebugString(); |
834 | global_output_shapes->push_back(global_output_shape); |
835 | |
836 | const Layout* layout = nullptr; |
837 | if (default_layout.has_value() && output_index == 0) { |
838 | // Record the user's requested output layout. The scope currently only |
839 | // covers the first output of an op. |
840 | layout = &default_layout.value(); |
841 | ret_node->AddAttr(kDefaultLayoutAttr, layout->ToString()); |
842 | } |
843 | output_layouts->push_back(layout); |
844 | } |
845 | return OkStatus(); |
846 | } |
847 | |
848 | // Returns set of functions to run to execute DTensor computation. |
849 | StatusOr<ExecutionFunctions> IdentifyAllFunctionsToExecute( |
850 | const tensorflow::Graph& graph, |
851 | const std::vector<PartialTensorShape>& global_output_shapes) { |
852 | ExecutionFunctions execution_functions; |
853 | execution_functions.function_list = std::vector<TranslatedFunction>(); |
854 | for (Node* node : graph.nodes()) { |
855 | if (node->op_def().name() != "StatefulPartitionedCall" ) continue; |
856 | // Extract mesh to execute the function. |
857 | std::string serialized_mesh; |
858 | TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kMeshAttr, &serialized_mesh)); |
859 | Mesh mesh; |
860 | TF_ASSIGN_OR_RETURN(mesh, Mesh::FromString(serialized_mesh)); |
861 | |
862 | TranslatedFunction function; |
863 | function.function_mesh = std::move(mesh); |
864 | function.node_to_execute = node; |
865 | |
866 | // Identify input arg information. |
867 | TF_RETURN_IF_ERROR( |
868 | ParseResourceArgumentLayouts(*node, &function.resource_input_layouts)); |
869 | |
870 | TF_RETURN_IF_ERROR( |
871 | ParseShapeInputLayouts(*node, &function.shape_output_metadata)); |
872 | |
873 | function.input_index_map.resize(node->num_inputs()); |
874 | // Identity mapping between local mesh function input index and global |
875 | // input index. |
876 | for (int in_index = 0; in_index < node->num_inputs(); ++in_index) { |
877 | Node* input_node; |
878 | |
879 | TF_RETURN_IF_ERROR(node->input_node(in_index, &input_node)); |
880 | if (!input_node->IsArg()) |
881 | return errors::InvalidArgument( |
882 | "Input node to mesh computation must be arg node." ); |
883 | |
884 | int global_index; |
885 | TF_RETURN_IF_ERROR( |
886 | GetNodeAttr(input_node->attrs(), "index" , &global_index)); |
887 | function.input_index_map[in_index] = global_index; |
888 | } |
889 | |
890 | // Identify output mappings and layouts for each outputs. |
891 | std::map<int, const Edge*> output_edges; |
892 | for (const Edge* out_edge : node->out_edges()) { |
893 | if (out_edge->IsControlEdge()) continue; |
894 | |
895 | const Node* retval_or_identity_node = out_edge->dst(); |
896 | while (retval_or_identity_node->IsIdentity()) { |
897 | retval_or_identity_node = |
898 | *(retval_or_identity_node->out_nodes().begin()); |
899 | } |
900 | |
901 | TF_RET_CHECK(retval_or_identity_node->IsRetval()); |
902 | int global_index; |
903 | TF_RETURN_IF_ERROR(GetNodeAttr(retval_or_identity_node->attrs(), "index" , |
904 | &global_index)); |
905 | output_edges[global_index] = out_edge; |
906 | } |
907 | |
908 | for (auto it = output_edges.begin(); it != output_edges.end(); it++) { |
909 | const int global_index = it->first; |
910 | function.output_index_map.emplace_back(global_index); |
911 | |
912 | const Edge* retval_edge = it->second; |
913 | const int output_index = retval_edge->src_output(); |
914 | |
915 | // Add output layout and shape information. |
916 | TF_ASSIGN_OR_RETURN( |
917 | const Layout output_layout, |
918 | GetLayoutThroughIdentityOps(retval_edge->src(), output_index)); |
919 | |
920 | function.output_layouts.emplace_back(output_layout); |
921 | function.local_output_shapes.emplace_back( |
922 | output_layout.LocalShapeFromGlobalShape( |
923 | global_output_shapes[global_index])); |
924 | } |
925 | |
926 | execution_functions.function_list.emplace_back(std::move(function)); |
927 | } |
928 | |
929 | if (execution_functions.function_list.empty()) { |
930 | return errors::InvalidArgument( |
931 | "MLIR transformed graph does not have any functions to execute for " |
932 | "mesh." ); |
933 | } |
934 | |
935 | return execution_functions; |
936 | } |
937 | |
938 | // For functions with control outputs, add identity nodes between |
939 | // StatefulPartitionedCall and _Retvals, in order to preserve control output |
940 | // dependencies after StatefulPartitionedCall is inlined at runtime. |
941 | // Consider calling this in PrepareGraphForMlir, once the identity nodes won't |
942 | // be dropped during MLIR lowering. |
943 | // TODO(b/171265131): fix the underlying issue to avoid inserting identity |
944 | // nodes. |
945 | Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph) { |
946 | if (function_def == nullptr || function_def->control_ret().empty()) { |
947 | return OkStatus(); |
948 | } |
949 | tensorflow::Status status; |
950 | for (Node* n : graph->nodes()) { |
951 | if (!n->IsRetval()) { |
952 | continue; |
953 | } |
954 | const Edge* edge; |
955 | TF_RETURN_IF_ERROR(n->input_edge(0, &edge)); |
956 | int ret_index; |
957 | TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index" , &ret_index)); |
958 | tensorflow::NodeDefBuilder identity_builder( |
959 | absl::StrCat("op_output_identity_" , ret_index), "Identity" ); |
960 | tensorflow::NodeDef ret_identity_node_def; |
961 | tensorflow::DataType output_type = n->input_type(0); |
962 | TF_RETURN_IF_ERROR( |
963 | identity_builder.Attr("T" , output_type) |
964 | .Input(edge->src()->name(), edge->src_output(), output_type) |
965 | .Finalize(&ret_identity_node_def, /*consume=*/true)); |
966 | Node* ret_identity_node = graph->AddNode(ret_identity_node_def, &status); |
967 | TF_RETURN_IF_ERROR(status); |
968 | // Delete the edge between StatefulPartitionedCall and _Retval. |
969 | graph->RemoveEdge(edge); |
970 | // Add an edge between StatefulPartitionedCall and Identity. |
971 | graph->AddEdge(edge->src(), edge->src_output(), ret_identity_node, 0); |
972 | graph->AddControlEdge(edge->src(), ret_identity_node); |
973 | // Add an edge between Identity and _Retval. |
974 | graph->AddEdge(ret_identity_node, 0, n, 0); |
975 | } |
976 | return OkStatus(); |
977 | } |
978 | |
979 | void AddDTensorFunctionAttr(FunctionDef& function_def) { |
980 | // Do not xla compile function returned by DTensor MLIR graph transformation |
981 | // as it already returns compiled graph. |
982 | AttrValue xla_must_compile_val; |
983 | xla_must_compile_val.set_b(false); |
984 | function_def.mutable_attr()->insert( |
985 | {"_XlaMustCompile" , xla_must_compile_val}); |
986 | |
987 | // Explicitly place function outputs on the default function device to avoid |
988 | // redundant host <-> device copies (Placer may place outputs on the host |
989 | // CPU). |
990 | AttrValue outputs_on_op_device; |
991 | outputs_on_op_device.set_b(true); |
992 | function_def.mutable_attr()->insert( |
993 | {"_OutputsOnOpDevice" , outputs_on_op_device}); |
994 | } |
995 | |
996 | StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs( |
997 | const std::vector<TensorWithLayout*>& inputs) { |
998 | absl::flat_hash_map<int64_t, std::vector<int64_t>> table_vars_input_index; |
999 | for (int64_t i = 0; i < inputs.size(); ++i) { |
1000 | if (inputs[i]->tensor_type() != kResource) continue; |
1001 | |
1002 | const absl::optional<EmbeddingResourceAttrs>& resource_attrs = |
1003 | inputs[i]->attrs(); |
1004 | if (resource_attrs.has_value()) { |
1005 | table_vars_input_index[resource_attrs->table_id].push_back(i); |
1006 | } |
1007 | } |
1008 | |
1009 | // Check if there is no embedding resource input found. |
1010 | if (table_vars_input_index.empty()) { |
1011 | return errors::Internal("There are no TPU embedding resource input found." ); |
1012 | } |
1013 | std::vector<parallel_device::ParallelTensor*> parallel_inputs; |
1014 | // Assure parallel inputs has numeric order as table ids. |
1015 | for (const auto& [table_id, table_vars_indices] : table_vars_input_index) { |
1016 | for (const int64_t input_index : table_vars_indices) { |
1017 | parallel_inputs.push_back(inputs[input_index]->tensor()); |
1018 | } |
1019 | } |
1020 | return parallel_inputs; |
1021 | } |
1022 | |
1023 | StatusOr<std::map<int64_t, std::vector<Node*>>> GetTPUEmbeddingInputNodes( |
1024 | TF_Status* s, const Graph& graph, |
1025 | const std::vector<TensorWithLayout*>& inputs) { |
1026 | // After the graph is lowered, the sparse tensors live at the end of the |
1027 | // argument list, so process the dtensor dense inputs only so that |
1028 | // we index correctly. |
1029 | std::vector<TensorWithLayout*> non_sparse_inputs; |
1030 | non_sparse_inputs.reserve(inputs.size()); |
1031 | for (TensorWithLayout* input : inputs) { |
1032 | if (input->tensor_type() != TensorType::kSparse) { |
1033 | non_sparse_inputs.push_back(input); |
1034 | } |
1035 | } |
1036 | std::map<int64_t, std::vector<Node*>> table_id_node_map; |
1037 | for (Node* node : graph.nodes()) { |
1038 | if (!node->IsArg()) continue; |
1039 | |
1040 | const int64_t& arg_id = node->attrs().Find("index" )->i(); |
1041 | const AttrValue* embedding_attr = |
1042 | node->attrs().Find("_tpu_embedding_table_id" ); |
1043 | |
1044 | if (embedding_attr == nullptr) continue; |
1045 | EmbeddingResourceAttrs embedding_input_attrs; |
1046 | |
1047 | // Add embedding table id. |
1048 | const int64_t table_id = embedding_attr->i(); |
1049 | embedding_input_attrs.table_id = table_id; |
1050 | |
1051 | // Add embedding slot id if there is one. |
1052 | const AttrValue* embedding_slot_attr = |
1053 | node->attrs().Find("_tpu_embedding_slot_id" ); |
1054 | if (embedding_slot_attr != nullptr) { |
1055 | const int64_t slot_id = embedding_slot_attr->i(); |
1056 | embedding_input_attrs.slot_id = slot_id; |
1057 | } |
1058 | |
1059 | table_id_node_map[table_id].push_back(node); |
1060 | |
1061 | // Arg input offset due to device id. |
1062 | if (non_sparse_inputs[arg_id - 1]->attrs().has_value()) continue; |
1063 | non_sparse_inputs[arg_id - 1]->UpdateAttrs(embedding_input_attrs, s); |
1064 | if (!s->status.ok()) { |
1065 | return errors::Internal( |
1066 | "Failed to set embedding resource attrs. \n Got error: " , |
1067 | s->status.error_message()); |
1068 | } |
1069 | } |
1070 | return table_id_node_map; |
1071 | } |
1072 | |
1073 | StatusOr<std::string> ValidateResourceMeshConsistency( |
1074 | const std::vector<TensorWithLayout*>& inputs) { |
1075 | std::string mesh_str; |
1076 | for (TensorWithLayout* inp : inputs) { |
1077 | if ((inp->tensor_type() != kResource) || !inp->attrs().has_value()) |
1078 | continue; |
1079 | const std::string& input_mesh_str = inp->layout().mesh().ToString(); |
1080 | if (mesh_str.empty()) { |
1081 | mesh_str = input_mesh_str; |
1082 | } else if (mesh_str != input_mesh_str) { |
1083 | return errors::Internal(absl::StrCat( |
1084 | "All inputs of embedding resource must be on same mesh. but get : " , |
1085 | mesh_str, " != " , input_mesh_str)); |
1086 | } |
1087 | } |
1088 | VLOG(1) << "Resource input mesh is : " << mesh_str; |
1089 | return mesh_str; |
1090 | } |
1091 | |
1092 | Status InsertFunctionForTPUEmbeddingCheckpoint( |
1093 | TF_Status* status, Graph* graph, |
1094 | const std::vector<TensorWithLayout*>& inputs, |
1095 | const std::string& checkpoint_fn_name) { |
1096 | if (checkpoint_fn_name != kLoadEmbeddingFn && |
1097 | checkpoint_fn_name != kRetrieveEmbeddingFn) { |
1098 | return errors::InvalidArgument(absl::StrCat( |
1099 | "Found wrong function name: " , checkpoint_fn_name, |
1100 | " \n expects : " , kLoadEmbeddingFn, " or " , kRetrieveEmbeddingFn)); |
1101 | } |
1102 | |
1103 | StatusOr<std::map<int64_t, std::vector<Node*>>> table_id_node_map = |
1104 | GetTPUEmbeddingInputNodes(status, *graph, inputs); |
1105 | if (!table_id_node_map.ok()) { |
1106 | return errors::Internal(table_id_node_map.status().error_message()); |
1107 | } |
1108 | |
1109 | StatusOr<std::string> mesh_str = ValidateResourceMeshConsistency(inputs); |
1110 | |
1111 | const int64_t& num_tables = table_id_node_map->size(); |
1112 | NodeDef func_node_def; |
1113 | std::vector<NodeDefBuilder::NodeOut> func_inputs; |
1114 | std::vector<DataType> input_types, output_types; |
1115 | |
1116 | func_inputs.reserve(num_tables); |
1117 | input_types.reserve(num_tables); |
1118 | |
1119 | for (int i = 0; i < num_tables; ++i) { |
1120 | auto node_vec_ptr = table_id_node_map->find(i); |
1121 | if (node_vec_ptr == table_id_node_map->end()) { |
1122 | return errors::Internal( |
1123 | absl::StrCat("Embedding table id " , i, " is not found." )); |
1124 | } |
1125 | for (const Node* n : node_vec_ptr->second) { |
1126 | const std::string& node_name = n->name(); |
1127 | func_inputs.push_back({node_name, i, DT_RESOURCE}); |
1128 | input_types.push_back(DT_RESOURCE); |
1129 | } |
1130 | } |
1131 | |
1132 | AttrValue mesh_attr; |
1133 | *mesh_attr.mutable_s() = *mesh_str; |
1134 | NameAttrList func_attr; |
1135 | func_attr.set_name(checkpoint_fn_name); |
1136 | TF_RETURN_IF_ERROR( |
1137 | NodeDefBuilder(checkpoint_fn_name, "StatefulPartitionedCall" ) |
1138 | .Attr("Tin" , input_types) |
1139 | .Attr("Tout" , output_types) |
1140 | .Attr("f" , func_attr) |
1141 | .Attr(kMeshAttr, mesh_attr) |
1142 | .Attr("config" , mesh_attr) |
1143 | .Input(func_inputs) |
1144 | .Finalize(&func_node_def, true)); |
1145 | |
1146 | TF_ASSIGN_OR_RETURN(Node * func_node, graph->AddNode(func_node_def)); |
1147 | for (int i = 0; i < num_tables; ++i) { |
1148 | const std::vector<Node*>& node_vec = table_id_node_map->find(i)->second; |
1149 | for (int j = 0; j < node_vec.size(); ++j) { |
1150 | graph->AddEdge(node_vec[j], 0, func_node, j + i); |
1151 | } |
1152 | } |
1153 | |
1154 | return OkStatus(); |
1155 | } |
1156 | |
1157 | } // namespace dtensor |
1158 | } // namespace tensorflow |
1159 | |