1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
17
18#include <string>
19
20#include "absl/strings/str_join.h"
21
22namespace tensorflow {
23
24Status InitializeFusedComputation(
25 OpKernelConstruction* context, const std::string& kernel_name,
26 const std::vector<FusedComputationPattern>& patterns,
27 FusedComputationType* fused_computation,
28 FusedComputationArgs* fused_computation_args) {
29 // 'fused_ops' and 'num_args' attributes are specified by the Grappler
30 // Remapper optimizer (see grappler/optimizers/remapper.cc).
31
32 std::vector<std::string> fused_ops;
33 TF_RETURN_IF_ERROR(context->GetAttr("fused_ops", &fused_ops));
34 if (fused_ops.empty()) {
35 return errors::InvalidArgument("Fused ", kernel_name,
36 " must have at least one fused op.");
37 }
38
39 int num_args;
40 TF_RETURN_IF_ERROR(context->GetAttr("num_args", &num_args));
41 int num_host_args;
42 if (!context->GetAttr("num_host_args", &num_host_args).ok()) {
43 num_host_args = 0; // default value
44 }
45 int num_inputs = context->num_inputs();
46 if (num_inputs != 2 + num_args + num_host_args) {
47 return errors::InvalidArgument(
48 "Fused ", kernel_name,
49 " must have the number of inputs equal to 2 + num_args + num_host_args "
50 "but in fact the number of inputs is ",
51 num_inputs, " and num_args is ", num_args, " and num_host_args is ",
52 num_host_args);
53 }
54
55 // TODO(ezhulenev): Add support for fusion element-wise op chains defined
56 // at runtime, e.g. Relu+Sqrt+Tanh+etc.
57
58 // Reset fused computation type.
59 *fused_computation = FusedComputationType::kUndefined;
60
61 // Match op fusion to one of the supported patterns.
62 for (const auto& pattern : patterns) {
63 if (fused_ops == pattern.fused_ops) {
64 *fused_computation = pattern.fused_computation;
65 break;
66 }
67 }
68 if (*fused_computation == FusedComputationType::kUndefined) {
69 return errors::Unimplemented("Fusion is not implemented: [",
70 absl::StrJoin(fused_ops, ","), "]");
71 }
72
73 // Depending on a picked fusion type validate fusion-specific arguments.
74 if (*fused_computation == FusedComputationType::kBiasAdd ||
75 *fused_computation == FusedComputationType::kBiasAddWithRelu ||
76 *fused_computation == FusedComputationType::kBiasAddWithRelu6 ||
77 *fused_computation == FusedComputationType::kBiasAddWithTanh ||
78 *fused_computation == FusedComputationType::kBiasAddWithSigmoid ||
79 *fused_computation == FusedComputationType::kBiasAddWithElu ||
80 *fused_computation == FusedComputationType::kBiasAddWithLeakyRelu ||
81 *fused_computation == FusedComputationType::kBiasAddWithGeluApproximate ||
82 *fused_computation == FusedComputationType::kBiasAddWithGeluExact) {
83 if (num_args != 1 && !(num_args == 2 && num_host_args == 2)) {
84 return errors::InvalidArgument(
85 "Fused ", kernel_name,
86 " with BiasAdd must have one extra argument: bias"
87 " or 4 extra arguments: bias, side_input, conv_input_scale and "
88 "side_input_scale");
89 }
90 constexpr int kConvInput = 0;
91 constexpr int kFilter = 1;
92 constexpr int kBias = 2;
93 constexpr int kSideInput = 3;
94 if (context->input_type(kConvInput) == DT_INT8) {
95 if (num_inputs != 6) {
96 return errors::InvalidArgument("Fused ", kernel_name,
97 " for int8 must have 6 inputs and ",
98 num_inputs, " is provided");
99 }
100 if (context->input_type(kFilter) != DT_INT8) {
101 return errors::InvalidArgument("Fused ", kernel_name,
102 " for int8 has filter type ",
103 context->input_type(kFilter),
104 " that does not match the input fype ",
105 context->input_type(kConvInput));
106 }
107 if (context->input_type(kBias) != DT_FLOAT) {
108 return errors::InvalidArgument(
109 "Fused ", kernel_name, " for int8 has bias type ",
110 context->input_type(kBias), " that must have the float type");
111 }
112 if (context->input_type(kSideInput) != DT_INT8) {
113 return errors::InvalidArgument("Fused ", kernel_name,
114 " for int8 has side_input type ",
115 context->input_type(kSideInput),
116 " that does not match the input fype ",
117 context->input_type(kConvInput));
118 }
119 }
120 if (*fused_computation == FusedComputationType::kBiasAddWithLeakyRelu) {
121 TF_RETURN_IF_ERROR(context->GetAttr(
122 "leakyrelu_alpha", &fused_computation_args->leakyrelu_alpha));
123 }
124 }
125
126 if (*fused_computation == FusedComputationType::kFusedBatchNorm ||
127 *fused_computation == FusedComputationType::kFusedBatchNormWithRelu ||
128 *fused_computation == FusedComputationType::kFusedBatchNormWithRelu6 ||
129 *fused_computation == FusedComputationType::kFusedBatchNormWithElu ||
130 *fused_computation ==
131 FusedComputationType::kFusedBatchNormWithLeakyRelu) {
132 if (num_args != 4) {
133 return errors::InvalidArgument(
134 "Fused ", kernel_name,
135 " with FusedBatchNorm must have four extra arguments: scale, offset, "
136 "mean, variance.");
137 }
138 TF_RETURN_IF_ERROR(
139 context->GetAttr("epsilon", &fused_computation_args->epsilon));
140 if (*fused_computation ==
141 FusedComputationType::kFusedBatchNormWithLeakyRelu) {
142 TF_RETURN_IF_ERROR(context->GetAttr(
143 "leakyrelu_alpha", &fused_computation_args->leakyrelu_alpha));
144 }
145 }
146
147 return OkStatus();
148}
149
150} // namespace tensorflow
151