1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/op/nn/convolution.h
22 * \brief Properties def of convlution operator for sharing.
23 */
24#ifndef TVM_RELAY_OP_NN_POOLING_H_
25#define TVM_RELAY_OP_NN_POOLING_H_
26
27#include <tvm/relay/attrs/nn.h>
28#include <tvm/relay/op.h>
29
30#include <utility>
31
32namespace tvm {
33namespace relay {
34
35template <typename T>
36inline Expr MakeMaxPool(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
37 Array<IndexExpr> dilation, Array<IndexExpr> padding, String layout,
38 String out_layout, bool ceil_mode, String op_name) {
39 auto attrs = make_object<T>();
40 attrs->pool_size = std::move(pool_size);
41 attrs->strides = std::move(strides);
42 attrs->dilation = std::move(dilation);
43 attrs->padding = std::move(padding);
44 attrs->layout = std::move(layout);
45 attrs->out_layout = std::move(out_layout);
46 attrs->ceil_mode = ceil_mode;
47 static const Op& op = Op::Get(op_name);
48 return Call(op, {data}, Attrs(attrs), {});
49}
50
51template <typename T>
52inline Expr MakeAvgPool(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
53 Array<IndexExpr> dilation, Array<IndexExpr> padding, String layout,
54 String out_layout, bool ceil_mode, bool count_include_pad, String op_name) {
55 auto attrs = make_object<T>();
56 attrs->pool_size = std::move(pool_size);
57 attrs->strides = std::move(strides);
58 attrs->dilation = std::move(dilation);
59 attrs->padding = std::move(padding);
60 attrs->layout = std::move(layout);
61 attrs->out_layout = std::move(out_layout);
62 attrs->ceil_mode = ceil_mode;
63 attrs->count_include_pad = count_include_pad;
64 static const Op& op = Op::Get(op_name);
65 return Call(op, {data}, Attrs(attrs), {});
66}
67
68} // namespace relay
69} // namespace tvm
70#endif // TVM_RELAY_OP_NN_POOLING_H_
71