1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/cc/dtensor_device.h"
17
18#include <algorithm>
19#include <cstdint>
20#include <memory>
21#include <string>
22#include <utility>
23#include <vector>
24
25#include "absl/base/attributes.h"
26#include "absl/container/flat_hash_map.h"
27#include "absl/container/flat_hash_set.h"
28#include "absl/memory/memory.h"
29#include "absl/strings/match.h"
30#include "absl/strings/str_cat.h"
31#include "absl/strings/str_join.h"
32#include "absl/strings/string_view.h"
33#include "absl/strings/strip.h"
34#include "tensorflow/c/c_api_experimental.h"
35#include "tensorflow/c/eager/c_api.h"
36#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
37#include "tensorflow/c/eager/tfe_context_internal.h"
38#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
39#include "tensorflow/c/tf_datatype.h"
40#include "tensorflow/c/tf_status.h"
41#include "tensorflow/c/tf_status_helper.h"
42#include "tensorflow/c/tf_tensor_internal.h"
43#include "tensorflow/compiler/xla/status_macros.h"
44#include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h"
45#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.h"
46#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_topology.h"
47#include "tensorflow/core/common_runtime/device_set.h"
48#include "tensorflow/core/common_runtime/eager/context.h"
49#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
50#include "tensorflow/core/common_runtime/graph_constructor.h"
51#include "tensorflow/core/common_runtime/shape_refiner.h"
52#include "tensorflow/core/framework/attr_value.pb.h"
53#include "tensorflow/core/framework/function.h"
54#include "tensorflow/core/framework/function.pb.h"
55#include "tensorflow/core/framework/graph_to_functiondef.h"
56#include "tensorflow/core/framework/node_def_builder.h"
57#include "tensorflow/core/framework/node_def_util.h"
58#include "tensorflow/core/framework/op.h"
59#include "tensorflow/core/framework/tensor_shape.h"
60#include "tensorflow/core/graph/algorithm.h"
61#include "tensorflow/core/graph/graph.h"
62#include "tensorflow/core/lib/strings/proto_serialization.h"
63#include "tensorflow/core/platform/errors.h"
64#include "tensorflow/core/platform/fingerprint.h"
65#include "tensorflow/core/platform/types.h"
66#include "tensorflow/core/profiler/lib/traceme.h"
67#include "tensorflow/core/util/dump_graph.h"
68#include "tensorflow/dtensor/cc/constants.h"
69#include "tensorflow/dtensor/cc/dstatus.h"
70#include "tensorflow/dtensor/cc/dtensor_device_util.h"
71#include "tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h"
72#include "tensorflow/dtensor/cc/small_constant_optimization.h"
73#include "tensorflow/dtensor/cc/tensor_layout.h"
74#include "tensorflow/dtensor/cc/tpu_system_interface.h"
75#include "tensorflow/dtensor/proto/layout.pb.h"
76
77namespace tensorflow {
78namespace dtensor {
79
80// TODO(b/189332820): Replace this with a Partitioner stub swapped in by the
81// Copybara workflow.
82StatusOr<ExecutionFunctions> ABSL_ATTRIBUTE_WEAK PipeliningPartitionerRun(
83 const absl::flat_hash_map<std::string, const MeshWithParallelDevice*>*
84 device_name_to_mesh_device,
85 FunctionLibraryDefinition* flib_def, DTensorMlirPassRunner* pass_runner,
86 const FunctionDef& fdef, const NameAttrList& eager_attributes,
87 const std::vector<TensorWithLayout*>& inputs, const DeviceSet& device_set,
88 int num_outputs) {
89 // The actual definition is in the pipelining package.
90 return errors::Unimplemented("DTensor pipelining is unavailable.");
91}
92
93class DTensorDevice {
94 public:
95 explicit DTensorDevice(absl::string_view name)
96 : name_(name),
97 same_shape_policy_enabled_(false),
98 cancellation_manager_(std::make_unique<CancellationManager>()) {}
99
100 void AddMesh(std::unique_ptr<MeshWithParallelDevice> mesh,
101 bool is_host_mesh) {
102 if (is_host_mesh) {
103 std::string& tpu_host_mesh = Mesh::tpu_host_mesh();
104 const std::string new_tpu_host_mesh = mesh->mesh_config().ToString();
105 if (!tpu_host_mesh.empty()) {
106 // TODO(b/180046115): Add per-TPU-mesh host mesh bookkeeping.
107 LOG(WARNING)
108 << "A new TPU host mesh is overwriting the old TPU host mesh. The "
109 "old TPU mesh cannot be used in sea of donuts mode anymore.";
110 }
111 tpu_host_mesh.assign(new_tpu_host_mesh);
112 }
113 // For idempotency, don't register the same mesh twice.
114 if (!mesh_to_device_map_.insert({mesh->mesh_config(), std::move(mesh)})
115 .second)
116 return;
117 if (!default_mesh_) {
118 global_default_mesh_ = mesh_to_device_map_.begin()->second.get();
119 default_mesh_ = global_default_mesh_;
120 }
121 }
122
123 // Returns sub meshes of pipelining.
124 // Key is the name of a composite device.
125 StatusOr<absl::flat_hash_map<std::string, const MeshWithParallelDevice*>>
126 PipelineSubMeshes(TFE_Context* context) {
127 absl::flat_hash_map<std::string, const MeshWithParallelDevice*>
128 device_to_mesh;
129 for (const auto& pair : mesh_to_device_map_) {
130 TF_ASSIGN_OR_RETURN(CompositeDevice * device,
131 pair.second->FindOrCreateCompositeDevice(context));
132 if (device != nullptr) {
133 device_to_mesh[pair.second->composite_device()->name()] =
134 pair.second.get();
135 }
136 }
137 return device_to_mesh;
138 }
139
140 // Runs an operation on the DTensorDevice,
141 //
142 // Ignoring the placement of the original op (TFE_OpGetDevice(original_op)).
143 // This indicates whether the user explicitly placed the op on the DTensor
144 // device (vs. having it placed on the DTensor device because an input was
145 // placed there), but DTensor is doing type-based dispatch and so handles
146 // these cases identically at the moment.
147 void Execute(const TFE_Op* original_op, int* num_outputs,
148 TFE_TensorHandle** outputs, TF_Status* status);
149
150 void SetDefaultLayout(Layout layout) { default_layout_.emplace(layout); }
151 void ClearDefaultLayout() { default_layout_.reset(); }
152 void SetDefaultMesh(Mesh mesh) {
153 default_mesh_ = mesh_to_device_map_.at(mesh).get();
154 }
155 void ClearDefaultMesh() { default_mesh_ = global_default_mesh_; }
156 void SetSameShapePolicy(bool enabled) {
157 same_shape_policy_enabled_ = enabled;
158 }
159
160 Status SetTPUCoreIDs(const std::string& mesh_name,
161 const std::vector<int>& tpu_core_ids) {
162 if (VLOG_IS_ON(1)) {
163 LOG(INFO) << "Setting TPU core IDs for "
164 << (mesh_name.empty() ? "default mesh" : mesh_name) << ": ";
165 for (auto i : tpu_core_ids) {
166 LOG(INFO) << i;
167 }
168 }
169 // Setting the default mesh under an empty name repeatedly is fine, which
170 // happens when dtensor_initialize_tpu_system is called multiple times
171 // especially in tests. All the set mappings should be the same anyway.
172 if (!mesh_name.empty() && Mesh::tpu_core_ids().count(mesh_name) > 0) {
173 return errors::AlreadyExists("Mesh name already in use: ", mesh_name);
174 }
175 Mesh::tpu_core_ids()[mesh_name].assign(tpu_core_ids.begin(),
176 tpu_core_ids.end());
177 return OkStatus();
178 }
179
180 void ClearTPUCoreIDs() { Mesh::tpu_core_ids().clear(); }
181
182 std::vector<std::vector<int>> TPUCoreIDsToLocations(
183 TFE_Context* context, const std::vector<int>& tpu_core_ids) {
184 TpuSystemInterface* tpu_system = GetPreferredTpuSystem();
185 if (tpu_system == nullptr) {
186 VLOG(1) << "Calling TPUCoreIDsToLocations on the default TPU system.";
187 std::vector<std::vector<int>> tpu_core_locations;
188 tpu_core_locations.reserve(tpu_core_ids.size());
189 tpu::TpuPlatformInterface* tpu_platform =
190 tpu::TpuPlatformInterface::GetRegisteredPlatform();
191 if (tpu_platform == nullptr) {
192 LOG(WARNING) << "No TPU platform is found.";
193 return {{}};
194 }
195 if (!tpu_platform->Initialized()) {
196 LOG(WARNING) << "TPU platform is not initialized.";
197 return {{}};
198 }
199 tpu::TpuTopologyExternal tpu_topology = tpu_platform->topology();
200
201 for (const int& tpu_core_id : tpu_core_ids) {
202 tpu::TpuCoreLocationExternal core =
203 tpu_topology.CoreForId(TpuCoreTypeEnum::kTensorCore, tpu_core_id);
204 tpu::TpuDimensionsExternal tpu_chip_location = core.chip_coordinates();
205 tpu_core_locations.push_back({tpu_chip_location.x, tpu_chip_location.y,
206 tpu_chip_location.z, core.index()});
207 }
208 return tpu_core_locations;
209 } else {
210 VLOG(1) << "Calling TPUCoreIDsToLocations on a preferred TPU system.";
211 return tpu_system->TPUCoreIDsToLocations(context, tpu_core_ids);
212 }
213 }
214
215 std::vector<int> TPUCoreLocationsToIDs(
216 TFE_Context* context,
217 const std::vector<std::vector<int>>& tpu_core_locations) {
218 TpuSystemInterface* tpu_system = GetPreferredTpuSystem();
219 if (tpu_system == nullptr) {
220 VLOG(1) << "Calling TPUCoreLocationsToIDs on the default TPU system.";
221 std::vector<int> tpu_core_ids;
222 tpu_core_ids.reserve(tpu_core_locations.size());
223 tpu::TpuPlatformInterface* tpu_platform =
224 tpu::TpuPlatformInterface::GetRegisteredPlatform();
225 if (tpu_platform == nullptr) {
226 LOG(WARNING) << "No TPU platform is found.";
227 return {};
228 }
229 if (!tpu_platform->Initialized()) {
230 LOG(WARNING) << "TPU platform is not initialized.";
231 return {};
232 }
233 tpu::TpuTopologyExternal tpu_topology = tpu_platform->topology();
234
235 for (const std::vector<int>& tpu_core_location : tpu_core_locations) {
236 tpu::TpuCoreLocationExternal core = tpu_topology.Core(
237 TpuCoreTypeEnum::kTensorCore, tpu_core_location[0],
238 tpu_core_location[1], tpu_core_location[2], tpu_core_location[3]);
239 tpu_core_ids.push_back(core.Id());
240 }
241 return tpu_core_ids;
242 } else {
243 VLOG(1) << "Calling TPUCoreLocationsToIDs on a preferred TPU system.";
244 return tpu_system->TPUCoreLocationsToIDs(context, tpu_core_locations);
245 }
246 }
247
248 // Waits for ops to finish in ALL meshes as we share the cancellation manager.
249 void AsyncWait(TFE_Context* context, TF_Status* status) {
250 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> first_bad_status(
251 nullptr, TF_DeleteStatus);
252
253 for (const auto& pair : mesh_to_device_map_) {
254 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status(
255 TF_NewStatus(), TF_DeleteStatus);
256
257 pair.second->parallel_device().AsyncWait(context,
258 async_wait_status.get());
259
260 TF_Code error_code = TF_GetCode(async_wait_status.get());
261 if (error_code != TF_OK &&
262 (first_bad_status == nullptr ||
263 TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
264 first_bad_status.reset(TF_NewStatus());
265 TF_SetStatus(first_bad_status.get(), error_code,
266 TF_Message(async_wait_status.get()));
267 }
268 }
269
270 if (first_bad_status != nullptr) {
271 TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
272 TF_Message(first_bad_status.get()));
273 }
274
275 // Reset the global function rendezvous, which otherwise stores a failure
276 // state.
277 tensorflow::unwrap(context)->ResetGlobalRendezvousForFunction();
278
279 // Reset the cancellation manager on (potential) failure so we don't cancel
280 // future ops. This is only safe because we have just cleared pending async
281 // nodes, which may have had a reference to he cancellation manager.
282 cancellation_manager_ = std::make_unique<CancellationManager>();
283 }
284
285 TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs,
286 TFE_TensorHandle** inputs,
287 const std::string& string_layout, TF_Status* status);
288
289 std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context,
290 TFE_TensorHandle* input,
291 TF_Status* status);
292
293 // Return the layout for the input tensor.
294 std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input,
295 TF_Status* status);
296
297 TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs,
298 TFE_TensorHandle** indices,
299 TFE_TensorHandle** values,
300 TFE_TensorHandle** shapes,
301 const std::string& string_layout,
302 TF_Status* status);
303
304 bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input,
305 TF_Status* status);
306
307 std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount(
308 TFE_Context* context, TF_Status* status) const;
309
310 private:
311 // If the `operation_name` of an op indicates a custom DTensor op (e.g.
312 // CopyToMesh), then separately handle those custom ops instead of running
313 // default DTensor graph compilation.
314 void MaybeHandleDTensorCustomOps(
315 const char* operation_name, const int num_inputs,
316 const TFE_OpAttrs* attributes, TFE_Context* context,
317 TFE_TensorHandle** inputs, int* num_outputs, TFE_TensorHandle** outputs,
318 bool* is_custom_dtensor_op, TF_Status* status);
319
320 // Copies non-dtensor eager tensor or DTensor to a mesh specified by
321 // `attributes`.
322 // Currently, only copy to replicated layout on target mesh is supported.
323 void CopyToMesh(TFE_Context* context, int num_inputs,
324 TFE_TensorHandle** inputs, const TFE_OpAttrs* attributes,
325 TFE_TensorHandle** outputs, int* num_outputs,
326 TF_Status* status);
327
328 // Update output layouts for eager ops based on same shape policy.
329 void UpdateOutputLayoutsWithSameShapePolicy(
330 const std::vector<PartialTensorShape>& global_output_shapes,
331 const absl::flat_hash_set<Mesh>& input_meshes, absl::string_view op_name,
332 tensorflow::Graph* graph, std::vector<const Layout*>* output_layouts,
333 TF_Status* status);
334
335 // Takes the description of an operation and makes a function out of it with
336 // the same signature, running DTensor MLIR passes. Registers that function
337 // with `context`. `translated_function_name` is set to the name of the
338 // function.
339 //
340 // The resulting function expects a device ID as its first input.
341 void LowerToSPMDFunction(TFE_Context* context,
342 const std::vector<TensorWithLayout*>& inputs,
343 const DTensorOperation& doperation,
344 const TFE_OpAttrs* attributes, const int num_outputs,
345 const ExecutionFunctions** execution_functions,
346 TF_Status* status);
347
348 // Execute a given function.
349 void ExecuteFunctionAndWait(
350 TFE_Context* context, const TranslatedFunction* function_ptr,
351 const MeshWithParallelDevice* parallel_device_mesh,
352 const std::vector<parallel_device::ParallelTensor*>& parallel_inputs,
353 const int64_t step_id, const TFE_OpAttrs* attributes, TF_Status* status);
354
355 // Implements `Execute` for operations which aren't special-cased in
356 void ExecuteRegularOperation(TFE_Context* context,
357 const std::vector<TensorWithLayout*>& inputs,
358 const DTensorOperation& doperation,
359 const TFE_OpAttrs* attributes, int* num_outputs,
360 TFE_TensorHandle** outputs, TF_Status* status);
361
362 // Wraps a TensorWithLayout into a TFE_TensorHandle.
363 TFE_TensorHandle* MakeLayoutTensorHandle(TFE_Context* context,
364 std::unique_ptr<TensorWithLayout> t,
365 TF_Status* status);
366
367 void RecordInShapeLayoutCache(const TensorWithLayout& tensor);
368
369 // Returns whether a given mesh is a remote mesh.
370 bool is_remote_mesh(const Mesh& mesh) const;
371
372 // The name of the device (the custom device)
373 std::string name_;
374 // Mesh configs with matching parallel devices.
375 //
376 // For now we just consider the first entry added to dtensor_device as the
377 // default mesh. Before we reach an agreement on this, we'll leave it as is.
378 absl::flat_hash_map<Mesh, std::unique_ptr<MeshWithParallelDevice>>
379 mesh_to_device_map_;
380 // TODO(hthu): Consider whether we want to preserve the default_mesh semantic.
381 // Current default mesh consistent to default_layout_. If default_layout_ is
382 // not set, it equals to global_default_mesh_.
383 const MeshWithParallelDevice* default_mesh_ = nullptr;
384 // The default mesh of a DTensorDevice, which cannot be modified once being
385 // set.
386 const MeshWithParallelDevice* global_default_mesh_ = nullptr;
387 // If the user has specified a default output layout.
388 absl::optional<Layout> default_layout_;
389
390 // Determines whether tensors with a shape previously associated with only one
391 // layout use that layout if nothing else can be inferred.
392 bool same_shape_policy_enabled_;
393
394 DTensorMlirPassRunner pass_runner_;
395
396 struct CachedLayout {
397 // The first layout seen with this shape
398 Layout layout;
399 // Whether the layout is unique for this shape
400 bool is_unique;
401 };
402 absl::flat_hash_map<int64_t, CachedLayout> shape_layout_cache_;
403
404 FunctionManager function_manager_;
405
406 // Records the function compilation cache hits and misses.
407 std::unordered_map<std::string, int> function_compilation_hits_and_misses_;
408
409 // Coordinates cancelling ops across meshes on error. Must outlive any queued
410 // async op launches, so we only reset it after seeing a failure status.
411 std::unique_ptr<CancellationManager> cancellation_manager_;
412
413 // Map each function_mesh_fingerprint (based on the set of the mesh involved)
414 // to the number of times of the function execution. The
415 // function_mesh_fingerprint and the counter together are used for generating
416 // the step id, which is used for rendezvous creation.
417 absl::flat_hash_map<uint64, uint64> func_mesh_fingerprint_to_step_counter_;
418};
419
420int64_t FingerprintShape(const absl::Span<const int64_t> shape) {
421 int64_t fprint = 0;
422 for (int64_t dim : shape) {
423 fprint = FingerprintCat64(fprint, dim);
424 }
425 return fprint;
426}
427
428parallel_device::ParallelTensor* MeshWithParallelDevice::DeviceIDs(
429 TFE_Context* context, TF_Status* status) const {
430 if (device_ids_tensor_ == nullptr) {
431 // Global device IDs sequentially increase.
432 //
433 // This is the assumption in the dtensor software stack. MLIR pass relies on
434 // this assumption to generate mesh coordinates for each core efficiently.
435 //
436 // The rule to set local ids and the mapping from global ids to real
437 // physical core index, e.g., TPU, is nontrivial unfortunately. It is
438 // possible to set identical mapping but the collective operation
439 // performance is terrible for most of cases.
440 //
441 // - For ICI-connected TPU slice, see go/dtensor-device-assignment-summary
442 // for guide how to create efficient core assignments toward peak
443 // performance.
444 //
445 // The global id to core assignment mapping is bridged by
446 // `Mesh::tpu_core_ids()` and consumed by `UpdateTPUCompileMetadata`.
447 //
448 // - For DCN-connected topology, we need to map different sections of the
449 // global ids to its real physical cores separately according to the
450 // runtime requirements. For example, for a 4x32 mesh, in which the outer
451 // dimension is connected via DCN and inner dimension is connected by ICI,
452 // the device assignments for inner dimension should typically form its
453 // own ring order (not plain physical core index) in each sub-meshes and
454 // the outer dimension should be assigned according to the real physical
455 // ring of DNC hosts.
456 //
457 // Note: In order to change this assumption, MLIR pass needs adjustment. One
458 // possible approach is to take a N-D mapping vector for N-D mesh and lookup
459 // the coordinates in MLIR, by consulting tensor layout as well, rather than
460 // calculation on-the-fly.
461
462 // LINT.IfChange
463 for (int64_t i = 0; i < mesh_config_.global_device_ids().size(); ++i) {
464 if (mesh_config_.global_device_ids()[i] - i !=
465 mesh_config_.global_device_ids()[0]) {
466 TF_SetStatus(
467 status, TF_INTERNAL,
468 absl::StrCat("Global device IDs should be consecutive: ",
469 absl::StrJoin(mesh_config_.global_device_ids(), ", "))
470 .c_str());
471 return nullptr;
472 }
473 }
474 // LINT.ThenChange(//tensorflow/dtensor/python/layout.py)
475
476 // Local device IDs are a subset of global device IDs, arranged in device
477 // ordinal order.
478 std::vector<int32_t> ids;
479 for (int64_t id : mesh_config_.local_device_ids()) {
480 ids.push_back(id);
481 }
482 VLOG(1) << "Parallel device IDs: " << absl::StrJoin(ids, ", ");
483 device_ids_tensor_ =
484 parallel_device_->ScalarsFromSequence<int32_t>(ids, context, status);
485 if (TF_GetCode(status) != TF_OK) return nullptr;
486 }
487 return device_ids_tensor_.get();
488}
489
490int TensorWithLayoutNumDims(void* data, TF_Status* status) {
491 return reinterpret_cast<TensorWithLayout*>(data)->global_shape().size();
492}
493
494int64_t TensorWithLayoutDim(void* data, int dim_index, TF_Status* status) {
495 return reinterpret_cast<TensorWithLayout*>(data)->global_shape()[dim_index];
496}
497
498void TensorWithLayoutDeallocator(void* data) {
499 delete reinterpret_cast<TensorWithLayout*>(data);
500}
501
502TF_Buffer* TensorWithLayoutSummarize(void* data, TF_Status* status) {
503 std::string summary =
504 reinterpret_cast<TensorWithLayout*>(data)->SummarizeValue();
505 return TF_NewBufferFromString(summary.data(), summary.size());
506}
507
508TFE_TensorHandle* DTensorDevice::MakeLayoutTensorHandle(
509 TFE_Context* context, std::unique_ptr<TensorWithLayout> t,
510 TF_Status* status) {
511 TF_DataType dtype = t->dtype();
512 TFE_CustomDeviceTensorHandleMethods handle_methods;
513 handle_methods.num_dims = &TensorWithLayoutNumDims;
514 handle_methods.dim = &TensorWithLayoutDim;
515 handle_methods.deallocator = &TensorWithLayoutDeallocator;
516 handle_methods.summarize = &TensorWithLayoutSummarize;
517 return TFE_NewCustomDeviceTensorHandle(context, name_.c_str(), dtype,
518 /*data=*/t.release(), handle_methods,
519 status);
520}
521
522void DTensorDevice::RecordInShapeLayoutCache(const TensorWithLayout& tensor) {
523 auto existing = shape_layout_cache_.insert(
524 {FingerprintShape(tensor.global_shape()),
525 CachedLayout{tensor.layout(), /*is_unique=*/true}});
526
527 if (!existing.second) {
528 // There is an entry already; if the layout doesn't match we should record
529 // the fact that it's not unique.
530 if (tensor.layout() != existing.first->second.layout) {
531 existing.first->second.is_unique = false;
532 }
533 }
534}
535
536bool DTensorDevice::is_remote_mesh(const Mesh& mesh) const {
537 // An empty mesh might be assigned to VarHandleOp during DTensor MLIR lowering
538 // pass. Decide whether the empty mesh is remote based on the current default
539 // mesh.
540 return mesh.is_remote() ||
541 (mesh.IsEmpty() && default_mesh_->mesh_config().is_remote());
542}
543
544StatusOr<NameAttrList> FetchAttributes(const TFE_OpAttrs* attributes) {
545 // TODO(allenl): Should we just give up on the public C API to save on
546 // serialization/deserialization? We need all of the attributes and to treat
547 // them generically, which isn't going to be pleasant with typed attribute
548 // methods.
549 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> serialized_attributes(
550 TF_NewBuffer(), TF_DeleteBuffer);
551
552 TF_Status* status = TF_NewStatus();
553 TFE_OpAttrsSerialize(attributes, serialized_attributes.get(), status);
554 if (TF_GetCode(status) == TF_OK) {
555 TF_DeleteStatus(status);
556 } else {
557 Status failure_status = StatusFromTF_Status(status);
558 TF_DeleteStatus(status);
559 return failure_status;
560 }
561
562 NameAttrList name_and_attrs;
563 if (!name_and_attrs.ParseFromArray(serialized_attributes->data,
564 serialized_attributes->length)) {
565 return tensorflow::errors::Unknown("Could not parse attributes");
566 }
567 return name_and_attrs;
568}
569
570StatusOr<Layout> FetchLayoutFromAttributes(const TFE_OpAttrs* attributes,
571 absl::string_view attribute_name) {
572 // Get attributes.
573 TF_ASSIGN_OR_RETURN(NameAttrList name_and_attrs, FetchAttributes(attributes));
574
575 // Get layout string from attributes.
576 absl::string_view layout_str =
577 name_and_attrs.attr().find(std::string(attribute_name))->second.s();
578
579 // This would probably be slow at the moment without caching.
580 // We should consider making this faster in the future.
581 return Layout::FromString(string(layout_str));
582}
583
584std::string DTensorDevice::FetchLayout(TFE_Context* context,
585 TFE_TensorHandle* input,
586 TF_Status* status) {
587 VLOG(1) << "Checking layout...";
588 const char* input_device = TFE_TensorHandleDeviceName(input, status);
589 if (input_device != name_) {
590 TF_SetStatus(status, TF_INVALID_ARGUMENT,
591 "FetchLayout expects a tensor placed on the layout device.");
592 return {};
593 }
594 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
595 TFE_TensorHandleDevicePointer(input, status));
596 if (TF_GetCode(status) != TF_OK) return {};
597 return t->layout().ToString();
598}
599
600std::vector<TFE_TensorHandle*> DTensorDevice::Unpack(TFE_Context* context,
601 TFE_TensorHandle* input,
602 TF_Status* status) {
603 std::vector<TFE_TensorHandle*> outputs;
604
605 const char* input_device = TFE_TensorHandleDeviceName(input, status);
606 if (TF_GetCode(status) != TF_OK) return outputs;
607 if (input_device != name_) {
608 TF_SetStatus(
609 status, TF_INVALID_ARGUMENT,
610 absl::StrCat(
611 "DTensorUnpack expects a tensor placed on the DTensor device: ",
612 name_, ", but input was placed on device: ", input_device)
613 .c_str());
614 return outputs;
615 }
616 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
617 TFE_TensorHandleDevicePointer(input, status));
618 if (TF_GetCode(status) != TF_OK) return outputs;
619
620 if (is_remote_mesh(t->mesh().mesh_config())) {
621 TF_SetStatus(status, TF_UNIMPLEMENTED,
622 "DTensorUnpack is not supported on a remote mesh.");
623 return outputs;
624 }
625 const int output_size = t->num_tensors();
626 outputs.resize(output_size);
627
628 for (int output_index = 0; output_index < output_size; ++output_index) {
629 outputs[output_index] =
630 TFE_TensorHandleCopySharingTensor(t->get_tensor(output_index), status);
631 if (TF_GetCode(status) != TF_OK) {
632 return outputs;
633 }
634 }
635 return outputs;
636}
637
638void DTensorDevice::MaybeHandleDTensorCustomOps(
639 const char* operation_name, const int num_inputs,
640 const TFE_OpAttrs* attributes, TFE_Context* context,
641 TFE_TensorHandle** inputs, int* num_outputs, TFE_TensorHandle** outputs,
642 bool* is_custom_dtensor_op, TF_Status* status) {
643 *is_custom_dtensor_op = true;
644 if (operation_name == std::string("_EagerConst")) {
645 // Op-by-op const has no obvious layout. DTensor skips an SPMD expansion and
646 // instead relies on copy-on when the value is used later.
647 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
648 TFE_NewOp(context, operation_name, status), TFE_DeleteOp);
649 if (TF_GetCode(status) != TF_OK) return;
650 for (int input_index = 0; input_index < num_inputs; ++input_index) {
651 TFE_OpAddInput(op.get(), inputs[input_index], status);
652 if (TF_GetCode(status) != TF_OK) return;
653 }
654 TFE_OpAddAttrs(op.get(), attributes);
655 TFE_Execute(op.get(), outputs, num_outputs, status);
656 return;
657 }
658 if (operation_name == std::string("CopyToMesh")) {
659 CopyToMesh(context, num_inputs, inputs, attributes, outputs, num_outputs,
660 status);
661 return;
662 }
663
664 *is_custom_dtensor_op = false;
665}
666
667void DTensorDevice::CopyToMesh(TFE_Context* context, int num_inputs,
668 TFE_TensorHandle** inputs,
669 const TFE_OpAttrs* attributes,
670 TFE_TensorHandle** outputs, int* num_outputs,
671 TF_Status* status) {
672 if (num_inputs != 1) {
673 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
674 "DTensor CopyToMesh requires exactly 1 input.");
675 }
676 if (*num_outputs < 1) {
677 RETURN_STATUS(status, TF_INTERNAL,
678 "DTensor CopyToMesh must have output buffer to allocate at "
679 "least 1 output.");
680 }
681
682 // Assign layout.
683 StatusOr<Layout> target_layout_or =
684 FetchLayoutFromAttributes(attributes, kQualifiedLayoutAttr);
685 if (!target_layout_or.ok()) {
686 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
687 "DTensor CopyToMesh requires valid layout attribute for "
688 "destination DTensor.");
689 }
690
691 const Layout target_layout = *target_layout_or;
692 const Mesh& target_mesh = target_layout.mesh();
693
694 // TODO(b/193443769): Support sharded layout for eager copy to mesh.
695 if (!target_layout.IsFullyReplicated()) {
696 RETURN_STATUS(status, TF_UNIMPLEMENTED,
697 "Target layout of DTensor CopyToMesh must be replicated. "
698 "Consider changing the target layout to replicated layout or "
699 "file a bug to the DTensor team (b/193443769).");
700 }
701
702 TFE_TensorHandle* input_tensor = inputs[0];
703
704 // Check that if input tensor is DTensor, then input layout of the DTensor
705 // must be replicated.
706 const char* input_device = TFE_TensorHandleDeviceName(input_tensor, status);
707 if (TF_GetCode(status) != TF_OK) return;
708
709 if (name_ == input_device) {
710 // Handle input which is on DTensor device already.
711 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
712 TFE_TensorHandleDevicePointer(input_tensor, status));
713 if (TF_GetCode(status) != TF_OK) return;
714
715 if (!t->layout().IsFullyReplicated())
716 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
717 "Input tensor to CopyToMesh must be replicated DTensor or "
718 "normal eager Tensor.");
719
720 // If input to CopyToMesh is a DTensor, we use the first local tensor as
721 // input tensor handle to invoke copy.
722 input_tensor = t->get_tensor(0);
723 }
724
725 auto it = mesh_to_device_map_.find(target_mesh);
726 if (it == mesh_to_device_map_.end()) {
727 RETURN_STATUS(
728 status, TF_INTERNAL,
729 "DTensor CopyToMesh target mesh is not registered. Meshes should be "
730 "automatically registered. Please file a bug. (component id: 833864)");
731 }
732
733 const MeshWithParallelDevice* target_parallel_mesh = it->second.get();
734
735 // Broadcast non-dtensor value to dtensor.
736 std::unique_ptr<TensorWithLayout> wrapper = TensorWithLayout::Broadcast(
737 context, input_tensor, *target_parallel_mesh, name_, status);
738 if (TF_GetCode(status) != TF_OK) return;
739
740 RecordInShapeLayoutCache(*wrapper);
741 *num_outputs = 1;
742 *outputs = MakeLayoutTensorHandle(context, std::move(wrapper), status);
743}
744
745namespace {
746
747// Verifies that all components have the same dtype and shape.
748// The component shape will be set upon success.
749void VerifyPackTensorShapeAndDtype(
750 std::vector<parallel_device::TensorHandlePtr>& components,
751 std::vector<int64_t>* component_shape, TF_Status* status) {
752 TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
753 auto size = TFE_TensorHandleNumDims(components[0].get(), status);
754 if (TF_GetCode(status) != TF_OK) return;
755 component_shape->clear();
756 component_shape->reserve(size);
757 for (int i = 0; i < size; ++i) {
758 component_shape->push_back(
759 TFE_TensorHandleDim(components[0].get(), i, status));
760 if (TF_GetCode(status) != TF_OK) return;
761 }
762
763 // Verify that the TensorHandle's shape and dtype match all of the component
764 // shapes and dtypes.
765 for (const auto& component : components) {
766 for (int i = 0; i < component_shape->size(); ++i) {
767 int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
768 if (TF_GetCode(status) != TF_OK) return;
769 if (tensor_dim != (*component_shape)[i]) {
770 TF_SetStatus(status, TF_UNIMPLEMENTED,
771 "Components of a PackedTensor must currently all have "
772 "the same shape");
773 return;
774 }
775 if (TFE_TensorHandleDataType(component.get()) != dtype) {
776 TF_SetStatus(status, TF_INTERNAL,
777 "Components of a PackedTensor must all have "
778 "the same dtype");
779 return;
780 }
781 }
782 }
783}
784
785// Verifies that all TensorHandles have rank `rank` of dtype `dtype`.
786void VerifyTensorRankAndDType(TFE_TensorHandle** tensors, int num_input,
787 int expected_rank, TF_DataType* expected_dtype,
788 TF_Status* status) {
789 for (int i = 0; i < num_input; ++i) {
790 auto actual_rank = TFE_TensorHandleNumDims(tensors[i], status);
791 if (TF_GetCode(status) != TF_OK)
792 RETURN_STATUS(status, TF_INTERNAL, "Error getting rank of tensor.");
793 if (actual_rank != expected_rank)
794 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
795 "Rank of tensor did not match the expected rank.");
796 if (expected_dtype != nullptr &&
797 TFE_TensorHandleDataType(tensors[i]) != *expected_dtype)
798 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
799 "Dtype of tensor did not match the expected dtype.");
800 }
801}
802
803} // namespace
804
805TFE_TensorHandle* DTensorDevice::Pack(TFE_Context* context, int num_inputs,
806 TFE_TensorHandle** inputs,
807 const std::string& string_layout,
808 TF_Status* status) {
809 if (num_inputs < 1) {
810 TF_SetStatus(status, TF_INVALID_ARGUMENT,
811 "DTensorPack requires 1 or more inputs");
812 return nullptr;
813 }
814 StatusOr<Layout> target_layout = Layout::FromString(string_layout);
815 if (!target_layout.ok()) {
816 TF_SetStatus(status, TF_INVALID_ARGUMENT,
817 "Failed to parse layout from string layout");
818 return nullptr;
819 }
820 const Mesh& target_mesh = target_layout->mesh();
821 const MeshWithParallelDevice* target_parallel_device =
822 mesh_to_device_map_[target_mesh].get();
823 if (target_parallel_device == nullptr) {
824 TF_SetStatus(status, TF_INVALID_ARGUMENT,
825 absl::StrCat("Required mesh : ", target_mesh.ToString(),
826 "is not registered with DTensor ")
827 .c_str());
828 return nullptr;
829 }
830
831 std::unique_ptr<TensorWithLayout> packed_tensor;
832 if (is_remote_mesh(target_parallel_device->mesh_config())) {
833 // Create a dummy output for DTensorPack if inputs are on a remote mesh.
834 TF_DataType dtype = TFE_TensorHandleDataType(inputs[0]);
835 auto size = TFE_TensorHandleNumDims(inputs[0], status);
836 if (TF_GetCode(status) != TF_OK) return nullptr;
837 std::vector<int64_t> component_shape;
838 component_shape.reserve(size);
839 for (int i = 0; i < size; ++i) {
840 component_shape.push_back(TFE_TensorHandleDim(inputs[0], i, status));
841 if (TF_GetCode(status) != TF_OK) return nullptr;
842 }
843 packed_tensor = TensorWithLayout::Dummy(
844 component_shape, dtype, *target_parallel_device, *target_layout);
845
846 } else {
847 auto local_devices = target_parallel_device->mesh_config().local_devices();
848
849 if (num_inputs !=
850 target_parallel_device->parallel_device().num_underlying_devices()) {
851 TF_SetStatus(status, TF_INVALID_ARGUMENT,
852 absl::StrCat("The dtensor device ", name_, " expected ",
853 local_devices.size(),
854 " inputs to DTensorPack, but got ", num_inputs)
855 .c_str());
856 return nullptr;
857 }
858
859 std::vector<parallel_device::TensorHandlePtr> components;
860 components.reserve(num_inputs);
861 for (int i = 0; i < num_inputs; ++i) {
862 TFE_TensorHandle* input = inputs[i];
863 const char* input_device = TFE_TensorHandleDeviceName(input, status);
864 if (TF_GetCode(status) != TF_OK) return nullptr;
865 if (name_ == input_device) {
866 TF_SetStatus(status, TF_INVALID_ARGUMENT,
867 "Does not support packing a Tensor that is already on "
868 "dtensor device");
869 return nullptr;
870 }
871 // If `input` is on the target device, this creates a new handle sharing
872 // the underlying data; otherwise, async copies are invoked.
873 components.emplace_back(TFE_TensorHandleCopyToDevice(
874 input, context, local_devices[i].c_str(), status));
875 if (TF_GetCode(status) != TF_OK) return nullptr;
876 }
877
878 std::vector<int64_t> component_shape;
879 VerifyPackTensorShapeAndDtype(components, &component_shape, status);
880 if (TF_GetCode(status) != TF_OK) return nullptr;
881
882 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
883 parallel_device::ParallelTensor::FromTensorHandles(
884 target_parallel_device->parallel_device(), std::move(components),
885 status);
886 if (TF_GetCode(status) != TF_OK) return nullptr;
887
888 if (target_layout->rank() != component_shape.size()) {
889 TF_SetStatus(
890 status, TF_INVALID_ARGUMENT,
891 absl::StrCat(
892 "Packed layout should have the same rank as the rank for each "
893 "component. The rank of each component is: ",
894 component_shape.size(),
895 ", while layout has rank: ", target_layout->rank(),
896 "\nLayout: ", target_layout->ToString(), "\n")
897 .c_str());
898 return nullptr;
899 }
900
901 packed_tensor =
902 TensorWithLayout::Wrap(std::move(parallel_tensor),
903 *target_parallel_device, *target_layout)
904 .value();
905 }
906
907 RecordInShapeLayoutCache(*packed_tensor);
908 TFE_TensorHandle* output =
909 MakeLayoutTensorHandle(context, std::move(packed_tensor), status);
910 if (TF_GetCode(status) != TF_OK) return nullptr;
911 return output;
912}
913
914TFE_TensorHandle* DTensorDevice::SparsePack(
915 TFE_Context* context, int num_inputs, TFE_TensorHandle** indices,
916 TFE_TensorHandle** values, TFE_TensorHandle** shapes,
917 const std::string& string_layout, TF_Status* status) {
918 StatusOr<Layout> target_layout = Layout::FromString(string_layout);
919 if (!target_layout.ok()) {
920 TF_SetStatus(status, TF_INVALID_ARGUMENT,
921 "Failed to parse layout from string layout");
922 return nullptr;
923 }
924 const Mesh& target_mesh = target_layout->mesh();
925 const MeshWithParallelDevice* target_parallel_device =
926 mesh_to_device_map_[target_mesh].get();
927 if (target_parallel_device == nullptr) {
928 TF_SetStatus(status, TF_INVALID_ARGUMENT,
929 absl::StrCat("Required mesh : ", target_mesh.ToString(),
930 "is not registered with DTensor ")
931 .c_str());
932 return nullptr;
933 }
934
935 TF_DataType tf_int64 = TF_INT64;
936 // Verify rank and dtype of shapes.
937 VerifyTensorRankAndDType(shapes, num_inputs, /*expected_rank=*/1,
938 /*expected_dtype=*/&tf_int64, status);
939 if (TF_GetCode(status) != TF_OK) return nullptr;
940
941 // Verify rank and dtype of indices.
942 VerifyTensorRankAndDType(indices, num_inputs, /*expected_rank=*/2,
943 /*expected_dtype=*/&tf_int64, status);
944 if (TF_GetCode(status) != TF_OK) return nullptr;
945
946 // Verify rank of values.
947 VerifyTensorRankAndDType(values, num_inputs, /*expected_rank=*/1,
948 /*expected_dtype=*/nullptr, status);
949 if (TF_GetCode(status) != TF_OK) return nullptr;
950
951 // Compute the local shape from a shape tensor.
952 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> shape_tensor(
953 TFE_TensorHandleResolve(shapes[0], status), TF_DeleteTensor);
954 if (TF_GetCode(status) != TF_OK) {
955 TF_SetStatus(
956 status, TF_GetCode(status),
957 absl::StrCat("Error resolving the tensor handle of shape tensor"
958 ". Original message: ",
959 TF_Message(status))
960 .c_str());
961 return nullptr;
962 }
963 int shape_tensor_size = TFE_TensorHandleDim(shapes[0], 0, status);
964 if (TF_GetCode(status) != TF_OK || shape_tensor_size <= 0) {
965 TF_SetStatus(status, TF_GetCode(status),
966 absl::StrCat("Error computing the num dims of shape tensor",
967 TF_Message(status))
968 .c_str());
969 return nullptr;
970 }
971
972 const int64_t* data =
973 static_cast<int64_t*>(TF_TensorData(shape_tensor.get()));
974 std::vector<int64_t> local_shape(data, data + shape_tensor_size);
975 if (local_shape.size() != target_layout->rank()) {
976 TF_SetStatus(
977 status, TF_INVALID_ARGUMENT,
978 absl::StrCat(
979 "Packed layout should have the same rank as the rank for each "
980 "component. The rank of each component is: ",
981 local_shape.size(),
982 ", while layout has rank: ", target_layout->rank(),
983 "\nLayout: ", target_layout->ToString(), "\n")
984 .c_str());
985 return nullptr;
986 }
987
988 // Create the SparseTensorWithLayout.
989 std::unique_ptr<TensorWithLayout> packed_tensor;
990 if (is_remote_mesh(target_parallel_device->mesh_config())) {
991 // Create a dummy SparseTensorWithLayout.
992 packed_tensor = SparseTensorWithLayout::Dummy(
993 local_shape, *target_parallel_device, target_layout.value());
994 } else {
995 // Parse the indices, values, and dense_shape tensors and put them into
996 // parallel tensors, and then pack it into a single SparseTensorWithLayout.
997 auto local_devices = target_parallel_device->mesh_config().local_devices();
998
999 std::vector<parallel_device::TensorHandlePtr> indices_components;
1000 std::vector<parallel_device::TensorHandlePtr> values_components;
1001 std::vector<parallel_device::TensorHandlePtr> dense_shapes_components;
1002
1003 // Just a nice trick to make code cleaner to pack each of indices, values,
1004 // shapes.
1005 std::vector<std::vector<parallel_device::TensorHandlePtr>*> components{
1006 &indices_components, &values_components, &dense_shapes_components};
1007 std::vector<TFE_TensorHandle**> input_vectors{indices, values, shapes};
1008 for (int component_index = 0; component_index < 3; ++component_index) {
1009 components[component_index]->reserve(num_inputs);
1010 TFE_TensorHandle** inputs = input_vectors[component_index];
1011 for (int i = 0; i < num_inputs; ++i) {
1012 const char* input_device =
1013 TFE_TensorHandleDeviceName(inputs[i], status);
1014 if (TF_GetCode(status) != TF_OK) return nullptr;
1015 if (name_ == input_device) {
1016 TF_SetStatus(status, TF_INVALID_ARGUMENT,
1017 "Does not support packing a Tensor that is already on "
1018 "dtensor device.");
1019 return nullptr;
1020 }
1021
1022 components[component_index]->emplace_back(TFE_TensorHandleCopyToDevice(
1023 inputs[i], context, local_devices[i].c_str(), status));
1024 if (TF_GetCode(status) != TF_OK) return nullptr;
1025 }
1026 }
1027 std::unique_ptr<parallel_device::ParallelTensor> parallel_indices_tensor =
1028 parallel_device::ParallelTensor::FromTensorHandles(
1029 target_parallel_device->parallel_device(),
1030 std::move(indices_components), status);
1031
1032 std::unique_ptr<parallel_device::ParallelTensor> parallel_values_tensor =
1033 parallel_device::ParallelTensor::FromTensorHandles(
1034 target_parallel_device->parallel_device(),
1035 std::move(values_components), status);
1036
1037 std::unique_ptr<parallel_device::ParallelTensor>
1038 parallel_dense_shapes_tensor =
1039 parallel_device::ParallelTensor::FromTensorHandles(
1040 target_parallel_device->parallel_device(),
1041 std::move(dense_shapes_components), status);
1042
1043 if (TF_GetCode(status) != TF_OK) return nullptr;
1044 packed_tensor =
1045 SparseTensorWithLayout::Wrap(std::move(parallel_indices_tensor),
1046 std::move(parallel_values_tensor),
1047 std::move(parallel_dense_shapes_tensor),
1048 *target_parallel_device,
1049 target_layout.value(), local_shape)
1050 .value();
1051 }
1052
1053 RecordInShapeLayoutCache(*packed_tensor);
1054 TFE_TensorHandle* output =
1055 MakeLayoutTensorHandle(context, std::move(packed_tensor), status);
1056 if (TF_GetCode(status) != TF_OK) return nullptr;
1057 return output;
1058}
1059
1060bool DTensorDevice::IsSparseDTensor(TFE_Context* context,
1061 TFE_TensorHandle* input,
1062 TF_Status* status) {
1063 const char* input_device = TFE_TensorHandleDeviceName(input, status);
1064 if (input_device != name_) {
1065 TF_SetStatus(
1066 status, TF_INVALID_ARGUMENT,
1067 "DTensorSparseUnpack expects a tensor placed on the DTensor device.");
1068 return false;
1069 }
1070 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
1071 TFE_TensorHandleDevicePointer(input, status));
1072 if (TF_GetCode(status) != TF_OK) return false;
1073 return t->tensor_type() == TensorType::kSparse;
1074}
1075
1076void DTensorDevice::UpdateOutputLayoutsWithSameShapePolicy(
1077 const std::vector<PartialTensorShape>& global_output_shapes,
1078 const absl::flat_hash_set<Mesh>& input_meshes, absl::string_view op_name,
1079 tensorflow::Graph* graph, std::vector<const Layout*>* output_layouts,
1080 TF_Status* status) {
1081 if (!same_shape_policy_enabled_) return;
1082 // Simply do not hint if inputs span across multiple meshes.
1083 if (input_meshes.size() > 1) return;
1084
1085 for (Node* node : graph->op_nodes()) {
1086 if (!node->IsRetval()) {
1087 continue;
1088 }
1089 int output_index;
1090 RETURN_C_STATUS_IF_NOT_OK(
1091 GetNodeAttr(node->attrs(), "index", &output_index), status);
1092 if (output_layouts->at(output_index)) {
1093 continue;
1094 }
1095
1096 const auto& global_output_shape = global_output_shapes.at(output_index);
1097 const Layout* layout = nullptr;
1098 // TODO(b/180022708): This is useful information, we should be
1099 // able to hint to layout propagation without making it a hard
1100 // requirement
1101 //
1102 // Special cases at the moment:
1103 // - Relayout needs an exemption.
1104 // - VarHandleOp does not need hint. VarHandleOp has scalar shape so layout
1105 // is trivial. On the other hande, downstream system "thinks' Variable has
1106 // shape same as the pointing value. So, providing a layout based on
1107 // VarHandleOp (scalar) might confuse the downstream system.
1108 if (op_name != std::string("Relayout") &&
1109 op_name != std::string("VarHandleOp")) {
1110 // TODO(b/162009702): Support matching between partially-known shapes.
1111 if (global_output_shape.IsFullyDefined()) {
1112 gtl::InlinedVector<int64, 4> shape_vector(
1113 global_output_shape.dim_sizes());
1114 auto layout_iterator =
1115 shape_layout_cache_.find(FingerprintShape(shape_vector));
1116 if (layout_iterator != shape_layout_cache_.end() &&
1117 layout_iterator->second.is_unique) {
1118 // We have a cached layout for this shape. Send it to MLIR.
1119 layout = &layout_iterator->second.layout;
1120 VLOG(3) << op_name << ": found a cached layout for shape "
1121 << global_output_shape.DebugString() << ": \""
1122 << layout->ToString() << "\"";
1123 if (input_meshes.empty() &&
1124 layout->mesh() != default_mesh_->mesh_config()) {
1125 VLOG(3) << "But we can't infer a input mesh and cached layout: "
1126 << "mesh \"" << (layout->mesh().ToString()) << " "
1127 << "is different than the default mesh : \""
1128 << default_mesh_->mesh_config().ToString() << "\"\n"
1129 << "Not applying the cached layout.";
1130 } else if (!input_meshes.empty() &&
1131 layout->mesh() != *input_meshes.begin()) {
1132 VLOG(3)
1133 << "But the layout mesh is different than the executing mesh: "
1134 << "\"" << (*input_meshes.begin()).ToString() << "\"\n"
1135 << "Not applying the cached layout.";
1136 } else {
1137 (*output_layouts)[output_index] = layout;
1138 node->AddAttr(kDefaultLayoutAttr, layout->ToString());
1139 }
1140 } else if (layout_iterator == shape_layout_cache_.end()) {
1141 VLOG(3) << op_name << ": no cached layout found for shape "
1142 << global_output_shape.DebugString();
1143 } else {
1144 VLOG(3) << op_name << ": found multiple layouts for shape "
1145 << global_output_shape.DebugString();
1146 }
1147 } else {
1148 VLOG(3) << op_name
1149 << ": not applying same-shape-same-layout due to "
1150 "not-fully-known shape "
1151 << global_output_shape.DebugString();
1152 }
1153 }
1154 }
1155}
1156
1157std::unordered_map<std::string, int>
1158DTensorDevice::GetFunctionCacheHitAndMissCount(TFE_Context* context,
1159 TF_Status* status) const {
1160 return function_compilation_hits_and_misses_;
1161}
1162
1163// From `graph` containing computation for all meshes, extract/select
1164// computation for mesh specified in `function`. Returned graph is a cloned
1165// graph with ops only for single mesh execution.
1166StatusOr<std::unique_ptr<Graph>> SelectGraphToExecute(
1167 const TranslatedFunction& function, const Graph& graph,
1168 std::string* stateful_partitioned_call_name) {
1169 auto new_graph = std::make_unique<Graph>(graph.flib_def());
1170 CopyGraph(graph, new_graph.get());
1171 std::vector<Node*> arg_nodes;
1172 std::vector<Node*> retval_nodes;
1173 for (Node* node : new_graph->nodes()) {
1174 if (node->IsArg()) arg_nodes.emplace_back(node);
1175 if (node->IsRetval()) retval_nodes.emplace_back(node);
1176 }
1177
1178 // Remove irrelevant function calls.
1179 for (Node* node : new_graph->nodes()) {
1180 if (node->op_def().name() != "StatefulPartitionedCall") continue;
1181
1182 if (node->name() != function.node_to_execute->name()) {
1183 // Remove function call that does not match mesh specification and all
1184 // output retval nodes connected to the function call node.
1185 std::queue<Node*> nodes_to_remove;
1186 nodes_to_remove.push(node);
1187 while (!nodes_to_remove.empty()) {
1188 Node* n = nodes_to_remove.front();
1189 for (const Edge* out_edge : n->out_edges()) {
1190 if (out_edge->IsControlEdge()) continue;
1191 Node* out_node = out_edge->dst();
1192 if (!out_node->IsSink()) nodes_to_remove.push(out_node);
1193 }
1194 if (n->IsRetval()) {
1195 auto pos = std::find(retval_nodes.begin(), retval_nodes.end(), n);
1196 TF_RET_CHECK(pos != retval_nodes.end());
1197 retval_nodes.erase(pos);
1198 }
1199 nodes_to_remove.pop();
1200 new_graph->RemoveNode(n);
1201 }
1202 }
1203 }
1204
1205 *stateful_partitioned_call_name = function.node_to_execute->name();
1206 VLOG(1) << "Selected call " << *stateful_partitioned_call_name;
1207
1208 // Remove unused arg nodes in graph.
1209 for (auto it = arg_nodes.begin(); it != arg_nodes.end(); it++) {
1210 Node* arg_node = *it;
1211 bool arg_unused = true;
1212 for (const Edge* e : arg_node->out_edges()) {
1213 if (e->dst()->IsOp()) {
1214 arg_unused = false;
1215 }
1216 }
1217 if (!arg_unused) continue;
1218
1219 new_graph->RemoveNode(arg_node);
1220 arg_nodes.erase(it--);
1221 }
1222
1223 // Reset index attributes for arg and retval nodes.
1224 for (Node* n : new_graph->nodes()) {
1225 // Reset arg node index attributes to its position within all the arg
1226 // nodes. This should just be increasing from 0 to n where n
1227 // is the total number of arguments. Note that this definition to
1228 // the `index` attribute is different from the definition we set in
1229 // PrepareGraphForMLIR.
1230 // This attribute is needed for each arg node when converting a Graph to
1231 // a FunctionDef.
1232 if (n->IsArg()) {
1233 auto pos = std::find(arg_nodes.begin(), arg_nodes.end(), n);
1234 TF_RET_CHECK(pos != arg_nodes.end());
1235 const int new_index = std::distance(arg_nodes.begin(), pos);
1236 n->AddAttr("index", new_index);
1237 }
1238
1239 // Reset retval nodes index attributes.
1240 if (n->IsRetval()) {
1241 auto retval_pos = std::find(retval_nodes.begin(), retval_nodes.end(), n);
1242 TF_RET_CHECK(retval_pos != retval_nodes.end());
1243 const int new_index = std::distance(retval_nodes.begin(), retval_pos);
1244 n->AddAttr("index", new_index);
1245 }
1246 }
1247
1248 VLOG(4) << tensorflow::DumpGraphToFile("selected_graph_to_execute_",
1249 *new_graph);
1250
1251 return new_graph;
1252}
1253
1254// Adds processed graph to run for each mesh computation in
1255// `execution_functions` to function definition library.
1256Status AddExecutionFunctionDefsToFunctionDefLibrary(
1257 const absl::flat_hash_set<Node*>& control_ret_nodes, TFE_Context* context,
1258 const Graph& graph, ExecutionFunctions* execution_functions) {
1259 // Note: We use node name instead of node pointer for comparison because
1260 // node address in the new graph is different with the original graph.
1261 absl::flat_hash_set<std::string> control_ret_names;
1262 for (auto* n : control_ret_nodes) {
1263 control_ret_names.emplace(n->name());
1264 }
1265 for (TranslatedFunction& function : execution_functions->function_list) {
1266 std::string selected_call_node_name;
1267 // TODO(bfontain): We should just try to call the functions directly rather
1268 // than wrap
1269 // Construct graph that executes only computation for `function`.
1270 TF_ASSIGN_OR_RETURN(
1271 std::unique_ptr<Graph> new_graph,
1272 SelectGraphToExecute(function, graph, &selected_call_node_name));
1273 VLOG(4) << tensorflow::DumpGraphToFile("selected_graph_", *new_graph);
1274
1275 // Add unique identifier based on the function we are executing to the
1276 // function/graph and convert graph to functiondef.
1277 NameAttrList func;
1278 TF_RETURN_IF_ERROR(
1279 GetNodeAttr(function.node_to_execute->attrs(), "f", &func));
1280
1281 static std::atomic<int64_t> unique_function_number(0);
1282 function.translated_function_name =
1283 absl::StrCat(func.name(), "_", unique_function_number.fetch_add(1));
1284 auto control_ret_node_names =
1285 [&control_ret_names, &selected_call_node_name](
1286 const Node* node) -> absl::optional<std::string> {
1287 // Add the stateful partitioned call node as a control return as we need
1288 // to process any control deps inside the inner function.
1289 if (control_ret_names.contains(node->name()) ||
1290 node->name() == selected_call_node_name) {
1291 return node->name();
1292 }
1293 return absl::nullopt;
1294 };
1295
1296 tensorflow::FunctionDef to_run;
1297 TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef(
1298 *new_graph, function.translated_function_name, control_ret_node_names,
1299 &to_run));
1300
1301 for (const auto& out : to_run.signature().output_arg()) {
1302 function.output_dtypes.emplace_back(static_cast<TF_DataType>(out.type()));
1303 }
1304
1305 AddDTensorFunctionAttr(to_run);
1306 TF_RETURN_IF_ERROR(tensorflow::unwrap(context)->AddFunctionDef(to_run));
1307 }
1308
1309 return OkStatus();
1310}
1311
1312void DTensorDevice::LowerToSPMDFunction(
1313 TFE_Context* context, const std::vector<TensorWithLayout*>& inputs,
1314 const DTensorOperation& doperation, const TFE_OpAttrs* attributes,
1315 const int num_outputs, const ExecutionFunctions** execution_functions,
1316 TF_Status* status) {
1317 profiler::TraceMe activity(
1318 [&] { return "DTensorDevice::LowerToSPMDFunction"; },
1319 profiler::TraceMeLevel::kInfo);
1320 FunctionLibraryDefinition* flib_def =
1321 tensorflow::unwrap(context)->FuncLibDef();
1322 auto graph(std::make_unique<tensorflow::Graph>(flib_def));
1323 NameAttrList eager_attributes;
1324 ASSIGN_OR_RETURN_C_STATUS(eager_attributes, FetchAttributes(attributes),
1325 status);
1326
1327 std::vector<PartialTensorShape> global_output_shapes;
1328 std::vector<const Layout*> output_layouts;
1329 const FunctionDef* function_def = doperation.function_def;
1330 if (!function_def) {
1331 // Output layouts of an eager op (e.g. fill) must be inferred before cache
1332 // key computation, since they might depend on the current DTensorDevice
1333 // state.
1334 Status s = PrepareGraphForMlir(
1335 function_manager_, inputs, doperation, *flib_def, eager_attributes,
1336 default_layout_, graph.get(), &global_output_shapes, &output_layouts);
1337 RETURN_C_STATUS_IF_NOT_OK(s, status);
1338
1339 // Finds all meshes the inputs are lied on.
1340 absl::flat_hash_set<Mesh> input_meshes;
1341 for (const TensorWithLayout* tensor : inputs) {
1342 if (!tensor->layout().mesh().IsEmpty()) {
1343 input_meshes.insert(tensor->layout().mesh());
1344 }
1345 }
1346 // Currently we only provide layout hints for op-by-op, since
1347 // they interact badly with layout propagation.
1348 UpdateOutputLayoutsWithSameShapePolicy(global_output_shapes, input_meshes,
1349 doperation.name, graph.get(),
1350 &output_layouts, status);
1351 if (TF_GetCode(status) != TF_OK) return;
1352 }
1353
1354 std::pair<tensorflow::Fprint128, const ExecutionFunctions*>
1355 cache_key_and_func = function_manager_.GetCachedFunction(
1356 doperation, eager_attributes, inputs, output_layouts);
1357 *execution_functions = cache_key_and_func.second;
1358 if (*execution_functions != nullptr) {
1359 function_compilation_hits_and_misses_["hit"]++;
1360 return;
1361 } else if (function_def) {
1362 function_compilation_hits_and_misses_["miss"]++;
1363 LOG(INFO) << "DTensor cache key lookup missed for " << doperation.name
1364 << ". DTensor is (re-)computing its SPMD transformation.";
1365 }
1366
1367 // It includes remote devices when the coordination service is enabled.
1368 const auto device_list = tensorflow::unwrap(context)->ListAllTfDevices();
1369 DeviceSet device_set;
1370 for (const auto device : device_list) device_set.AddDevice(device);
1371
1372 if (function_def) {
1373 ASSIGN_OR_RETURN_C_STATUS(auto device_name_to_mesh_device,
1374 PipelineSubMeshes(context), status);
1375 const bool is_pipelining_function = !device_name_to_mesh_device.empty();
1376 // For a multi-mesh function for pipelining, we take a different execution
1377 // path. Call the partitioner to lower and partition the graph into multiple
1378 // sub functions to execute (one per sub mesh).
1379 if (is_pipelining_function) {
1380 ASSIGN_OR_RETURN_C_STATUS(
1381 ExecutionFunctions functions,
1382 PipeliningPartitionerRun(&device_name_to_mesh_device, flib_def,
1383 &pass_runner_, *doperation.function_def,
1384 eager_attributes, inputs, device_set,
1385 num_outputs),
1386 status);
1387 *execution_functions = function_manager_.AddCachedFunction(
1388 doperation, cache_key_and_func.first, std::move(functions));
1389 return;
1390 }
1391 // Output layouts of a function are inferred by MLIR lowering. They are
1392 // not necessary for cache key computation, so run PrepareGraphForMlir after
1393 // cache key computation to reduce the overheads of running the same
1394 // function multiple times.
1395 Status s = PrepareGraphForMlir(
1396 function_manager_, inputs, doperation, *flib_def, eager_attributes,
1397 default_layout_, graph.get(), &global_output_shapes, &output_layouts);
1398 RETURN_C_STATUS_IF_NOT_OK(s, status);
1399 }
1400
1401 absl::flat_hash_set<Node*> control_ret_nodes;
1402 // Run DTensor MLIR passes that convert input graph to SPMD version.
1403 {
1404 profiler::TraceMe activity([&] { return "DTensorDevice::RunMLIRPasses"; },
1405 profiler::TraceMeLevel::kInfo);
1406 RETURN_C_STATUS_IF_NOT_OK(
1407 pass_runner_.RunOnGraph(device_set, doperation.is_func(), flib_def,
1408 &graph, control_ret_nodes,
1409 cache_key_and_func.first),
1410 status);
1411 }
1412 VLOG(4) << tensorflow::DumpGraphToFile("after_mlir_spmd_lowering", *graph,
1413 flib_def);
1414 if (flib_def->Contains(kLoadEmbeddingFn)) {
1415 Status s = InsertFunctionForTPUEmbeddingCheckpoint(
1416 status, graph.get(), inputs, kLoadEmbeddingFn);
1417 RETURN_C_STATUS_IF_NOT_OK(s, status);
1418 }
1419
1420 // After MLIR transformations, exactly one StatefulPartitionedCall op is
1421 // returned for mesh cluster in computation. Identity all functions to execute
1422 // for each mesh and relevant input and output information.
1423 ASSIGN_OR_RETURN_C_STATUS(
1424 ExecutionFunctions functions,
1425 IdentifyAllFunctionsToExecute(*graph.get(), global_output_shapes),
1426 status);
1427
1428 // In order to ensure that all resource assign operations as well as side
1429 // effecting ops are executed, we add identity ops before function outputs
1430 // with control rets.
1431 RETURN_C_STATUS_IF_NOT_OK(MaybeInsertIdentityNodes(function_def, graph.get()),
1432 status);
1433
1434 VLOG(4) << tensorflow::DumpGraphToFile("after_post_processing_graph", *graph,
1435 flib_def);
1436
1437 RETURN_C_STATUS_IF_NOT_OK(
1438 AddExecutionFunctionDefsToFunctionDefLibrary(control_ret_nodes, context,
1439 *graph.get(), &functions),
1440 status);
1441 functions.num_device_ids = 1;
1442 if (function_def) {
1443 for (TranslatedFunction& function : functions.function_list) {
1444 functions.function_mesh_fingerprint =
1445 FingerprintCat64(functions.function_mesh_fingerprint,
1446 function.function_mesh.GlobalFingerprint());
1447 }
1448 }
1449
1450 *execution_functions = function_manager_.AddCachedFunction(
1451 doperation, cache_key_and_func.first, std::move(functions));
1452}
1453
1454void DTensorDevice::ExecuteFunctionAndWait(
1455 TFE_Context* context, const TranslatedFunction* function_ptr,
1456 const MeshWithParallelDevice* parallel_device_mesh,
1457 const std::vector<parallel_device::ParallelTensor*>& parallel_inputs,
1458 const int64_t step_id, const TFE_OpAttrs* attributes, TF_Status* status) {
1459 const std::string mesh_str = function_ptr->function_mesh.ToString();
1460 VLOG(4) << "Launching computation for mesh : " << mesh_str;
1461 parallel_device_mesh->parallel_device().StartExecute(
1462 context,
1463 /*inputs=*/parallel_inputs,
1464 /*operation_name=*/function_ptr->translated_function_name.c_str(),
1465 /*attributes=*/attributes,
1466 /*expected_max_outputs=*/function_ptr->local_output_shapes.size(),
1467 /*cancellation_manager=*/*cancellation_manager_,
1468 /*step_id=*/step_id);
1469
1470 VLOG(4) << "Joining computation result from mesh : " << mesh_str;
1471 parallel_device_mesh->parallel_device().Join(
1472 function_ptr->local_output_shapes, status);
1473 VLOG(4) << "Joining status: " << TF_Message(status);
1474 if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_CANCELLED) {
1475 LOG(ERROR) << "Encountered error while executing function: "
1476 << function_ptr->translated_function_name
1477 << " for mesh : " << mesh_str
1478 << " / error : " << TF_Message(status);
1479 }
1480
1481 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status(
1482 TF_NewStatus(), TF_DeleteStatus);
1483 AsyncWait(context, async_wait_status.get());
1484 TF_Code error_code = TF_GetCode(async_wait_status.get());
1485 if (error_code != TF_OK && error_code != TF_CANCELLED) {
1486 LOG(ERROR) << "Async status: " << TF_Message(async_wait_status.get());
1487 }
1488}
1489
1490void DTensorDevice::ExecuteRegularOperation(
1491 TFE_Context* context, const std::vector<TensorWithLayout*>& inputs,
1492 const DTensorOperation& doperation, const TFE_OpAttrs* attributes,
1493 int* num_outputs, TFE_TensorHandle** outputs, TF_Status* status) {
1494 const ExecutionFunctions* execution_functions = nullptr;
1495
1496 LowerToSPMDFunction(context, inputs, doperation, attributes, *num_outputs,
1497 &execution_functions, status);
1498 if (TF_GetCode(status) != TF_OK) return;
1499
1500 // Update input layouts for resource arguments.
1501 for (const TranslatedFunction& function :
1502 execution_functions->function_list) {
1503 for (const auto& entry : function.resource_input_layouts) {
1504 // TODO(hthu): Add an TensorWithLayout in the inputs vector at location 0
1505 // for DeviceId. This is done as the first arg is always DeviceId, and it
1506 // isn't mapped to input Tensors.
1507 const int resource_index_to_update = entry.first - 1;
1508 inputs[resource_index_to_update]->UpdateLayout(entry.second, status);
1509 if (TF_GetCode(status) != TF_OK) {
1510 RETURN_STATUS(status, TF_GetCode(status),
1511 absl::StrCat("Attempt to update layout input arg: ",
1512 resource_index_to_update,
1513 ". Original message: ", TF_Message(status))
1514 .c_str());
1515 }
1516 }
1517 }
1518
1519 int num_global_outputs = 0;
1520
1521 std::map<std::string, const MeshWithParallelDevice*>
1522 function_name_and_mesh_mapping;
1523 absl::flat_hash_set<std::string> excluded_fn_names;
1524 std::unique_ptr<const TranslatedFunction> epu_fn_ptr, load_embedding_ptr;
1525 for (const TranslatedFunction& function :
1526 execution_functions->function_list) {
1527 StatusOr<Mesh> maybe_converted_mesh = function.function_mesh;
1528 if (function.function_mesh.is_epu_mesh()) {
1529 maybe_converted_mesh = function.function_mesh.ToDeviceType("CPU");
1530 }
1531
1532 if (!maybe_converted_mesh.ok()) {
1533 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
1534 absl::StrCat("Failed to convert mesh, get error: ",
1535 maybe_converted_mesh.status().error_message())
1536 .c_str());
1537 }
1538 const Mesh& mesh = *maybe_converted_mesh;
1539 const MeshWithParallelDevice* parallel_device_mesh =
1540 mesh_to_device_map_.contains(mesh) ? mesh_to_device_map_[mesh].get()
1541 : default_mesh_;
1542 if (parallel_device_mesh == nullptr) {
1543 RETURN_STATUS(status, TF_INTERNAL,
1544 "required mesh is not registered with DTensor device");
1545 }
1546 function_name_and_mesh_mapping[function.translated_function_name] =
1547 parallel_device_mesh;
1548
1549 if (function.function_mesh.is_epu_mesh()) {
1550 if (epu_fn_ptr != nullptr) {
1551 RETURN_STATUS(status, TF_INTERNAL,
1552 "There are more than one function defined on EPU mesh.");
1553 }
1554 epu_fn_ptr = std::make_unique<const TranslatedFunction>(function);
1555 excluded_fn_names.insert(function.translated_function_name);
1556 }
1557 if (absl::StartsWith(function.translated_function_name, kLoadEmbeddingFn)) {
1558 if (load_embedding_ptr != nullptr) {
1559 RETURN_STATUS(status, TF_INTERNAL,
1560 "There are more than one function defined on EPU mesh.");
1561 }
1562 load_embedding_ptr = std::make_unique<const TranslatedFunction>(function);
1563 excluded_fn_names.insert(function.translated_function_name);
1564 }
1565 }
1566
1567 // Compute the step_id based on the function_mesh_fingerprint and the
1568 // corresponding function execution counter.
1569 uint64 function_mesh_fingerprint =
1570 execution_functions->function_mesh_fingerprint;
1571 if (func_mesh_fingerprint_to_step_counter_.contains(
1572 function_mesh_fingerprint)) {
1573 func_mesh_fingerprint_to_step_counter_.at(function_mesh_fingerprint)++;
1574 } else {
1575 func_mesh_fingerprint_to_step_counter_.insert(
1576 {function_mesh_fingerprint, 0});
1577 }
1578 const uint64 step_id = FingerprintCat64(
1579 function_mesh_fingerprint,
1580 func_mesh_fingerprint_to_step_counter_.at(function_mesh_fingerprint));
1581
1582 // Execute excluded functions in sequence.
1583 if (epu_fn_ptr != nullptr) {
1584 ExecuteFunctionAndWait(
1585 context,
1586 /*function_ptr=*/epu_fn_ptr.get(),
1587 /*parallel_device_mesh=*/
1588 function_name_and_mesh_mapping[epu_fn_ptr->translated_function_name],
1589 /*parallel_inputs=*/{}, /*step_id=*/step_id, /*attributes=*/attributes,
1590 /*status=*/status);
1591 }
1592
1593 if (load_embedding_ptr != nullptr) {
1594 StatusOr<std::vector<parallel_device::ParallelTensor*>> parallel_inputs =
1595 PrepareEmbeddingInputs(inputs);
1596 if (!parallel_inputs.ok()) {
1597 RETURN_STATUS(status, TF_INTERNAL,
1598 parallel_inputs.status().error_message().c_str());
1599 }
1600 ExecuteFunctionAndWait(
1601 context,
1602 /*function_ptr=*/load_embedding_ptr.get(),
1603 /*parallel_device_mesh=*/
1604 function_name_and_mesh_mapping[load_embedding_ptr
1605 ->translated_function_name],
1606 /*parallel_inputs=*/*parallel_inputs, /*step_id=*/step_id,
1607 /*attributes=*/attributes, /*status=*/status);
1608 }
1609
1610 // Extract the global parallel inputs and flatten SparseTensors
1611 // into the three component tensors.
1612 std::vector<parallel_device::ParallelTensor*> global_parallel_inputs;
1613 std::vector<parallel_device::ParallelTensor*> global_parallel_sparse_inputs;
1614 absl::flat_hash_set<int> global_sparse_input_indices;
1615 for (auto input : inputs) {
1616 if (input->tensor_type() == TensorType::kSparse) {
1617 SparseTensorWithLayout* sparse_input =
1618 dynamic_cast<SparseTensorWithLayout*>(input);
1619 global_parallel_sparse_inputs.push_back(sparse_input->indices());
1620 global_parallel_sparse_inputs.push_back(sparse_input->dense_shapes());
1621 global_parallel_sparse_inputs.push_back(sparse_input->values());
1622 } else {
1623 global_parallel_inputs.push_back(input->tensor());
1624 }
1625 }
1626 // Insert SparseTensor components to the end, this is because
1627 // in the MLIR handling of SparseTensors, we place SparseTensor components
1628 // to the end of the main func arguments for a fixed ordering.
1629 global_parallel_inputs.insert(global_parallel_inputs.end(),
1630 global_parallel_sparse_inputs.begin(),
1631 global_parallel_sparse_inputs.end());
1632
1633 // Execute all functions in parallel.
1634 for (const TranslatedFunction& function :
1635 execution_functions->function_list) {
1636 const Mesh& mesh = function.function_mesh;
1637 const std::string& translated_function_name =
1638 function.translated_function_name;
1639
1640 num_global_outputs += function.local_output_shapes.size();
1641
1642 if (is_remote_mesh(mesh) ||
1643 (excluded_fn_names.find(translated_function_name) !=
1644 excluded_fn_names.end())) {
1645 // Skip execution for a translated function has remote mesh or when it is
1646 // excluded.
1647 continue;
1648 }
1649
1650 const MeshWithParallelDevice* parallel_device_mesh =
1651 function_name_and_mesh_mapping[translated_function_name];
1652
1653 // Gather the local inputs for this function.
1654 std::vector<parallel_device::ParallelTensor*> parallel_inputs;
1655 parallel_inputs.reserve(inputs.size() + 1);
1656 auto input_mapping = function.input_index_map;
1657
1658 // We sort here because by this time, the function graph we are executing
1659 // is a reduced version of the main function, that includes the
1660 // StatefulPartitionedCall that we are executing for this mesh.
1661 // Thus, the ordering is the same as the main function ordering, which
1662 // is sorted increasingly.
1663 std::sort(input_mapping.begin(), input_mapping.end());
1664
1665 for (const int global_index : input_mapping) {
1666 auto input_index = global_index - execution_functions->num_device_ids;
1667
1668 if (global_index < execution_functions->num_device_ids) {
1669 parallel_inputs.push_back(
1670 parallel_device_mesh->DeviceIDs(context, status));
1671 if (TF_GetCode(status) != TF_OK) return;
1672 } else {
1673 parallel_inputs.push_back(global_parallel_inputs[input_index]);
1674 }
1675 }
1676
1677 VLOG(4) << "Launching computation for mesh : " << mesh.ToString();
1678 parallel_device_mesh->parallel_device().StartExecute(
1679 context, parallel_inputs, translated_function_name.c_str(), attributes,
1680 /*expected_max_outputs=*/function.local_output_shapes.size(),
1681 *cancellation_manager_, /*step_id=*/step_id);
1682 }
1683
1684 *num_outputs = num_global_outputs;
1685 std::vector<std::unique_ptr<TensorWithLayout>> typed_outputs;
1686 typed_outputs.resize(num_global_outputs);
1687
1688 // Join all mesh computation together.
1689 // TODO(b/177932563): Expose cancel logic to handle failures.
1690 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> join_status(
1691 TF_NewStatus(), TF_DeleteStatus);
1692 for (const TranslatedFunction& function :
1693 execution_functions->function_list) {
1694 // Skip execution for a function when it's excluded.
1695 if (excluded_fn_names.contains(function.translated_function_name)) {
1696 continue;
1697 }
1698 const Mesh& mesh = function.function_mesh;
1699 const MeshWithParallelDevice* parallel_device_mesh =
1700 function_name_and_mesh_mapping[function.translated_function_name];
1701
1702 std::vector<std::unique_ptr<TensorWithLayout>> output_with_layout;
1703 output_with_layout.reserve(function.output_index_map.size());
1704 if (is_remote_mesh(mesh)) {
1705 // Create dummy outputs on a remote mesh.
1706 for (int i = 0; i < function.output_index_map.size(); ++i) {
1707 const auto dim_sizes = function.local_output_shapes.at(i).dim_sizes();
1708 std::vector<int64_t> local_shape =
1709 std::vector<int64_t>(dim_sizes.begin(), dim_sizes.end());
1710 TF_DataType dtype =
1711 static_cast<TF_DataType>(function.output_dtypes.at(i));
1712 auto remote_output =
1713 TensorWithLayout::Dummy(local_shape, dtype, *parallel_device_mesh,
1714 function.output_layouts[i]);
1715 output_with_layout.push_back(std::move(remote_output));
1716 }
1717 } else {
1718 VLOG(4) << "Joining computation result from mesh : " << mesh.ToString();
1719 auto result = parallel_device_mesh->parallel_device().Join(
1720 function.local_output_shapes, status);
1721 if (TF_GetCode(join_status.get()) != TF_OK &&
1722 // Preserve the first failure we see, but only if it is a real failure
1723 // and not a cancellation (which was probably triggered by the error
1724 // we want to propagate).
1725 (TF_GetCode(status) == TF_OK ||
1726 TF_GetCode(join_status.get()) != TF_CANCELLED)) {
1727 continue;
1728 }
1729 if (TF_GetCode(status) != TF_OK) {
1730 if (TF_GetCode(status) != TF_CANCELLED) {
1731 LOG(ERROR) << "Encountered error while executing function: "
1732 << function.translated_function_name
1733 << " for mesh : " << mesh.ToString()
1734 << " / error : " << TF_Message(status);
1735 }
1736 TF_SetStatus(join_status.get(), TF_GetCode(status), TF_Message(status));
1737 continue;
1738 }
1739
1740 for (int i = 0; i < result->size(); ++i) {
1741 ASSIGN_OR_RETURN_C_STATUS(
1742 auto local_output,
1743 TensorWithLayout::Wrap(std::move((*result)[i]),
1744 *parallel_device_mesh,
1745 function.output_layouts[i]),
1746 status);
1747 output_with_layout.push_back(std::move(local_output));
1748 }
1749 }
1750
1751 for (int i = 0; i < function.output_index_map.size(); ++i) {
1752 // TODO(b/162744844): Generalize this pattern so that the extraction is
1753 // not special cased.
1754 if (function.shape_output_metadata.find(i) !=
1755 function.shape_output_metadata.end()) {
1756 output_with_layout[i]->set_input_layout_for_shape_op_result(
1757 function.shape_output_metadata.at(i));
1758 }
1759
1760 RecordInShapeLayoutCache(*output_with_layout[i]);
1761 typed_outputs[function.output_index_map[i]] =
1762 std::move(output_with_layout[i]);
1763 }
1764 }
1765 if (TF_GetCode(join_status.get()) != TF_OK) {
1766 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status(
1767 TF_NewStatus(), TF_DeleteStatus);
1768 AsyncWait(context, async_wait_status.get());
1769 TF_Code error_code = TF_GetCode(async_wait_status.get());
1770 if (error_code != TF_OK && error_code != TF_CANCELLED) {
1771 // Ignore the AsyncWait() status return since we already have a bad status
1772 // to propagate. We've just canceled a bunch of operations, so we expect
1773 // cancellation status returns. We'll log anything else just to be safe.
1774 LOG(ERROR) << "Error executing " << doperation.name << " "
1775 << TF_Message(async_wait_status.get());
1776 }
1777
1778 TF_SetStatus(status, TF_GetCode(join_status.get()),
1779 TF_Message(join_status.get()));
1780 return;
1781 }
1782 if (VLOG_IS_ON(2)) {
1783 LOG(INFO) << "Executed " << doperation.name << ", got "
1784 << typed_outputs.size() << " outputs:";
1785 for (const std::unique_ptr<TensorWithLayout>& output : typed_outputs) {
1786 LOG(INFO) << " " << output->DebugString();
1787 }
1788 }
1789 if (doperation.name == std::string("VarHandleOp")) {
1790 // For new variables, set the dereferenced shape/dtype so we can pass it in
1791 // as _handle_dtype and _handle_shape in the future.
1792 //
1793 // Note that VarHandleOps generated by `tf.Variable` objects are always run
1794 // eagerly, which is almost all of the op's usage in TF2. Theoretically a
1795 // user could run it in a tf.function via tf.raw_ops.VarHandleOp, return it
1796 // from that function, and add it as an input to another, and it would
1797 // currently be missing handle information.
1798 if (typed_outputs.size() != 1) {
1799 RETURN_STATUS(status, TF_INTERNAL,
1800 "Expected one output from VarHandleOp");
1801 }
1802 NameAttrList name_and_attrs;
1803 ASSIGN_OR_RETURN_C_STATUS(name_and_attrs, FetchAttributes(attributes),
1804 status);
1805
1806 typed_outputs[0]->UpdateShapeAndDType(
1807 name_and_attrs.attr().at("shape").shape(),
1808 name_and_attrs.attr().at("dtype").type(), status);
1809 if (TF_GetCode(status) != TF_OK) return;
1810 }
1811
1812 for (int i = 0; i < *num_outputs; ++i) {
1813 outputs[i] =
1814 MakeLayoutTensorHandle(context, std::move(typed_outputs[i]), status);
1815 if (TF_GetCode(status) != TF_OK) return;
1816 }
1817}
1818
1819void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs,
1820 TFE_TensorHandle** outputs, TF_Status* status) {
1821 TFE_Context* context = TFE_OpGetContext(original_op, status);
1822 if (TF_GetCode(status) != TF_OK) return;
1823 const char* operation_name = TFE_OpGetName(original_op, status);
1824 if (TF_GetCode(status) != TF_OK) return;
1825 const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
1826 int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
1827 if (TF_GetCode(status) != TF_OK) return;
1828 std::vector<TFE_TensorHandle*> inputs_vector;
1829 inputs_vector.reserve(num_inputs);
1830 for (int input_index = 0; input_index < num_inputs; ++input_index) {
1831 TFE_TensorHandle* input =
1832 TFE_OpGetFlatInput(original_op, input_index, status);
1833 if (TF_GetCode(status) != TF_OK) return;
1834 inputs_vector.push_back(input);
1835 }
1836 TFE_TensorHandle** inputs = inputs_vector.data();
1837
1838 DTensorOperation dtensor_operation{};
1839 dtensor_operation.name = operation_name;
1840 {
1841 dtensor_operation.function_def =
1842 tensorflow::unwrap(context)->FindFunctionDef(operation_name);
1843 }
1844
1845 // First handle DTensor-specific virtual operations.
1846 bool is_op_handled = false;
1847 MaybeHandleDTensorCustomOps(operation_name, num_inputs, attributes, context,
1848 inputs, num_outputs, outputs, &is_op_handled,
1849 status);
1850 if (is_op_handled) return;
1851
1852 // This isn't a special op, so we'll defer to TFE_Execute to actually execute
1853 // it, but we'll also run DTensor MLIR passes and propagate the layout.
1854 std::vector<TensorWithLayout*> typed_inputs;
1855 std::vector<std::unique_ptr<TensorWithLayout>> inputs_with_no_layout;
1856
1857 // Record a unique mesh identified through all inputs that's already on
1858 // DTensor device. If we can identify a single mesh, the same mesh is used as
1859 // the mesh to broadcast non-dtensor inputs.
1860 absl::flat_hash_set<Mesh> input_meshes;
1861 std::vector<int> not_on_device_input_indices;
1862
1863 typed_inputs.resize(num_inputs);
1864 for (int j = 0; j < num_inputs; ++j) {
1865 TFE_TensorHandle* input = inputs[j];
1866 const char* input_device = TFE_TensorHandleDeviceName(input, status);
1867 if (TF_GetCode(status) != TF_OK) return;
1868 if (name_ != input_device) {
1869 not_on_device_input_indices.push_back(j);
1870 continue;
1871 }
1872 // Handle input which is on DTensor device already.
1873 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
1874 TFE_TensorHandleDevicePointer(input, status));
1875 if (TF_GetCode(status) != TF_OK) return;
1876
1877 // VarHandleOp runs on empty mesh, and that isn't registered with device.
1878 if (!t->layout().mesh().IsEmpty()) {
1879 input_meshes.insert(t->layout().mesh());
1880 }
1881 // Remote mesh inputs are not able to be read and evaluated.
1882 if (!is_remote_mesh(t->layout().mesh()) && !t->const_value().has_value()) {
1883 std::optional<NodeDef> const_value =
1884 ExtractSmallTensorValue(context, input, t->layout(), status);
1885 if (TF_GetCode(status) != TF_OK) return;
1886 if (const_value.has_value()) {
1887 t->set_const_value(const_value.value());
1888 }
1889 }
1890 typed_inputs[j] = t;
1891 }
1892
1893 // If a unique mesh is identified across all inputs, we use that mesh as the
1894 // mesh to broadcast to. Otherwise we fallback to default mesh.
1895 const MeshWithParallelDevice* broadcast_mesh =
1896 input_meshes.size() == 1
1897 ? mesh_to_device_map_[*input_meshes.begin()].get()
1898 : default_mesh_;
1899 if (!broadcast_mesh) {
1900 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
1901 "No mesh has been registered to DTensor. Use copy_to_mesh to "
1902 "explicit specify a mesh instead.");
1903 }
1904 for (int not_on_device_input_index : not_on_device_input_indices) {
1905 TFE_TensorHandle* input = inputs[not_on_device_input_index];
1906 // DTensor creation should be explicit, with some exceptions for usability
1907 // (scalars/shapes/slice specs/etc.) Here we do some trivial validation to
1908 // enforce this rule.
1909 int num_dims = TFE_TensorHandleNumDims(input, status);
1910 if (TF_GetCode(status) != TF_OK) return;
1911 int64_t num_elements = TFE_TensorHandleNumElements(input, status);
1912 if (TF_GetCode(status) != TF_OK) return;
1913 TF_DataType dtype = TFE_TensorHandleDataType(input);
1914 const bool small_int_tensor = num_elements < kSmallTensorThreshold &&
1915 (dtype == TF_INT32 || dtype == TF_INT64);
1916 if (!(num_dims == 0 || dtype == TF_STRING || small_int_tensor)) {
1917 std::vector<int64_t> tensor_shape(TensorShapeAsVector(input, status));
1918 if (TF_GetCode(status) != TF_OK) return;
1919 RETURN_STATUS(
1920 status, TF_UNIMPLEMENTED,
1921 absl::StrCat(
1922 "The op/function ", operation_name,
1923 " got a regular tensor for input ", not_on_device_input_index,
1924 " (shape ", ShapeToDebugString(tensor_shape),
1925 ") but was expecting a DTensor. Currently only scalars and "
1926 "small integer/string tensors are auto-broadcast to "
1927 "DTensors. For other tensors, please use copy_to_mesh to "
1928 "make a DTensor explicitly; note that this may be slow if it "
1929 "happens frequently.")
1930 .c_str());
1931 }
1932 // Construct temporary TensorWithLayout objects for inputs that didn't
1933 // have any to start. These are owned by the `inputs_with_no_layout`
1934 // vector, whereas the input `TFE_TensorHandle`s maintain ownership for
1935 // inputs that already had layouts (and therefor had TensorWithLayout
1936 // objects).
1937 std::unique_ptr<TensorWithLayout> wrapper = TensorWithLayout::Broadcast(
1938 context, input, *broadcast_mesh, name_, status);
1939 if (TF_GetCode(status) != TF_OK) return;
1940 if (!ShouldFoldInputArgument(dtensor_operation.name,
1941 /*input_index=*/not_on_device_input_index)) {
1942 wrapper->reset_const_value();
1943 }
1944 typed_inputs[not_on_device_input_index] = wrapper.get();
1945 inputs_with_no_layout.emplace_back(wrapper.release());
1946 }
1947
1948 ExecuteRegularOperation(context, typed_inputs, dtensor_operation, attributes,
1949 num_outputs, outputs, status);
1950}
1951
1952void ExecuteOnDTensorDevice(const TFE_Op* original_op, int* num_outputs,
1953 TFE_TensorHandle** outputs, TF_Status* status,
1954 void* device_info) {
1955 DTensorDevice* dev = reinterpret_cast<DTensorDevice*>(device_info);
1956 dev->Execute(original_op, num_outputs, outputs, status);
1957}
1958
1959void DeleteDTensorDevice(void* device_info) {
1960 delete static_cast<DTensorDevice*>(device_info);
1961}
1962
1963TFE_TensorHandle* CopyToDTensorDevice(TFE_Context* context,
1964 TFE_TensorHandle* tensor,
1965 TF_Status* status, void* device_info) {
1966 TF_SetStatus(status, TF_UNIMPLEMENTED,
1967 "Trying to copy a tensor on to a DTensor mesh without a layout "
1968 "(use the CopyToMesh op for now).");
1969 return nullptr;
1970}
1971
1972TFE_TensorHandle* CopyFromDTensorDevice(TFE_Context* context,
1973 TFE_TensorHandle* tensor,
1974 const char* target_device_name,
1975 TF_Status* status, void* device_info) {
1976 TensorWithLayout* typed_input = reinterpret_cast<TensorWithLayout*>(
1977 TFE_TensorHandleDevicePointer(tensor, status));
1978 if (!tensorflow::dtensor::Layout(typed_input->layout()).IsFullyReplicated()) {
1979 TF_SetStatus(status, TF_UNIMPLEMENTED,
1980 "Trying to copy a non-replicated DTensor is not supported.");
1981 return nullptr;
1982 }
1983 if (typed_input->tensor()->dtype() == TF_RESOURCE) {
1984 TF_SetStatus(status, TF_UNIMPLEMENTED,
1985 "Trying to copy a DTensor resource handle is not supported.");
1986 return nullptr;
1987 }
1988 DTensorDevice* dev = reinterpret_cast<DTensorDevice*>(device_info);
1989 // Since operations are executed asynchronously, the operation which should
1990 // produce the tensor we're trying to copy off the DTensor device may be
1991 // canceled due to a failure on another device. If so, we want to report the
1992 // failure that caused the cancellation, not the cancellation itself. This
1993 // requires blocking waiting for other devices to flush their execution
1994 // queues.
1995 // Note that we also only need to sync the threads on the parallel_device()
1996 // directly, or a context level sync might cause unintentional deadlocks when
1997 // grabbing locks on other threads.
1998 dev->AsyncWait(context, status);
1999 if (TF_GetCode(status) != TF_OK) return nullptr;
2000 return TFE_TensorHandleCopySharingTensor(typed_input->get_tensor(0), status);
2001}
2002
2003bool PinToDTensorDevice(const TFE_Op* op, TF_Status* s) {
2004 // Always pin to the dtensor device if any of its input is a dtensor.
2005 // Note that if this function is called, the caller guarantees
2006 // that all inputs that are on a custom device is a single dtensor device.
2007
2008 // Exception 1:
2009 // If there is a non-dtensor resource tensor and other dtensor inputs
2010 // are not on a CPU mesh, then pin to the physical device.
2011 //
2012 // This is because our resource upcast to a dtensor only supports broadcasting
2013 // to a CPU mesh. If any other dtensor inputs are on a TPU mesh,
2014 // then the mesh that is broadcasted will be the TPU mesh.
2015 int num_inputs = TFE_OpGetFlatInputCount(op, s);
2016 std::vector<TFE_TensorHandle*> inputs_vector;
2017 inputs_vector.reserve(num_inputs);
2018
2019 absl::flat_hash_set<Mesh> input_meshes;
2020
2021 bool has_non_dtensor_resource = false;
2022
2023 for (int input_index = 0; input_index < num_inputs; ++input_index) {
2024 TFE_TensorHandle* input = TFE_OpGetFlatInput(op, input_index, s);
2025
2026 std::string input_device_name =
2027 std::string(TFE_TensorHandleDeviceName(input, s));
2028 if (!absl::StrContains(absl::AsciiStrToLower(input_device_name),
2029 "custom")) {
2030 TF_DataType dtype = TFE_TensorHandleDataType(input);
2031 if (dtype == TF_RESOURCE) {
2032 has_non_dtensor_resource = true;
2033 }
2034 continue;
2035 }
2036
2037 // Handle input which is on DTensor device already.
2038 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
2039 TFE_TensorHandleDevicePointer(input, s));
2040
2041 if (!t->layout().mesh().IsEmpty()) {
2042 input_meshes.insert(t->layout().mesh());
2043 }
2044 }
2045
2046 const Mesh* broadcast_mesh =
2047 input_meshes.size() == 1 ? &(*input_meshes.begin()) : nullptr;
2048
2049 // Place on physical device as dtensor does not support upcasting resource
2050 // tensor to a non-cpu mesh.
2051 if (has_non_dtensor_resource && broadcast_mesh &&
2052 !broadcast_mesh->is_cpu_mesh()) {
2053 return false;
2054 }
2055
2056 return true;
2057}
2058
2059void AllocateDTensorDevice(absl::string_view device_name,
2060 TFE_CustomDevice* device, void** device_info) {
2061 device->copy_tensor_to_device = &CopyToDTensorDevice;
2062 device->copy_tensor_from_device = &CopyFromDTensorDevice;
2063 device->delete_device = &DeleteDTensorDevice;
2064 device->execute = &ExecuteOnDTensorDevice;
2065 device->shall_pin_to_this_device = &PinToDTensorDevice;
2066 *device_info = new DTensorDevice(device_name);
2067}
2068
2069void AddMesh(const std::string& serialized_mesh, void* device_info,
2070 bool is_async, bool is_host_mesh, TF_Status* status) {
2071 auto mesh_config_or_status = Mesh::FromString(serialized_mesh);
2072 if (!mesh_config_or_status.ok()) {
2073 TF_SetStatus(status, TF_INTERNAL,
2074 absl::StrCat("Failed to parse mesh config. ",
2075 mesh_config_or_status.status().error_message())
2076 .c_str());
2077 return;
2078 }
2079 auto mesh_config = mesh_config_or_status.value();
2080 std::vector<std::string> underlying_devices;
2081 underlying_devices.insert(underlying_devices.end(),
2082 mesh_config.local_devices().begin(),
2083 mesh_config.local_devices().end());
2084 // DTensor uses multi-client setup which doesn't use remote eager, so we can
2085 // enable eager async execution in ParallelDevice.
2086 std::unique_ptr<tensorflow::parallel_device::ParallelDevice> parallel(
2087 new tensorflow::parallel_device::ParallelDevice(underlying_devices,
2088 is_async));
2089
2090 std::string composite_device_name;
2091 if (absl::StartsWith(mesh_config.name(), kPipelineMeshNamePrefix)) {
2092 composite_device_name = std::string(
2093 absl::StripPrefix(mesh_config.name(), kPipelineMeshNamePrefix));
2094 }
2095
2096 auto mesh = std::make_unique<MeshWithParallelDevice>(
2097 std::move(mesh_config), std::move(parallel), composite_device_name);
2098 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2099 device->AddMesh(std::move(mesh), is_host_mesh);
2100}
2101
2102void ExperimentalSetDefaultLayout(const std::string& serialized_layout,
2103 void* device_info, TF_Status* status) {
2104 StatusOr<Layout> layout = Layout::FromString(serialized_layout);
2105 if (!layout.ok()) {
2106 RETURN_STATUS(status, TF_INTERNAL, layout.status().error_message().c_str());
2107 }
2108 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2109 device->SetDefaultLayout(layout.value());
2110}
2111
2112void ExperimentalClearDefaultLayout(void* device_info, TF_Status* status) {
2113 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2114 device->ClearDefaultLayout();
2115}
2116
2117void ExperimentalSetDefaultMesh(const std::string& serialized_mesh,
2118 void* device_info, TF_Status* status) {
2119 StatusOr<Mesh> mesh = Mesh::FromString(serialized_mesh);
2120 if (!mesh.ok()) {
2121 RETURN_STATUS(status, TF_INTERNAL, mesh.status().error_message().c_str());
2122 }
2123 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2124 device->SetDefaultMesh(mesh.value());
2125}
2126
2127void ExperimentalClearDefaultMesh(void* device_info, TF_Status* status) {
2128 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2129 device->ClearDefaultMesh();
2130}
2131
2132void SetSameShapePolicy(void* device_info, bool enabled) {
2133 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2134 device->SetSameShapePolicy(enabled);
2135}
2136
2137void SetTPUCoreIDs(const std::string& mesh_name,
2138 const std::vector<int>& tpu_core_ids, void* device_info,
2139 TF_Status* status) {
2140 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2141 RETURN_C_STATUS_IF_NOT_OK(device->SetTPUCoreIDs(mesh_name, tpu_core_ids),
2142 status);
2143}
2144
2145void ClearTPUCoreIDs(void* device_info) {
2146 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2147 device->ClearTPUCoreIDs();
2148}
2149
2150std::vector<std::vector<int>> TPUCoreIDsToLocations(
2151 TFE_Context* context, const std::vector<int>& tpu_core_ids,
2152 void* device_info) {
2153 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2154 return device->TPUCoreIDsToLocations(context, tpu_core_ids);
2155}
2156
2157std::vector<int> TPUCoreLocationsToIDs(
2158 TFE_Context* context,
2159 const std::vector<std::vector<int>>& tpu_core_locations,
2160 void* device_info) {
2161 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2162 return device->TPUCoreLocationsToIDs(context, tpu_core_locations);
2163}
2164
2165TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs,
2166 TFE_TensorHandle** inputs,
2167 const std::string& string_layout, void* device_info,
2168 TF_Status* status) {
2169 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2170 return device->Pack(context, num_inputs, inputs, string_layout, status);
2171}
2172
2173std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context,
2174 TFE_TensorHandle* input,
2175 void* device_info, TF_Status* status) {
2176 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2177 return device->Unpack(context, input, status);
2178}
2179
2180std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input,
2181 void* device_info, TF_Status* status) {
2182 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2183 return device->FetchLayout(context, input, status);
2184}
2185
2186TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs,
2187 TFE_TensorHandle** indices,
2188 TFE_TensorHandle** values,
2189 TFE_TensorHandle** shapes,
2190 const std::string& string_layout,
2191 void* device_info, TF_Status* status) {
2192 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2193 return device->SparsePack(context, num_inputs, indices, values, shapes,
2194 string_layout, status);
2195}
2196
2197bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input,
2198 void* device_info, TF_Status* status) {
2199 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2200 return device->IsSparseDTensor(context, input, status);
2201}
2202
2203std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount(
2204 TFE_Context* context, void* device_info, TF_Status* status) {
2205 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2206 return device->GetFunctionCacheHitAndMissCount(context, status);
2207}
2208} // namespace dtensor
2209} // namespace tensorflow
2210