1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_META_SCHEDULE_ARG_INFO_H_
20#define TVM_META_SCHEDULE_ARG_INFO_H_
21
22#include <tvm/ir/module.h>
23#include <tvm/node/node.h>
24#include <tvm/node/reflection.h>
25#include <tvm/runtime/container/shape_tuple.h>
26#include <tvm/runtime/data_type.h>
27#include <tvm/runtime/object.h>
28#include <tvm/tir/function.h>
29
30namespace tvm {
31namespace meta_schedule {
32
33/*! \brief The argument information. */
34class ArgInfoNode : public runtime::Object {
35 public:
36 static constexpr const char* _type_key = "meta_schedule.ArgInfo";
37 TVM_DECLARE_BASE_OBJECT_INFO(ArgInfoNode, runtime::Object);
38
39 public:
40 /*! \brief Default destructor. */
41 virtual ~ArgInfoNode() = default;
42 /*! \brief Converts the ArgInfo to its corresponding JSON representation. */
43 virtual ObjectRef AsJSON() const = 0;
44};
45
46/*!
47 * \brief Managed reference to ArgInfoNode
48 * \sa ArgInfoNode
49 */
50class ArgInfo : public runtime::ObjectRef {
51 public:
52 /*!
53 * \brief Parse the argument information from a JSON object.
54 * \param json_obj The json object to parse.
55 * \return The argument information parsed.
56 */
57 TVM_DLL static ArgInfo FromJSON(const ObjectRef& json_obj);
58 /*!
59 * \brief Extract a list of the argument information from PrimFunc.
60 * \param func The PrimFunc to get argument information from.
61 * \return An array of the argument information derived.
62 */
63 TVM_DLL static Array<ArgInfo, void> FromPrimFunc(const tir::PrimFunc& func);
64 /*!
65 * \brief Extract a list of the argument information from the entry func of an IRModule
66 * \param mod The IRModule to extract argument information from.
67 * \param remove_preproc Whether to remove the preprocessing blocks.
68 * \return An array of the argument information derived.
69 */
70 TVM_DLL static Array<ArgInfo, void> FromEntryFunc(const IRModule& mod, bool remove_preproc);
71
72 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ArgInfo, runtime::ObjectRef, ArgInfoNode);
73
74 protected:
75 ArgInfo() = default;
76};
77
78/*! \brief The tensor argument information. */
79class TensorInfoNode : public ArgInfoNode {
80 public:
81 /*! \brief The data type of the tensor. */
82 runtime::DataType dtype;
83 /*! \brief The shape of the tensor. */
84 runtime::ShapeTuple shape;
85
86 void VisitAttrs(tvm::AttrVisitor* v) {
87 v->Visit("dtype", &dtype);
88 v->Visit("shape", &shape);
89 }
90
91 static constexpr const char* _type_key = "meta_schedule.TensorInfo";
92 TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, ArgInfoNode);
93
94 public:
95 ObjectRef AsJSON() const;
96};
97
98/*!
99 * \brief Managed reference to TensorInfoNode
100 * \sa TensorInfoNode
101 */
102class TensorInfo : public ArgInfo {
103 public:
104 /*!
105 * \brief Constructor of TensorInfo.
106 * \param dtype The data type of the tensor argument.
107 * \param shape The shape tuple of the tensor argument.
108 */
109 TVM_DLL explicit TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape);
110 /*!
111 * \brief Parse the argument information from a JSON object.
112 * \param json_obj The json object to parse.
113 * \return The argument information parsed.
114 */
115 TVM_DLL static TensorInfo FromJSON(const ObjectRef& json_obj);
116 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorInfo, ArgInfo, TensorInfoNode);
117};
118
119} // namespace meta_schedule
120} // namespace tvm
121
122#endif // TVM_META_SCHEDULE_ARG_INFO_H_
123