1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/op/type_relations.h |
22 | * \brief A set of utilities and common functionality |
23 | * for type relations. |
24 | */ |
25 | #ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_ |
26 | #define TVM_RELAY_OP_TYPE_RELATIONS_H_ |
27 | |
28 | #include <tvm/relay/error.h> |
29 | #include <tvm/relay/type.h> |
30 | |
31 | #include <string> |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | /*! |
36 | * \brief The identity type relation, all the types are equal. |
37 | * |
38 | * \param types The input and output types to the relation. |
39 | * \param num_inputs The number of input arguments. |
40 | * \param attrs The attributes |
41 | * \param reporter The reporter. |
42 | * \return true whether relation has been resolved. |
43 | */ |
44 | bool IdentityRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
45 | const TypeReporter& reporter); |
46 | |
47 | /*! |
48 | * \brief The broadcast type relation, implements the broadcasting |
49 | * rule over the two input types producing the broadcasted type. |
50 | * |
51 | * \param types The input and output types to the relation. |
52 | * \param num_inputs The number of input arguments. |
53 | * \param attrs The attributes |
54 | * \param reporter The reporter. |
55 | * \return true whether relation has been resolved. |
56 | */ |
57 | bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
58 | const TypeReporter& reporter); |
59 | |
60 | /*! |
61 | * \brief Determine the broadcasted shape from two input shapes |
62 | * \param t1 One of two Tensortype whose shapes are broadcasted |
63 | * \param t2 One of two Tensortype whose shapes are broadcasted |
64 | * \param output_dtype dtype of the output TensorType |
65 | * \return A TensorType whose shape is broadcasted from two input TensorType. |
66 | */ |
67 | TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype); |
68 | |
69 | /*! |
70 | * \brief The broadcast type relation, implements the broadcasting |
71 | * rule over the two input types producing the broadcasted type. |
72 | * |
73 | * This differs from BroadcastRel in the return dtype, |
74 | * it instead returns bool(uint8), for use in comparsion operators |
75 | * such as equal, not_equal, lt, and so on. |
76 | * |
77 | * \param types The input and output types to the relation. |
78 | * \param num_inputs The number of input arguments. |
79 | * \param attrs The attributes |
80 | * \param reporter The reporter. |
81 | * \return true whether relation has been resolved. |
82 | */ |
83 | bool BroadcastCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
84 | const TypeReporter& reporter); |
85 | |
86 | bool IdentityCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
87 | const TypeReporter& reporter); |
88 | |
89 | Array<IndexExpr> RankShape(const Array<IndexExpr>& shape); |
90 | |
91 | /*! |
92 | * \brief The shape of type relation. |
93 | * |
94 | * \param types The input and output types to the relation. |
95 | * \param num_inputs The number of input arguments. |
96 | * \param attrs The attributes |
97 | * \param reporter The reporter. |
98 | * \return true whether relation has been resolved. |
99 | */ |
100 | bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
101 | const TypeReporter& reporter); |
102 | |
103 | } // namespace relay |
104 | } // namespace tvm |
105 | |
106 | #endif // TVM_RELAY_OP_TYPE_RELATIONS_H_ |
107 | |