1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/mlir/value_utils.h" |
17 | |
18 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
19 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
20 | #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" |
21 | #include "tensorflow/core/platform/errors.h" |
22 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
23 | #include "tensorflow/dtensor/mlir/op_utils.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace dtensor { |
27 | namespace { |
28 | |
29 | // Given a mlir::Value will trace the value back through |
30 | // DTensorLayout and basic blocks of while loops. |
31 | // This is like a reverse version of TraceUseToNextTFOp. |
32 | mlir::Value GetForwardedInput(mlir::Value value) { |
33 | bool value_updated; |
34 | do { |
35 | value_updated = false; |
36 | if (mlir::BlockArgument argument = value.dyn_cast<mlir::BlockArgument>()) { |
37 | mlir::Region* region = argument.getParentRegion(); |
38 | if (region == nullptr) break; |
39 | mlir::Operation* parent_op = region->getParentOp(); |
40 | // TODO(bfontain): handle if and other control flow blocks. |
41 | if (mlir::TF::WhileRegionOp while_op = |
42 | mlir::dyn_cast<mlir::TF::WhileRegionOp>(parent_op)) { |
43 | value = while_op.getOperand(argument.getArgNumber()); |
44 | value_updated = true; |
45 | } |
46 | } else { |
47 | mlir::Operation* op = value.getDefiningOp(); |
48 | // TODO(bfontain): Add cases for identity and control flow return values. |
49 | if (mlir::TF::DTensorLayout layout_op = |
50 | mlir::dyn_cast<mlir::TF::DTensorLayout>(op)) { |
51 | value = layout_op.input(); |
52 | value_updated = true; |
53 | } |
54 | } |
55 | } while (value_updated); |
56 | |
57 | return value; |
58 | } |
59 | } // namespace |
60 | |
61 | namespace ops_util = ::mlir::TF::collection_ops_util; |
62 | |
63 | int ValueRank(mlir::Value operand_value) { |
64 | mlir::Type type = GetSubtypeOrSelf(operand_value); |
65 | const auto operand_type = type.cast<mlir::TensorType>(); |
66 | if (!operand_type.hasRank()) return -1; |
67 | return operand_type.getRank(); |
68 | } |
69 | |
70 | mlir::RankedTensorType EffectivelyScalarR1Type(mlir::Type element_type) { |
71 | return mlir::RankedTensorType::get({1}, element_type); |
72 | } |
73 | |
74 | mlir::Value ReshapeSizeTypeToScalar(mlir::OpBuilder builder, mlir::Location loc, |
75 | mlir::Value tensor) { |
76 | auto scalar_type = |
77 | mlir::RankedTensorType::get({}, builder.getIntegerType(32)); |
78 | mlir::Value scalar_shape = |
79 | ops_util::GetR1Const(scalar_type.getShape(), builder, loc); |
80 | return builder.create<mlir::TF::ReshapeOp>( |
81 | loc, mlir::ArrayRef<mlir::Type>{scalar_type}, |
82 | mlir::ArrayRef<mlir::Value>{tensor, scalar_shape}); |
83 | } |
84 | |
85 | mlir::Value IntConst(mlir::OpBuilder& builder, mlir::Location loc, |
86 | llvm::ArrayRef<int32> values) { |
87 | auto const_type = mlir::RankedTensorType::get( |
88 | {static_cast<int64_t>(values.size())}, builder.getIntegerType(32)); |
89 | mlir::Attribute const_attr = |
90 | mlir::DenseIntElementsAttr::get(const_type, values); |
91 | return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult(); |
92 | } |
93 | |
94 | mlir::Value Int64Const(mlir::OpBuilder& builder, mlir::Location loc, |
95 | llvm::ArrayRef<int64_t> values) { |
96 | auto const_type = mlir::RankedTensorType::get( |
97 | {static_cast<int64_t>(values.size())}, builder.getIntegerType(64)); |
98 | mlir::Attribute const_attr = |
99 | mlir::DenseIntElementsAttr::get(const_type, values); |
100 | return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult(); |
101 | } |
102 | |
103 | mlir::Value FloatConst(mlir::OpBuilder& builder, mlir::Location loc, |
104 | llvm::ArrayRef<float> values) { |
105 | mlir::RankedTensorType const_type = mlir::RankedTensorType::get( |
106 | {static_cast<int64_t>(values.size())}, builder.getF32Type()); |
107 | mlir::Attribute const_attr = |
108 | mlir::DenseFPElementsAttr::get(const_type, values); |
109 | return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult(); |
110 | } |
111 | |
112 | mlir::Value StringScalarConst(mlir::OpBuilder& builder, mlir::Location loc, |
113 | llvm::StringRef value) { |
114 | return builder.create<mlir::TF::ConstOp>( |
115 | loc, mlir::DenseStringElementsAttr::get( |
116 | mlir::RankedTensorType::get( |
117 | {}, builder.getType<mlir::TF::StringType>()), |
118 | value)); |
119 | } |
120 | |
121 | mlir::Value StringConst(mlir::OpBuilder& builder, mlir::Location loc, |
122 | llvm::ArrayRef<llvm::StringRef> values) { |
123 | auto const_type = |
124 | mlir::RankedTensorType::get({static_cast<int64_t>(values.size())}, |
125 | builder.getType<mlir::TF::StringType>()); |
126 | mlir::Attribute const_attr = |
127 | mlir::DenseStringElementsAttr::get(const_type, values); |
128 | return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult(); |
129 | } |
130 | |
131 | StatusOr<int64_t> (mlir::Value value) { |
132 | value = GetForwardedInput(value); |
133 | if (value.isa<mlir::BlockArgument>()) |
134 | return errors::Internal("unable get constant value from block argument" ); |
135 | mlir::DenseIntElementsAttr attr; |
136 | if (!matchPattern(value, m_Constant(&attr))) { |
137 | return errors::Internal(absl::StrCat("required constant value for " , |
138 | OpName(value.getDefiningOp()))); |
139 | } |
140 | if (attr.size() != 1) { |
141 | return errors::Internal(absl::StrCat("expected 1 element, got " , |
142 | attr.size(), " for " , |
143 | OpName(value.getDefiningOp()))); |
144 | } |
145 | auto a = *attr.value_begin<llvm::APInt>(); |
146 | return a.getSExtValue(); |
147 | } |
148 | |
149 | Status (mlir::Value value, |
150 | llvm::SmallVector<int64_t, 4>* out_vector) { |
151 | value = GetForwardedInput(value); |
152 | if (value.isa<mlir::BlockArgument>()) |
153 | return errors::Internal("unable get constant value from block argument" ); |
154 | mlir::DenseIntElementsAttr attr; |
155 | if (!matchPattern(value, m_Constant(&attr))) { |
156 | return errors::Internal( |
157 | absl::StrCat("failed to extract constant value from " , |
158 | value.getDefiningOp()->getName().getStringRef().str())); |
159 | } |
160 | for (const mlir::APInt& index : attr) |
161 | out_vector->emplace_back(index.getSExtValue()); |
162 | return OkStatus(); |
163 | } |
164 | |
165 | mlir::Value CreateIntScalarConst(const int64_t value, mlir::OpBuilder builder, |
166 | mlir::Location loc, bool use_int64) { |
167 | if (use_int64) { |
168 | return builder.create<mlir::TF::ConstOp>( |
169 | loc, mlir::DenseIntElementsAttr::get( |
170 | mlir::RankedTensorType::get({}, builder.getI64Type()), value)); |
171 | } else { |
172 | return builder.create<mlir::TF::ConstOp>( |
173 | loc, mlir::DenseIntElementsAttr::get( |
174 | mlir::RankedTensorType::get({}, builder.getI32Type()), |
175 | static_cast<int32_t>(value))); |
176 | } |
177 | } |
178 | |
179 | absl::optional<mlir::Value> CreateZeroScalarConst(mlir::OpBuilder& builder, |
180 | mlir::Location loc, |
181 | mlir::Type type) { |
182 | if (type.isF64()) { |
183 | return builder.create<mlir::TF::ConstOp>( |
184 | loc, mlir::DenseFPElementsAttr::get( |
185 | mlir::RankedTensorType::get({}, builder.getF64Type()), |
186 | static_cast<double>(0.))); |
187 | } else if (type.isF32()) { |
188 | return builder.create<mlir::TF::ConstOp>( |
189 | loc, mlir::DenseFPElementsAttr::get( |
190 | mlir::RankedTensorType::get({}, builder.getF32Type()), |
191 | static_cast<float>(0.f))); |
192 | } else if (type.isInteger(32)) { |
193 | return builder.create<mlir::TF::ConstOp>( |
194 | loc, mlir::DenseIntElementsAttr::get( |
195 | mlir::RankedTensorType::get({}, builder.getI32Type()), |
196 | static_cast<int32_t>(0))); |
197 | } else if (type.isInteger(64)) { |
198 | return builder.create<mlir::TF::ConstOp>( |
199 | loc, mlir::DenseIntElementsAttr::get( |
200 | mlir::RankedTensorType::get({}, builder.getI64Type()), |
201 | static_cast<int64_t>(0))); |
202 | } else { |
203 | return absl::nullopt; |
204 | } |
205 | } |
206 | |
207 | StatusOr<mlir::Value> SelectScalarValueFromArray(mlir::OpBuilder& builder, |
208 | int index, |
209 | mlir::Location location, |
210 | mlir::Value array) { |
211 | mlir::TensorType arrayType = array.getType().cast<mlir::TensorType>(); |
212 | if (arrayType.getRank() != 2 || arrayType.getDimSize(0) != 1) { |
213 | return errors::InvalidArgument("Input array must have shape [1, N]." ); |
214 | } |
215 | |
216 | mlir::TF::SliceOp sliced_value = builder.create<mlir::TF::SliceOp>( |
217 | location, mlir::RankedTensorType::get({1, 1}, arrayType.getElementType()), |
218 | /*input=*/array, |
219 | /*begin=*/IntConst(builder, location, {0, index}), |
220 | /*size=*/IntConst(builder, location, {1, 1})); |
221 | |
222 | // Reshape the sliced shape (1,1) tensor to shape 0 scalar. |
223 | auto scalar_size_type = |
224 | mlir::RankedTensorType::get({}, builder.getIntegerType(32)); |
225 | mlir::Value scalar_shape = mlir::TF::collection_ops_util::GetR1Const( |
226 | scalar_size_type.getShape(), builder, location); |
227 | mlir::Value scalar_sliced_value = builder.create<mlir::TF::ReshapeOp>( |
228 | location, mlir::ArrayRef<mlir::Type>{scalar_size_type}, |
229 | mlir::ArrayRef<mlir::Value>{sliced_value.output(), scalar_shape}, |
230 | mlir::ArrayRef<mlir::NamedAttribute>{}); |
231 | return scalar_sliced_value; |
232 | } |
233 | |
234 | mlir::Type GetSubtypeOrSelf(mlir::Value val) { |
235 | mlir::Type type = val.getType(); |
236 | if (auto type_with_subtype = |
237 | mlir::getElementTypeOrSelf(val) |
238 | .dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>()) { |
239 | if (type_with_subtype.GetSubtypes().size() == 1) { |
240 | type = type_with_subtype.GetSubtypes().front(); |
241 | } |
242 | } |
243 | return type; |
244 | } |
245 | |
246 | } // namespace dtensor |
247 | } // namespace tensorflow |
248 | |