1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/mlir/shape_utils.h" |
17 | |
18 | #include "llvm/Support/FormatVariadic.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
20 | #include "mlir/IR/Builders.h" // from @llvm-project |
21 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
22 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
23 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
24 | #include "mlir/IR/Location.h" // from @llvm-project |
25 | #include "mlir/IR/MLIRContext.h" // from @llvm-project |
26 | #include "mlir/IR/OperationSupport.h" // from @llvm-project |
27 | #include "mlir/IR/Value.h" // from @llvm-project |
28 | #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" |
29 | #include "tensorflow/core/public/version.h" |
30 | #include "tensorflow/dtensor/cc/constants.h" |
31 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
32 | #include "tensorflow/dtensor/mlir/value_utils.h" |
33 | |
34 | namespace tensorflow { |
35 | namespace dtensor { |
36 | |
37 | StatusOr<llvm::ArrayRef<int64_t>> ExtractGlobalInputShape( |
38 | mlir::OpOperand& input_value) { |
39 | const int operand_index = input_value.getOperandNumber(); |
40 | auto input_defining_op = input_value.get().getDefiningOp(); |
41 | |
42 | if (input_defining_op) { |
43 | if (auto layout_op = |
44 | llvm::dyn_cast<mlir::TF::DTensorLayout>(input_defining_op)) { |
45 | auto global_shape = layout_op.global_shape(); |
46 | if (!global_shape) |
47 | return errors::Internal("global_shape does not have static rank" ); |
48 | return *global_shape; |
49 | } |
50 | return ExtractGlobalOutputShape(input_value.get().cast<mlir::OpResult>()); |
51 | } |
52 | |
53 | // If we reach this point, we're working with a function argument. |
54 | auto op = input_value.getOwner(); |
55 | auto enclosing_function = op->getParentOfType<mlir::func::FuncOp>(); |
56 | if (!enclosing_function) |
57 | return errors::InvalidArgument( |
58 | llvm::formatv("Could not find global shape of {0}-th input to op: {1}" , |
59 | operand_index, op->getName()) |
60 | .str()); |
61 | |
62 | auto block_arg = input_value.get().dyn_cast<mlir::BlockArgument>(); |
63 | auto global_shape_attr = |
64 | enclosing_function.getArgAttrOfType<mlir::TF::ShapeAttr>( |
65 | block_arg.getArgNumber(), kGlobalShapeDialectAttr); |
66 | if (!global_shape_attr) |
67 | return errors::InvalidArgument( |
68 | "`tf._global_shape` attribute of operation not found." ); |
69 | |
70 | return global_shape_attr.getShape(); |
71 | } |
72 | |
73 | StatusOr<llvm::ArrayRef<int64_t>> ( |
74 | mlir::OpResult result_value) { |
75 | auto op = result_value.getOwner(); |
76 | const int output_index = result_value.getResultNumber(); |
77 | |
78 | if (op->getOpResult(output_index).hasOneUse()) { |
79 | auto user = op->getOpResult(output_index).getUses().begin().getUser(); |
80 | if (auto layout_op = mlir::dyn_cast<mlir::TF::DTensorLayout>(user)) { |
81 | auto global_shape = layout_op.global_shape(); |
82 | if (!global_shape) |
83 | return errors::Internal("global_shape does not have static rank" ); |
84 | return *global_shape; |
85 | } |
86 | } |
87 | |
88 | auto global_shape_attr = op->getAttrOfType<mlir::ArrayAttr>(kGlobalShape); |
89 | if (!global_shape_attr) |
90 | return errors::InvalidArgument( |
91 | "`_global_shape` attribute of operation not found." ); |
92 | |
93 | const int num_results = op->getNumResults(); |
94 | assert(global_shape_attr.size() == num_results); |
95 | |
96 | if (output_index >= op->getNumResults()) |
97 | return errors::InvalidArgument( |
98 | llvm::formatv("Requested global shape of {0} output but op has only " |
99 | "{1} return values." , |
100 | output_index, num_results) |
101 | .str()); |
102 | |
103 | auto shape_attr = global_shape_attr[output_index]; |
104 | return shape_attr.cast<mlir::TF::ShapeAttr>().getShape(); |
105 | } |
106 | |
107 | namespace { |
108 | |
109 | // Extracts attributes from a MLIR operation, including derived attributes, into |
110 | // one NamedAttrList. |
111 | mlir::NamedAttrList GetAllAttributesFromOperation(mlir::Operation* op) { |
112 | mlir::NamedAttrList attr_list; |
113 | attr_list.append(op->getAttrDictionary().getValue()); |
114 | |
115 | if (auto derived = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(op)) { |
116 | auto materialized = derived.materializeDerivedAttributes(); |
117 | attr_list.append(materialized.getValue()); |
118 | } |
119 | |
120 | return attr_list; |
121 | } |
122 | |
123 | // Infers output shape of `op` given its local operand shape. For shape |
124 | // inference function that requires input operation to be a constant, if input |
125 | // operation is `DTensorLayout` op, then we use input of DTensorLayout op |
126 | // instead for correct constant matching. |
127 | mlir::LogicalResult InferShapeOfTFOpWithCustomOperandConstantFn( |
128 | llvm::Optional<mlir::Location> location, mlir::Operation* op, |
129 | int64_t graph_version, |
130 | llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferred_return_shapes) { |
131 | if (auto type_op = llvm::dyn_cast<mlir::InferTypeOpInterface>(op)) { |
132 | auto attributes = GetAllAttributesFromOperation(op); |
133 | llvm::SmallVector<mlir::Type, 4> inferred_return_types; |
134 | auto result = type_op.inferReturnTypes( |
135 | op->getContext(), location, op->getOperands(), |
136 | mlir::DictionaryAttr::get(op->getContext(), attributes), |
137 | op->getRegions(), inferred_return_types); |
138 | if (failed(result)) return mlir::failure(); |
139 | |
140 | inferred_return_shapes.resize(inferred_return_types.size()); |
141 | for (const auto& inferred_return_type : |
142 | llvm::enumerate(inferred_return_types)) { |
143 | if (auto shaped_type = |
144 | inferred_return_type.value().dyn_cast<mlir::ShapedType>()) { |
145 | if (shaped_type.hasRank()) { |
146 | inferred_return_shapes[inferred_return_type.index()] = |
147 | mlir::ShapedTypeComponents(shaped_type.getShape(), |
148 | shaped_type.getElementType()); |
149 | } else { |
150 | inferred_return_shapes[inferred_return_type.index()] = |
151 | mlir::ShapedTypeComponents(shaped_type.getElementType()); |
152 | } |
153 | } |
154 | } |
155 | |
156 | return mlir::success(); |
157 | } |
158 | |
159 | if (auto shape_type_op = |
160 | llvm::dyn_cast<mlir::InferShapedTypeOpInterface>(op)) { |
161 | auto attributes = GetAllAttributesFromOperation(op); |
162 | return shape_type_op.inferReturnTypeComponents( |
163 | op->getContext(), location, op->getOperands(), |
164 | mlir::DictionaryAttr::get(op->getContext(), attributes), |
165 | op->getRegions(), inferred_return_shapes); |
166 | } |
167 | |
168 | // If `operand` is from DTensorLayout op, use input value of DTensorLayout op |
169 | // instead. |
170 | auto operand_as_constant_fn = [](mlir::Value operand) -> mlir::Attribute { |
171 | while (auto input_op = llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>( |
172 | operand.getDefiningOp())) { |
173 | operand = input_op.input(); |
174 | } |
175 | |
176 | mlir::Attribute attr; |
177 | if (matchPattern(operand, m_Constant(&attr))) return attr; |
178 | return nullptr; |
179 | }; |
180 | |
181 | auto op_result_as_shape_fn = |
182 | [](shape_inference::InferenceContext& ic, |
183 | mlir::OpResult op_result) -> shape_inference::ShapeHandle { |
184 | auto rt = op_result.getType().dyn_cast<mlir::RankedTensorType>(); |
185 | if (!rt || rt.getRank() != 1 || !rt.hasStaticShape()) return {}; |
186 | |
187 | std::vector<shape_inference::DimensionHandle> dims(rt.getDimSize(0), |
188 | ic.UnknownDim()); |
189 | mlir::Attribute attr; |
190 | if (matchPattern(op_result, m_Constant(&attr))) { |
191 | auto elements = attr.dyn_cast<mlir::DenseIntElementsAttr>(); |
192 | if (elements) |
193 | for (const auto& element : |
194 | llvm::enumerate(elements.getValues<llvm::APInt>())) |
195 | dims[element.index()] = ic.MakeDim(element.value().getSExtValue()); |
196 | } |
197 | return ic.MakeShape(dims); |
198 | }; |
199 | |
200 | auto result_element_type_fn = [](int) -> mlir::Type { return nullptr; }; |
201 | |
202 | return mlir::TF::InferReturnTypeComponentsForTFOp( |
203 | location, op, graph_version, operand_as_constant_fn, |
204 | op_result_as_shape_fn, result_element_type_fn, inferred_return_shapes); |
205 | } |
206 | |
207 | } // namespace |
208 | |
209 | mlir::Operation* InferSPMDExpandedLocalShape(mlir::Operation* op) { |
210 | llvm::SmallVector<mlir::ShapedTypeComponents, 4> inferred_return_types; |
211 | (void)InferShapeOfTFOpWithCustomOperandConstantFn( |
212 | op->getLoc(), op, TF_GRAPH_DEF_VERSION, inferred_return_types); |
213 | assert(inferred_return_types.size() == op->getNumResults()); |
214 | |
215 | for (auto it : llvm::zip(inferred_return_types, op->getOpResults())) { |
216 | const auto& return_type = std::get<0>(it); |
217 | auto& op_result = std::get<1>(it); |
218 | const auto element_type = |
219 | op_result.getType().cast<mlir::TensorType>().getElementType(); |
220 | |
221 | if (return_type.hasRank()) { |
222 | op_result.setType( |
223 | mlir::RankedTensorType::get(return_type.getDims(), element_type)); |
224 | } else { |
225 | op_result.setType(mlir::UnrankedTensorType::get(element_type)); |
226 | } |
227 | } |
228 | |
229 | return op; |
230 | } |
231 | |
232 | StatusOr<llvm::ArrayRef<int64_t>> GetShapeOfValue(const mlir::Value& value, |
233 | bool fail_on_dynamic) { |
234 | // Getting the subtype or self allows supporting extracting the underlying |
235 | // shape that variant or resource tensors point to. |
236 | mlir::Type type = GetSubtypeOrSelf(value); |
237 | if (auto ranked_type = type.dyn_cast<mlir::RankedTensorType>()) { |
238 | if (ranked_type.hasStaticShape() || !fail_on_dynamic) |
239 | return ranked_type.getShape(); |
240 | else |
241 | return errors::InvalidArgument("value shape is not static" ); |
242 | } |
243 | return errors::InvalidArgument("value type is not a RankedTensorType" ); |
244 | } |
245 | |
246 | StatusOr<llvm::ArrayRef<int64_t>> GetGlobalShapeOfValueFromDTensorLayout( |
247 | const mlir::Value& value) { |
248 | if (value.isa<mlir::OpResult>() && |
249 | mlir::isa<mlir::TF::DTensorLayout>(value.getDefiningOp())) { |
250 | auto layout_op = mlir::cast<mlir::TF::DTensorLayout>(value.getDefiningOp()); |
251 | if (layout_op.global_shape()) return layout_op.global_shape().getValue(); |
252 | } else if (value.hasOneUse() && |
253 | mlir::isa<mlir::TF::DTensorLayout>(*value.getUsers().begin())) { |
254 | auto layout_op = |
255 | mlir::cast<mlir::TF::DTensorLayout>(*value.getUsers().begin()); |
256 | if (layout_op.global_shape()) return layout_op.global_shape().getValue(); |
257 | } |
258 | return errors::InvalidArgument( |
259 | "consumer or producer of value is not a DTensorLayout" ); |
260 | } |
261 | |
262 | } // namespace dtensor |
263 | } // namespace tensorflow |
264 | |