1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_MLIR_VALUE_UTILS_H_ |
17 | #define TENSORFLOW_DTENSOR_MLIR_VALUE_UTILS_H_ |
18 | |
19 | #include "llvm/ADT/ArrayRef.h" |
20 | #include "llvm/ADT/SmallVector.h" |
21 | #include "mlir/IR/Builders.h" // from @llvm-project |
22 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
23 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
24 | #include "mlir/IR/Location.h" // from @llvm-project |
25 | #include "mlir/IR/Value.h" // from @llvm-project |
26 | #include "tensorflow/dtensor/cc/dstatus.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace dtensor { |
30 | |
31 | int ValueRank(mlir::Value operand_value); |
32 | |
33 | // Creates a effective scalar type as rank 1 with a single element. |
34 | mlir::RankedTensorType EffectivelyScalarR1Type(mlir::Type element_type); |
35 | |
36 | // Reshapes a value of size type tensor<i32> to scalar. |
37 | mlir::Value ReshapeSizeTypeToScalar(mlir::OpBuilder builder, mlir::Location loc, |
38 | mlir::Value tensor); |
39 | |
40 | // Return a 1-D int32 constant array with the given values. |
41 | mlir::Value IntConst(mlir::OpBuilder& builder, mlir::Location loc, |
42 | llvm::ArrayRef<int32> values); |
43 | // Return a 1-D int64 constant array with the given values. |
44 | mlir::Value Int64Const(mlir::OpBuilder& builder, mlir::Location loc, |
45 | llvm::ArrayRef<int64_t> values); |
46 | // Return a 1-D float32 constant array with the given values. |
47 | mlir::Value FloatConst(mlir::OpBuilder& builder, mlir::Location loc, |
48 | llvm::ArrayRef<float> values); |
49 | // Returns a 1-D tf.string constant array with given values. |
50 | mlir::Value StringConst(mlir::OpBuilder& builder, mlir::Location loc, |
51 | llvm::ArrayRef<llvm::StringRef> values); |
52 | // Returns a tf.string scalar constant with given value. |
53 | mlir::Value StringScalarConst(mlir::OpBuilder& builder, mlir::Location loc, |
54 | llvm::StringRef value); |
55 | StatusOr<int64_t> (mlir::Value value); |
56 | Status (mlir::Value value, |
57 | llvm::SmallVector<int64_t, 4>* out_vector); |
58 | |
59 | // Returns a int64 scalar constant with `value`. |
60 | mlir::Value CreateIntScalarConst(const int64_t value, mlir::OpBuilder builder, |
61 | mlir::Location loc, bool use_int64 = true); |
62 | |
63 | // Returns a scalar constant with 'value' of 'type'. |
64 | absl::optional<mlir::Value> CreateZeroScalarConst(mlir::OpBuilder& builder, |
65 | mlir::Location loc, |
66 | mlir::Type type); |
67 | |
68 | // Selects a scalar tensor value from a 1D array in specified index. |
69 | StatusOr<mlir::Value> SelectScalarValueFromArray(mlir::OpBuilder& builder, |
70 | int index, |
71 | mlir::Location location, |
72 | mlir::Value array); |
73 | |
74 | // Returns the type that value holds. If value holds a Type that has a subtype, |
75 | // then it returns the subtype. |
76 | mlir::Type GetSubtypeOrSelf(mlir::Value value); |
77 | |
78 | } // namespace dtensor |
79 | } // namespace tensorflow |
80 | #endif // TENSORFLOW_DTENSOR_MLIR_VALUE_UTILS_H_ |
81 | |