1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef RUY_RUY_BLOCKING_COUNTER_H_
17#define RUY_RUY_BLOCKING_COUNTER_H_
18
19#include <atomic>
20#include <condition_variable> // NOLINT(build/c++11) // IWYU pragma: keep
21#include <mutex> // NOLINT(build/c++11) // IWYU pragma: keep
22
23#include "ruy/time.h"
24
25namespace ruy {
26
27// A BlockingCounter lets one thread to wait for N events to occur.
28// This is how the master thread waits for all the worker threads
29// to have finished working.
30// The waiting is done using a naive spinlock waiting for the atomic
31// count_ to hit the value 0. This is acceptable because in our usage
32// pattern, BlockingCounter is used only to synchronize threads after
33// short-lived tasks (performing parts of the same GEMM). It is not used
34// for synchronizing longer waits (resuming work on the next GEMM).
35class BlockingCounter {
36 public:
37 BlockingCounter() : count_(0) {}
38
39 // Sets/resets the counter; initial_count is the number of
40 // decrementing events that the Wait() call will be waiting for.
41 void Reset(int initial_count);
42
43 // Decrements the counter; if the counter hits zero, signals
44 // the threads that were waiting for that, and returns true.
45 // Otherwise (if the decremented count is still nonzero),
46 // returns false.
47 bool DecrementCount();
48
49 // Waits for the N other threads (N having been set by Reset())
50 // to hit the BlockingCounter.
51 //
52 // Will first spin-wait for `spin_duration` before reverting to passive wait.
53 void Wait(const Duration spin_duration);
54
55 private:
56 std::atomic<int> count_;
57
58 // The condition variable and mutex allowing to passively wait for count_
59 // to reach the value zero, in the case of longer waits.
60 std::condition_variable count_cond_;
61 std::mutex count_mutex_;
62};
63
64} // namespace ruy
65
66#endif // RUY_RUY_BLOCKING_COUNTER_H_
67