1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_META_SCHEDULE_MODULE_EQUALITY_H_
20#define TVM_META_SCHEDULE_MODULE_EQUALITY_H_
21
22#include <tvm/ir/module.h>
23
24#include <memory>
25#include <string>
26
27namespace tvm {
28namespace meta_schedule {
29
30/*! \brief Method to compute hash and determine equality of modules */
31class ModuleEquality {
32 public:
33 virtual ~ModuleEquality() = default;
34
35 virtual size_t Hash(IRModule mod) const = 0;
36 virtual bool Equal(IRModule lhs, IRModule rhs) const = 0;
37
38 /*!
39 * \brief Create a ModuleEquality instance
40 * \param mod_eq_name A string to specify the module equality testing and hashing method.
41 * It must be one of the followings:
42 * - "structural": Use StructuralEqual/Hash
43 * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
44 * equality testing and hashing.
45 * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
46 * given module. The "ignore-ndarray" varint is used for the extracted blocks
47 * or in case no anchor block is found.
48 * For the definition of the anchor block, see tvm/tir/analysis.h.
49 * \return An owning pointer to the created instance
50 */
51 static std::unique_ptr<ModuleEquality> Create(const std::string& mod_eq_name);
52};
53
54/*! \brief Functor to compute hash a module using the provided method. */
55class ModuleHash {
56 public:
57 explicit ModuleHash(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
58 size_t operator()(const IRModule& mod) const { return mod_eq_.Hash(mod); }
59
60 private:
61 const ModuleEquality& mod_eq_;
62};
63
64/*! \brief Functor to determine equality of modules using the provided method. */
65class ModuleEqual {
66 public:
67 explicit ModuleEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
68 bool operator()(const IRModule& lhs, const IRModule& rhs) const {
69 return mod_eq_.Equal(lhs, rhs);
70 }
71
72 private:
73 const ModuleEquality& mod_eq_;
74};
75
76} // namespace meta_schedule
77} // namespace tvm
78
79#endif // TVM_META_SCHEDULE_MODULE_EQUALITY_H_
80