1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | |
18 | #include "absl/types/optional.h" |
19 | #include "llvm/ADT/STLExtras.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
22 | #include "mlir/IR/Builders.h" // from @llvm-project |
23 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
24 | #include "mlir/IR/MLIRContext.h" // from @llvm-project |
25 | #include "mlir/IR/Operation.h" // from @llvm-project |
26 | #include "mlir/IR/Visitors.h" // from @llvm-project |
27 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
28 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
29 | #include "tensorflow/dtensor/cc/constants.h" |
30 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
31 | #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" |
32 | #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h" |
33 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
34 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
35 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
36 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
37 | #include "tensorflow/dtensor/mlir/value_utils.h" |
38 | |
39 | namespace tensorflow { |
40 | namespace dtensor { |
41 | |
42 | namespace { |
43 | #define GEN_PASS_DEF_DTENSORPROPAGATEDEFAULTLAYOUT |
44 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
45 | |
46 | // Creates tf.DTensorLayout op that forwards `input` value. |
47 | void CreateDTensorLayoutOp(const Layout& layout, mlir::Value input, |
48 | mlir::TensorType& type, mlir::Location loc, |
49 | mlir::OpBuilder* builder, |
50 | mlir::MLIRContext* context) { |
51 | if (layout.IsEmpty()) return; |
52 | |
53 | auto layout_op = builder->create<mlir::TF::DTensorLayout>( |
54 | loc, input, mlir::dtensor::LayoutAttr::get(context, layout), |
55 | mlir::TF::ShapeAttr::get(context, type)); |
56 | llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op}; |
57 | input.replaceAllUsesExcept(layout_op.output(), exception); |
58 | } |
59 | |
60 | // Adds DTensorLayout op following each Relayout operation to ensure that |
61 | // tensor from `relayout` has fixed layout. |
62 | mlir::LogicalResult PropagateDTensorLayoutForRelayout( |
63 | mlir::MLIRContext& c, mlir::TF::RelayoutOp relayout) { |
64 | const std::string layout_str = relayout.layout().str(); |
65 | auto layout_or_status = Layout::FromString(layout_str); |
66 | if (!layout_or_status.ok()) { |
67 | return relayout.emitOpError( |
68 | llvm::formatv("found Relayout op with incorrect/unparsable layout. " |
69 | "Found layout: {0} " , |
70 | layout_str)); |
71 | } |
72 | const Layout& layout = layout_or_status.value(); |
73 | |
74 | // Skip adding a DTensorLayout if Relayout is 'dynamic'. Any dimension with |
75 | // MATCH for the layout will have its layout preserved in layout propagation. |
76 | for (const std::string& sharding_spec : layout.sharding_spec_strs()) |
77 | if (sharding_spec == Layout::kMatch) return mlir::success(); |
78 | |
79 | mlir::OpBuilder builder(relayout->getBlock(), |
80 | ++mlir::Block::iterator(relayout)); |
81 | mlir::TensorType type = relayout.getType().dyn_cast<mlir::TensorType>(); |
82 | if (!type) return relayout.emitOpError("type required for Relayout op" ); |
83 | |
84 | CreateDTensorLayoutOp(layout, relayout.output(), type, relayout.getLoc(), |
85 | &builder, &c); |
86 | return mlir::success(); |
87 | } |
88 | |
89 | // Creates tf.DTensorLayout that is connected to each function argument if |
90 | // function arg contains layout attribute. |
91 | mlir::LogicalResult PropagateFunctionArgAttrToLayoutOp( |
92 | mlir::MLIRContext& c, mlir::func::FuncOp function) { |
93 | for (int arg_index = 0; arg_index < function.getNumArguments(); ++arg_index) { |
94 | auto layout_attr = function.getArgAttrOfType<mlir::StringAttr>( |
95 | arg_index, kCustomDeviceAttr); |
96 | if (!layout_attr) continue; |
97 | const auto layout_str = layout_attr.getValue().str(); |
98 | auto layout_or_status = Layout::FromString(layout_str); |
99 | if (!layout_or_status.ok()) |
100 | return function.emitOpError(llvm::formatv( |
101 | "function includes attribute {0} for {1}-th arg that cannot be " |
102 | "serialized to correct layout format. Found attribute {3}" , |
103 | kCustomDeviceAttr, arg_index, layout_str)); |
104 | |
105 | mlir::OpBuilder builder(function.getBody()); |
106 | auto arg = function.getArgument(arg_index); |
107 | mlir::Type tensor_type = GetSubtypeOrSelf(arg); |
108 | if (auto type = tensor_type.dyn_cast<mlir::TensorType>()) { |
109 | CreateDTensorLayoutOp(layout_or_status.value(), arg, type, |
110 | function.getLoc(), &builder, &c); |
111 | } else { |
112 | return function.emitOpError() |
113 | << "is missing tensor type for argument " << arg_index; |
114 | } |
115 | } |
116 | |
117 | return mlir::success(); |
118 | } |
119 | |
120 | // Creates tf.DTensorLayout that is connected to terminator op of function if |
121 | // function contains default layout attribute that represents layout of function |
122 | // outputs. |
123 | mlir::LogicalResult PropagateFunctionDefaultLayoutAttrToLayoutOp( |
124 | mlir::MLIRContext& c, mlir::func::FuncOp function) { |
125 | for (int ret_index = 0; ret_index < function.getNumResults(); ++ret_index) { |
126 | auto layout_attr_from_func_result = |
127 | function.getResultAttrOfType<mlir::StringAttr>( |
128 | ret_index, kCustomDefaultLayoutAttr); |
129 | if (!layout_attr_from_func_result) continue; |
130 | |
131 | const std::string layout_string = |
132 | layout_attr_from_func_result.getValue().str(); |
133 | auto result_layout_or_status = Layout::FromString(layout_string); |
134 | if (!result_layout_or_status.ok()) |
135 | return function.emitOpError( |
136 | llvm::formatv("function includes default layout attribute {0} for " |
137 | "{1}-th output that cannot be serialized to correct " |
138 | "layout format. Found attribute {3}" , |
139 | kCustomDefaultLayoutAttr, ret_index, layout_string)); |
140 | |
141 | auto function_terminator = function.getBody().front().getTerminator(); |
142 | mlir::OpBuilder builder(function_terminator); |
143 | auto return_value = function_terminator->getOperand(ret_index); |
144 | |
145 | if (auto type = return_value.getType().dyn_cast<mlir::TensorType>()) |
146 | CreateDTensorLayoutOp(result_layout_or_status.value(), return_value, type, |
147 | function.getLoc(), &builder, &c); |
148 | else |
149 | return function.emitOpError() |
150 | << "is missing tensor type for result " << ret_index; |
151 | } |
152 | |
153 | return mlir::success(); |
154 | } |
155 | |
156 | // MLIR pass that removes trivially unused operations in graph. |
157 | struct DTensorPropagateDefaultLayout |
158 | : public impl::DTensorPropagateDefaultLayoutBase< |
159 | DTensorPropagateDefaultLayout> { |
160 | void getDependentDialects(mlir::DialectRegistry& registry) const override { |
161 | registry.insert<mlir::dtensor::DTensorDialect>(); |
162 | } |
163 | |
164 | void runOnOperation() override { |
165 | mlir::MLIRContext& context = getContext(); |
166 | mlir::OpBuilder builder(&context); |
167 | |
168 | auto function = getOperation(); |
169 | |
170 | auto walk_result = |
171 | getOperation().walk([&](mlir::Operation* op) -> mlir::WalkResult { |
172 | if (auto relayout = llvm::dyn_cast<mlir::TF::RelayoutOp>(op)) { |
173 | (void)PropagateDTensorLayoutForRelayout(context, relayout); |
174 | return mlir::WalkResult::advance(); |
175 | } |
176 | |
177 | // Set user annotated layout on operations. |
178 | auto layout_or_status = ExtractLayoutFromOp(op); |
179 | if (!layout_or_status.ok()) { |
180 | op->emitOpError(llvm::formatv( |
181 | "op has layout attribute {0} that cannot be deserizlied." , |
182 | kLayoutAttr)); |
183 | return mlir::WalkResult::interrupt(); |
184 | } |
185 | |
186 | mlir::OpBuilder builder(&context); |
187 | builder.setInsertionPointAfter(op); |
188 | const auto layouts = layout_or_status.value(); |
189 | for (const auto& layout_and_index : llvm::enumerate(layouts)) { |
190 | const int index = layout_and_index.index(); |
191 | const auto& layout = layout_and_index.value(); |
192 | if (!layout || layout->IsEmpty()) continue; |
193 | |
194 | auto op_output = op->getResult(index); |
195 | if (auto type = op_output.getType().dyn_cast<mlir::TensorType>()) { |
196 | auto layout_op = builder.create<mlir::TF::DTensorLayout>( |
197 | function.getLoc(), op_output, |
198 | mlir::dtensor::LayoutAttr::get(&context, *layout), |
199 | mlir::TF::ShapeAttr::get(&context, type)); |
200 | llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op}; |
201 | op_output.replaceAllUsesExcept(layout_op.output(), exception); |
202 | } else { |
203 | return op->emitOpError() |
204 | << "type for output " << index << " is not a TensorType" ; |
205 | } |
206 | } |
207 | |
208 | return mlir::WalkResult::advance(); |
209 | }); |
210 | |
211 | if (walk_result.wasInterrupted()) return signalPassFailure(); |
212 | |
213 | // Set user annotated layout on function arguments. |
214 | if (mlir::failed(PropagateFunctionArgAttrToLayoutOp(context, function))) |
215 | return signalPassFailure(); |
216 | |
217 | // Set user annotated layout on function outputs. |
218 | if (mlir::failed( |
219 | PropagateFunctionDefaultLayoutAttrToLayoutOp(context, function))) |
220 | return signalPassFailure(); |
221 | } |
222 | }; |
223 | |
224 | } // namespace |
225 | |
226 | std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> |
227 | CreateDTensorPropagateDefaultLayout() { |
228 | return std::make_unique<DTensorPropagateDefaultLayout>(); |
229 | } |
230 | |
231 | } // namespace dtensor |
232 | } // namespace tensorflow |
233 | |