1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "reorder/reorder.hpp"
20
21namespace reorder {
22
23void compute_ref(
24 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
25 const dnn_mem_t &src = args.find(DNNL_ARG_FROM);
26 const dnn_mem_t &dst = args.find(DNNL_ARG_TO);
27 const dnn_mem_t &s8_comp = args.find(DNNL_ARG_SRC_1);
28 const dnn_mem_t &zp_comp = args.find(DNNL_ARG_SRC_2);
29
30 float *dst_ptr = (float *)dst;
31
32 const auto dst_dt = prb->ddt;
33 const auto nelems = src.nelems();
34 const auto &src_scales = prb->attr.scales.get(DNNL_ARG_FROM);
35 const auto &dst_scales = prb->attr.scales.get(DNNL_ARG_TO);
36 const int src_scale_mask = attr_t::get_default_mask(src_scales.policy);
37 const int dst_scale_mask = attr_t::get_default_mask(dst_scales.policy);
38 // This is native to reorder zero point which comes from reorder attributes.
39 const int src_zero_point = prb->src_zp ? prb->src_zp[0] : 0;
40 const int dst_zero_point = prb->dst_zp ? prb->dst_zp[0] : 0;
41
42 float beta = 0;
43 const auto &po = prb->attr.post_ops;
44 const int beta_idx = po.find(attr_t::post_ops_t::kind_t::SUM);
45 if (beta_idx >= 0) beta = po.entry[beta_idx].sum.scale;
46
47 // These are non-native compensations coming from other primitives with
48 // s8s8 or zero-points support to pre-compute compensated part and apply it
49 // at the end of computations.
50 const bool need_s8_comp = s8_comp.dt() == dnnl_s32;
51 const bool need_zp_comp = zp_comp.dt() == dnnl_s32;
52 const bool need_comp = need_s8_comp || need_zp_comp;
53 // `adjust_scale` participates only with s8s8 compensation.
54 const float s8_scale_factor = need_s8_comp ? reorder_rescale_factor() : 1.f;
55
56 benchdnn_parallel_nd(nelems, [&](int64_t idx) {
57 float s = src.get_elem(idx) - src_zero_point;
58 float d = 0;
59 if (beta_idx >= 0) d = dst.get_elem(idx) - dst_zero_point;
60
61 float src_scale = 1.f, dst_scale = 1.f;
62 if (!src_scales.is_def()) {
63 int64_t src_mask_idx = src.get_scale_idx(idx, src_scale_mask);
64 src_scale = prb->src_scales[src_mask_idx];
65 }
66 if (!dst_scales.is_def()) {
67 int64_t dst_mask_idx = dst.get_scale_idx(idx, dst_scale_mask);
68 dst_scale = prb->dst_scales[dst_mask_idx];
69 }
70 float value = (s8_scale_factor * src_scale * s + beta * d) / dst_scale
71 + dst_zero_point;
72 value = maybe_saturate(dst_dt, value);
73 if (dst_dt == dnnl_s32 && value >= (float)INT_MAX)
74 value = BENCHDNN_S32_TO_F32_SAT_CONST;
75
76 dst_ptr[idx] = round_to_nearest_representable(dst_dt, value);
77 });
78
79 if (!need_comp) return;
80
81 int *s8_comp_ptr = (int *)s8_comp;
82 int *zp_comp_ptr = (int *)zp_comp;
83
84 // mostly following benchdnn/ref_reduction.cpp/compute_ref
85 const auto nelems_s8_comp = s8_comp.nelems();
86 const auto nelems_zp_comp = zp_comp.nelems();
87 const auto nelems_comp = MAX2(nelems_s8_comp, nelems_zp_comp);
88 const auto &ndims = src.ndims();
89 assert(nelems_comp > 0);
90 assert(IMPLICATION(
91 need_s8_comp && need_zp_comp, nelems_s8_comp == nelems_zp_comp));
92
93 int comp_mask = 0;
94 for (const auto &i_oflag : prb->oflag) {
95 if ((i_oflag.first == FLAG_S8S8_COMP || i_oflag.first == FLAG_ZP_COMP)
96 && i_oflag.second != FLAG_NONE) {
97 comp_mask = i_oflag.second;
98 break;
99 }
100 }
101
102 dims_t comp_dims(ndims, 1); // src_dims with '1' at non-masked dims.
103 dims_t reduce_dims(ndims, 1); // complementary to above.
104 for (int i = 0; i < ndims; ++i) {
105 if (comp_mask & (1 << i)) {
106 comp_dims[i] = src.dims()[i];
107 reduce_dims[i] = 1;
108 } else {
109 comp_dims[i] = 1;
110 reduce_dims[i] = src.dims()[i];
111 }
112 }
113
114 const auto nelems_reduce = nelems / nelems_comp;
115 benchdnn_parallel_nd(nelems_comp, [&](int64_t f) {
116 dims_t idle_pos = off2dims_idx(comp_dims, f);
117 const int64_t src_idle_off = md_off_v(src, idle_pos.data());
118 int comp_val = 0;
119 for (int64_t r = 0; r < nelems_reduce; ++r) {
120 dims_t reduce_pos = off2dims_idx(reduce_dims, r);
121 const int64_t src_reduce_off = md_off_v(src, reduce_pos.data());
122 const int64_t src_off = src_idle_off + src_reduce_off;
123
124 float src_scale = 1.f, dst_scale = 1.f;
125 if (!src_scales.is_def()) {
126 int64_t src_mask_idx
127 = src.get_scale_idx(src_off, src_scale_mask);
128 src_scale = prb->src_scales[src_mask_idx];
129 }
130 if (!dst_scales.is_def()) {
131 int64_t dst_mask_idx
132 = dst.get_scale_idx(src_off, dst_scale_mask);
133 dst_scale = prb->dst_scales[dst_mask_idx];
134 }
135
136 const float alpha = src_scale / dst_scale;
137 const float value = src.get_elem(src_off) * alpha * s8_scale_factor;
138 comp_val -= maybe_saturate(dst_dt, value);
139 }
140 if (need_zp_comp) zp_comp_ptr[f] = comp_val;
141 comp_val *= 128;
142 if (need_s8_comp) s8_comp_ptr[f] = comp_val;
143 });
144}
145
146} // namespace reorder
147