1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef LNORM_HPP
18#define LNORM_HPP
19
20#include <assert.h>
21#include <limits.h>
22#include <numeric>
23#include <stdint.h>
24
25#include <iostream>
26
27#include "common.hpp"
28#include "dnn_types.hpp"
29#include "dnnl_common.hpp"
30#include "dnnl_debug.hpp"
31#include "utils/perf_report.hpp"
32#include "utils/settings.hpp"
33
34#include "bnorm/bnorm.hpp"
35
36namespace lnorm {
37
38using check_alg_t = bnorm::check_alg_t;
39using flags_t = bnorm::flags_t;
40const flags_t NONE = bnorm::NONE;
41const flags_t GLOB_STATS = bnorm::GLOB_STATS;
42const flags_t USE_SCALE = bnorm::USE_SCALE;
43const flags_t USE_SHIFT = bnorm::USE_SHIFT;
44const auto flags2str = bnorm::flags2str;
45flags_t str2flags(const char *str);
46
47struct settings_t : public base_settings_t {
48 settings_t() = default;
49
50 // ctor to save certain fields from resetting
51 settings_t(const char *perf_template) : settings_t() {
52 this->perf_template = perf_template;
53 }
54
55 prb_dims_t prb_dims;
56
57 std::vector<dir_t> dir {FWD_D};
58 std::vector<std::vector<dnnl_data_type_t>> dt {{dnnl_f32}};
59 std::vector<std::vector<std::string>> tag {{tag::abx, tag::any}};
60 std::vector<std::string> stat_tag {tag::any};
61 std::vector<flags_t> flags {NONE};
62 check_alg_t check_alg = check_alg_t::ALG_AUTO;
63
64 const char *perf_template_csv() const {
65 static const std::string args = "%dir%,%dt%,%tag%,%stat_tag%,%flags%";
66 return perf_template_csv_base(args);
67 }
68
69 void reset() { *this = settings_t(perf_template); }
70};
71
72struct prb_t : public prb_dims_t {
73 prb_t(const prb_dims_t &prb_dims, const std::vector<std::string> &tag,
74 const std::string &stat_tag, dir_t dir,
75 const std::vector<dnnl_data_type_t> &dt, flags_t flags,
76 const attr_t &attr, const thr_ctx_t &ctx_init,
77 const thr_ctx_t &ctx_exe, bool inplace, check_alg_t check_alg)
78 : prb_dims_t(prb_dims)
79 , check_alg(check_alg)
80 , tag(tag)
81 , stat_tag(stat_tag)
82 , dir(dir)
83 , dt(dt)
84 , flags(flags)
85 , inplace(inplace)
86 , attr(attr)
87 , ctx_init(ctx_init)
88 , ctx_exe(ctx_exe) {
89 n = 1;
90 for (int d = 0; d < ndims - 1; d++)
91 n *= dims[d];
92 c = dims[ndims - 1];
93 eps = 1.f / 16;
94
95 // Broadcast data types if needed
96 if (dt.size() == 1) {
97 const auto val = dt[0]; // Need a copy here.
98 this->dt.assign(2, val);
99 }
100 if (tag.size() == 1) { this->tag.push_back(tag::any); }
101 }
102
103 check_alg_t check_alg;
104 std::vector<std::string> tag;
105 std::string stat_tag;
106 dir_t dir;
107 std::vector<dnnl_data_type_t> dt;
108 flags_t flags;
109 bool inplace;
110 attr_t attr;
111 const thr_ctx_t ctx_init, ctx_exe;
112 int64_t n, c;
113 float eps;
114
115 bool use_sc() const { return flags & USE_SCALE; }
116 bool use_sh() const { return flags & USE_SHIFT; }
117};
118
119std::ostream &operator<<(std::ostream &s, const prb_t &prb);
120
121struct perf_report_t : public base_perf_report_t {
122 perf_report_t(const prb_t *prb, const char *perf_template)
123 : base_perf_report_t(perf_template)
124 , p_(prb)
125 , stat_tag_(normalize_tag(p_->stat_tag, p_->ndims - 1)) {
126 for (size_t d = 0; d < p_->tag.size(); d++)
127 tag_.push_back(normalize_tag(p_->tag[d], p_->ndims));
128 }
129
130 void dump_desc(std::ostream &s) const override {
131 s << static_cast<const prb_dims_t &>(*p_);
132 }
133
134 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
135
136 void dump_flags(std::ostream &s) const override {
137 s << flags2str(p_->flags);
138 }
139
140 const attr_t *attr() const override { return &p_->attr; }
141 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
142 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
143 const std::string *name() const override { return &p_->name; }
144 const dir_t *dir() const override { return &p_->dir; }
145 const std::vector<dnnl_data_type_t> *sdt() const override {
146 return &p_->dt;
147 }
148 const std::vector<std::string> *stag() const override { return &tag_; }
149 const std::string *stat_tag() const override { return &stat_tag_; }
150
151private:
152 const prb_t *p_;
153 std::vector<std::string> tag_;
154 std::string stat_tag_;
155};
156
157void skip_unimplemented_prb(const prb_t *prb, res_t *res);
158void skip_invalid_prb(const prb_t *prb, res_t *res);
159void compute_ref(const prb_t *prb, const args_t &args,
160 dnnl_primitive_t prim_ref = nullptr);
161
162int doit(const prb_t *prb, res_t *res);
163int bench(int argc, char **argv);
164
165} // namespace lnorm
166
167#endif
168