1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_MLIR_VALUE_UTILS_H_
17#define TENSORFLOW_DTENSOR_MLIR_VALUE_UTILS_H_
18
19#include "llvm/ADT/ArrayRef.h"
20#include "llvm/ADT/SmallVector.h"
21#include "mlir/IR/Builders.h" // from @llvm-project
22#include "mlir/IR/BuiltinOps.h" // from @llvm-project
23#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24#include "mlir/IR/Location.h" // from @llvm-project
25#include "mlir/IR/Value.h" // from @llvm-project
26#include "tensorflow/dtensor/cc/dstatus.h"
27
28namespace tensorflow {
29namespace dtensor {
30
31int ValueRank(mlir::Value operand_value);
32
33// Creates a effective scalar type as rank 1 with a single element.
34mlir::RankedTensorType EffectivelyScalarR1Type(mlir::Type element_type);
35
36// Reshapes a value of size type tensor<i32> to scalar.
37mlir::Value ReshapeSizeTypeToScalar(mlir::OpBuilder builder, mlir::Location loc,
38 mlir::Value tensor);
39
40// Return a 1-D int32 constant array with the given values.
41mlir::Value IntConst(mlir::OpBuilder& builder, mlir::Location loc,
42 llvm::ArrayRef<int32> values);
43// Return a 1-D int64 constant array with the given values.
44mlir::Value Int64Const(mlir::OpBuilder& builder, mlir::Location loc,
45 llvm::ArrayRef<int64_t> values);
46// Return a 1-D float32 constant array with the given values.
47mlir::Value FloatConst(mlir::OpBuilder& builder, mlir::Location loc,
48 llvm::ArrayRef<float> values);
49// Returns a 1-D tf.string constant array with given values.
50mlir::Value StringConst(mlir::OpBuilder& builder, mlir::Location loc,
51 llvm::ArrayRef<llvm::StringRef> values);
52// Returns a tf.string scalar constant with given value.
53mlir::Value StringScalarConst(mlir::OpBuilder& builder, mlir::Location loc,
54 llvm::StringRef value);
55StatusOr<int64_t> ExtractConstIntFromValue(mlir::Value value);
56Status ExtractConstVectorFromValue(mlir::Value value,
57 llvm::SmallVector<int64_t, 4>* out_vector);
58
59// Returns a int64 scalar constant with `value`.
60mlir::Value CreateIntScalarConst(const int64_t value, mlir::OpBuilder builder,
61 mlir::Location loc, bool use_int64 = true);
62
63// Returns a scalar constant with 'value' of 'type'.
64absl::optional<mlir::Value> CreateZeroScalarConst(mlir::OpBuilder& builder,
65 mlir::Location loc,
66 mlir::Type type);
67
68// Selects a scalar tensor value from a 1D array in specified index.
69StatusOr<mlir::Value> SelectScalarValueFromArray(mlir::OpBuilder& builder,
70 int index,
71 mlir::Location location,
72 mlir::Value array);
73
74// Returns the type that value holds. If value holds a Type that has a subtype,
75// then it returns the subtype.
76mlir::Type GetSubtypeOrSelf(mlir::Value value);
77
78} // namespace dtensor
79} // namespace tensorflow
80#endif // TENSORFLOW_DTENSOR_MLIR_VALUE_UTILS_H_
81