1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "absl/types/optional.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/DenseSet.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/Support/Casting.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25#include "mlir/IR/Attributes.h" // from @llvm-project
26#include "mlir/IR/Builders.h" // from @llvm-project
27#include "mlir/IR/BuiltinOps.h" // from @llvm-project
28#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29#include "mlir/IR/Dialect.h" // from @llvm-project
30#include "mlir/IR/Operation.h" // from @llvm-project
31#include "mlir/IR/TypeUtilities.h" // from @llvm-project
32#include "mlir/IR/UseDefLists.h" // from @llvm-project
33#include "mlir/IR/Value.h" // from @llvm-project
34#include "mlir/Pass/Pass.h" // from @llvm-project
35#include "mlir/Pass/PassManager.h" // from @llvm-project
36#include "mlir/Support/LLVM.h" // from @llvm-project
37#include "mlir/Support/LogicalResult.h" // from @llvm-project
38#include "mlir/Transforms/Passes.h" // from @llvm-project
39#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
43#include "tensorflow/dtensor/cc/constants.h"
44#include "tensorflow/dtensor/cc/tensor_layout.h"
45#include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
46#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
47#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
48#include "tensorflow/dtensor/mlir/layout_parsing.h"
49#include "tensorflow/dtensor/mlir/op_utils.h"
50#include "tensorflow/dtensor/mlir/spmd_expander.h"
51#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
52
53namespace tensorflow {
54namespace dtensor {
55
56namespace {
57#define GEN_PASS_DEF_DTENSORSPMDEXPANSION
58#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
59
60constexpr char kMainFunctionName[] = "main";
61
62// Updates `function` input signature operand at `argument_index` with
63// `new_shape`.
64void UpdateFunctionInputShape(const int argument_index,
65 mlir::RankedTensorType new_arg_type,
66 mlir::func::FuncOp function) {
67 auto func_type = function.getFunctionType();
68 auto input_types = llvm::to_vector<8>(func_type.getInputs());
69 input_types[argument_index] = new_arg_type;
70 auto new_func_type = mlir::FunctionType::get(
71 function.getContext(), input_types, func_type.getResults());
72 function.setType(new_func_type);
73 function.getBody()
74 .getArgument(argument_index)
75 .setType(function.getFunctionType().getInput(argument_index));
76}
77
78// If `op` is a TF operation, return itself. If it is an DTensorLayout op,
79// return it's consumer TF operation.
80mlir::Operation* NextTFOp(mlir::Operation* op) {
81 while (auto layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) {
82 if (op->getUsers().empty()) return nullptr;
83 op = *(op->getUsers().begin());
84 }
85 return op;
86}
87
88// Updates the shape of resource argument if argument has `tf._layout`
89// attribute.
90// For example:
91// main(%arg0: tensor<!tf_type.resource<tensor<4x4xf32>>
92// {tf._layout = "mesh:TPU,x=2,y=2 layout:x,not_sharded"})
93//
94// will be converted to:
95//
96// main(%arg0: tensor<!tf_type.resource<tensor<2x4xf32>>
97// {tf._layout = "mesh:TPU,x=2,y=2 layout:x,not_sharded"})
98//
99// Note that resource argument type is still a resource type. But it's subtype
100// has been changed to reflect local shape.
101// If resource argument does not have subtype or subtype does not have static
102// shapes or if resource argument does not have corresponding layout attribute,
103// this function is an no-op.
104mlir::LogicalResult UpdateResourceArgumentType(
105 const int arg_index, mlir::func::FuncOp function,
106 absl::optional<mlir::RankedTensorType> new_subtype = absl::nullopt) {
107 auto resource_arg = function.getArgument(arg_index);
108 if (new_subtype) {
109 auto new_var_type = mlir::RankedTensorType::get(
110 {}, mlir::TF::ResourceType::get(
111 mlir::ArrayRef<mlir::TensorType>{*new_subtype},
112 function.getContext()));
113 UpdateFunctionInputShape(arg_index, new_var_type, function);
114 function.setArgAttr(arg_index, kAssignedResourceLocalShape,
115 ConvertTypeToTensorShapeAttr(*new_subtype));
116 return mlir::success();
117 }
118
119 auto resource_type = resource_arg.getType()
120 .cast<mlir::TensorType>()
121 .getElementType()
122 .dyn_cast<mlir::TF::ResourceType>();
123 if (!resource_type) return mlir::success();
124
125 auto sub_types = resource_type.getSubtypes();
126 if (sub_types.size() != 1) return mlir::success();
127
128 auto resource_arg_sub_type = sub_types.front();
129 if (!resource_arg_sub_type.hasStaticShape()) return mlir::success();
130
131 // The local shape that is to be assigned to this resource argument type. We
132 // will either pull it from the assigned local shape attribute or compute it
133 // based on the layout.
134 // TODO(srujun): use the attribute value only to check the computed shape.
135 // This is currently blocked by an "empty_layout" set on the resource
136 // arguments, meaning it is not possible to compute local layout.
137 llvm::SmallVector<int64_t, 4> local_arg_shape;
138 auto assigned_resource_local_shape_attr =
139 function.getArgAttrOfType<mlir::TF::ShapeAttr>(
140 arg_index, kAssignedResourceLocalShape);
141 if (assigned_resource_local_shape_attr) {
142 local_arg_shape.append(
143 assigned_resource_local_shape_attr.getShape().begin(),
144 assigned_resource_local_shape_attr.getShape().end());
145 } else {
146 auto layout_or_status = ExtractLayoutFromOperand(resource_arg);
147 if (!layout_or_status.ok())
148 return function.emitOpError(layout_or_status.status().error_message());
149
150 const auto& layout = layout_or_status.value();
151 if (!layout) return mlir::success();
152
153 std::vector<int64_t> local_arg_shape_vec =
154 layout->LocalShapeFromGlobalShape(resource_arg_sub_type.getShape());
155 local_arg_shape.append(local_arg_shape_vec.begin(),
156 local_arg_shape_vec.end());
157 }
158
159 auto local_variable_subtype = mlir::RankedTensorType::get(
160 local_arg_shape, resource_arg_sub_type.getElementType());
161 auto new_var_type = mlir::RankedTensorType::get(
162 {}, mlir::TF::ResourceType::get(
163 mlir::ArrayRef<mlir::TensorType>{local_variable_subtype},
164 function.getContext()));
165
166 UpdateFunctionInputShape(arg_index, new_var_type, function);
167 function.setArgAttr(
168 arg_index, kAssignedResourceLocalShape,
169 mlir::TF::ShapeAttr::get(local_variable_subtype.getContext(),
170 mlir::ArrayRef<int64_t>(local_arg_shape)));
171
172 return mlir::success();
173}
174
175// Returns whether `value` is used by AssignVariable op, skipping DTensorLayout
176// op.
177bool IsValueUsedByAssignVariableOp(
178 mlir::Value value, int* resource_argument_index_for_assign_variable) {
179 for (auto user : value.getUsers()) {
180 if (auto assign_variable_op =
181 llvm::dyn_cast_or_null<mlir::TF::AssignVariableOp>(
182 NextTFOp(user))) {
183 *resource_argument_index_for_assign_variable =
184 GetForwardedDTensorLayoutInput(assign_variable_op.resource())
185 .cast<mlir::BlockArgument>()
186 .getArgNumber();
187 return true;
188 }
189 }
190 return false;
191}
192
193// Updates argument shapes of `function` based on `tf._layout` attribute.
194mlir::LogicalResult UpdateFunctionArgsUsingLayout(mlir::func::FuncOp function) {
195 for (int argument_index = 0; argument_index < function.getNumArguments();
196 ++argument_index) {
197 auto arg_layout_attr = function.getArgAttrOfType<mlir::StringAttr>(
198 argument_index, kCustomDeviceAttr);
199 if (!arg_layout_attr) continue;
200
201 auto arg_layout = Layout::FromString(arg_layout_attr.getValue().str());
202 if (!arg_layout.ok())
203 return function.emitOpError(llvm::formatv(
204 "Invalid layout attribute found during SPMD expansion: {0}",
205 arg_layout.status().error_message()));
206
207 mlir::Type arg_type = mlir::getElementTypeOrSelf(
208 function.getFunctionType().getInput(argument_index));
209
210 // If argument is a resource type update the subtype shape information
211 // to reflect local shape of resources.
212 if (arg_type.isa<mlir::TF::ResourceType>()) {
213 if (mlir::failed(UpdateResourceArgumentType(argument_index, function)))
214 return mlir::failure();
215 continue;
216 }
217
218 mlir::RankedTensorType ranked_type =
219 function.getFunctionType()
220 .getInput(argument_index)
221 .dyn_cast<mlir::RankedTensorType>();
222 if (!ranked_type) continue;
223
224 // If input value is non-resource type, then update the value to reflect
225 // local shape.
226 llvm::ArrayRef<int64_t> arg_shape = ranked_type.getShape();
227 const std::vector<int64_t> arg_local_shape =
228 arg_layout->LocalShapeFromGlobalShape(arg_shape);
229 mlir::RankedTensorType new_arg_type = mlir::RankedTensorType::get(
230 arg_local_shape, ranked_type.getElementType());
231 UpdateFunctionInputShape(argument_index, new_arg_type, function);
232
233 // If non-resource value was used for AssignVariable op, then ensure that
234 // resource shape of updated/assigned resource is consistent with the
235 // local shape of assigned value.
236 int assigned_resource_argument_index = -1;
237 if (IsValueUsedByAssignVariableOp(function.getArgument(argument_index),
238 &assigned_resource_argument_index)) {
239 (void)UpdateResourceArgumentType(assigned_resource_argument_index,
240 function, new_arg_type);
241 }
242 }
243 return mlir::success();
244}
245
246// Given SPMD expanded `function_operands` to `function`, update the function
247// signature to reflect the local shape of `function_operands`.
248mlir::LogicalResult UpdateFunctionWithLocalInputShapes(
249 mlir::MutableArrayRef<mlir::OpOperand> function_operands,
250 mlir::func::FuncOp function) {
251 for (auto& operand : function_operands) {
252 const int index = operand.getOperandNumber();
253 auto arg_type = operand.get().getType().dyn_cast<mlir::RankedTensorType>();
254 if (!arg_type) continue;
255
256 auto arg_local_shape = arg_type.getShape();
257 auto new_arg_type =
258 mlir::RankedTensorType::get(arg_local_shape, arg_type.getElementType());
259 UpdateFunctionInputShape(index, new_arg_type, function);
260 }
261 return mlir::success();
262}
263
264// Updates output shapes of enclosing op or function containing `terminator_op`
265// to local shapes.
266mlir::LogicalResult UpdateReturnValueShapes(mlir::ModuleOp module,
267 mlir::Operation* terminator_op) {
268 auto parent_op = terminator_op->getBlock()->getParentOp();
269 if (!parent_op) return mlir::success();
270
271 auto output_types = llvm::to_vector<8>(terminator_op->getOperandTypes());
272 if (auto function = llvm::dyn_cast<mlir::func::FuncOp>(parent_op)) {
273 // Update function output type to have local shape.
274 auto new_func_type = mlir::FunctionType::get(
275 function.getContext(), function.getFunctionType().getInputs(),
276 output_types);
277 function.setType(new_func_type);
278
279 // Update function callsite operations to reflect local output shapes.
280 auto function_uses =
281 mlir::SymbolTable::getSymbolUses(function, &module.getBodyRegion());
282 if (!function_uses) return mlir::success();
283
284 // Update function callsite operations to reflect local output shapes.
285 for (auto function_use : *function_uses) {
286 auto callsite_op = function_use.getUser();
287 if (!callsite_op) continue;
288
289 for (auto& output_type_and_index : llvm::enumerate(output_types)) {
290 int index = output_type_and_index.index();
291 const auto& type = output_type_and_index.value();
292 callsite_op->getResult(index).setType(type);
293 }
294 }
295 } else {
296 for (auto& output_type_and_index : llvm::enumerate(output_types)) {
297 int index = output_type_and_index.index();
298 const auto& type = output_type_and_index.value();
299 parent_op->getResult(index).setType(type);
300 }
301 }
302
303 return mlir::success();
304}
305
306// Conducts SPMD expansion for all ops in `module`. If function call operation
307// exists, walk the function in topological order to update inputs/outputs of
308// functions before SPMD expansion of callsite operations is done.
309// Note that the iteration won't work with recursive function calls.
310mlir::LogicalResult ConductSPMDExpansion(mlir::ModuleOp module) {
311 auto main_func = module.lookupSymbol<mlir::func::FuncOp>(kMainFunctionName);
312 if (!main_func)
313 return module.emitOpError(
314 "could not find `main` function in module for SPMD expansion.");
315
316 if (mlir::failed(UpdateFunctionArgsUsingLayout(main_func)))
317 return mlir::failure();
318
319 TopologicalIterator iterator(main_func);
320 while (iterator.hasNext()) {
321 mlir::Operation* op = iterator.next();
322 absl::optional<mlir::func::FuncOp> func = MaybeFindFunction(op);
323 if (func.has_value()) {
324 if (mlir::failed(
325 UpdateFunctionWithLocalInputShapes(op->getOpOperands(), *func)))
326 return mlir::failure();
327 }
328
329 const bool is_terminator_op =
330 llvm::isa<mlir::func::ReturnOp, mlir::tf_device::ReturnOp>(op);
331 if (auto layout_op = llvm::dyn_cast<mlir::TF::DTensorLayout>(op))
332 layout_op.output().setType(layout_op.input().getType());
333
334 mlir::Operation* expanded_op = nullptr;
335 auto status = RunSPMDExpansion(op, &expanded_op);
336 if (!status.ok() || expanded_op == nullptr) {
337 // Sometimes op may been erased and expanded_op set.
338 // In this case we should emit the error on the expanded op.
339 mlir::Operation* emit_op = op;
340 if (expanded_op != nullptr) emit_op = expanded_op;
341 return emit_op->emitError(WithContext(status, __FILE__, __LINE__,
342 "While computing SPMD expansion")
343 .error_message());
344 }
345
346 // If expanded op is terminator of tf_device.Cluster or a function, then
347 // make sure to update the function return value as well as the shape of
348 // it's callsite operation.
349 if (is_terminator_op)
350 if (mlir::failed(UpdateReturnValueShapes(module, expanded_op)))
351 return mlir::failure();
352 }
353 return mlir::success();
354}
355
356// DTensorLayout only conveys layout information of tensors which is no
357// longer needed after SPMD expansion. As so, remove all layouts from
358// graph.
359void RemoveDTensorLayoutOps(mlir::ModuleOp module) {
360 llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops;
361 module.walk(
362 [&](mlir::TF::DTensorLayout layout) { layout_ops.emplace_back(layout); });
363
364 for (auto layout_op : layout_ops) RemoveDTensorLayoutOp(layout_op);
365}
366
367// Removes temporary attrs created during SPMD expansion.
368void RemoveTemporarySPMDAttrs(mlir::ModuleOp module) {
369 module.walk([&](mlir::Operation* op) {
370 if (op->hasAttr(kDeviceSeedForMeshDims)) {
371 op->removeAttr(kDeviceSeedForMeshDims);
372 }
373 });
374}
375
376// MLIR pass that converts graph in global view into a local view which can be
377// invoked in parallel on distributed set of devices. This pass removes
378// all DTensorLayout ops after the expansion is done. Temporary nodes and
379// attributes are also removed after the pass is done.
380struct DTensorSPMDExpansion
381 : public impl::DTensorSPMDExpansionBase<DTensorSPMDExpansion> {
382 void getDependentDialects(mlir::DialectRegistry& registry) const override {
383 registry.insert<mlir::dtensor::DTensorDialect>();
384 }
385
386 void runOnOperation() override {
387 auto module = getOperation();
388 if (failed(ConductSPMDExpansion(module))) return signalPassFailure();
389
390 RemoveDTensorLayoutOps(module);
391
392 RemoveTemporarySPMDAttrs(module);
393 };
394};
395
396} // namespace
397
398std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
399CreateDTensorSPMDExpansion() {
400 return std::make_unique<DTensorSPMDExpansion>();
401}
402
403} // namespace dtensor
404} // namespace tensorflow
405