1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/memory_tracking.hpp"
20#include "common/nstl.hpp"
21
22#include "cpu/cpu_primitive.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace cpu {
27
28namespace {
29constexpr size_t scales_simd_w = 16;
30}
31
32void book_precomputed_scales(memory_tracking::registrar_t &scratchpad,
33 const arg_scales_t &attr_scales, size_t oc) {
34 using namespace dnnl::impl::memory_tracking::names;
35
36 const bool with_src_scales
37 = !attr_scales.get(DNNL_ARG_SRC).has_default_values();
38 const bool with_wei_scales
39 = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values();
40 if (with_src_scales && with_wei_scales) {
41 const int wei_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_;
42 const size_t precomputed_scales_size = wei_mask == 0
43 ? scales_simd_w
44 : nstl::max(static_cast<size_t>(oc), scales_simd_w);
45 scratchpad.template book<float>(
46 memory_tracking::names::key_precomputed_scales,
47 precomputed_scales_size);
48 }
49}
50
51const float *precompute_scales(const memory_tracking::grantor_t &scratchpad,
52 const float *src_scales, const float *wei_scales, dim_t oc,
53 const primitive_attr_t *attr) {
54 using namespace dnnl::impl::memory_tracking::names;
55
56 const auto &attr_scales = attr->scales_;
57 bool with_src_scales = !attr_scales.get(DNNL_ARG_SRC).has_default_values();
58 bool with_wei_scales
59 = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values();
60 int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_;
61 dim_t wei_scale_count = wei_scale_mask == 0 ? 1 : oc;
62
63 const float *scales = nullptr;
64 if (with_src_scales && with_wei_scales) {
65 size_t size = 0;
66 auto loc_scales
67 = scratchpad.template get<float>(key_precomputed_scales, &size);
68 if (wei_scale_mask == 0) {
69 const size_t count = nstl::min(size / sizeof(float), scales_simd_w);
70 utils::array_set(loc_scales, src_scales[0] * wei_scales[0], count);
71 } else {
72 const dim_t count = nstl::min(
73 static_cast<dim_t>(size / sizeof(float)), wei_scale_count);
74 PRAGMA_OMP_SIMD()
75 for (dim_t c = 0; c < count; c++)
76 loc_scales[c] = src_scales[0] * wei_scales[c];
77 }
78 scales = loc_scales;
79 } else if (with_src_scales) {
80 scales = src_scales;
81 } else {
82 scales = wei_scales;
83 }
84
85 return scales;
86}
87
88} // namespace cpu
89} // namespace impl
90} // namespace dnnl
91