1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/framework/common_shape_fns.h" |
16 | #include "tensorflow/core/framework/op.h" |
17 | #include "tensorflow/core/framework/shape_inference.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | using shape_inference::DimensionHandle; |
22 | using shape_inference::InferenceContext; |
23 | using shape_inference::ShapeHandle; |
24 | |
25 | Status RaggedGatherShapeFn(InferenceContext* c); |
26 | |
27 | //============================================================================== |
28 | // Registered Ops |
29 | //============================================================================== |
30 | |
31 | REGISTER_OP("RaggedGather" ) |
32 | .Input("params_nested_splits: PARAMS_RAGGED_RANK * Tsplits" ) |
33 | .Input("params_dense_values: Tvalues" ) |
34 | .Input("indices: Tindices" ) |
35 | .Output("output_nested_splits: OUTPUT_RAGGED_RANK * Tsplits" ) |
36 | .Output("output_dense_values: Tvalues" ) |
37 | .Attr("Tvalues: type" ) |
38 | .Attr("Tindices: {int32, int64}" ) |
39 | .Attr("Tsplits: {int32, int64} = DT_INT64" ) |
40 | .Attr("PARAMS_RAGGED_RANK: int >= 1" ) |
41 | .Attr("OUTPUT_RAGGED_RANK: int >= 0" ) |
42 | .SetShapeFn(RaggedGatherShapeFn); |
43 | |
44 | REGISTER_OP("RaggedCross" ) |
45 | .Input("ragged_values: ragged_values_types" ) |
46 | .Input("ragged_row_splits: ragged_splits_types" ) |
47 | .Input("sparse_indices: Nsparse * int64" ) |
48 | .Input("sparse_values: sparse_values_types" ) |
49 | .Input("sparse_shape: Nsparse * int64" ) |
50 | .Input("dense_inputs: dense_types" ) |
51 | .Output("output_values: out_values_type" ) |
52 | .Output("output_row_splits: out_row_splits_type" ) |
53 | .Attr("Nsparse: int >= 0" ) |
54 | .Attr("input_order: string" ) |
55 | .Attr("hashed_output: bool" ) |
56 | .Attr("num_buckets: int >= 0" ) |
57 | .Attr("hash_key: int" ) |
58 | .Attr("ragged_values_types: list({int64, string}) >= 0" ) |
59 | .Attr("ragged_splits_types: list({int32, int64}) >= 0" ) |
60 | .Attr("sparse_values_types: list({int64, string}) >= 0" ) |
61 | .Attr("dense_types: list({int64, string}) >= 0" ) |
62 | .Attr("out_values_type: {int64, string}" ) |
63 | .Attr("out_row_splits_type: {int32, int64}" ) |
64 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
65 | std::vector<DataType> ragged_values_types; |
66 | std::vector<DataType> ragged_splits_types; |
67 | std::vector<DataType> sparse_values_types; |
68 | std::vector<DataType> dense_types; |
69 | |
70 | TF_RETURN_IF_ERROR( |
71 | c->GetAttr("ragged_values_types" , &ragged_values_types)); |
72 | TF_RETURN_IF_ERROR( |
73 | c->GetAttr("ragged_splits_types" , &ragged_splits_types)); |
74 | TF_RETURN_IF_ERROR(c->GetAttr("dense_types" , &dense_types)); |
75 | TF_RETURN_IF_ERROR( |
76 | c->GetAttr("sparse_values_types" , &sparse_values_types)); |
77 | |
78 | int num_ragged = ragged_values_types.size(); |
79 | if (num_ragged != ragged_splits_types.size()) { |
80 | return errors::InvalidArgument( |
81 | "ragged values and splits must have the same length." ); |
82 | } |
83 | |
84 | int num_sparse; |
85 | TF_RETURN_IF_ERROR(c->GetAttr("Nsparse" , &num_sparse)); |
86 | if (num_sparse != sparse_values_types.size()) { |
87 | return errors::InvalidArgument( |
88 | "sparse indices and values must have the same length" ); |
89 | } |
90 | |
91 | ShapeHandle out_values = c->UnknownShapeOfRank(1); |
92 | ShapeHandle out_splits = c->UnknownShapeOfRank(1); |
93 | |
94 | // Merge the shapes of row_splits from ragged inputs. (This is one plus |
95 | // the batch size.) |
96 | int ragged_splits_start = num_ragged; |
97 | for (int i = 0; i < ragged_splits_types.size(); ++i) { |
98 | ShapeHandle row_splits = c->input(i + ragged_splits_start); |
99 | if (!c->Merge(out_splits, row_splits, &out_splits).ok()) { |
100 | return errors::InvalidArgument( |
101 | "inputs must all have the same batch dimension size." ); |
102 | } |
103 | } |
104 | |
105 | // Merge the batch size of each dense input into out_splits. |
106 | int dense_start = num_ragged * 2 + num_sparse * 3; |
107 | for (int i = 0; i < dense_types.size(); ++i) { |
108 | ShapeHandle dense_input = c->input(i + dense_start); |
109 | int32 rank = c->Rank(dense_input); |
110 | if (rank == InferenceContext::kUnknownRank) { |
111 | continue; |
112 | } else if (rank != 2) { |
113 | return errors::InvalidArgument( |
114 | "tf.ragged.cross only supports inputs with rank=2" ); |
115 | } |
116 | int64_t batch_size = c->Value(c->Dim(dense_input, 0)); |
117 | if (batch_size != InferenceContext::kUnknownDim) { |
118 | ShapeHandle row_splits = c->Vector(batch_size + 1); |
119 | if (!c->Merge(out_splits, row_splits, &out_splits).ok()) { |
120 | return errors::InvalidArgument( |
121 | "inputs must all have the same batch dimension size." ); |
122 | } |
123 | } |
124 | } |
125 | |
126 | c->set_output(0, out_values); |
127 | c->set_output(1, out_splits); |
128 | return OkStatus(); |
129 | }); |
130 | |
131 | //============================================================================== |
132 | // Shape Functions |
133 | //============================================================================== |
134 | |
135 | Status RaggedGatherShapeFn(InferenceContext* c) { |
136 | int num_splits; |
137 | int64_t PARAMS_RAGGED_RANK; |
138 | TF_RETURN_IF_ERROR( |
139 | c->GetAttr<int64_t>("PARAMS_RAGGED_RANK" , &PARAMS_RAGGED_RANK)); |
140 | TF_RETURN_IF_ERROR(c->GetAttr<int>("OUTPUT_RAGGED_RANK" , &num_splits)); |
141 | |
142 | // Check rank of `indices`. |
143 | ShapeHandle indices = c->input(PARAMS_RAGGED_RANK + 1); |
144 | TF_RETURN_IF_ERROR( |
145 | c->WithRank(indices, num_splits - PARAMS_RAGGED_RANK + 1, &indices)); |
146 | |
147 | // Check that all params_nested_splits have rank 1. |
148 | for (int64_t i = 0; i < PARAMS_RAGGED_RANK; ++i) { |
149 | ShapeHandle splits = c->input(i); |
150 | TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits)); |
151 | } |
152 | |
153 | // Check that `params_dense_values` has rank>=1. |
154 | ShapeHandle params_dense_values = c->input(PARAMS_RAGGED_RANK); |
155 | TF_RETURN_IF_ERROR( |
156 | c->WithRankAtLeast(params_dense_values, 1, ¶ms_dense_values)); |
157 | |
158 | // Set the rank for the `splits` outputs. |
159 | for (int i = 0; i < num_splits; ++i) { |
160 | c->set_output(i, c->UnknownShapeOfRank(1)); |
161 | } |
162 | |
163 | // Calculate the `values` shape. |
164 | ShapeHandle value = c->UnknownShape(); |
165 | ShapeHandle values = c->UnknownShape(); |
166 | TF_RETURN_IF_ERROR(c->Subshape(params_dense_values, 1, &value)); |
167 | TF_RETURN_IF_ERROR(c->Concatenate(c->UnknownShapeOfRank(1), value, &values)); |
168 | c->set_output(num_splits, values); |
169 | |
170 | return OkStatus(); |
171 | } |
172 | |
173 | } // namespace tensorflow |
174 | |