1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /// @example bnorm_u8_via_binary_postops.cpp |
18 | /// @copybrief bnorm_u8_via_binary_postops_cpp |
19 | /// > Annotated version: @ref bnorm_u8_via_binary_postops_cpp |
20 | /// |
21 | /// @page bnorm_u8_via_binary_postops_cpp_short |
22 | /// Bnorm u8 via binary postops example. |
23 | /// |
24 | /// @page bnorm_u8_via_binary_postops_cpp Bnorm u8 by binary post-ops example |
25 | /// The example implements the Batch normalization u8 via the following |
26 | /// operations: binary_sub(src, mean), binary_div(tmp_dst, variance), |
27 | /// binary_mul(tmp_dst, scale), binary_add(tmp_dst, shift). |
28 | /// |
29 | /// Some key take-aways include: |
30 | /// |
31 | /// * How tensors are implemented and submitted to primitives. |
32 | /// * How primitives are created. |
33 | /// * How to use multiple binary post operations. |
34 | /// * How to use different data types in binary. |
35 | /// |
36 | /// @include bnorm_u8_via_binary_postops.cpp |
37 | |
38 | #include <algorithm> |
39 | #include <cmath> |
40 | #include <iostream> |
41 | #include <string> |
42 | #include <vector> |
43 | |
44 | #include "dnnl.hpp" |
45 | #include "example_utils.hpp" |
46 | |
47 | using namespace dnnl; |
48 | |
49 | using tag = memory::format_tag; |
50 | using dt = memory::data_type; |
51 | |
52 | void bnorm_u8_via_binary_postops(dnnl::engine::kind engine_kind) { |
53 | |
54 | // Create execution dnnl::engine. |
55 | dnnl::engine engine(engine_kind, 0); |
56 | |
57 | // Create dnnl::stream. |
58 | dnnl::stream engine_stream(engine); |
59 | |
60 | // Tensor dimensions. |
61 | const memory::dim N = 3, // batch size |
62 | IC = 3, // channels |
63 | IH = 150, // tensor height |
64 | IW = 150; // tensor width |
65 | |
66 | // Tensors dimensions. |
67 | memory::dims src_dims = {N, IC, IH, IW}; |
68 | memory::dims params_dims = {1, IC, 1, 1}; |
69 | |
70 | // Allocate buffers. |
71 | std::vector<float> src_data(product(src_dims)); |
72 | std::vector<float> mean_data(product(params_dims)); |
73 | std::vector<float> variance_data(product(params_dims)); |
74 | std::vector<float> scale_data(product(params_dims)); |
75 | std::vector<float> shift_data(product(params_dims)); |
76 | std::vector<float> oscale_data(product(params_dims)); |
77 | |
78 | // Initialize |
79 | std::generate(src_data.begin(), src_data.end(), []() { |
80 | static int i = 0; |
81 | return std::cos(i++ / 10.f); |
82 | }); |
83 | std::generate(mean_data.begin(), mean_data.end(), []() { |
84 | static int i = 0; |
85 | return std::sin(i++ * 2.f); |
86 | }); |
87 | std::generate(variance_data.begin(), variance_data.end(), []() { |
88 | static int i = 0; |
89 | float value = std::abs(std::sin(i++ * 4.f)); |
90 | // Avoid division by zero. Variance should be positive. |
91 | return value == 0.f ? 1.f : value; |
92 | }); |
93 | std::generate(scale_data.begin(), scale_data.end(), []() { |
94 | static int i = 0; |
95 | return std::sin(i++ * 6.f); |
96 | }); |
97 | std::generate(shift_data.begin(), shift_data.end(), []() { |
98 | static int i = 0; |
99 | return std::sin(i++ * 8.f); |
100 | }); |
101 | std::generate( |
102 | oscale_data.begin(), oscale_data.end(), []() { return 0.5f; }); |
103 | |
104 | // Create descriptors. |
105 | auto src_md = memory::desc(src_dims, dt::u8, tag::nhwc); |
106 | auto mean_md = memory::desc(params_dims, dt::f32, tag::nhwc); |
107 | auto variance_md = memory::desc(params_dims, dt::f32, tag::nhwc); |
108 | auto scale_md = memory::desc(params_dims, dt::f32, tag::nhwc); |
109 | auto shift_md = memory::desc(params_dims, dt::f32, tag::nhwc); |
110 | auto oscale_md = memory::desc(params_dims, dt::f32, tag::nhwc); |
111 | |
112 | // Create src memory objects. |
113 | auto src_mem = memory(src_md, engine); |
114 | auto mean_mem = memory(mean_md, engine); |
115 | auto variance_mem = memory(variance_md, engine); |
116 | auto scale_mem = memory(scale_md, engine); |
117 | auto shift_mem = memory(shift_md, engine); |
118 | auto oscale_mem = memory(oscale_md, engine); |
119 | |
120 | // Write data to memory object's handle. |
121 | write_to_dnnl_memory(src_data.data(), src_mem); |
122 | write_to_dnnl_memory(mean_data.data(), mean_mem); |
123 | write_to_dnnl_memory(variance_data.data(), variance_mem); |
124 | write_to_dnnl_memory(scale_data.data(), scale_mem); |
125 | write_to_dnnl_memory(shift_data.data(), shift_mem); |
126 | write_to_dnnl_memory(oscale_data.data(), oscale_mem); |
127 | |
128 | // Bnorm operation with scale and shift |
129 | post_ops binary_ops; |
130 | // dst_tmp = dst_tmp / variance |
131 | binary_ops.append_binary(algorithm::binary_div, variance_md); |
132 | // dst_tmp = dst_tmp * scale |
133 | binary_ops.append_binary(algorithm::binary_mul, scale_md); |
134 | // dst_tmp = dst_tmp + shift |
135 | binary_ops.append_binary(algorithm::binary_add, shift_md); |
136 | // dst = dst_tmp * output_scale (only for re-quantization) |
137 | binary_ops.append_binary(algorithm::binary_mul, oscale_md); |
138 | primitive_attr binary_attr; |
139 | binary_attr.set_post_ops(binary_ops); |
140 | |
141 | // Create primitive descriptor. |
142 | // dst_tmp = src - mean |
143 | auto binary_pd = binary::primitive_desc(engine, algorithm::binary_sub, |
144 | src_md, mean_md, src_md, binary_attr); |
145 | |
146 | // Create the primitive. |
147 | auto binary_prim = binary(binary_pd); |
148 | |
149 | // Primitive arguments. |
150 | std::unordered_map<int, memory> binary_args; |
151 | binary_args.insert({DNNL_ARG_SRC_0, src_mem}); |
152 | binary_args.insert({DNNL_ARG_SRC_1, mean_mem}); |
153 | // In-place mode (dst is src) |
154 | binary_args.insert({DNNL_ARG_DST, src_mem}); |
155 | binary_args.insert( |
156 | {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, variance_mem}); |
157 | binary_args.insert( |
158 | {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, scale_mem}); |
159 | binary_args.insert( |
160 | {DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1, shift_mem}); |
161 | binary_args.insert( |
162 | {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, oscale_mem}); |
163 | |
164 | // Primitive execution |
165 | binary_prim.execute(engine_stream, binary_args); |
166 | |
167 | // Wait for the computation to finalize. |
168 | engine_stream.wait(); |
169 | |
170 | // Read data from memory object's handle. |
171 | read_from_dnnl_memory(src_data.data(), src_mem); |
172 | } |
173 | |
174 | int main(int argc, char **argv) { |
175 | return handle_example_errors( |
176 | bnorm_u8_via_binary_postops, parse_engine_kind(argc, argv)); |
177 | } |
178 | |