1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifdef INTEL_MKL
17
18// This file contains the registration of MKL-DNN array ops.
19
20#include "tensorflow/core/framework/common_shape_fns.h"
21#include "tensorflow/core/framework/op.h"
22#include "tensorflow/core/framework/shape_inference.h"
23#include "tensorflow/core/framework/tensor.pb.h"
24#include "tensorflow/core/util/mirror_pad_mode.h"
25#include "tensorflow/core/util/padding.h"
26#include "tensorflow/core/util/strided_slice_op.h"
27#include "tensorflow/core/util/tensor_format.h"
28
29namespace tensorflow {
30
31using shape_inference::DimensionHandle;
32using shape_inference::InferenceContext;
33using shape_inference::ShapeHandle;
34using shape_inference::UnchangedShape;
35
36// Adding QuantizedConcatV2 op to be able to replace it by
37// _MklQuantizedConcatV2 in the graph rewrite.
38REGISTER_OP("QuantizedConcatV2")
39 .Input("values: N * T")
40 .Input("axis: Tidx")
41 .Input("input_mins: N * float32")
42 .Input("input_maxes: N * float32")
43 .Output("output: T")
44 .Output("output_min: float")
45 .Output("output_max: float")
46 .Attr("N: int >= 2")
47 .Attr("T: type")
48 .Attr("Tidx: {int32, int64} = DT_INT32")
49 .SetShapeFn([](InferenceContext* c) {
50 const int n = (c->num_inputs() - 1) / 3;
51 TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n));
52 ShapeHandle unused;
53 for (int i = n + 1; i < c->num_inputs(); ++i) {
54 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
55 }
56 c->set_output(1, c->Scalar());
57 c->set_output(2, c->Scalar());
58 return Status::OK();
59 });
60
61REGISTER_OP("_MklQuantizedConcatV2")
62 .Input("values: N * T")
63 .Input("axis: Tidx")
64 .Input("input_mins: N * float32")
65 .Input("input_maxes: N * float32")
66 .Output("output: T")
67 .Output("output_min: float")
68 .Output("output_max: float")
69 .Attr("N: int >= 2")
70 .Attr("T: type")
71 .Attr("Tidx: {int32, int64} = DT_INT32")
72 .SetShapeFn([](InferenceContext* c) {
73 const int n = (c->num_inputs() / 2 - 1) / 3;
74 TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n));
75 ShapeHandle unused;
76 for (int i = n + 1; i < c->num_inputs() / 2; ++i) {
77 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
78 }
79 c->set_output(1, c->Scalar());
80 c->set_output(2, c->Scalar());
81 return Status::OK();
82 });
83
84REGISTER_OP("_MklQuantizeV2")
85 .Input("input: float")
86 .Input("min_range: float")
87 .Input("max_range: float")
88 .Output("output: T")
89 .Output("output_min: float")
90 .Output("output_max: float")
91 .Attr("T: quantizedtype")
92 .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'SCALED'")
93 .Attr(
94 "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = "
95 "'HALF_AWAY_FROM_ZERO'")
96 .Attr("narrow_range: bool = false")
97 .Attr("axis: int = -1")
98 .Attr("ensure_minimum_range: float = 0.01")
99 .SetShapeFn(shape_inference::QuantizeV2Shape);
100
101REGISTER_OP("_MklDequantize")
102 .Input("input: T")
103 .Input("min_range: float")
104 .Input("max_range: float")
105 .Output("output: float")
106 .Attr("T: quantizedtype")
107 .Attr("narrow_range: bool = false")
108 .Attr("axis: int = -1")
109 .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'SCALED'")
110 .Attr("dtype: {bfloat16, float} = DT_FLOAT")
111 .SetShapeFn([](InferenceContext* c) {
112 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
113 ShapeHandle unused;
114 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
115 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
116 return Status::OK();
117 });
118
119} // namespace tensorflow
120
121#endif // INTEL_MKL
122