1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "absl/types/optional.h" |
17 | #include "llvm/ADT/ArrayRef.h" |
18 | #include "llvm/ADT/DenseSet.h" |
19 | #include "llvm/ADT/STLExtras.h" |
20 | #include "llvm/ADT/SmallVector.h" |
21 | #include "llvm/Support/Casting.h" |
22 | #include "llvm/Support/Debug.h" |
23 | #include "llvm/Support/FormatVariadic.h" |
24 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
25 | #include "mlir/IR/Attributes.h" // from @llvm-project |
26 | #include "mlir/IR/Builders.h" // from @llvm-project |
27 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
28 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
29 | #include "mlir/IR/Dialect.h" // from @llvm-project |
30 | #include "mlir/IR/Operation.h" // from @llvm-project |
31 | #include "mlir/IR/TypeUtilities.h" // from @llvm-project |
32 | #include "mlir/IR/UseDefLists.h" // from @llvm-project |
33 | #include "mlir/IR/Value.h" // from @llvm-project |
34 | #include "mlir/Pass/Pass.h" // from @llvm-project |
35 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
36 | #include "mlir/Support/LLVM.h" // from @llvm-project |
37 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
38 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
39 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
40 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
41 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
42 | #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" |
43 | #include "tensorflow/dtensor/cc/constants.h" |
44 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
45 | #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" |
46 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
47 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
48 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
49 | #include "tensorflow/dtensor/mlir/op_utils.h" |
50 | #include "tensorflow/dtensor/mlir/spmd_expander.h" |
51 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
52 | |
53 | namespace tensorflow { |
54 | namespace dtensor { |
55 | |
56 | namespace { |
57 | #define GEN_PASS_DEF_DTENSORSPMDEXPANSION |
58 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
59 | |
60 | constexpr char kMainFunctionName[] = "main" ; |
61 | |
62 | // Updates `function` input signature operand at `argument_index` with |
63 | // `new_shape`. |
64 | void UpdateFunctionInputShape(const int argument_index, |
65 | mlir::RankedTensorType new_arg_type, |
66 | mlir::func::FuncOp function) { |
67 | auto func_type = function.getFunctionType(); |
68 | auto input_types = llvm::to_vector<8>(func_type.getInputs()); |
69 | input_types[argument_index] = new_arg_type; |
70 | auto new_func_type = mlir::FunctionType::get( |
71 | function.getContext(), input_types, func_type.getResults()); |
72 | function.setType(new_func_type); |
73 | function.getBody() |
74 | .getArgument(argument_index) |
75 | .setType(function.getFunctionType().getInput(argument_index)); |
76 | } |
77 | |
78 | // If `op` is a TF operation, return itself. If it is an DTensorLayout op, |
79 | // return it's consumer TF operation. |
80 | mlir::Operation* NextTFOp(mlir::Operation* op) { |
81 | while (auto layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) { |
82 | if (op->getUsers().empty()) return nullptr; |
83 | op = *(op->getUsers().begin()); |
84 | } |
85 | return op; |
86 | } |
87 | |
88 | // Updates the shape of resource argument if argument has `tf._layout` |
89 | // attribute. |
90 | // For example: |
91 | // main(%arg0: tensor<!tf_type.resource<tensor<4x4xf32>> |
92 | // {tf._layout = "mesh:TPU,x=2,y=2 layout:x,not_sharded"}) |
93 | // |
94 | // will be converted to: |
95 | // |
96 | // main(%arg0: tensor<!tf_type.resource<tensor<2x4xf32>> |
97 | // {tf._layout = "mesh:TPU,x=2,y=2 layout:x,not_sharded"}) |
98 | // |
99 | // Note that resource argument type is still a resource type. But it's subtype |
100 | // has been changed to reflect local shape. |
101 | // If resource argument does not have subtype or subtype does not have static |
102 | // shapes or if resource argument does not have corresponding layout attribute, |
103 | // this function is an no-op. |
104 | mlir::LogicalResult UpdateResourceArgumentType( |
105 | const int arg_index, mlir::func::FuncOp function, |
106 | absl::optional<mlir::RankedTensorType> new_subtype = absl::nullopt) { |
107 | auto resource_arg = function.getArgument(arg_index); |
108 | if (new_subtype) { |
109 | auto new_var_type = mlir::RankedTensorType::get( |
110 | {}, mlir::TF::ResourceType::get( |
111 | mlir::ArrayRef<mlir::TensorType>{*new_subtype}, |
112 | function.getContext())); |
113 | UpdateFunctionInputShape(arg_index, new_var_type, function); |
114 | function.setArgAttr(arg_index, kAssignedResourceLocalShape, |
115 | ConvertTypeToTensorShapeAttr(*new_subtype)); |
116 | return mlir::success(); |
117 | } |
118 | |
119 | auto resource_type = resource_arg.getType() |
120 | .cast<mlir::TensorType>() |
121 | .getElementType() |
122 | .dyn_cast<mlir::TF::ResourceType>(); |
123 | if (!resource_type) return mlir::success(); |
124 | |
125 | auto sub_types = resource_type.getSubtypes(); |
126 | if (sub_types.size() != 1) return mlir::success(); |
127 | |
128 | auto resource_arg_sub_type = sub_types.front(); |
129 | if (!resource_arg_sub_type.hasStaticShape()) return mlir::success(); |
130 | |
131 | // The local shape that is to be assigned to this resource argument type. We |
132 | // will either pull it from the assigned local shape attribute or compute it |
133 | // based on the layout. |
134 | // TODO(srujun): use the attribute value only to check the computed shape. |
135 | // This is currently blocked by an "empty_layout" set on the resource |
136 | // arguments, meaning it is not possible to compute local layout. |
137 | llvm::SmallVector<int64_t, 4> local_arg_shape; |
138 | auto assigned_resource_local_shape_attr = |
139 | function.getArgAttrOfType<mlir::TF::ShapeAttr>( |
140 | arg_index, kAssignedResourceLocalShape); |
141 | if (assigned_resource_local_shape_attr) { |
142 | local_arg_shape.append( |
143 | assigned_resource_local_shape_attr.getShape().begin(), |
144 | assigned_resource_local_shape_attr.getShape().end()); |
145 | } else { |
146 | auto layout_or_status = ExtractLayoutFromOperand(resource_arg); |
147 | if (!layout_or_status.ok()) |
148 | return function.emitOpError(layout_or_status.status().error_message()); |
149 | |
150 | const auto& layout = layout_or_status.value(); |
151 | if (!layout) return mlir::success(); |
152 | |
153 | std::vector<int64_t> local_arg_shape_vec = |
154 | layout->LocalShapeFromGlobalShape(resource_arg_sub_type.getShape()); |
155 | local_arg_shape.append(local_arg_shape_vec.begin(), |
156 | local_arg_shape_vec.end()); |
157 | } |
158 | |
159 | auto local_variable_subtype = mlir::RankedTensorType::get( |
160 | local_arg_shape, resource_arg_sub_type.getElementType()); |
161 | auto new_var_type = mlir::RankedTensorType::get( |
162 | {}, mlir::TF::ResourceType::get( |
163 | mlir::ArrayRef<mlir::TensorType>{local_variable_subtype}, |
164 | function.getContext())); |
165 | |
166 | UpdateFunctionInputShape(arg_index, new_var_type, function); |
167 | function.setArgAttr( |
168 | arg_index, kAssignedResourceLocalShape, |
169 | mlir::TF::ShapeAttr::get(local_variable_subtype.getContext(), |
170 | mlir::ArrayRef<int64_t>(local_arg_shape))); |
171 | |
172 | return mlir::success(); |
173 | } |
174 | |
175 | // Returns whether `value` is used by AssignVariable op, skipping DTensorLayout |
176 | // op. |
177 | bool IsValueUsedByAssignVariableOp( |
178 | mlir::Value value, int* resource_argument_index_for_assign_variable) { |
179 | for (auto user : value.getUsers()) { |
180 | if (auto assign_variable_op = |
181 | llvm::dyn_cast_or_null<mlir::TF::AssignVariableOp>( |
182 | NextTFOp(user))) { |
183 | *resource_argument_index_for_assign_variable = |
184 | GetForwardedDTensorLayoutInput(assign_variable_op.resource()) |
185 | .cast<mlir::BlockArgument>() |
186 | .getArgNumber(); |
187 | return true; |
188 | } |
189 | } |
190 | return false; |
191 | } |
192 | |
193 | // Updates argument shapes of `function` based on `tf._layout` attribute. |
194 | mlir::LogicalResult UpdateFunctionArgsUsingLayout(mlir::func::FuncOp function) { |
195 | for (int argument_index = 0; argument_index < function.getNumArguments(); |
196 | ++argument_index) { |
197 | auto arg_layout_attr = function.getArgAttrOfType<mlir::StringAttr>( |
198 | argument_index, kCustomDeviceAttr); |
199 | if (!arg_layout_attr) continue; |
200 | |
201 | auto arg_layout = Layout::FromString(arg_layout_attr.getValue().str()); |
202 | if (!arg_layout.ok()) |
203 | return function.emitOpError(llvm::formatv( |
204 | "Invalid layout attribute found during SPMD expansion: {0}" , |
205 | arg_layout.status().error_message())); |
206 | |
207 | mlir::Type arg_type = mlir::getElementTypeOrSelf( |
208 | function.getFunctionType().getInput(argument_index)); |
209 | |
210 | // If argument is a resource type update the subtype shape information |
211 | // to reflect local shape of resources. |
212 | if (arg_type.isa<mlir::TF::ResourceType>()) { |
213 | if (mlir::failed(UpdateResourceArgumentType(argument_index, function))) |
214 | return mlir::failure(); |
215 | continue; |
216 | } |
217 | |
218 | mlir::RankedTensorType ranked_type = |
219 | function.getFunctionType() |
220 | .getInput(argument_index) |
221 | .dyn_cast<mlir::RankedTensorType>(); |
222 | if (!ranked_type) continue; |
223 | |
224 | // If input value is non-resource type, then update the value to reflect |
225 | // local shape. |
226 | llvm::ArrayRef<int64_t> arg_shape = ranked_type.getShape(); |
227 | const std::vector<int64_t> arg_local_shape = |
228 | arg_layout->LocalShapeFromGlobalShape(arg_shape); |
229 | mlir::RankedTensorType new_arg_type = mlir::RankedTensorType::get( |
230 | arg_local_shape, ranked_type.getElementType()); |
231 | UpdateFunctionInputShape(argument_index, new_arg_type, function); |
232 | |
233 | // If non-resource value was used for AssignVariable op, then ensure that |
234 | // resource shape of updated/assigned resource is consistent with the |
235 | // local shape of assigned value. |
236 | int assigned_resource_argument_index = -1; |
237 | if (IsValueUsedByAssignVariableOp(function.getArgument(argument_index), |
238 | &assigned_resource_argument_index)) { |
239 | (void)UpdateResourceArgumentType(assigned_resource_argument_index, |
240 | function, new_arg_type); |
241 | } |
242 | } |
243 | return mlir::success(); |
244 | } |
245 | |
246 | // Given SPMD expanded `function_operands` to `function`, update the function |
247 | // signature to reflect the local shape of `function_operands`. |
248 | mlir::LogicalResult UpdateFunctionWithLocalInputShapes( |
249 | mlir::MutableArrayRef<mlir::OpOperand> function_operands, |
250 | mlir::func::FuncOp function) { |
251 | for (auto& operand : function_operands) { |
252 | const int index = operand.getOperandNumber(); |
253 | auto arg_type = operand.get().getType().dyn_cast<mlir::RankedTensorType>(); |
254 | if (!arg_type) continue; |
255 | |
256 | auto arg_local_shape = arg_type.getShape(); |
257 | auto new_arg_type = |
258 | mlir::RankedTensorType::get(arg_local_shape, arg_type.getElementType()); |
259 | UpdateFunctionInputShape(index, new_arg_type, function); |
260 | } |
261 | return mlir::success(); |
262 | } |
263 | |
264 | // Updates output shapes of enclosing op or function containing `terminator_op` |
265 | // to local shapes. |
266 | mlir::LogicalResult UpdateReturnValueShapes(mlir::ModuleOp module, |
267 | mlir::Operation* terminator_op) { |
268 | auto parent_op = terminator_op->getBlock()->getParentOp(); |
269 | if (!parent_op) return mlir::success(); |
270 | |
271 | auto output_types = llvm::to_vector<8>(terminator_op->getOperandTypes()); |
272 | if (auto function = llvm::dyn_cast<mlir::func::FuncOp>(parent_op)) { |
273 | // Update function output type to have local shape. |
274 | auto new_func_type = mlir::FunctionType::get( |
275 | function.getContext(), function.getFunctionType().getInputs(), |
276 | output_types); |
277 | function.setType(new_func_type); |
278 | |
279 | // Update function callsite operations to reflect local output shapes. |
280 | auto function_uses = |
281 | mlir::SymbolTable::getSymbolUses(function, &module.getBodyRegion()); |
282 | if (!function_uses) return mlir::success(); |
283 | |
284 | // Update function callsite operations to reflect local output shapes. |
285 | for (auto function_use : *function_uses) { |
286 | auto callsite_op = function_use.getUser(); |
287 | if (!callsite_op) continue; |
288 | |
289 | for (auto& output_type_and_index : llvm::enumerate(output_types)) { |
290 | int index = output_type_and_index.index(); |
291 | const auto& type = output_type_and_index.value(); |
292 | callsite_op->getResult(index).setType(type); |
293 | } |
294 | } |
295 | } else { |
296 | for (auto& output_type_and_index : llvm::enumerate(output_types)) { |
297 | int index = output_type_and_index.index(); |
298 | const auto& type = output_type_and_index.value(); |
299 | parent_op->getResult(index).setType(type); |
300 | } |
301 | } |
302 | |
303 | return mlir::success(); |
304 | } |
305 | |
306 | // Conducts SPMD expansion for all ops in `module`. If function call operation |
307 | // exists, walk the function in topological order to update inputs/outputs of |
308 | // functions before SPMD expansion of callsite operations is done. |
309 | // Note that the iteration won't work with recursive function calls. |
310 | mlir::LogicalResult ConductSPMDExpansion(mlir::ModuleOp module) { |
311 | auto main_func = module.lookupSymbol<mlir::func::FuncOp>(kMainFunctionName); |
312 | if (!main_func) |
313 | return module.emitOpError( |
314 | "could not find `main` function in module for SPMD expansion." ); |
315 | |
316 | if (mlir::failed(UpdateFunctionArgsUsingLayout(main_func))) |
317 | return mlir::failure(); |
318 | |
319 | TopologicalIterator iterator(main_func); |
320 | while (iterator.hasNext()) { |
321 | mlir::Operation* op = iterator.next(); |
322 | absl::optional<mlir::func::FuncOp> func = MaybeFindFunction(op); |
323 | if (func.has_value()) { |
324 | if (mlir::failed( |
325 | UpdateFunctionWithLocalInputShapes(op->getOpOperands(), *func))) |
326 | return mlir::failure(); |
327 | } |
328 | |
329 | const bool is_terminator_op = |
330 | llvm::isa<mlir::func::ReturnOp, mlir::tf_device::ReturnOp>(op); |
331 | if (auto layout_op = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) |
332 | layout_op.output().setType(layout_op.input().getType()); |
333 | |
334 | mlir::Operation* expanded_op = nullptr; |
335 | auto status = RunSPMDExpansion(op, &expanded_op); |
336 | if (!status.ok() || expanded_op == nullptr) { |
337 | // Sometimes op may been erased and expanded_op set. |
338 | // In this case we should emit the error on the expanded op. |
339 | mlir::Operation* emit_op = op; |
340 | if (expanded_op != nullptr) emit_op = expanded_op; |
341 | return emit_op->emitError(WithContext(status, __FILE__, __LINE__, |
342 | "While computing SPMD expansion" ) |
343 | .error_message()); |
344 | } |
345 | |
346 | // If expanded op is terminator of tf_device.Cluster or a function, then |
347 | // make sure to update the function return value as well as the shape of |
348 | // it's callsite operation. |
349 | if (is_terminator_op) |
350 | if (mlir::failed(UpdateReturnValueShapes(module, expanded_op))) |
351 | return mlir::failure(); |
352 | } |
353 | return mlir::success(); |
354 | } |
355 | |
356 | // DTensorLayout only conveys layout information of tensors which is no |
357 | // longer needed after SPMD expansion. As so, remove all layouts from |
358 | // graph. |
359 | void RemoveDTensorLayoutOps(mlir::ModuleOp module) { |
360 | llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops; |
361 | module.walk( |
362 | [&](mlir::TF::DTensorLayout layout) { layout_ops.emplace_back(layout); }); |
363 | |
364 | for (auto layout_op : layout_ops) RemoveDTensorLayoutOp(layout_op); |
365 | } |
366 | |
367 | // Removes temporary attrs created during SPMD expansion. |
368 | void RemoveTemporarySPMDAttrs(mlir::ModuleOp module) { |
369 | module.walk([&](mlir::Operation* op) { |
370 | if (op->hasAttr(kDeviceSeedForMeshDims)) { |
371 | op->removeAttr(kDeviceSeedForMeshDims); |
372 | } |
373 | }); |
374 | } |
375 | |
376 | // MLIR pass that converts graph in global view into a local view which can be |
377 | // invoked in parallel on distributed set of devices. This pass removes |
378 | // all DTensorLayout ops after the expansion is done. Temporary nodes and |
379 | // attributes are also removed after the pass is done. |
380 | struct DTensorSPMDExpansion |
381 | : public impl::DTensorSPMDExpansionBase<DTensorSPMDExpansion> { |
382 | void getDependentDialects(mlir::DialectRegistry& registry) const override { |
383 | registry.insert<mlir::dtensor::DTensorDialect>(); |
384 | } |
385 | |
386 | void runOnOperation() override { |
387 | auto module = getOperation(); |
388 | if (failed(ConductSPMDExpansion(module))) return signalPassFailure(); |
389 | |
390 | RemoveDTensorLayoutOps(module); |
391 | |
392 | RemoveTemporarySPMDAttrs(module); |
393 | }; |
394 | }; |
395 | |
396 | } // namespace |
397 | |
398 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
399 | CreateDTensorSPMDExpansion() { |
400 | return std::make_unique<DTensorSPMDExpansion>(); |
401 | } |
402 | |
403 | } // namespace dtensor |
404 | } // namespace tensorflow |
405 | |