1#pragma once
2
3#include <torch/csrc/jit/ir/ir.h>
4
5namespace torch {
6namespace jit {
7
8// return true if graph is modified
9TORCH_API bool UnrollLoops(std::shared_ptr<Graph>& graph);
10
11// Only unrolls constant loops. Will unroll them regardless of loop block size
12TORCH_API bool UnrollConstantLoops(std::shared_ptr<Graph>& graph);
13
14TORCH_API Node* PeelLoop(Node* n, size_t times);
15
16// return true if graph is modified
17TORCH_API bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph);
18
19struct TORCH_API LoopsPeeler {
20 LoopsPeeler(std::function<bool(Node* n)> callback, size_t num_iterations = 1)
21 : callback_(std::move(callback)), num_iterations_(num_iterations) {}
22
23 bool run(const std::shared_ptr<Graph>& graph);
24
25 private:
26 void collectLoop(Node* n);
27 void collectLoops(Block* block);
28 void peelLoops();
29
30 std::function<bool(Node* n)> callback_ = nullptr;
31 Node* in_loop_ = nullptr;
32 std::list<Node*> loops_to_peel_;
33 size_t num_iterations_ = 1;
34};
35} // namespace jit
36} // namespace torch
37