1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef LNORM_HPP |
18 | #define LNORM_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <limits.h> |
22 | #include <numeric> |
23 | #include <stdint.h> |
24 | |
25 | #include <iostream> |
26 | |
27 | #include "common.hpp" |
28 | #include "dnn_types.hpp" |
29 | #include "dnnl_common.hpp" |
30 | #include "dnnl_debug.hpp" |
31 | #include "utils/perf_report.hpp" |
32 | #include "utils/settings.hpp" |
33 | |
34 | #include "bnorm/bnorm.hpp" |
35 | |
36 | namespace lnorm { |
37 | |
38 | using check_alg_t = bnorm::check_alg_t; |
39 | using flags_t = bnorm::flags_t; |
40 | const flags_t NONE = bnorm::NONE; |
41 | const flags_t GLOB_STATS = bnorm::GLOB_STATS; |
42 | const flags_t USE_SCALE = bnorm::USE_SCALE; |
43 | const flags_t USE_SHIFT = bnorm::USE_SHIFT; |
44 | const auto flags2str = bnorm::flags2str; |
45 | flags_t str2flags(const char *str); |
46 | |
47 | struct settings_t : public base_settings_t { |
48 | settings_t() = default; |
49 | |
50 | // ctor to save certain fields from resetting |
51 | settings_t(const char *perf_template) : settings_t() { |
52 | this->perf_template = perf_template; |
53 | } |
54 | |
55 | prb_dims_t prb_dims; |
56 | |
57 | std::vector<dir_t> dir {FWD_D}; |
58 | std::vector<std::vector<dnnl_data_type_t>> dt {{dnnl_f32}}; |
59 | std::vector<std::vector<std::string>> tag {{tag::abx, tag::any}}; |
60 | std::vector<std::string> stat_tag {tag::any}; |
61 | std::vector<flags_t> flags {NONE}; |
62 | check_alg_t check_alg = check_alg_t::ALG_AUTO; |
63 | |
64 | const char *perf_template_csv() const { |
65 | static const std::string args = "%dir%,%dt%,%tag%,%stat_tag%,%flags%" ; |
66 | return perf_template_csv_base(args); |
67 | } |
68 | |
69 | void reset() { *this = settings_t(perf_template); } |
70 | }; |
71 | |
72 | struct prb_t : public prb_dims_t { |
73 | prb_t(const prb_dims_t &prb_dims, const std::vector<std::string> &tag, |
74 | const std::string &stat_tag, dir_t dir, |
75 | const std::vector<dnnl_data_type_t> &dt, flags_t flags, |
76 | const attr_t &attr, const thr_ctx_t &ctx_init, |
77 | const thr_ctx_t &ctx_exe, bool inplace, check_alg_t check_alg) |
78 | : prb_dims_t(prb_dims) |
79 | , check_alg(check_alg) |
80 | , tag(tag) |
81 | , stat_tag(stat_tag) |
82 | , dir(dir) |
83 | , dt(dt) |
84 | , flags(flags) |
85 | , inplace(inplace) |
86 | , attr(attr) |
87 | , ctx_init(ctx_init) |
88 | , ctx_exe(ctx_exe) { |
89 | n = 1; |
90 | for (int d = 0; d < ndims - 1; d++) |
91 | n *= dims[d]; |
92 | c = dims[ndims - 1]; |
93 | eps = 1.f / 16; |
94 | |
95 | // Broadcast data types if needed |
96 | if (dt.size() == 1) { |
97 | const auto val = dt[0]; // Need a copy here. |
98 | this->dt.assign(2, val); |
99 | } |
100 | if (tag.size() == 1) { this->tag.push_back(tag::any); } |
101 | } |
102 | |
103 | check_alg_t check_alg; |
104 | std::vector<std::string> tag; |
105 | std::string stat_tag; |
106 | dir_t dir; |
107 | std::vector<dnnl_data_type_t> dt; |
108 | flags_t flags; |
109 | bool inplace; |
110 | attr_t attr; |
111 | const thr_ctx_t ctx_init, ctx_exe; |
112 | int64_t n, c; |
113 | float eps; |
114 | |
115 | bool use_sc() const { return flags & USE_SCALE; } |
116 | bool use_sh() const { return flags & USE_SHIFT; } |
117 | }; |
118 | |
119 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
120 | |
121 | struct perf_report_t : public base_perf_report_t { |
122 | perf_report_t(const prb_t *prb, const char *perf_template) |
123 | : base_perf_report_t(perf_template) |
124 | , p_(prb) |
125 | , stat_tag_(normalize_tag(p_->stat_tag, p_->ndims - 1)) { |
126 | for (size_t d = 0; d < p_->tag.size(); d++) |
127 | tag_.push_back(normalize_tag(p_->tag[d], p_->ndims)); |
128 | } |
129 | |
130 | void dump_desc(std::ostream &s) const override { |
131 | s << static_cast<const prb_dims_t &>(*p_); |
132 | } |
133 | |
134 | void dump_desc_csv(std::ostream &s) const override { dump_desc(s); } |
135 | |
136 | void dump_flags(std::ostream &s) const override { |
137 | s << flags2str(p_->flags); |
138 | } |
139 | |
140 | const attr_t *attr() const override { return &p_->attr; } |
141 | const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; } |
142 | const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; } |
143 | const std::string *name() const override { return &p_->name; } |
144 | const dir_t *dir() const override { return &p_->dir; } |
145 | const std::vector<dnnl_data_type_t> *sdt() const override { |
146 | return &p_->dt; |
147 | } |
148 | const std::vector<std::string> *stag() const override { return &tag_; } |
149 | const std::string *stat_tag() const override { return &stat_tag_; } |
150 | |
151 | private: |
152 | const prb_t *p_; |
153 | std::vector<std::string> tag_; |
154 | std::string stat_tag_; |
155 | }; |
156 | |
157 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
158 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
159 | void compute_ref(const prb_t *prb, const args_t &args, |
160 | dnnl_primitive_t prim_ref = nullptr); |
161 | |
162 | int doit(const prb_t *prb, res_t *res); |
163 | int bench(int argc, char **argv); |
164 | |
165 | } // namespace lnorm |
166 | |
167 | #endif |
168 | |