1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/string_ops.cc. |
17 | |
18 | #include <string> |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/core/framework/kernel_def_builder.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor_shape.h" |
25 | #include "tensorflow/core/lib/core/errors.h" |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | #include "tensorflow/core/lib/core/stringpiece.h" |
28 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
29 | #include "tensorflow/core/lib/strings/str_util.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | namespace { |
34 | |
35 | template <typename INDICES_TYPE> |
36 | gtl::InlinedVector<INDICES_TYPE, 8> GetFlattenedRelativeOffsets( |
37 | INDICES_TYPE small_stride, INDICES_TYPE big_stride) { |
38 | gtl::InlinedVector<INDICES_TYPE, 8> flattened_offsets(small_stride); |
39 | for (auto i = 0; i < small_stride; i++) { |
40 | flattened_offsets[i] = i * big_stride; |
41 | } |
42 | return flattened_offsets; |
43 | } |
44 | |
45 | template <typename INDICES_TYPE> |
46 | std::pair<INDICES_TYPE, INDICES_TYPE> GetStrides( |
47 | const TensorShape& input_shape, const TensorShape& segment_id_shape) { |
48 | int64_t small_stride = 1; |
49 | int64_t big_stride = 1; |
50 | for (auto i = 0; i < input_shape.dims(); i++) { |
51 | if (i < segment_id_shape.dims()) { |
52 | small_stride *= segment_id_shape.dim_size(i); |
53 | } else { |
54 | big_stride *= input_shape.dim_size(i); |
55 | } |
56 | } |
57 | return std::make_pair(big_stride, small_stride); |
58 | } |
59 | |
60 | TensorShape GetOutputShape(const TensorShape& input_shape, |
61 | const TensorShape& segment_id_shape, |
62 | const int64_t num_segments) { |
63 | TensorShape output_shape; |
64 | output_shape.AddDim(num_segments); |
65 | for (size_t index = segment_id_shape.dims(); index < input_shape.dims(); |
66 | ++index) { |
67 | output_shape.AddDim(input_shape.dim_size(index)); |
68 | } |
69 | return output_shape; |
70 | } |
71 | |
72 | } // namespace |
73 | |
74 | template <typename INDICES_TYPE, typename NUM_SEGMENTS_TYPE> |
75 | class UnsortedSegmentJoinOp : public OpKernel { |
76 | public: |
77 | using OpKernel::OpKernel; |
78 | |
79 | explicit UnsortedSegmentJoinOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
80 | OP_REQUIRES_OK(ctx, ctx->GetAttr("separator" , &separator_)); |
81 | } |
82 | |
83 | void Compute(OpKernelContext* context) override { |
84 | const Tensor& input = context->input(0); |
85 | const TensorShape& input_shape = input.shape(); |
86 | const int32_t input_dims = input_shape.dims(); |
87 | |
88 | const Tensor& segment_id = context->input(1); |
89 | const TensorShape& segment_id_shape = segment_id.shape(); |
90 | const int32_t segment_dims = segment_id_shape.dims(); |
91 | |
92 | const Tensor& num_segments_tensor = context->input(2); |
93 | OP_REQUIRES(context, num_segments_tensor.NumElements() != 0, |
94 | errors::InvalidArgument("Number of segments cannot be empty." )); |
95 | OP_REQUIRES(context, |
96 | TensorShapeUtils::IsScalar(num_segments_tensor.shape()), |
97 | errors::InvalidArgument("Number of segments must be a scalar" )); |
98 | auto num_segments = num_segments_tensor.scalar<NUM_SEGMENTS_TYPE>()(); |
99 | |
100 | OP_REQUIRES( |
101 | context, num_segments >= 0, |
102 | errors::InvalidArgument( |
103 | "Number of segments must be non-negative but got " , num_segments)); |
104 | OP_REQUIRES(context, segment_dims != 0, |
105 | errors::InvalidArgument("Segment_id cannot have rank 0" )); |
106 | |
107 | OP_REQUIRES( |
108 | context, segment_dims <= input_dims, |
109 | errors::OutOfRange("Invalid segment_id rank " , segment_dims, |
110 | " for input with " , input_dims, " dimension(s)" )); |
111 | for (auto i = 0; i < segment_dims; i++) { |
112 | OP_REQUIRES( |
113 | context, segment_id_shape.dim_size(i) == input_shape.dim_size(i), |
114 | errors::InvalidArgument( |
115 | "Segment dimension is " , segment_id_shape.dim_size(i), |
116 | " while input dimension is " , input_dims, " in rank " , i)); |
117 | } |
118 | |
119 | // Making output tensor. |
120 | Tensor* output_tensor = nullptr; |
121 | TensorShape output_shape = |
122 | GetOutputShape(input_shape, segment_id_shape, num_segments); |
123 | OP_REQUIRES_OK(context, context->allocate_output("output" , output_shape, |
124 | &output_tensor)); |
125 | |
126 | // Preparating flat tensors. |
127 | auto output_flat = output_tensor->flat<tstring>(); |
128 | auto flat_segment_id = segment_id.flat<INDICES_TYPE>(); |
129 | auto flat_input = input.flat<tstring>(); |
130 | |
131 | for (int i = 0; i < flat_segment_id.size(); i++) { |
132 | OP_REQUIRES( |
133 | context, |
134 | ((flat_segment_id(i) < num_segments) && (flat_segment_id(i) >= 0)), |
135 | errors::InvalidArgument( |
136 | "segment_ids are not allowed to exceed num_segments or" |
137 | " to have negative values." )); |
138 | } |
139 | |
140 | int64_t big_stride; |
141 | int64_t small_stride; |
142 | std::tie(big_stride, small_stride) = |
143 | GetStrides<INDICES_TYPE>(input_shape, segment_id_shape); |
144 | auto relative_offset_set = |
145 | GetFlattenedRelativeOffsets<INDICES_TYPE>(small_stride, big_stride); |
146 | for (auto start_offset = 0; start_offset < big_stride; start_offset++) { |
147 | for (auto i = 0; i < relative_offset_set.size(); i++) { |
148 | auto output_index = start_offset + flat_segment_id(i) * big_stride; |
149 | auto offset = start_offset + relative_offset_set[i]; |
150 | if (output_flat(output_index).length() != 0) |
151 | output_flat(output_index).append(separator_.c_str()); |
152 | output_flat(output_index).append(flat_input(offset)); |
153 | } |
154 | } |
155 | } |
156 | |
157 | private: |
158 | string separator_; |
159 | }; |
160 | |
161 | #define REGISTER_CPU_KERNEL(indices_type, num_segments_type) \ |
162 | REGISTER_KERNEL_BUILDER( \ |
163 | Name("UnsortedSegmentJoin") \ |
164 | .Device(DEVICE_CPU) \ |
165 | .TypeConstraint<indices_type>("Tindices") \ |
166 | .TypeConstraint<num_segments_type>("Tnumsegments"), \ |
167 | UnsortedSegmentJoinOp<indices_type, num_segments_type>); |
168 | |
169 | REGISTER_CPU_KERNEL(int32, int32); |
170 | REGISTER_CPU_KERNEL(int32, int64_t); |
171 | REGISTER_CPU_KERNEL(int64_t, int32); |
172 | REGISTER_CPU_KERNEL(int64_t, int64_t); |
173 | #undef REGISTER_CPU_KERNEL |
174 | |
175 | } // namespace tensorflow |
176 | |