1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/shape_utils.h"
17
18#include "llvm/Support/FormatVariadic.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20#include "mlir/IR/Builders.h" // from @llvm-project
21#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
22#include "mlir/IR/BuiltinOps.h" // from @llvm-project
23#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24#include "mlir/IR/Location.h" // from @llvm-project
25#include "mlir/IR/MLIRContext.h" // from @llvm-project
26#include "mlir/IR/OperationSupport.h" // from @llvm-project
27#include "mlir/IR/Value.h" // from @llvm-project
28#include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h"
29#include "tensorflow/core/public/version.h"
30#include "tensorflow/dtensor/cc/constants.h"
31#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
32#include "tensorflow/dtensor/mlir/value_utils.h"
33
34namespace tensorflow {
35namespace dtensor {
36
37StatusOr<llvm::ArrayRef<int64_t>> ExtractGlobalInputShape(
38 mlir::OpOperand& input_value) {
39 const int operand_index = input_value.getOperandNumber();
40 auto input_defining_op = input_value.get().getDefiningOp();
41
42 if (input_defining_op) {
43 if (auto layout_op =
44 llvm::dyn_cast<mlir::TF::DTensorLayout>(input_defining_op)) {
45 auto global_shape = layout_op.global_shape();
46 if (!global_shape)
47 return errors::Internal("global_shape does not have static rank");
48 return *global_shape;
49 }
50 return ExtractGlobalOutputShape(input_value.get().cast<mlir::OpResult>());
51 }
52
53 // If we reach this point, we're working with a function argument.
54 auto op = input_value.getOwner();
55 auto enclosing_function = op->getParentOfType<mlir::func::FuncOp>();
56 if (!enclosing_function)
57 return errors::InvalidArgument(
58 llvm::formatv("Could not find global shape of {0}-th input to op: {1}",
59 operand_index, op->getName())
60 .str());
61
62 auto block_arg = input_value.get().dyn_cast<mlir::BlockArgument>();
63 auto global_shape_attr =
64 enclosing_function.getArgAttrOfType<mlir::TF::ShapeAttr>(
65 block_arg.getArgNumber(), kGlobalShapeDialectAttr);
66 if (!global_shape_attr)
67 return errors::InvalidArgument(
68 "`tf._global_shape` attribute of operation not found.");
69
70 return global_shape_attr.getShape();
71}
72
73StatusOr<llvm::ArrayRef<int64_t>> ExtractGlobalOutputShape(
74 mlir::OpResult result_value) {
75 auto op = result_value.getOwner();
76 const int output_index = result_value.getResultNumber();
77
78 if (op->getOpResult(output_index).hasOneUse()) {
79 auto user = op->getOpResult(output_index).getUses().begin().getUser();
80 if (auto layout_op = mlir::dyn_cast<mlir::TF::DTensorLayout>(user)) {
81 auto global_shape = layout_op.global_shape();
82 if (!global_shape)
83 return errors::Internal("global_shape does not have static rank");
84 return *global_shape;
85 }
86 }
87
88 auto global_shape_attr = op->getAttrOfType<mlir::ArrayAttr>(kGlobalShape);
89 if (!global_shape_attr)
90 return errors::InvalidArgument(
91 "`_global_shape` attribute of operation not found.");
92
93 const int num_results = op->getNumResults();
94 assert(global_shape_attr.size() == num_results);
95
96 if (output_index >= op->getNumResults())
97 return errors::InvalidArgument(
98 llvm::formatv("Requested global shape of {0} output but op has only "
99 "{1} return values.",
100 output_index, num_results)
101 .str());
102
103 auto shape_attr = global_shape_attr[output_index];
104 return shape_attr.cast<mlir::TF::ShapeAttr>().getShape();
105}
106
107namespace {
108
109// Extracts attributes from a MLIR operation, including derived attributes, into
110// one NamedAttrList.
111mlir::NamedAttrList GetAllAttributesFromOperation(mlir::Operation* op) {
112 mlir::NamedAttrList attr_list;
113 attr_list.append(op->getAttrDictionary().getValue());
114
115 if (auto derived = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(op)) {
116 auto materialized = derived.materializeDerivedAttributes();
117 attr_list.append(materialized.getValue());
118 }
119
120 return attr_list;
121}
122
123// Infers output shape of `op` given its local operand shape. For shape
124// inference function that requires input operation to be a constant, if input
125// operation is `DTensorLayout` op, then we use input of DTensorLayout op
126// instead for correct constant matching.
127mlir::LogicalResult InferShapeOfTFOpWithCustomOperandConstantFn(
128 llvm::Optional<mlir::Location> location, mlir::Operation* op,
129 int64_t graph_version,
130 llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferred_return_shapes) {
131 if (auto type_op = llvm::dyn_cast<mlir::InferTypeOpInterface>(op)) {
132 auto attributes = GetAllAttributesFromOperation(op);
133 llvm::SmallVector<mlir::Type, 4> inferred_return_types;
134 auto result = type_op.inferReturnTypes(
135 op->getContext(), location, op->getOperands(),
136 mlir::DictionaryAttr::get(op->getContext(), attributes),
137 op->getRegions(), inferred_return_types);
138 if (failed(result)) return mlir::failure();
139
140 inferred_return_shapes.resize(inferred_return_types.size());
141 for (const auto& inferred_return_type :
142 llvm::enumerate(inferred_return_types)) {
143 if (auto shaped_type =
144 inferred_return_type.value().dyn_cast<mlir::ShapedType>()) {
145 if (shaped_type.hasRank()) {
146 inferred_return_shapes[inferred_return_type.index()] =
147 mlir::ShapedTypeComponents(shaped_type.getShape(),
148 shaped_type.getElementType());
149 } else {
150 inferred_return_shapes[inferred_return_type.index()] =
151 mlir::ShapedTypeComponents(shaped_type.getElementType());
152 }
153 }
154 }
155
156 return mlir::success();
157 }
158
159 if (auto shape_type_op =
160 llvm::dyn_cast<mlir::InferShapedTypeOpInterface>(op)) {
161 auto attributes = GetAllAttributesFromOperation(op);
162 return shape_type_op.inferReturnTypeComponents(
163 op->getContext(), location, op->getOperands(),
164 mlir::DictionaryAttr::get(op->getContext(), attributes),
165 op->getRegions(), inferred_return_shapes);
166 }
167
168 // If `operand` is from DTensorLayout op, use input value of DTensorLayout op
169 // instead.
170 auto operand_as_constant_fn = [](mlir::Value operand) -> mlir::Attribute {
171 while (auto input_op = llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>(
172 operand.getDefiningOp())) {
173 operand = input_op.input();
174 }
175
176 mlir::Attribute attr;
177 if (matchPattern(operand, m_Constant(&attr))) return attr;
178 return nullptr;
179 };
180
181 auto op_result_as_shape_fn =
182 [](shape_inference::InferenceContext& ic,
183 mlir::OpResult op_result) -> shape_inference::ShapeHandle {
184 auto rt = op_result.getType().dyn_cast<mlir::RankedTensorType>();
185 if (!rt || rt.getRank() != 1 || !rt.hasStaticShape()) return {};
186
187 std::vector<shape_inference::DimensionHandle> dims(rt.getDimSize(0),
188 ic.UnknownDim());
189 mlir::Attribute attr;
190 if (matchPattern(op_result, m_Constant(&attr))) {
191 auto elements = attr.dyn_cast<mlir::DenseIntElementsAttr>();
192 if (elements)
193 for (const auto& element :
194 llvm::enumerate(elements.getValues<llvm::APInt>()))
195 dims[element.index()] = ic.MakeDim(element.value().getSExtValue());
196 }
197 return ic.MakeShape(dims);
198 };
199
200 auto result_element_type_fn = [](int) -> mlir::Type { return nullptr; };
201
202 return mlir::TF::InferReturnTypeComponentsForTFOp(
203 location, op, graph_version, operand_as_constant_fn,
204 op_result_as_shape_fn, result_element_type_fn, inferred_return_shapes);
205}
206
207} // namespace
208
209mlir::Operation* InferSPMDExpandedLocalShape(mlir::Operation* op) {
210 llvm::SmallVector<mlir::ShapedTypeComponents, 4> inferred_return_types;
211 (void)InferShapeOfTFOpWithCustomOperandConstantFn(
212 op->getLoc(), op, TF_GRAPH_DEF_VERSION, inferred_return_types);
213 assert(inferred_return_types.size() == op->getNumResults());
214
215 for (auto it : llvm::zip(inferred_return_types, op->getOpResults())) {
216 const auto& return_type = std::get<0>(it);
217 auto& op_result = std::get<1>(it);
218 const auto element_type =
219 op_result.getType().cast<mlir::TensorType>().getElementType();
220
221 if (return_type.hasRank()) {
222 op_result.setType(
223 mlir::RankedTensorType::get(return_type.getDims(), element_type));
224 } else {
225 op_result.setType(mlir::UnrankedTensorType::get(element_type));
226 }
227 }
228
229 return op;
230}
231
232StatusOr<llvm::ArrayRef<int64_t>> GetShapeOfValue(const mlir::Value& value,
233 bool fail_on_dynamic) {
234 // Getting the subtype or self allows supporting extracting the underlying
235 // shape that variant or resource tensors point to.
236 mlir::Type type = GetSubtypeOrSelf(value);
237 if (auto ranked_type = type.dyn_cast<mlir::RankedTensorType>()) {
238 if (ranked_type.hasStaticShape() || !fail_on_dynamic)
239 return ranked_type.getShape();
240 else
241 return errors::InvalidArgument("value shape is not static");
242 }
243 return errors::InvalidArgument("value type is not a RankedTensorType");
244}
245
246StatusOr<llvm::ArrayRef<int64_t>> GetGlobalShapeOfValueFromDTensorLayout(
247 const mlir::Value& value) {
248 if (value.isa<mlir::OpResult>() &&
249 mlir::isa<mlir::TF::DTensorLayout>(value.getDefiningOp())) {
250 auto layout_op = mlir::cast<mlir::TF::DTensorLayout>(value.getDefiningOp());
251 if (layout_op.global_shape()) return layout_op.global_shape().getValue();
252 } else if (value.hasOneUse() &&
253 mlir::isa<mlir::TF::DTensorLayout>(*value.getUsers().begin())) {
254 auto layout_op =
255 mlir::cast<mlir::TF::DTensorLayout>(*value.getUsers().begin());
256 if (layout_op.global_shape()) return layout_op.global_shape().getValue();
257 }
258 return errors::InvalidArgument(
259 "consumer or producer of value is not a DTensorLayout");
260}
261
262} // namespace dtensor
263} // namespace tensorflow
264