1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file primfunc_utils.cc
22 * \brief Passes that serve as helper functions.
23 */
24
25#include <tvm/driver/driver_api.h>
26#include <tvm/tir/transform.h>
27
28namespace tvm {
29namespace tir {
30namespace transform {
31transform::Pass BindTarget(Target target) {
32 auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
33 return WithAttr(std::move(f), tvm::attr::kTarget, target);
34 };
35 return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {});
36}
37
38transform::Pass AnnotateEntryFunc() {
39 auto fpass = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
40 ICHECK(m->functions.size() == 1);
41 return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true));
42 };
43 return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.AnnotateEntryFunc", {});
44}
45
46transform::Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond) {
47 auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
48 if (fcond(f)) {
49 return f;
50 } else {
51 return tir::PrimFunc(nullptr);
52 }
53 };
54 return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {});
55}
56
57TVM_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget);
58TVM_REGISTER_GLOBAL("tir.transform.AnnotateEntryFunc").set_body_typed(AnnotateEntryFunc);
59TVM_REGISTER_GLOBAL("tir.transform.Filter").set_body_typed(Filter);
60
61} // namespace transform
62} // namespace tir
63} // namespace tvm
64