1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file extract_constants.cc
22 * \brief Collects PrimFunc's constant data into mod's 'tvm::attr::kConstantsArray' attrs array,
23 * sets irmod_storage_idx as index in this array.
24 * For more information, see the RFC:
25 * https://github.com/apache/tvm-rfcs/blob/main/rfcs/0022-tir-non-scalar-constants.md
26 */
27#include <tvm/arith/analyzer.h>
28#include <tvm/ir/transform.h>
29#include <tvm/runtime/registry.h>
30#include <tvm/tir/stmt_functor.h>
31
32#include "ir_utils.h"
33
34namespace tvm {
35namespace tir {
36
37using ConstArrayType = Array<runtime::NDArray>;
38class Applicator : public tir::StmtMutator {
39 protected:
40 // returns index of the a in constant_array_, if not found - appends
41 size_t DeDup(const runtime::NDArray& a) {
42 tvm::SEqualReducer eql;
43 auto it = std::find_if(
44 constant_array_.begin(), constant_array_.end(), [&eql, a](const runtime::NDArray& v) {
45 return NDArrayContainerTrait::SEqualReduce(a.as<runtime::NDArray::Container>(),
46 v.as<runtime::NDArray::Container>(), eql);
47 });
48 if (it != constant_array_.end()) {
49 return it - constant_array_.begin();
50 }
51 constant_array_.push_back(std::move(a));
52 return constant_array_.size() - 1;
53 }
54
55 public:
56 Stmt Apply(tir::Stmt body, const ConstArrayType& constant_array) {
57 constant_array_ = constant_array;
58 return this->VisitStmt(body);
59 }
60
61 Stmt VisitStmt_(const tir::AllocateConstNode* acn) override {
62 // Check whether the data already defined within the module's attrs
63 // and add array index.
64 ICHECK(acn->data) << "data field should be defined";
65 auto node = CopyOnWrite(acn);
66 node->irmod_storage_idx = Optional<Integer>(Integer(DeDup(node->data.value())));
67 return Stmt(node);
68 }
69
70 ConstArrayType constant_array_;
71};
72
73namespace transform {
74
75tvm::transform::Pass ExtractPrimFuncConstants() {
76 auto prim_func_pass = [=](PrimFunc foo, IRModule m, tvm::transform::PassContext ctx) {
77 auto* func = foo.CopyOnWrite();
78 if (!m->attrs.defined()) {
79 m->attrs = DictAttrs(Map<String, ObjectRef>());
80 }
81 auto* attrs = m->attrs.CopyOnWrite();
82 ConstArrayType constant_array_ =
83 (attrs->dict.count(tvm::attr::kConstants))
84 ? Downcast<ConstArrayType>(attrs->dict[tvm::attr::kConstants])
85 : ConstArrayType();
86 Applicator a = Applicator();
87 func->body = a.Apply(func->body, constant_array_);
88 const ConstArrayType constant_list = a.constant_array_;
89 if (constant_list.size()) {
90 attrs->dict.Set(tvm::attr::kConstants, constant_list);
91 }
92 return GetRef<PrimFunc>(func);
93 };
94
95 auto pass_func = [=](IRModule module, tvm::transform::PassContext pc) {
96 auto m = GetRef<IRModule>(module.CopyOnWrite());
97 for (const auto& kv : m->functions) {
98 BaseFunc f = kv.second;
99 if (f->IsInstance<PrimFuncNode>()) {
100 m->Update(kv.first, prim_func_pass(GetRef<PrimFunc>(f.as<PrimFuncNode>()), m, pc));
101 }
102 }
103 return m;
104 };
105
106 return tvm::transform::CreateModulePass(pass_func, 0, "tir.ExtractPrimFuncConstants", {});
107}
108
109TVM_REGISTER_GLOBAL("tir.transform.ExtractPrimFuncConstants")
110 .set_body_typed(ExtractPrimFuncConstants);
111
112} // namespace transform
113
114} // namespace tir
115} // namespace tvm
116