1/*******************************************************************************
2* Copyright 2019-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils.hpp"
18
19#include "convolution_pd.hpp"
20
21namespace dnnl {
22namespace impl {
23
24using namespace prop_kind;
25
26memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) {
27 return desc->prop_kind == backward_data ? &desc->diff_src_desc
28 : &desc->src_desc;
29}
30
31memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) {
32 return desc->prop_kind == backward_weights ? &desc->diff_weights_desc
33 : &desc->weights_desc;
34}
35
36memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) {
37 return desc->prop_kind == backward_weights ? &desc->diff_bias_desc
38 : &desc->bias_desc;
39}
40
41memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) {
42 return utils::one_of(desc->prop_kind, forward_inference, forward_training)
43 ? &desc->dst_desc
44 : &desc->diff_dst_desc;
45}
46
47const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc) {
48 return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc));
49}
50const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc) {
51 return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc));
52}
53const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc) {
54 return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc));
55}
56const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc) {
57 return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc));
58}
59
60} // namespace impl
61} // namespace dnnl
62