1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/cc/dtensor_device.h" |
17 | |
18 | #include <algorithm> |
19 | #include <cstdint> |
20 | #include <memory> |
21 | #include <string> |
22 | #include <utility> |
23 | #include <vector> |
24 | |
25 | #include "absl/base/attributes.h" |
26 | #include "absl/container/flat_hash_map.h" |
27 | #include "absl/container/flat_hash_set.h" |
28 | #include "absl/memory/memory.h" |
29 | #include "absl/strings/match.h" |
30 | #include "absl/strings/str_cat.h" |
31 | #include "absl/strings/str_join.h" |
32 | #include "absl/strings/string_view.h" |
33 | #include "absl/strings/strip.h" |
34 | #include "tensorflow/c/c_api_experimental.h" |
35 | #include "tensorflow/c/eager/c_api.h" |
36 | #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" |
37 | #include "tensorflow/c/eager/tfe_context_internal.h" |
38 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
39 | #include "tensorflow/c/tf_datatype.h" |
40 | #include "tensorflow/c/tf_status.h" |
41 | #include "tensorflow/c/tf_status_helper.h" |
42 | #include "tensorflow/c/tf_tensor_internal.h" |
43 | #include "tensorflow/compiler/xla/status_macros.h" |
44 | #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" |
45 | #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.h" |
46 | #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_topology.h" |
47 | #include "tensorflow/core/common_runtime/device_set.h" |
48 | #include "tensorflow/core/common_runtime/eager/context.h" |
49 | #include "tensorflow/core/common_runtime/eager/tensor_handle.h" |
50 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
51 | #include "tensorflow/core/common_runtime/shape_refiner.h" |
52 | #include "tensorflow/core/framework/attr_value.pb.h" |
53 | #include "tensorflow/core/framework/function.h" |
54 | #include "tensorflow/core/framework/function.pb.h" |
55 | #include "tensorflow/core/framework/graph_to_functiondef.h" |
56 | #include "tensorflow/core/framework/node_def_builder.h" |
57 | #include "tensorflow/core/framework/node_def_util.h" |
58 | #include "tensorflow/core/framework/op.h" |
59 | #include "tensorflow/core/framework/tensor_shape.h" |
60 | #include "tensorflow/core/graph/algorithm.h" |
61 | #include "tensorflow/core/graph/graph.h" |
62 | #include "tensorflow/core/lib/strings/proto_serialization.h" |
63 | #include "tensorflow/core/platform/errors.h" |
64 | #include "tensorflow/core/platform/fingerprint.h" |
65 | #include "tensorflow/core/platform/types.h" |
66 | #include "tensorflow/core/profiler/lib/traceme.h" |
67 | #include "tensorflow/core/util/dump_graph.h" |
68 | #include "tensorflow/dtensor/cc/constants.h" |
69 | #include "tensorflow/dtensor/cc/dstatus.h" |
70 | #include "tensorflow/dtensor/cc/dtensor_device_util.h" |
71 | #include "tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h" |
72 | #include "tensorflow/dtensor/cc/small_constant_optimization.h" |
73 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
74 | #include "tensorflow/dtensor/cc/tpu_system_interface.h" |
75 | #include "tensorflow/dtensor/proto/layout.pb.h" |
76 | |
77 | namespace tensorflow { |
78 | namespace dtensor { |
79 | |
80 | // TODO(b/189332820): Replace this with a Partitioner stub swapped in by the |
81 | // Copybara workflow. |
82 | StatusOr<ExecutionFunctions> ABSL_ATTRIBUTE_WEAK PipeliningPartitionerRun( |
83 | const absl::flat_hash_map<std::string, const MeshWithParallelDevice*>* |
84 | device_name_to_mesh_device, |
85 | FunctionLibraryDefinition* flib_def, DTensorMlirPassRunner* pass_runner, |
86 | const FunctionDef& fdef, const NameAttrList& eager_attributes, |
87 | const std::vector<TensorWithLayout*>& inputs, const DeviceSet& device_set, |
88 | int num_outputs) { |
89 | // The actual definition is in the pipelining package. |
90 | return errors::Unimplemented("DTensor pipelining is unavailable." ); |
91 | } |
92 | |
93 | class DTensorDevice { |
94 | public: |
95 | explicit DTensorDevice(absl::string_view name) |
96 | : name_(name), |
97 | same_shape_policy_enabled_(false), |
98 | cancellation_manager_(std::make_unique<CancellationManager>()) {} |
99 | |
100 | void AddMesh(std::unique_ptr<MeshWithParallelDevice> mesh, |
101 | bool is_host_mesh) { |
102 | if (is_host_mesh) { |
103 | std::string& tpu_host_mesh = Mesh::tpu_host_mesh(); |
104 | const std::string new_tpu_host_mesh = mesh->mesh_config().ToString(); |
105 | if (!tpu_host_mesh.empty()) { |
106 | // TODO(b/180046115): Add per-TPU-mesh host mesh bookkeeping. |
107 | LOG(WARNING) |
108 | << "A new TPU host mesh is overwriting the old TPU host mesh. The " |
109 | "old TPU mesh cannot be used in sea of donuts mode anymore." ; |
110 | } |
111 | tpu_host_mesh.assign(new_tpu_host_mesh); |
112 | } |
113 | // For idempotency, don't register the same mesh twice. |
114 | if (!mesh_to_device_map_.insert({mesh->mesh_config(), std::move(mesh)}) |
115 | .second) |
116 | return; |
117 | if (!default_mesh_) { |
118 | global_default_mesh_ = mesh_to_device_map_.begin()->second.get(); |
119 | default_mesh_ = global_default_mesh_; |
120 | } |
121 | } |
122 | |
123 | // Returns sub meshes of pipelining. |
124 | // Key is the name of a composite device. |
125 | StatusOr<absl::flat_hash_map<std::string, const MeshWithParallelDevice*>> |
126 | PipelineSubMeshes(TFE_Context* context) { |
127 | absl::flat_hash_map<std::string, const MeshWithParallelDevice*> |
128 | device_to_mesh; |
129 | for (const auto& pair : mesh_to_device_map_) { |
130 | TF_ASSIGN_OR_RETURN(CompositeDevice * device, |
131 | pair.second->FindOrCreateCompositeDevice(context)); |
132 | if (device != nullptr) { |
133 | device_to_mesh[pair.second->composite_device()->name()] = |
134 | pair.second.get(); |
135 | } |
136 | } |
137 | return device_to_mesh; |
138 | } |
139 | |
140 | // Runs an operation on the DTensorDevice, |
141 | // |
142 | // Ignoring the placement of the original op (TFE_OpGetDevice(original_op)). |
143 | // This indicates whether the user explicitly placed the op on the DTensor |
144 | // device (vs. having it placed on the DTensor device because an input was |
145 | // placed there), but DTensor is doing type-based dispatch and so handles |
146 | // these cases identically at the moment. |
147 | void Execute(const TFE_Op* original_op, int* num_outputs, |
148 | TFE_TensorHandle** outputs, TF_Status* status); |
149 | |
150 | void SetDefaultLayout(Layout layout) { default_layout_.emplace(layout); } |
151 | void ClearDefaultLayout() { default_layout_.reset(); } |
152 | void SetDefaultMesh(Mesh mesh) { |
153 | default_mesh_ = mesh_to_device_map_.at(mesh).get(); |
154 | } |
155 | void ClearDefaultMesh() { default_mesh_ = global_default_mesh_; } |
156 | void SetSameShapePolicy(bool enabled) { |
157 | same_shape_policy_enabled_ = enabled; |
158 | } |
159 | |
160 | Status SetTPUCoreIDs(const std::string& mesh_name, |
161 | const std::vector<int>& tpu_core_ids) { |
162 | if (VLOG_IS_ON(1)) { |
163 | LOG(INFO) << "Setting TPU core IDs for " |
164 | << (mesh_name.empty() ? "default mesh" : mesh_name) << ": " ; |
165 | for (auto i : tpu_core_ids) { |
166 | LOG(INFO) << i; |
167 | } |
168 | } |
169 | // Setting the default mesh under an empty name repeatedly is fine, which |
170 | // happens when dtensor_initialize_tpu_system is called multiple times |
171 | // especially in tests. All the set mappings should be the same anyway. |
172 | if (!mesh_name.empty() && Mesh::tpu_core_ids().count(mesh_name) > 0) { |
173 | return errors::AlreadyExists("Mesh name already in use: " , mesh_name); |
174 | } |
175 | Mesh::tpu_core_ids()[mesh_name].assign(tpu_core_ids.begin(), |
176 | tpu_core_ids.end()); |
177 | return OkStatus(); |
178 | } |
179 | |
180 | void ClearTPUCoreIDs() { Mesh::tpu_core_ids().clear(); } |
181 | |
182 | std::vector<std::vector<int>> TPUCoreIDsToLocations( |
183 | TFE_Context* context, const std::vector<int>& tpu_core_ids) { |
184 | TpuSystemInterface* tpu_system = GetPreferredTpuSystem(); |
185 | if (tpu_system == nullptr) { |
186 | VLOG(1) << "Calling TPUCoreIDsToLocations on the default TPU system." ; |
187 | std::vector<std::vector<int>> tpu_core_locations; |
188 | tpu_core_locations.reserve(tpu_core_ids.size()); |
189 | tpu::TpuPlatformInterface* tpu_platform = |
190 | tpu::TpuPlatformInterface::GetRegisteredPlatform(); |
191 | if (tpu_platform == nullptr) { |
192 | LOG(WARNING) << "No TPU platform is found." ; |
193 | return {{}}; |
194 | } |
195 | if (!tpu_platform->Initialized()) { |
196 | LOG(WARNING) << "TPU platform is not initialized." ; |
197 | return {{}}; |
198 | } |
199 | tpu::TpuTopologyExternal tpu_topology = tpu_platform->topology(); |
200 | |
201 | for (const int& tpu_core_id : tpu_core_ids) { |
202 | tpu::TpuCoreLocationExternal core = |
203 | tpu_topology.CoreForId(TpuCoreTypeEnum::kTensorCore, tpu_core_id); |
204 | tpu::TpuDimensionsExternal tpu_chip_location = core.chip_coordinates(); |
205 | tpu_core_locations.push_back({tpu_chip_location.x, tpu_chip_location.y, |
206 | tpu_chip_location.z, core.index()}); |
207 | } |
208 | return tpu_core_locations; |
209 | } else { |
210 | VLOG(1) << "Calling TPUCoreIDsToLocations on a preferred TPU system." ; |
211 | return tpu_system->TPUCoreIDsToLocations(context, tpu_core_ids); |
212 | } |
213 | } |
214 | |
215 | std::vector<int> TPUCoreLocationsToIDs( |
216 | TFE_Context* context, |
217 | const std::vector<std::vector<int>>& tpu_core_locations) { |
218 | TpuSystemInterface* tpu_system = GetPreferredTpuSystem(); |
219 | if (tpu_system == nullptr) { |
220 | VLOG(1) << "Calling TPUCoreLocationsToIDs on the default TPU system." ; |
221 | std::vector<int> tpu_core_ids; |
222 | tpu_core_ids.reserve(tpu_core_locations.size()); |
223 | tpu::TpuPlatformInterface* tpu_platform = |
224 | tpu::TpuPlatformInterface::GetRegisteredPlatform(); |
225 | if (tpu_platform == nullptr) { |
226 | LOG(WARNING) << "No TPU platform is found." ; |
227 | return {}; |
228 | } |
229 | if (!tpu_platform->Initialized()) { |
230 | LOG(WARNING) << "TPU platform is not initialized." ; |
231 | return {}; |
232 | } |
233 | tpu::TpuTopologyExternal tpu_topology = tpu_platform->topology(); |
234 | |
235 | for (const std::vector<int>& tpu_core_location : tpu_core_locations) { |
236 | tpu::TpuCoreLocationExternal core = tpu_topology.Core( |
237 | TpuCoreTypeEnum::kTensorCore, tpu_core_location[0], |
238 | tpu_core_location[1], tpu_core_location[2], tpu_core_location[3]); |
239 | tpu_core_ids.push_back(core.Id()); |
240 | } |
241 | return tpu_core_ids; |
242 | } else { |
243 | VLOG(1) << "Calling TPUCoreLocationsToIDs on a preferred TPU system." ; |
244 | return tpu_system->TPUCoreLocationsToIDs(context, tpu_core_locations); |
245 | } |
246 | } |
247 | |
248 | // Waits for ops to finish in ALL meshes as we share the cancellation manager. |
249 | void AsyncWait(TFE_Context* context, TF_Status* status) { |
250 | std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> first_bad_status( |
251 | nullptr, TF_DeleteStatus); |
252 | |
253 | for (const auto& pair : mesh_to_device_map_) { |
254 | std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status( |
255 | TF_NewStatus(), TF_DeleteStatus); |
256 | |
257 | pair.second->parallel_device().AsyncWait(context, |
258 | async_wait_status.get()); |
259 | |
260 | TF_Code error_code = TF_GetCode(async_wait_status.get()); |
261 | if (error_code != TF_OK && |
262 | (first_bad_status == nullptr || |
263 | TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) { |
264 | first_bad_status.reset(TF_NewStatus()); |
265 | TF_SetStatus(first_bad_status.get(), error_code, |
266 | TF_Message(async_wait_status.get())); |
267 | } |
268 | } |
269 | |
270 | if (first_bad_status != nullptr) { |
271 | TF_SetStatus(status, TF_GetCode(first_bad_status.get()), |
272 | TF_Message(first_bad_status.get())); |
273 | } |
274 | |
275 | // Reset the global function rendezvous, which otherwise stores a failure |
276 | // state. |
277 | tensorflow::unwrap(context)->ResetGlobalRendezvousForFunction(); |
278 | |
279 | // Reset the cancellation manager on (potential) failure so we don't cancel |
280 | // future ops. This is only safe because we have just cleared pending async |
281 | // nodes, which may have had a reference to he cancellation manager. |
282 | cancellation_manager_ = std::make_unique<CancellationManager>(); |
283 | } |
284 | |
285 | TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs, |
286 | TFE_TensorHandle** inputs, |
287 | const std::string& string_layout, TF_Status* status); |
288 | |
289 | std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context, |
290 | TFE_TensorHandle* input, |
291 | TF_Status* status); |
292 | |
293 | // Return the layout for the input tensor. |
294 | std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input, |
295 | TF_Status* status); |
296 | |
297 | TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs, |
298 | TFE_TensorHandle** indices, |
299 | TFE_TensorHandle** values, |
300 | TFE_TensorHandle** shapes, |
301 | const std::string& string_layout, |
302 | TF_Status* status); |
303 | |
304 | bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input, |
305 | TF_Status* status); |
306 | |
307 | std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount( |
308 | TFE_Context* context, TF_Status* status) const; |
309 | |
310 | private: |
311 | // If the `operation_name` of an op indicates a custom DTensor op (e.g. |
312 | // CopyToMesh), then separately handle those custom ops instead of running |
313 | // default DTensor graph compilation. |
314 | void MaybeHandleDTensorCustomOps( |
315 | const char* operation_name, const int num_inputs, |
316 | const TFE_OpAttrs* attributes, TFE_Context* context, |
317 | TFE_TensorHandle** inputs, int* num_outputs, TFE_TensorHandle** outputs, |
318 | bool* is_custom_dtensor_op, TF_Status* status); |
319 | |
320 | // Copies non-dtensor eager tensor or DTensor to a mesh specified by |
321 | // `attributes`. |
322 | // Currently, only copy to replicated layout on target mesh is supported. |
323 | void CopyToMesh(TFE_Context* context, int num_inputs, |
324 | TFE_TensorHandle** inputs, const TFE_OpAttrs* attributes, |
325 | TFE_TensorHandle** outputs, int* num_outputs, |
326 | TF_Status* status); |
327 | |
328 | // Update output layouts for eager ops based on same shape policy. |
329 | void UpdateOutputLayoutsWithSameShapePolicy( |
330 | const std::vector<PartialTensorShape>& global_output_shapes, |
331 | const absl::flat_hash_set<Mesh>& input_meshes, absl::string_view op_name, |
332 | tensorflow::Graph* graph, std::vector<const Layout*>* output_layouts, |
333 | TF_Status* status); |
334 | |
335 | // Takes the description of an operation and makes a function out of it with |
336 | // the same signature, running DTensor MLIR passes. Registers that function |
337 | // with `context`. `translated_function_name` is set to the name of the |
338 | // function. |
339 | // |
340 | // The resulting function expects a device ID as its first input. |
341 | void LowerToSPMDFunction(TFE_Context* context, |
342 | const std::vector<TensorWithLayout*>& inputs, |
343 | const DTensorOperation& doperation, |
344 | const TFE_OpAttrs* attributes, const int num_outputs, |
345 | const ExecutionFunctions** execution_functions, |
346 | TF_Status* status); |
347 | |
348 | // Execute a given function. |
349 | void ExecuteFunctionAndWait( |
350 | TFE_Context* context, const TranslatedFunction* function_ptr, |
351 | const MeshWithParallelDevice* parallel_device_mesh, |
352 | const std::vector<parallel_device::ParallelTensor*>& parallel_inputs, |
353 | const int64_t step_id, const TFE_OpAttrs* attributes, TF_Status* status); |
354 | |
355 | // Implements `Execute` for operations which aren't special-cased in |
356 | void ExecuteRegularOperation(TFE_Context* context, |
357 | const std::vector<TensorWithLayout*>& inputs, |
358 | const DTensorOperation& doperation, |
359 | const TFE_OpAttrs* attributes, int* num_outputs, |
360 | TFE_TensorHandle** outputs, TF_Status* status); |
361 | |
362 | // Wraps a TensorWithLayout into a TFE_TensorHandle. |
363 | TFE_TensorHandle* MakeLayoutTensorHandle(TFE_Context* context, |
364 | std::unique_ptr<TensorWithLayout> t, |
365 | TF_Status* status); |
366 | |
367 | void RecordInShapeLayoutCache(const TensorWithLayout& tensor); |
368 | |
369 | // Returns whether a given mesh is a remote mesh. |
370 | bool is_remote_mesh(const Mesh& mesh) const; |
371 | |
372 | // The name of the device (the custom device) |
373 | std::string name_; |
374 | // Mesh configs with matching parallel devices. |
375 | // |
376 | // For now we just consider the first entry added to dtensor_device as the |
377 | // default mesh. Before we reach an agreement on this, we'll leave it as is. |
378 | absl::flat_hash_map<Mesh, std::unique_ptr<MeshWithParallelDevice>> |
379 | mesh_to_device_map_; |
380 | // TODO(hthu): Consider whether we want to preserve the default_mesh semantic. |
381 | // Current default mesh consistent to default_layout_. If default_layout_ is |
382 | // not set, it equals to global_default_mesh_. |
383 | const MeshWithParallelDevice* default_mesh_ = nullptr; |
384 | // The default mesh of a DTensorDevice, which cannot be modified once being |
385 | // set. |
386 | const MeshWithParallelDevice* global_default_mesh_ = nullptr; |
387 | // If the user has specified a default output layout. |
388 | absl::optional<Layout> default_layout_; |
389 | |
390 | // Determines whether tensors with a shape previously associated with only one |
391 | // layout use that layout if nothing else can be inferred. |
392 | bool same_shape_policy_enabled_; |
393 | |
394 | DTensorMlirPassRunner pass_runner_; |
395 | |
396 | struct CachedLayout { |
397 | // The first layout seen with this shape |
398 | Layout layout; |
399 | // Whether the layout is unique for this shape |
400 | bool is_unique; |
401 | }; |
402 | absl::flat_hash_map<int64_t, CachedLayout> shape_layout_cache_; |
403 | |
404 | FunctionManager function_manager_; |
405 | |
406 | // Records the function compilation cache hits and misses. |
407 | std::unordered_map<std::string, int> function_compilation_hits_and_misses_; |
408 | |
409 | // Coordinates cancelling ops across meshes on error. Must outlive any queued |
410 | // async op launches, so we only reset it after seeing a failure status. |
411 | std::unique_ptr<CancellationManager> cancellation_manager_; |
412 | |
413 | // Map each function_mesh_fingerprint (based on the set of the mesh involved) |
414 | // to the number of times of the function execution. The |
415 | // function_mesh_fingerprint and the counter together are used for generating |
416 | // the step id, which is used for rendezvous creation. |
417 | absl::flat_hash_map<uint64, uint64> func_mesh_fingerprint_to_step_counter_; |
418 | }; |
419 | |
420 | int64_t FingerprintShape(const absl::Span<const int64_t> shape) { |
421 | int64_t fprint = 0; |
422 | for (int64_t dim : shape) { |
423 | fprint = FingerprintCat64(fprint, dim); |
424 | } |
425 | return fprint; |
426 | } |
427 | |
428 | parallel_device::ParallelTensor* MeshWithParallelDevice::DeviceIDs( |
429 | TFE_Context* context, TF_Status* status) const { |
430 | if (device_ids_tensor_ == nullptr) { |
431 | // Global device IDs sequentially increase. |
432 | // |
433 | // This is the assumption in the dtensor software stack. MLIR pass relies on |
434 | // this assumption to generate mesh coordinates for each core efficiently. |
435 | // |
436 | // The rule to set local ids and the mapping from global ids to real |
437 | // physical core index, e.g., TPU, is nontrivial unfortunately. It is |
438 | // possible to set identical mapping but the collective operation |
439 | // performance is terrible for most of cases. |
440 | // |
441 | // - For ICI-connected TPU slice, see go/dtensor-device-assignment-summary |
442 | // for guide how to create efficient core assignments toward peak |
443 | // performance. |
444 | // |
445 | // The global id to core assignment mapping is bridged by |
446 | // `Mesh::tpu_core_ids()` and consumed by `UpdateTPUCompileMetadata`. |
447 | // |
448 | // - For DCN-connected topology, we need to map different sections of the |
449 | // global ids to its real physical cores separately according to the |
450 | // runtime requirements. For example, for a 4x32 mesh, in which the outer |
451 | // dimension is connected via DCN and inner dimension is connected by ICI, |
452 | // the device assignments for inner dimension should typically form its |
453 | // own ring order (not plain physical core index) in each sub-meshes and |
454 | // the outer dimension should be assigned according to the real physical |
455 | // ring of DNC hosts. |
456 | // |
457 | // Note: In order to change this assumption, MLIR pass needs adjustment. One |
458 | // possible approach is to take a N-D mapping vector for N-D mesh and lookup |
459 | // the coordinates in MLIR, by consulting tensor layout as well, rather than |
460 | // calculation on-the-fly. |
461 | |
462 | // LINT.IfChange |
463 | for (int64_t i = 0; i < mesh_config_.global_device_ids().size(); ++i) { |
464 | if (mesh_config_.global_device_ids()[i] - i != |
465 | mesh_config_.global_device_ids()[0]) { |
466 | TF_SetStatus( |
467 | status, TF_INTERNAL, |
468 | absl::StrCat("Global device IDs should be consecutive: " , |
469 | absl::StrJoin(mesh_config_.global_device_ids(), ", " )) |
470 | .c_str()); |
471 | return nullptr; |
472 | } |
473 | } |
474 | // LINT.ThenChange(//tensorflow/dtensor/python/layout.py) |
475 | |
476 | // Local device IDs are a subset of global device IDs, arranged in device |
477 | // ordinal order. |
478 | std::vector<int32_t> ids; |
479 | for (int64_t id : mesh_config_.local_device_ids()) { |
480 | ids.push_back(id); |
481 | } |
482 | VLOG(1) << "Parallel device IDs: " << absl::StrJoin(ids, ", " ); |
483 | device_ids_tensor_ = |
484 | parallel_device_->ScalarsFromSequence<int32_t>(ids, context, status); |
485 | if (TF_GetCode(status) != TF_OK) return nullptr; |
486 | } |
487 | return device_ids_tensor_.get(); |
488 | } |
489 | |
490 | int TensorWithLayoutNumDims(void* data, TF_Status* status) { |
491 | return reinterpret_cast<TensorWithLayout*>(data)->global_shape().size(); |
492 | } |
493 | |
494 | int64_t TensorWithLayoutDim(void* data, int dim_index, TF_Status* status) { |
495 | return reinterpret_cast<TensorWithLayout*>(data)->global_shape()[dim_index]; |
496 | } |
497 | |
498 | void TensorWithLayoutDeallocator(void* data) { |
499 | delete reinterpret_cast<TensorWithLayout*>(data); |
500 | } |
501 | |
502 | TF_Buffer* TensorWithLayoutSummarize(void* data, TF_Status* status) { |
503 | std::string summary = |
504 | reinterpret_cast<TensorWithLayout*>(data)->SummarizeValue(); |
505 | return TF_NewBufferFromString(summary.data(), summary.size()); |
506 | } |
507 | |
508 | TFE_TensorHandle* DTensorDevice::MakeLayoutTensorHandle( |
509 | TFE_Context* context, std::unique_ptr<TensorWithLayout> t, |
510 | TF_Status* status) { |
511 | TF_DataType dtype = t->dtype(); |
512 | TFE_CustomDeviceTensorHandleMethods handle_methods; |
513 | handle_methods.num_dims = &TensorWithLayoutNumDims; |
514 | handle_methods.dim = &TensorWithLayoutDim; |
515 | handle_methods.deallocator = &TensorWithLayoutDeallocator; |
516 | handle_methods.summarize = &TensorWithLayoutSummarize; |
517 | return TFE_NewCustomDeviceTensorHandle(context, name_.c_str(), dtype, |
518 | /*data=*/t.release(), handle_methods, |
519 | status); |
520 | } |
521 | |
522 | void DTensorDevice::RecordInShapeLayoutCache(const TensorWithLayout& tensor) { |
523 | auto existing = shape_layout_cache_.insert( |
524 | {FingerprintShape(tensor.global_shape()), |
525 | CachedLayout{tensor.layout(), /*is_unique=*/true}}); |
526 | |
527 | if (!existing.second) { |
528 | // There is an entry already; if the layout doesn't match we should record |
529 | // the fact that it's not unique. |
530 | if (tensor.layout() != existing.first->second.layout) { |
531 | existing.first->second.is_unique = false; |
532 | } |
533 | } |
534 | } |
535 | |
536 | bool DTensorDevice::is_remote_mesh(const Mesh& mesh) const { |
537 | // An empty mesh might be assigned to VarHandleOp during DTensor MLIR lowering |
538 | // pass. Decide whether the empty mesh is remote based on the current default |
539 | // mesh. |
540 | return mesh.is_remote() || |
541 | (mesh.IsEmpty() && default_mesh_->mesh_config().is_remote()); |
542 | } |
543 | |
544 | StatusOr<NameAttrList> FetchAttributes(const TFE_OpAttrs* attributes) { |
545 | // TODO(allenl): Should we just give up on the public C API to save on |
546 | // serialization/deserialization? We need all of the attributes and to treat |
547 | // them generically, which isn't going to be pleasant with typed attribute |
548 | // methods. |
549 | std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> serialized_attributes( |
550 | TF_NewBuffer(), TF_DeleteBuffer); |
551 | |
552 | TF_Status* status = TF_NewStatus(); |
553 | TFE_OpAttrsSerialize(attributes, serialized_attributes.get(), status); |
554 | if (TF_GetCode(status) == TF_OK) { |
555 | TF_DeleteStatus(status); |
556 | } else { |
557 | Status failure_status = StatusFromTF_Status(status); |
558 | TF_DeleteStatus(status); |
559 | return failure_status; |
560 | } |
561 | |
562 | NameAttrList name_and_attrs; |
563 | if (!name_and_attrs.ParseFromArray(serialized_attributes->data, |
564 | serialized_attributes->length)) { |
565 | return tensorflow::errors::Unknown("Could not parse attributes" ); |
566 | } |
567 | return name_and_attrs; |
568 | } |
569 | |
570 | StatusOr<Layout> (const TFE_OpAttrs* attributes, |
571 | absl::string_view attribute_name) { |
572 | // Get attributes. |
573 | TF_ASSIGN_OR_RETURN(NameAttrList name_and_attrs, FetchAttributes(attributes)); |
574 | |
575 | // Get layout string from attributes. |
576 | absl::string_view layout_str = |
577 | name_and_attrs.attr().find(std::string(attribute_name))->second.s(); |
578 | |
579 | // This would probably be slow at the moment without caching. |
580 | // We should consider making this faster in the future. |
581 | return Layout::FromString(string(layout_str)); |
582 | } |
583 | |
584 | std::string DTensorDevice::FetchLayout(TFE_Context* context, |
585 | TFE_TensorHandle* input, |
586 | TF_Status* status) { |
587 | VLOG(1) << "Checking layout..." ; |
588 | const char* input_device = TFE_TensorHandleDeviceName(input, status); |
589 | if (input_device != name_) { |
590 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
591 | "FetchLayout expects a tensor placed on the layout device." ); |
592 | return {}; |
593 | } |
594 | TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
595 | TFE_TensorHandleDevicePointer(input, status)); |
596 | if (TF_GetCode(status) != TF_OK) return {}; |
597 | return t->layout().ToString(); |
598 | } |
599 | |
600 | std::vector<TFE_TensorHandle*> DTensorDevice::Unpack(TFE_Context* context, |
601 | TFE_TensorHandle* input, |
602 | TF_Status* status) { |
603 | std::vector<TFE_TensorHandle*> outputs; |
604 | |
605 | const char* input_device = TFE_TensorHandleDeviceName(input, status); |
606 | if (TF_GetCode(status) != TF_OK) return outputs; |
607 | if (input_device != name_) { |
608 | TF_SetStatus( |
609 | status, TF_INVALID_ARGUMENT, |
610 | absl::StrCat( |
611 | "DTensorUnpack expects a tensor placed on the DTensor device: " , |
612 | name_, ", but input was placed on device: " , input_device) |
613 | .c_str()); |
614 | return outputs; |
615 | } |
616 | TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
617 | TFE_TensorHandleDevicePointer(input, status)); |
618 | if (TF_GetCode(status) != TF_OK) return outputs; |
619 | |
620 | if (is_remote_mesh(t->mesh().mesh_config())) { |
621 | TF_SetStatus(status, TF_UNIMPLEMENTED, |
622 | "DTensorUnpack is not supported on a remote mesh." ); |
623 | return outputs; |
624 | } |
625 | const int output_size = t->num_tensors(); |
626 | outputs.resize(output_size); |
627 | |
628 | for (int output_index = 0; output_index < output_size; ++output_index) { |
629 | outputs[output_index] = |
630 | TFE_TensorHandleCopySharingTensor(t->get_tensor(output_index), status); |
631 | if (TF_GetCode(status) != TF_OK) { |
632 | return outputs; |
633 | } |
634 | } |
635 | return outputs; |
636 | } |
637 | |
638 | void DTensorDevice::MaybeHandleDTensorCustomOps( |
639 | const char* operation_name, const int num_inputs, |
640 | const TFE_OpAttrs* attributes, TFE_Context* context, |
641 | TFE_TensorHandle** inputs, int* num_outputs, TFE_TensorHandle** outputs, |
642 | bool* is_custom_dtensor_op, TF_Status* status) { |
643 | *is_custom_dtensor_op = true; |
644 | if (operation_name == std::string("_EagerConst" )) { |
645 | // Op-by-op const has no obvious layout. DTensor skips an SPMD expansion and |
646 | // instead relies on copy-on when the value is used later. |
647 | std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op( |
648 | TFE_NewOp(context, operation_name, status), TFE_DeleteOp); |
649 | if (TF_GetCode(status) != TF_OK) return; |
650 | for (int input_index = 0; input_index < num_inputs; ++input_index) { |
651 | TFE_OpAddInput(op.get(), inputs[input_index], status); |
652 | if (TF_GetCode(status) != TF_OK) return; |
653 | } |
654 | TFE_OpAddAttrs(op.get(), attributes); |
655 | TFE_Execute(op.get(), outputs, num_outputs, status); |
656 | return; |
657 | } |
658 | if (operation_name == std::string("CopyToMesh" )) { |
659 | CopyToMesh(context, num_inputs, inputs, attributes, outputs, num_outputs, |
660 | status); |
661 | return; |
662 | } |
663 | |
664 | *is_custom_dtensor_op = false; |
665 | } |
666 | |
667 | void DTensorDevice::CopyToMesh(TFE_Context* context, int num_inputs, |
668 | TFE_TensorHandle** inputs, |
669 | const TFE_OpAttrs* attributes, |
670 | TFE_TensorHandle** outputs, int* num_outputs, |
671 | TF_Status* status) { |
672 | if (num_inputs != 1) { |
673 | RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
674 | "DTensor CopyToMesh requires exactly 1 input." ); |
675 | } |
676 | if (*num_outputs < 1) { |
677 | RETURN_STATUS(status, TF_INTERNAL, |
678 | "DTensor CopyToMesh must have output buffer to allocate at " |
679 | "least 1 output." ); |
680 | } |
681 | |
682 | // Assign layout. |
683 | StatusOr<Layout> target_layout_or = |
684 | FetchLayoutFromAttributes(attributes, kQualifiedLayoutAttr); |
685 | if (!target_layout_or.ok()) { |
686 | RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
687 | "DTensor CopyToMesh requires valid layout attribute for " |
688 | "destination DTensor." ); |
689 | } |
690 | |
691 | const Layout target_layout = *target_layout_or; |
692 | const Mesh& target_mesh = target_layout.mesh(); |
693 | |
694 | // TODO(b/193443769): Support sharded layout for eager copy to mesh. |
695 | if (!target_layout.IsFullyReplicated()) { |
696 | RETURN_STATUS(status, TF_UNIMPLEMENTED, |
697 | "Target layout of DTensor CopyToMesh must be replicated. " |
698 | "Consider changing the target layout to replicated layout or " |
699 | "file a bug to the DTensor team (b/193443769)." ); |
700 | } |
701 | |
702 | TFE_TensorHandle* input_tensor = inputs[0]; |
703 | |
704 | // Check that if input tensor is DTensor, then input layout of the DTensor |
705 | // must be replicated. |
706 | const char* input_device = TFE_TensorHandleDeviceName(input_tensor, status); |
707 | if (TF_GetCode(status) != TF_OK) return; |
708 | |
709 | if (name_ == input_device) { |
710 | // Handle input which is on DTensor device already. |
711 | TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
712 | TFE_TensorHandleDevicePointer(input_tensor, status)); |
713 | if (TF_GetCode(status) != TF_OK) return; |
714 | |
715 | if (!t->layout().IsFullyReplicated()) |
716 | RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
717 | "Input tensor to CopyToMesh must be replicated DTensor or " |
718 | "normal eager Tensor." ); |
719 | |
720 | // If input to CopyToMesh is a DTensor, we use the first local tensor as |
721 | // input tensor handle to invoke copy. |
722 | input_tensor = t->get_tensor(0); |
723 | } |
724 | |
725 | auto it = mesh_to_device_map_.find(target_mesh); |
726 | if (it == mesh_to_device_map_.end()) { |
727 | RETURN_STATUS( |
728 | status, TF_INTERNAL, |
729 | "DTensor CopyToMesh target mesh is not registered. Meshes should be " |
730 | "automatically registered. Please file a bug. (component id: 833864)" ); |
731 | } |
732 | |
733 | const MeshWithParallelDevice* target_parallel_mesh = it->second.get(); |
734 | |
735 | // Broadcast non-dtensor value to dtensor. |
736 | std::unique_ptr<TensorWithLayout> wrapper = TensorWithLayout::Broadcast( |
737 | context, input_tensor, *target_parallel_mesh, name_, status); |
738 | if (TF_GetCode(status) != TF_OK) return; |
739 | |
740 | RecordInShapeLayoutCache(*wrapper); |
741 | *num_outputs = 1; |
742 | *outputs = MakeLayoutTensorHandle(context, std::move(wrapper), status); |
743 | } |
744 | |
745 | namespace { |
746 | |
747 | // Verifies that all components have the same dtype and shape. |
748 | // The component shape will be set upon success. |
749 | void VerifyPackTensorShapeAndDtype( |
750 | std::vector<parallel_device::TensorHandlePtr>& components, |
751 | std::vector<int64_t>* component_shape, TF_Status* status) { |
752 | TF_DataType dtype = TFE_TensorHandleDataType(components[0].get()); |
753 | auto size = TFE_TensorHandleNumDims(components[0].get(), status); |
754 | if (TF_GetCode(status) != TF_OK) return; |
755 | component_shape->clear(); |
756 | component_shape->reserve(size); |
757 | for (int i = 0; i < size; ++i) { |
758 | component_shape->push_back( |
759 | TFE_TensorHandleDim(components[0].get(), i, status)); |
760 | if (TF_GetCode(status) != TF_OK) return; |
761 | } |
762 | |
763 | // Verify that the TensorHandle's shape and dtype match all of the component |
764 | // shapes and dtypes. |
765 | for (const auto& component : components) { |
766 | for (int i = 0; i < component_shape->size(); ++i) { |
767 | int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status); |
768 | if (TF_GetCode(status) != TF_OK) return; |
769 | if (tensor_dim != (*component_shape)[i]) { |
770 | TF_SetStatus(status, TF_UNIMPLEMENTED, |
771 | "Components of a PackedTensor must currently all have " |
772 | "the same shape" ); |
773 | return; |
774 | } |
775 | if (TFE_TensorHandleDataType(component.get()) != dtype) { |
776 | TF_SetStatus(status, TF_INTERNAL, |
777 | "Components of a PackedTensor must all have " |
778 | "the same dtype" ); |
779 | return; |
780 | } |
781 | } |
782 | } |
783 | } |
784 | |
785 | // Verifies that all TensorHandles have rank `rank` of dtype `dtype`. |
786 | void VerifyTensorRankAndDType(TFE_TensorHandle** tensors, int num_input, |
787 | int expected_rank, TF_DataType* expected_dtype, |
788 | TF_Status* status) { |
789 | for (int i = 0; i < num_input; ++i) { |
790 | auto actual_rank = TFE_TensorHandleNumDims(tensors[i], status); |
791 | if (TF_GetCode(status) != TF_OK) |
792 | RETURN_STATUS(status, TF_INTERNAL, "Error getting rank of tensor." ); |
793 | if (actual_rank != expected_rank) |
794 | RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
795 | "Rank of tensor did not match the expected rank." ); |
796 | if (expected_dtype != nullptr && |
797 | TFE_TensorHandleDataType(tensors[i]) != *expected_dtype) |
798 | RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
799 | "Dtype of tensor did not match the expected dtype." ); |
800 | } |
801 | } |
802 | |
803 | } // namespace |
804 | |
805 | TFE_TensorHandle* DTensorDevice::Pack(TFE_Context* context, int num_inputs, |
806 | TFE_TensorHandle** inputs, |
807 | const std::string& string_layout, |
808 | TF_Status* status) { |
809 | if (num_inputs < 1) { |
810 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
811 | "DTensorPack requires 1 or more inputs" ); |
812 | return nullptr; |
813 | } |
814 | StatusOr<Layout> target_layout = Layout::FromString(string_layout); |
815 | if (!target_layout.ok()) { |
816 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
817 | "Failed to parse layout from string layout" ); |
818 | return nullptr; |
819 | } |
820 | const Mesh& target_mesh = target_layout->mesh(); |
821 | const MeshWithParallelDevice* target_parallel_device = |
822 | mesh_to_device_map_[target_mesh].get(); |
823 | if (target_parallel_device == nullptr) { |
824 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
825 | absl::StrCat("Required mesh : " , target_mesh.ToString(), |
826 | "is not registered with DTensor " ) |
827 | .c_str()); |
828 | return nullptr; |
829 | } |
830 | |
831 | std::unique_ptr<TensorWithLayout> packed_tensor; |
832 | if (is_remote_mesh(target_parallel_device->mesh_config())) { |
833 | // Create a dummy output for DTensorPack if inputs are on a remote mesh. |
834 | TF_DataType dtype = TFE_TensorHandleDataType(inputs[0]); |
835 | auto size = TFE_TensorHandleNumDims(inputs[0], status); |
836 | if (TF_GetCode(status) != TF_OK) return nullptr; |
837 | std::vector<int64_t> component_shape; |
838 | component_shape.reserve(size); |
839 | for (int i = 0; i < size; ++i) { |
840 | component_shape.push_back(TFE_TensorHandleDim(inputs[0], i, status)); |
841 | if (TF_GetCode(status) != TF_OK) return nullptr; |
842 | } |
843 | packed_tensor = TensorWithLayout::Dummy( |
844 | component_shape, dtype, *target_parallel_device, *target_layout); |
845 | |
846 | } else { |
847 | auto local_devices = target_parallel_device->mesh_config().local_devices(); |
848 | |
849 | if (num_inputs != |
850 | target_parallel_device->parallel_device().num_underlying_devices()) { |
851 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
852 | absl::StrCat("The dtensor device " , name_, " expected " , |
853 | local_devices.size(), |
854 | " inputs to DTensorPack, but got " , num_inputs) |
855 | .c_str()); |
856 | return nullptr; |
857 | } |
858 | |
859 | std::vector<parallel_device::TensorHandlePtr> components; |
860 | components.reserve(num_inputs); |
861 | for (int i = 0; i < num_inputs; ++i) { |
862 | TFE_TensorHandle* input = inputs[i]; |
863 | const char* input_device = TFE_TensorHandleDeviceName(input, status); |
864 | if (TF_GetCode(status) != TF_OK) return nullptr; |
865 | if (name_ == input_device) { |
866 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
867 | "Does not support packing a Tensor that is already on " |
868 | "dtensor device" ); |
869 | return nullptr; |
870 | } |
871 | // If `input` is on the target device, this creates a new handle sharing |
872 | // the underlying data; otherwise, async copies are invoked. |
873 | components.emplace_back(TFE_TensorHandleCopyToDevice( |
874 | input, context, local_devices[i].c_str(), status)); |
875 | if (TF_GetCode(status) != TF_OK) return nullptr; |
876 | } |
877 | |
878 | std::vector<int64_t> component_shape; |
879 | VerifyPackTensorShapeAndDtype(components, &component_shape, status); |
880 | if (TF_GetCode(status) != TF_OK) return nullptr; |
881 | |
882 | std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = |
883 | parallel_device::ParallelTensor::FromTensorHandles( |
884 | target_parallel_device->parallel_device(), std::move(components), |
885 | status); |
886 | if (TF_GetCode(status) != TF_OK) return nullptr; |
887 | |
888 | if (target_layout->rank() != component_shape.size()) { |
889 | TF_SetStatus( |
890 | status, TF_INVALID_ARGUMENT, |
891 | absl::StrCat( |
892 | "Packed layout should have the same rank as the rank for each " |
893 | "component. The rank of each component is: " , |
894 | component_shape.size(), |
895 | ", while layout has rank: " , target_layout->rank(), |
896 | "\nLayout: " , target_layout->ToString(), "\n" ) |
897 | .c_str()); |
898 | return nullptr; |
899 | } |
900 | |
901 | packed_tensor = |
902 | TensorWithLayout::Wrap(std::move(parallel_tensor), |
903 | *target_parallel_device, *target_layout) |
904 | .value(); |
905 | } |
906 | |
907 | RecordInShapeLayoutCache(*packed_tensor); |
908 | TFE_TensorHandle* output = |
909 | MakeLayoutTensorHandle(context, std::move(packed_tensor), status); |
910 | if (TF_GetCode(status) != TF_OK) return nullptr; |
911 | return output; |
912 | } |
913 | |
914 | TFE_TensorHandle* DTensorDevice::SparsePack( |
915 | TFE_Context* context, int num_inputs, TFE_TensorHandle** indices, |
916 | TFE_TensorHandle** values, TFE_TensorHandle** shapes, |
917 | const std::string& string_layout, TF_Status* status) { |
918 | StatusOr<Layout> target_layout = Layout::FromString(string_layout); |
919 | if (!target_layout.ok()) { |
920 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
921 | "Failed to parse layout from string layout" ); |
922 | return nullptr; |
923 | } |
924 | const Mesh& target_mesh = target_layout->mesh(); |
925 | const MeshWithParallelDevice* target_parallel_device = |
926 | mesh_to_device_map_[target_mesh].get(); |
927 | if (target_parallel_device == nullptr) { |
928 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
929 | absl::StrCat("Required mesh : " , target_mesh.ToString(), |
930 | "is not registered with DTensor " ) |
931 | .c_str()); |
932 | return nullptr; |
933 | } |
934 | |
935 | TF_DataType tf_int64 = TF_INT64; |
936 | // Verify rank and dtype of shapes. |
937 | VerifyTensorRankAndDType(shapes, num_inputs, /*expected_rank=*/1, |
938 | /*expected_dtype=*/&tf_int64, status); |
939 | if (TF_GetCode(status) != TF_OK) return nullptr; |
940 | |
941 | // Verify rank and dtype of indices. |
942 | VerifyTensorRankAndDType(indices, num_inputs, /*expected_rank=*/2, |
943 | /*expected_dtype=*/&tf_int64, status); |
944 | if (TF_GetCode(status) != TF_OK) return nullptr; |
945 | |
946 | // Verify rank of values. |
947 | VerifyTensorRankAndDType(values, num_inputs, /*expected_rank=*/1, |
948 | /*expected_dtype=*/nullptr, status); |
949 | if (TF_GetCode(status) != TF_OK) return nullptr; |
950 | |
951 | // Compute the local shape from a shape tensor. |
952 | std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> shape_tensor( |
953 | TFE_TensorHandleResolve(shapes[0], status), TF_DeleteTensor); |
954 | if (TF_GetCode(status) != TF_OK) { |
955 | TF_SetStatus( |
956 | status, TF_GetCode(status), |
957 | absl::StrCat("Error resolving the tensor handle of shape tensor" |
958 | ". Original message: " , |
959 | TF_Message(status)) |
960 | .c_str()); |
961 | return nullptr; |
962 | } |
963 | int shape_tensor_size = TFE_TensorHandleDim(shapes[0], 0, status); |
964 | if (TF_GetCode(status) != TF_OK || shape_tensor_size <= 0) { |
965 | TF_SetStatus(status, TF_GetCode(status), |
966 | absl::StrCat("Error computing the num dims of shape tensor" , |
967 | TF_Message(status)) |
968 | .c_str()); |
969 | return nullptr; |
970 | } |
971 | |
972 | const int64_t* data = |
973 | static_cast<int64_t*>(TF_TensorData(shape_tensor.get())); |
974 | std::vector<int64_t> local_shape(data, data + shape_tensor_size); |
975 | if (local_shape.size() != target_layout->rank()) { |
976 | TF_SetStatus( |
977 | status, TF_INVALID_ARGUMENT, |
978 | absl::StrCat( |
979 | "Packed layout should have the same rank as the rank for each " |
980 | "component. The rank of each component is: " , |
981 | local_shape.size(), |
982 | ", while layout has rank: " , target_layout->rank(), |
983 | "\nLayout: " , target_layout->ToString(), "\n" ) |
984 | .c_str()); |
985 | return nullptr; |
986 | } |
987 | |
988 | // Create the SparseTensorWithLayout. |
989 | std::unique_ptr<TensorWithLayout> packed_tensor; |
990 | if (is_remote_mesh(target_parallel_device->mesh_config())) { |
991 | // Create a dummy SparseTensorWithLayout. |
992 | packed_tensor = SparseTensorWithLayout::Dummy( |
993 | local_shape, *target_parallel_device, target_layout.value()); |
994 | } else { |
995 | // Parse the indices, values, and dense_shape tensors and put them into |
996 | // parallel tensors, and then pack it into a single SparseTensorWithLayout. |
997 | auto local_devices = target_parallel_device->mesh_config().local_devices(); |
998 | |
999 | std::vector<parallel_device::TensorHandlePtr> indices_components; |
1000 | std::vector<parallel_device::TensorHandlePtr> values_components; |
1001 | std::vector<parallel_device::TensorHandlePtr> dense_shapes_components; |
1002 | |
1003 | // Just a nice trick to make code cleaner to pack each of indices, values, |
1004 | // shapes. |
1005 | std::vector<std::vector<parallel_device::TensorHandlePtr>*> components{ |
1006 | &indices_components, &values_components, &dense_shapes_components}; |
1007 | std::vector<TFE_TensorHandle**> input_vectors{indices, values, shapes}; |
1008 | for (int component_index = 0; component_index < 3; ++component_index) { |
1009 | components[component_index]->reserve(num_inputs); |
1010 | TFE_TensorHandle** inputs = input_vectors[component_index]; |
1011 | for (int i = 0; i < num_inputs; ++i) { |
1012 | const char* input_device = |
1013 | TFE_TensorHandleDeviceName(inputs[i], status); |
1014 | if (TF_GetCode(status) != TF_OK) return nullptr; |
1015 | if (name_ == input_device) { |
1016 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
1017 | "Does not support packing a Tensor that is already on " |
1018 | "dtensor device." ); |
1019 | return nullptr; |
1020 | } |
1021 | |
1022 | components[component_index]->emplace_back(TFE_TensorHandleCopyToDevice( |
1023 | inputs[i], context, local_devices[i].c_str(), status)); |
1024 | if (TF_GetCode(status) != TF_OK) return nullptr; |
1025 | } |
1026 | } |
1027 | std::unique_ptr<parallel_device::ParallelTensor> parallel_indices_tensor = |
1028 | parallel_device::ParallelTensor::FromTensorHandles( |
1029 | target_parallel_device->parallel_device(), |
1030 | std::move(indices_components), status); |
1031 | |
1032 | std::unique_ptr<parallel_device::ParallelTensor> parallel_values_tensor = |
1033 | parallel_device::ParallelTensor::FromTensorHandles( |
1034 | target_parallel_device->parallel_device(), |
1035 | std::move(values_components), status); |
1036 | |
1037 | std::unique_ptr<parallel_device::ParallelTensor> |
1038 | parallel_dense_shapes_tensor = |
1039 | parallel_device::ParallelTensor::FromTensorHandles( |
1040 | target_parallel_device->parallel_device(), |
1041 | std::move(dense_shapes_components), status); |
1042 | |
1043 | if (TF_GetCode(status) != TF_OK) return nullptr; |
1044 | packed_tensor = |
1045 | SparseTensorWithLayout::Wrap(std::move(parallel_indices_tensor), |
1046 | std::move(parallel_values_tensor), |
1047 | std::move(parallel_dense_shapes_tensor), |
1048 | *target_parallel_device, |
1049 | target_layout.value(), local_shape) |
1050 | .value(); |
1051 | } |
1052 | |
1053 | RecordInShapeLayoutCache(*packed_tensor); |
1054 | TFE_TensorHandle* output = |
1055 | MakeLayoutTensorHandle(context, std::move(packed_tensor), status); |
1056 | if (TF_GetCode(status) != TF_OK) return nullptr; |
1057 | return output; |
1058 | } |
1059 | |
1060 | bool DTensorDevice::IsSparseDTensor(TFE_Context* context, |
1061 | TFE_TensorHandle* input, |
1062 | TF_Status* status) { |
1063 | const char* input_device = TFE_TensorHandleDeviceName(input, status); |
1064 | if (input_device != name_) { |
1065 | TF_SetStatus( |
1066 | status, TF_INVALID_ARGUMENT, |
1067 | "DTensorSparseUnpack expects a tensor placed on the DTensor device." ); |
1068 | return false; |
1069 | } |
1070 | TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
1071 | TFE_TensorHandleDevicePointer(input, status)); |
1072 | if (TF_GetCode(status) != TF_OK) return false; |
1073 | return t->tensor_type() == TensorType::kSparse; |
1074 | } |
1075 | |
1076 | void DTensorDevice::UpdateOutputLayoutsWithSameShapePolicy( |
1077 | const std::vector<PartialTensorShape>& global_output_shapes, |
1078 | const absl::flat_hash_set<Mesh>& input_meshes, absl::string_view op_name, |
1079 | tensorflow::Graph* graph, std::vector<const Layout*>* output_layouts, |
1080 | TF_Status* status) { |
1081 | if (!same_shape_policy_enabled_) return; |
1082 | // Simply do not hint if inputs span across multiple meshes. |
1083 | if (input_meshes.size() > 1) return; |
1084 | |
1085 | for (Node* node : graph->op_nodes()) { |
1086 | if (!node->IsRetval()) { |
1087 | continue; |
1088 | } |
1089 | int output_index; |
1090 | RETURN_C_STATUS_IF_NOT_OK( |
1091 | GetNodeAttr(node->attrs(), "index" , &output_index), status); |
1092 | if (output_layouts->at(output_index)) { |
1093 | continue; |
1094 | } |
1095 | |
1096 | const auto& global_output_shape = global_output_shapes.at(output_index); |
1097 | const Layout* layout = nullptr; |
1098 | // TODO(b/180022708): This is useful information, we should be |
1099 | // able to hint to layout propagation without making it a hard |
1100 | // requirement |
1101 | // |
1102 | // Special cases at the moment: |
1103 | // - Relayout needs an exemption. |
1104 | // - VarHandleOp does not need hint. VarHandleOp has scalar shape so layout |
1105 | // is trivial. On the other hande, downstream system "thinks' Variable has |
1106 | // shape same as the pointing value. So, providing a layout based on |
1107 | // VarHandleOp (scalar) might confuse the downstream system. |
1108 | if (op_name != std::string("Relayout" ) && |
1109 | op_name != std::string("VarHandleOp" )) { |
1110 | // TODO(b/162009702): Support matching between partially-known shapes. |
1111 | if (global_output_shape.IsFullyDefined()) { |
1112 | gtl::InlinedVector<int64, 4> shape_vector( |
1113 | global_output_shape.dim_sizes()); |
1114 | auto layout_iterator = |
1115 | shape_layout_cache_.find(FingerprintShape(shape_vector)); |
1116 | if (layout_iterator != shape_layout_cache_.end() && |
1117 | layout_iterator->second.is_unique) { |
1118 | // We have a cached layout for this shape. Send it to MLIR. |
1119 | layout = &layout_iterator->second.layout; |
1120 | VLOG(3) << op_name << ": found a cached layout for shape " |
1121 | << global_output_shape.DebugString() << ": \"" |
1122 | << layout->ToString() << "\"" ; |
1123 | if (input_meshes.empty() && |
1124 | layout->mesh() != default_mesh_->mesh_config()) { |
1125 | VLOG(3) << "But we can't infer a input mesh and cached layout: " |
1126 | << "mesh \"" << (layout->mesh().ToString()) << " " |
1127 | << "is different than the default mesh : \"" |
1128 | << default_mesh_->mesh_config().ToString() << "\"\n" |
1129 | << "Not applying the cached layout." ; |
1130 | } else if (!input_meshes.empty() && |
1131 | layout->mesh() != *input_meshes.begin()) { |
1132 | VLOG(3) |
1133 | << "But the layout mesh is different than the executing mesh: " |
1134 | << "\"" << (*input_meshes.begin()).ToString() << "\"\n" |
1135 | << "Not applying the cached layout." ; |
1136 | } else { |
1137 | (*output_layouts)[output_index] = layout; |
1138 | node->AddAttr(kDefaultLayoutAttr, layout->ToString()); |
1139 | } |
1140 | } else if (layout_iterator == shape_layout_cache_.end()) { |
1141 | VLOG(3) << op_name << ": no cached layout found for shape " |
1142 | << global_output_shape.DebugString(); |
1143 | } else { |
1144 | VLOG(3) << op_name << ": found multiple layouts for shape " |
1145 | << global_output_shape.DebugString(); |
1146 | } |
1147 | } else { |
1148 | VLOG(3) << op_name |
1149 | << ": not applying same-shape-same-layout due to " |
1150 | "not-fully-known shape " |
1151 | << global_output_shape.DebugString(); |
1152 | } |
1153 | } |
1154 | } |
1155 | } |
1156 | |
1157 | std::unordered_map<std::string, int> |
1158 | DTensorDevice::GetFunctionCacheHitAndMissCount(TFE_Context* context, |
1159 | TF_Status* status) const { |
1160 | return function_compilation_hits_and_misses_; |
1161 | } |
1162 | |
1163 | // From `graph` containing computation for all meshes, extract/select |
1164 | // computation for mesh specified in `function`. Returned graph is a cloned |
1165 | // graph with ops only for single mesh execution. |
1166 | StatusOr<std::unique_ptr<Graph>> SelectGraphToExecute( |
1167 | const TranslatedFunction& function, const Graph& graph, |
1168 | std::string* stateful_partitioned_call_name) { |
1169 | auto new_graph = std::make_unique<Graph>(graph.flib_def()); |
1170 | CopyGraph(graph, new_graph.get()); |
1171 | std::vector<Node*> arg_nodes; |
1172 | std::vector<Node*> retval_nodes; |
1173 | for (Node* node : new_graph->nodes()) { |
1174 | if (node->IsArg()) arg_nodes.emplace_back(node); |
1175 | if (node->IsRetval()) retval_nodes.emplace_back(node); |
1176 | } |
1177 | |
1178 | // Remove irrelevant function calls. |
1179 | for (Node* node : new_graph->nodes()) { |
1180 | if (node->op_def().name() != "StatefulPartitionedCall" ) continue; |
1181 | |
1182 | if (node->name() != function.node_to_execute->name()) { |
1183 | // Remove function call that does not match mesh specification and all |
1184 | // output retval nodes connected to the function call node. |
1185 | std::queue<Node*> nodes_to_remove; |
1186 | nodes_to_remove.push(node); |
1187 | while (!nodes_to_remove.empty()) { |
1188 | Node* n = nodes_to_remove.front(); |
1189 | for (const Edge* out_edge : n->out_edges()) { |
1190 | if (out_edge->IsControlEdge()) continue; |
1191 | Node* out_node = out_edge->dst(); |
1192 | if (!out_node->IsSink()) nodes_to_remove.push(out_node); |
1193 | } |
1194 | if (n->IsRetval()) { |
1195 | auto pos = std::find(retval_nodes.begin(), retval_nodes.end(), n); |
1196 | TF_RET_CHECK(pos != retval_nodes.end()); |
1197 | retval_nodes.erase(pos); |
1198 | } |
1199 | nodes_to_remove.pop(); |
1200 | new_graph->RemoveNode(n); |
1201 | } |
1202 | } |
1203 | } |
1204 | |
1205 | *stateful_partitioned_call_name = function.node_to_execute->name(); |
1206 | VLOG(1) << "Selected call " << *stateful_partitioned_call_name; |
1207 | |
1208 | // Remove unused arg nodes in graph. |
1209 | for (auto it = arg_nodes.begin(); it != arg_nodes.end(); it++) { |
1210 | Node* arg_node = *it; |
1211 | bool arg_unused = true; |
1212 | for (const Edge* e : arg_node->out_edges()) { |
1213 | if (e->dst()->IsOp()) { |
1214 | arg_unused = false; |
1215 | } |
1216 | } |
1217 | if (!arg_unused) continue; |
1218 | |
1219 | new_graph->RemoveNode(arg_node); |
1220 | arg_nodes.erase(it--); |
1221 | } |
1222 | |
1223 | // Reset index attributes for arg and retval nodes. |
1224 | for (Node* n : new_graph->nodes()) { |
1225 | // Reset arg node index attributes to its position within all the arg |
1226 | // nodes. This should just be increasing from 0 to n where n |
1227 | // is the total number of arguments. Note that this definition to |
1228 | // the `index` attribute is different from the definition we set in |
1229 | // PrepareGraphForMLIR. |
1230 | // This attribute is needed for each arg node when converting a Graph to |
1231 | // a FunctionDef. |
1232 | if (n->IsArg()) { |
1233 | auto pos = std::find(arg_nodes.begin(), arg_nodes.end(), n); |
1234 | TF_RET_CHECK(pos != arg_nodes.end()); |
1235 | const int new_index = std::distance(arg_nodes.begin(), pos); |
1236 | n->AddAttr("index" , new_index); |
1237 | } |
1238 | |
1239 | // Reset retval nodes index attributes. |
1240 | if (n->IsRetval()) { |
1241 | auto retval_pos = std::find(retval_nodes.begin(), retval_nodes.end(), n); |
1242 | TF_RET_CHECK(retval_pos != retval_nodes.end()); |
1243 | const int new_index = std::distance(retval_nodes.begin(), retval_pos); |
1244 | n->AddAttr("index" , new_index); |
1245 | } |
1246 | } |
1247 | |
1248 | VLOG(4) << tensorflow::DumpGraphToFile("selected_graph_to_execute_" , |
1249 | *new_graph); |
1250 | |
1251 | return new_graph; |
1252 | } |
1253 | |
1254 | // Adds processed graph to run for each mesh computation in |
1255 | // `execution_functions` to function definition library. |
1256 | Status AddExecutionFunctionDefsToFunctionDefLibrary( |
1257 | const absl::flat_hash_set<Node*>& control_ret_nodes, TFE_Context* context, |
1258 | const Graph& graph, ExecutionFunctions* execution_functions) { |
1259 | // Note: We use node name instead of node pointer for comparison because |
1260 | // node address in the new graph is different with the original graph. |
1261 | absl::flat_hash_set<std::string> control_ret_names; |
1262 | for (auto* n : control_ret_nodes) { |
1263 | control_ret_names.emplace(n->name()); |
1264 | } |
1265 | for (TranslatedFunction& function : execution_functions->function_list) { |
1266 | std::string selected_call_node_name; |
1267 | // TODO(bfontain): We should just try to call the functions directly rather |
1268 | // than wrap |
1269 | // Construct graph that executes only computation for `function`. |
1270 | TF_ASSIGN_OR_RETURN( |
1271 | std::unique_ptr<Graph> new_graph, |
1272 | SelectGraphToExecute(function, graph, &selected_call_node_name)); |
1273 | VLOG(4) << tensorflow::DumpGraphToFile("selected_graph_" , *new_graph); |
1274 | |
1275 | // Add unique identifier based on the function we are executing to the |
1276 | // function/graph and convert graph to functiondef. |
1277 | NameAttrList func; |
1278 | TF_RETURN_IF_ERROR( |
1279 | GetNodeAttr(function.node_to_execute->attrs(), "f" , &func)); |
1280 | |
1281 | static std::atomic<int64_t> unique_function_number(0); |
1282 | function.translated_function_name = |
1283 | absl::StrCat(func.name(), "_" , unique_function_number.fetch_add(1)); |
1284 | auto control_ret_node_names = |
1285 | [&control_ret_names, &selected_call_node_name]( |
1286 | const Node* node) -> absl::optional<std::string> { |
1287 | // Add the stateful partitioned call node as a control return as we need |
1288 | // to process any control deps inside the inner function. |
1289 | if (control_ret_names.contains(node->name()) || |
1290 | node->name() == selected_call_node_name) { |
1291 | return node->name(); |
1292 | } |
1293 | return absl::nullopt; |
1294 | }; |
1295 | |
1296 | tensorflow::FunctionDef to_run; |
1297 | TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef( |
1298 | *new_graph, function.translated_function_name, control_ret_node_names, |
1299 | &to_run)); |
1300 | |
1301 | for (const auto& out : to_run.signature().output_arg()) { |
1302 | function.output_dtypes.emplace_back(static_cast<TF_DataType>(out.type())); |
1303 | } |
1304 | |
1305 | AddDTensorFunctionAttr(to_run); |
1306 | TF_RETURN_IF_ERROR(tensorflow::unwrap(context)->AddFunctionDef(to_run)); |
1307 | } |
1308 | |
1309 | return OkStatus(); |
1310 | } |
1311 | |
1312 | void DTensorDevice::LowerToSPMDFunction( |
1313 | TFE_Context* context, const std::vector<TensorWithLayout*>& inputs, |
1314 | const DTensorOperation& doperation, const TFE_OpAttrs* attributes, |
1315 | const int num_outputs, const ExecutionFunctions** execution_functions, |
1316 | TF_Status* status) { |
1317 | profiler::TraceMe activity( |
1318 | [&] { return "DTensorDevice::LowerToSPMDFunction" ; }, |
1319 | profiler::TraceMeLevel::kInfo); |
1320 | FunctionLibraryDefinition* flib_def = |
1321 | tensorflow::unwrap(context)->FuncLibDef(); |
1322 | auto graph(std::make_unique<tensorflow::Graph>(flib_def)); |
1323 | NameAttrList eager_attributes; |
1324 | ASSIGN_OR_RETURN_C_STATUS(eager_attributes, FetchAttributes(attributes), |
1325 | status); |
1326 | |
1327 | std::vector<PartialTensorShape> global_output_shapes; |
1328 | std::vector<const Layout*> output_layouts; |
1329 | const FunctionDef* function_def = doperation.function_def; |
1330 | if (!function_def) { |
1331 | // Output layouts of an eager op (e.g. fill) must be inferred before cache |
1332 | // key computation, since they might depend on the current DTensorDevice |
1333 | // state. |
1334 | Status s = PrepareGraphForMlir( |
1335 | function_manager_, inputs, doperation, *flib_def, eager_attributes, |
1336 | default_layout_, graph.get(), &global_output_shapes, &output_layouts); |
1337 | RETURN_C_STATUS_IF_NOT_OK(s, status); |
1338 | |
1339 | // Finds all meshes the inputs are lied on. |
1340 | absl::flat_hash_set<Mesh> input_meshes; |
1341 | for (const TensorWithLayout* tensor : inputs) { |
1342 | if (!tensor->layout().mesh().IsEmpty()) { |
1343 | input_meshes.insert(tensor->layout().mesh()); |
1344 | } |
1345 | } |
1346 | // Currently we only provide layout hints for op-by-op, since |
1347 | // they interact badly with layout propagation. |
1348 | UpdateOutputLayoutsWithSameShapePolicy(global_output_shapes, input_meshes, |
1349 | doperation.name, graph.get(), |
1350 | &output_layouts, status); |
1351 | if (TF_GetCode(status) != TF_OK) return; |
1352 | } |
1353 | |
1354 | std::pair<tensorflow::Fprint128, const ExecutionFunctions*> |
1355 | cache_key_and_func = function_manager_.GetCachedFunction( |
1356 | doperation, eager_attributes, inputs, output_layouts); |
1357 | *execution_functions = cache_key_and_func.second; |
1358 | if (*execution_functions != nullptr) { |
1359 | function_compilation_hits_and_misses_["hit" ]++; |
1360 | return; |
1361 | } else if (function_def) { |
1362 | function_compilation_hits_and_misses_["miss" ]++; |
1363 | LOG(INFO) << "DTensor cache key lookup missed for " << doperation.name |
1364 | << ". DTensor is (re-)computing its SPMD transformation." ; |
1365 | } |
1366 | |
1367 | // It includes remote devices when the coordination service is enabled. |
1368 | const auto device_list = tensorflow::unwrap(context)->ListAllTfDevices(); |
1369 | DeviceSet device_set; |
1370 | for (const auto device : device_list) device_set.AddDevice(device); |
1371 | |
1372 | if (function_def) { |
1373 | ASSIGN_OR_RETURN_C_STATUS(auto device_name_to_mesh_device, |
1374 | PipelineSubMeshes(context), status); |
1375 | const bool is_pipelining_function = !device_name_to_mesh_device.empty(); |
1376 | // For a multi-mesh function for pipelining, we take a different execution |
1377 | // path. Call the partitioner to lower and partition the graph into multiple |
1378 | // sub functions to execute (one per sub mesh). |
1379 | if (is_pipelining_function) { |
1380 | ASSIGN_OR_RETURN_C_STATUS( |
1381 | ExecutionFunctions functions, |
1382 | PipeliningPartitionerRun(&device_name_to_mesh_device, flib_def, |
1383 | &pass_runner_, *doperation.function_def, |
1384 | eager_attributes, inputs, device_set, |
1385 | num_outputs), |
1386 | status); |
1387 | *execution_functions = function_manager_.AddCachedFunction( |
1388 | doperation, cache_key_and_func.first, std::move(functions)); |
1389 | return; |
1390 | } |
1391 | // Output layouts of a function are inferred by MLIR lowering. They are |
1392 | // not necessary for cache key computation, so run PrepareGraphForMlir after |
1393 | // cache key computation to reduce the overheads of running the same |
1394 | // function multiple times. |
1395 | Status s = PrepareGraphForMlir( |
1396 | function_manager_, inputs, doperation, *flib_def, eager_attributes, |
1397 | default_layout_, graph.get(), &global_output_shapes, &output_layouts); |
1398 | RETURN_C_STATUS_IF_NOT_OK(s, status); |
1399 | } |
1400 | |
1401 | absl::flat_hash_set<Node*> control_ret_nodes; |
1402 | // Run DTensor MLIR passes that convert input graph to SPMD version. |
1403 | { |
1404 | profiler::TraceMe activity([&] { return "DTensorDevice::RunMLIRPasses" ; }, |
1405 | profiler::TraceMeLevel::kInfo); |
1406 | RETURN_C_STATUS_IF_NOT_OK( |
1407 | pass_runner_.RunOnGraph(device_set, doperation.is_func(), flib_def, |
1408 | &graph, control_ret_nodes, |
1409 | cache_key_and_func.first), |
1410 | status); |
1411 | } |
1412 | VLOG(4) << tensorflow::DumpGraphToFile("after_mlir_spmd_lowering" , *graph, |
1413 | flib_def); |
1414 | if (flib_def->Contains(kLoadEmbeddingFn)) { |
1415 | Status s = InsertFunctionForTPUEmbeddingCheckpoint( |
1416 | status, graph.get(), inputs, kLoadEmbeddingFn); |
1417 | RETURN_C_STATUS_IF_NOT_OK(s, status); |
1418 | } |
1419 | |
1420 | // After MLIR transformations, exactly one StatefulPartitionedCall op is |
1421 | // returned for mesh cluster in computation. Identity all functions to execute |
1422 | // for each mesh and relevant input and output information. |
1423 | ASSIGN_OR_RETURN_C_STATUS( |
1424 | ExecutionFunctions functions, |
1425 | IdentifyAllFunctionsToExecute(*graph.get(), global_output_shapes), |
1426 | status); |
1427 | |
1428 | // In order to ensure that all resource assign operations as well as side |
1429 | // effecting ops are executed, we add identity ops before function outputs |
1430 | // with control rets. |
1431 | RETURN_C_STATUS_IF_NOT_OK(MaybeInsertIdentityNodes(function_def, graph.get()), |
1432 | status); |
1433 | |
1434 | VLOG(4) << tensorflow::DumpGraphToFile("after_post_processing_graph" , *graph, |
1435 | flib_def); |
1436 | |
1437 | RETURN_C_STATUS_IF_NOT_OK( |
1438 | AddExecutionFunctionDefsToFunctionDefLibrary(control_ret_nodes, context, |
1439 | *graph.get(), &functions), |
1440 | status); |
1441 | functions.num_device_ids = 1; |
1442 | if (function_def) { |
1443 | for (TranslatedFunction& function : functions.function_list) { |
1444 | functions.function_mesh_fingerprint = |
1445 | FingerprintCat64(functions.function_mesh_fingerprint, |
1446 | function.function_mesh.GlobalFingerprint()); |
1447 | } |
1448 | } |
1449 | |
1450 | *execution_functions = function_manager_.AddCachedFunction( |
1451 | doperation, cache_key_and_func.first, std::move(functions)); |
1452 | } |
1453 | |
1454 | void DTensorDevice::ExecuteFunctionAndWait( |
1455 | TFE_Context* context, const TranslatedFunction* function_ptr, |
1456 | const MeshWithParallelDevice* parallel_device_mesh, |
1457 | const std::vector<parallel_device::ParallelTensor*>& parallel_inputs, |
1458 | const int64_t step_id, const TFE_OpAttrs* attributes, TF_Status* status) { |
1459 | const std::string mesh_str = function_ptr->function_mesh.ToString(); |
1460 | VLOG(4) << "Launching computation for mesh : " << mesh_str; |
1461 | parallel_device_mesh->parallel_device().StartExecute( |
1462 | context, |
1463 | /*inputs=*/parallel_inputs, |
1464 | /*operation_name=*/function_ptr->translated_function_name.c_str(), |
1465 | /*attributes=*/attributes, |
1466 | /*expected_max_outputs=*/function_ptr->local_output_shapes.size(), |
1467 | /*cancellation_manager=*/*cancellation_manager_, |
1468 | /*step_id=*/step_id); |
1469 | |
1470 | VLOG(4) << "Joining computation result from mesh : " << mesh_str; |
1471 | parallel_device_mesh->parallel_device().Join( |
1472 | function_ptr->local_output_shapes, status); |
1473 | VLOG(4) << "Joining status: " << TF_Message(status); |
1474 | if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_CANCELLED) { |
1475 | LOG(ERROR) << "Encountered error while executing function: " |
1476 | << function_ptr->translated_function_name |
1477 | << " for mesh : " << mesh_str |
1478 | << " / error : " << TF_Message(status); |
1479 | } |
1480 | |
1481 | std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status( |
1482 | TF_NewStatus(), TF_DeleteStatus); |
1483 | AsyncWait(context, async_wait_status.get()); |
1484 | TF_Code error_code = TF_GetCode(async_wait_status.get()); |
1485 | if (error_code != TF_OK && error_code != TF_CANCELLED) { |
1486 | LOG(ERROR) << "Async status: " << TF_Message(async_wait_status.get()); |
1487 | } |
1488 | } |
1489 | |
1490 | void DTensorDevice::ExecuteRegularOperation( |
1491 | TFE_Context* context, const std::vector<TensorWithLayout*>& inputs, |
1492 | const DTensorOperation& doperation, const TFE_OpAttrs* attributes, |
1493 | int* num_outputs, TFE_TensorHandle** outputs, TF_Status* status) { |
1494 | const ExecutionFunctions* execution_functions = nullptr; |
1495 | |
1496 | LowerToSPMDFunction(context, inputs, doperation, attributes, *num_outputs, |
1497 | &execution_functions, status); |
1498 | if (TF_GetCode(status) != TF_OK) return; |
1499 | |
1500 | // Update input layouts for resource arguments. |
1501 | for (const TranslatedFunction& function : |
1502 | execution_functions->function_list) { |
1503 | for (const auto& entry : function.resource_input_layouts) { |
1504 | // TODO(hthu): Add an TensorWithLayout in the inputs vector at location 0 |
1505 | // for DeviceId. This is done as the first arg is always DeviceId, and it |
1506 | // isn't mapped to input Tensors. |
1507 | const int resource_index_to_update = entry.first - 1; |
1508 | inputs[resource_index_to_update]->UpdateLayout(entry.second, status); |
1509 | if (TF_GetCode(status) != TF_OK) { |
1510 | RETURN_STATUS(status, TF_GetCode(status), |
1511 | absl::StrCat("Attempt to update layout input arg: " , |
1512 | resource_index_to_update, |
1513 | ". Original message: " , TF_Message(status)) |
1514 | .c_str()); |
1515 | } |
1516 | } |
1517 | } |
1518 | |
1519 | int num_global_outputs = 0; |
1520 | |
1521 | std::map<std::string, const MeshWithParallelDevice*> |
1522 | function_name_and_mesh_mapping; |
1523 | absl::flat_hash_set<std::string> excluded_fn_names; |
1524 | std::unique_ptr<const TranslatedFunction> epu_fn_ptr, load_embedding_ptr; |
1525 | for (const TranslatedFunction& function : |
1526 | execution_functions->function_list) { |
1527 | StatusOr<Mesh> maybe_converted_mesh = function.function_mesh; |
1528 | if (function.function_mesh.is_epu_mesh()) { |
1529 | maybe_converted_mesh = function.function_mesh.ToDeviceType("CPU" ); |
1530 | } |
1531 | |
1532 | if (!maybe_converted_mesh.ok()) { |
1533 | RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
1534 | absl::StrCat("Failed to convert mesh, get error: " , |
1535 | maybe_converted_mesh.status().error_message()) |
1536 | .c_str()); |
1537 | } |
1538 | const Mesh& mesh = *maybe_converted_mesh; |
1539 | const MeshWithParallelDevice* parallel_device_mesh = |
1540 | mesh_to_device_map_.contains(mesh) ? mesh_to_device_map_[mesh].get() |
1541 | : default_mesh_; |
1542 | if (parallel_device_mesh == nullptr) { |
1543 | RETURN_STATUS(status, TF_INTERNAL, |
1544 | "required mesh is not registered with DTensor device" ); |
1545 | } |
1546 | function_name_and_mesh_mapping[function.translated_function_name] = |
1547 | parallel_device_mesh; |
1548 | |
1549 | if (function.function_mesh.is_epu_mesh()) { |
1550 | if (epu_fn_ptr != nullptr) { |
1551 | RETURN_STATUS(status, TF_INTERNAL, |
1552 | "There are more than one function defined on EPU mesh." ); |
1553 | } |
1554 | epu_fn_ptr = std::make_unique<const TranslatedFunction>(function); |
1555 | excluded_fn_names.insert(function.translated_function_name); |
1556 | } |
1557 | if (absl::StartsWith(function.translated_function_name, kLoadEmbeddingFn)) { |
1558 | if (load_embedding_ptr != nullptr) { |
1559 | RETURN_STATUS(status, TF_INTERNAL, |
1560 | "There are more than one function defined on EPU mesh." ); |
1561 | } |
1562 | load_embedding_ptr = std::make_unique<const TranslatedFunction>(function); |
1563 | excluded_fn_names.insert(function.translated_function_name); |
1564 | } |
1565 | } |
1566 | |
1567 | // Compute the step_id based on the function_mesh_fingerprint and the |
1568 | // corresponding function execution counter. |
1569 | uint64 function_mesh_fingerprint = |
1570 | execution_functions->function_mesh_fingerprint; |
1571 | if (func_mesh_fingerprint_to_step_counter_.contains( |
1572 | function_mesh_fingerprint)) { |
1573 | func_mesh_fingerprint_to_step_counter_.at(function_mesh_fingerprint)++; |
1574 | } else { |
1575 | func_mesh_fingerprint_to_step_counter_.insert( |
1576 | {function_mesh_fingerprint, 0}); |
1577 | } |
1578 | const uint64 step_id = FingerprintCat64( |
1579 | function_mesh_fingerprint, |
1580 | func_mesh_fingerprint_to_step_counter_.at(function_mesh_fingerprint)); |
1581 | |
1582 | // Execute excluded functions in sequence. |
1583 | if (epu_fn_ptr != nullptr) { |
1584 | ExecuteFunctionAndWait( |
1585 | context, |
1586 | /*function_ptr=*/epu_fn_ptr.get(), |
1587 | /*parallel_device_mesh=*/ |
1588 | function_name_and_mesh_mapping[epu_fn_ptr->translated_function_name], |
1589 | /*parallel_inputs=*/{}, /*step_id=*/step_id, /*attributes=*/attributes, |
1590 | /*status=*/status); |
1591 | } |
1592 | |
1593 | if (load_embedding_ptr != nullptr) { |
1594 | StatusOr<std::vector<parallel_device::ParallelTensor*>> parallel_inputs = |
1595 | PrepareEmbeddingInputs(inputs); |
1596 | if (!parallel_inputs.ok()) { |
1597 | RETURN_STATUS(status, TF_INTERNAL, |
1598 | parallel_inputs.status().error_message().c_str()); |
1599 | } |
1600 | ExecuteFunctionAndWait( |
1601 | context, |
1602 | /*function_ptr=*/load_embedding_ptr.get(), |
1603 | /*parallel_device_mesh=*/ |
1604 | function_name_and_mesh_mapping[load_embedding_ptr |
1605 | ->translated_function_name], |
1606 | /*parallel_inputs=*/*parallel_inputs, /*step_id=*/step_id, |
1607 | /*attributes=*/attributes, /*status=*/status); |
1608 | } |
1609 | |
1610 | // Extract the global parallel inputs and flatten SparseTensors |
1611 | // into the three component tensors. |
1612 | std::vector<parallel_device::ParallelTensor*> global_parallel_inputs; |
1613 | std::vector<parallel_device::ParallelTensor*> global_parallel_sparse_inputs; |
1614 | absl::flat_hash_set<int> global_sparse_input_indices; |
1615 | for (auto input : inputs) { |
1616 | if (input->tensor_type() == TensorType::kSparse) { |
1617 | SparseTensorWithLayout* sparse_input = |
1618 | dynamic_cast<SparseTensorWithLayout*>(input); |
1619 | global_parallel_sparse_inputs.push_back(sparse_input->indices()); |
1620 | global_parallel_sparse_inputs.push_back(sparse_input->dense_shapes()); |
1621 | global_parallel_sparse_inputs.push_back(sparse_input->values()); |
1622 | } else { |
1623 | global_parallel_inputs.push_back(input->tensor()); |
1624 | } |
1625 | } |
1626 | // Insert SparseTensor components to the end, this is because |
1627 | // in the MLIR handling of SparseTensors, we place SparseTensor components |
1628 | // to the end of the main func arguments for a fixed ordering. |
1629 | global_parallel_inputs.insert(global_parallel_inputs.end(), |
1630 | global_parallel_sparse_inputs.begin(), |
1631 | global_parallel_sparse_inputs.end()); |
1632 | |
1633 | // Execute all functions in parallel. |
1634 | for (const TranslatedFunction& function : |
1635 | execution_functions->function_list) { |
1636 | const Mesh& mesh = function.function_mesh; |
1637 | const std::string& translated_function_name = |
1638 | function.translated_function_name; |
1639 | |
1640 | num_global_outputs += function.local_output_shapes.size(); |
1641 | |
1642 | if (is_remote_mesh(mesh) || |
1643 | (excluded_fn_names.find(translated_function_name) != |
1644 | excluded_fn_names.end())) { |
1645 | // Skip execution for a translated function has remote mesh or when it is |
1646 | // excluded. |
1647 | continue; |
1648 | } |
1649 | |
1650 | const MeshWithParallelDevice* parallel_device_mesh = |
1651 | function_name_and_mesh_mapping[translated_function_name]; |
1652 | |
1653 | // Gather the local inputs for this function. |
1654 | std::vector<parallel_device::ParallelTensor*> parallel_inputs; |
1655 | parallel_inputs.reserve(inputs.size() + 1); |
1656 | auto input_mapping = function.input_index_map; |
1657 | |
1658 | // We sort here because by this time, the function graph we are executing |
1659 | // is a reduced version of the main function, that includes the |
1660 | // StatefulPartitionedCall that we are executing for this mesh. |
1661 | // Thus, the ordering is the same as the main function ordering, which |
1662 | // is sorted increasingly. |
1663 | std::sort(input_mapping.begin(), input_mapping.end()); |
1664 | |
1665 | for (const int global_index : input_mapping) { |
1666 | auto input_index = global_index - execution_functions->num_device_ids; |
1667 | |
1668 | if (global_index < execution_functions->num_device_ids) { |
1669 | parallel_inputs.push_back( |
1670 | parallel_device_mesh->DeviceIDs(context, status)); |
1671 | if (TF_GetCode(status) != TF_OK) return; |
1672 | } else { |
1673 | parallel_inputs.push_back(global_parallel_inputs[input_index]); |
1674 | } |
1675 | } |
1676 | |
1677 | VLOG(4) << "Launching computation for mesh : " << mesh.ToString(); |
1678 | parallel_device_mesh->parallel_device().StartExecute( |
1679 | context, parallel_inputs, translated_function_name.c_str(), attributes, |
1680 | /*expected_max_outputs=*/function.local_output_shapes.size(), |
1681 | *cancellation_manager_, /*step_id=*/step_id); |
1682 | } |
1683 | |
1684 | *num_outputs = num_global_outputs; |
1685 | std::vector<std::unique_ptr<TensorWithLayout>> typed_outputs; |
1686 | typed_outputs.resize(num_global_outputs); |
1687 | |
1688 | // Join all mesh computation together. |
1689 | // TODO(b/177932563): Expose cancel logic to handle failures. |
1690 | std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> join_status( |
1691 | TF_NewStatus(), TF_DeleteStatus); |
1692 | for (const TranslatedFunction& function : |
1693 | execution_functions->function_list) { |
1694 | // Skip execution for a function when it's excluded. |
1695 | if (excluded_fn_names.contains(function.translated_function_name)) { |
1696 | continue; |
1697 | } |
1698 | const Mesh& mesh = function.function_mesh; |
1699 | const MeshWithParallelDevice* parallel_device_mesh = |
1700 | function_name_and_mesh_mapping[function.translated_function_name]; |
1701 | |
1702 | std::vector<std::unique_ptr<TensorWithLayout>> output_with_layout; |
1703 | output_with_layout.reserve(function.output_index_map.size()); |
1704 | if (is_remote_mesh(mesh)) { |
1705 | // Create dummy outputs on a remote mesh. |
1706 | for (int i = 0; i < function.output_index_map.size(); ++i) { |
1707 | const auto dim_sizes = function.local_output_shapes.at(i).dim_sizes(); |
1708 | std::vector<int64_t> local_shape = |
1709 | std::vector<int64_t>(dim_sizes.begin(), dim_sizes.end()); |
1710 | TF_DataType dtype = |
1711 | static_cast<TF_DataType>(function.output_dtypes.at(i)); |
1712 | auto remote_output = |
1713 | TensorWithLayout::Dummy(local_shape, dtype, *parallel_device_mesh, |
1714 | function.output_layouts[i]); |
1715 | output_with_layout.push_back(std::move(remote_output)); |
1716 | } |
1717 | } else { |
1718 | VLOG(4) << "Joining computation result from mesh : " << mesh.ToString(); |
1719 | auto result = parallel_device_mesh->parallel_device().Join( |
1720 | function.local_output_shapes, status); |
1721 | if (TF_GetCode(join_status.get()) != TF_OK && |
1722 | // Preserve the first failure we see, but only if it is a real failure |
1723 | // and not a cancellation (which was probably triggered by the error |
1724 | // we want to propagate). |
1725 | (TF_GetCode(status) == TF_OK || |
1726 | TF_GetCode(join_status.get()) != TF_CANCELLED)) { |
1727 | continue; |
1728 | } |
1729 | if (TF_GetCode(status) != TF_OK) { |
1730 | if (TF_GetCode(status) != TF_CANCELLED) { |
1731 | LOG(ERROR) << "Encountered error while executing function: " |
1732 | << function.translated_function_name |
1733 | << " for mesh : " << mesh.ToString() |
1734 | << " / error : " << TF_Message(status); |
1735 | } |
1736 | TF_SetStatus(join_status.get(), TF_GetCode(status), TF_Message(status)); |
1737 | continue; |
1738 | } |
1739 | |
1740 | for (int i = 0; i < result->size(); ++i) { |
1741 | ASSIGN_OR_RETURN_C_STATUS( |
1742 | auto local_output, |
1743 | TensorWithLayout::Wrap(std::move((*result)[i]), |
1744 | *parallel_device_mesh, |
1745 | function.output_layouts[i]), |
1746 | status); |
1747 | output_with_layout.push_back(std::move(local_output)); |
1748 | } |
1749 | } |
1750 | |
1751 | for (int i = 0; i < function.output_index_map.size(); ++i) { |
1752 | // TODO(b/162744844): Generalize this pattern so that the extraction is |
1753 | // not special cased. |
1754 | if (function.shape_output_metadata.find(i) != |
1755 | function.shape_output_metadata.end()) { |
1756 | output_with_layout[i]->set_input_layout_for_shape_op_result( |
1757 | function.shape_output_metadata.at(i)); |
1758 | } |
1759 | |
1760 | RecordInShapeLayoutCache(*output_with_layout[i]); |
1761 | typed_outputs[function.output_index_map[i]] = |
1762 | std::move(output_with_layout[i]); |
1763 | } |
1764 | } |
1765 | if (TF_GetCode(join_status.get()) != TF_OK) { |
1766 | std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status( |
1767 | TF_NewStatus(), TF_DeleteStatus); |
1768 | AsyncWait(context, async_wait_status.get()); |
1769 | TF_Code error_code = TF_GetCode(async_wait_status.get()); |
1770 | if (error_code != TF_OK && error_code != TF_CANCELLED) { |
1771 | // Ignore the AsyncWait() status return since we already have a bad status |
1772 | // to propagate. We've just canceled a bunch of operations, so we expect |
1773 | // cancellation status returns. We'll log anything else just to be safe. |
1774 | LOG(ERROR) << "Error executing " << doperation.name << " " |
1775 | << TF_Message(async_wait_status.get()); |
1776 | } |
1777 | |
1778 | TF_SetStatus(status, TF_GetCode(join_status.get()), |
1779 | TF_Message(join_status.get())); |
1780 | return; |
1781 | } |
1782 | if (VLOG_IS_ON(2)) { |
1783 | LOG(INFO) << "Executed " << doperation.name << ", got " |
1784 | << typed_outputs.size() << " outputs:" ; |
1785 | for (const std::unique_ptr<TensorWithLayout>& output : typed_outputs) { |
1786 | LOG(INFO) << " " << output->DebugString(); |
1787 | } |
1788 | } |
1789 | if (doperation.name == std::string("VarHandleOp" )) { |
1790 | // For new variables, set the dereferenced shape/dtype so we can pass it in |
1791 | // as _handle_dtype and _handle_shape in the future. |
1792 | // |
1793 | // Note that VarHandleOps generated by `tf.Variable` objects are always run |
1794 | // eagerly, which is almost all of the op's usage in TF2. Theoretically a |
1795 | // user could run it in a tf.function via tf.raw_ops.VarHandleOp, return it |
1796 | // from that function, and add it as an input to another, and it would |
1797 | // currently be missing handle information. |
1798 | if (typed_outputs.size() != 1) { |
1799 | RETURN_STATUS(status, TF_INTERNAL, |
1800 | "Expected one output from VarHandleOp" ); |
1801 | } |
1802 | NameAttrList name_and_attrs; |
1803 | ASSIGN_OR_RETURN_C_STATUS(name_and_attrs, FetchAttributes(attributes), |
1804 | status); |
1805 | |
1806 | typed_outputs[0]->UpdateShapeAndDType( |
1807 | name_and_attrs.attr().at("shape" ).shape(), |
1808 | name_and_attrs.attr().at("dtype" ).type(), status); |
1809 | if (TF_GetCode(status) != TF_OK) return; |
1810 | } |
1811 | |
1812 | for (int i = 0; i < *num_outputs; ++i) { |
1813 | outputs[i] = |
1814 | MakeLayoutTensorHandle(context, std::move(typed_outputs[i]), status); |
1815 | if (TF_GetCode(status) != TF_OK) return; |
1816 | } |
1817 | } |
1818 | |
1819 | void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs, |
1820 | TFE_TensorHandle** outputs, TF_Status* status) { |
1821 | TFE_Context* context = TFE_OpGetContext(original_op, status); |
1822 | if (TF_GetCode(status) != TF_OK) return; |
1823 | const char* operation_name = TFE_OpGetName(original_op, status); |
1824 | if (TF_GetCode(status) != TF_OK) return; |
1825 | const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op); |
1826 | int num_inputs = TFE_OpGetFlatInputCount(original_op, status); |
1827 | if (TF_GetCode(status) != TF_OK) return; |
1828 | std::vector<TFE_TensorHandle*> inputs_vector; |
1829 | inputs_vector.reserve(num_inputs); |
1830 | for (int input_index = 0; input_index < num_inputs; ++input_index) { |
1831 | TFE_TensorHandle* input = |
1832 | TFE_OpGetFlatInput(original_op, input_index, status); |
1833 | if (TF_GetCode(status) != TF_OK) return; |
1834 | inputs_vector.push_back(input); |
1835 | } |
1836 | TFE_TensorHandle** inputs = inputs_vector.data(); |
1837 | |
1838 | DTensorOperation dtensor_operation{}; |
1839 | dtensor_operation.name = operation_name; |
1840 | { |
1841 | dtensor_operation.function_def = |
1842 | tensorflow::unwrap(context)->FindFunctionDef(operation_name); |
1843 | } |
1844 | |
1845 | // First handle DTensor-specific virtual operations. |
1846 | bool is_op_handled = false; |
1847 | MaybeHandleDTensorCustomOps(operation_name, num_inputs, attributes, context, |
1848 | inputs, num_outputs, outputs, &is_op_handled, |
1849 | status); |
1850 | if (is_op_handled) return; |
1851 | |
1852 | // This isn't a special op, so we'll defer to TFE_Execute to actually execute |
1853 | // it, but we'll also run DTensor MLIR passes and propagate the layout. |
1854 | std::vector<TensorWithLayout*> typed_inputs; |
1855 | std::vector<std::unique_ptr<TensorWithLayout>> inputs_with_no_layout; |
1856 | |
1857 | // Record a unique mesh identified through all inputs that's already on |
1858 | // DTensor device. If we can identify a single mesh, the same mesh is used as |
1859 | // the mesh to broadcast non-dtensor inputs. |
1860 | absl::flat_hash_set<Mesh> input_meshes; |
1861 | std::vector<int> not_on_device_input_indices; |
1862 | |
1863 | typed_inputs.resize(num_inputs); |
1864 | for (int j = 0; j < num_inputs; ++j) { |
1865 | TFE_TensorHandle* input = inputs[j]; |
1866 | const char* input_device = TFE_TensorHandleDeviceName(input, status); |
1867 | if (TF_GetCode(status) != TF_OK) return; |
1868 | if (name_ != input_device) { |
1869 | not_on_device_input_indices.push_back(j); |
1870 | continue; |
1871 | } |
1872 | // Handle input which is on DTensor device already. |
1873 | TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
1874 | TFE_TensorHandleDevicePointer(input, status)); |
1875 | if (TF_GetCode(status) != TF_OK) return; |
1876 | |
1877 | // VarHandleOp runs on empty mesh, and that isn't registered with device. |
1878 | if (!t->layout().mesh().IsEmpty()) { |
1879 | input_meshes.insert(t->layout().mesh()); |
1880 | } |
1881 | // Remote mesh inputs are not able to be read and evaluated. |
1882 | if (!is_remote_mesh(t->layout().mesh()) && !t->const_value().has_value()) { |
1883 | std::optional<NodeDef> const_value = |
1884 | ExtractSmallTensorValue(context, input, t->layout(), status); |
1885 | if (TF_GetCode(status) != TF_OK) return; |
1886 | if (const_value.has_value()) { |
1887 | t->set_const_value(const_value.value()); |
1888 | } |
1889 | } |
1890 | typed_inputs[j] = t; |
1891 | } |
1892 | |
1893 | // If a unique mesh is identified across all inputs, we use that mesh as the |
1894 | // mesh to broadcast to. Otherwise we fallback to default mesh. |
1895 | const MeshWithParallelDevice* broadcast_mesh = |
1896 | input_meshes.size() == 1 |
1897 | ? mesh_to_device_map_[*input_meshes.begin()].get() |
1898 | : default_mesh_; |
1899 | if (!broadcast_mesh) { |
1900 | RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
1901 | "No mesh has been registered to DTensor. Use copy_to_mesh to " |
1902 | "explicit specify a mesh instead." ); |
1903 | } |
1904 | for (int not_on_device_input_index : not_on_device_input_indices) { |
1905 | TFE_TensorHandle* input = inputs[not_on_device_input_index]; |
1906 | // DTensor creation should be explicit, with some exceptions for usability |
1907 | // (scalars/shapes/slice specs/etc.) Here we do some trivial validation to |
1908 | // enforce this rule. |
1909 | int num_dims = TFE_TensorHandleNumDims(input, status); |
1910 | if (TF_GetCode(status) != TF_OK) return; |
1911 | int64_t num_elements = TFE_TensorHandleNumElements(input, status); |
1912 | if (TF_GetCode(status) != TF_OK) return; |
1913 | TF_DataType dtype = TFE_TensorHandleDataType(input); |
1914 | const bool small_int_tensor = num_elements < kSmallTensorThreshold && |
1915 | (dtype == TF_INT32 || dtype == TF_INT64); |
1916 | if (!(num_dims == 0 || dtype == TF_STRING || small_int_tensor)) { |
1917 | std::vector<int64_t> tensor_shape(TensorShapeAsVector(input, status)); |
1918 | if (TF_GetCode(status) != TF_OK) return; |
1919 | RETURN_STATUS( |
1920 | status, TF_UNIMPLEMENTED, |
1921 | absl::StrCat( |
1922 | "The op/function " , operation_name, |
1923 | " got a regular tensor for input " , not_on_device_input_index, |
1924 | " (shape " , ShapeToDebugString(tensor_shape), |
1925 | ") but was expecting a DTensor. Currently only scalars and " |
1926 | "small integer/string tensors are auto-broadcast to " |
1927 | "DTensors. For other tensors, please use copy_to_mesh to " |
1928 | "make a DTensor explicitly; note that this may be slow if it " |
1929 | "happens frequently." ) |
1930 | .c_str()); |
1931 | } |
1932 | // Construct temporary TensorWithLayout objects for inputs that didn't |
1933 | // have any to start. These are owned by the `inputs_with_no_layout` |
1934 | // vector, whereas the input `TFE_TensorHandle`s maintain ownership for |
1935 | // inputs that already had layouts (and therefor had TensorWithLayout |
1936 | // objects). |
1937 | std::unique_ptr<TensorWithLayout> wrapper = TensorWithLayout::Broadcast( |
1938 | context, input, *broadcast_mesh, name_, status); |
1939 | if (TF_GetCode(status) != TF_OK) return; |
1940 | if (!ShouldFoldInputArgument(dtensor_operation.name, |
1941 | /*input_index=*/not_on_device_input_index)) { |
1942 | wrapper->reset_const_value(); |
1943 | } |
1944 | typed_inputs[not_on_device_input_index] = wrapper.get(); |
1945 | inputs_with_no_layout.emplace_back(wrapper.release()); |
1946 | } |
1947 | |
1948 | ExecuteRegularOperation(context, typed_inputs, dtensor_operation, attributes, |
1949 | num_outputs, outputs, status); |
1950 | } |
1951 | |
1952 | void ExecuteOnDTensorDevice(const TFE_Op* original_op, int* num_outputs, |
1953 | TFE_TensorHandle** outputs, TF_Status* status, |
1954 | void* device_info) { |
1955 | DTensorDevice* dev = reinterpret_cast<DTensorDevice*>(device_info); |
1956 | dev->Execute(original_op, num_outputs, outputs, status); |
1957 | } |
1958 | |
1959 | void DeleteDTensorDevice(void* device_info) { |
1960 | delete static_cast<DTensorDevice*>(device_info); |
1961 | } |
1962 | |
1963 | TFE_TensorHandle* CopyToDTensorDevice(TFE_Context* context, |
1964 | TFE_TensorHandle* tensor, |
1965 | TF_Status* status, void* device_info) { |
1966 | TF_SetStatus(status, TF_UNIMPLEMENTED, |
1967 | "Trying to copy a tensor on to a DTensor mesh without a layout " |
1968 | "(use the CopyToMesh op for now)." ); |
1969 | return nullptr; |
1970 | } |
1971 | |
1972 | TFE_TensorHandle* CopyFromDTensorDevice(TFE_Context* context, |
1973 | TFE_TensorHandle* tensor, |
1974 | const char* target_device_name, |
1975 | TF_Status* status, void* device_info) { |
1976 | TensorWithLayout* typed_input = reinterpret_cast<TensorWithLayout*>( |
1977 | TFE_TensorHandleDevicePointer(tensor, status)); |
1978 | if (!tensorflow::dtensor::Layout(typed_input->layout()).IsFullyReplicated()) { |
1979 | TF_SetStatus(status, TF_UNIMPLEMENTED, |
1980 | "Trying to copy a non-replicated DTensor is not supported." ); |
1981 | return nullptr; |
1982 | } |
1983 | if (typed_input->tensor()->dtype() == TF_RESOURCE) { |
1984 | TF_SetStatus(status, TF_UNIMPLEMENTED, |
1985 | "Trying to copy a DTensor resource handle is not supported." ); |
1986 | return nullptr; |
1987 | } |
1988 | DTensorDevice* dev = reinterpret_cast<DTensorDevice*>(device_info); |
1989 | // Since operations are executed asynchronously, the operation which should |
1990 | // produce the tensor we're trying to copy off the DTensor device may be |
1991 | // canceled due to a failure on another device. If so, we want to report the |
1992 | // failure that caused the cancellation, not the cancellation itself. This |
1993 | // requires blocking waiting for other devices to flush their execution |
1994 | // queues. |
1995 | // Note that we also only need to sync the threads on the parallel_device() |
1996 | // directly, or a context level sync might cause unintentional deadlocks when |
1997 | // grabbing locks on other threads. |
1998 | dev->AsyncWait(context, status); |
1999 | if (TF_GetCode(status) != TF_OK) return nullptr; |
2000 | return TFE_TensorHandleCopySharingTensor(typed_input->get_tensor(0), status); |
2001 | } |
2002 | |
2003 | bool PinToDTensorDevice(const TFE_Op* op, TF_Status* s) { |
2004 | // Always pin to the dtensor device if any of its input is a dtensor. |
2005 | // Note that if this function is called, the caller guarantees |
2006 | // that all inputs that are on a custom device is a single dtensor device. |
2007 | |
2008 | // Exception 1: |
2009 | // If there is a non-dtensor resource tensor and other dtensor inputs |
2010 | // are not on a CPU mesh, then pin to the physical device. |
2011 | // |
2012 | // This is because our resource upcast to a dtensor only supports broadcasting |
2013 | // to a CPU mesh. If any other dtensor inputs are on a TPU mesh, |
2014 | // then the mesh that is broadcasted will be the TPU mesh. |
2015 | int num_inputs = TFE_OpGetFlatInputCount(op, s); |
2016 | std::vector<TFE_TensorHandle*> inputs_vector; |
2017 | inputs_vector.reserve(num_inputs); |
2018 | |
2019 | absl::flat_hash_set<Mesh> input_meshes; |
2020 | |
2021 | bool has_non_dtensor_resource = false; |
2022 | |
2023 | for (int input_index = 0; input_index < num_inputs; ++input_index) { |
2024 | TFE_TensorHandle* input = TFE_OpGetFlatInput(op, input_index, s); |
2025 | |
2026 | std::string input_device_name = |
2027 | std::string(TFE_TensorHandleDeviceName(input, s)); |
2028 | if (!absl::StrContains(absl::AsciiStrToLower(input_device_name), |
2029 | "custom" )) { |
2030 | TF_DataType dtype = TFE_TensorHandleDataType(input); |
2031 | if (dtype == TF_RESOURCE) { |
2032 | has_non_dtensor_resource = true; |
2033 | } |
2034 | continue; |
2035 | } |
2036 | |
2037 | // Handle input which is on DTensor device already. |
2038 | TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
2039 | TFE_TensorHandleDevicePointer(input, s)); |
2040 | |
2041 | if (!t->layout().mesh().IsEmpty()) { |
2042 | input_meshes.insert(t->layout().mesh()); |
2043 | } |
2044 | } |
2045 | |
2046 | const Mesh* broadcast_mesh = |
2047 | input_meshes.size() == 1 ? &(*input_meshes.begin()) : nullptr; |
2048 | |
2049 | // Place on physical device as dtensor does not support upcasting resource |
2050 | // tensor to a non-cpu mesh. |
2051 | if (has_non_dtensor_resource && broadcast_mesh && |
2052 | !broadcast_mesh->is_cpu_mesh()) { |
2053 | return false; |
2054 | } |
2055 | |
2056 | return true; |
2057 | } |
2058 | |
2059 | void AllocateDTensorDevice(absl::string_view device_name, |
2060 | TFE_CustomDevice* device, void** device_info) { |
2061 | device->copy_tensor_to_device = &CopyToDTensorDevice; |
2062 | device->copy_tensor_from_device = &CopyFromDTensorDevice; |
2063 | device->delete_device = &DeleteDTensorDevice; |
2064 | device->execute = &ExecuteOnDTensorDevice; |
2065 | device->shall_pin_to_this_device = &PinToDTensorDevice; |
2066 | *device_info = new DTensorDevice(device_name); |
2067 | } |
2068 | |
2069 | void AddMesh(const std::string& serialized_mesh, void* device_info, |
2070 | bool is_async, bool is_host_mesh, TF_Status* status) { |
2071 | auto mesh_config_or_status = Mesh::FromString(serialized_mesh); |
2072 | if (!mesh_config_or_status.ok()) { |
2073 | TF_SetStatus(status, TF_INTERNAL, |
2074 | absl::StrCat("Failed to parse mesh config. " , |
2075 | mesh_config_or_status.status().error_message()) |
2076 | .c_str()); |
2077 | return; |
2078 | } |
2079 | auto mesh_config = mesh_config_or_status.value(); |
2080 | std::vector<std::string> underlying_devices; |
2081 | underlying_devices.insert(underlying_devices.end(), |
2082 | mesh_config.local_devices().begin(), |
2083 | mesh_config.local_devices().end()); |
2084 | // DTensor uses multi-client setup which doesn't use remote eager, so we can |
2085 | // enable eager async execution in ParallelDevice. |
2086 | std::unique_ptr<tensorflow::parallel_device::ParallelDevice> parallel( |
2087 | new tensorflow::parallel_device::ParallelDevice(underlying_devices, |
2088 | is_async)); |
2089 | |
2090 | std::string composite_device_name; |
2091 | if (absl::StartsWith(mesh_config.name(), kPipelineMeshNamePrefix)) { |
2092 | composite_device_name = std::string( |
2093 | absl::StripPrefix(mesh_config.name(), kPipelineMeshNamePrefix)); |
2094 | } |
2095 | |
2096 | auto mesh = std::make_unique<MeshWithParallelDevice>( |
2097 | std::move(mesh_config), std::move(parallel), composite_device_name); |
2098 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2099 | device->AddMesh(std::move(mesh), is_host_mesh); |
2100 | } |
2101 | |
2102 | void ExperimentalSetDefaultLayout(const std::string& serialized_layout, |
2103 | void* device_info, TF_Status* status) { |
2104 | StatusOr<Layout> layout = Layout::FromString(serialized_layout); |
2105 | if (!layout.ok()) { |
2106 | RETURN_STATUS(status, TF_INTERNAL, layout.status().error_message().c_str()); |
2107 | } |
2108 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2109 | device->SetDefaultLayout(layout.value()); |
2110 | } |
2111 | |
2112 | void ExperimentalClearDefaultLayout(void* device_info, TF_Status* status) { |
2113 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2114 | device->ClearDefaultLayout(); |
2115 | } |
2116 | |
2117 | void ExperimentalSetDefaultMesh(const std::string& serialized_mesh, |
2118 | void* device_info, TF_Status* status) { |
2119 | StatusOr<Mesh> mesh = Mesh::FromString(serialized_mesh); |
2120 | if (!mesh.ok()) { |
2121 | RETURN_STATUS(status, TF_INTERNAL, mesh.status().error_message().c_str()); |
2122 | } |
2123 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2124 | device->SetDefaultMesh(mesh.value()); |
2125 | } |
2126 | |
2127 | void ExperimentalClearDefaultMesh(void* device_info, TF_Status* status) { |
2128 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2129 | device->ClearDefaultMesh(); |
2130 | } |
2131 | |
2132 | void SetSameShapePolicy(void* device_info, bool enabled) { |
2133 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2134 | device->SetSameShapePolicy(enabled); |
2135 | } |
2136 | |
2137 | void SetTPUCoreIDs(const std::string& mesh_name, |
2138 | const std::vector<int>& tpu_core_ids, void* device_info, |
2139 | TF_Status* status) { |
2140 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2141 | RETURN_C_STATUS_IF_NOT_OK(device->SetTPUCoreIDs(mesh_name, tpu_core_ids), |
2142 | status); |
2143 | } |
2144 | |
2145 | void ClearTPUCoreIDs(void* device_info) { |
2146 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2147 | device->ClearTPUCoreIDs(); |
2148 | } |
2149 | |
2150 | std::vector<std::vector<int>> TPUCoreIDsToLocations( |
2151 | TFE_Context* context, const std::vector<int>& tpu_core_ids, |
2152 | void* device_info) { |
2153 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2154 | return device->TPUCoreIDsToLocations(context, tpu_core_ids); |
2155 | } |
2156 | |
2157 | std::vector<int> TPUCoreLocationsToIDs( |
2158 | TFE_Context* context, |
2159 | const std::vector<std::vector<int>>& tpu_core_locations, |
2160 | void* device_info) { |
2161 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2162 | return device->TPUCoreLocationsToIDs(context, tpu_core_locations); |
2163 | } |
2164 | |
2165 | TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs, |
2166 | TFE_TensorHandle** inputs, |
2167 | const std::string& string_layout, void* device_info, |
2168 | TF_Status* status) { |
2169 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2170 | return device->Pack(context, num_inputs, inputs, string_layout, status); |
2171 | } |
2172 | |
2173 | std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context, |
2174 | TFE_TensorHandle* input, |
2175 | void* device_info, TF_Status* status) { |
2176 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2177 | return device->Unpack(context, input, status); |
2178 | } |
2179 | |
2180 | std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input, |
2181 | void* device_info, TF_Status* status) { |
2182 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2183 | return device->FetchLayout(context, input, status); |
2184 | } |
2185 | |
2186 | TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs, |
2187 | TFE_TensorHandle** indices, |
2188 | TFE_TensorHandle** values, |
2189 | TFE_TensorHandle** shapes, |
2190 | const std::string& string_layout, |
2191 | void* device_info, TF_Status* status) { |
2192 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2193 | return device->SparsePack(context, num_inputs, indices, values, shapes, |
2194 | string_layout, status); |
2195 | } |
2196 | |
2197 | bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input, |
2198 | void* device_info, TF_Status* status) { |
2199 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2200 | return device->IsSparseDTensor(context, input, status); |
2201 | } |
2202 | |
2203 | std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount( |
2204 | TFE_Context* context, void* device_info, TF_Status* status) { |
2205 | DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
2206 | return device->GetFunctionCacheHitAndMissCount(context, status); |
2207 | } |
2208 | } // namespace dtensor |
2209 | } // namespace tensorflow |
2210 | |