1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "module_equality.h"
20
21#include <tvm/ir/module.h>
22#include <tvm/node/structural_equal.h>
23#include <tvm/node/structural_hash.h>
24#include <tvm/tir/analysis.h>
25
26#include <memory>
27
28#include "../node/ndarray_hash_equal.h"
29
30namespace tvm {
31namespace meta_schedule {
32
33class ModuleEqualityStructural : public ModuleEquality {
34 public:
35 size_t Hash(IRModule mod) const { return tvm::StructuralHash()(mod); }
36 bool Equal(IRModule lhs, IRModule rhs) const { return tvm::StructuralEqual()(lhs, rhs); }
37};
38
39class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault {
40 public:
41 SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr, false) {}
42
43 protected:
44 bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
45 const Optional<ObjectPathPair>& current_paths) {
46 if (auto lhs_ptr = lhs.as<runtime::NDArray::Container>(),
47 rhs_ptr = rhs.as<runtime::NDArray::Container>();
48 lhs_ptr && rhs_ptr) {
49 SEqualReducer reducer(this, nullptr, map_free_vars);
50 return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, false);
51 }
52 return SEqualHandlerDefault::DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths);
53 }
54};
55
56class SHashHandlerIgnoreNDArray : public SHashHandlerDefault {
57 protected:
58 void DispatchSHash(const ObjectRef& object, bool map_free_vars) override {
59 ICHECK(object.defined());
60 if (auto ndarray = object.as<runtime::NDArray::Container>()) {
61 SHashReducer hash_reduce(this, map_free_vars);
62 NDArrayHash(ndarray, &hash_reduce, false);
63 } else {
64 SHashHandlerDefault::DispatchSHash(object, map_free_vars);
65 }
66 }
67};
68
69class ModuleEqualityIgnoreNDArray : public ModuleEquality {
70 public:
71 size_t Hash(IRModule mod) const { return SHashHandlerIgnoreNDArray().Hash(mod, false); }
72 bool Equal(IRModule lhs, IRModule rhs) const {
73 return SEqualHandlerIgnoreNDArray().Equal(lhs, rhs, false);
74 }
75};
76
77// The NDArray-ignoring variant of structural equal / hash is used for the module equality
78// on the extracted anchor blocks.
79class ModuleEqualityAnchorBlock : public ModuleEquality {
80 size_t Hash(IRModule mod) const {
81 auto anchor_block = tir::FindAnchorBlock(mod);
82 if (anchor_block) {
83 return SHashHandlerIgnoreNDArray().Hash(GetRef<tir::Block>(anchor_block), false);
84 }
85 return ModuleEqualityIgnoreNDArray().Hash(mod);
86 }
87 bool Equal(IRModule lhs, IRModule rhs) const {
88 auto anchor_block_lhs = tir::FindAnchorBlock(lhs);
89 auto anchor_block_rhs = tir::FindAnchorBlock(rhs);
90 if (anchor_block_lhs && anchor_block_rhs) {
91 return SEqualHandlerIgnoreNDArray().Equal(GetRef<tir::Block>(anchor_block_lhs),
92 GetRef<tir::Block>(anchor_block_rhs), false);
93 }
94 return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs);
95 }
96};
97
98std::unique_ptr<ModuleEquality> ModuleEquality::Create(const std::string& mod_eq_name) {
99 if (mod_eq_name == "structural") {
100 return std::make_unique<ModuleEqualityStructural>();
101 } else if (mod_eq_name == "ignore-ndarray") {
102 return std::make_unique<ModuleEqualityIgnoreNDArray>();
103 } else if (mod_eq_name == "anchor-block") {
104 return std::make_unique<ModuleEqualityAnchorBlock>();
105 }
106 LOG(FATAL) << "Unknown module equality " << mod_eq_name;
107}
108
109} // namespace meta_schedule
110} // namespace tvm
111