1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/string_ops.cc.
17
18#include <string>
19#include <utility>
20
21#include "tensorflow/core/framework/kernel_def_builder.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/framework/tensor_shape.h"
25#include "tensorflow/core/lib/core/errors.h"
26#include "tensorflow/core/lib/core/status.h"
27#include "tensorflow/core/lib/core/stringpiece.h"
28#include "tensorflow/core/lib/gtl/inlined_vector.h"
29#include "tensorflow/core/lib/strings/str_util.h"
30
31namespace tensorflow {
32
33namespace {
34
35template <typename INDICES_TYPE>
36gtl::InlinedVector<INDICES_TYPE, 8> GetFlattenedRelativeOffsets(
37 INDICES_TYPE small_stride, INDICES_TYPE big_stride) {
38 gtl::InlinedVector<INDICES_TYPE, 8> flattened_offsets(small_stride);
39 for (auto i = 0; i < small_stride; i++) {
40 flattened_offsets[i] = i * big_stride;
41 }
42 return flattened_offsets;
43}
44
45template <typename INDICES_TYPE>
46std::pair<INDICES_TYPE, INDICES_TYPE> GetStrides(
47 const TensorShape& input_shape, const TensorShape& segment_id_shape) {
48 int64_t small_stride = 1;
49 int64_t big_stride = 1;
50 for (auto i = 0; i < input_shape.dims(); i++) {
51 if (i < segment_id_shape.dims()) {
52 small_stride *= segment_id_shape.dim_size(i);
53 } else {
54 big_stride *= input_shape.dim_size(i);
55 }
56 }
57 return std::make_pair(big_stride, small_stride);
58}
59
60TensorShape GetOutputShape(const TensorShape& input_shape,
61 const TensorShape& segment_id_shape,
62 const int64_t num_segments) {
63 TensorShape output_shape;
64 output_shape.AddDim(num_segments);
65 for (size_t index = segment_id_shape.dims(); index < input_shape.dims();
66 ++index) {
67 output_shape.AddDim(input_shape.dim_size(index));
68 }
69 return output_shape;
70}
71
72} // namespace
73
74template <typename INDICES_TYPE, typename NUM_SEGMENTS_TYPE>
75class UnsortedSegmentJoinOp : public OpKernel {
76 public:
77 using OpKernel::OpKernel;
78
79 explicit UnsortedSegmentJoinOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
80 OP_REQUIRES_OK(ctx, ctx->GetAttr("separator", &separator_));
81 }
82
83 void Compute(OpKernelContext* context) override {
84 const Tensor& input = context->input(0);
85 const TensorShape& input_shape = input.shape();
86 const int32_t input_dims = input_shape.dims();
87
88 const Tensor& segment_id = context->input(1);
89 const TensorShape& segment_id_shape = segment_id.shape();
90 const int32_t segment_dims = segment_id_shape.dims();
91
92 const Tensor& num_segments_tensor = context->input(2);
93 OP_REQUIRES(context, num_segments_tensor.NumElements() != 0,
94 errors::InvalidArgument("Number of segments cannot be empty."));
95 OP_REQUIRES(context,
96 TensorShapeUtils::IsScalar(num_segments_tensor.shape()),
97 errors::InvalidArgument("Number of segments must be a scalar"));
98 auto num_segments = num_segments_tensor.scalar<NUM_SEGMENTS_TYPE>()();
99
100 OP_REQUIRES(
101 context, num_segments >= 0,
102 errors::InvalidArgument(
103 "Number of segments must be non-negative but got ", num_segments));
104 OP_REQUIRES(context, segment_dims != 0,
105 errors::InvalidArgument("Segment_id cannot have rank 0"));
106
107 OP_REQUIRES(
108 context, segment_dims <= input_dims,
109 errors::OutOfRange("Invalid segment_id rank ", segment_dims,
110 " for input with ", input_dims, " dimension(s)"));
111 for (auto i = 0; i < segment_dims; i++) {
112 OP_REQUIRES(
113 context, segment_id_shape.dim_size(i) == input_shape.dim_size(i),
114 errors::InvalidArgument(
115 "Segment dimension is ", segment_id_shape.dim_size(i),
116 " while input dimension is ", input_dims, " in rank ", i));
117 }
118
119 // Making output tensor.
120 Tensor* output_tensor = nullptr;
121 TensorShape output_shape =
122 GetOutputShape(input_shape, segment_id_shape, num_segments);
123 OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
124 &output_tensor));
125
126 // Preparating flat tensors.
127 auto output_flat = output_tensor->flat<tstring>();
128 auto flat_segment_id = segment_id.flat<INDICES_TYPE>();
129 auto flat_input = input.flat<tstring>();
130
131 for (int i = 0; i < flat_segment_id.size(); i++) {
132 OP_REQUIRES(
133 context,
134 ((flat_segment_id(i) < num_segments) && (flat_segment_id(i) >= 0)),
135 errors::InvalidArgument(
136 "segment_ids are not allowed to exceed num_segments or"
137 " to have negative values."));
138 }
139
140 int64_t big_stride;
141 int64_t small_stride;
142 std::tie(big_stride, small_stride) =
143 GetStrides<INDICES_TYPE>(input_shape, segment_id_shape);
144 auto relative_offset_set =
145 GetFlattenedRelativeOffsets<INDICES_TYPE>(small_stride, big_stride);
146 for (auto start_offset = 0; start_offset < big_stride; start_offset++) {
147 for (auto i = 0; i < relative_offset_set.size(); i++) {
148 auto output_index = start_offset + flat_segment_id(i) * big_stride;
149 auto offset = start_offset + relative_offset_set[i];
150 if (output_flat(output_index).length() != 0)
151 output_flat(output_index).append(separator_.c_str());
152 output_flat(output_index).append(flat_input(offset));
153 }
154 }
155 }
156
157 private:
158 string separator_;
159};
160
161#define REGISTER_CPU_KERNEL(indices_type, num_segments_type) \
162 REGISTER_KERNEL_BUILDER( \
163 Name("UnsortedSegmentJoin") \
164 .Device(DEVICE_CPU) \
165 .TypeConstraint<indices_type>("Tindices") \
166 .TypeConstraint<num_segments_type>("Tnumsegments"), \
167 UnsortedSegmentJoinOp<indices_type, num_segments_type>);
168
169REGISTER_CPU_KERNEL(int32, int32);
170REGISTER_CPU_KERNEL(int32, int64_t);
171REGISTER_CPU_KERNEL(int64_t, int32);
172REGISTER_CPU_KERNEL(int64_t, int64_t);
173#undef REGISTER_CPU_KERNEL
174
175} // namespace tensorflow
176