1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CONV_TENSOR_CONFIG_HPP
18#define GPU_JIT_CONV_TENSOR_CONFIG_HPP
19
20#include <vector>
21
22#include "gpu/jit/ir/tensor.hpp"
23#include "gpu/jit/utils/utils.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28namespace jit {
29
30struct tensor_info_t {
31 std::string name;
32 int arg_key;
33 bool is_input;
34 bool is_output;
35 layout_t compute_layout;
36 layout_t user_layout;
37
38 bool needs_reorder;
39 bool needs_zero_out;
40};
41
42class tensor_config_t {
43public:
44 const std::vector<tensor_info_t> &tensors() const { return tensors_; }
45
46 void add_tensor(const std::string &name, int arg_key, bool is_input,
47 bool is_output, const layout_t &user_layout) {
48 tensors_.emplace_back();
49 auto &t = tensors_.back();
50 t.name = name;
51 t.arg_key = arg_key;
52 t.is_input = is_input;
53 t.is_output = is_output;
54 t.compute_layout = user_layout;
55 t.user_layout = user_layout;
56 t.needs_reorder = false;
57 t.needs_zero_out = false;
58 }
59
60 void add_tensor(const std::string &name, int arg_key, bool is_input,
61 bool is_output, const layout_t &compute_layout,
62 const layout_t &user_layout) {
63 tensors_.emplace_back();
64 auto &t = tensors_.back();
65 t.name = name;
66 t.arg_key = arg_key;
67 t.is_input = is_input;
68 t.is_output = is_output;
69 t.compute_layout = compute_layout;
70 t.user_layout = user_layout;
71 t.needs_reorder = (t.compute_layout != t.user_layout);
72 t.needs_zero_out = false;
73 }
74
75 void set_compute_layout(
76 const std::string &name, const layout_t &compute_layout) {
77 auto &t = find_tensor(name);
78 t.compute_layout = compute_layout;
79 t.needs_reorder = (t.compute_layout != t.user_layout);
80 }
81
82 const layout_t &compute_layout(const std::string &name) const {
83 return find_tensor(name).compute_layout;
84 }
85
86 const layout_t &user_layout(const std::string &name) const {
87 return find_tensor(name).user_layout;
88 }
89
90 void require_zero_out(const std::string &name) {
91 auto &t = find_tensor(name);
92 t.needs_zero_out = true;
93 }
94
95private:
96 const tensor_info_t &find_tensor(const std::string &name) const {
97 for (auto &t : tensors_) {
98 if (t.name == name) return t;
99 }
100 ir_error_not_expected() << "Can't find tensor " << name;
101 return tensors_.front();
102 }
103
104 tensor_info_t &find_tensor(const std::string &name) {
105 auto *const_this = const_cast<const tensor_config_t *>(this);
106 return const_cast<tensor_info_t &>(const_this->find_tensor(name));
107 }
108
109 std::vector<tensor_info_t> tensors_;
110};
111
112} // namespace jit
113} // namespace gpu
114} // namespace impl
115} // namespace dnnl
116
117#endif
118