1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ |
16 | #define TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ |
17 | |
18 | #include <memory> |
19 | #include <string> |
20 | |
21 | #include "tensorflow/core/framework/resource_mgr.h" |
22 | #include "tensorflow/core/lib/gtl/array_slice.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | class Summary; |
27 | class SummaryWriterInterface; |
28 | namespace data { |
29 | |
30 | // A `StatsAggregator` accumulates statistics incrementally. A |
31 | // `StatsAggregator` can accumulate multiple different statistics, distinguished |
32 | // by a string name. |
33 | // |
34 | // The class currently supports accumulating `Histogram`, `scalar` objects and |
35 | // tfstreamz metrics, and we expect to add other methods in future. |
36 | // |
37 | // NOTE(mrry): `StatsAggregator` is a virtual interface because we anticipate |
38 | // that many different implementations will have the same interface. For |
39 | // example, we have different implementations in "stats_aggregator_ops.cc" for |
40 | // simple in-memory implementation that integrates with the pull-based summary |
41 | // API, and for the push-based `SummaryWriterInterface`, and we may add |
42 | // implementations that work well with other custom monitoring services. |
43 | class StatsAggregator { |
44 | public: |
45 | virtual ~StatsAggregator() {} |
46 | |
47 | // Add the given `values` to the histogram with the given `name`. Each |
48 | // element of `values` will be treated as a separate sample in the histogram. |
49 | virtual void AddToHistogram(const string& name, |
50 | gtl::ArraySlice<double> values, |
51 | int64_t global_step) = 0; |
52 | |
53 | // TODO(shivaniagrawal): consistency in double and float usage. |
54 | // Add the given `value` as Scalar with the given `name`. |
55 | virtual void AddScalar(const string& name, float value, |
56 | int64_t global_step) = 0; |
57 | |
58 | // Stores a protocol buffer representation of the aggregator state in the |
59 | // given `out_summary`. |
60 | virtual void EncodeToProto(Summary* out_summary) = 0; |
61 | |
62 | // Sets a `summary_writer` with this stats_aggregator. |
63 | virtual Status SetSummaryWriter(SummaryWriterInterface* summary_writer) = 0; |
64 | |
65 | // Increment the `label` cell of metrics mapped with `name` by given `value`. |
66 | virtual void IncrementCounter(const string& name, const string& label, |
67 | int64_t val) = 0; |
68 | }; |
69 | |
70 | // A `StatsAggregatorResource` wraps a sharable `StatsAggregator` as a resource |
71 | // in the TensorFlow resource manager. |
72 | // |
73 | // NOTE(mrry): This class is separate from `StatsAggregator` in order to |
74 | // simplify the memory management of the shared object. Most users of |
75 | // `StatsAggregator` interact with a `std::shared_ptr<StatsAggregator>` whereas |
76 | // the `ResourceBase` API requires explicit reference counting. |
77 | class StatsAggregatorResource : public ResourceBase { |
78 | public: |
79 | // Creates a new resource from the given `stats_aggregator`. |
80 | StatsAggregatorResource(std::unique_ptr<StatsAggregator> stats_aggregator) |
81 | : stats_aggregator_(stats_aggregator.release()) {} |
82 | |
83 | // Returns the wrapped `StatsAggregator`. |
84 | std::shared_ptr<StatsAggregator> stats_aggregator() const { |
85 | return stats_aggregator_; |
86 | } |
87 | |
88 | string DebugString() const override { return "StatsAggregatorResource" ; } |
89 | |
90 | private: |
91 | const std::shared_ptr<StatsAggregator> stats_aggregator_; |
92 | }; |
93 | |
94 | } // namespace data |
95 | } // namespace tensorflow |
96 | |
97 | #endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ |
98 | |