1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_GPU_BATCH_NORMALIZATION_PD_HPP
18#define GPU_GPU_BATCH_NORMALIZATION_PD_HPP
19
20#include <assert.h>
21
22#include "common/batch_normalization_pd.hpp"
23#include "common/c_types_map.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30
31namespace {
32template <typename pd_t>
33inline void gpu_init_default_ws(pd_t *self, memory_desc_t &ws_md) {
34 auto mdw = memory_desc_wrapper(self->src_md(0));
35 ws_md = *mdw.md_;
36 ws_md.data_type = data_type::s8;
37}
38} // namespace
39
40struct gpu_batch_normalization_fwd_pd_t : public batch_normalization_fwd_pd_t {
41 using batch_normalization_fwd_pd_t::batch_normalization_fwd_pd_t;
42
43protected:
44 void init_default_ws(size_t bits_per_element) override {
45 UNUSED(bits_per_element);
46 gpu_init_default_ws(this, ws_md_);
47 }
48};
49
50struct gpu_batch_normalization_bwd_pd_t : public batch_normalization_bwd_pd_t {
51 using batch_normalization_bwd_pd_t::batch_normalization_bwd_pd_t;
52
53 void init_default_ws(size_t bits_per_element) override {
54 UNUSED(bits_per_element);
55 gpu_init_default_ws(this, ws_md_);
56 }
57};
58
59} // namespace gpu
60} // namespace impl
61} // namespace dnnl
62
63#endif
64