1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
17 | |
18 | #include <string> |
19 | #include <utility> |
20 | |
21 | #include "absl/strings/str_cat.h" |
22 | #include "absl/types/optional.h" |
23 | #include "llvm/ADT/STLExtras.h" |
24 | #include "llvm/Support/FormatVariadic.h" |
25 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
26 | #include "mlir/IR/Attributes.h" // from @llvm-project |
27 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
28 | #include "mlir/IR/Operation.h" // from @llvm-project |
29 | #include "mlir/IR/OperationSupport.h" // from @llvm-project |
30 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
31 | #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" |
32 | #include "tensorflow/core/platform/errors.h" |
33 | #include "tensorflow/core/platform/mutex.h" |
34 | #include "tensorflow/dtensor/cc/constants.h" |
35 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
36 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
37 | |
38 | namespace tensorflow { |
39 | namespace dtensor { |
40 | namespace { |
41 | |
42 | bool OpUsesV2LayoutAnnotation(mlir::Operation* op) { |
43 | return !op->getUsers().empty() && |
44 | llvm::all_of(op->getUsers(), [](mlir::Operation* user_op) { |
45 | return llvm::isa<mlir::TF::DTensorLayout>(user_op); |
46 | }); |
47 | } |
48 | |
49 | } // namespace |
50 | |
51 | StatusOr<absl::optional<Layout>> ( |
52 | mlir::Operation* op, std::string attr_name) { |
53 | absl::optional<Layout> out; |
54 | |
55 | // If v2 layout propagation algorithm is used, parse layout from DTensorLayout |
56 | // op. |
57 | if (OpUsesV2LayoutAnnotation(op)) { |
58 | // If DTensorLayout is used, then DTensorLayout op is the only consumer for |
59 | // the operation output value. |
60 | auto users = op->getUsers(); |
61 | out.emplace(llvm::cast<mlir::TF::DTensorLayout>(*users.begin()).layout()); |
62 | } else { |
63 | TF_ASSIGN_OR_RETURN(auto layouts, ExtractLayoutFromOp(op, attr_name)); |
64 | if (layouts.empty()) return out; |
65 | if (layouts.size() != 1) { |
66 | return errors::Internal( |
67 | "Extracting single layout on Op that has multiple layout attached is " |
68 | "ambiguous. op : " , |
69 | op->getName().getStringRef().str()); |
70 | } |
71 | out.swap(layouts[0]); |
72 | } |
73 | return out; |
74 | } |
75 | |
76 | StatusOr<absl::optional<Layout>> ( |
77 | mlir::Operation* op) { |
78 | return ExtractSingleLayoutFromOp(op, kLayoutAttr); |
79 | } |
80 | |
81 | StatusOr<Layout> (mlir::Operation* op) { |
82 | TF_ASSIGN_OR_RETURN(absl::optional<Layout> layout, |
83 | ExtractSingleLayoutFromOp(op)); |
84 | if (!layout) return errors::Internal("expected layout missing" ); |
85 | |
86 | return *layout; |
87 | } |
88 | |
89 | StatusOr<std::vector<absl::optional<Layout>>> ( |
90 | mlir::Operation* op, std::string attr_name) { |
91 | std::vector<absl::optional<Layout>> outs; |
92 | outs.reserve(op->getNumResults()); |
93 | |
94 | // If v2 layout propagation algorithm is used, parse layout from DTensorLayout |
95 | // op. |
96 | if (OpUsesV2LayoutAnnotation(op)) { |
97 | for (auto op_result : op->getOpResults()) { |
98 | outs.emplace_back( |
99 | llvm::cast<mlir::TF::DTensorLayout>(*op_result.getUsers().begin()) |
100 | .layout()); |
101 | } |
102 | } else { |
103 | auto serialized_layouts = op->getAttrOfType<mlir::ArrayAttr>(attr_name); |
104 | if (!serialized_layouts) return outs; |
105 | |
106 | for (auto const& attr : serialized_layouts) { |
107 | auto attr_str = attr.cast<mlir::StringAttr>().getValue().str(); |
108 | if (!attr_str.empty()) { |
109 | TF_ASSIGN_OR_RETURN(auto layout, Layout::FromString(attr_str)); |
110 | outs.emplace_back(std::move(layout)); |
111 | } else { |
112 | outs.emplace_back(absl::nullopt); |
113 | } |
114 | } |
115 | } |
116 | return outs; |
117 | } |
118 | |
119 | StatusOr<std::vector<absl::optional<Layout>>> ( |
120 | mlir::Operation* op) { |
121 | return ExtractLayoutFromOp(op, kLayoutAttr); |
122 | } |
123 | |
124 | StatusOr<std::vector<Layout>> (mlir::Operation* op) { |
125 | TF_ASSIGN_OR_RETURN(std::vector<absl::optional<Layout>> optional_layouts, |
126 | ExtractLayoutFromOp(op)); |
127 | std::vector<Layout> layouts; |
128 | for (const absl::optional<Layout>& layout : optional_layouts) { |
129 | if (!layout) return errors::Internal("expected layout missing" ); |
130 | layouts.emplace_back(*layout); |
131 | } |
132 | |
133 | return layouts; |
134 | } |
135 | |
136 | StatusOr<Mesh> (mlir::Operation* op) { |
137 | auto enclosing_cluster = op->getParentOfType<mlir::tf_device::ClusterOp>(); |
138 | if (!enclosing_cluster) |
139 | return errors::InvalidArgument("op is not inside a device mesh cluster." ); |
140 | |
141 | TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshFromOp(enclosing_cluster)); |
142 | if (!mesh) |
143 | return errors::InvalidArgument( |
144 | "op's enclosing device cluster does not have mesh defined." ); |
145 | |
146 | return *mesh; |
147 | } |
148 | |
149 | StatusOr<absl::optional<Mesh>> (mlir::Operation* op) { |
150 | absl::optional<Mesh> ; |
151 | if (op == nullptr) return extracted_mesh; |
152 | |
153 | auto mesh_str_attr = op->getAttrOfType<mlir::StringAttr>(kMeshAttr); |
154 | if (!mesh_str_attr) return extracted_mesh; |
155 | |
156 | TF_ASSIGN_OR_RETURN(Mesh mesh, |
157 | Mesh::FromString(mesh_str_attr.getValue().str())); |
158 | |
159 | extracted_mesh.emplace(std::move(mesh)); |
160 | return extracted_mesh; |
161 | } |
162 | |
163 | StatusOr<absl::optional<Layout>> ExtractLayoutFromOperand(mlir::Value operand) { |
164 | if (auto op_result = operand.dyn_cast<mlir::OpResult>()) { |
165 | mlir::Operation* op = op_result.getDefiningOp(); |
166 | absl::optional<Layout> out; |
167 | if (auto layout_op = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) { |
168 | out.emplace(layout_op.layout()); |
169 | } else { |
170 | const int result_number = op_result.getResultNumber(); |
171 | TF_ASSIGN_OR_RETURN(auto layouts, ExtractLayoutFromOp(op, kLayoutAttr)); |
172 | |
173 | if (layouts.empty()) return out; |
174 | |
175 | if (result_number >= layouts.size()) { |
176 | return errors::Internal( |
177 | "Expect to extract the " , result_number, |
178 | "-th output's layout, but " |
179 | "only see " , |
180 | layouts.size(), " outputs: " , op->getName().getStringRef().str()); |
181 | } |
182 | out.swap(layouts[result_number]); |
183 | } |
184 | return out; |
185 | } |
186 | |
187 | auto block_arg = operand.dyn_cast<mlir::BlockArgument>(); |
188 | if (!block_arg) |
189 | return errors::Internal( |
190 | "Operand is not either a OpResult or a BlockArgument. This should not " |
191 | "happen." ); |
192 | auto func_op = mlir::dyn_cast_or_null<mlir::func::FuncOp>( |
193 | block_arg.getOwner()->getParentOp()); |
194 | if (!func_op) { |
195 | return errors::InvalidArgument("op must be enclosed by a function" ); |
196 | } |
197 | |
198 | absl::optional<Layout> ; |
199 | auto layout_attr = func_op.getArgAttrOfType<mlir::StringAttr>( |
200 | block_arg.getArgNumber(), kCustomDeviceAttr); |
201 | if (!layout_attr) return extracted_layout; |
202 | |
203 | TF_ASSIGN_OR_RETURN(auto layout, |
204 | Layout::FromString(layout_attr.getValue().str())); |
205 | extracted_layout.emplace(std::move(layout)); |
206 | return extracted_layout; |
207 | } |
208 | |
209 | StatusOr<Layout> ExtractRequiredLayoutFromOperand(mlir::Value operand) { |
210 | TF_ASSIGN_OR_RETURN(absl::optional<Layout> layout, |
211 | ExtractLayoutFromOperand(operand)); |
212 | if (!layout) return errors::Internal("expected layout missing" ); |
213 | |
214 | return *layout; |
215 | } |
216 | |
217 | StatusOr<std::vector<Layout>> ExtractRequiredLayoutFromOperands( |
218 | mlir::Operation* op) { |
219 | std::vector<Layout> layouts; |
220 | for (const auto& operand : op->getOpOperands()) { |
221 | TF_ASSIGN_OR_RETURN(auto operand_layout, |
222 | ExtractRequiredLayoutFromOperand(operand.get())); |
223 | layouts.emplace_back(operand_layout); |
224 | } |
225 | return layouts; |
226 | } |
227 | |
228 | void SetLayoutOnOp(mlir::Operation* op, mlir::OpBuilder builder, |
229 | absl::Span<const absl::optional<Layout>> layouts) { |
230 | llvm::SmallVector<std::string, 8> serialized_layouts; |
231 | for (auto const& layout : layouts) { |
232 | serialized_layouts.emplace_back(layout.has_value() ? layout->ToString() |
233 | : "" ); |
234 | } |
235 | op->setAttr(kLayoutAttr, |
236 | builder.getStrArrayAttr(llvm::SmallVector<llvm::StringRef, 8>( |
237 | serialized_layouts.begin(), serialized_layouts.end()))); |
238 | } |
239 | |
240 | void SetLayoutOnOp(mlir::Operation* op, |
241 | absl::Span<const absl::optional<Layout>> layouts) { |
242 | SetLayoutOnOp(op, mlir::OpBuilder(op), layouts); |
243 | } |
244 | |
245 | void SetSingleLayoutOnOp(mlir::Operation* op, const Layout& layout) { |
246 | SetLayoutOnOp(op, mlir::OpBuilder(op), {absl::optional<Layout>(layout)}); |
247 | } |
248 | |
249 | StatusOr<absl::optional<Layout>> ( |
250 | mlir::func::ReturnOp return_op, const int return_index) { |
251 | absl::optional<Layout> layout; |
252 | // If value feeds into func op return op, then check to see if layout |
253 | // attribute is set for the return value. |
254 | auto function = return_op->getParentOfType<mlir::func::FuncOp>(); |
255 | auto layout_attr_from_func_result = |
256 | function.getResultAttrOfType<mlir::StringAttr>(return_index, |
257 | kCustomDefaultLayoutAttr); |
258 | if (!layout_attr_from_func_result) return layout; |
259 | |
260 | const std::string layout_string = |
261 | layout_attr_from_func_result.getValue().str(); |
262 | auto result_layout_or_status = Layout::FromString(layout_string); |
263 | if (!result_layout_or_status.ok()) |
264 | return errors::InvalidArgument( |
265 | llvm::formatv("Malformed default return layout received. {0} Received " |
266 | "layout : {1}" , |
267 | result_layout_or_status.status().error_message(), |
268 | layout_string) |
269 | .str()); |
270 | |
271 | layout.emplace(result_layout_or_status.value()); |
272 | return layout; |
273 | } |
274 | |
275 | } // namespace dtensor |
276 | } // namespace tensorflow |
277 | |