1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_ |
17 | #define TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_ |
18 | |
19 | #include <string> |
20 | #include <utility> |
21 | |
22 | #include "tensorflow/c/eager/c_api.h" |
23 | #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" |
24 | #include "tensorflow/c/eager/tfe_context_internal.h" |
25 | #include "tensorflow/core/common_runtime/composite_device.h" |
26 | #include "tensorflow/core/common_runtime/eager/context.h" |
27 | #include "tensorflow/core/framework/function.h" |
28 | #include "tensorflow/core/framework/function.pb.h" |
29 | #include "tensorflow/core/framework/node_def_builder.h" |
30 | #include "tensorflow/core/framework/tensor_shape.h" |
31 | #include "tensorflow/core/graph/graph.h" |
32 | #include "tensorflow/core/platform/errors.h" |
33 | #include "tensorflow/core/platform/fingerprint.h" |
34 | #include "tensorflow/dtensor/cc/constants.h" |
35 | #include "tensorflow/dtensor/cc/dstatus.h" |
36 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
37 | |
38 | namespace tensorflow { |
39 | namespace dtensor { |
40 | |
41 | #define RETURN_STATUS(status, code, message) \ |
42 | { \ |
43 | TF_SetStatus((status), (code), (message)); \ |
44 | return; \ |
45 | } |
46 | |
47 | #define RETURN_C_STATUS_IF_NOT_OK(cpp_status, c_status) \ |
48 | { \ |
49 | auto return_if_not_ok_status = (cpp_status); \ |
50 | if (!return_if_not_ok_status.ok()) { \ |
51 | RETURN_STATUS((c_status), \ |
52 | static_cast<TF_Code>(return_if_not_ok_status.code()), \ |
53 | return_if_not_ok_status.error_message().c_str()); \ |
54 | } \ |
55 | } |
56 | |
57 | // Using a counter to uniquify instead of a new block allows `var` to declare a |
58 | // new variable. |
59 | #define ASSIGN_OR_RETURN_C_STATUS(var, cpp_status, c_status) \ |
60 | ASSIGN_OR_RETURN_C_STATUS_IMPL( \ |
61 | TF_STATUS_MACROS_CONCAT_NAME(_dtensor_status_or_value, __COUNTER__), \ |
62 | var, cpp_status, c_status) |
63 | |
64 | #define ASSIGN_OR_RETURN_C_STATUS_IMPL(statusor, var, cpp_status, c_status) \ |
65 | auto statusor = (cpp_status); \ |
66 | RETURN_C_STATUS_IF_NOT_OK(statusor.status(), (c_status)); \ |
67 | var = std::move(statusor.value()); |
68 | |
69 | struct TranslatedFunction { |
70 | // Mesh for which specified function will run. |
71 | Mesh function_mesh; |
72 | |
73 | // StatefulPartitionedCall op to run the mesh function. |
74 | const Node* node_to_execute = nullptr; |
75 | |
76 | // Maps i-th local input index to input index in global graph. |
77 | std::vector<int> input_index_map; |
78 | |
79 | // Maps i-th local output to output index of global graph. |
80 | std::vector<int> output_index_map; |
81 | |
82 | std::string translated_function_name; |
83 | // For resource ops, layouts of resource handles are inferred lazily |
84 | // during SPMD expansion of resource assign ops. In that case, |
85 | // inferred layouts of resource handles are attached to arg nodes |
86 | // of the returned graph. |
87 | std::map<int, Layout> resource_input_layouts; |
88 | // Record some metadata for output of a shape op. This would help recover |
89 | // local shape on future operations over the Tensor. |
90 | std::map<int, Layout> shape_output_metadata; |
91 | std::vector<Layout> output_layouts; |
92 | // Local shapes inferred for function outputs; these may be partially known. |
93 | std::vector<PartialTensorShape> local_output_shapes; |
94 | // Output data types. |
95 | std::vector<TF_DataType> output_dtypes; |
96 | }; |
97 | |
98 | struct ExecutionFunctions { |
99 | // Stores information about all functions to execute for provided computation. |
100 | std::vector<TranslatedFunction> function_list; |
101 | // Number of device ids args added to translated functions. |
102 | // During translation, we insert one device id arg node per mesh. |
103 | // For a single mesh function, it equals 1. |
104 | // For a multi-mesh function (e.g. pipelining), it equals the number of |
105 | // meshes. |
106 | int num_device_ids; |
107 | |
108 | // Mesh fingerprint of function_list. Set only when ExecutionFunctions refers |
109 | // to a function for performance reason, since an eager op doesn't use it. |
110 | uint64 function_mesh_fingerprint = 0; |
111 | }; |
112 | |
113 | // TODO(yujingzhang): move FingerprintCat128 to tensorflow/platform. |
114 | inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, |
115 | const tensorflow::Fprint128& b) { |
116 | return {tensorflow::FingerprintCat64(a.low64, b.low64), |
117 | tensorflow::FingerprintCat64(a.high64, b.high64)}; |
118 | } |
119 | |
120 | inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, |
121 | const int64 b) { |
122 | auto x = tensorflow::FingerprintCat64(a.low64, b); |
123 | return {x, tensorflow::FingerprintCat64(a.high64, x)}; |
124 | } |
125 | |
126 | struct DTensorOperation { |
127 | // For both fields: not owned. lifetime covers the whole usage. |
128 | const char* name; |
129 | const FunctionDef* function_def; |
130 | |
131 | inline bool is_func() const { return function_def != nullptr; } |
132 | }; |
133 | |
134 | struct EmbeddingResourceAttrs { |
135 | int64_t table_id; |
136 | absl::optional<int64_t> slot_id; // NOLINT |
137 | bool is_dirty = false; |
138 | }; |
139 | |
140 | // Contains a mesh bundled with a parallel device over all of the devices in |
141 | // that mesh. |
142 | class MeshWithParallelDevice { |
143 | public: |
144 | MeshWithParallelDevice( |
145 | const Mesh& mesh_config, |
146 | std::unique_ptr<parallel_device::ParallelDevice> parallel_device, |
147 | const std::string& composite_device_name = "" ) |
148 | : mesh_config_(mesh_config), |
149 | parallel_device_(std::move(parallel_device)), |
150 | composite_device_name_(composite_device_name), |
151 | // Device IDs are constructed lazily because we don't have a context |
152 | // until we start executing ops. |
153 | device_ids_tensor_(nullptr) {} |
154 | |
155 | // A parallel tensor containing scalar integer device IDs for underlying |
156 | // devices, each placed on its corresponding device. |
157 | // |
158 | // TODO(allenl): It would be nice if DeviceID worked as an op inside the |
159 | // function's graph. Then we wouldn't need to feed it as an argument. |
160 | parallel_device::ParallelTensor* DeviceIDs(TFE_Context* context, |
161 | TF_Status* status) const; |
162 | const parallel_device::ParallelDevice& parallel_device() const { |
163 | return *parallel_device_; |
164 | } |
165 | |
166 | const dtensor::Mesh& mesh_config() const { return mesh_config_; } |
167 | |
168 | // Creates a CompositeDevice in eager context if it not exists. |
169 | // Called when parallel_device_ contains a subset of global devices, e.g. |
170 | // pipelining is enabled. |
171 | StatusOr<CompositeDevice*> FindOrCreateCompositeDevice(TFE_Context* context) { |
172 | if (composite_device_ == nullptr && !composite_device_name_.empty()) { |
173 | if (mesh_config_.global_devices().empty()) { |
174 | return errors::InvalidArgument( |
175 | "Expect non-empty global devices when creating a CompositeDevice." ); |
176 | } |
177 | TF_RETURN_IF_ERROR(ContextFromInterface(tensorflow::unwrap(context)) |
178 | ->FindOrCreateCompositeDevice( |
179 | mesh_config_.global_devices(), |
180 | composite_device_name_, &composite_device_)); |
181 | } |
182 | return composite_device_; |
183 | } |
184 | |
185 | CompositeDevice* composite_device() const { return composite_device_; } |
186 | |
187 | private: |
188 | dtensor::Mesh mesh_config_; |
189 | std::unique_ptr<parallel_device::ParallelDevice> parallel_device_; |
190 | |
191 | // Set when parallel_device_ contains a subset of global devices, e.g. |
192 | // pipelining is enabled. |
193 | const std::string composite_device_name_; |
194 | // A tensorflow::Device that represents underlying devices of |
195 | // parallel_device_. Set when composite_device_name_ is not empty. |
196 | CompositeDevice* composite_device_ = nullptr; // owned by eager context |
197 | |
198 | // Constructed lazily; contains a parallel tensor with scalar integer device |
199 | // IDs for each device. |
200 | mutable std::unique_ptr<parallel_device::ParallelTensor> device_ids_tensor_; |
201 | }; |
202 | |
203 | enum TensorType { |
204 | kDense = 0, |
205 | kResource = 1, |
206 | kSparse = 2, |
207 | }; |
208 | |
209 | class TensorWithLayout { |
210 | public: |
211 | // Broadcast a single non-parallel tensor onto `mesh` with a fully replicated |
212 | // sharding spec. Does not take ownership of `tensor`. |
213 | static std::unique_ptr<TensorWithLayout> Broadcast( |
214 | TFE_Context* context, TFE_TensorHandle* tensor, |
215 | const MeshWithParallelDevice& mesh, |
216 | const std::string& dtensor_device_name, TF_Status* status); |
217 | |
218 | // Given an already-parallel tensor, wraps it with a mesh and a layout. |
219 | static StatusOr<std::unique_ptr<TensorWithLayout>> Wrap( |
220 | std::unique_ptr<parallel_device::ParallelTensor> tensor, |
221 | const MeshWithParallelDevice& mesh, const Layout& layout); |
222 | |
223 | // A dummy TensorWithLayout without holding a ParallelTensor. |
224 | static std::unique_ptr<TensorWithLayout> Dummy( |
225 | const std::vector<int64_t>& local_shape, const TF_DataType dtype, |
226 | const MeshWithParallelDevice& mesh, const Layout& layout); |
227 | |
228 | virtual ~TensorWithLayout() {} |
229 | |
230 | virtual const Layout& layout() const { return layout_; } |
231 | |
232 | virtual TensorType tensor_type() const { return TensorType::kDense; } |
233 | |
234 | virtual TF_DataType dtype() const { |
235 | if (dtype_.has_value()) { |
236 | return dtype_.value(); |
237 | } else { |
238 | return tensor_->dtype(); |
239 | } |
240 | } |
241 | |
242 | // Small constant value optimization for non-resource-handle tensors. |
243 | virtual void set_const_value(NodeDef& const_node) { |
244 | // If we extracted a constant value from the tensor, check if this |
245 | // value was the output from `tf.shape`. In this case, we need to |
246 | // forward the kShapeOpInputLayout attribute to the new node def. This |
247 | // is needed for layout propagation when running in op-by-op mode. |
248 | // |
249 | // TODO(b/162747667): Improve the presentation for Shape input Op |
250 | // layout. |
251 | if (shape_metadata_layout().has_value()) { |
252 | AddNodeAttr(kShapeOpInputLayout, {shape_metadata_layout()->ToString()}, |
253 | &(const_node)); |
254 | } |
255 | const_value_.emplace(const_node); |
256 | } |
257 | |
258 | // Clears the cached const value if present. |
259 | void reset_const_value() { const_value_.reset(); } |
260 | |
261 | // Encodes the NodeDef via provided builder, if applicable. |
262 | virtual void EncodeAttributes(tensorflow::NodeDefBuilder& builder) const {} |
263 | |
264 | virtual tensorflow::Fprint128 CacheKey() const; |
265 | |
266 | // Updates layout for this Tensor. |
267 | virtual void UpdateLayout(const Layout& new_layout, TF_Status* status) { |
268 | TF_SetStatus(status, TF_INTERNAL, |
269 | "Attempt to update layout on non-resource-handle" ); |
270 | } |
271 | |
272 | // Update shape and dtype. |
273 | virtual void UpdateShapeAndDType(const TensorShapeProto& shape, |
274 | const DataType& dtype, TF_Status* status) { |
275 | TF_SetStatus(status, TF_INTERNAL, |
276 | "Attempt to update shape and layout on non-resource-handle" ); |
277 | } |
278 | |
279 | // Update Attrs for this Tensor. |
280 | virtual void UpdateAttrs(const EmbeddingResourceAttrs& attrs, |
281 | TF_Status* status) { |
282 | TF_SetStatus(status, TF_INTERNAL, |
283 | "Attempt to update layout on non-resource-handle" ); |
284 | } |
285 | |
286 | virtual TFE_TensorHandle* get_tensor(size_t index) const { |
287 | return tensor()->tensor(index); |
288 | } |
289 | |
290 | virtual size_t num_tensors() const { return tensor()->num_tensors(); } |
291 | |
292 | virtual parallel_device::ParallelTensor* tensor() const { |
293 | return tensor_.get(); |
294 | } |
295 | |
296 | // Returns a string which includes just the value and layout of the tensor. |
297 | virtual std::string SummarizeValue() const; |
298 | // Returns a string which includes `SummarizeValue` along with shape and type |
299 | // information. |
300 | virtual std::string DebugString() const; |
301 | |
302 | void set_input_layout_for_shape_op_result(const Layout& layout) { |
303 | input_layout_for_shape_op_result_.emplace(layout); |
304 | } |
305 | |
306 | const absl::optional<Layout> shape_metadata_layout() const { |
307 | return input_layout_for_shape_op_result_; |
308 | } |
309 | |
310 | const MeshWithParallelDevice& mesh() const { return mesh_; } |
311 | |
312 | // Compute global shape from layout & local tensor shape. |
313 | // |
314 | // For replicated layout tensors, global shape is simply the shape of local |
315 | // tensors on each device. For sharded tensor, this is the global shape |
316 | // encodes layout & local shape on each device. |
317 | const std::vector<int64_t> global_shape() const { |
318 | return layout().GlobalShapeFromLocalShape(local_shape()); |
319 | } |
320 | |
321 | const std::vector<int64_t> local_shape() const { return local_shape_; } |
322 | |
323 | const absl::optional<NodeDef> const_value() const { return const_value_; } |
324 | |
325 | const absl::optional<EmbeddingResourceAttrs>& attrs() const { return attrs_; } |
326 | |
327 | protected: |
328 | TensorWithLayout(std::unique_ptr<parallel_device::ParallelTensor> tensor, |
329 | const MeshWithParallelDevice& mesh, const Layout& layout, |
330 | std::vector<int64_t> local_shape, |
331 | absl::optional<TF_DataType> dtype = absl::nullopt, |
332 | absl::optional<NodeDef> const_value = absl::nullopt) |
333 | : tensor_(std::move(tensor)), |
334 | layout_(layout), |
335 | mesh_(mesh), |
336 | const_value_(std::move(const_value)), |
337 | local_shape_(local_shape), |
338 | dtype_(dtype) {} |
339 | |
340 | std::unique_ptr<parallel_device::ParallelTensor> tensor_; |
341 | |
342 | Layout layout_; |
343 | |
344 | const MeshWithParallelDevice& mesh_; |
345 | |
346 | // Optionally holds the value of a small, non-resource tensor. Small constants |
347 | // are directly folded into the SPMD graph instead of being passed as inputs. |
348 | // This provides extra information to the layout propagation and SPMD passes |
349 | // during op-by-op execution. (For example, the reduction indices for Sum, |
350 | // target shapes for Rng/Reshape, etc). |
351 | absl::optional<NodeDef> const_value_; |
352 | |
353 | // Optionally holds the original input layout for a shape Op returned Tensor. |
354 | // This is used to preserve information for a shape op output so that future |
355 | // uses could recover local shape. |
356 | // TODO(hthu,allenl,xiejw): Move this into a separate class for clarity. |
357 | absl::optional<Layout> input_layout_for_shape_op_result_ = absl::nullopt; |
358 | |
359 | // The local shape of tensors placed on each of `tensor_`'s component devices. |
360 | std::vector<int64_t> local_shape_; |
361 | |
362 | absl::optional<TF_DataType> dtype_; |
363 | |
364 | // Resource input attributes for embedding inputs. |
365 | absl::optional<EmbeddingResourceAttrs> attrs_; // NOLINT |
366 | }; |
367 | |
368 | // Extension of TensorWithLayout which holds resource handle with layout. |
369 | // |
370 | // The major differences are |
371 | // 1. The layout, shape, dtype are lazily set as they are unavailable upon |
372 | // creation. |
373 | // 2. Small const optimization should be disabled. |
374 | class ResourceHandleWithLayout : public TensorWithLayout { |
375 | public: |
376 | // The layout of uninitialized resource tensors, or the layout of the tensor |
377 | // contained in an initialized resource. |
378 | const Layout& layout() const override { |
379 | return dereferenced_layout_.has_value() ? dereferenced_layout_.value() |
380 | : layout_; |
381 | } |
382 | |
383 | TensorType tensor_type() const override { return TensorType::kResource; } |
384 | |
385 | void set_const_value(NodeDef& const_node) override { |
386 | // Just a no-op for resource handle. Maybe we should error out. |
387 | } |
388 | |
389 | void EncodeAttributes(tensorflow::NodeDefBuilder& builder) const override; |
390 | |
391 | tensorflow::Fprint128 CacheKey() const override; |
392 | |
393 | void UpdateLayout(const Layout& new_layout, TF_Status* status) override; |
394 | |
395 | void UpdateShapeAndDType(const TensorShapeProto& shape, const DataType& dtype, |
396 | TF_Status* status) override { |
397 | set_dereferenced_shape(shape); |
398 | set_dereferenced_dtype(dtype); |
399 | } |
400 | |
401 | void UpdateAttrs(const EmbeddingResourceAttrs& attrs, |
402 | TF_Status* status) override; |
403 | |
404 | void UpdateDirtyness(bool is_dirty, TF_Status* status) { |
405 | if (!attrs_.has_value()) { |
406 | TF_SetStatus(status, TF_INTERNAL, |
407 | "Attempt to update dirtyness on non embedding resource" ); |
408 | } |
409 | attrs_.value().is_dirty = is_dirty; |
410 | } |
411 | |
412 | void set_dereferenced_shape(const TensorShapeProto& shape) { |
413 | dereferenced_shape_.emplace(shape); |
414 | } |
415 | void set_dereferenced_dtype(const DataType& dtype) { |
416 | dereferenced_dtype_.emplace(dtype); |
417 | } |
418 | |
419 | const absl::optional<TensorShapeProto>& dereferenced_shape() const { |
420 | return dereferenced_shape_; |
421 | } |
422 | const absl::optional<DataType>& dereferenced_dtype() const { |
423 | return dereferenced_dtype_; |
424 | } |
425 | |
426 | public: |
427 | ResourceHandleWithLayout( |
428 | std::unique_ptr<parallel_device::ParallelTensor> tensor, |
429 | const MeshWithParallelDevice& mesh, const Layout& layout, |
430 | std::vector<int64_t> local_shape) |
431 | : TensorWithLayout(std::move(tensor), mesh, layout, local_shape, |
432 | TF_RESOURCE) {} |
433 | |
434 | private: |
435 | // The layout of the tensor pointed to by this handle, if any. |
436 | absl::optional<Layout> dereferenced_layout_; |
437 | // The shape and dtype of the tensor pointed to by this resource tensor. |
438 | absl::optional<TensorShapeProto> dereferenced_shape_; |
439 | absl::optional<DataType> dereferenced_dtype_; |
440 | }; |
441 | |
442 | // TensorWithLayout for SparseTensors. |
443 | // |
444 | // The main difference between this and TensorWithLayout is this |
445 | // contains 3 lists of tensors as opposed to one (values, indices, shapes). |
446 | // The shapes of the SparseTensors will always be the dense view of the shapes, |
447 | // and thus will have no difference with the TensorWithLayout in terms of |
448 | // shapes. |
449 | class SparseTensorWithLayout : public TensorWithLayout { |
450 | public: |
451 | static StatusOr<std::unique_ptr<TensorWithLayout>> Wrap( |
452 | std::unique_ptr<parallel_device::ParallelTensor> indices_tensor, |
453 | std::unique_ptr<parallel_device::ParallelTensor> values_tensor, |
454 | std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor, |
455 | const MeshWithParallelDevice& mesh, const Layout& layout, |
456 | std::vector<int64_t> local_shape); |
457 | |
458 | // A dummy TensorWithLayout without holding a ParallelTensor. |
459 | static std::unique_ptr<TensorWithLayout> Dummy( |
460 | const std::vector<int64_t>& local_shape, |
461 | const MeshWithParallelDevice& mesh, const Layout& layout) { |
462 | return std::unique_ptr<TensorWithLayout>(new SparseTensorWithLayout( |
463 | /*indices=*/nullptr, /*values=*/nullptr, /*dense_shapes=*/nullptr, mesh, |
464 | layout, local_shape)); |
465 | } |
466 | |
467 | void set_const_value(NodeDef& const_node) override { |
468 | // No-op for SparseTensors, consider erroring out. |
469 | } |
470 | |
471 | // Add attribute '_sparse' to the NodeDefBuilder so that the mlir::Value |
472 | // that originate from SparseTensorWithLayout are marked as '_sparse'. |
473 | void EncodeAttributes(tensorflow::NodeDefBuilder& builder) const override { |
474 | builder.Attr("_sparse" , true); |
475 | } |
476 | |
477 | TensorType tensor_type() const override { return TensorType::kSparse; } |
478 | |
479 | size_t num_tensors() const override { return 3 * indices()->num_tensors(); } |
480 | |
481 | TFE_TensorHandle* get_tensor(size_t index) const override; |
482 | |
483 | std::string SummarizeValue() const override; |
484 | |
485 | std::string DebugString() const override; |
486 | |
487 | TF_DataType dtype() const override; |
488 | |
489 | parallel_device::ParallelTensor* indices() const { return indices_.get(); } |
490 | |
491 | parallel_device::ParallelTensor* values() const { return values_.get(); } |
492 | |
493 | parallel_device::ParallelTensor* dense_shapes() const { |
494 | return dense_shapes_.get(); |
495 | } |
496 | |
497 | protected: |
498 | SparseTensorWithLayout( |
499 | std::unique_ptr<parallel_device::ParallelTensor> indices, |
500 | std::unique_ptr<parallel_device::ParallelTensor> values, |
501 | std::unique_ptr<parallel_device::ParallelTensor> dense_shapes, |
502 | const MeshWithParallelDevice& mesh, const Layout& layout, |
503 | std::vector<int64_t> local_shape, |
504 | absl::optional<TF_DataType> dtype = absl::nullopt, |
505 | absl::optional<NodeDef> const_value = absl::nullopt) |
506 | : TensorWithLayout(nullptr, mesh, layout, local_shape), |
507 | indices_(std::move(indices)), |
508 | values_(std::move(values)), |
509 | dense_shapes_(std::move(dense_shapes)) {} |
510 | std::unique_ptr<parallel_device::ParallelTensor> indices_; |
511 | std::unique_ptr<parallel_device::ParallelTensor> values_; |
512 | std::unique_ptr<parallel_device::ParallelTensor> dense_shapes_; |
513 | }; |
514 | |
515 | template <typename T> |
516 | std::string ShapeToDebugString(const std::vector<T> shape_vector) { |
517 | std::vector<tensorflow::int64> cast_shape(shape_vector.begin(), |
518 | shape_vector.end()); |
519 | tensorflow::PartialTensorShape shape; |
520 | if (!tensorflow::PartialTensorShape::MakePartialShape( |
521 | cast_shape.data(), cast_shape.size(), &shape) |
522 | .ok()) { |
523 | return "<error displaying shape>" ; |
524 | } else { |
525 | return shape.DebugString(); |
526 | } |
527 | } |
528 | // Class that holds information about DTensor Functions ran, including cached |
529 | // lowered functions and constant folding input information per function. |
530 | // |
531 | // |
532 | // The caching policy for constant folded inputs is the following: |
533 | // In the first call to a function, we assume that all the inputs that |
534 | // are constant foldable are constant folded and save these values. In the |
535 | // next call to the same function call, we compare the values of constant |
536 | // folded inputs to the previous constant folded inputs. We disable constant |
537 | // folding for the changed values, and save these new inputs. |
538 | // TODO(b/169348205) Support cache eviction if the cache gets bloated. |
539 | class FunctionManager { |
540 | public: |
541 | FunctionManager() = default; |
542 | |
543 | // Caches the graph with the lowered 'function'. |
544 | const ExecutionFunctions* AddCachedFunction(const DTensorOperation& op, |
545 | tensorflow::Fprint128 cache_key, |
546 | ExecutionFunctions function); |
547 | |
548 | // Returns the cache key and the cached lowered graph for the function. |
549 | // Returns a nullptr for the lowered graph if there is a cache miss. |
550 | // Upon a cache miss, this will save some metadata about the function |
551 | // and the small inputs to keep track of information for constant folding. |
552 | std::pair<tensorflow::Fprint128, const ExecutionFunctions*> GetCachedFunction( |
553 | const DTensorOperation& doperation, const NameAttrList& attributes, |
554 | const std::vector<TensorWithLayout*>& inputs, |
555 | const std::vector<const Layout*>& output_layouts); |
556 | |
557 | // Returns whether the input at `input_index` is known to be constant |
558 | // foldable for function `doperation`. An input is not constant foldable if we |
559 | // have ran this function at least twice and the small input value changed |
560 | // across separate runs. |
561 | bool IsConstantFoldable(const DTensorOperation& doperation, |
562 | const int input_index) const; |
563 | |
564 | private: |
565 | // Cache key for dtensor operation name, which includes the op name |
566 | // and the input shapes. This is needed as a higher level cache for constant |
567 | // folding. |
568 | const tensorflow::Fprint128 CacheKeyForDTensorOperation( |
569 | const DTensorOperation& doperation) const; |
570 | |
571 | // Generates a cache key for the graph, including its attributes, |
572 | // inputs, and outputs. |
573 | tensorflow::Fprint128 CacheKeyForGraph( |
574 | const DTensorOperation& doperation, const NameAttrList& attributes, |
575 | const std::vector<TensorWithLayout*>& inputs, |
576 | const std::vector<const Layout*>& output_layouts); |
577 | |
578 | // Maps the hash of a graph with the lowered graph. |
579 | absl::flat_hash_map<tensorflow::Fprint128, ExecutionFunctions, |
580 | tensorflow::Fprint128Hasher> |
581 | function_cache_; |
582 | |
583 | // Maps the hash of dtensor_operation and its input shapes to a map |
584 | // representing the small constant indices and values to the function. The |
585 | // small constant indices are saved to make faster comparisons for constant |
586 | // folding validation. |
587 | absl::flat_hash_map<tensorflow::Fprint128, absl::flat_hash_map<int, NodeDef>, |
588 | tensorflow::Fprint128Hasher> |
589 | dtensor_op_and_small_inputs_; |
590 | }; |
591 | |
592 | // Returns the shape of a given tensor. |
593 | std::vector<int64_t> TensorShapeAsVector(TFE_TensorHandle* tensor, |
594 | TF_Status* status); |
595 | |
596 | // Creates a Graph with _Arg and _Retval nodes surrounding an |
597 | // `operation_name`-type node. |
598 | Status PrepareGraphForMlir( |
599 | const FunctionManager& function_manager, |
600 | const std::vector<TensorWithLayout*>& inputs, |
601 | const DTensorOperation& doperation, |
602 | const tensorflow::FunctionLibraryDefinition& flib_def, |
603 | const NameAttrList& attributes, |
604 | const absl::optional<Layout>& default_layout, tensorflow::Graph* graph, |
605 | std::vector<PartialTensorShape>* global_output_shapes, |
606 | std::vector<const Layout*>* output_layouts); |
607 | |
608 | // Returns set of functions to run to execute DTensor computation. |
609 | StatusOr<ExecutionFunctions> IdentifyAllFunctionsToExecute( |
610 | const tensorflow::Graph& graph, |
611 | const std::vector<PartialTensorShape>& global_output_shapes); |
612 | |
613 | // For functions with control outputs, add identity nodes between |
614 | // StatefulPartitionedCall and _Retvals, in order to preserve control output |
615 | // dependencies after StatefulPartitionedCall is inlined at runtime. |
616 | // Consider calling this in PrepareGraphForMlir, once the identity nodes won't |
617 | // be dropped during MLIR lowering. |
618 | // TODO(b/171265131): fix the underlying issue to avoid inserting identity |
619 | // nodes. |
620 | Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph); |
621 | |
622 | // Add DTensor specific function attributes to be compatible with eager runtime. |
623 | void AddDTensorFunctionAttr(FunctionDef& function_def); |
624 | |
625 | // Prepare inputs of embeddings for checkpoint functions. |
626 | StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs( |
627 | const std::vector<TensorWithLayout*>& inputs); |
628 | |
629 | Status InsertFunctionForTPUEmbeddingCheckpoint( |
630 | TF_Status* status, Graph* graph, |
631 | const std::vector<TensorWithLayout*>& inputs, |
632 | const std::string& checkpoint_fn_name); |
633 | |
634 | } // namespace dtensor |
635 | } // namespace tensorflow |
636 | |
637 | #endif // TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_ |
638 | |