1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef LRN_HPP
18#define LRN_HPP
19
20#include <assert.h>
21#include <limits.h>
22#include <stdint.h>
23
24#include <iostream>
25
26#include "common.hpp"
27#include "dnn_types.hpp"
28#include "dnnl_common.hpp"
29#include "dnnl_debug.hpp"
30#include "utils/perf_report.hpp"
31#include "utils/settings.hpp"
32
33namespace lrn {
34
35enum alg_t { ACROSS, WITHIN };
36alg_t str2alg(const char *str);
37const char *alg2str(alg_t alg);
38dnnl_alg_kind_t alg2alg_kind(alg_t alg);
39
40struct desc_t {
41 int64_t mb, ic, id, ih, iw;
42 int64_t ls;
43 float alpha, beta, k;
44 std::string name;
45 int ndims;
46};
47int str2desc(desc_t *desc, const char *str);
48std::ostream &operator<<(std::ostream &s, const desc_t &d);
49
50struct settings_t : public base_settings_t {
51 settings_t() = default;
52
53 // ctor to save certain fields from resetting
54 settings_t(const char *perf_template) : settings_t() {
55 this->perf_template = perf_template;
56 }
57
58 desc_t desc {};
59
60 std::vector<dir_t> dir {FWD_D};
61 std::vector<dnnl_data_type_t> dt {dnnl_f32};
62 std::vector<std::string> tag {tag::abx};
63 std::vector<alg_t> alg {ACROSS};
64
65 const char *perf_template_csv() const {
66 static const std::string args = "%dir%,%dt%,%tag%,%alg%";
67 return perf_template_csv_base(args);
68 }
69
70 void reset() { *this = settings_t(perf_template); }
71};
72
73struct prb_t : public desc_t {
74 prb_t(const desc_t &desc, int64_t mb, dir_t dir, dnnl_data_type_t dt,
75 const std::string &tag, alg_t alg, const attr_t &attr,
76 const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe)
77 : desc_t(desc)
78 , dir(dir)
79 , dt(dt)
80 , tag(tag)
81 , alg(alg)
82 , attr(attr)
83 , ctx_init(ctx_init)
84 , ctx_exe(ctx_exe)
85 , user_mb(mb) {
86 if (mb) this->mb = mb;
87 }
88 ~prb_t() {}
89
90 dir_t dir;
91 dnnl_data_type_t dt;
92 std::string tag;
93 alg_t alg;
94 attr_t attr;
95 thr_ctx_t ctx_init, ctx_exe;
96 int64_t user_mb;
97
98 BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t);
99};
100std::ostream &operator<<(std::ostream &s, const prb_t &prb);
101
102struct perf_report_t : public base_perf_report_t {
103 perf_report_t(const prb_t *prb, const char *perf_template)
104 : base_perf_report_t(perf_template)
105 , p_(prb)
106 , tag_(normalize_tag(p_->tag, p_->ndims)) {}
107
108 void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); }
109
110 void dump_desc(std::ostream &s) const override {
111 s << static_cast<const desc_t &>(*p_);
112 }
113
114 void dump_desc_csv(std::ostream &s) const override {
115 s << p_->mb << ',' << p_->ic << ',' << p_->id << ',' << p_->ih << ','
116 << p_->iw << ',' << p_->ls << ',' << p_->alpha << ',' << p_->beta
117 << ',' << p_->k;
118 }
119
120 const int64_t *user_mb() const override { return &p_->user_mb; }
121 const attr_t *attr() const override { return &p_->attr; }
122 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
123 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
124 const std::string *name() const override { return &p_->name; }
125 const dir_t *dir() const override { return &p_->dir; }
126 const dnnl_data_type_t *dt() const override { return &p_->dt; }
127 const std::string *tag() const override { return &tag_; }
128
129private:
130 const prb_t *p_;
131 std::string tag_;
132};
133
134inline int compute_n_summands(const prb_t *prb) {
135 if (prb->alg == ACROSS) {
136 return prb->ls;
137 } else if (prb->alg == WITHIN) {
138 int n_summands = 1;
139 for (int64_t d = prb->ndims - 2; d > 0; --d)
140 n_summands *= prb->ls;
141 return n_summands;
142 } else {
143 assert(!"unknown algorithm");
144 return 1;
145 }
146}
147
148inline size_t data_off(const prb_t *prb, int64_t mb, int64_t c, int64_t d,
149 int64_t h, int64_t w) {
150 return (((mb * prb->ic + c) * prb->id + d) * prb->ih + h) * prb->iw + w;
151}
152
153void skip_unimplemented_prb(const prb_t *prb, res_t *res);
154void skip_invalid_prb(const prb_t *prb, res_t *res);
155void compute_ref(const prb_t *prb, const args_t &args,
156 dnnl_primitive_t prim_ref = nullptr);
157
158int doit(const prb_t *prb, res_t *res);
159int bench(int argc, char **argv);
160
161} // namespace lrn
162
163#endif
164