1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17
18#include "absl/types/optional.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22#include "mlir/IR/Builders.h" // from @llvm-project
23#include "mlir/IR/BuiltinOps.h" // from @llvm-project
24#include "mlir/IR/MLIRContext.h" // from @llvm-project
25#include "mlir/IR/Operation.h" // from @llvm-project
26#include "mlir/IR/Visitors.h" // from @llvm-project
27#include "mlir/Support/LogicalResult.h" // from @llvm-project
28#include "mlir/Transforms/Passes.h" // from @llvm-project
29#include "tensorflow/dtensor/cc/constants.h"
30#include "tensorflow/dtensor/cc/tensor_layout.h"
31#include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
32#include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h"
33#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
34#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
35#include "tensorflow/dtensor/mlir/layout_parsing.h"
36#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
37#include "tensorflow/dtensor/mlir/value_utils.h"
38
39namespace tensorflow {
40namespace dtensor {
41
42namespace {
43#define GEN_PASS_DEF_DTENSORPROPAGATEDEFAULTLAYOUT
44#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
45
46// Creates tf.DTensorLayout op that forwards `input` value.
47void CreateDTensorLayoutOp(const Layout& layout, mlir::Value input,
48 mlir::TensorType& type, mlir::Location loc,
49 mlir::OpBuilder* builder,
50 mlir::MLIRContext* context) {
51 if (layout.IsEmpty()) return;
52
53 auto layout_op = builder->create<mlir::TF::DTensorLayout>(
54 loc, input, mlir::dtensor::LayoutAttr::get(context, layout),
55 mlir::TF::ShapeAttr::get(context, type));
56 llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op};
57 input.replaceAllUsesExcept(layout_op.output(), exception);
58}
59
60// Adds DTensorLayout op following each Relayout operation to ensure that
61// tensor from `relayout` has fixed layout.
62mlir::LogicalResult PropagateDTensorLayoutForRelayout(
63 mlir::MLIRContext& c, mlir::TF::RelayoutOp relayout) {
64 const std::string layout_str = relayout.layout().str();
65 auto layout_or_status = Layout::FromString(layout_str);
66 if (!layout_or_status.ok()) {
67 return relayout.emitOpError(
68 llvm::formatv("found Relayout op with incorrect/unparsable layout. "
69 "Found layout: {0} ",
70 layout_str));
71 }
72 const Layout& layout = layout_or_status.value();
73
74 // Skip adding a DTensorLayout if Relayout is 'dynamic'. Any dimension with
75 // MATCH for the layout will have its layout preserved in layout propagation.
76 for (const std::string& sharding_spec : layout.sharding_spec_strs())
77 if (sharding_spec == Layout::kMatch) return mlir::success();
78
79 mlir::OpBuilder builder(relayout->getBlock(),
80 ++mlir::Block::iterator(relayout));
81 mlir::TensorType type = relayout.getType().dyn_cast<mlir::TensorType>();
82 if (!type) return relayout.emitOpError("type required for Relayout op");
83
84 CreateDTensorLayoutOp(layout, relayout.output(), type, relayout.getLoc(),
85 &builder, &c);
86 return mlir::success();
87}
88
89// Creates tf.DTensorLayout that is connected to each function argument if
90// function arg contains layout attribute.
91mlir::LogicalResult PropagateFunctionArgAttrToLayoutOp(
92 mlir::MLIRContext& c, mlir::func::FuncOp function) {
93 for (int arg_index = 0; arg_index < function.getNumArguments(); ++arg_index) {
94 auto layout_attr = function.getArgAttrOfType<mlir::StringAttr>(
95 arg_index, kCustomDeviceAttr);
96 if (!layout_attr) continue;
97 const auto layout_str = layout_attr.getValue().str();
98 auto layout_or_status = Layout::FromString(layout_str);
99 if (!layout_or_status.ok())
100 return function.emitOpError(llvm::formatv(
101 "function includes attribute {0} for {1}-th arg that cannot be "
102 "serialized to correct layout format. Found attribute {3}",
103 kCustomDeviceAttr, arg_index, layout_str));
104
105 mlir::OpBuilder builder(function.getBody());
106 auto arg = function.getArgument(arg_index);
107 mlir::Type tensor_type = GetSubtypeOrSelf(arg);
108 if (auto type = tensor_type.dyn_cast<mlir::TensorType>()) {
109 CreateDTensorLayoutOp(layout_or_status.value(), arg, type,
110 function.getLoc(), &builder, &c);
111 } else {
112 return function.emitOpError()
113 << "is missing tensor type for argument " << arg_index;
114 }
115 }
116
117 return mlir::success();
118}
119
120// Creates tf.DTensorLayout that is connected to terminator op of function if
121// function contains default layout attribute that represents layout of function
122// outputs.
123mlir::LogicalResult PropagateFunctionDefaultLayoutAttrToLayoutOp(
124 mlir::MLIRContext& c, mlir::func::FuncOp function) {
125 for (int ret_index = 0; ret_index < function.getNumResults(); ++ret_index) {
126 auto layout_attr_from_func_result =
127 function.getResultAttrOfType<mlir::StringAttr>(
128 ret_index, kCustomDefaultLayoutAttr);
129 if (!layout_attr_from_func_result) continue;
130
131 const std::string layout_string =
132 layout_attr_from_func_result.getValue().str();
133 auto result_layout_or_status = Layout::FromString(layout_string);
134 if (!result_layout_or_status.ok())
135 return function.emitOpError(
136 llvm::formatv("function includes default layout attribute {0} for "
137 "{1}-th output that cannot be serialized to correct "
138 "layout format. Found attribute {3}",
139 kCustomDefaultLayoutAttr, ret_index, layout_string));
140
141 auto function_terminator = function.getBody().front().getTerminator();
142 mlir::OpBuilder builder(function_terminator);
143 auto return_value = function_terminator->getOperand(ret_index);
144
145 if (auto type = return_value.getType().dyn_cast<mlir::TensorType>())
146 CreateDTensorLayoutOp(result_layout_or_status.value(), return_value, type,
147 function.getLoc(), &builder, &c);
148 else
149 return function.emitOpError()
150 << "is missing tensor type for result " << ret_index;
151 }
152
153 return mlir::success();
154}
155
156// MLIR pass that removes trivially unused operations in graph.
157struct DTensorPropagateDefaultLayout
158 : public impl::DTensorPropagateDefaultLayoutBase<
159 DTensorPropagateDefaultLayout> {
160 void getDependentDialects(mlir::DialectRegistry& registry) const override {
161 registry.insert<mlir::dtensor::DTensorDialect>();
162 }
163
164 void runOnOperation() override {
165 mlir::MLIRContext& context = getContext();
166 mlir::OpBuilder builder(&context);
167
168 auto function = getOperation();
169
170 auto walk_result =
171 getOperation().walk([&](mlir::Operation* op) -> mlir::WalkResult {
172 if (auto relayout = llvm::dyn_cast<mlir::TF::RelayoutOp>(op)) {
173 (void)PropagateDTensorLayoutForRelayout(context, relayout);
174 return mlir::WalkResult::advance();
175 }
176
177 // Set user annotated layout on operations.
178 auto layout_or_status = ExtractLayoutFromOp(op);
179 if (!layout_or_status.ok()) {
180 op->emitOpError(llvm::formatv(
181 "op has layout attribute {0} that cannot be deserizlied.",
182 kLayoutAttr));
183 return mlir::WalkResult::interrupt();
184 }
185
186 mlir::OpBuilder builder(&context);
187 builder.setInsertionPointAfter(op);
188 const auto layouts = layout_or_status.value();
189 for (const auto& layout_and_index : llvm::enumerate(layouts)) {
190 const int index = layout_and_index.index();
191 const auto& layout = layout_and_index.value();
192 if (!layout || layout->IsEmpty()) continue;
193
194 auto op_output = op->getResult(index);
195 if (auto type = op_output.getType().dyn_cast<mlir::TensorType>()) {
196 auto layout_op = builder.create<mlir::TF::DTensorLayout>(
197 function.getLoc(), op_output,
198 mlir::dtensor::LayoutAttr::get(&context, *layout),
199 mlir::TF::ShapeAttr::get(&context, type));
200 llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op};
201 op_output.replaceAllUsesExcept(layout_op.output(), exception);
202 } else {
203 return op->emitOpError()
204 << "type for output " << index << " is not a TensorType";
205 }
206 }
207
208 return mlir::WalkResult::advance();
209 });
210
211 if (walk_result.wasInterrupted()) return signalPassFailure();
212
213 // Set user annotated layout on function arguments.
214 if (mlir::failed(PropagateFunctionArgAttrToLayoutOp(context, function)))
215 return signalPassFailure();
216
217 // Set user annotated layout on function outputs.
218 if (mlir::failed(
219 PropagateFunctionDefaultLayoutAttrToLayoutOp(context, function)))
220 return signalPassFailure();
221 }
222};
223
224} // namespace
225
226std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
227CreateDTensorPropagateDefaultLayout() {
228 return std::make_unique<DTensorPropagateDefaultLayout>();
229}
230
231} // namespace dtensor
232} // namespace tensorflow
233