1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "module_equality.h" |
20 | |
21 | #include <tvm/ir/module.h> |
22 | #include <tvm/node/structural_equal.h> |
23 | #include <tvm/node/structural_hash.h> |
24 | #include <tvm/tir/analysis.h> |
25 | |
26 | #include <memory> |
27 | |
28 | #include "../node/ndarray_hash_equal.h" |
29 | |
30 | namespace tvm { |
31 | namespace meta_schedule { |
32 | |
33 | class ModuleEqualityStructural : public ModuleEquality { |
34 | public: |
35 | size_t Hash(IRModule mod) const { return tvm::StructuralHash()(mod); } |
36 | bool Equal(IRModule lhs, IRModule rhs) const { return tvm::StructuralEqual()(lhs, rhs); } |
37 | }; |
38 | |
39 | class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault { |
40 | public: |
41 | SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr, false) {} |
42 | |
43 | protected: |
44 | bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, |
45 | const Optional<ObjectPathPair>& current_paths) { |
46 | if (auto lhs_ptr = lhs.as<runtime::NDArray::Container>(), |
47 | rhs_ptr = rhs.as<runtime::NDArray::Container>(); |
48 | lhs_ptr && rhs_ptr) { |
49 | SEqualReducer reducer(this, nullptr, map_free_vars); |
50 | return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, false); |
51 | } |
52 | return SEqualHandlerDefault::DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths); |
53 | } |
54 | }; |
55 | |
56 | class SHashHandlerIgnoreNDArray : public SHashHandlerDefault { |
57 | protected: |
58 | void DispatchSHash(const ObjectRef& object, bool map_free_vars) override { |
59 | ICHECK(object.defined()); |
60 | if (auto ndarray = object.as<runtime::NDArray::Container>()) { |
61 | SHashReducer hash_reduce(this, map_free_vars); |
62 | NDArrayHash(ndarray, &hash_reduce, false); |
63 | } else { |
64 | SHashHandlerDefault::DispatchSHash(object, map_free_vars); |
65 | } |
66 | } |
67 | }; |
68 | |
69 | class ModuleEqualityIgnoreNDArray : public ModuleEquality { |
70 | public: |
71 | size_t Hash(IRModule mod) const { return SHashHandlerIgnoreNDArray().Hash(mod, false); } |
72 | bool Equal(IRModule lhs, IRModule rhs) const { |
73 | return SEqualHandlerIgnoreNDArray().Equal(lhs, rhs, false); |
74 | } |
75 | }; |
76 | |
77 | // The NDArray-ignoring variant of structural equal / hash is used for the module equality |
78 | // on the extracted anchor blocks. |
79 | class ModuleEqualityAnchorBlock : public ModuleEquality { |
80 | size_t Hash(IRModule mod) const { |
81 | auto anchor_block = tir::FindAnchorBlock(mod); |
82 | if (anchor_block) { |
83 | return SHashHandlerIgnoreNDArray().Hash(GetRef<tir::Block>(anchor_block), false); |
84 | } |
85 | return ModuleEqualityIgnoreNDArray().Hash(mod); |
86 | } |
87 | bool Equal(IRModule lhs, IRModule rhs) const { |
88 | auto anchor_block_lhs = tir::FindAnchorBlock(lhs); |
89 | auto anchor_block_rhs = tir::FindAnchorBlock(rhs); |
90 | if (anchor_block_lhs && anchor_block_rhs) { |
91 | return SEqualHandlerIgnoreNDArray().Equal(GetRef<tir::Block>(anchor_block_lhs), |
92 | GetRef<tir::Block>(anchor_block_rhs), false); |
93 | } |
94 | return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs); |
95 | } |
96 | }; |
97 | |
98 | std::unique_ptr<ModuleEquality> ModuleEquality::Create(const std::string& mod_eq_name) { |
99 | if (mod_eq_name == "structural" ) { |
100 | return std::make_unique<ModuleEqualityStructural>(); |
101 | } else if (mod_eq_name == "ignore-ndarray" ) { |
102 | return std::make_unique<ModuleEqualityIgnoreNDArray>(); |
103 | } else if (mod_eq_name == "anchor-block" ) { |
104 | return std::make_unique<ModuleEqualityAnchorBlock>(); |
105 | } |
106 | LOG(FATAL) << "Unknown module equality " << mod_eq_name; |
107 | } |
108 | |
109 | } // namespace meta_schedule |
110 | } // namespace tvm |
111 | |