1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/mlir/value_utils.h"
17
18#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
19#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
21#include "tensorflow/core/platform/errors.h"
22#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
23#include "tensorflow/dtensor/mlir/op_utils.h"
24
25namespace tensorflow {
26namespace dtensor {
27namespace {
28
29// Given a mlir::Value will trace the value back through
30// DTensorLayout and basic blocks of while loops.
31// This is like a reverse version of TraceUseToNextTFOp.
32mlir::Value GetForwardedInput(mlir::Value value) {
33 bool value_updated;
34 do {
35 value_updated = false;
36 if (mlir::BlockArgument argument = value.dyn_cast<mlir::BlockArgument>()) {
37 mlir::Region* region = argument.getParentRegion();
38 if (region == nullptr) break;
39 mlir::Operation* parent_op = region->getParentOp();
40 // TODO(bfontain): handle if and other control flow blocks.
41 if (mlir::TF::WhileRegionOp while_op =
42 mlir::dyn_cast<mlir::TF::WhileRegionOp>(parent_op)) {
43 value = while_op.getOperand(argument.getArgNumber());
44 value_updated = true;
45 }
46 } else {
47 mlir::Operation* op = value.getDefiningOp();
48 // TODO(bfontain): Add cases for identity and control flow return values.
49 if (mlir::TF::DTensorLayout layout_op =
50 mlir::dyn_cast<mlir::TF::DTensorLayout>(op)) {
51 value = layout_op.input();
52 value_updated = true;
53 }
54 }
55 } while (value_updated);
56
57 return value;
58}
59} // namespace
60
61namespace ops_util = ::mlir::TF::collection_ops_util;
62
63int ValueRank(mlir::Value operand_value) {
64 mlir::Type type = GetSubtypeOrSelf(operand_value);
65 const auto operand_type = type.cast<mlir::TensorType>();
66 if (!operand_type.hasRank()) return -1;
67 return operand_type.getRank();
68}
69
70mlir::RankedTensorType EffectivelyScalarR1Type(mlir::Type element_type) {
71 return mlir::RankedTensorType::get({1}, element_type);
72}
73
74mlir::Value ReshapeSizeTypeToScalar(mlir::OpBuilder builder, mlir::Location loc,
75 mlir::Value tensor) {
76 auto scalar_type =
77 mlir::RankedTensorType::get({}, builder.getIntegerType(32));
78 mlir::Value scalar_shape =
79 ops_util::GetR1Const(scalar_type.getShape(), builder, loc);
80 return builder.create<mlir::TF::ReshapeOp>(
81 loc, mlir::ArrayRef<mlir::Type>{scalar_type},
82 mlir::ArrayRef<mlir::Value>{tensor, scalar_shape});
83}
84
85mlir::Value IntConst(mlir::OpBuilder& builder, mlir::Location loc,
86 llvm::ArrayRef<int32> values) {
87 auto const_type = mlir::RankedTensorType::get(
88 {static_cast<int64_t>(values.size())}, builder.getIntegerType(32));
89 mlir::Attribute const_attr =
90 mlir::DenseIntElementsAttr::get(const_type, values);
91 return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
92}
93
94mlir::Value Int64Const(mlir::OpBuilder& builder, mlir::Location loc,
95 llvm::ArrayRef<int64_t> values) {
96 auto const_type = mlir::RankedTensorType::get(
97 {static_cast<int64_t>(values.size())}, builder.getIntegerType(64));
98 mlir::Attribute const_attr =
99 mlir::DenseIntElementsAttr::get(const_type, values);
100 return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
101}
102
103mlir::Value FloatConst(mlir::OpBuilder& builder, mlir::Location loc,
104 llvm::ArrayRef<float> values) {
105 mlir::RankedTensorType const_type = mlir::RankedTensorType::get(
106 {static_cast<int64_t>(values.size())}, builder.getF32Type());
107 mlir::Attribute const_attr =
108 mlir::DenseFPElementsAttr::get(const_type, values);
109 return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
110}
111
112mlir::Value StringScalarConst(mlir::OpBuilder& builder, mlir::Location loc,
113 llvm::StringRef value) {
114 return builder.create<mlir::TF::ConstOp>(
115 loc, mlir::DenseStringElementsAttr::get(
116 mlir::RankedTensorType::get(
117 {}, builder.getType<mlir::TF::StringType>()),
118 value));
119}
120
121mlir::Value StringConst(mlir::OpBuilder& builder, mlir::Location loc,
122 llvm::ArrayRef<llvm::StringRef> values) {
123 auto const_type =
124 mlir::RankedTensorType::get({static_cast<int64_t>(values.size())},
125 builder.getType<mlir::TF::StringType>());
126 mlir::Attribute const_attr =
127 mlir::DenseStringElementsAttr::get(const_type, values);
128 return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
129}
130
131StatusOr<int64_t> ExtractConstIntFromValue(mlir::Value value) {
132 value = GetForwardedInput(value);
133 if (value.isa<mlir::BlockArgument>())
134 return errors::Internal("unable get constant value from block argument");
135 mlir::DenseIntElementsAttr attr;
136 if (!matchPattern(value, m_Constant(&attr))) {
137 return errors::Internal(absl::StrCat("required constant value for ",
138 OpName(value.getDefiningOp())));
139 }
140 if (attr.size() != 1) {
141 return errors::Internal(absl::StrCat("expected 1 element, got ",
142 attr.size(), " for ",
143 OpName(value.getDefiningOp())));
144 }
145 auto a = *attr.value_begin<llvm::APInt>();
146 return a.getSExtValue();
147}
148
149Status ExtractConstVectorFromValue(mlir::Value value,
150 llvm::SmallVector<int64_t, 4>* out_vector) {
151 value = GetForwardedInput(value);
152 if (value.isa<mlir::BlockArgument>())
153 return errors::Internal("unable get constant value from block argument");
154 mlir::DenseIntElementsAttr attr;
155 if (!matchPattern(value, m_Constant(&attr))) {
156 return errors::Internal(
157 absl::StrCat("failed to extract constant value from ",
158 value.getDefiningOp()->getName().getStringRef().str()));
159 }
160 for (const mlir::APInt& index : attr)
161 out_vector->emplace_back(index.getSExtValue());
162 return OkStatus();
163}
164
165mlir::Value CreateIntScalarConst(const int64_t value, mlir::OpBuilder builder,
166 mlir::Location loc, bool use_int64) {
167 if (use_int64) {
168 return builder.create<mlir::TF::ConstOp>(
169 loc, mlir::DenseIntElementsAttr::get(
170 mlir::RankedTensorType::get({}, builder.getI64Type()), value));
171 } else {
172 return builder.create<mlir::TF::ConstOp>(
173 loc, mlir::DenseIntElementsAttr::get(
174 mlir::RankedTensorType::get({}, builder.getI32Type()),
175 static_cast<int32_t>(value)));
176 }
177}
178
179absl::optional<mlir::Value> CreateZeroScalarConst(mlir::OpBuilder& builder,
180 mlir::Location loc,
181 mlir::Type type) {
182 if (type.isF64()) {
183 return builder.create<mlir::TF::ConstOp>(
184 loc, mlir::DenseFPElementsAttr::get(
185 mlir::RankedTensorType::get({}, builder.getF64Type()),
186 static_cast<double>(0.)));
187 } else if (type.isF32()) {
188 return builder.create<mlir::TF::ConstOp>(
189 loc, mlir::DenseFPElementsAttr::get(
190 mlir::RankedTensorType::get({}, builder.getF32Type()),
191 static_cast<float>(0.f)));
192 } else if (type.isInteger(32)) {
193 return builder.create<mlir::TF::ConstOp>(
194 loc, mlir::DenseIntElementsAttr::get(
195 mlir::RankedTensorType::get({}, builder.getI32Type()),
196 static_cast<int32_t>(0)));
197 } else if (type.isInteger(64)) {
198 return builder.create<mlir::TF::ConstOp>(
199 loc, mlir::DenseIntElementsAttr::get(
200 mlir::RankedTensorType::get({}, builder.getI64Type()),
201 static_cast<int64_t>(0)));
202 } else {
203 return absl::nullopt;
204 }
205}
206
207StatusOr<mlir::Value> SelectScalarValueFromArray(mlir::OpBuilder& builder,
208 int index,
209 mlir::Location location,
210 mlir::Value array) {
211 mlir::TensorType arrayType = array.getType().cast<mlir::TensorType>();
212 if (arrayType.getRank() != 2 || arrayType.getDimSize(0) != 1) {
213 return errors::InvalidArgument("Input array must have shape [1, N].");
214 }
215
216 mlir::TF::SliceOp sliced_value = builder.create<mlir::TF::SliceOp>(
217 location, mlir::RankedTensorType::get({1, 1}, arrayType.getElementType()),
218 /*input=*/array,
219 /*begin=*/IntConst(builder, location, {0, index}),
220 /*size=*/IntConst(builder, location, {1, 1}));
221
222 // Reshape the sliced shape (1,1) tensor to shape 0 scalar.
223 auto scalar_size_type =
224 mlir::RankedTensorType::get({}, builder.getIntegerType(32));
225 mlir::Value scalar_shape = mlir::TF::collection_ops_util::GetR1Const(
226 scalar_size_type.getShape(), builder, location);
227 mlir::Value scalar_sliced_value = builder.create<mlir::TF::ReshapeOp>(
228 location, mlir::ArrayRef<mlir::Type>{scalar_size_type},
229 mlir::ArrayRef<mlir::Value>{sliced_value.output(), scalar_shape},
230 mlir::ArrayRef<mlir::NamedAttribute>{});
231 return scalar_sliced_value;
232}
233
234mlir::Type GetSubtypeOrSelf(mlir::Value val) {
235 mlir::Type type = val.getType();
236 if (auto type_with_subtype =
237 mlir::getElementTypeOrSelf(val)
238 .dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>()) {
239 if (type_with_subtype.GetSubtypes().size() == 1) {
240 type = type_with_subtype.GetSubtypes().front();
241 }
242 }
243 return type;
244}
245
246} // namespace dtensor
247} // namespace tensorflow
248