1#pragma once
2
3#include <map>
4#include <mutex>
5#include <vector>
6
7#include "taichi/common/core.h"
8
9namespace taichi {
10
11class RefCount {
12 public:
13 void inc() {
14 ref_count_++;
15 }
16 int dec() {
17 return --ref_count_;
18 }
19 int count() {
20 return ref_count_;
21 }
22
23 private:
24 int ref_count_{1};
25};
26
27template <class T, bool sync>
28class RefCountedPool {
29 public:
30 void inc(T obj) {
31 if constexpr (sync) {
32 gc_pool_lock_.lock();
33 }
34
35 auto iter = counts_.find(obj);
36
37 if (iter == counts_.end()) {
38 counts_[obj] = RefCount();
39 } else {
40 iter->second.inc();
41 }
42
43 if constexpr (sync) {
44 gc_pool_lock_.unlock();
45 }
46 }
47
48 void dec(T obj) {
49 if constexpr (sync) {
50 gc_pool_lock_.lock();
51 }
52
53 auto iter = counts_.find(obj);
54
55 if (iter == counts_.end()) {
56 TI_ERROR("Can not find counted reference");
57 } else {
58 int c = iter->second.dec();
59 if (c == 0) {
60 gc_pool_.push_back(iter->first);
61 counts_.erase(iter);
62 }
63 }
64
65 if constexpr (sync) {
66 gc_pool_lock_.unlock();
67 }
68 }
69
70 T gc_pop_one(T null) {
71 if constexpr (sync) {
72 gc_pool_lock_.lock();
73 }
74
75 T obj = null;
76
77 if (gc_pool_.size()) {
78 obj = gc_pool_.back();
79 gc_pool_.pop_back();
80 }
81
82 if constexpr (sync) {
83 gc_pool_lock_.unlock();
84 }
85
86 return obj;
87 }
88
89 void gc_remove_all(std::function<void(T)> deallocator) {
90 std::lock_guard<std::mutex> lg(gc_pool_lock_);
91
92 for (T obj : gc_pool_) {
93 deallocator(obj);
94 }
95 gc_pool_.clear();
96 }
97
98 private:
99 std::unordered_map<T, RefCount> counts_;
100 std::vector<T> gc_pool_;
101 std::mutex gc_pool_lock_;
102};
103
104} // namespace taichi
105