1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op.h" |
17 | #include "tensorflow/core/framework/op_kernel.h" |
18 | #include "tensorflow/core/platform/macros.h" |
19 | #include "tensorflow/core/platform/types.h" |
20 | #include "tensorflow/core/util/bcast.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | // Given shapes of two tensors, computes the broadcast shape. |
25 | template <typename T> |
26 | class BCastArgsOp : public OpKernel { |
27 | public: |
28 | explicit BCastArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
29 | |
30 | void Compute(OpKernelContext* ctx) override { |
31 | OP_REQUIRES( |
32 | ctx, ctx->num_inputs() == 2, |
33 | errors::Unimplemented("Broadcast for n-ary operations (n > 2)" )); |
34 | gtl::InlinedVector<BCast::Vec, 4> shapes; |
35 | for (int i = 0; i < ctx->num_inputs(); ++i) { |
36 | const Tensor& in = ctx->input(i); |
37 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()), |
38 | errors::InvalidArgument("In[" , i, "] must be a vector." , |
39 | in.shape().DebugString())); |
40 | BCast::Vec vec; |
41 | for (int64_t i = 0; i < in.NumElements(); ++i) { |
42 | vec.push_back(in.vec<T>()(i)); |
43 | } |
44 | shapes.push_back(vec); |
45 | } |
46 | BCast bcast(shapes[0], shapes[1]); |
47 | OP_REQUIRES(ctx, bcast.IsValid(), |
48 | errors::InvalidArgument( |
49 | "Incompatible shapes: [" , absl::StrJoin(shapes[0], "," ), |
50 | "] vs. [" , absl::StrJoin(shapes[1], "," ), "]" )); |
51 | Output(ctx, 0, bcast.output_shape()); |
52 | } |
53 | |
54 | bool IsExpensive() override { return false; } |
55 | |
56 | private: |
57 | void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) { |
58 | const int64_t len = v.size(); |
59 | Tensor* o = nullptr; |
60 | OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o)); |
61 | for (int64_t i = 0; i < len; ++i) { |
62 | o->flat<T>()(i) = static_cast<T>(v[i]); |
63 | } |
64 | } |
65 | |
66 | TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp); |
67 | }; |
68 | |
69 | // Given shapes of two tensors, computes the reduction indices for the |
70 | // gradient computation. |
71 | // |
72 | // TODO(zhifengc): |
73 | // 1. Adds support for n-ary (n >= 2). |
74 | template <typename T> |
75 | class BCastGradArgsOp : public OpKernel { |
76 | public: |
77 | explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
78 | |
79 | void Compute(OpKernelContext* ctx) override { |
80 | OP_REQUIRES( |
81 | ctx, ctx->num_inputs() == 2, |
82 | errors::Unimplemented("Broadcast for n-ary operations (n > 2)" )); |
83 | gtl::InlinedVector<BCast::Vec, 4> shapes; |
84 | for (int i = 0; i < ctx->num_inputs(); ++i) { |
85 | const Tensor& in = ctx->input(i); |
86 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()), |
87 | errors::InvalidArgument("In[" , i, "] must be a vector." , |
88 | in.shape().DebugString())); |
89 | BCast::Vec vec; |
90 | for (int64_t i = 0; i < in.NumElements(); ++i) { |
91 | vec.push_back(in.vec<T>()(i)); |
92 | } |
93 | shapes.push_back(vec); |
94 | } |
95 | BCast bcast(shapes[0], shapes[1]); |
96 | OP_REQUIRES(ctx, bcast.IsValid(), |
97 | errors::InvalidArgument( |
98 | "Incompatible shapes: [" , absl::StrJoin(shapes[0], "," ), |
99 | "] vs. [" , absl::StrJoin(shapes[1], "," ), "]" )); |
100 | Output(ctx, 0, bcast.grad_x_reduce_idx()); |
101 | Output(ctx, 1, bcast.grad_y_reduce_idx()); |
102 | } |
103 | |
104 | bool IsExpensive() override { return false; } |
105 | |
106 | private: |
107 | void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) { |
108 | const int64_t len = v.size(); |
109 | Tensor* o = nullptr; |
110 | OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o)); |
111 | for (int64_t i = 0; i < len; ++i) { |
112 | o->flat<T>()(i) = static_cast<T>(v[i]); |
113 | } |
114 | } |
115 | |
116 | TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp); |
117 | }; |
118 | |
119 | REGISTER_KERNEL_BUILDER(Name("BroadcastArgs" ) |
120 | .Device(DEVICE_CPU) |
121 | .TypeConstraint<int32>("T" ) |
122 | .HostMemory("s0" ) |
123 | .HostMemory("s1" ) |
124 | .HostMemory("r0" ), |
125 | BCastArgsOp<int32>); |
126 | REGISTER_KERNEL_BUILDER(Name("BroadcastArgs" ) |
127 | .Device(DEVICE_CPU) |
128 | .TypeConstraint<int64_t>("T" ) |
129 | .HostMemory("s0" ) |
130 | .HostMemory("s1" ) |
131 | .HostMemory("r0" ), |
132 | BCastArgsOp<int64_t>); |
133 | REGISTER_KERNEL_BUILDER(Name("BroadcastArgs" ) |
134 | .Device(DEVICE_DEFAULT) |
135 | .TypeConstraint<int32>("T" ) |
136 | .HostMemory("s0" ) |
137 | .HostMemory("s1" ) |
138 | .HostMemory("r0" ), |
139 | BCastArgsOp<int32>); |
140 | REGISTER_KERNEL_BUILDER(Name("BroadcastArgs" ) |
141 | .Device(DEVICE_DEFAULT) |
142 | .TypeConstraint<int64_t>("T" ) |
143 | .HostMemory("s0" ) |
144 | .HostMemory("s1" ) |
145 | .HostMemory("r0" ), |
146 | BCastArgsOp<int64_t>); |
147 | |
148 | REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs" ) |
149 | .Device(DEVICE_CPU) |
150 | .TypeConstraint<int32>("T" ) |
151 | .HostMemory("s0" ) |
152 | .HostMemory("s1" ) |
153 | .HostMemory("r0" ) |
154 | .HostMemory("r1" ), |
155 | BCastGradArgsOp<int32>); |
156 | REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs" ) |
157 | .Device(DEVICE_CPU) |
158 | .TypeConstraint<int64_t>("T" ) |
159 | .HostMemory("s0" ) |
160 | .HostMemory("s1" ) |
161 | .HostMemory("r0" ) |
162 | .HostMemory("r1" ), |
163 | BCastGradArgsOp<int64_t>); |
164 | REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs" ) |
165 | .Device(DEVICE_DEFAULT) |
166 | .TypeConstraint<int32>("T" ) |
167 | .HostMemory("s0" ) |
168 | .HostMemory("s1" ) |
169 | .HostMemory("r0" ) |
170 | .HostMemory("r1" ), |
171 | BCastGradArgsOp<int32>); |
172 | REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs" ) |
173 | .Device(DEVICE_DEFAULT) |
174 | .TypeConstraint<int64_t>("T" ) |
175 | .HostMemory("s0" ) |
176 | .HostMemory("s1" ) |
177 | .HostMemory("r0" ) |
178 | .HostMemory("r1" ), |
179 | BCastGradArgsOp<int64_t>); |
180 | |
181 | } // end namespace tensorflow |
182 | |