1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_ |
17 | #define TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_ |
18 | |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | #include "absl/container/flat_hash_map.h" |
23 | #include "absl/strings/string_view.h" |
24 | #include "llvm/ADT/ArrayRef.h" |
25 | #include "llvm/ADT/SmallPtrSet.h" |
26 | #include "llvm/ADT/SmallVector.h" |
27 | #include "llvm/Support/Casting.h" |
28 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
29 | #include "mlir/IR/Builders.h" // from @llvm-project |
30 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
31 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
32 | #include "mlir/IR/MLIRContext.h" // from @llvm-project |
33 | #include "mlir/IR/Value.h" // from @llvm-project |
34 | #include "mlir/IR/Visitors.h" // from @llvm-project |
35 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
36 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
37 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" |
38 | #include "tensorflow/dtensor/cc/dstatus.h" |
39 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
40 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
41 | |
42 | namespace tensorflow { |
43 | namespace dtensor { |
44 | |
45 | constexpr absl::string_view kReduceOpAdd = "Add" ; |
46 | constexpr absl::string_view kReduceOpAll = "All" ; |
47 | constexpr absl::string_view kReduceOpAny = "Any" ; |
48 | constexpr absl::string_view kReduceOpMax = "Max" ; |
49 | constexpr absl::string_view kReduceOpMin = "Min" ; |
50 | constexpr absl::string_view kReduceOpMul = "Mul" ; |
51 | // Mean is not a valid combinator function on its own. It is handled specially |
52 | // by the reduce expansion. |
53 | constexpr absl::string_view kReduceOpMean = "Mean" ; |
54 | |
55 | // Returns true if all layouts are replicated. |
56 | bool AllReplicated(const std::vector<Layout>& layouts); |
57 | |
58 | // Takes a global type and converts it to a local type. Fails if the number of |
59 | // shards does not divide the size of the dimension (if not dynamic). |
60 | StatusOr<mlir::TensorType> LocalTypeFromGlobalType( |
61 | const Layout& layout, const mlir::TensorType& original_type); |
62 | |
63 | // Takes a global type and converts it to a local type. |
64 | StatusOr<mlir::TensorType> GlobalTypeFromLocalType( |
65 | const Layout& layout, const mlir::TensorType& original_type); |
66 | |
67 | // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways |
68 | // in 'split_dimension' dimension and returns the split values. |
69 | Status CreateSplitOp(const int num_split, const int split_dimension, |
70 | const mlir::Location location, mlir::Value src_input, |
71 | mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op); |
72 | |
73 | // Given layouts + shapes, determines if the two are broadcast compatible. |
74 | // See source file for more documentation. |
75 | StatusOr<Layout> GetBroadcastLayoutForElementWise( |
76 | const Layout& layout_a, const Layout& layout_b, |
77 | mlir::ArrayRef<int64_t> shape_a, mlir::ArrayRef<int64_t> shape_b, |
78 | int64_t dims_to_ignore, std::vector<std::string>& to_split_a, |
79 | std::vector<std::string>& to_split_b); |
80 | |
81 | // Returns a merged layout using `GetBroadcastLayoutForElementwise()` function |
82 | // given a list of operand layouts. |
83 | StatusOr<absl::optional<Layout>> GetMergedOperandLayout( |
84 | const llvm::DenseMap<int, Layout>& operand_layouts, mlir::Operation* op); |
85 | |
86 | // Returns the forwarded input value of DTensorLayout op for which `value` is |
87 | // the output. This must be used after layout propagation and before SPMD |
88 | // expansion when all mlir::Value's of tf ops are followed by DTensorLayout op |
89 | // to specify output layout. |
90 | // To make the implementation safe for Layout Propagation V1 algorithm, if the |
91 | // defining op of `value` is not DTensorLayout op (only the case for V1), |
92 | // returns `value` directly. |
93 | // TODO(b/172936130): Remove special casing for v1 Layout Propagation |
94 | // algorithm. |
95 | mlir::Value GetForwardedDTensorLayoutInput(mlir::Value value); |
96 | |
97 | // Goal of this function is to connect 'mlir::Value's (read 'mlir::OpResult's) |
98 | // to the 'mlir::OpOperand's which use them, crossing function call |
99 | // boundaries. The only keys in consumers which will not actually be |
100 | // 'mlir::OpResult's will be the 'mlir::Value's representing the inputs of the |
101 | // main function. The rest will be direct output of operations -- i.e. |
102 | // mlir::OpResult. Note that 'mlir::Value's that are not used by any op or are |
103 | // simply returned from the main functiuon will not be in this list. In these |
104 | // cases, there are no conditions on the layouts for these 'mlir::Value's. |
105 | // |
106 | // A list of current assumptions in this code: |
107 | // * Functions are only called once. |
108 | // * Functions that are not reachable from main have been trimmed. |
109 | // * Input to CopyToMesh can always be traced back to function inputs. |
110 | mlir::LogicalResult PopulateConsumersFromModule( |
111 | mlir::ModuleOp* module, mlir::Dialect* tf_dialect, |
112 | llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers); |
113 | |
114 | // From device id, return an mlir::Value for a tensor of shape [1, |
115 | // mesh.rank()] whose entries are the mesh coordinates of the device. The mesh |
116 | // used, is the mesh for the given cluster. |
117 | StatusOr<mlir::Value> GetMeshCoordinatesFromCluster( |
118 | mlir::tf_device::ClusterOp cluster); |
119 | |
120 | // Checks that optional metadata attributes of `op` are valid if they |
121 | // exist. More specifically, output layouts of tf.Shape op and layouts of |
122 | // resources inferred from AssignVariable op is added as metadata. |
123 | mlir::LogicalResult ValidateMetadataAttributes(mlir::Operation* op); |
124 | |
125 | // Creates a map from function to ops which calls the function. |
126 | mlir::LogicalResult GetFuncToCaller( |
127 | mlir::ModuleOp module, |
128 | llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller); |
129 | |
130 | // Takes an operand and traces its use across function call and |
131 | // tf_device.cluster boundaries. Note that this may turn one operand into |
132 | // many. |
133 | llvm::SmallVector<mlir::OpOperand*, 4> TraceUseToNextTFOp( |
134 | mlir::OpOperand* operand, |
135 | const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller, |
136 | llvm::SmallVector<mlir::Value, 4>* skipped_values = nullptr); |
137 | |
138 | // Replaces `cluster` with a new tf_device.cluster without return values |
139 | // if result values are not used by any other ops. |
140 | // |
141 | // For example: |
142 | // |
143 | // %unused_value = "tf_device.cluster"() ({ |
144 | // %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> |
145 | // tensor<i32> %2 = "tf.Neg"(%1) : (tensor<i32>) -> tensor<i32> |
146 | // tf_device.return %2 : tensor<i32> |
147 | // }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>) |
148 | // |
149 | // Will be transformed to: |
150 | // |
151 | // "tf_device.cluster"() ({ |
152 | // %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> |
153 | // tensor<i32> %2 = "tf.Neg"(%1) : (tensor<i32>) -> tensor<i32> |
154 | // tf_device.return |
155 | // }) {_mesh="mesh:CPU,x=2,y=2"} : () -> () |
156 | void RemoveUnusedClusterResults(mlir::tf_device::ClusterOp cluster); |
157 | |
158 | mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix, |
159 | mlir::OpBuilder& builder); |
160 | |
161 | // Sets the builder insertion point to after value. If value is a block |
162 | // argument, this checks that all users of the value are in the same cluster. |
163 | // If not it errors out. If they are then it sets the inserition point to the |
164 | // top of the cluster. |
165 | Status SetBuilderInsertionAfterValue(mlir::Value value, |
166 | mlir::OpBuilder& builder); |
167 | |
168 | // Inserts a StringFormat and Print op, should only be used for debugging |
169 | // on CPU. |
170 | Status PrintTensor(mlir::Value value, const std::string& format_string); |
171 | |
172 | // Extract a vector of string from mlir value. |
173 | Status ( |
174 | mlir::Value value, llvm::SmallVectorImpl<std::string>& out_vector); |
175 | |
176 | StatusOr<std::string> (mlir::Value value); |
177 | |
178 | // A general Iterator that visits a FuncOp's body in topological order. Note |
179 | // that this does not visit the given FuncOp itself. Function ops are visited |
180 | // exactly once if functions are used in multiple call sites. |
181 | // |
182 | // An example usage of this Iterator is for SPMD Expansion or Sparse |
183 | // Expansion, where we expand ops in topological order starting from the |
184 | // `main` FuncOp, only visiting function ops once so that we don't expand |
185 | // multiple times. |
186 | class TopologicalIterator { |
187 | public: |
188 | explicit TopologicalIterator(mlir::func::FuncOp main_func); |
189 | |
190 | // Returns whether there is any further ops to visit. |
191 | bool hasNext(); |
192 | |
193 | // Returns the next op to visit in the topological ordering. Returns |
194 | // a nullptr if there is no next op to visit. |
195 | mlir::Operation* next(); |
196 | |
197 | private: |
198 | // Stack to keep track of ops to visit. |
199 | llvm::SmallVector<mlir::Operation*, 4> ops_to_visit_; |
200 | |
201 | // Keep track of functions we are walking, this is needed to avoid recursive |
202 | // function calls. |
203 | llvm::SmallDenseSet<mlir::StringRef, 4> funcs_visited_in_call_stack_; |
204 | |
205 | // Keep track of all visit functions. This is to guarantee that |
206 | // functions are visited exactly once if functions are used in multiple |
207 | // callsites. |
208 | llvm::SmallDenseSet<mlir::StringRef, 4> funcs_visited_; |
209 | }; |
210 | } // namespace dtensor |
211 | } // namespace tensorflow |
212 | |
213 | #endif // TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_ |
214 | |