1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/layout_parsing.h"
17
18#include <string>
19#include <utility>
20
21#include "absl/strings/str_cat.h"
22#include "absl/types/optional.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26#include "mlir/IR/Attributes.h" // from @llvm-project
27#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28#include "mlir/IR/Operation.h" // from @llvm-project
29#include "mlir/IR/OperationSupport.h" // from @llvm-project
30#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
32#include "tensorflow/core/platform/errors.h"
33#include "tensorflow/core/platform/mutex.h"
34#include "tensorflow/dtensor/cc/constants.h"
35#include "tensorflow/dtensor/cc/tensor_layout.h"
36#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
37
38namespace tensorflow {
39namespace dtensor {
40namespace {
41
42bool OpUsesV2LayoutAnnotation(mlir::Operation* op) {
43 return !op->getUsers().empty() &&
44 llvm::all_of(op->getUsers(), [](mlir::Operation* user_op) {
45 return llvm::isa<mlir::TF::DTensorLayout>(user_op);
46 });
47}
48
49} // namespace
50
51StatusOr<absl::optional<Layout>> ExtractSingleLayoutFromOp(
52 mlir::Operation* op, std::string attr_name) {
53 absl::optional<Layout> out;
54
55 // If v2 layout propagation algorithm is used, parse layout from DTensorLayout
56 // op.
57 if (OpUsesV2LayoutAnnotation(op)) {
58 // If DTensorLayout is used, then DTensorLayout op is the only consumer for
59 // the operation output value.
60 auto users = op->getUsers();
61 out.emplace(llvm::cast<mlir::TF::DTensorLayout>(*users.begin()).layout());
62 } else {
63 TF_ASSIGN_OR_RETURN(auto layouts, ExtractLayoutFromOp(op, attr_name));
64 if (layouts.empty()) return out;
65 if (layouts.size() != 1) {
66 return errors::Internal(
67 "Extracting single layout on Op that has multiple layout attached is "
68 "ambiguous. op : ",
69 op->getName().getStringRef().str());
70 }
71 out.swap(layouts[0]);
72 }
73 return out;
74}
75
76StatusOr<absl::optional<Layout>> ExtractSingleLayoutFromOp(
77 mlir::Operation* op) {
78 return ExtractSingleLayoutFromOp(op, kLayoutAttr);
79}
80
81StatusOr<Layout> ExtractRequiredSingleLayoutFromOp(mlir::Operation* op) {
82 TF_ASSIGN_OR_RETURN(absl::optional<Layout> layout,
83 ExtractSingleLayoutFromOp(op));
84 if (!layout) return errors::Internal("expected layout missing");
85
86 return *layout;
87}
88
89StatusOr<std::vector<absl::optional<Layout>>> ExtractLayoutFromOp(
90 mlir::Operation* op, std::string attr_name) {
91 std::vector<absl::optional<Layout>> outs;
92 outs.reserve(op->getNumResults());
93
94 // If v2 layout propagation algorithm is used, parse layout from DTensorLayout
95 // op.
96 if (OpUsesV2LayoutAnnotation(op)) {
97 for (auto op_result : op->getOpResults()) {
98 outs.emplace_back(
99 llvm::cast<mlir::TF::DTensorLayout>(*op_result.getUsers().begin())
100 .layout());
101 }
102 } else {
103 auto serialized_layouts = op->getAttrOfType<mlir::ArrayAttr>(attr_name);
104 if (!serialized_layouts) return outs;
105
106 for (auto const& attr : serialized_layouts) {
107 auto attr_str = attr.cast<mlir::StringAttr>().getValue().str();
108 if (!attr_str.empty()) {
109 TF_ASSIGN_OR_RETURN(auto layout, Layout::FromString(attr_str));
110 outs.emplace_back(std::move(layout));
111 } else {
112 outs.emplace_back(absl::nullopt);
113 }
114 }
115 }
116 return outs;
117}
118
119StatusOr<std::vector<absl::optional<Layout>>> ExtractLayoutFromOp(
120 mlir::Operation* op) {
121 return ExtractLayoutFromOp(op, kLayoutAttr);
122}
123
124StatusOr<std::vector<Layout>> ExtractRequiredLayoutFromOp(mlir::Operation* op) {
125 TF_ASSIGN_OR_RETURN(std::vector<absl::optional<Layout>> optional_layouts,
126 ExtractLayoutFromOp(op));
127 std::vector<Layout> layouts;
128 for (const absl::optional<Layout>& layout : optional_layouts) {
129 if (!layout) return errors::Internal("expected layout missing");
130 layouts.emplace_back(*layout);
131 }
132
133 return layouts;
134}
135
136StatusOr<Mesh> ExtractDeviceMeshEnclosingCluster(mlir::Operation* op) {
137 auto enclosing_cluster = op->getParentOfType<mlir::tf_device::ClusterOp>();
138 if (!enclosing_cluster)
139 return errors::InvalidArgument("op is not inside a device mesh cluster.");
140
141 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshFromOp(enclosing_cluster));
142 if (!mesh)
143 return errors::InvalidArgument(
144 "op's enclosing device cluster does not have mesh defined.");
145
146 return *mesh;
147}
148
149StatusOr<absl::optional<Mesh>> ExtractDeviceMeshFromOp(mlir::Operation* op) {
150 absl::optional<Mesh> extracted_mesh;
151 if (op == nullptr) return extracted_mesh;
152
153 auto mesh_str_attr = op->getAttrOfType<mlir::StringAttr>(kMeshAttr);
154 if (!mesh_str_attr) return extracted_mesh;
155
156 TF_ASSIGN_OR_RETURN(Mesh mesh,
157 Mesh::FromString(mesh_str_attr.getValue().str()));
158
159 extracted_mesh.emplace(std::move(mesh));
160 return extracted_mesh;
161}
162
163StatusOr<absl::optional<Layout>> ExtractLayoutFromOperand(mlir::Value operand) {
164 if (auto op_result = operand.dyn_cast<mlir::OpResult>()) {
165 mlir::Operation* op = op_result.getDefiningOp();
166 absl::optional<Layout> out;
167 if (auto layout_op = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) {
168 out.emplace(layout_op.layout());
169 } else {
170 const int result_number = op_result.getResultNumber();
171 TF_ASSIGN_OR_RETURN(auto layouts, ExtractLayoutFromOp(op, kLayoutAttr));
172
173 if (layouts.empty()) return out;
174
175 if (result_number >= layouts.size()) {
176 return errors::Internal(
177 "Expect to extract the ", result_number,
178 "-th output's layout, but "
179 "only see ",
180 layouts.size(), " outputs: ", op->getName().getStringRef().str());
181 }
182 out.swap(layouts[result_number]);
183 }
184 return out;
185 }
186
187 auto block_arg = operand.dyn_cast<mlir::BlockArgument>();
188 if (!block_arg)
189 return errors::Internal(
190 "Operand is not either a OpResult or a BlockArgument. This should not "
191 "happen.");
192 auto func_op = mlir::dyn_cast_or_null<mlir::func::FuncOp>(
193 block_arg.getOwner()->getParentOp());
194 if (!func_op) {
195 return errors::InvalidArgument("op must be enclosed by a function");
196 }
197
198 absl::optional<Layout> extracted_layout;
199 auto layout_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
200 block_arg.getArgNumber(), kCustomDeviceAttr);
201 if (!layout_attr) return extracted_layout;
202
203 TF_ASSIGN_OR_RETURN(auto layout,
204 Layout::FromString(layout_attr.getValue().str()));
205 extracted_layout.emplace(std::move(layout));
206 return extracted_layout;
207}
208
209StatusOr<Layout> ExtractRequiredLayoutFromOperand(mlir::Value operand) {
210 TF_ASSIGN_OR_RETURN(absl::optional<Layout> layout,
211 ExtractLayoutFromOperand(operand));
212 if (!layout) return errors::Internal("expected layout missing");
213
214 return *layout;
215}
216
217StatusOr<std::vector<Layout>> ExtractRequiredLayoutFromOperands(
218 mlir::Operation* op) {
219 std::vector<Layout> layouts;
220 for (const auto& operand : op->getOpOperands()) {
221 TF_ASSIGN_OR_RETURN(auto operand_layout,
222 ExtractRequiredLayoutFromOperand(operand.get()));
223 layouts.emplace_back(operand_layout);
224 }
225 return layouts;
226}
227
228void SetLayoutOnOp(mlir::Operation* op, mlir::OpBuilder builder,
229 absl::Span<const absl::optional<Layout>> layouts) {
230 llvm::SmallVector<std::string, 8> serialized_layouts;
231 for (auto const& layout : layouts) {
232 serialized_layouts.emplace_back(layout.has_value() ? layout->ToString()
233 : "");
234 }
235 op->setAttr(kLayoutAttr,
236 builder.getStrArrayAttr(llvm::SmallVector<llvm::StringRef, 8>(
237 serialized_layouts.begin(), serialized_layouts.end())));
238}
239
240void SetLayoutOnOp(mlir::Operation* op,
241 absl::Span<const absl::optional<Layout>> layouts) {
242 SetLayoutOnOp(op, mlir::OpBuilder(op), layouts);
243}
244
245void SetSingleLayoutOnOp(mlir::Operation* op, const Layout& layout) {
246 SetLayoutOnOp(op, mlir::OpBuilder(op), {absl::optional<Layout>(layout)});
247}
248
249StatusOr<absl::optional<Layout>> ExtractLayoutFromFunctionReturnAttr(
250 mlir::func::ReturnOp return_op, const int return_index) {
251 absl::optional<Layout> layout;
252 // If value feeds into func op return op, then check to see if layout
253 // attribute is set for the return value.
254 auto function = return_op->getParentOfType<mlir::func::FuncOp>();
255 auto layout_attr_from_func_result =
256 function.getResultAttrOfType<mlir::StringAttr>(return_index,
257 kCustomDefaultLayoutAttr);
258 if (!layout_attr_from_func_result) return layout;
259
260 const std::string layout_string =
261 layout_attr_from_func_result.getValue().str();
262 auto result_layout_or_status = Layout::FromString(layout_string);
263 if (!result_layout_or_status.ok())
264 return errors::InvalidArgument(
265 llvm::formatv("Malformed default return layout received. {0} Received "
266 "layout : {1}",
267 result_layout_or_status.status().error_message(),
268 layout_string)
269 .str());
270
271 layout.emplace(result_layout_or_status.value());
272 return layout;
273}
274
275} // namespace dtensor
276} // namespace tensorflow
277