1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_MLIR_COLLECTIVES_H_ |
17 | #define TENSORFLOW_DTENSOR_MLIR_COLLECTIVES_H_ |
18 | |
19 | #include <string> |
20 | |
21 | #include "absl/container/flat_hash_set.h" |
22 | #include "absl/strings/string_view.h" |
23 | #include "mlir/IR/Value.h" // from @llvm-project |
24 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
25 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace dtensor { |
29 | |
30 | // Emits collective ops to convert `input` from `src_layout` to `tgt_layout`. |
31 | // `src_layout` and `tgt_layout` must have the same rank. For each dimension, |
32 | // it can only go from sharded to replicated. `input` must have static shapes. |
33 | StatusOr<mlir::Value> EmitAllGather( |
34 | mlir::OpBuilder& builder, mlir::Value input, |
35 | const dtensor::Layout& src_layout, const dtensor::Layout& tgt_layout, |
36 | llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops = nullptr); |
37 | |
38 | // Given an input layout and a desired layout, inserts the necessary slice to |
39 | // slice the original value based on the device id. All ops created by this |
40 | // function are added to new_created_ops. |
41 | // |
42 | // Note that the newly created ops are inserted `after` original_value. |
43 | StatusOr<const mlir::Value> EmitAllScatter( |
44 | mlir::OpBuilder& builder, const mlir::Value& original_value, |
45 | const Layout& original_layout, const Layout& desired_layout, |
46 | llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops = nullptr); |
47 | |
48 | // Emits splits and calls EmitAllGather (once) to relayout from the src layout |
49 | // to the tgt layout on a single mesh. |
50 | // Shape of input is expected to be the local shape for src_layout. |
51 | StatusOr<mlir::Value> EmitRelayout( |
52 | mlir::Value input, const dtensor::Layout& src_layout, |
53 | const dtensor::Layout& tgt_layout, |
54 | llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops = nullptr); |
55 | |
56 | // Emits collective ops to reduce `input` over `reduced_dims`. |
57 | StatusOr<mlir::Operation*> EmitAllReduce( |
58 | mlir::OpBuilder& builder, const dtensor::Layout& output_layout, |
59 | const absl::flat_hash_set<std::string>& reduced_dims, |
60 | mlir::Operation* input, absl::string_view reduce_op); |
61 | |
62 | // Given input `tensor` that is sharded across spatial dimensions, conduct |
63 | // halo exchange such that each spatially sharded input blocks exchange |
64 | // `halo_size` slice with its neighboring processors. |
65 | // If the input block is at the left/right/top/bottom edge, then ghost halo |
66 | // tensor (zero) are padded instead. `mesh_dim` specifies the dimension which |
67 | // halo exchange will be conducted. For example, if we consider a 4D Tensor |
68 | // (batch, height, width, channel) that has layout (*, h, w, *). Then, |
69 | // `mesh_dim` == "w" would mean that halo exchange will occur along the width |
70 | // dimension. That is halo tensors with right/left neighbors will be exchanged. |
71 | StatusOr<mlir::Value> EmitHaloExchange(mlir::OpBuilder& builder, int halo_size, |
72 | const std::string& mesh_dim, |
73 | const Layout& layout, |
74 | mlir::Value mesh_coordinates, |
75 | mlir::tf_device::ClusterOp cluster, |
76 | mlir::Location location, |
77 | mlir::Value tensor); |
78 | |
79 | // Emits a DenseToSparse op followed by a SparseToDenseOp. |
80 | // This is useful for emitting a Relayout on a SparseTensor. |
81 | // One usage of this is in EmitRelayout when the input is a SparseTensor. |
82 | StatusOr<mlir::Value> EmitDenseToSparseToDense( |
83 | mlir::OpBuilder& builder, mlir::Value input, |
84 | llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops = nullptr); |
85 | |
86 | } // namespace dtensor |
87 | } // namespace tensorflow |
88 | |
89 | #endif // TENSORFLOW_DTENSOR_MLIR_COLLECTIVES_H_ |
90 | |