1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @example bnorm_u8_via_binary_postops.cpp
18/// @copybrief bnorm_u8_via_binary_postops_cpp
19/// > Annotated version: @ref bnorm_u8_via_binary_postops_cpp
20///
21/// @page bnorm_u8_via_binary_postops_cpp_short
22/// Bnorm u8 via binary postops example.
23///
24/// @page bnorm_u8_via_binary_postops_cpp Bnorm u8 by binary post-ops example
25/// The example implements the Batch normalization u8 via the following
26/// operations: binary_sub(src, mean), binary_div(tmp_dst, variance),
27/// binary_mul(tmp_dst, scale), binary_add(tmp_dst, shift).
28///
29/// Some key take-aways include:
30///
31/// * How tensors are implemented and submitted to primitives.
32/// * How primitives are created.
33/// * How to use multiple binary post operations.
34/// * How to use different data types in binary.
35///
36/// @include bnorm_u8_via_binary_postops.cpp
37
38#include <algorithm>
39#include <cmath>
40#include <iostream>
41#include <string>
42#include <vector>
43
44#include "dnnl.hpp"
45#include "example_utils.hpp"
46
47using namespace dnnl;
48
49using tag = memory::format_tag;
50using dt = memory::data_type;
51
52void bnorm_u8_via_binary_postops(dnnl::engine::kind engine_kind) {
53
54 // Create execution dnnl::engine.
55 dnnl::engine engine(engine_kind, 0);
56
57 // Create dnnl::stream.
58 dnnl::stream engine_stream(engine);
59
60 // Tensor dimensions.
61 const memory::dim N = 3, // batch size
62 IC = 3, // channels
63 IH = 150, // tensor height
64 IW = 150; // tensor width
65
66 // Tensors dimensions.
67 memory::dims src_dims = {N, IC, IH, IW};
68 memory::dims params_dims = {1, IC, 1, 1};
69
70 // Allocate buffers.
71 std::vector<float> src_data(product(src_dims));
72 std::vector<float> mean_data(product(params_dims));
73 std::vector<float> variance_data(product(params_dims));
74 std::vector<float> scale_data(product(params_dims));
75 std::vector<float> shift_data(product(params_dims));
76 std::vector<float> oscale_data(product(params_dims));
77
78 // Initialize
79 std::generate(src_data.begin(), src_data.end(), []() {
80 static int i = 0;
81 return std::cos(i++ / 10.f);
82 });
83 std::generate(mean_data.begin(), mean_data.end(), []() {
84 static int i = 0;
85 return std::sin(i++ * 2.f);
86 });
87 std::generate(variance_data.begin(), variance_data.end(), []() {
88 static int i = 0;
89 float value = std::abs(std::sin(i++ * 4.f));
90 // Avoid division by zero. Variance should be positive.
91 return value == 0.f ? 1.f : value;
92 });
93 std::generate(scale_data.begin(), scale_data.end(), []() {
94 static int i = 0;
95 return std::sin(i++ * 6.f);
96 });
97 std::generate(shift_data.begin(), shift_data.end(), []() {
98 static int i = 0;
99 return std::sin(i++ * 8.f);
100 });
101 std::generate(
102 oscale_data.begin(), oscale_data.end(), []() { return 0.5f; });
103
104 // Create descriptors.
105 auto src_md = memory::desc(src_dims, dt::u8, tag::nhwc);
106 auto mean_md = memory::desc(params_dims, dt::f32, tag::nhwc);
107 auto variance_md = memory::desc(params_dims, dt::f32, tag::nhwc);
108 auto scale_md = memory::desc(params_dims, dt::f32, tag::nhwc);
109 auto shift_md = memory::desc(params_dims, dt::f32, tag::nhwc);
110 auto oscale_md = memory::desc(params_dims, dt::f32, tag::nhwc);
111
112 // Create src memory objects.
113 auto src_mem = memory(src_md, engine);
114 auto mean_mem = memory(mean_md, engine);
115 auto variance_mem = memory(variance_md, engine);
116 auto scale_mem = memory(scale_md, engine);
117 auto shift_mem = memory(shift_md, engine);
118 auto oscale_mem = memory(oscale_md, engine);
119
120 // Write data to memory object's handle.
121 write_to_dnnl_memory(src_data.data(), src_mem);
122 write_to_dnnl_memory(mean_data.data(), mean_mem);
123 write_to_dnnl_memory(variance_data.data(), variance_mem);
124 write_to_dnnl_memory(scale_data.data(), scale_mem);
125 write_to_dnnl_memory(shift_data.data(), shift_mem);
126 write_to_dnnl_memory(oscale_data.data(), oscale_mem);
127
128 // Bnorm operation with scale and shift
129 post_ops binary_ops;
130 // dst_tmp = dst_tmp / variance
131 binary_ops.append_binary(algorithm::binary_div, variance_md);
132 // dst_tmp = dst_tmp * scale
133 binary_ops.append_binary(algorithm::binary_mul, scale_md);
134 // dst_tmp = dst_tmp + shift
135 binary_ops.append_binary(algorithm::binary_add, shift_md);
136 // dst = dst_tmp * output_scale (only for re-quantization)
137 binary_ops.append_binary(algorithm::binary_mul, oscale_md);
138 primitive_attr binary_attr;
139 binary_attr.set_post_ops(binary_ops);
140
141 // Create primitive descriptor.
142 // dst_tmp = src - mean
143 auto binary_pd = binary::primitive_desc(engine, algorithm::binary_sub,
144 src_md, mean_md, src_md, binary_attr);
145
146 // Create the primitive.
147 auto binary_prim = binary(binary_pd);
148
149 // Primitive arguments.
150 std::unordered_map<int, memory> binary_args;
151 binary_args.insert({DNNL_ARG_SRC_0, src_mem});
152 binary_args.insert({DNNL_ARG_SRC_1, mean_mem});
153 // In-place mode (dst is src)
154 binary_args.insert({DNNL_ARG_DST, src_mem});
155 binary_args.insert(
156 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, variance_mem});
157 binary_args.insert(
158 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, scale_mem});
159 binary_args.insert(
160 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1, shift_mem});
161 binary_args.insert(
162 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, oscale_mem});
163
164 // Primitive execution
165 binary_prim.execute(engine_stream, binary_args);
166
167 // Wait for the computation to finalize.
168 engine_stream.wait();
169
170 // Read data from memory object's handle.
171 read_from_dnnl_memory(src_data.data(), src_mem);
172}
173
174int main(int argc, char **argv) {
175 return handle_example_errors(
176 bnorm_u8_via_binary_postops, parse_engine_kind(argc, argv));
177}
178