1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "reorder/reorder.hpp" |
20 | |
21 | namespace reorder { |
22 | |
23 | void compute_ref( |
24 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
25 | const dnn_mem_t &src = args.find(DNNL_ARG_FROM); |
26 | const dnn_mem_t &dst = args.find(DNNL_ARG_TO); |
27 | const dnn_mem_t &s8_comp = args.find(DNNL_ARG_SRC_1); |
28 | const dnn_mem_t &zp_comp = args.find(DNNL_ARG_SRC_2); |
29 | |
30 | float *dst_ptr = (float *)dst; |
31 | |
32 | const auto dst_dt = prb->ddt; |
33 | const auto nelems = src.nelems(); |
34 | const auto &src_scales = prb->attr.scales.get(DNNL_ARG_FROM); |
35 | const auto &dst_scales = prb->attr.scales.get(DNNL_ARG_TO); |
36 | const int src_scale_mask = attr_t::get_default_mask(src_scales.policy); |
37 | const int dst_scale_mask = attr_t::get_default_mask(dst_scales.policy); |
38 | // This is native to reorder zero point which comes from reorder attributes. |
39 | const int src_zero_point = prb->src_zp ? prb->src_zp[0] : 0; |
40 | const int dst_zero_point = prb->dst_zp ? prb->dst_zp[0] : 0; |
41 | |
42 | float beta = 0; |
43 | const auto &po = prb->attr.post_ops; |
44 | const int beta_idx = po.find(attr_t::post_ops_t::kind_t::SUM); |
45 | if (beta_idx >= 0) beta = po.entry[beta_idx].sum.scale; |
46 | |
47 | // These are non-native compensations coming from other primitives with |
48 | // s8s8 or zero-points support to pre-compute compensated part and apply it |
49 | // at the end of computations. |
50 | const bool need_s8_comp = s8_comp.dt() == dnnl_s32; |
51 | const bool need_zp_comp = zp_comp.dt() == dnnl_s32; |
52 | const bool need_comp = need_s8_comp || need_zp_comp; |
53 | // `adjust_scale` participates only with s8s8 compensation. |
54 | const float s8_scale_factor = need_s8_comp ? reorder_rescale_factor() : 1.f; |
55 | |
56 | benchdnn_parallel_nd(nelems, [&](int64_t idx) { |
57 | float s = src.get_elem(idx) - src_zero_point; |
58 | float d = 0; |
59 | if (beta_idx >= 0) d = dst.get_elem(idx) - dst_zero_point; |
60 | |
61 | float src_scale = 1.f, dst_scale = 1.f; |
62 | if (!src_scales.is_def()) { |
63 | int64_t src_mask_idx = src.get_scale_idx(idx, src_scale_mask); |
64 | src_scale = prb->src_scales[src_mask_idx]; |
65 | } |
66 | if (!dst_scales.is_def()) { |
67 | int64_t dst_mask_idx = dst.get_scale_idx(idx, dst_scale_mask); |
68 | dst_scale = prb->dst_scales[dst_mask_idx]; |
69 | } |
70 | float value = (s8_scale_factor * src_scale * s + beta * d) / dst_scale |
71 | + dst_zero_point; |
72 | value = maybe_saturate(dst_dt, value); |
73 | if (dst_dt == dnnl_s32 && value >= (float)INT_MAX) |
74 | value = BENCHDNN_S32_TO_F32_SAT_CONST; |
75 | |
76 | dst_ptr[idx] = round_to_nearest_representable(dst_dt, value); |
77 | }); |
78 | |
79 | if (!need_comp) return; |
80 | |
81 | int *s8_comp_ptr = (int *)s8_comp; |
82 | int *zp_comp_ptr = (int *)zp_comp; |
83 | |
84 | // mostly following benchdnn/ref_reduction.cpp/compute_ref |
85 | const auto nelems_s8_comp = s8_comp.nelems(); |
86 | const auto nelems_zp_comp = zp_comp.nelems(); |
87 | const auto nelems_comp = MAX2(nelems_s8_comp, nelems_zp_comp); |
88 | const auto &ndims = src.ndims(); |
89 | assert(nelems_comp > 0); |
90 | assert(IMPLICATION( |
91 | need_s8_comp && need_zp_comp, nelems_s8_comp == nelems_zp_comp)); |
92 | |
93 | int comp_mask = 0; |
94 | for (const auto &i_oflag : prb->oflag) { |
95 | if ((i_oflag.first == FLAG_S8S8_COMP || i_oflag.first == FLAG_ZP_COMP) |
96 | && i_oflag.second != FLAG_NONE) { |
97 | comp_mask = i_oflag.second; |
98 | break; |
99 | } |
100 | } |
101 | |
102 | dims_t comp_dims(ndims, 1); // src_dims with '1' at non-masked dims. |
103 | dims_t reduce_dims(ndims, 1); // complementary to above. |
104 | for (int i = 0; i < ndims; ++i) { |
105 | if (comp_mask & (1 << i)) { |
106 | comp_dims[i] = src.dims()[i]; |
107 | reduce_dims[i] = 1; |
108 | } else { |
109 | comp_dims[i] = 1; |
110 | reduce_dims[i] = src.dims()[i]; |
111 | } |
112 | } |
113 | |
114 | const auto nelems_reduce = nelems / nelems_comp; |
115 | benchdnn_parallel_nd(nelems_comp, [&](int64_t f) { |
116 | dims_t idle_pos = off2dims_idx(comp_dims, f); |
117 | const int64_t src_idle_off = md_off_v(src, idle_pos.data()); |
118 | int comp_val = 0; |
119 | for (int64_t r = 0; r < nelems_reduce; ++r) { |
120 | dims_t reduce_pos = off2dims_idx(reduce_dims, r); |
121 | const int64_t src_reduce_off = md_off_v(src, reduce_pos.data()); |
122 | const int64_t src_off = src_idle_off + src_reduce_off; |
123 | |
124 | float src_scale = 1.f, dst_scale = 1.f; |
125 | if (!src_scales.is_def()) { |
126 | int64_t src_mask_idx |
127 | = src.get_scale_idx(src_off, src_scale_mask); |
128 | src_scale = prb->src_scales[src_mask_idx]; |
129 | } |
130 | if (!dst_scales.is_def()) { |
131 | int64_t dst_mask_idx |
132 | = dst.get_scale_idx(src_off, dst_scale_mask); |
133 | dst_scale = prb->dst_scales[dst_mask_idx]; |
134 | } |
135 | |
136 | const float alpha = src_scale / dst_scale; |
137 | const float value = src.get_elem(src_off) * alpha * s8_scale_factor; |
138 | comp_val -= maybe_saturate(dst_dt, value); |
139 | } |
140 | if (need_zp_comp) zp_comp_ptr[f] = comp_val; |
141 | comp_val *= 128; |
142 | if (need_s8_comp) s8_comp_ptr[f] = comp_val; |
143 | }); |
144 | } |
145 | |
146 | } // namespace reorder |
147 | |