1/*******************************************************************************
2* Copyright 2019-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_BATCH_NORMALIZATION_S8_HPP
18#define CPU_X64_JIT_UNI_BATCH_NORMALIZATION_S8_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_batch_normalization_pd.hpp"
28#include "cpu/x64/cpu_isa_traits.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35namespace bnorm_s8_impl {
36template <cpu_isa_t isa>
37struct driver_t;
38}
39
40template <cpu_isa_t isa>
41struct jit_uni_batch_normalization_s8_fwd_t : public primitive_t {
42 struct pd_t : public cpu_batch_normalization_fwd_pd_t {
43 pd_t(const batch_normalization_desc_t *adesc,
44 const primitive_attr_t *attr,
45 const batch_normalization_fwd_pd_t *hint_fwd_pd)
46 : cpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
47
48 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("bnorm_s8_jit:", isa, ""),
49 jit_uni_batch_normalization_s8_fwd_t);
50
51 status_t init(engine_t *engine);
52 };
53
54 typedef int8_t data_t;
55
56 jit_uni_batch_normalization_s8_fwd_t(const pd_t *apd);
57 ~jit_uni_batch_normalization_s8_fwd_t();
58
59 status_t init(engine_t *engine) override;
60
61 status_t execute(const exec_ctx_t &ctx) const override;
62
63private:
64 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
65
66 bnorm_s8_impl::driver_t<isa> *bnorm_driver_;
67};
68
69} // namespace x64
70} // namespace cpu
71} // namespace impl
72} // namespace dnnl
73
74#endif
75
76// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
77