1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_
17#define TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_
18
19#include <string>
20#include <utility>
21
22#include "tensorflow/c/eager/c_api.h"
23#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
24#include "tensorflow/c/eager/tfe_context_internal.h"
25#include "tensorflow/core/common_runtime/composite_device.h"
26#include "tensorflow/core/common_runtime/eager/context.h"
27#include "tensorflow/core/framework/function.h"
28#include "tensorflow/core/framework/function.pb.h"
29#include "tensorflow/core/framework/node_def_builder.h"
30#include "tensorflow/core/framework/tensor_shape.h"
31#include "tensorflow/core/graph/graph.h"
32#include "tensorflow/core/platform/errors.h"
33#include "tensorflow/core/platform/fingerprint.h"
34#include "tensorflow/dtensor/cc/constants.h"
35#include "tensorflow/dtensor/cc/dstatus.h"
36#include "tensorflow/dtensor/cc/tensor_layout.h"
37
38namespace tensorflow {
39namespace dtensor {
40
41#define RETURN_STATUS(status, code, message) \
42 { \
43 TF_SetStatus((status), (code), (message)); \
44 return; \
45 }
46
47#define RETURN_C_STATUS_IF_NOT_OK(cpp_status, c_status) \
48 { \
49 auto return_if_not_ok_status = (cpp_status); \
50 if (!return_if_not_ok_status.ok()) { \
51 RETURN_STATUS((c_status), \
52 static_cast<TF_Code>(return_if_not_ok_status.code()), \
53 return_if_not_ok_status.error_message().c_str()); \
54 } \
55 }
56
57// Using a counter to uniquify instead of a new block allows `var` to declare a
58// new variable.
59#define ASSIGN_OR_RETURN_C_STATUS(var, cpp_status, c_status) \
60 ASSIGN_OR_RETURN_C_STATUS_IMPL( \
61 TF_STATUS_MACROS_CONCAT_NAME(_dtensor_status_or_value, __COUNTER__), \
62 var, cpp_status, c_status)
63
64#define ASSIGN_OR_RETURN_C_STATUS_IMPL(statusor, var, cpp_status, c_status) \
65 auto statusor = (cpp_status); \
66 RETURN_C_STATUS_IF_NOT_OK(statusor.status(), (c_status)); \
67 var = std::move(statusor.value());
68
69struct TranslatedFunction {
70 // Mesh for which specified function will run.
71 Mesh function_mesh;
72
73 // StatefulPartitionedCall op to run the mesh function.
74 const Node* node_to_execute = nullptr;
75
76 // Maps i-th local input index to input index in global graph.
77 std::vector<int> input_index_map;
78
79 // Maps i-th local output to output index of global graph.
80 std::vector<int> output_index_map;
81
82 std::string translated_function_name;
83 // For resource ops, layouts of resource handles are inferred lazily
84 // during SPMD expansion of resource assign ops. In that case,
85 // inferred layouts of resource handles are attached to arg nodes
86 // of the returned graph.
87 std::map<int, Layout> resource_input_layouts;
88 // Record some metadata for output of a shape op. This would help recover
89 // local shape on future operations over the Tensor.
90 std::map<int, Layout> shape_output_metadata;
91 std::vector<Layout> output_layouts;
92 // Local shapes inferred for function outputs; these may be partially known.
93 std::vector<PartialTensorShape> local_output_shapes;
94 // Output data types.
95 std::vector<TF_DataType> output_dtypes;
96};
97
98struct ExecutionFunctions {
99 // Stores information about all functions to execute for provided computation.
100 std::vector<TranslatedFunction> function_list;
101 // Number of device ids args added to translated functions.
102 // During translation, we insert one device id arg node per mesh.
103 // For a single mesh function, it equals 1.
104 // For a multi-mesh function (e.g. pipelining), it equals the number of
105 // meshes.
106 int num_device_ids;
107
108 // Mesh fingerprint of function_list. Set only when ExecutionFunctions refers
109 // to a function for performance reason, since an eager op doesn't use it.
110 uint64 function_mesh_fingerprint = 0;
111};
112
113// TODO(yujingzhang): move FingerprintCat128 to tensorflow/platform.
114inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
115 const tensorflow::Fprint128& b) {
116 return {tensorflow::FingerprintCat64(a.low64, b.low64),
117 tensorflow::FingerprintCat64(a.high64, b.high64)};
118}
119
120inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
121 const int64 b) {
122 auto x = tensorflow::FingerprintCat64(a.low64, b);
123 return {x, tensorflow::FingerprintCat64(a.high64, x)};
124}
125
126struct DTensorOperation {
127 // For both fields: not owned. lifetime covers the whole usage.
128 const char* name;
129 const FunctionDef* function_def;
130
131 inline bool is_func() const { return function_def != nullptr; }
132};
133
134struct EmbeddingResourceAttrs {
135 int64_t table_id;
136 absl::optional<int64_t> slot_id; // NOLINT
137 bool is_dirty = false;
138};
139
140// Contains a mesh bundled with a parallel device over all of the devices in
141// that mesh.
142class MeshWithParallelDevice {
143 public:
144 MeshWithParallelDevice(
145 const Mesh& mesh_config,
146 std::unique_ptr<parallel_device::ParallelDevice> parallel_device,
147 const std::string& composite_device_name = "")
148 : mesh_config_(mesh_config),
149 parallel_device_(std::move(parallel_device)),
150 composite_device_name_(composite_device_name),
151 // Device IDs are constructed lazily because we don't have a context
152 // until we start executing ops.
153 device_ids_tensor_(nullptr) {}
154
155 // A parallel tensor containing scalar integer device IDs for underlying
156 // devices, each placed on its corresponding device.
157 //
158 // TODO(allenl): It would be nice if DeviceID worked as an op inside the
159 // function's graph. Then we wouldn't need to feed it as an argument.
160 parallel_device::ParallelTensor* DeviceIDs(TFE_Context* context,
161 TF_Status* status) const;
162 const parallel_device::ParallelDevice& parallel_device() const {
163 return *parallel_device_;
164 }
165
166 const dtensor::Mesh& mesh_config() const { return mesh_config_; }
167
168 // Creates a CompositeDevice in eager context if it not exists.
169 // Called when parallel_device_ contains a subset of global devices, e.g.
170 // pipelining is enabled.
171 StatusOr<CompositeDevice*> FindOrCreateCompositeDevice(TFE_Context* context) {
172 if (composite_device_ == nullptr && !composite_device_name_.empty()) {
173 if (mesh_config_.global_devices().empty()) {
174 return errors::InvalidArgument(
175 "Expect non-empty global devices when creating a CompositeDevice.");
176 }
177 TF_RETURN_IF_ERROR(ContextFromInterface(tensorflow::unwrap(context))
178 ->FindOrCreateCompositeDevice(
179 mesh_config_.global_devices(),
180 composite_device_name_, &composite_device_));
181 }
182 return composite_device_;
183 }
184
185 CompositeDevice* composite_device() const { return composite_device_; }
186
187 private:
188 dtensor::Mesh mesh_config_;
189 std::unique_ptr<parallel_device::ParallelDevice> parallel_device_;
190
191 // Set when parallel_device_ contains a subset of global devices, e.g.
192 // pipelining is enabled.
193 const std::string composite_device_name_;
194 // A tensorflow::Device that represents underlying devices of
195 // parallel_device_. Set when composite_device_name_ is not empty.
196 CompositeDevice* composite_device_ = nullptr; // owned by eager context
197
198 // Constructed lazily; contains a parallel tensor with scalar integer device
199 // IDs for each device.
200 mutable std::unique_ptr<parallel_device::ParallelTensor> device_ids_tensor_;
201};
202
203enum TensorType {
204 kDense = 0,
205 kResource = 1,
206 kSparse = 2,
207};
208
209class TensorWithLayout {
210 public:
211 // Broadcast a single non-parallel tensor onto `mesh` with a fully replicated
212 // sharding spec. Does not take ownership of `tensor`.
213 static std::unique_ptr<TensorWithLayout> Broadcast(
214 TFE_Context* context, TFE_TensorHandle* tensor,
215 const MeshWithParallelDevice& mesh,
216 const std::string& dtensor_device_name, TF_Status* status);
217
218 // Given an already-parallel tensor, wraps it with a mesh and a layout.
219 static StatusOr<std::unique_ptr<TensorWithLayout>> Wrap(
220 std::unique_ptr<parallel_device::ParallelTensor> tensor,
221 const MeshWithParallelDevice& mesh, const Layout& layout);
222
223 // A dummy TensorWithLayout without holding a ParallelTensor.
224 static std::unique_ptr<TensorWithLayout> Dummy(
225 const std::vector<int64_t>& local_shape, const TF_DataType dtype,
226 const MeshWithParallelDevice& mesh, const Layout& layout);
227
228 virtual ~TensorWithLayout() {}
229
230 virtual const Layout& layout() const { return layout_; }
231
232 virtual TensorType tensor_type() const { return TensorType::kDense; }
233
234 virtual TF_DataType dtype() const {
235 if (dtype_.has_value()) {
236 return dtype_.value();
237 } else {
238 return tensor_->dtype();
239 }
240 }
241
242 // Small constant value optimization for non-resource-handle tensors.
243 virtual void set_const_value(NodeDef& const_node) {
244 // If we extracted a constant value from the tensor, check if this
245 // value was the output from `tf.shape`. In this case, we need to
246 // forward the kShapeOpInputLayout attribute to the new node def. This
247 // is needed for layout propagation when running in op-by-op mode.
248 //
249 // TODO(b/162747667): Improve the presentation for Shape input Op
250 // layout.
251 if (shape_metadata_layout().has_value()) {
252 AddNodeAttr(kShapeOpInputLayout, {shape_metadata_layout()->ToString()},
253 &(const_node));
254 }
255 const_value_.emplace(const_node);
256 }
257
258 // Clears the cached const value if present.
259 void reset_const_value() { const_value_.reset(); }
260
261 // Encodes the NodeDef via provided builder, if applicable.
262 virtual void EncodeAttributes(tensorflow::NodeDefBuilder& builder) const {}
263
264 virtual tensorflow::Fprint128 CacheKey() const;
265
266 // Updates layout for this Tensor.
267 virtual void UpdateLayout(const Layout& new_layout, TF_Status* status) {
268 TF_SetStatus(status, TF_INTERNAL,
269 "Attempt to update layout on non-resource-handle");
270 }
271
272 // Update shape and dtype.
273 virtual void UpdateShapeAndDType(const TensorShapeProto& shape,
274 const DataType& dtype, TF_Status* status) {
275 TF_SetStatus(status, TF_INTERNAL,
276 "Attempt to update shape and layout on non-resource-handle");
277 }
278
279 // Update Attrs for this Tensor.
280 virtual void UpdateAttrs(const EmbeddingResourceAttrs& attrs,
281 TF_Status* status) {
282 TF_SetStatus(status, TF_INTERNAL,
283 "Attempt to update layout on non-resource-handle");
284 }
285
286 virtual TFE_TensorHandle* get_tensor(size_t index) const {
287 return tensor()->tensor(index);
288 }
289
290 virtual size_t num_tensors() const { return tensor()->num_tensors(); }
291
292 virtual parallel_device::ParallelTensor* tensor() const {
293 return tensor_.get();
294 }
295
296 // Returns a string which includes just the value and layout of the tensor.
297 virtual std::string SummarizeValue() const;
298 // Returns a string which includes `SummarizeValue` along with shape and type
299 // information.
300 virtual std::string DebugString() const;
301
302 void set_input_layout_for_shape_op_result(const Layout& layout) {
303 input_layout_for_shape_op_result_.emplace(layout);
304 }
305
306 const absl::optional<Layout> shape_metadata_layout() const {
307 return input_layout_for_shape_op_result_;
308 }
309
310 const MeshWithParallelDevice& mesh() const { return mesh_; }
311
312 // Compute global shape from layout & local tensor shape.
313 //
314 // For replicated layout tensors, global shape is simply the shape of local
315 // tensors on each device. For sharded tensor, this is the global shape
316 // encodes layout & local shape on each device.
317 const std::vector<int64_t> global_shape() const {
318 return layout().GlobalShapeFromLocalShape(local_shape());
319 }
320
321 const std::vector<int64_t> local_shape() const { return local_shape_; }
322
323 const absl::optional<NodeDef> const_value() const { return const_value_; }
324
325 const absl::optional<EmbeddingResourceAttrs>& attrs() const { return attrs_; }
326
327 protected:
328 TensorWithLayout(std::unique_ptr<parallel_device::ParallelTensor> tensor,
329 const MeshWithParallelDevice& mesh, const Layout& layout,
330 std::vector<int64_t> local_shape,
331 absl::optional<TF_DataType> dtype = absl::nullopt,
332 absl::optional<NodeDef> const_value = absl::nullopt)
333 : tensor_(std::move(tensor)),
334 layout_(layout),
335 mesh_(mesh),
336 const_value_(std::move(const_value)),
337 local_shape_(local_shape),
338 dtype_(dtype) {}
339
340 std::unique_ptr<parallel_device::ParallelTensor> tensor_;
341
342 Layout layout_;
343
344 const MeshWithParallelDevice& mesh_;
345
346 // Optionally holds the value of a small, non-resource tensor. Small constants
347 // are directly folded into the SPMD graph instead of being passed as inputs.
348 // This provides extra information to the layout propagation and SPMD passes
349 // during op-by-op execution. (For example, the reduction indices for Sum,
350 // target shapes for Rng/Reshape, etc).
351 absl::optional<NodeDef> const_value_;
352
353 // Optionally holds the original input layout for a shape Op returned Tensor.
354 // This is used to preserve information for a shape op output so that future
355 // uses could recover local shape.
356 // TODO(hthu,allenl,xiejw): Move this into a separate class for clarity.
357 absl::optional<Layout> input_layout_for_shape_op_result_ = absl::nullopt;
358
359 // The local shape of tensors placed on each of `tensor_`'s component devices.
360 std::vector<int64_t> local_shape_;
361
362 absl::optional<TF_DataType> dtype_;
363
364 // Resource input attributes for embedding inputs.
365 absl::optional<EmbeddingResourceAttrs> attrs_; // NOLINT
366};
367
368// Extension of TensorWithLayout which holds resource handle with layout.
369//
370// The major differences are
371// 1. The layout, shape, dtype are lazily set as they are unavailable upon
372// creation.
373// 2. Small const optimization should be disabled.
374class ResourceHandleWithLayout : public TensorWithLayout {
375 public:
376 // The layout of uninitialized resource tensors, or the layout of the tensor
377 // contained in an initialized resource.
378 const Layout& layout() const override {
379 return dereferenced_layout_.has_value() ? dereferenced_layout_.value()
380 : layout_;
381 }
382
383 TensorType tensor_type() const override { return TensorType::kResource; }
384
385 void set_const_value(NodeDef& const_node) override {
386 // Just a no-op for resource handle. Maybe we should error out.
387 }
388
389 void EncodeAttributes(tensorflow::NodeDefBuilder& builder) const override;
390
391 tensorflow::Fprint128 CacheKey() const override;
392
393 void UpdateLayout(const Layout& new_layout, TF_Status* status) override;
394
395 void UpdateShapeAndDType(const TensorShapeProto& shape, const DataType& dtype,
396 TF_Status* status) override {
397 set_dereferenced_shape(shape);
398 set_dereferenced_dtype(dtype);
399 }
400
401 void UpdateAttrs(const EmbeddingResourceAttrs& attrs,
402 TF_Status* status) override;
403
404 void UpdateDirtyness(bool is_dirty, TF_Status* status) {
405 if (!attrs_.has_value()) {
406 TF_SetStatus(status, TF_INTERNAL,
407 "Attempt to update dirtyness on non embedding resource");
408 }
409 attrs_.value().is_dirty = is_dirty;
410 }
411
412 void set_dereferenced_shape(const TensorShapeProto& shape) {
413 dereferenced_shape_.emplace(shape);
414 }
415 void set_dereferenced_dtype(const DataType& dtype) {
416 dereferenced_dtype_.emplace(dtype);
417 }
418
419 const absl::optional<TensorShapeProto>& dereferenced_shape() const {
420 return dereferenced_shape_;
421 }
422 const absl::optional<DataType>& dereferenced_dtype() const {
423 return dereferenced_dtype_;
424 }
425
426 public:
427 ResourceHandleWithLayout(
428 std::unique_ptr<parallel_device::ParallelTensor> tensor,
429 const MeshWithParallelDevice& mesh, const Layout& layout,
430 std::vector<int64_t> local_shape)
431 : TensorWithLayout(std::move(tensor), mesh, layout, local_shape,
432 TF_RESOURCE) {}
433
434 private:
435 // The layout of the tensor pointed to by this handle, if any.
436 absl::optional<Layout> dereferenced_layout_;
437 // The shape and dtype of the tensor pointed to by this resource tensor.
438 absl::optional<TensorShapeProto> dereferenced_shape_;
439 absl::optional<DataType> dereferenced_dtype_;
440};
441
442// TensorWithLayout for SparseTensors.
443//
444// The main difference between this and TensorWithLayout is this
445// contains 3 lists of tensors as opposed to one (values, indices, shapes).
446// The shapes of the SparseTensors will always be the dense view of the shapes,
447// and thus will have no difference with the TensorWithLayout in terms of
448// shapes.
449class SparseTensorWithLayout : public TensorWithLayout {
450 public:
451 static StatusOr<std::unique_ptr<TensorWithLayout>> Wrap(
452 std::unique_ptr<parallel_device::ParallelTensor> indices_tensor,
453 std::unique_ptr<parallel_device::ParallelTensor> values_tensor,
454 std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor,
455 const MeshWithParallelDevice& mesh, const Layout& layout,
456 std::vector<int64_t> local_shape);
457
458 // A dummy TensorWithLayout without holding a ParallelTensor.
459 static std::unique_ptr<TensorWithLayout> Dummy(
460 const std::vector<int64_t>& local_shape,
461 const MeshWithParallelDevice& mesh, const Layout& layout) {
462 return std::unique_ptr<TensorWithLayout>(new SparseTensorWithLayout(
463 /*indices=*/nullptr, /*values=*/nullptr, /*dense_shapes=*/nullptr, mesh,
464 layout, local_shape));
465 }
466
467 void set_const_value(NodeDef& const_node) override {
468 // No-op for SparseTensors, consider erroring out.
469 }
470
471 // Add attribute '_sparse' to the NodeDefBuilder so that the mlir::Value
472 // that originate from SparseTensorWithLayout are marked as '_sparse'.
473 void EncodeAttributes(tensorflow::NodeDefBuilder& builder) const override {
474 builder.Attr("_sparse", true);
475 }
476
477 TensorType tensor_type() const override { return TensorType::kSparse; }
478
479 size_t num_tensors() const override { return 3 * indices()->num_tensors(); }
480
481 TFE_TensorHandle* get_tensor(size_t index) const override;
482
483 std::string SummarizeValue() const override;
484
485 std::string DebugString() const override;
486
487 TF_DataType dtype() const override;
488
489 parallel_device::ParallelTensor* indices() const { return indices_.get(); }
490
491 parallel_device::ParallelTensor* values() const { return values_.get(); }
492
493 parallel_device::ParallelTensor* dense_shapes() const {
494 return dense_shapes_.get();
495 }
496
497 protected:
498 SparseTensorWithLayout(
499 std::unique_ptr<parallel_device::ParallelTensor> indices,
500 std::unique_ptr<parallel_device::ParallelTensor> values,
501 std::unique_ptr<parallel_device::ParallelTensor> dense_shapes,
502 const MeshWithParallelDevice& mesh, const Layout& layout,
503 std::vector<int64_t> local_shape,
504 absl::optional<TF_DataType> dtype = absl::nullopt,
505 absl::optional<NodeDef> const_value = absl::nullopt)
506 : TensorWithLayout(nullptr, mesh, layout, local_shape),
507 indices_(std::move(indices)),
508 values_(std::move(values)),
509 dense_shapes_(std::move(dense_shapes)) {}
510 std::unique_ptr<parallel_device::ParallelTensor> indices_;
511 std::unique_ptr<parallel_device::ParallelTensor> values_;
512 std::unique_ptr<parallel_device::ParallelTensor> dense_shapes_;
513};
514
515template <typename T>
516std::string ShapeToDebugString(const std::vector<T> shape_vector) {
517 std::vector<tensorflow::int64> cast_shape(shape_vector.begin(),
518 shape_vector.end());
519 tensorflow::PartialTensorShape shape;
520 if (!tensorflow::PartialTensorShape::MakePartialShape(
521 cast_shape.data(), cast_shape.size(), &shape)
522 .ok()) {
523 return "<error displaying shape>";
524 } else {
525 return shape.DebugString();
526 }
527}
528// Class that holds information about DTensor Functions ran, including cached
529// lowered functions and constant folding input information per function.
530//
531//
532// The caching policy for constant folded inputs is the following:
533// In the first call to a function, we assume that all the inputs that
534// are constant foldable are constant folded and save these values. In the
535// next call to the same function call, we compare the values of constant
536// folded inputs to the previous constant folded inputs. We disable constant
537// folding for the changed values, and save these new inputs.
538// TODO(b/169348205) Support cache eviction if the cache gets bloated.
539class FunctionManager {
540 public:
541 FunctionManager() = default;
542
543 // Caches the graph with the lowered 'function'.
544 const ExecutionFunctions* AddCachedFunction(const DTensorOperation& op,
545 tensorflow::Fprint128 cache_key,
546 ExecutionFunctions function);
547
548 // Returns the cache key and the cached lowered graph for the function.
549 // Returns a nullptr for the lowered graph if there is a cache miss.
550 // Upon a cache miss, this will save some metadata about the function
551 // and the small inputs to keep track of information for constant folding.
552 std::pair<tensorflow::Fprint128, const ExecutionFunctions*> GetCachedFunction(
553 const DTensorOperation& doperation, const NameAttrList& attributes,
554 const std::vector<TensorWithLayout*>& inputs,
555 const std::vector<const Layout*>& output_layouts);
556
557 // Returns whether the input at `input_index` is known to be constant
558 // foldable for function `doperation`. An input is not constant foldable if we
559 // have ran this function at least twice and the small input value changed
560 // across separate runs.
561 bool IsConstantFoldable(const DTensorOperation& doperation,
562 const int input_index) const;
563
564 private:
565 // Cache key for dtensor operation name, which includes the op name
566 // and the input shapes. This is needed as a higher level cache for constant
567 // folding.
568 const tensorflow::Fprint128 CacheKeyForDTensorOperation(
569 const DTensorOperation& doperation) const;
570
571 // Generates a cache key for the graph, including its attributes,
572 // inputs, and outputs.
573 tensorflow::Fprint128 CacheKeyForGraph(
574 const DTensorOperation& doperation, const NameAttrList& attributes,
575 const std::vector<TensorWithLayout*>& inputs,
576 const std::vector<const Layout*>& output_layouts);
577
578 // Maps the hash of a graph with the lowered graph.
579 absl::flat_hash_map<tensorflow::Fprint128, ExecutionFunctions,
580 tensorflow::Fprint128Hasher>
581 function_cache_;
582
583 // Maps the hash of dtensor_operation and its input shapes to a map
584 // representing the small constant indices and values to the function. The
585 // small constant indices are saved to make faster comparisons for constant
586 // folding validation.
587 absl::flat_hash_map<tensorflow::Fprint128, absl::flat_hash_map<int, NodeDef>,
588 tensorflow::Fprint128Hasher>
589 dtensor_op_and_small_inputs_;
590};
591
592// Returns the shape of a given tensor.
593std::vector<int64_t> TensorShapeAsVector(TFE_TensorHandle* tensor,
594 TF_Status* status);
595
596// Creates a Graph with _Arg and _Retval nodes surrounding an
597// `operation_name`-type node.
598Status PrepareGraphForMlir(
599 const FunctionManager& function_manager,
600 const std::vector<TensorWithLayout*>& inputs,
601 const DTensorOperation& doperation,
602 const tensorflow::FunctionLibraryDefinition& flib_def,
603 const NameAttrList& attributes,
604 const absl::optional<Layout>& default_layout, tensorflow::Graph* graph,
605 std::vector<PartialTensorShape>* global_output_shapes,
606 std::vector<const Layout*>* output_layouts);
607
608// Returns set of functions to run to execute DTensor computation.
609StatusOr<ExecutionFunctions> IdentifyAllFunctionsToExecute(
610 const tensorflow::Graph& graph,
611 const std::vector<PartialTensorShape>& global_output_shapes);
612
613// For functions with control outputs, add identity nodes between
614// StatefulPartitionedCall and _Retvals, in order to preserve control output
615// dependencies after StatefulPartitionedCall is inlined at runtime.
616// Consider calling this in PrepareGraphForMlir, once the identity nodes won't
617// be dropped during MLIR lowering.
618// TODO(b/171265131): fix the underlying issue to avoid inserting identity
619// nodes.
620Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph);
621
622// Add DTensor specific function attributes to be compatible with eager runtime.
623void AddDTensorFunctionAttr(FunctionDef& function_def);
624
625// Prepare inputs of embeddings for checkpoint functions.
626StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs(
627 const std::vector<TensorWithLayout*>& inputs);
628
629Status InsertFunctionForTPUEmbeddingCheckpoint(
630 TF_Status* status, Graph* graph,
631 const std::vector<TensorWithLayout*>& inputs,
632 const std::string& checkpoint_fn_name);
633
634} // namespace dtensor
635} // namespace tensorflow
636
637#endif // TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_
638