1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op.h"
17#include "tensorflow/core/framework/op_kernel.h"
18#include "tensorflow/core/platform/macros.h"
19#include "tensorflow/core/platform/types.h"
20#include "tensorflow/core/util/bcast.h"
21
22namespace tensorflow {
23
24// Given shapes of two tensors, computes the broadcast shape.
25template <typename T>
26class BCastArgsOp : public OpKernel {
27 public:
28 explicit BCastArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
29
30 void Compute(OpKernelContext* ctx) override {
31 OP_REQUIRES(
32 ctx, ctx->num_inputs() == 2,
33 errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
34 gtl::InlinedVector<BCast::Vec, 4> shapes;
35 for (int i = 0; i < ctx->num_inputs(); ++i) {
36 const Tensor& in = ctx->input(i);
37 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()),
38 errors::InvalidArgument("In[", i, "] must be a vector.",
39 in.shape().DebugString()));
40 BCast::Vec vec;
41 for (int64_t i = 0; i < in.NumElements(); ++i) {
42 vec.push_back(in.vec<T>()(i));
43 }
44 shapes.push_back(vec);
45 }
46 BCast bcast(shapes[0], shapes[1]);
47 OP_REQUIRES(ctx, bcast.IsValid(),
48 errors::InvalidArgument(
49 "Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
50 "] vs. [", absl::StrJoin(shapes[1], ","), "]"));
51 Output(ctx, 0, bcast.output_shape());
52 }
53
54 bool IsExpensive() override { return false; }
55
56 private:
57 void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) {
58 const int64_t len = v.size();
59 Tensor* o = nullptr;
60 OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o));
61 for (int64_t i = 0; i < len; ++i) {
62 o->flat<T>()(i) = static_cast<T>(v[i]);
63 }
64 }
65
66 TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp);
67};
68
69// Given shapes of two tensors, computes the reduction indices for the
70// gradient computation.
71//
72// TODO(zhifengc):
73// 1. Adds support for n-ary (n >= 2).
74template <typename T>
75class BCastGradArgsOp : public OpKernel {
76 public:
77 explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
78
79 void Compute(OpKernelContext* ctx) override {
80 OP_REQUIRES(
81 ctx, ctx->num_inputs() == 2,
82 errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
83 gtl::InlinedVector<BCast::Vec, 4> shapes;
84 for (int i = 0; i < ctx->num_inputs(); ++i) {
85 const Tensor& in = ctx->input(i);
86 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()),
87 errors::InvalidArgument("In[", i, "] must be a vector.",
88 in.shape().DebugString()));
89 BCast::Vec vec;
90 for (int64_t i = 0; i < in.NumElements(); ++i) {
91 vec.push_back(in.vec<T>()(i));
92 }
93 shapes.push_back(vec);
94 }
95 BCast bcast(shapes[0], shapes[1]);
96 OP_REQUIRES(ctx, bcast.IsValid(),
97 errors::InvalidArgument(
98 "Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
99 "] vs. [", absl::StrJoin(shapes[1], ","), "]"));
100 Output(ctx, 0, bcast.grad_x_reduce_idx());
101 Output(ctx, 1, bcast.grad_y_reduce_idx());
102 }
103
104 bool IsExpensive() override { return false; }
105
106 private:
107 void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) {
108 const int64_t len = v.size();
109 Tensor* o = nullptr;
110 OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o));
111 for (int64_t i = 0; i < len; ++i) {
112 o->flat<T>()(i) = static_cast<T>(v[i]);
113 }
114 }
115
116 TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp);
117};
118
119REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
120 .Device(DEVICE_CPU)
121 .TypeConstraint<int32>("T")
122 .HostMemory("s0")
123 .HostMemory("s1")
124 .HostMemory("r0"),
125 BCastArgsOp<int32>);
126REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
127 .Device(DEVICE_CPU)
128 .TypeConstraint<int64_t>("T")
129 .HostMemory("s0")
130 .HostMemory("s1")
131 .HostMemory("r0"),
132 BCastArgsOp<int64_t>);
133REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
134 .Device(DEVICE_DEFAULT)
135 .TypeConstraint<int32>("T")
136 .HostMemory("s0")
137 .HostMemory("s1")
138 .HostMemory("r0"),
139 BCastArgsOp<int32>);
140REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
141 .Device(DEVICE_DEFAULT)
142 .TypeConstraint<int64_t>("T")
143 .HostMemory("s0")
144 .HostMemory("s1")
145 .HostMemory("r0"),
146 BCastArgsOp<int64_t>);
147
148REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
149 .Device(DEVICE_CPU)
150 .TypeConstraint<int32>("T")
151 .HostMemory("s0")
152 .HostMemory("s1")
153 .HostMemory("r0")
154 .HostMemory("r1"),
155 BCastGradArgsOp<int32>);
156REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
157 .Device(DEVICE_CPU)
158 .TypeConstraint<int64_t>("T")
159 .HostMemory("s0")
160 .HostMemory("s1")
161 .HostMemory("r0")
162 .HostMemory("r1"),
163 BCastGradArgsOp<int64_t>);
164REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
165 .Device(DEVICE_DEFAULT)
166 .TypeConstraint<int32>("T")
167 .HostMemory("s0")
168 .HostMemory("s1")
169 .HostMemory("r0")
170 .HostMemory("r1"),
171 BCastGradArgsOp<int32>);
172REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
173 .Device(DEVICE_DEFAULT)
174 .TypeConstraint<int64_t>("T")
175 .HostMemory("s0")
176 .HostMemory("s1")
177 .HostMemory("r0")
178 .HostMemory("r1"),
179 BCastGradArgsOp<int64_t>);
180
181} // end namespace tensorflow
182