1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "llvm/ADT/ArrayRef.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SmallVector.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20#include "mlir/IR/Attributes.h" // from @llvm-project
21#include "mlir/IR/Builders.h" // from @llvm-project
22#include "mlir/IR/BuiltinOps.h" // from @llvm-project
23#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24#include "mlir/IR/Diagnostics.h" // from @llvm-project
25#include "mlir/IR/Operation.h" // from @llvm-project
26#include "mlir/IR/TypeUtilities.h" // from @llvm-project
27#include "mlir/IR/Types.h" // from @llvm-project
28#include "mlir/Transforms/Passes.h" // from @llvm-project
29#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
30#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
31#include "tensorflow/dtensor/cc/constants.h"
32#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33
34namespace tensorflow {
35namespace dtensor {
36
37namespace {
38#define GEN_PASS_DEF_DTENSORANNOTATEGLOBALSHAPE
39#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
40
41// Sets `_global_shape` attributes to argument/return values of `function`.
42void AnnotateFunctionArgRetvalGlobalShapes(mlir::func::FuncOp function,
43 mlir::OpBuilder* builder) {
44 for (const auto& argument_type_and_index :
45 llvm::enumerate(function.getArgumentTypes())) {
46 const int index = argument_type_and_index.index();
47 const auto& argument_type = argument_type_and_index.value();
48 // Extract TensorType from element of resource type to allow setting proper
49 // global shape of resource types.
50 if (auto resource_type = mlir::getElementTypeOrSelf(argument_type)
51 .dyn_cast<mlir::TF::ResourceType>()) {
52 auto subtype = resource_type.getSubtypes();
53 if (subtype.size() == 1) {
54 // subtype returns a Array of TensorType -- if it contains more than one
55 // Tensor type, we give up extracting the single TensorType inside the
56 // subtype.
57 function.setArgAttr(index, kGlobalShapeDialectAttr,
58 ConvertTypeToTensorShapeAttr(subtype[0]));
59 }
60 } else {
61 function.setArgAttr(index, kGlobalShapeDialectAttr,
62 ConvertTypeToTensorShapeAttr(argument_type));
63 }
64 }
65
66 for (const auto& retval_type_and_index :
67 llvm::enumerate(function.getFunctionType().getResults())) {
68 const int index = retval_type_and_index.index();
69 const auto& retval_type = retval_type_and_index.value();
70 function.setResultAttr(index, kGlobalShapeDialectAttr,
71 ConvertTypeToTensorShapeAttr(retval_type));
72 }
73}
74
75// Sets `_global_shape` attribute of an `op` with array of ShapeAttr of
76// `outputs.
77void AnnotateOperationGlobalShape(mlir::Operation* op,
78 mlir::OpBuilder* builder) {
79 llvm::SmallVector<mlir::Attribute, 4> op_global_shape;
80 op_global_shape.reserve(op->getNumResults());
81
82 for (const auto& result_type : op->getResultTypes())
83 op_global_shape.emplace_back(ConvertTypeToTensorShapeAttr(result_type));
84
85 op->setAttr(kGlobalShape, builder->getArrayAttr(op_global_shape));
86}
87
88// Pass that annotates function argument/return values and all operation with
89// `_global_shape` attribute. This will be used during SPMD expansion to
90// preserve original global shape of operations in graph after shape has been
91// modified to local shape.
92struct DTensorAnnotateGlobalShape
93 : public impl::DTensorAnnotateGlobalShapeBase<DTensorAnnotateGlobalShape> {
94 void runOnOperation() override {
95 mlir::MLIRContext& context = getContext();
96 mlir::OpBuilder builder(&context);
97
98 auto module = getOperation();
99 module.walk([&](mlir::func::FuncOp function) {
100 if (function.empty()) return;
101
102 auto* terminator = function.getBody().front().getTerminator();
103 AnnotateFunctionArgRetvalGlobalShapes(function, &builder);
104 function.getBody().walk([&](mlir::Operation* op) {
105 if (op == terminator) return;
106
107 AnnotateOperationGlobalShape(op, &builder);
108 });
109 });
110 }
111};
112
113} // namespace
114
115std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
116CreateDTensorAnnotateGlobalShape() {
117 return std::make_unique<DTensorAnnotateGlobalShape>();
118}
119
120} // namespace dtensor
121} // namespace tensorflow
122