1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/cwise_ops_common.h" |
17 | |
18 | namespace tensorflow { |
19 | |
20 | BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out, |
21 | DataType in) |
22 | : OpKernel(ctx) { |
23 | #if !defined(INTEL_MKL) || !defined(ENABLE_MKL) |
24 | OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out})); |
25 | #endif // !INTEL_MKL || !ENABLE_MKL |
26 | } |
27 | |
28 | void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) { |
29 | ctx->SetStatus(errors::Unimplemented( |
30 | "Broadcast between " , ctx->input(0).shape().DebugString(), " and " , |
31 | ctx->input(1).shape().DebugString(), " is not supported yet." )); |
32 | } |
33 | |
34 | void BinaryOpShared::SetComputeError(OpKernelContext* ctx) { |
35 | // For speed, errors during compute are caught only via boolean flag, with no |
36 | // associated information. This is sufficient for now, since the only binary |
37 | // ops that have compute errors are integer division and mod, and the only |
38 | // error they produce is zero division. |
39 | const string& op = ctx->op_kernel().type_string(); |
40 | if ((op == "Div" || op == "Mod" || op == "FloorMod" || op == "FloorDiv" ) && |
41 | DataTypeIsInteger(ctx->op_kernel().input_type(0))) { |
42 | ctx->CtxFailure(errors::InvalidArgument("Integer division by zero" )); |
43 | } else if ((op == "Pow" ) && |
44 | DataTypeIsInteger(ctx->op_kernel().input_type(0)) && |
45 | DataTypeIsSigned(ctx->op_kernel().input_type(1))) { |
46 | ctx->CtxFailure(errors::InvalidArgument( |
47 | "Integers to negative integer powers are not allowed" )); |
48 | } else { |
49 | ctx->CtxFailure( |
50 | errors::Internal("Unexpected error in binary operator " |
51 | "(only integer div and mod should have errors)" )); |
52 | } |
53 | } |
54 | |
55 | BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx) |
56 | : in0(ctx->input(0)), |
57 | in1(ctx->input(1)), |
58 | bcast(BCast::FromShape(in0.shape()), BCast::FromShape(in1.shape())) { |
59 | if (!bcast.IsValid()) { |
60 | bool incompatible_shape_error; |
61 | bool has_attr = |
62 | TryGetNodeAttr(ctx->op_kernel().def(), "incompatible_shape_error" , |
63 | &(incompatible_shape_error)); |
64 | if (has_attr && !incompatible_shape_error) { |
65 | const string& op = ctx->op_kernel().type_string(); |
66 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); |
67 | result = (op == "NotEqual" ); |
68 | return; |
69 | } |
70 | |
71 | ctx->SetStatus(errors::InvalidArgument( |
72 | "Incompatible shapes: " , in0.shape().DebugString(), " vs. " , |
73 | in1.shape().DebugString())); |
74 | return; |
75 | } |
76 | |
77 | const TensorShape output_shape = BCast::ToShape(bcast.output_shape()); |
78 | out_num_elements = output_shape.num_elements(); |
79 | in0_num_elements = in0.NumElements(); |
80 | in1_num_elements = in1.NumElements(); |
81 | OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( |
82 | {0, 1}, 0, output_shape, &out)); |
83 | |
84 | ndims = static_cast<int>(bcast.x_reshape().size()); |
85 | } |
86 | |
87 | } // namespace tensorflow |
88 | |