1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
---|---|
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/mlir/sparse_expander_common.h" |
17 | |
18 | #include "mlir/IR/Value.h" // from @llvm-project |
19 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
20 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
21 | |
22 | namespace tensorflow { |
23 | namespace dtensor { |
24 | |
25 | StatusOr<mlir::TF::SparseToDenseOp> GetSparseToDenseOp(mlir::Value value) { |
26 | // Travel back until we see a TF op. We generally expect this value |
27 | // to be connected by a series of DTensor ops like DTensorLayout or |
28 | // various DTensorRelayout ops, so skip past the tf.DTensor ops. |
29 | auto op = value.getDefiningOp(); |
30 | while (op && op->getName().getStringRef().startswith("tf.DTensor")) { |
31 | op = op->getOperand(0).getDefiningOp(); |
32 | } |
33 | |
34 | if (op && llvm::isa<mlir::TF::SparseToDenseOp>(op)) |
35 | return llvm::dyn_cast_or_null<mlir::TF::SparseToDenseOp>(op); |
36 | return errors::NotFound("SparseToDenseOp not found from value."); |
37 | } |
38 | |
39 | bool IsSparseValue(mlir::Value value) { return GetSparseToDenseOp(value).ok(); } |
40 | |
41 | bool HasAnySparseInput(mlir::Operation* op) { |
42 | for (auto operand : op->getOperands()) |
43 | if (IsSparseValue(operand)) return true; |
44 | return false; |
45 | } |
46 | |
47 | bool AllSparseInput(mlir::Operation* op) { |
48 | for (auto operand : op->getOperands()) |
49 | if (!IsSparseValue(operand)) return false; |
50 | return true; |
51 | } |
52 | |
53 | StatusOr<mlir::Value> GetIndicesFromSparseTensor(mlir::Value value) { |
54 | auto sparse_op = GetSparseToDenseOp(value); |
55 | if (!sparse_op.ok()) |
56 | return errors::NotFound( |
57 | "Indices tensor not found from value because it was not from a " |
58 | "SparseTensor."); |
59 | return sparse_op->getOperand(0); |
60 | } |
61 | |
62 | StatusOr<mlir::Value> GetValuesFromSparseTensor(mlir::Value value) { |
63 | auto sparse_op = GetSparseToDenseOp(value); |
64 | if (!sparse_op.ok()) |
65 | return errors::NotFound( |
66 | "Values tensor not found from value because it was not from a " |
67 | "SparseTensor."); |
68 | return sparse_op->getOperand(2); |
69 | } |
70 | |
71 | StatusOr<mlir::Value> GetDenseShapesFromSparseTensor(mlir::Value value) { |
72 | auto sparse_op = GetSparseToDenseOp(value); |
73 | if (!sparse_op.ok()) |
74 | return errors::NotFound( |
75 | "Dense shape tensor not found from value because it was not from a " |
76 | "SparseTensor."); |
77 | return sparse_op->getOperand(1); |
78 | } |
79 | |
80 | } // namespace dtensor |
81 | } // namespace tensorflow |
82 |