1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef SOFTMAX_HPP
18#define SOFTMAX_HPP
19
20#include <iostream>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "common.hpp"
25#include "dnn_types.hpp"
26#include "dnnl_common.hpp"
27#include "utils/perf_report.hpp"
28#include "utils/settings.hpp"
29
30namespace softmax {
31
32enum alg_t {
33 UNDEF,
34 SOFTMAX,
35 LOGSOFTMAX,
36 softmax_accurate = SOFTMAX,
37 softmax_log = LOGSOFTMAX,
38};
39alg_t str2alg(const char *str);
40const char *alg2str(alg_t alg);
41dnnl_alg_kind_t alg2alg_kind(alg_t alg);
42
43struct settings_t : public base_settings_t {
44 settings_t() = default;
45
46 // ctor to save certain fields from resetting
47 settings_t(const char *perf_template) : settings_t() {
48 this->perf_template = perf_template;
49 }
50
51 prb_dims_t prb_dims;
52
53 std::vector<dir_t> dir {FWD_D};
54 std::vector<dnnl_data_type_t> sdt {dnnl_f32}, ddt {dnnl_f32};
55 std::vector<std::string> stag {tag::abx}, dtag {tag::any};
56 std::vector<alg_t> alg {SOFTMAX};
57 std::vector<int> axis {1};
58
59 const char *perf_template_csv() const {
60 static const std::string args
61 = "%dir%,%sdt%,%ddt%,%stag%,%dtag%,%alg%,%axis%";
62 return perf_template_csv_base(args);
63 }
64
65 void reset() { *this = settings_t(perf_template); }
66};
67
68struct prb_t : public prb_dims_t {
69 prb_t(const prb_dims_t &prb_dims, dir_t dir, dnnl_data_type_t sdt,
70 dnnl_data_type_t ddt, const std::string &stag,
71 const std::string &dtag, alg_t alg, int axis, bool inplace,
72 const attr_t &attr, const thr_ctx_t &ctx_init,
73 const thr_ctx_t &ctx_exe, int64_t mb = 0)
74 : prb_dims_t(prb_dims)
75 , dir(dir)
76 , sdt(sdt)
77 , ddt(ddt)
78 , stag(stag)
79 , dtag(dtag)
80 , alg(alg)
81 , axis(axis)
82 , inplace(inplace)
83 , attr(attr)
84 , ctx_init(ctx_init)
85 , ctx_exe(ctx_exe)
86 , user_mb(mb) {
87 if (mb) dims[0] = mb;
88 }
89
90 dir_t dir;
91 dnnl_data_type_t sdt, ddt;
92 std::string stag, dtag;
93 alg_t alg;
94 int axis;
95 bool inplace;
96 attr_t attr;
97 thr_ctx_t ctx_init, ctx_exe;
98 int64_t user_mb;
99};
100std::ostream &operator<<(std::ostream &s, const prb_t &prb);
101
102struct perf_report_t : public base_perf_report_t {
103 perf_report_t(const prb_t *prb, const char *perf_template)
104 : base_perf_report_t(perf_template)
105 , p_(prb)
106 , sdt_({p_->sdt})
107 , ddt_(p_->ddt)
108 , stag_({normalize_tag(p_->stag, p_->ndims)})
109 , dtag_(normalize_tag(p_->dtag, p_->ndims)) {}
110
111 void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); }
112
113 void dump_desc(std::ostream &s) const override {
114 s << static_cast<const prb_dims_t &>(*p_);
115 }
116
117 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
118
119 const attr_t *attr() const override { return &p_->attr; }
120 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
121 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
122 const std::string *name() const override { return &p_->name; }
123 const int *axis() const override { return &p_->axis; }
124 const dir_t *dir() const override { return &p_->dir; }
125 const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; }
126 const dnnl_data_type_t *ddt() const override { return &ddt_; }
127 const int64_t *user_mb() const override { return &p_->user_mb; }
128 const std::vector<std::string> *stag() const override { return &stag_; }
129 const std::string *dtag() const override { return &dtag_; }
130
131private:
132 const prb_t *p_;
133 std::vector<dnnl_data_type_t> sdt_;
134 dnnl_data_type_t ddt_;
135 std::vector<std::string> stag_;
136 std::string dtag_;
137};
138
139inline void map_off_to_mb_ic(
140 const prb_t *prb, int64_t off, int64_t &mb, int64_t &ic) {
141 for (int i = prb->ndims - 1; i > 1; i--)
142 off /= prb->dims[i];
143
144 ic = off % prb->dims[1];
145 off /= prb->dims[1];
146 mb = off % prb->dims[0];
147 off /= prb->dims[0];
148 assert(off == 0);
149}
150
151inline void get_sizes(const prb_t *prb, int64_t &outer_size,
152 int64_t &inner_size, int64_t &axis_size) {
153 outer_size = inner_size = axis_size = 1;
154 for (int i = 0; i < prb->axis; i++)
155 outer_size *= prb->dims[i];
156 for (int i = prb->axis + 1; i < prb->ndims; i++)
157 inner_size *= prb->dims[i];
158 axis_size = prb->dims[prb->axis];
159}
160
161void skip_unimplemented_prb(const prb_t *prb, res_t *res);
162void skip_invalid_prb(const prb_t *prb, res_t *res);
163void compute_ref(const prb_t *prb, const args_t &args,
164 dnnl_primitive_t prim_ref = nullptr);
165
166int doit(const prb_t *prb, res_t *res);
167int bench(int argc, char **argv);
168
169} // namespace softmax
170
171#endif
172