1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/op/type_relations.h
22 * \brief A set of utilities and common functionality
23 * for type relations.
24 */
25#ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_
26#define TVM_RELAY_OP_TYPE_RELATIONS_H_
27
28#include <tvm/relay/error.h>
29#include <tvm/relay/type.h>
30
31#include <string>
32
33namespace tvm {
34namespace relay {
35/*!
36 * \brief The identity type relation, all the types are equal.
37 *
38 * \param types The input and output types to the relation.
39 * \param num_inputs The number of input arguments.
40 * \param attrs The attributes
41 * \param reporter The reporter.
42 * \return true whether relation has been resolved.
43 */
44bool IdentityRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
45 const TypeReporter& reporter);
46
47/*!
48 * \brief The broadcast type relation, implements the broadcasting
49 * rule over the two input types producing the broadcasted type.
50 *
51 * \param types The input and output types to the relation.
52 * \param num_inputs The number of input arguments.
53 * \param attrs The attributes
54 * \param reporter The reporter.
55 * \return true whether relation has been resolved.
56 */
57bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
58 const TypeReporter& reporter);
59
60/*!
61 * \brief Determine the broadcasted shape from two input shapes
62 * \param t1 One of two Tensortype whose shapes are broadcasted
63 * \param t2 One of two Tensortype whose shapes are broadcasted
64 * \param output_dtype dtype of the output TensorType
65 * \return A TensorType whose shape is broadcasted from two input TensorType.
66 */
67TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype);
68
69/*!
70 * \brief The broadcast type relation, implements the broadcasting
71 * rule over the two input types producing the broadcasted type.
72 *
73 * This differs from BroadcastRel in the return dtype,
74 * it instead returns bool(uint8), for use in comparsion operators
75 * such as equal, not_equal, lt, and so on.
76 *
77 * \param types The input and output types to the relation.
78 * \param num_inputs The number of input arguments.
79 * \param attrs The attributes
80 * \param reporter The reporter.
81 * \return true whether relation has been resolved.
82 */
83bool BroadcastCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
84 const TypeReporter& reporter);
85
86bool IdentityCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
87 const TypeReporter& reporter);
88
89Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);
90
91/*!
92 * \brief The shape of type relation.
93 *
94 * \param types The input and output types to the relation.
95 * \param num_inputs The number of input arguments.
96 * \param attrs The attributes
97 * \param reporter The reporter.
98 * \return true whether relation has been resolved.
99 */
100bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
101 const TypeReporter& reporter);
102
103} // namespace relay
104} // namespace tvm
105
106#endif // TVM_RELAY_OP_TYPE_RELATIONS_H_
107