1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/data_flow_ops.cc.
17
18#include <vector>
19#include "tensorflow/core/framework/bounds_check.h"
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/register_types.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/framework/types.h"
24#include "tensorflow/core/lib/gtl/inlined_vector.h"
25#include "tensorflow/core/util/util.h"
26
27namespace tensorflow {
28
29// Shared code that is not dependent on the type of T. We do this to reduce
30// code size by not duplicating all this for all T (float, double, int32, etc.)
31class DynamicPartitionOp_Shared : public OpKernel {
32 public:
33 explicit DynamicPartitionOp_Shared(OpKernelConstruction* c) : OpKernel(c) {
34 OP_REQUIRES_OK(c, c->GetAttr("num_partitions", &num_partitions_));
35 // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, etc.
36 // to input[1]. Should we have the framework do some sort of
37 // integer promotion automatically, or should that be something
38 // that users have to do explicitly with a conversion operator
39 // in the graph?
40 }
41
42 void ValidateAndAllocateOutputs(OpKernelContext* c, const Tensor** data,
43 const Tensor** partitions,
44 OpOutputList* Tout) {
45 OP_REQUIRES_OK(c, c->input("data", data));
46 OP_REQUIRES_OK(c, c->input("partitions", partitions));
47 OP_REQUIRES(
48 c,
49 TensorShapeUtils::StartsWith((*data)->shape(), (*partitions)->shape()),
50 errors::InvalidArgument(
51 "data.shape must start with partitions.shape, ",
52 "got data.shape = ", (*data)->shape().DebugString(),
53 ", partitions.shape = ", (*partitions)->shape().DebugString()));
54
55 // Count how many occurrences of each partition id we have in partitions
56 gtl::InlinedVector<int, 32> partition_count(num_partitions_);
57 auto e_partitions = (*partitions)->flat<int32>();
58 const int64_t N = e_partitions.dimension(0);
59 for (int64_t i = 0; i < N; i++) {
60 const int32_t p = internal::SubtleMustCopy(e_partitions(i));
61 OP_REQUIRES(c, FastBoundsCheck(p, num_partitions_),
62 errors::InvalidArgument(
63 "partitions", SliceDebugString((*partitions)->shape(), i),
64 " = ", p, " is not in [0, ", num_partitions_, ")"));
65 partition_count[p]++;
66 }
67
68 // Allocate output tensors of the right size
69 OP_REQUIRES_OK(c, c->output_list("outputs", Tout));
70 for (int p = 0; p < num_partitions_; p++) {
71 TensorShape shape;
72 shape.AddDim(partition_count[p]);
73 for (int i = (*partitions)->dims(); i < (*data)->dims(); i++) {
74 shape.AddDim((*data)->dim_size(i));
75 }
76 Tensor* out;
77 OP_REQUIRES_OK(c, Tout->allocate(p, shape, &out));
78 }
79 }
80
81 protected:
82 int num_partitions_;
83};
84
85template <class T>
86class DynamicPartitionOp : public DynamicPartitionOp_Shared {
87 public:
88 explicit DynamicPartitionOp(OpKernelConstruction* c)
89 : DynamicPartitionOp_Shared(c) {}
90 void Compute(OpKernelContext* c) override {
91 const Tensor* data;
92 const Tensor* partitions;
93 OpOutputList outputs;
94 ValidateAndAllocateOutputs(c, &data, &partitions, &outputs);
95 if (!c->status().ok()) return;
96 if (num_partitions_ == 0 || data->NumElements() == 0) return;
97
98 auto e_partitions = partitions->flat<int32>();
99 const int64_t N = e_partitions.dimension(0);
100 gtl::InlinedVector<int, 32> output_index(num_partitions_);
101
102 if (partitions->dims() == data->dims()) {
103 // Walk through data and copy the data to the appropriate output tensor
104 const auto data_flat = data->flat<T>();
105 std::vector<Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
106 Eigen::Aligned> >
107 out_vec;
108 out_vec.reserve(num_partitions_);
109 for (int p = 0; p < num_partitions_; p++) {
110 out_vec.push_back(outputs[p]->vec<T>());
111 }
112 for (int64_t i = 0; i < N; i++) {
113 const int32_t p = internal::SubtleMustCopy(e_partitions(i));
114 OP_REQUIRES(
115 c, FastBoundsCheck(p, num_partitions_),
116 errors::InvalidArgument("indices[", i, "] is out of range"));
117 auto oi = output_index[p];
118 OP_REQUIRES(c, FastBoundsCheck(oi, out_vec[p].size()),
119 errors::InvalidArgument(
120 "out_vec[", p, "] size: ", out_vec[p].size(),
121 " is not LTE output_index[", p, "] : ", oi));
122 out_vec[p](oi) = data_flat(i);
123 output_index[p]++;
124 }
125 } else {
126 // If data has extra dimensions, use Eigen slices
127 std::vector<Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
128 Eigen::Aligned> >
129 out_flat;
130 out_flat.reserve(num_partitions_);
131 for (int p = 0; p < num_partitions_; p++) {
132 out_flat.push_back(outputs[p]->flat_outer_dims<T>());
133 }
134
135 // Walk through data and copy the data to the appropriate output tensor
136 const int64_t slice_size = data->NumElements() / N;
137 const auto data_flat = data->shaped<T, 2>({N, slice_size});
138 Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size);
139 for (int64_t i = 0; i < N; i++) {
140 // outputs[p][output_index[p]++] = data[i]
141 const int32_t p = internal::SubtleMustCopy(e_partitions(i));
142 OP_REQUIRES(
143 c, FastBoundsCheck(p, num_partitions_),
144 errors::InvalidArgument("indices[", i,
145 "] has been asynchronously overwritten and "
146 "is no longer in range!"));
147 auto oi = output_index[p];
148 OP_REQUIRES(c, FastBoundsCheck(oi, out_flat[p].dimension(0)),
149 errors::InvalidArgument("Size of output_index: ", oi,
150 " is no longer in range."));
151 Eigen::DSizes<Eigen::DenseIndex, 2> out_indices(oi, 0);
152 Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0);
153 out_flat[p].slice(out_indices, sizes) =
154 data_flat.slice(data_indices, sizes);
155 output_index[p]++;
156 }
157 }
158 }
159};
160
161#define REGISTER_DYNAMIC_PARTITION(T) \
162 REGISTER_KERNEL_BUILDER( \
163 Name("DynamicPartition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
164 DynamicPartitionOp<T>)
165
166TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION);
167#undef REGISTER_DYNAMIC_PARTITION
168
169} // namespace tensorflow
170