1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/cc/dtensor_device_util.h"
17
18#include <cstddef>
19#include <string>
20#include <utility>
21
22#include "absl/container/flat_hash_map.h"
23#include "absl/strings/str_cat.h"
24#include "tensorflow/c/eager/c_api_internal.h"
25#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
26#include "tensorflow/c/tf_status.h"
27#include "tensorflow/compiler/xla/status_macros.h"
28#include "tensorflow/core/common_runtime/graph_constructor.h"
29#include "tensorflow/core/common_runtime/shape_refiner.h"
30#include "tensorflow/core/framework/attr_value.pb.h"
31#include "tensorflow/core/framework/function.h"
32#include "tensorflow/core/framework/node_def.pb.h"
33#include "tensorflow/core/framework/node_def_util.h"
34#include "tensorflow/core/framework/tensor.pb.h"
35#include "tensorflow/core/framework/types.pb.h"
36#include "tensorflow/core/graph/graph.h"
37#include "tensorflow/core/lib/strings/proto_serialization.h"
38#include "tensorflow/core/platform/errors.h"
39#include "tensorflow/core/platform/fingerprint.h"
40#include "tensorflow/core/public/version.h"
41#include "tensorflow/dtensor/cc/constants.h"
42#include "tensorflow/dtensor/cc/dstatus.h"
43#include "tensorflow/dtensor/cc/small_constant_optimization.h"
44
45namespace tensorflow {
46namespace dtensor {
47namespace {
48// Represents an input node during graph construction.
49// When executing a Function, `output` is used to align graph inputs
50// with the inputs to the function call.
51struct FunctionArgument {
52 Node* node;
53 NodeDefBuilder::NodeOut output;
54};
55
56std::unique_ptr<parallel_device::ParallelTensor>
57BroadcastTensorHandleToParallelTensor(TFE_Context* context,
58 TFE_TensorHandle* tensor,
59 const MeshWithParallelDevice& mesh,
60 TF_Status* status) {
61 // Broadcast tensor value to local devices.
62 const Mesh& target_mesh = mesh.mesh_config();
63 absl::Span<const std::string> local_devices = target_mesh.local_devices();
64 const int num_local_devices = local_devices.size();
65
66 std::vector<parallel_device::TensorHandlePtr> components;
67 components.reserve(num_local_devices);
68 for (int i = 0; i < num_local_devices; ++i) {
69 // Create tensor copies to each local devices specifie by `target_mesh`.
70 components.emplace_back(TFE_TensorHandleCopyToDevice(
71 tensor, context, local_devices[i].c_str(), status));
72 if (TF_GetCode(status) != TF_OK) {
73 TF_SetStatus(
74 status, TF_INTERNAL,
75 absl::StrCat(
76 "Unable to copy tensor value for broadcast. Original message: ",
77 TF_Message(status))
78 .c_str());
79 return nullptr;
80 }
81 }
82
83 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
84 parallel_device::ParallelTensor::FromTensorHandles(
85 mesh.parallel_device(), std::move(components), status);
86 if (TF_GetCode(status) != TF_OK) return nullptr;
87 return parallel_tensor;
88}
89
90// Broadcast a single non-parallel resource tensor onto `mesh` with a fully
91// replicated sharding spec. Does not take ownership of `tensor`.
92std::unique_ptr<TensorWithLayout> BroadcastResourceTensor(
93 TFE_Context* context, TFE_TensorHandle* tensor,
94 const MeshWithParallelDevice& mesh, const std::string& dtensor_device_name,
95 TF_Status* status) {
96 // Only broadcast resource tensors that point to scalars since they are
97 // always replicated. We also still want to catch honest user errors so
98 // error out on non-scalars.
99 // Resolve the Tensor as resource handle and get the shape and dtype
100 // of the tensor it points to.
101 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tf_tensor(
102 TFE_TensorHandleResolve(tensor, status), TF_DeleteTensor);
103 Tensor t;
104 Status convert_status = TF_TensorToTensor(tf_tensor.get(), &t);
105 if (!convert_status.ok() || t.dtype() != DataType::DT_RESOURCE) {
106 TF_SetStatus(status, TF_INTERNAL,
107 absl::StrCat("TF_TensorToTensor() Conversion failed:",
108 convert_status.error_message())
109 .c_str());
110 return nullptr;
111 }
112 // Replicate this resource handle to all devices without changing the
113 // associated device of the resource itself.
114 ResourceHandle r = t.flat<ResourceHandle>()(0);
115
116 // Only broadcast resource tensors onto a CPU mesh. Copying
117 // resource tensors to non CPU device is not supported.
118 if (!mesh.mesh_config().is_cpu_mesh()) {
119 std::string error_message =
120 "Using a non-DTensor variable with DTensor is only supported for "
121 "copying to a CPU mesh. If you are using a scope "
122 "based API, create "
123 "variables inside the DTensor scope.\n";
124
125 // Get the stack_trace and Summaries from the resource tensor.
126 absl::StrAppend(
127 &error_message, "Offending variable summary: ", r.SummarizeValue(),
128 "\nStack trace: ", DefinitionLocationMsg(r.definition_stack_trace()));
129 TF_SetStatus(status, TF_INVALID_ARGUMENT, error_message.c_str());
130 return nullptr;
131 }
132
133 LOG(INFO) << "Broadcasting resource tensor to a dtensor resource tensor.";
134 if (mesh.mesh_config().is_remote()) {
135 TF_DataType dtype = TFE_TensorHandleDataType(tensor);
136 std::vector<int64_t> shape(TensorShapeAsVector(tensor, status));
137 if (TF_GetCode(status) != TF_OK) return nullptr;
138 auto layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), shape.size());
139
140 auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout);
141 return ret;
142 }
143
144 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
145 BroadcastTensorHandleToParallelTensor(context, tensor, mesh, status);
146 if (TF_GetCode(status) != TF_OK) return nullptr;
147
148 int rank = r.dtypes_and_shapes().empty()
149 ? 0
150 : r.dtypes_and_shapes().begin()->shape.dims();
151
152 StatusOr<std::unique_ptr<TensorWithLayout>> result = TensorWithLayout::Wrap(
153 std::move(parallel_tensor), mesh,
154 Layout::ReplicatedOnMesh(mesh.mesh_config(), rank));
155 if (!result.ok()) {
156 TF_SetStatus(
157 status, TF_INTERNAL,
158 absl::StrCat("Error creating a TensorWithLayout from a resource tensor "
159 "during broadcasting with original error message:",
160 result.status().error_message())
161 .c_str());
162 return nullptr;
163 }
164
165 if (!r.dtypes_and_shapes().empty()) {
166 PartialTensorShape partial_shape = r.dtypes_and_shapes().begin()->shape;
167 // Set the shape/type of the tensor that the resource points to
168 // so that the graph has correct shape/type information that we can use.
169 (*result)->UpdateShapeAndDType(
170 partial_shape.AsProto(), r.dtypes_and_shapes().begin()->dtype, status);
171 }
172
173 if (TF_GetCode(status) != TF_OK) {
174 TF_SetStatus(status, TF_INTERNAL,
175 "Error updating shape and dtype for resource tensor during "
176 "broadcasting.");
177 return nullptr;
178 }
179 return std::move(*result);
180}
181
182bool LayoutsAreCompatible(absl::optional<Layout> first_layout,
183 absl::optional<Layout> second_layout) {
184 if (!first_layout.has_value() && !second_layout.has_value()) {
185 return true;
186 }
187 if (!first_layout.has_value() || !second_layout.has_value()) {
188 return false;
189 }
190 return first_layout.value() == second_layout.value();
191}
192
193// Parse a pair of attribute of (indices, layouts) into a map.
194Status ParseAttrMap(const Node& node, absl::string_view indices_attr,
195 absl::string_view layout_attr,
196 std::map<int, Layout>* indices_layout_map) {
197 std::vector<std::string> layouts;
198 if (!TryGetNodeAttr(node.attrs(), layout_attr, &layouts)) {
199 return OkStatus();
200 }
201 const TensorProto* indices;
202 if (!TryGetNodeAttr(node.attrs(), indices_attr, &indices)) {
203 return errors::Internal(
204 "Arg indices must be set when setting inferred resource layouts.");
205 }
206 if (indices->int_val_size() != layouts.size()) {
207 return errors::Internal(
208 "Arg indices for inferred resource argument must match the "
209 "size of inferred resource layout.");
210 }
211 for (int i = 0; i < indices->int_val_size(); ++i) {
212 const auto arg_index = indices->int_val(i);
213 const auto& arg_layout = layouts[i];
214 indices_layout_map->emplace(
215 arg_index, tensorflow::dtensor::Layout::FromString(arg_layout).value());
216 }
217 return OkStatus();
218}
219
220Status ParseResourceArgumentLayouts(
221 const Node& node, std::map<int, Layout>* inferred_resource_input_layouts) {
222 return ParseAttrMap(node, kNewResourceLayoutIndices, kNewResourceArgLayouts,
223 inferred_resource_input_layouts);
224}
225
226Status ParseShapeInputLayouts(const Node& node,
227 std::map<int, Layout>* shape_output_metadata) {
228 return ParseAttrMap(node, kShapeOpInputLayoutIndices, kShapeOpInputLayout,
229 shape_output_metadata);
230}
231
232// Gets the layout attached to a specific node at a given index, ignoring any
233// Identity ops.
234StatusOr<Layout> GetLayoutThroughIdentityOps(Node* op, int output_index) {
235 while (op->op_def().name() == "Identity" ||
236 op->op_def().name() == "IdentityN") {
237 const Edge* edge;
238 TF_RETURN_IF_ERROR(op->input_edge(output_index, &edge));
239 op = edge->src();
240 output_index = edge->src_output();
241 }
242 const auto serialized_layouts = op->attrs().Find(kLayoutAttr);
243
244 if (!serialized_layouts) {
245 return errors::InvalidArgument(
246 op->op_def().name(), " doesn't contain attribute : ", kLayoutAttr);
247 }
248
249 // We assume that there is one layout for each output.
250 if (serialized_layouts->list().s_size() != op->num_outputs()) {
251 return errors::InvalidArgument(
252 "Number of outputs to ", op->op_def().name(),
253 " does not match number of layouts attached");
254 }
255
256 return Layout::FromString(serialized_layouts->list().s(output_index));
257}
258
259} // namespace
260
261tensorflow::Fprint128 TensorWithLayout::CacheKey() const {
262 tensorflow::Fprint128 f = tensorflow::Fingerprint128(layout_.ToString());
263 // Use exact shape to compute the key.
264 for (const int64_t dim : local_shape()) {
265 f = FingerprintCat128(f, dim);
266 }
267 if (const_value_.has_value()) {
268 std::string serialized;
269 SerializeToStringDeterministic(const_value_.value(), &serialized);
270 f = FingerprintCat128(f, tensorflow::Fingerprint128(serialized));
271 }
272 return f;
273}
274
275std::unique_ptr<TensorWithLayout> TensorWithLayout::Broadcast(
276 TFE_Context* context, TFE_TensorHandle* tensor,
277 const MeshWithParallelDevice& mesh, const std::string& dtensor_device_name,
278 TF_Status* status) {
279 const char* input_device = TFE_TensorHandleDeviceName(tensor, status);
280 if (TF_GetCode(status) != TF_OK) return nullptr;
281
282 if (dtensor_device_name == input_device) {
283 TF_SetStatus(status, TF_INVALID_ARGUMENT,
284 "Input to Broadcast must be eager tensor.");
285 return nullptr;
286 }
287
288 // Handle resource tensor broadcasting to the mesh.
289 if (TFE_TensorHandleDataType(tensor) == TF_RESOURCE) {
290 return BroadcastResourceTensor(context, tensor, mesh, dtensor_device_name,
291 status);
292 }
293
294 if (mesh.mesh_config().is_remote()) {
295 TF_DataType dtype = TFE_TensorHandleDataType(tensor);
296 std::vector<int64_t> shape(TensorShapeAsVector(tensor, status));
297 if (TF_GetCode(status) != TF_OK) return nullptr;
298 auto layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), shape.size());
299
300 auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout);
301 absl::optional<NodeDef> const_value =
302 ExtractSmallTensorValue(context, tensor, layout, status);
303 if (TF_GetCode(status) != TF_OK) return nullptr;
304 if (const_value) {
305 ret->set_const_value(const_value.value());
306 }
307 return ret;
308 }
309
310 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
311 BroadcastTensorHandleToParallelTensor(context, tensor, mesh, status);
312 if (TF_GetCode(status) != TF_OK) return nullptr;
313
314 const std::vector<int64_t>* shape;
315 Status s = parallel_tensor->Shape(&shape);
316 if (!s.ok()) {
317 TF_SetStatus(status, static_cast<TF_Code>(s.code()),
318 s.error_message().c_str());
319 return nullptr;
320 }
321 size_t num_dims = shape->size();
322 const Layout layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), num_dims);
323
324 absl::optional<NodeDef> const_value =
325 ExtractSmallTensorValue(context, tensor, layout, status);
326 if (TF_GetCode(status) != TF_OK) return nullptr;
327
328 std::unique_ptr<TensorWithLayout> result(new TensorWithLayout(
329 std::move(parallel_tensor), mesh, std::move(layout), *shape,
330 /*dtype=*/absl::nullopt, std::move(const_value)));
331 return result;
332}
333
334StatusOr<std::unique_ptr<TensorWithLayout>> TensorWithLayout::Wrap(
335 std::unique_ptr<parallel_device::ParallelTensor> tensor,
336 const MeshWithParallelDevice& mesh, const Layout& layout) {
337 const std::vector<int64_t>* shape;
338 TF_RETURN_IF_ERROR(tensor->Shape(&shape));
339
340 if (tensor->dtype() != TF_RESOURCE) {
341 return std::unique_ptr<TensorWithLayout>(
342 new TensorWithLayout(std::move(tensor), mesh, layout, *shape));
343 } else {
344 return std::unique_ptr<TensorWithLayout>(
345 new ResourceHandleWithLayout(std::move(tensor), mesh, layout, *shape));
346 }
347}
348
349std::unique_ptr<TensorWithLayout> TensorWithLayout::Dummy(
350 const std::vector<int64_t>& local_shape, const TF_DataType dtype,
351 const MeshWithParallelDevice& mesh, const Layout& layout) {
352 if (dtype != TF_RESOURCE) {
353 return std::unique_ptr<TensorWithLayout>(new TensorWithLayout(
354 /*tensor=*/nullptr, mesh, layout, local_shape, dtype));
355 } else {
356 return std::unique_ptr<TensorWithLayout>(new ResourceHandleWithLayout(
357 /*tensor=*/nullptr, mesh, layout, local_shape));
358 }
359}
360
361std::string TensorWithLayout::SummarizeValue() const {
362 std::string value_summary;
363 Status status;
364 if (dtype() != TF_RESOURCE && layout().IsFullyReplicated()) {
365 status =
366 tensorflow::unwrap(tensor()->tensor(0))->SummarizeValue(value_summary);
367 } else {
368 // Note that this just prints the local values for sharded tensors. We could
369 // instead run a collective here to relayout to replicated.
370 status = tensor()->SummarizeValue(value_summary);
371 }
372 if (!status.ok()) {
373 value_summary = "<error computing value>";
374 }
375 return absl::StrCat(value_summary, ", layout=\"", layout().ToString(), "\"");
376}
377
378std::string TensorWithLayout::DebugString() const {
379 auto dtype = static_cast<DataType>(tensor()->dtype());
380
381 const auto& shape_vector = global_shape();
382 return absl::StrCat("DTensor(", SummarizeValue(),
383 ", shape=", ShapeToDebugString(shape_vector),
384 ", type=", DataTypeString(dtype), ")");
385}
386
387void ResourceHandleWithLayout::EncodeAttributes(
388 tensorflow::NodeDefBuilder& builder) const {
389 // If set, attach shape and dtype to the given node def.
390 if (dereferenced_shape().has_value()) {
391 builder.Attr("_handle_shapes", {*dereferenced_shape()});
392 }
393 if (dereferenced_dtype().has_value()) {
394 builder.Attr("_handle_dtypes", {*dereferenced_dtype()});
395 }
396}
397
398tensorflow::Fprint128 ResourceHandleWithLayout::CacheKey() const {
399 tensorflow::Fprint128 f = tensorflow::Fingerprint128(layout().ToString());
400 if (dereferenced_shape().has_value()) {
401 std::string serialized;
402 SerializeToStringDeterministic(dereferenced_shape().value(), &serialized);
403 f = FingerprintCat128(f, tensorflow::Fingerprint128(serialized));
404 }
405 if (dereferenced_dtype().has_value()) {
406 f = FingerprintCat128(f, dereferenced_dtype().value());
407 }
408 return f;
409}
410
411void ResourceHandleWithLayout::UpdateLayout(const Layout& new_layout,
412 TF_Status* status) {
413 // Only set the value for deferenced layout if the incoming layout is not
414 // empty. This is still hacky as we use empty layout as placeholder for
415 // eagerly placed VarHandleOp.
416 if (!dereferenced_layout_.has_value() && new_layout.IsEmpty()) return;
417 if (dereferenced_layout_.has_value() &&
418 !LayoutsAreCompatible(dereferenced_layout_, new_layout)) {
419 // TODO(xiejw, allenl): Consider allowing variables to switch layouts.
420 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
421 "Attempted to overwrite an existing Layout.");
422 }
423 dereferenced_layout_.emplace(new_layout);
424}
425
426void ResourceHandleWithLayout::UpdateAttrs(const EmbeddingResourceAttrs& attrs,
427 TF_Status* status) {
428 if (attrs_.has_value()) {
429 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
430 "Attepted to overwrite an existing embedding resource "
431 "attribute.");
432 }
433 attrs_.emplace(attrs);
434}
435
436StatusOr<std::unique_ptr<TensorWithLayout>> SparseTensorWithLayout::Wrap(
437 std::unique_ptr<parallel_device::ParallelTensor> indices_tensor,
438 std::unique_ptr<parallel_device::ParallelTensor> values_tensor,
439 std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor,
440 const MeshWithParallelDevice& mesh, const Layout& layout,
441 std::vector<int64_t> local_shape) {
442 return std::unique_ptr<TensorWithLayout>(new SparseTensorWithLayout(
443 std::move(indices_tensor), std::move(values_tensor),
444 std::move(shapes_tensor), mesh, layout, local_shape));
445}
446
447std::string SparseTensorWithLayout::SummarizeValue() const {
448 std::string indices_summary;
449 std::string values_summary;
450 std::string dense_shapes_summary;
451
452 Status indices_status;
453 Status values_status;
454 Status dense_shapes_status;
455
456 if (layout().IsFullyReplicated()) {
457 indices_status = tensorflow::unwrap(indices_->tensor(0))
458 ->SummarizeValue(indices_summary);
459 values_status =
460 tensorflow::unwrap(values_->tensor(0))->SummarizeValue(values_summary);
461 dense_shapes_status = tensorflow::unwrap(dense_shapes_->tensor(0))
462 ->SummarizeValue(dense_shapes_summary);
463 } else {
464 indices_status = indices_->SummarizeValue(indices_summary);
465 values_status = values_->SummarizeValue(values_summary);
466 dense_shapes_status = dense_shapes_->SummarizeValue(dense_shapes_summary);
467 }
468
469 if (!indices_status.ok())
470 values_summary = "<error computing summary for indices>";
471 if (!values_status.ok())
472 indices_summary = "<error computing summary for values>";
473 if (!dense_shapes_status.ok())
474 indices_summary = "<error computing summary for dense_shapes>";
475
476 return absl::StrCat("indices: ", indices_summary, ", ",
477 "values: ", values_summary, ", ",
478 "dense_shapes: ", dense_shapes_summary, ", layout=\"",
479 layout().ToString(), "\"");
480}
481
482std::string SparseTensorWithLayout::DebugString() const {
483 auto dtype = static_cast<DataType>(values_->dtype());
484
485 const auto& shape_vector = global_shape();
486 return absl::StrCat("DTensor(", SummarizeValue(),
487 ", shape=", ShapeToDebugString(shape_vector),
488 ", type=", DataTypeString(dtype), ")");
489}
490
491TF_DataType SparseTensorWithLayout::dtype() const {
492 if (dtype_.has_value()) {
493 return dtype_.value();
494 } else {
495 return values_->dtype();
496 }
497}
498
499TFE_TensorHandle* SparseTensorWithLayout::get_tensor(size_t index) const {
500 int num_sparse_tensors = num_tensors() / 3;
501 if (index < num_sparse_tensors) {
502 return indices()->tensor(index);
503 } else if (index < 2 * num_sparse_tensors) {
504 return values()->tensor(index % num_sparse_tensors);
505 } else {
506 return dense_shapes()->tensor(index % num_sparse_tensors);
507 }
508}
509
510absl::flat_hash_map<int, NodeDef> GetConstantFoldableTensors(
511 const std::vector<TensorWithLayout*>& inputs) {
512 absl::flat_hash_map<int, NodeDef> small_tensors;
513 for (auto index = 0; index < inputs.size(); ++index) {
514 if (inputs[index]->const_value().has_value()) {
515 small_tensors.insert({index, inputs[index]->const_value().value()});
516 }
517 }
518 return small_tensors;
519}
520
521// Thread unsafe method. go/thread-unsafe
522// Cache key computation should consider all features of an op that affects
523// the SPMD lowering. The cache keys of two ops must be different if the
524// translated functions are different.
525// - op name and attr
526// - input shapes and layouts
527// - default layout of outputs.
528// - values of constant foldable inputs.
529tensorflow::Fprint128 FunctionManager::CacheKeyForGraph(
530 const DTensorOperation& doperation, const NameAttrList& attributes,
531 const std::vector<TensorWithLayout*>& inputs,
532 const std::vector<const Layout*>& output_layouts) {
533 tensorflow::Fprint128 cache_key = tensorflow::Fingerprint128(doperation.name);
534 std::string serialized;
535 SerializeToStringDeterministic(attributes, &serialized);
536 cache_key =
537 FingerprintCat128(cache_key, tensorflow::Fingerprint128(serialized));
538 // Higher level cache based on operation name and input shapes.
539 for (auto i = 0; i < inputs.size(); ++i) {
540 if (!IsConstantFoldable(doperation, i)) {
541 inputs[i]->reset_const_value();
542 }
543 cache_key = FingerprintCat128(cache_key, inputs[i]->CacheKey());
544 }
545 for (int output_index = 0; output_index < output_layouts.size();
546 ++output_index) {
547 if (output_layouts[output_index]) {
548 cache_key = FingerprintCat128(cache_key, output_index);
549 cache_key = FingerprintCat128(
550 cache_key,
551 tensorflow::Fingerprint128(output_layouts[output_index]->ToString()));
552 }
553 }
554 return cache_key;
555}
556
557// Thread-unsafe method go/thread-unsafe.
558std::pair<tensorflow::Fprint128, const ExecutionFunctions*>
559FunctionManager::GetCachedFunction(
560 const DTensorOperation& doperation, const NameAttrList& attributes,
561 const std::vector<TensorWithLayout*>& inputs,
562 const std::vector<const Layout*>& output_layouts) {
563 tensorflow::Fprint128 cache_key =
564 CacheKeyForGraph(doperation, attributes, inputs, output_layouts);
565 auto iter = function_cache_.find(cache_key);
566
567 // Early return if we have a cache hit.
568 if (iter != function_cache_.end()) {
569 return std::pair<Fprint128, ExecutionFunctions*>(cache_key, &iter->second);
570 }
571
572 // For eager ops we early return the cache miss and do not make further
573 // optimizations.
574 if (!doperation.is_func()) {
575 return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr);
576 }
577
578 const tensorflow::Fprint128 doperation_hash =
579 CacheKeyForDTensorOperation(doperation);
580
581 // Save the constant folded inputs to this doperation if we have not seen this
582 // before. This is needed so that in the next call to this operation, we
583 // can compare these inputs to confirm which one is indeed a constant.
584 auto doperation_iter = dtensor_op_and_small_inputs_.find(doperation_hash);
585 if (doperation_iter == dtensor_op_and_small_inputs_.end()) {
586 dtensor_op_and_small_inputs_.insert(
587 {doperation_hash, GetConstantFoldableTensors(inputs)});
588 return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr);
589 }
590
591 // If we are here, then we have ran this function before but constant folded
592 // some input(s) when it was not a constant input i.e. one of the small value
593 // to this function input changed. So mark those changed values as
594 // non-constant.
595 absl::flat_hash_map<int, NodeDef>& previous_small_inputs =
596 doperation_iter->second;
597 std::vector<int> non_constant_indices;
598
599 for (auto const& [index, previous_small_input] : previous_small_inputs) {
600 if (inputs[index]->const_value().has_value()) {
601 if (NodeDefsHaveDifferentTensorProto(
602 previous_small_input, inputs[index]->const_value().value())) {
603 inputs[index]->reset_const_value();
604 non_constant_indices.push_back(index);
605 }
606 }
607 }
608 for (int non_constant_index : non_constant_indices) {
609 previous_small_inputs.erase(non_constant_index);
610 }
611 // Generate a new cache key since we updated small const inputs which change
612 // the cache key.
613 cache_key = CacheKeyForGraph(doperation, attributes, inputs, output_layouts);
614 return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr);
615}
616
617const ExecutionFunctions* FunctionManager::AddCachedFunction(
618 const DTensorOperation& op, tensorflow::Fprint128 cache_key,
619 ExecutionFunctions function) {
620 return &function_cache_.insert({cache_key, std::move(function)})
621 .first->second;
622}
623
624bool FunctionManager::IsConstantFoldable(const DTensorOperation& doperation,
625 const int input_index) const {
626 // For eager ops, assume the inputs are constant foldable.
627 if (!doperation.is_func()) return true;
628 const tensorflow::Fprint128 doperation_hash =
629 CacheKeyForDTensorOperation(doperation);
630 // If we didn't see this doperation before then optimisticly assume this is
631 // foldable. The input at `input_index` is foldable only if it is one of the
632 // indices we have saved as the small inputs.
633 auto doperation_iter = dtensor_op_and_small_inputs_.find(doperation_hash);
634 return doperation_iter == dtensor_op_and_small_inputs_.end() ||
635 doperation_iter->second.contains(input_index);
636}
637
638const tensorflow::Fprint128 FunctionManager::CacheKeyForDTensorOperation(
639 const DTensorOperation& doperation) const {
640 return tensorflow::Fingerprint128(doperation.name);
641}
642
643std::vector<int64_t> TensorShapeAsVector(TFE_TensorHandle* tensor,
644 TF_Status* status) {
645 std::vector<int64_t> shape(TFE_TensorHandleNumDims(tensor, status));
646 if (TF_GetCode(status) != TF_OK) return {};
647 for (int i = 0; i < shape.size(); ++i) {
648 shape[i] = TFE_TensorHandleDim(tensor, i, status);
649 if (TF_GetCode(status) != TF_OK) return {};
650 }
651 return shape;
652}
653
654Status PrepareGraphForMlir(
655 const FunctionManager& function_manager,
656 const std::vector<TensorWithLayout*>& inputs,
657 const DTensorOperation& doperation,
658 const tensorflow::FunctionLibraryDefinition& flib_def,
659 const NameAttrList& attributes,
660 const absl::optional<Layout>& default_layout, tensorflow::Graph* graph,
661 std::vector<PartialTensorShape>* global_output_shapes,
662 std::vector<const Layout*>* output_layouts) {
663 // We run shape inference on the graph to find output shapes, which may
664 // determine default layouts.
665 ShapeRefiner shape_refiner(TF_GRAPH_DEF_VERSION, &flib_def);
666 shape_refiner.set_function_library_for_shape_inference(&flib_def);
667 tensorflow::Status status;
668 {
669 // We include an _Arg node for the device ID, but this isn't used by the
670 // initial function. It will be provided a value, though, so it's available
671 // for use in rewrites.
672 tensorflow::NodeDefBuilder builder("device_id", "_Arg");
673 tensorflow::PartialTensorShape partial_shape;
674 TF_RETURN_IF_ERROR(tensorflow::PartialTensorShape::MakePartialShape(
675 static_cast<int*>(nullptr), 0, &partial_shape));
676 tensorflow::NodeDef arg_node_def;
677 TF_RETURN_IF_ERROR(builder.Attr("shape", partial_shape)
678 .Attr("T", tensorflow::DT_INT32)
679 .Attr("index", 0)
680 .Finalize(&arg_node_def, /*consume=*/true));
681 tensorflow::Node* arg_node = graph->AddNode(arg_node_def, &status);
682 TF_RETURN_IF_ERROR(status);
683 graph->AddControlEdge(graph->source_node(), arg_node);
684 TF_RETURN_IF_ERROR(shape_refiner.AddNode(arg_node));
685 }
686 std::vector<FunctionArgument> graph_op_inputs;
687 graph_op_inputs.reserve(inputs.size());
688 for (int i = 0; i < inputs.size(); ++i) {
689 const TensorWithLayout* input = inputs[i];
690 // TODO(allenl): This will block until async execution is complete, which
691 // will be slow. We should find a non-blocking way of fetching the shape,
692 // at least pre-cache.
693 // The shape passed into MLIR transformation represents the global shape of
694 // the tensor. Ideally, the local shape on each parallel device should not
695 // be consulted at all and we should use the shape on our input tensor
696 // directly.
697 const auto& shape = input->global_shape();
698 std::vector<tensorflow::int64> cast_shape(shape.begin(), shape.end());
699 tensorflow::PartialTensorShape partial_shape;
700 // For resource tensors, `shape` attribute should not be specified as shape
701 // of resource tensors is specified by resource shape subtype -- not the
702 // shape attribute.
703 auto* resource = dynamic_cast<const ResourceHandleWithLayout*>(input);
704 if (!resource) {
705 TF_RETURN_IF_ERROR(tensorflow::PartialTensorShape::MakePartialShape(
706 cast_shape.data(), cast_shape.size(), &partial_shape));
707 }
708
709 tensorflow::NodeDef arg_node_def;
710 auto dtype = static_cast<tensorflow::DataType>(input->dtype());
711 tensorflow::NodeDefBuilder builder(absl::StrCat("op_input_", i), "_Arg");
712
713 // Delegate TensorWithLayout to encode attributes if applicable.
714 input->EncodeAttributes(builder);
715
716 // Here we set each arg node's `index` attribute to the position of
717 // the dtensor inputs. This is important for later use when we create
718 // a mapping from the graph argument node to the corresponding argument
719 // index of the list of dtensor inputs. Thus, even if the argument node
720 // orderings change within the graph, we can always correctly
721 // find the dtensor input corresponding to that arg node.
722 //
723 // This assumes that the dtensor inputs stay unchanged in ordering,
724 // and if there is an ordering change of dtensor inputs, then special
725 // care must be taken.
726 TF_RETURN_IF_ERROR(
727 builder.Attr("shape", partial_shape)
728 .Attr("T", dtype)
729 .Attr("index", i + 1) // Indices are offset by 1 for device_id
730 .Attr(kLayoutAttr, input->layout().ToString())
731 .Attr(kMeshAttr, input->mesh().mesh_config().ToString())
732 .Finalize(&arg_node_def, /*consume=*/true));
733 Node* arg_node = graph->AddNode(arg_node_def, &status);
734 TF_RETURN_IF_ERROR(status);
735 TF_RETURN_IF_ERROR(shape_refiner.AddNode(arg_node));
736
737 shape_inference::InferenceContext* inference_context =
738 shape_refiner.GetContext(arg_node);
739 shape_inference::ShapeHandle shape_handle;
740 TF_RETURN_IF_ERROR(inference_context->MakeShapeFromPartialTensorShape(
741 partial_shape, &shape_handle));
742 TF_RETURN_IF_ERROR(shape_refiner.SetShape(arg_node, 0, shape_handle));
743
744 // Small constants are converted into constant graph nodes, instead of being
745 // passed in as input arguments. This provides more information to the SPMD
746 // and layout propagation passes.
747 if (!input->const_value().has_value() ||
748 !function_manager.IsConstantFoldable(doperation, i)) {
749 graph_op_inputs.push_back(FunctionArgument{
750 arg_node, NodeDefBuilder::NodeOut{arg_node->name(), i, dtype}});
751 graph->AddControlEdge(graph->source_node(), arg_node);
752 } else {
753 // TODO(xiejw): Refactor the TensorWithLayout representation to avoid
754 // special code here.
755 NodeDef const_node = input->const_value().value();
756 const_node.set_name(absl::StrCat("input_", i, "_const_value"));
757 Node* const_value_n = graph->AddNode(const_node, &status);
758 TF_RETURN_IF_ERROR(status);
759 TF_RETURN_IF_ERROR(shape_refiner.AddNode(const_value_n));
760 graph_op_inputs.push_back(FunctionArgument{
761 const_value_n, tensorflow::NodeDefBuilder::NodeOut{
762 const_value_n->name(), i, dtype}});
763 }
764 }
765
766 tensorflow::NodeDef op_node_def;
767 const FunctionDef* function_def = doperation.function_def;
768 if (function_def) {
769 AttrValue func_attr;
770 func_attr.mutable_func()->set_name(doperation.name);
771 std::vector<tensorflow::NodeDefBuilder::NodeOut> func_inputs;
772 std::vector<tensorflow::DataType> inputs_types;
773 for (const auto& in : graph_op_inputs) {
774 func_inputs.emplace_back(in.output);
775 inputs_types.emplace_back(in.output.data_type);
776 }
777
778 std::vector<tensorflow::DataType> output_types;
779 for (const auto& out : function_def->signature().output_arg())
780 output_types.emplace_back(out.type());
781
782 TF_RETURN_IF_ERROR(
783 NodeDefBuilder("eager_operation", "StatefulPartitionedCall")
784 .Attr("Tin", inputs_types)
785 .Attr("Tout", output_types)
786 .Attr("f", func_attr)
787 .Input(func_inputs)
788 .Finalize(&op_node_def, true));
789 } else {
790 op_node_def.set_op(doperation.name);
791 op_node_def.set_name("eager_operation");
792 }
793
794 op_node_def.mutable_attr()->insert(attributes.attr().begin(),
795 attributes.attr().end());
796
797 tensorflow::Node* op_node = graph->AddNode(op_node_def, &status);
798 TF_RETURN_IF_ERROR(status);
799
800 for (int i = 0; i < graph_op_inputs.size(); ++i) {
801 graph->AddEdge(graph_op_inputs[i].node, 0, op_node, i);
802 }
803 TF_RETURN_IF_ERROR(shape_refiner.AddNode(op_node));
804
805 output_layouts->clear();
806 output_layouts->reserve(op_node->num_outputs());
807 global_output_shapes->reserve(op_node->num_outputs());
808 for (int output_index = 0; output_index < op_node->num_outputs();
809 ++output_index) {
810 tensorflow::NodeDefBuilder builder(absl::StrCat("op_output_", output_index),
811 "_Retval");
812 tensorflow::NodeDef ret_node_def;
813 tensorflow::DataType output_type = op_node->output_type(output_index);
814
815 TF_RETURN_IF_ERROR(builder.Attr("T", output_type)
816 .Attr("index", output_index)
817 .Input("eager_operation", output_index, output_type)
818 .Finalize(&ret_node_def, /*consume=*/true));
819 tensorflow::Node* ret_node = graph->AddNode(ret_node_def, &status);
820 TF_RETURN_IF_ERROR(status);
821 graph->AddEdge(op_node, output_index, ret_node, 0);
822 graph->AddControlEdge(ret_node, graph->sink_node());
823
824 shape_inference::InferenceContext* inference_context =
825 shape_refiner.GetContext(op_node);
826 shape_inference::ShapeHandle output_shape_handle =
827 inference_context->output(output_index);
828 TensorShapeProto output_shape_proto;
829 inference_context->ShapeHandleToProto(output_shape_handle,
830 &output_shape_proto);
831 PartialTensorShape global_output_shape(output_shape_proto);
832 VLOG(3) << "Inferred shape for operation '" << doperation.name
833 << "':" << global_output_shape.DebugString();
834 global_output_shapes->push_back(global_output_shape);
835
836 const Layout* layout = nullptr;
837 if (default_layout.has_value() && output_index == 0) {
838 // Record the user's requested output layout. The scope currently only
839 // covers the first output of an op.
840 layout = &default_layout.value();
841 ret_node->AddAttr(kDefaultLayoutAttr, layout->ToString());
842 }
843 output_layouts->push_back(layout);
844 }
845 return OkStatus();
846}
847
848// Returns set of functions to run to execute DTensor computation.
849StatusOr<ExecutionFunctions> IdentifyAllFunctionsToExecute(
850 const tensorflow::Graph& graph,
851 const std::vector<PartialTensorShape>& global_output_shapes) {
852 ExecutionFunctions execution_functions;
853 execution_functions.function_list = std::vector<TranslatedFunction>();
854 for (Node* node : graph.nodes()) {
855 if (node->op_def().name() != "StatefulPartitionedCall") continue;
856 // Extract mesh to execute the function.
857 std::string serialized_mesh;
858 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kMeshAttr, &serialized_mesh));
859 Mesh mesh;
860 TF_ASSIGN_OR_RETURN(mesh, Mesh::FromString(serialized_mesh));
861
862 TranslatedFunction function;
863 function.function_mesh = std::move(mesh);
864 function.node_to_execute = node;
865
866 // Identify input arg information.
867 TF_RETURN_IF_ERROR(
868 ParseResourceArgumentLayouts(*node, &function.resource_input_layouts));
869
870 TF_RETURN_IF_ERROR(
871 ParseShapeInputLayouts(*node, &function.shape_output_metadata));
872
873 function.input_index_map.resize(node->num_inputs());
874 // Identity mapping between local mesh function input index and global
875 // input index.
876 for (int in_index = 0; in_index < node->num_inputs(); ++in_index) {
877 Node* input_node;
878
879 TF_RETURN_IF_ERROR(node->input_node(in_index, &input_node));
880 if (!input_node->IsArg())
881 return errors::InvalidArgument(
882 "Input node to mesh computation must be arg node.");
883
884 int global_index;
885 TF_RETURN_IF_ERROR(
886 GetNodeAttr(input_node->attrs(), "index", &global_index));
887 function.input_index_map[in_index] = global_index;
888 }
889
890 // Identify output mappings and layouts for each outputs.
891 std::map<int, const Edge*> output_edges;
892 for (const Edge* out_edge : node->out_edges()) {
893 if (out_edge->IsControlEdge()) continue;
894
895 const Node* retval_or_identity_node = out_edge->dst();
896 while (retval_or_identity_node->IsIdentity()) {
897 retval_or_identity_node =
898 *(retval_or_identity_node->out_nodes().begin());
899 }
900
901 TF_RET_CHECK(retval_or_identity_node->IsRetval());
902 int global_index;
903 TF_RETURN_IF_ERROR(GetNodeAttr(retval_or_identity_node->attrs(), "index",
904 &global_index));
905 output_edges[global_index] = out_edge;
906 }
907
908 for (auto it = output_edges.begin(); it != output_edges.end(); it++) {
909 const int global_index = it->first;
910 function.output_index_map.emplace_back(global_index);
911
912 const Edge* retval_edge = it->second;
913 const int output_index = retval_edge->src_output();
914
915 // Add output layout and shape information.
916 TF_ASSIGN_OR_RETURN(
917 const Layout output_layout,
918 GetLayoutThroughIdentityOps(retval_edge->src(), output_index));
919
920 function.output_layouts.emplace_back(output_layout);
921 function.local_output_shapes.emplace_back(
922 output_layout.LocalShapeFromGlobalShape(
923 global_output_shapes[global_index]));
924 }
925
926 execution_functions.function_list.emplace_back(std::move(function));
927 }
928
929 if (execution_functions.function_list.empty()) {
930 return errors::InvalidArgument(
931 "MLIR transformed graph does not have any functions to execute for "
932 "mesh.");
933 }
934
935 return execution_functions;
936}
937
938// For functions with control outputs, add identity nodes between
939// StatefulPartitionedCall and _Retvals, in order to preserve control output
940// dependencies after StatefulPartitionedCall is inlined at runtime.
941// Consider calling this in PrepareGraphForMlir, once the identity nodes won't
942// be dropped during MLIR lowering.
943// TODO(b/171265131): fix the underlying issue to avoid inserting identity
944// nodes.
945Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph) {
946 if (function_def == nullptr || function_def->control_ret().empty()) {
947 return OkStatus();
948 }
949 tensorflow::Status status;
950 for (Node* n : graph->nodes()) {
951 if (!n->IsRetval()) {
952 continue;
953 }
954 const Edge* edge;
955 TF_RETURN_IF_ERROR(n->input_edge(0, &edge));
956 int ret_index;
957 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &ret_index));
958 tensorflow::NodeDefBuilder identity_builder(
959 absl::StrCat("op_output_identity_", ret_index), "Identity");
960 tensorflow::NodeDef ret_identity_node_def;
961 tensorflow::DataType output_type = n->input_type(0);
962 TF_RETURN_IF_ERROR(
963 identity_builder.Attr("T", output_type)
964 .Input(edge->src()->name(), edge->src_output(), output_type)
965 .Finalize(&ret_identity_node_def, /*consume=*/true));
966 Node* ret_identity_node = graph->AddNode(ret_identity_node_def, &status);
967 TF_RETURN_IF_ERROR(status);
968 // Delete the edge between StatefulPartitionedCall and _Retval.
969 graph->RemoveEdge(edge);
970 // Add an edge between StatefulPartitionedCall and Identity.
971 graph->AddEdge(edge->src(), edge->src_output(), ret_identity_node, 0);
972 graph->AddControlEdge(edge->src(), ret_identity_node);
973 // Add an edge between Identity and _Retval.
974 graph->AddEdge(ret_identity_node, 0, n, 0);
975 }
976 return OkStatus();
977}
978
979void AddDTensorFunctionAttr(FunctionDef& function_def) {
980 // Do not xla compile function returned by DTensor MLIR graph transformation
981 // as it already returns compiled graph.
982 AttrValue xla_must_compile_val;
983 xla_must_compile_val.set_b(false);
984 function_def.mutable_attr()->insert(
985 {"_XlaMustCompile", xla_must_compile_val});
986
987 // Explicitly place function outputs on the default function device to avoid
988 // redundant host <-> device copies (Placer may place outputs on the host
989 // CPU).
990 AttrValue outputs_on_op_device;
991 outputs_on_op_device.set_b(true);
992 function_def.mutable_attr()->insert(
993 {"_OutputsOnOpDevice", outputs_on_op_device});
994}
995
996StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs(
997 const std::vector<TensorWithLayout*>& inputs) {
998 absl::flat_hash_map<int64_t, std::vector<int64_t>> table_vars_input_index;
999 for (int64_t i = 0; i < inputs.size(); ++i) {
1000 if (inputs[i]->tensor_type() != kResource) continue;
1001
1002 const absl::optional<EmbeddingResourceAttrs>& resource_attrs =
1003 inputs[i]->attrs();
1004 if (resource_attrs.has_value()) {
1005 table_vars_input_index[resource_attrs->table_id].push_back(i);
1006 }
1007 }
1008
1009 // Check if there is no embedding resource input found.
1010 if (table_vars_input_index.empty()) {
1011 return errors::Internal("There are no TPU embedding resource input found.");
1012 }
1013 std::vector<parallel_device::ParallelTensor*> parallel_inputs;
1014 // Assure parallel inputs has numeric order as table ids.
1015 for (const auto& [table_id, table_vars_indices] : table_vars_input_index) {
1016 for (const int64_t input_index : table_vars_indices) {
1017 parallel_inputs.push_back(inputs[input_index]->tensor());
1018 }
1019 }
1020 return parallel_inputs;
1021}
1022
1023StatusOr<std::map<int64_t, std::vector<Node*>>> GetTPUEmbeddingInputNodes(
1024 TF_Status* s, const Graph& graph,
1025 const std::vector<TensorWithLayout*>& inputs) {
1026 // After the graph is lowered, the sparse tensors live at the end of the
1027 // argument list, so process the dtensor dense inputs only so that
1028 // we index correctly.
1029 std::vector<TensorWithLayout*> non_sparse_inputs;
1030 non_sparse_inputs.reserve(inputs.size());
1031 for (TensorWithLayout* input : inputs) {
1032 if (input->tensor_type() != TensorType::kSparse) {
1033 non_sparse_inputs.push_back(input);
1034 }
1035 }
1036 std::map<int64_t, std::vector<Node*>> table_id_node_map;
1037 for (Node* node : graph.nodes()) {
1038 if (!node->IsArg()) continue;
1039
1040 const int64_t& arg_id = node->attrs().Find("index")->i();
1041 const AttrValue* embedding_attr =
1042 node->attrs().Find("_tpu_embedding_table_id");
1043
1044 if (embedding_attr == nullptr) continue;
1045 EmbeddingResourceAttrs embedding_input_attrs;
1046
1047 // Add embedding table id.
1048 const int64_t table_id = embedding_attr->i();
1049 embedding_input_attrs.table_id = table_id;
1050
1051 // Add embedding slot id if there is one.
1052 const AttrValue* embedding_slot_attr =
1053 node->attrs().Find("_tpu_embedding_slot_id");
1054 if (embedding_slot_attr != nullptr) {
1055 const int64_t slot_id = embedding_slot_attr->i();
1056 embedding_input_attrs.slot_id = slot_id;
1057 }
1058
1059 table_id_node_map[table_id].push_back(node);
1060
1061 // Arg input offset due to device id.
1062 if (non_sparse_inputs[arg_id - 1]->attrs().has_value()) continue;
1063 non_sparse_inputs[arg_id - 1]->UpdateAttrs(embedding_input_attrs, s);
1064 if (!s->status.ok()) {
1065 return errors::Internal(
1066 "Failed to set embedding resource attrs. \n Got error: ",
1067 s->status.error_message());
1068 }
1069 }
1070 return table_id_node_map;
1071}
1072
1073StatusOr<std::string> ValidateResourceMeshConsistency(
1074 const std::vector<TensorWithLayout*>& inputs) {
1075 std::string mesh_str;
1076 for (TensorWithLayout* inp : inputs) {
1077 if ((inp->tensor_type() != kResource) || !inp->attrs().has_value())
1078 continue;
1079 const std::string& input_mesh_str = inp->layout().mesh().ToString();
1080 if (mesh_str.empty()) {
1081 mesh_str = input_mesh_str;
1082 } else if (mesh_str != input_mesh_str) {
1083 return errors::Internal(absl::StrCat(
1084 "All inputs of embedding resource must be on same mesh. but get : ",
1085 mesh_str, " != ", input_mesh_str));
1086 }
1087 }
1088 VLOG(1) << "Resource input mesh is : " << mesh_str;
1089 return mesh_str;
1090}
1091
1092Status InsertFunctionForTPUEmbeddingCheckpoint(
1093 TF_Status* status, Graph* graph,
1094 const std::vector<TensorWithLayout*>& inputs,
1095 const std::string& checkpoint_fn_name) {
1096 if (checkpoint_fn_name != kLoadEmbeddingFn &&
1097 checkpoint_fn_name != kRetrieveEmbeddingFn) {
1098 return errors::InvalidArgument(absl::StrCat(
1099 "Found wrong function name: ", checkpoint_fn_name,
1100 " \n expects : ", kLoadEmbeddingFn, " or ", kRetrieveEmbeddingFn));
1101 }
1102
1103 StatusOr<std::map<int64_t, std::vector<Node*>>> table_id_node_map =
1104 GetTPUEmbeddingInputNodes(status, *graph, inputs);
1105 if (!table_id_node_map.ok()) {
1106 return errors::Internal(table_id_node_map.status().error_message());
1107 }
1108
1109 StatusOr<std::string> mesh_str = ValidateResourceMeshConsistency(inputs);
1110
1111 const int64_t& num_tables = table_id_node_map->size();
1112 NodeDef func_node_def;
1113 std::vector<NodeDefBuilder::NodeOut> func_inputs;
1114 std::vector<DataType> input_types, output_types;
1115
1116 func_inputs.reserve(num_tables);
1117 input_types.reserve(num_tables);
1118
1119 for (int i = 0; i < num_tables; ++i) {
1120 auto node_vec_ptr = table_id_node_map->find(i);
1121 if (node_vec_ptr == table_id_node_map->end()) {
1122 return errors::Internal(
1123 absl::StrCat("Embedding table id ", i, " is not found."));
1124 }
1125 for (const Node* n : node_vec_ptr->second) {
1126 const std::string& node_name = n->name();
1127 func_inputs.push_back({node_name, i, DT_RESOURCE});
1128 input_types.push_back(DT_RESOURCE);
1129 }
1130 }
1131
1132 AttrValue mesh_attr;
1133 *mesh_attr.mutable_s() = *mesh_str;
1134 NameAttrList func_attr;
1135 func_attr.set_name(checkpoint_fn_name);
1136 TF_RETURN_IF_ERROR(
1137 NodeDefBuilder(checkpoint_fn_name, "StatefulPartitionedCall")
1138 .Attr("Tin", input_types)
1139 .Attr("Tout", output_types)
1140 .Attr("f", func_attr)
1141 .Attr(kMeshAttr, mesh_attr)
1142 .Attr("config", mesh_attr)
1143 .Input(func_inputs)
1144 .Finalize(&func_node_def, true));
1145
1146 TF_ASSIGN_OR_RETURN(Node * func_node, graph->AddNode(func_node_def));
1147 for (int i = 0; i < num_tables; ++i) {
1148 const std::vector<Node*>& node_vec = table_id_node_map->find(i)->second;
1149 for (int j = 0; j < node_vec.size(); ++j) {
1150 graph->AddEdge(node_vec[j], 0, func_node, j + i);
1151 }
1152 }
1153
1154 return OkStatus();
1155}
1156
1157} // namespace dtensor
1158} // namespace tensorflow
1159