1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_MLIR_SPARSE_EXPANDER_COMMON_H_
17#define TENSORFLOW_DTENSOR_MLIR_SPARSE_EXPANDER_COMMON_H_
18
19#include <optional>
20
21#include "absl/types/optional.h"
22#include "mlir/IR/Operation.h" // from @llvm-project
23#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
24
25namespace tensorflow {
26namespace dtensor {
27
28// Gets the SparseToDenseOp that generates `value` if `value` is the result of
29// a SparseToDenseOp. Returns empty otherwise. This is useful
30// in SparseExpansion where we want to check whether some operand
31// is a SparseTensor, by checking whether that operand is a result of a
32// SparseToDenseOp. If this value is eventually an output of a SparseToDenseOp,
33// there should only be DTensor related ops between the actual SparseToDenseOp,
34// e.g. DTensorRelayout ops or DTensorLayout op.
35StatusOr<mlir::TF::SparseToDenseOp> GetSparseToDenseOp(mlir::Value value);
36
37// Checks whether `value is an output of a SparseToDenseOp value.
38bool IsSparseValue(mlir::Value value);
39
40// Checks if `op` has any sparse value operands.
41bool HasAnySparseInput(mlir::Operation* op);
42
43// Checks if all operands of `op` is a sparse value.
44bool AllSparseInput(mlir::Operation* op);
45
46// Returns the indices component dense tensor from `value`. `value` represents
47// a SparseTensor value.
48StatusOr<mlir::Value> GetIndicesFromSparseTensor(mlir::Value value);
49
50// Returns the values component dense tensor from `value`.`value` represents
51// a SparseTensor value.
52StatusOr<mlir::Value> GetValuesFromSparseTensor(mlir::Value value);
53
54// Returns the dense shape component dense tensor from `value`. `value`
55// represents a SparseTensor value.
56StatusOr<mlir::Value> GetDenseShapesFromSparseTensor(mlir::Value value);
57
58} // namespace dtensor
59} // namespace tensorflow
60
61#endif // TENSORFLOW_DTENSOR_MLIR_SPARSE_EXPANDER_COMMON_H_
62