1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/common_shape_fns.h"
16#include "tensorflow/core/framework/op.h"
17#include "tensorflow/core/framework/shape_inference.h"
18
19namespace tensorflow {
20
21using shape_inference::DimensionHandle;
22using shape_inference::InferenceContext;
23using shape_inference::ShapeHandle;
24
25Status RaggedGatherShapeFn(InferenceContext* c);
26
27//==============================================================================
28// Registered Ops
29//==============================================================================
30
31REGISTER_OP("RaggedGather")
32 .Input("params_nested_splits: PARAMS_RAGGED_RANK * Tsplits")
33 .Input("params_dense_values: Tvalues")
34 .Input("indices: Tindices")
35 .Output("output_nested_splits: OUTPUT_RAGGED_RANK * Tsplits")
36 .Output("output_dense_values: Tvalues")
37 .Attr("Tvalues: type")
38 .Attr("Tindices: {int32, int64}")
39 .Attr("Tsplits: {int32, int64} = DT_INT64")
40 .Attr("PARAMS_RAGGED_RANK: int >= 1")
41 .Attr("OUTPUT_RAGGED_RANK: int >= 0")
42 .SetShapeFn(RaggedGatherShapeFn);
43
44REGISTER_OP("RaggedCross")
45 .Input("ragged_values: ragged_values_types")
46 .Input("ragged_row_splits: ragged_splits_types")
47 .Input("sparse_indices: Nsparse * int64")
48 .Input("sparse_values: sparse_values_types")
49 .Input("sparse_shape: Nsparse * int64")
50 .Input("dense_inputs: dense_types")
51 .Output("output_values: out_values_type")
52 .Output("output_row_splits: out_row_splits_type")
53 .Attr("Nsparse: int >= 0")
54 .Attr("input_order: string")
55 .Attr("hashed_output: bool")
56 .Attr("num_buckets: int >= 0")
57 .Attr("hash_key: int")
58 .Attr("ragged_values_types: list({int64, string}) >= 0")
59 .Attr("ragged_splits_types: list({int32, int64}) >= 0")
60 .Attr("sparse_values_types: list({int64, string}) >= 0")
61 .Attr("dense_types: list({int64, string}) >= 0")
62 .Attr("out_values_type: {int64, string}")
63 .Attr("out_row_splits_type: {int32, int64}")
64 .SetShapeFn([](shape_inference::InferenceContext* c) {
65 std::vector<DataType> ragged_values_types;
66 std::vector<DataType> ragged_splits_types;
67 std::vector<DataType> sparse_values_types;
68 std::vector<DataType> dense_types;
69
70 TF_RETURN_IF_ERROR(
71 c->GetAttr("ragged_values_types", &ragged_values_types));
72 TF_RETURN_IF_ERROR(
73 c->GetAttr("ragged_splits_types", &ragged_splits_types));
74 TF_RETURN_IF_ERROR(c->GetAttr("dense_types", &dense_types));
75 TF_RETURN_IF_ERROR(
76 c->GetAttr("sparse_values_types", &sparse_values_types));
77
78 int num_ragged = ragged_values_types.size();
79 if (num_ragged != ragged_splits_types.size()) {
80 return errors::InvalidArgument(
81 "ragged values and splits must have the same length.");
82 }
83
84 int num_sparse;
85 TF_RETURN_IF_ERROR(c->GetAttr("Nsparse", &num_sparse));
86 if (num_sparse != sparse_values_types.size()) {
87 return errors::InvalidArgument(
88 "sparse indices and values must have the same length");
89 }
90
91 ShapeHandle out_values = c->UnknownShapeOfRank(1);
92 ShapeHandle out_splits = c->UnknownShapeOfRank(1);
93
94 // Merge the shapes of row_splits from ragged inputs. (This is one plus
95 // the batch size.)
96 int ragged_splits_start = num_ragged;
97 for (int i = 0; i < ragged_splits_types.size(); ++i) {
98 ShapeHandle row_splits = c->input(i + ragged_splits_start);
99 if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
100 return errors::InvalidArgument(
101 "inputs must all have the same batch dimension size.");
102 }
103 }
104
105 // Merge the batch size of each dense input into out_splits.
106 int dense_start = num_ragged * 2 + num_sparse * 3;
107 for (int i = 0; i < dense_types.size(); ++i) {
108 ShapeHandle dense_input = c->input(i + dense_start);
109 int32 rank = c->Rank(dense_input);
110 if (rank == InferenceContext::kUnknownRank) {
111 continue;
112 } else if (rank != 2) {
113 return errors::InvalidArgument(
114 "tf.ragged.cross only supports inputs with rank=2");
115 }
116 int64_t batch_size = c->Value(c->Dim(dense_input, 0));
117 if (batch_size != InferenceContext::kUnknownDim) {
118 ShapeHandle row_splits = c->Vector(batch_size + 1);
119 if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
120 return errors::InvalidArgument(
121 "inputs must all have the same batch dimension size.");
122 }
123 }
124 }
125
126 c->set_output(0, out_values);
127 c->set_output(1, out_splits);
128 return OkStatus();
129 });
130
131//==============================================================================
132// Shape Functions
133//==============================================================================
134
135Status RaggedGatherShapeFn(InferenceContext* c) {
136 int num_splits;
137 int64_t PARAMS_RAGGED_RANK;
138 TF_RETURN_IF_ERROR(
139 c->GetAttr<int64_t>("PARAMS_RAGGED_RANK", &PARAMS_RAGGED_RANK));
140 TF_RETURN_IF_ERROR(c->GetAttr<int>("OUTPUT_RAGGED_RANK", &num_splits));
141
142 // Check rank of `indices`.
143 ShapeHandle indices = c->input(PARAMS_RAGGED_RANK + 1);
144 TF_RETURN_IF_ERROR(
145 c->WithRank(indices, num_splits - PARAMS_RAGGED_RANK + 1, &indices));
146
147 // Check that all params_nested_splits have rank 1.
148 for (int64_t i = 0; i < PARAMS_RAGGED_RANK; ++i) {
149 ShapeHandle splits = c->input(i);
150 TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
151 }
152
153 // Check that `params_dense_values` has rank>=1.
154 ShapeHandle params_dense_values = c->input(PARAMS_RAGGED_RANK);
155 TF_RETURN_IF_ERROR(
156 c->WithRankAtLeast(params_dense_values, 1, &params_dense_values));
157
158 // Set the rank for the `splits` outputs.
159 for (int i = 0; i < num_splits; ++i) {
160 c->set_output(i, c->UnknownShapeOfRank(1));
161 }
162
163 // Calculate the `values` shape.
164 ShapeHandle value = c->UnknownShape();
165 ShapeHandle values = c->UnknownShape();
166 TF_RETURN_IF_ERROR(c->Subshape(params_dense_values, 1, &value));
167 TF_RETURN_IF_ERROR(c->Concatenate(c->UnknownShapeOfRank(1), value, &values));
168 c->set_output(num_splits, values);
169
170 return OkStatus();
171}
172
173} // namespace tensorflow
174