1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/data_flow_ops.cc. |
17 | |
18 | #include <vector> |
19 | #include "tensorflow/core/framework/bounds_check.h" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/register_types.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/framework/types.h" |
24 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
25 | #include "tensorflow/core/util/util.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // Shared code that is not dependent on the type of T. We do this to reduce |
30 | // code size by not duplicating all this for all T (float, double, int32, etc.) |
31 | class DynamicPartitionOp_Shared : public OpKernel { |
32 | public: |
33 | explicit DynamicPartitionOp_Shared(OpKernelConstruction* c) : OpKernel(c) { |
34 | OP_REQUIRES_OK(c, c->GetAttr("num_partitions" , &num_partitions_)); |
35 | // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, etc. |
36 | // to input[1]. Should we have the framework do some sort of |
37 | // integer promotion automatically, or should that be something |
38 | // that users have to do explicitly with a conversion operator |
39 | // in the graph? |
40 | } |
41 | |
42 | void ValidateAndAllocateOutputs(OpKernelContext* c, const Tensor** data, |
43 | const Tensor** partitions, |
44 | OpOutputList* Tout) { |
45 | OP_REQUIRES_OK(c, c->input("data" , data)); |
46 | OP_REQUIRES_OK(c, c->input("partitions" , partitions)); |
47 | OP_REQUIRES( |
48 | c, |
49 | TensorShapeUtils::StartsWith((*data)->shape(), (*partitions)->shape()), |
50 | errors::InvalidArgument( |
51 | "data.shape must start with partitions.shape, " , |
52 | "got data.shape = " , (*data)->shape().DebugString(), |
53 | ", partitions.shape = " , (*partitions)->shape().DebugString())); |
54 | |
55 | // Count how many occurrences of each partition id we have in partitions |
56 | gtl::InlinedVector<int, 32> partition_count(num_partitions_); |
57 | auto e_partitions = (*partitions)->flat<int32>(); |
58 | const int64_t N = e_partitions.dimension(0); |
59 | for (int64_t i = 0; i < N; i++) { |
60 | const int32_t p = internal::SubtleMustCopy(e_partitions(i)); |
61 | OP_REQUIRES(c, FastBoundsCheck(p, num_partitions_), |
62 | errors::InvalidArgument( |
63 | "partitions" , SliceDebugString((*partitions)->shape(), i), |
64 | " = " , p, " is not in [0, " , num_partitions_, ")" )); |
65 | partition_count[p]++; |
66 | } |
67 | |
68 | // Allocate output tensors of the right size |
69 | OP_REQUIRES_OK(c, c->output_list("outputs" , Tout)); |
70 | for (int p = 0; p < num_partitions_; p++) { |
71 | TensorShape shape; |
72 | shape.AddDim(partition_count[p]); |
73 | for (int i = (*partitions)->dims(); i < (*data)->dims(); i++) { |
74 | shape.AddDim((*data)->dim_size(i)); |
75 | } |
76 | Tensor* out; |
77 | OP_REQUIRES_OK(c, Tout->allocate(p, shape, &out)); |
78 | } |
79 | } |
80 | |
81 | protected: |
82 | int num_partitions_; |
83 | }; |
84 | |
85 | template <class T> |
86 | class DynamicPartitionOp : public DynamicPartitionOp_Shared { |
87 | public: |
88 | explicit DynamicPartitionOp(OpKernelConstruction* c) |
89 | : DynamicPartitionOp_Shared(c) {} |
90 | void Compute(OpKernelContext* c) override { |
91 | const Tensor* data; |
92 | const Tensor* partitions; |
93 | OpOutputList outputs; |
94 | ValidateAndAllocateOutputs(c, &data, &partitions, &outputs); |
95 | if (!c->status().ok()) return; |
96 | if (num_partitions_ == 0 || data->NumElements() == 0) return; |
97 | |
98 | auto e_partitions = partitions->flat<int32>(); |
99 | const int64_t N = e_partitions.dimension(0); |
100 | gtl::InlinedVector<int, 32> output_index(num_partitions_); |
101 | |
102 | if (partitions->dims() == data->dims()) { |
103 | // Walk through data and copy the data to the appropriate output tensor |
104 | const auto data_flat = data->flat<T>(); |
105 | std::vector<Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, |
106 | Eigen::Aligned> > |
107 | out_vec; |
108 | out_vec.reserve(num_partitions_); |
109 | for (int p = 0; p < num_partitions_; p++) { |
110 | out_vec.push_back(outputs[p]->vec<T>()); |
111 | } |
112 | for (int64_t i = 0; i < N; i++) { |
113 | const int32_t p = internal::SubtleMustCopy(e_partitions(i)); |
114 | OP_REQUIRES( |
115 | c, FastBoundsCheck(p, num_partitions_), |
116 | errors::InvalidArgument("indices[" , i, "] is out of range" )); |
117 | auto oi = output_index[p]; |
118 | OP_REQUIRES(c, FastBoundsCheck(oi, out_vec[p].size()), |
119 | errors::InvalidArgument( |
120 | "out_vec[" , p, "] size: " , out_vec[p].size(), |
121 | " is not LTE output_index[" , p, "] : " , oi)); |
122 | out_vec[p](oi) = data_flat(i); |
123 | output_index[p]++; |
124 | } |
125 | } else { |
126 | // If data has extra dimensions, use Eigen slices |
127 | std::vector<Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, |
128 | Eigen::Aligned> > |
129 | out_flat; |
130 | out_flat.reserve(num_partitions_); |
131 | for (int p = 0; p < num_partitions_; p++) { |
132 | out_flat.push_back(outputs[p]->flat_outer_dims<T>()); |
133 | } |
134 | |
135 | // Walk through data and copy the data to the appropriate output tensor |
136 | const int64_t slice_size = data->NumElements() / N; |
137 | const auto data_flat = data->shaped<T, 2>({N, slice_size}); |
138 | Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size); |
139 | for (int64_t i = 0; i < N; i++) { |
140 | // outputs[p][output_index[p]++] = data[i] |
141 | const int32_t p = internal::SubtleMustCopy(e_partitions(i)); |
142 | OP_REQUIRES( |
143 | c, FastBoundsCheck(p, num_partitions_), |
144 | errors::InvalidArgument("indices[" , i, |
145 | "] has been asynchronously overwritten and " |
146 | "is no longer in range!" )); |
147 | auto oi = output_index[p]; |
148 | OP_REQUIRES(c, FastBoundsCheck(oi, out_flat[p].dimension(0)), |
149 | errors::InvalidArgument("Size of output_index: " , oi, |
150 | " is no longer in range." )); |
151 | Eigen::DSizes<Eigen::DenseIndex, 2> out_indices(oi, 0); |
152 | Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0); |
153 | out_flat[p].slice(out_indices, sizes) = |
154 | data_flat.slice(data_indices, sizes); |
155 | output_index[p]++; |
156 | } |
157 | } |
158 | } |
159 | }; |
160 | |
161 | #define REGISTER_DYNAMIC_PARTITION(T) \ |
162 | REGISTER_KERNEL_BUILDER( \ |
163 | Name("DynamicPartition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
164 | DynamicPartitionOp<T>) |
165 | |
166 | TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION); |
167 | #undef REGISTER_DYNAMIC_PARTITION |
168 | |
169 | } // namespace tensorflow |
170 | |