1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/fused_eigen_output_kernels.h" |
17 | |
18 | #include <string> |
19 | |
20 | #include "absl/strings/str_join.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | Status InitializeFusedComputation( |
25 | OpKernelConstruction* context, const std::string& kernel_name, |
26 | const std::vector<FusedComputationPattern>& patterns, |
27 | FusedComputationType* fused_computation, |
28 | FusedComputationArgs* fused_computation_args) { |
29 | // 'fused_ops' and 'num_args' attributes are specified by the Grappler |
30 | // Remapper optimizer (see grappler/optimizers/remapper.cc). |
31 | |
32 | std::vector<std::string> fused_ops; |
33 | TF_RETURN_IF_ERROR(context->GetAttr("fused_ops" , &fused_ops)); |
34 | if (fused_ops.empty()) { |
35 | return errors::InvalidArgument("Fused " , kernel_name, |
36 | " must have at least one fused op." ); |
37 | } |
38 | |
39 | int num_args; |
40 | TF_RETURN_IF_ERROR(context->GetAttr("num_args" , &num_args)); |
41 | int num_host_args; |
42 | if (!context->GetAttr("num_host_args" , &num_host_args).ok()) { |
43 | num_host_args = 0; // default value |
44 | } |
45 | int num_inputs = context->num_inputs(); |
46 | if (num_inputs != 2 + num_args + num_host_args) { |
47 | return errors::InvalidArgument( |
48 | "Fused " , kernel_name, |
49 | " must have the number of inputs equal to 2 + num_args + num_host_args " |
50 | "but in fact the number of inputs is " , |
51 | num_inputs, " and num_args is " , num_args, " and num_host_args is " , |
52 | num_host_args); |
53 | } |
54 | |
55 | // TODO(ezhulenev): Add support for fusion element-wise op chains defined |
56 | // at runtime, e.g. Relu+Sqrt+Tanh+etc. |
57 | |
58 | // Reset fused computation type. |
59 | *fused_computation = FusedComputationType::kUndefined; |
60 | |
61 | // Match op fusion to one of the supported patterns. |
62 | for (const auto& pattern : patterns) { |
63 | if (fused_ops == pattern.fused_ops) { |
64 | *fused_computation = pattern.fused_computation; |
65 | break; |
66 | } |
67 | } |
68 | if (*fused_computation == FusedComputationType::kUndefined) { |
69 | return errors::Unimplemented("Fusion is not implemented: [" , |
70 | absl::StrJoin(fused_ops, "," ), "]" ); |
71 | } |
72 | |
73 | // Depending on a picked fusion type validate fusion-specific arguments. |
74 | if (*fused_computation == FusedComputationType::kBiasAdd || |
75 | *fused_computation == FusedComputationType::kBiasAddWithRelu || |
76 | *fused_computation == FusedComputationType::kBiasAddWithRelu6 || |
77 | *fused_computation == FusedComputationType::kBiasAddWithTanh || |
78 | *fused_computation == FusedComputationType::kBiasAddWithSigmoid || |
79 | *fused_computation == FusedComputationType::kBiasAddWithElu || |
80 | *fused_computation == FusedComputationType::kBiasAddWithLeakyRelu || |
81 | *fused_computation == FusedComputationType::kBiasAddWithGeluApproximate || |
82 | *fused_computation == FusedComputationType::kBiasAddWithGeluExact) { |
83 | if (num_args != 1 && !(num_args == 2 && num_host_args == 2)) { |
84 | return errors::InvalidArgument( |
85 | "Fused " , kernel_name, |
86 | " with BiasAdd must have one extra argument: bias" |
87 | " or 4 extra arguments: bias, side_input, conv_input_scale and " |
88 | "side_input_scale" ); |
89 | } |
90 | constexpr int kConvInput = 0; |
91 | constexpr int kFilter = 1; |
92 | constexpr int kBias = 2; |
93 | constexpr int kSideInput = 3; |
94 | if (context->input_type(kConvInput) == DT_INT8) { |
95 | if (num_inputs != 6) { |
96 | return errors::InvalidArgument("Fused " , kernel_name, |
97 | " for int8 must have 6 inputs and " , |
98 | num_inputs, " is provided" ); |
99 | } |
100 | if (context->input_type(kFilter) != DT_INT8) { |
101 | return errors::InvalidArgument("Fused " , kernel_name, |
102 | " for int8 has filter type " , |
103 | context->input_type(kFilter), |
104 | " that does not match the input fype " , |
105 | context->input_type(kConvInput)); |
106 | } |
107 | if (context->input_type(kBias) != DT_FLOAT) { |
108 | return errors::InvalidArgument( |
109 | "Fused " , kernel_name, " for int8 has bias type " , |
110 | context->input_type(kBias), " that must have the float type" ); |
111 | } |
112 | if (context->input_type(kSideInput) != DT_INT8) { |
113 | return errors::InvalidArgument("Fused " , kernel_name, |
114 | " for int8 has side_input type " , |
115 | context->input_type(kSideInput), |
116 | " that does not match the input fype " , |
117 | context->input_type(kConvInput)); |
118 | } |
119 | } |
120 | if (*fused_computation == FusedComputationType::kBiasAddWithLeakyRelu) { |
121 | TF_RETURN_IF_ERROR(context->GetAttr( |
122 | "leakyrelu_alpha" , &fused_computation_args->leakyrelu_alpha)); |
123 | } |
124 | } |
125 | |
126 | if (*fused_computation == FusedComputationType::kFusedBatchNorm || |
127 | *fused_computation == FusedComputationType::kFusedBatchNormWithRelu || |
128 | *fused_computation == FusedComputationType::kFusedBatchNormWithRelu6 || |
129 | *fused_computation == FusedComputationType::kFusedBatchNormWithElu || |
130 | *fused_computation == |
131 | FusedComputationType::kFusedBatchNormWithLeakyRelu) { |
132 | if (num_args != 4) { |
133 | return errors::InvalidArgument( |
134 | "Fused " , kernel_name, |
135 | " with FusedBatchNorm must have four extra arguments: scale, offset, " |
136 | "mean, variance." ); |
137 | } |
138 | TF_RETURN_IF_ERROR( |
139 | context->GetAttr("epsilon" , &fused_computation_args->epsilon)); |
140 | if (*fused_computation == |
141 | FusedComputationType::kFusedBatchNormWithLeakyRelu) { |
142 | TF_RETURN_IF_ERROR(context->GetAttr( |
143 | "leakyrelu_alpha" , &fused_computation_args->leakyrelu_alpha)); |
144 | } |
145 | } |
146 | |
147 | return OkStatus(); |
148 | } |
149 | |
150 | } // namespace tensorflow |
151 | |