1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/model.h"
17
18#include <algorithm>
19#include <cmath>
20#include <memory>
21#include <queue>
22
23#include "absl/time/clock.h"
24#include "tensorflow/core/framework/cancellation.h"
25#include "tensorflow/core/framework/model.pb.h"
26#include "tensorflow/core/lib/gtl/cleanup.h"
27#include "tensorflow/core/lib/strings/str_util.h"
28#include "tensorflow/core/platform/host_info.h"
29#include "tensorflow/core/platform/mem.h"
30#include "tensorflow/core/platform/statusor.h"
31
32namespace tensorflow {
33namespace data {
34namespace model {
35
36constexpr int64_t Model::kOptimizationPeriodMinMs;
37constexpr int64_t Model::kOptimizationPeriodMaxMs;
38
39namespace {
40
41// This is the number of the latest gap times used to compute the target time
42// for stage based optimization.
43constexpr int32_t kGapTimeWindow = 100;
44// Gap time threshold: any gap time over the this duration will be dropped.
45constexpr uint64_t kGapDurationThresholdUsec = 10000000; // 10 seconds
46// In outlier computation, points that are larger than `kOutlierSigmas` standard
47// deviations are considered outliers.
48constexpr double kOutlierSigmas = 2.0;
49
50// A class to prune outliers given a set of points. To use it, instantiate an
51// object and call the `GetCleanPoints()` method.
52class OutlierPruner {
53 public:
54 explicit OutlierPruner(const std::vector<uint64_t>& points)
55 : points_(points.begin(), points.end()) {}
56
57 // Returns the remaining points after removing outliers from the original set
58 // of points.
59 std::vector<uint64_t> GetCleanPoints() {
60 if (points_.empty()) {
61 return points_;
62 }
63 // Compute the outlier threshold
64 double mean;
65 double standard_deviation;
66 ComputeMeanAndStandardDeviation(&mean, &standard_deviation);
67 double threshold = mean + standard_deviation * kOutlierSigmas;
68 std::vector<uint64_t> clean_points;
69 for (auto point : points_) {
70 if (static_cast<double>(point) > threshold) {
71 continue;
72 }
73 clean_points.push_back(point);
74 }
75 return clean_points;
76 }
77
78 private:
79 void ComputeMeanAndStandardDeviation(double* mean,
80 double* standard_deviation) {
81 uint64_t sum = std::accumulate(points_.begin(), points_.end(), 0);
82 *mean = static_cast<double>(sum) / static_cast<double>(points_.size());
83 double accum = 0.0;
84 for (auto point : points_) {
85 accum += (static_cast<double>(point) - *mean) *
86 (static_cast<double>(point) - *mean);
87 }
88 *standard_deviation = std::sqrt(accum / (points_.size() - 1));
89 }
90
91 // Points to cluster.
92 std::vector<uint64_t> points_;
93};
94
95// A priority queue that holds stage roots where the top of the priority queue
96// is the node with the largest total time.
97class ModelTimingPriorityQueue {
98 public:
99 explicit ModelTimingPriorityQueue(ModelTiming& model_timing) {
100 std::vector<std::shared_ptr<Node>> stage_roots =
101 model_timing.GetStageRoots();
102 if (stage_roots.empty()) {
103 return;
104 }
105 for (auto& root : stage_roots) {
106 DCHECK(model_timing.GetTiming(root.get()) != nullptr);
107 const ModelTiming::NodeTiming* root_timing =
108 model_timing.GetTiming(root.get());
109 stage_roots_queue_.emplace(
110 root_timing->total_time_nsec * root_timing->pipeline_ratio,
111 root.get());
112 }
113 }
114
115 // Pops the top item from the queue, i.e. node with the largest total time.
116 StatusOr<std::pair<double, Node*>> PopSlowestStageRoot() {
117 if (stage_roots_queue_.empty()) {
118 return errors::Internal(
119 "Model timing priority queue is empty during stage-based "
120 "optimization");
121 }
122 std::pair<double, Node*> top_item = stage_roots_queue_.top();
123 stage_roots_queue_.pop();
124 return top_item;
125 }
126
127 // Push a node together with its total time onto the queue.
128 void Push(Node* node, const ModelTiming::NodeTiming& node_timing) {
129 stage_roots_queue_.emplace(
130 node_timing.total_time_nsec * node_timing.pipeline_ratio, node);
131 }
132
133 private:
134 std::priority_queue<std::pair<double, Node*>> stage_roots_queue_;
135};
136
137// A cache that looks up the `parallelism` parameters of nodes the first time
138// they are requested and saves them for subsequent requests.
139class NodeParallelismParameters {
140 public:
141 NodeParallelismParameters() {}
142
143 // Returns the `parallelism` parameter given a node.
144 Parameter* Get(const Node* node) {
145 if (node_parallelism_.contains(node)) {
146 // Look for the `parallelism` parameter of this node in the cache.
147 return node_parallelism_.at(node);
148 }
149 // Find the `parallelism` parameter of this node and cache it.
150 Node::ModelParameters parameters = node->CollectNodeTunableParameters();
151 Node::ModelParameters::iterator parameter_pair = std::find_if(
152 parameters.begin(), parameters.end(),
153 [](const std::pair<std::string, std::shared_ptr<Parameter>>&
154 parameter) { return parameter.second->name == kParallelism; });
155 if (parameter_pair == parameters.end()) {
156 return nullptr;
157 }
158 node_parallelism_[node] = parameter_pair->second.get();
159 return parameter_pair->second.get();
160 }
161
162 private:
163 absl::flat_hash_map<const Node*, Parameter*> node_parallelism_;
164};
165
166// Returns true if all parameters have reached their max values.
167bool AreAllParametersMax(const Model::ModelParameters& parameters) {
168 for (const auto& pair : parameters) {
169 if (pair.second->value < pair.second->max) {
170 return false;
171 }
172 }
173 return true;
174}
175
176// Records the ram usage of hill climbing algorithm.
177void RecordAutotuneRamUsage(int64 ram_budget, double max_buffered_bytes) {
178 if (ram_budget == 0) {
179 return;
180 }
181 const auto memory_info = port::GetMemoryInfo();
182 // Records ratio of memory used since RootDataset was created over the ram
183 // budget.
184 const auto original_free_memory = ram_budget / kRamBudgetShare;
185 const auto current_free_memory = memory_info.free;
186 metrics::RecordTFDataAutotuneUsedRamBudgetRatio(
187 (original_free_memory - current_free_memory) / ram_budget);
188 // Records ratio of maximum buffer bytes tf.data could use over the ram
189 // budget.
190 metrics::RecordTFDataAutotuneMaxBufferBudgetRatio(
191 max_buffered_bytes / static_cast<double>(ram_budget));
192}
193
194// Helper function for node traversal that doesn't skip any nodes.
195inline bool IsAnyNode(const std::shared_ptr<Node> node) { return true; }
196
197// Helper function for node traversal that filters out nodes for which
198// autotuning is disabled.
199inline bool IsAutotuneNode(const std::shared_ptr<Node> node) {
200 return node->autotune();
201}
202
203// Helper function for node traversal that returns only synchronous nodes.
204inline bool IsSyncNode(const std::shared_ptr<Node> node) {
205 return !node->IsAsync();
206}
207
208// Helper function for node traversal that returns only asynchronous nodes.
209inline bool IsAsyncNode(const std::shared_ptr<Node> node) {
210 return node->IsAsync();
211}
212
213// Wrapper for the square function to reduce verbosity.
214inline double Square(double x) { return x * x; }
215
216// Collects "essential" parallelism parameters and buffer size parameters in the
217// tree rooted in the given node. Which parallelism parameters are essential is
218// determined by the relative processing time spent in the corresponding
219// transformation. The collected parameters are returned via maps that map node
220// names to their respective parameters.
221inline void CollectParameters(std::shared_ptr<Node> node,
222 const Node::ModelParameters& parameters,
223 Node::ModelParameters* parallelism_parameters,
224 Node::ModelParameters* buffer_size_parameters) {
225 // Parallelism parameter is considered to be essential if the corresponding
226 // transformations's processing time is greater than essential rate times the
227 // average transformation self processing time.
228 constexpr double kEssentialRate = 0.3L;
229
230 Node::NodeValues processing_times;
231 double processing_time = node->TotalProcessingTime(&processing_times);
232 double uniform_share =
233 processing_time / static_cast<double>(processing_times.size());
234 for (auto& pair : parameters) {
235 if (pair.second->name == kParallelism &&
236 processing_times[pair.first] > kEssentialRate * uniform_share) {
237 parallelism_parameters->push_back(pair);
238 } else if (pair.second->name == kBufferSize) {
239 buffer_size_parameters->push_back(pair);
240 }
241 }
242}
243
244// Applies the gradient descent method once and updates the parameter values. If
245// the new value is out of the range, bound it within the range between the
246// minimal and maximum values.
247inline void UpdateParameterValues(const Node::ParameterGradients& gradients,
248 Node::ModelParameters* parameters) {
249 // Gradient descent step size.
250 constexpr double kDescentStep = 0.1L;
251 double new_value;
252
253 double max_abs_derivative = 1.0;
254 for (auto& pair : *parameters) {
255 if (std::round(pair.second->value) != pair.second->max) {
256 auto* gradient = gtl::FindOrNull(
257 gradients, std::make_pair(pair.first, pair.second->name));
258 if (gradient) {
259 max_abs_derivative = std::max(max_abs_derivative, std::abs(*gradient));
260 }
261 }
262 }
263 for (auto& pair : *parameters) {
264 auto* gradient = gtl::FindOrNull(
265 gradients, std::make_pair(pair.first, pair.second->name));
266 if (gradient) {
267 new_value =
268 pair.second->value - kDescentStep * (*gradient) / max_abs_derivative;
269 // Projection on a feasible interval.
270 if (new_value > pair.second->max) {
271 pair.second->value = pair.second->max;
272 } else if (new_value < pair.second->min) {
273 pair.second->value = pair.second->min;
274 } else {
275 pair.second->value = new_value;
276 }
277 }
278 }
279}
280
281// Copies the parameter values (which are for optimization tuning) and updates
282// the state values (which are for the input pipeline to follow).
283inline void UpdateStateValues(Node::ModelParameters* parameters) {
284 for (auto& pair : *parameters) {
285 auto& parameter = pair.second;
286 VLOG(2) << "Setting tunable parameter " << pair.first
287 << ":: " << parameter->name << " to " << parameter->value;
288 mutex_lock l(*parameter->state->mu);
289 parameter->state->value = parameter->value;
290 parameter->state->cond_var->notify_all();
291 }
292}
293
294// Recursively produces protos for nodes in a subtree of `output` node and
295// appends them to nodes of the given model.
296Status ModelToProtoHelper(std::shared_ptr<Node> output, ModelProto* model) {
297 model->set_output(output->id());
298 std::list<std::shared_ptr<Node>> to_serialize = {output};
299 auto& nodes = *model->mutable_nodes();
300 while (!to_serialize.empty()) {
301 const std::shared_ptr<Node> node = to_serialize.front();
302 to_serialize.pop_front();
303 TF_RETURN_IF_ERROR(node->ToProto(&(nodes[node->id()])));
304 for (auto input : node->inputs()) {
305 to_serialize.push_back(input);
306 }
307 }
308 return OkStatus();
309}
310
311// Recursively produces node tree rooted in `output` from the given model proto.
312Status ModelFromProtoHelper(ModelProto model, std::shared_ptr<Node>* output) {
313 if (model.nodes().empty()) {
314 return errors::Internal(
315 "Cannot restore model from proto because it has no nodes.");
316 }
317 TF_RETURN_IF_ERROR(Node::FromProto(model.nodes().at(model.output()),
318 /*output=*/nullptr, output));
319 std::list<std::shared_ptr<Node>> to_restore_inputs = {*output};
320 while (!to_restore_inputs.empty()) {
321 std::shared_ptr<Node> node = to_restore_inputs.front();
322 to_restore_inputs.pop_front();
323 for (int64_t input_id : model.nodes().at(node->id()).inputs()) {
324 std::shared_ptr<Node> input;
325 TF_RETURN_IF_ERROR(
326 Node::FromProto(model.nodes().at(input_id), node, &input));
327 node->add_input(input);
328 to_restore_inputs.push_back(input);
329 }
330 }
331 return OkStatus();
332}
333
334// The first input of InterleaveMany corresponds to the input dataset whose
335// elements are used to create the (derived) input datasets whose elements are
336// interleaved as output.
337//
338// TODO(jsimsa): model the first input
339class InterleaveMany : public Node {
340 public:
341 using Node::Node;
342
343 InterleaveMany(Node::Args args,
344 std::vector<std::shared_ptr<Parameter>> parameters)
345 : Node(args) {
346 for (auto& parameter : parameters) {
347 parameters_[parameter->name] = std::move(parameter);
348 }
349 }
350
351 virtual ~InterleaveMany() {}
352
353 // The ratio of an InterleaveMany node is `1/cycle_length`. If cycle length is
354 // not available, we approximate it by `1/input_size`. The input size does not
355 // include the original input dataset that generates other input datasets of
356 // interleave nodes.
357 double Ratio() const override {
358 auto* cycle_length = gtl::FindOrNull(parameters_, kCycleLength);
359 if (cycle_length != nullptr) {
360 return 1.0 / (*cycle_length)->value;
361 }
362 // After cl/436244658, `cycle_length` can not be `nullptr`. The remaining
363 // part of this function is used to approximate `Ratio()` of this node for
364 // model proto that was created before the CL.
365
366 // Cycle length is not available, use 1/input_size as the ratio.
367 std::size_t input_size = 1;
368 {
369 mutex_lock l(mu_);
370 if (inputs_.size() >= 2) {
371 auto first_input = inputs_.begin();
372 auto second_input = std::next(first_input);
373 // Some interleave datasets have 2 different inputs: the original input
374 // dataset and the generated input datasets when interleave is iterated,
375 // and some do not.
376 if ((*first_input)->name() == (*second_input)->name()) {
377 input_size = std::max(inputs_.size(), input_size);
378 } else {
379 input_size = std::max(inputs_.size() - 1, input_size);
380 }
381 }
382 }
383 if (input_size == 0) {
384 return 1.0;
385 }
386 return 1.0 / static_cast<double>(input_size);
387 }
388
389 protected:
390 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
391 TF_SHARED_LOCKS_REQUIRED(mu_) {
392 return std::make_shared<InterleaveMany>(
393 Args{id_, name_, std::move(output)});
394 }
395
396 void InputTimeLocked(NodeValues* input_times) const override
397 TF_SHARED_LOCKS_REQUIRED(mu_) {
398 double inherited_input_time;
399 if (output_) {
400 inherited_input_time = (*input_times)[output_->long_name()];
401 } else {
402 inherited_input_time = (*input_times)[kModelInputTimeKey];
403 }
404
405 if (num_inputs() <= 1) {
406 (*input_times)[long_name()] = inherited_input_time;
407 return;
408 }
409 // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
410 // input time for InterleaveMany node to call one of the `(num_inputs() -
411 // 1)` input nodes (except first input) to return an element. Regardless of
412 // the `block_length` parameter of InterleaveMany node, the average input
413 // time for any of the `(num_inputs() - 1)` input nodes to be called is
414 // computed as:
415 double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
416 static_cast<double>(num_inputs() - 1);
417 (*input_times)[long_name()] = input_time;
418 }
419
420 // The output time is the sum of the self processing time and the average
421 // output time of inputs comprising the interleave "cycle".
422 void OutputTimeLocked(const NodeValues& input_times,
423 ParameterGradients* gradients, NodeValues* output_times,
424 NodeValues* output_time_gradients) const override
425 TF_SHARED_LOCKS_REQUIRED(mu_) {
426 double self_processing_time = SelfProcessingTimeLocked();
427 if (num_inputs() <= 1) {
428 (*output_times)[long_name()] = self_processing_time;
429 if (gradients) {
430 for (const auto& pair : CollectTunableParametersLocked()) {
431 gradients->erase(std::make_pair(pair.first, pair.second->name));
432 }
433 }
434 return;
435 }
436
437 double inputs_output_time =
438 (OutputTimeForInputs(*output_times) -
439 (*output_times)[inputs_.front()->long_name()]) /
440 static_cast<double>(num_inputs() - 1);
441 if (gradients) {
442 for (const auto& pair : CollectTunableParametersLocked()) {
443 auto* gradient = gtl::FindOrNull(
444 *gradients, std::make_pair(pair.first, pair.second->name));
445 if (gradient) {
446 *gradient /= static_cast<double>(num_inputs() - 1);
447 }
448 }
449
450 (*output_time_gradients)[long_name()] =
451 OutputTimeGradientsForInputs(*output_time_gradients) -
452 (*output_time_gradients)[inputs_.front()->long_name()];
453
454 // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
455 // first input equal to 0 since its output time is excluded from
456 // computations.
457 for (auto& pair : inputs_.front()->CollectTunableParameters()) {
458 (*gradients)[std::make_pair(pair.first, pair.second->name)] = 0.0L;
459 }
460 }
461 (*output_times)[long_name()] = self_processing_time + inputs_output_time;
462 }
463
464 // The processing time is the sum of the self processing time and the average
465 // processing time of inputs comprising the interleave "cycle".
466 void TotalProcessingTimeLocked(NodeValues* processing_times,
467 NodeValues* total_processing_times) override
468 TF_SHARED_LOCKS_REQUIRED(mu_) {
469 double self_processing_time = SelfProcessingTimeLocked();
470 if (processing_times) {
471 (*processing_times)[long_name()] = self_processing_time;
472 }
473 if (num_inputs() <= 1) {
474 (*total_processing_times)[long_name()] = self_processing_time;
475 return;
476 }
477 double inputs_processing_time =
478 (TotalProcessingTimeForInputs(*total_processing_times) -
479 (*total_processing_times)[inputs_.front()->long_name()]) /
480 static_cast<double>(num_inputs() - 1);
481 (*total_processing_times)[long_name()] =
482 self_processing_time + inputs_processing_time;
483 }
484
485 Status ToProto(ModelProto::Node* node_proto) const {
486 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
487 node_proto->set_node_class(NodeClass::INTERLEAVE_MANY);
488 return OkStatus();
489 }
490};
491
492// The first input of AsyncInterleaveMany corresponds to the input dataset whose
493// elements are used to create the (derived) input datasets whose elements are
494// interleaved as output.
495//
496// TODO(jsimsa): model the first input
497class AsyncInterleaveMany : public Node {
498 public:
499 AsyncInterleaveMany(Node::Args args,
500 std::vector<std::shared_ptr<Parameter>> parameters)
501 : Node(args) {
502 for (auto& parameter : parameters) {
503 parameters_[parameter->name] = std::move(parameter);
504 }
505 }
506
507 virtual ~AsyncInterleaveMany() {}
508
509 bool IsAsync() const override { return true; }
510
511 // The ratio of an AsyncInterleaveMany node is 1/`cycle_length`. If cycle
512 // length is not available, we use 1/parallelism.
513 double Ratio() const override {
514 auto* cycle_length = gtl::FindOrNull(parameters_, kCycleLength);
515 if (cycle_length != nullptr) {
516 return 1.0 / (*cycle_length)->value;
517 }
518 // After cl/436244658, `cycle_length` can not be `nullptr`. The remaining
519 // part of this function is used to approximate `Ratio()` of this node for
520 // model proto that was created before the CL.
521
522 // Cycle length is not available, use 1/min(input_size, parallelism) as the
523 // ratio.
524 double parallelism = 1.0;
525 {
526 mutex_lock l(mu_);
527 if (inputs_.size() >= 2) {
528 auto first_input = inputs_.begin();
529 auto second_input = std::next(first_input);
530 // Some interleave datasets have 2 different inputs: the original input
531 // dataset and the generated input datasets when interleave is iterated,
532 // and some do not.
533 if ((*first_input)->name() == (*second_input)->name()) {
534 parallelism = std::max(inputs_.size(), size_t{1});
535 } else {
536 parallelism = std::max(inputs_.size() - 1, size_t{1});
537 }
538 }
539 }
540 auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
541 if (parameter) {
542 parallelism = std::min(parallelism, (*parameter)->value);
543 }
544 return 1.0 / parallelism;
545 }
546
547 double ComputeSelfTime() const override {
548 double parallelism = 1.0;
549 auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
550 if (parallelism_parameter) {
551 parallelism = (*parallelism_parameter)->value;
552 }
553 if (num_elements_ == 0) {
554 return 0;
555 }
556 {
557 tf_shared_lock l(mu_);
558 return processing_time_ema_ / parallelism;
559 }
560 }
561
562 protected:
563 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
564 TF_SHARED_LOCKS_REQUIRED(mu_) {
565 std::vector<std::shared_ptr<Parameter>> parameters;
566 for (auto& pair : parameters_) {
567 parameters.push_back(pair.second);
568 }
569 return std::make_shared<AsyncInterleaveMany>(
570 Args{id_, name_, std::move(output)}, parameters);
571 }
572
573 void InputTimeLocked(NodeValues* input_times) const override
574 TF_SHARED_LOCKS_REQUIRED(mu_) {
575 double inherited_input_time;
576 if (output_) {
577 inherited_input_time = (*input_times)[output_->long_name()];
578 } else {
579 inherited_input_time = (*input_times)[kModelInputTimeKey];
580 }
581
582 if (num_inputs() <= 1) {
583 (*input_times)[long_name()] = inherited_input_time;
584 return;
585 }
586 // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
587 // input time for AsyncInterleaveMany node to call one of the `(num_inputs()
588 // - 1)` input nodes (except first input) to return an element. Regardless
589 // of the `block_length` parameter of AsyncInterleaveMany node, the average
590 // input time for any of the `(num_inputs() - 1)` input nodes to be called
591 // is computed as:
592 double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
593 static_cast<double>(num_inputs() - 1);
594 (*input_times)[long_name()] = input_time;
595 }
596
597 // The output time is the sum of self processing time and expected wait time
598 // from the buffer model estimated using `ComputeWaitTime(producer_time,
599 // consumer_time, parallelism, ...)`, where `producer_time` is the average
600 // output time of inputs comprising the interleave "cycle" divided by
601 // `parallelism`, `consumer_time` is the `input_time` specified through
602 // `input_times` divided by `num_inputs() - 1`, and if the node has
603 // parallelism parameter, then `buffer_size` is derived from `parallelism`.
604 void OutputTimeLocked(const NodeValues& input_times,
605 ParameterGradients* gradients, NodeValues* output_times,
606 NodeValues* output_time_gradients) const override
607 TF_SHARED_LOCKS_REQUIRED(mu_) {
608 double self_processing_time = SelfProcessingTimeLocked();
609 if (num_inputs() <= 1) {
610 (*output_times)[long_name()] = self_processing_time;
611 if (gradients) {
612 for (const auto& pair : CollectTunableParametersLocked()) {
613 gradients->erase(std::make_pair(pair.first, pair.second->name));
614 }
615 }
616 return;
617 }
618
619 double output_time, wait_time, consumer_time, producer_time;
620 double input_time = input_times.at(long_name());
621 consumer_time = input_time / static_cast<double>(num_inputs() - 1);
622 double parallelism = num_inputs() - 1; // default to cycle length
623 auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
624 if (parameter) {
625 parallelism = std::min(parallelism, (*parameter)->value);
626 }
627 double output_time_for_inputs =
628 OutputTimeForInputs(*output_times) -
629 (*output_times)[inputs_.front()->long_name()];
630 producer_time = output_time_for_inputs /
631 static_cast<double>(num_inputs() - 1) / parallelism;
632
633 if (gradients) {
634 double producer_time_der = 0.0L;
635 double consumer_time_der = 0.0L;
636 double buffer_size_der = 0.0L;
637 wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
638 &producer_time_der, &consumer_time_der,
639 &buffer_size_der);
640 double inputs_time_der_sum =
641 OutputTimeGradientsForInputs(*output_time_gradients);
642 (*output_time_gradients)[long_name()] =
643 consumer_time_der +
644 producer_time_der * inputs_time_der_sum / parallelism;
645
646 for (const auto& pair : CollectTunableParametersLocked()) {
647 auto* gradient = gtl::FindOrNull(
648 *gradients, std::make_pair(pair.first, pair.second->name));
649 if (gradient) {
650 *gradient *= (producer_time_der /
651 static_cast<double>(num_inputs() - 1) / parallelism);
652 }
653 }
654
655 // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
656 // first input equal to 0 since its output time is excluded from
657 // computations.
658 for (auto& pair : inputs_.front()->CollectTunableParameters()) {
659 (*gradients)[std::make_pair(pair.first, pair.second->name)] = 0.0L;
660 }
661 // Add derivative w.r.t. own parallelism parameter.
662 if (parameter && (*parameter)->state->tunable) {
663 (*gradients)[std::make_pair(long_name(), (*parameter)->name)] =
664 buffer_size_der - producer_time_der * producer_time / parallelism;
665 }
666 } else {
667 wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
668 /*producer_time_derivative=*/nullptr,
669 /*consumer_time_derivative=*/nullptr,
670 /*buffer_size_derivative=*/nullptr);
671 }
672 output_time = self_processing_time + wait_time;
673 (*output_times)[long_name()] = output_time;
674 }
675
676 // The processing time is the sum of the self processing time and the average
677 // processing time of inputs comprising the interleave "cycle".
678 void TotalProcessingTimeLocked(NodeValues* processing_times,
679 NodeValues* total_processing_times) override
680 TF_SHARED_LOCKS_REQUIRED(mu_) {
681 double self_processing_time = SelfProcessingTimeLocked();
682 if (processing_times) {
683 (*processing_times)[long_name()] = self_processing_time;
684 }
685 if (num_inputs() <= 1) {
686 (*total_processing_times)[long_name()] = self_processing_time;
687 return;
688 }
689 double inputs_processing_time =
690 (TotalProcessingTimeForInputs(*total_processing_times) -
691 (*total_processing_times)[inputs_.front()->long_name()]) /
692 static_cast<double>(num_inputs() - 1);
693 (*total_processing_times)[long_name()] =
694 self_processing_time + inputs_processing_time;
695 }
696
697 double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
698 auto* parameter = gtl::FindOrNull(parameters_, kMaxBufferedElements);
699 if (parameter == nullptr) {
700 parameter = gtl::FindOrNull(parameters_, kParallelism);
701 if (parameter == nullptr) {
702 return 0.0;
703 }
704 }
705 return (*parameter)->value * AverageBufferedElementSize();
706 }
707
708 Status ToProto(ModelProto::Node* node_proto) const {
709 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
710 node_proto->set_node_class(NodeClass::ASYNC_INTERLEAVE_MANY);
711 return OkStatus();
712 }
713};
714
715class KnownRatio : public Node {
716 public:
717 KnownRatio(Node::Args args, double ratio) : Node(args), ratio_(ratio) {}
718
719 virtual ~KnownRatio() {}
720
721 double Ratio() const override { return ratio_; }
722
723 protected:
724 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
725 TF_SHARED_LOCKS_REQUIRED(mu_) {
726 return std::make_shared<KnownRatio>(Args{id_, name_, std::move(output)},
727 ratio_);
728 }
729
730 // The input time is the sum of inherited input time and self processing time,
731 // divided by `ratio_`.
732 void InputTimeLocked(NodeValues* input_times) const override
733 TF_SHARED_LOCKS_REQUIRED(mu_) {
734 double inherited_input_time;
735 if (output_) {
736 inherited_input_time = (*input_times)[output_->long_name()];
737 } else {
738 inherited_input_time = (*input_times)[kModelInputTimeKey];
739 }
740
741 if (ratio_ == 0) {
742 (*input_times)[long_name()] = inherited_input_time;
743 return;
744 }
745 double input_time =
746 (inherited_input_time + SelfProcessingTimeLocked()) / ratio_;
747 (*input_times)[long_name()] = input_time;
748 }
749
750 // The output time is the sum of the self processing time and the product of
751 // `ratio_` and the sum of output times of inputs.
752 void OutputTimeLocked(const NodeValues& input_times,
753 ParameterGradients* gradients, NodeValues* output_times,
754 NodeValues* output_time_gradients) const override
755 TF_SHARED_LOCKS_REQUIRED(mu_) {
756 double self_processing_time = SelfProcessingTimeLocked();
757 if (ratio_ == 0) {
758 (*output_times)[long_name()] = self_processing_time;
759 if (gradients) {
760 for (const auto& pair : CollectTunableParametersLocked()) {
761 gradients->erase(std::make_pair(pair.first, pair.second->name));
762 }
763 }
764 return;
765 }
766 if (gradients) {
767 for (const auto& pair : CollectTunableParametersLocked()) {
768 auto* gradient = gtl::FindOrNull(
769 *gradients, std::make_pair(pair.first, pair.second->name));
770 if (gradient) {
771 *gradient *= ratio_;
772 }
773 }
774 (*output_time_gradients)[long_name()] =
775 OutputTimeGradientsForInputs(*output_time_gradients);
776 }
777 double inputs_output_time = ratio_ * OutputTimeForInputs(*output_times);
778 (*output_times)[long_name()] = self_processing_time + inputs_output_time;
779 }
780
781 // The processing time is the sum of the self processing time and the product
782 // of `ratio_` and the sum of processing times of inputs.
783 void TotalProcessingTimeLocked(NodeValues* processing_times,
784 NodeValues* total_processing_times) override
785 TF_SHARED_LOCKS_REQUIRED(mu_) {
786 double self_processing_time = SelfProcessingTimeLocked();
787 if (processing_times) {
788 (*processing_times)[long_name()] = self_processing_time;
789 }
790 if (ratio_ == 0) {
791 (*total_processing_times)[long_name()] = self_processing_time;
792 return;
793 }
794 double inputs_processing_time =
795 ratio_ * TotalProcessingTimeForInputs(*total_processing_times);
796 (*total_processing_times)[long_name()] =
797 self_processing_time + inputs_processing_time;
798 }
799
800 Status ToProto(ModelProto::Node* node_proto) const {
801 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
802 node_proto->set_node_class(NodeClass::KNOWN_RATIO);
803 node_proto->set_ratio(ratio_);
804 return OkStatus();
805 }
806
807 private:
808 const double ratio_;
809};
810
811class AsyncRatio : public Node {
812 public:
813 AsyncRatio(Node::Args args, double ratio, double memory_ratio,
814 std::vector<std::shared_ptr<Parameter>> parameters)
815 : Node(args), ratio_(ratio), memory_ratio_(memory_ratio) {
816 for (auto& parameter : parameters) {
817 parameters_[parameter->name] = std::move(parameter);
818 }
819 }
820
821 virtual ~AsyncRatio() {}
822
823 bool IsAsync() const override { return true; }
824
825 double Ratio() const override { return ratio_; }
826
827 double ComputeSelfTime() const override {
828 double parallelism = 1.0;
829 auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
830 if (parallelism_parameter) {
831 parallelism = (*parallelism_parameter)->value;
832 }
833 if (num_elements_ == 0) {
834 return 0;
835 }
836 {
837 tf_shared_lock l(mu_);
838 return processing_time_ema_ / parallelism;
839 }
840 }
841
842 protected:
843 virtual double RatioLocked() const TF_SHARED_LOCKS_REQUIRED(mu_) {
844 return ratio_;
845 }
846
847 double MemoryRatio() const { return memory_ratio_; }
848
849 // The input time is the sum of inherited input time and parallelism adjusted
850 // self processing time, divided by `Ratio()`.
851 void InputTimeLocked(NodeValues* input_times) const override
852 TF_SHARED_LOCKS_REQUIRED(mu_) {
853 double inherited_input_time;
854 if (output_) {
855 inherited_input_time = (*input_times)[output_->long_name()];
856 } else {
857 inherited_input_time = (*input_times)[kModelInputTimeKey];
858 }
859 double parallelism = 1.0;
860 auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
861 if (parallelism_parameter) {
862 parallelism = (*parallelism_parameter)->value;
863 }
864
865 auto ratio = RatioLocked();
866 if (ratio == 0.0) {
867 (*input_times)[long_name()] =
868 inherited_input_time + SelfProcessingTimeLocked() / parallelism;
869 return;
870 }
871 double input_time =
872 (inherited_input_time + SelfProcessingTimeLocked() / parallelism) /
873 ratio;
874 (*input_times)[long_name()] = input_time;
875 }
876
877 // The output time is the sum of parallelism adjusted self processing time and
878 // expected wait time from the buffer model estimated using
879 // `ComputeWaitTime(producer_time, consumer_time, parallelism, ...)`, where
880 // `producer_time` is the product of `Ratio()` and the sum of output times of
881 // inputs, `consumer_time` is the product of `Ratio()` and the `input_time`
882 // specified through `input_times` (since for each element stored in the
883 // buffer, the inputs need to be called `Ratio()` times), and if the node has
884 // parallelism parameter, then `buffer_size` is derived from `parallelism`.
885 //
886 // Current implementation assumes that there is at most 1 parameter per node.
887 void OutputTimeLocked(const NodeValues& input_times,
888 ParameterGradients* gradients, NodeValues* output_times,
889 NodeValues* output_time_gradients) const override
890 TF_SHARED_LOCKS_REQUIRED(mu_) {
891 auto ratio = RatioLocked();
892 double parallelism = 1.0;
893 double buffer_size = 0.0;
894 auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
895 auto* buffer_size_parameter = gtl::FindOrNull(parameters_, kBufferSize);
896 if (parallelism_parameter) {
897 parallelism = (*parallelism_parameter)->value;
898 if (ratio == 0.0) {
899 buffer_size = parallelism;
900 } else {
901 // Currently, MapAndBatch is the only transformation creates
902 // AsyncKnownRatio nodes with ratio >= 1. For MapAndBatch, we create
903 // `parallelism` threads to apply the function on elements from input
904 // dataset, while one element in the buffer actually corresponds to
905 // `Ratio()` elements from input dataset. So we adjust the `buffer_size`
906 // by dividing `Ratio()`.
907 buffer_size = parallelism / ratio;
908 }
909 } else if (buffer_size_parameter) {
910 buffer_size = (*buffer_size_parameter)->value;
911 }
912 double self_processing_time = SelfProcessingTimeLocked();
913 double output_time, wait_time, consumer_time, producer_time;
914 double input_time = input_times.at(long_name());
915
916 if (ratio == 0.0) {
917 consumer_time = input_time;
918 producer_time = 0.0L;
919 if (gradients) {
920 for (const auto& pair : CollectTunableParametersLocked()) {
921 gradients->erase(std::make_pair(pair.first, pair.second->name));
922 }
923
924 double producer_time_der = 0.0L;
925 double consumer_time_der = 0.0L;
926 double buffer_size_der = 0.0L;
927 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
928 &producer_time_der, &consumer_time_der,
929 &buffer_size_der);
930 (*output_time_gradients)[long_name()] = consumer_time_der;
931 if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
932 (*gradients)[std::make_pair(long_name(),
933 (*parallelism_parameter)->name)] =
934 -(1.0L + consumer_time_der) * self_processing_time /
935 Square(parallelism) +
936 buffer_size_der;
937 } else if (buffer_size_parameter &&
938 (*buffer_size_parameter)->state->tunable) {
939 (*gradients)[std::make_pair(
940 long_name(), (*buffer_size_parameter)->name)] = buffer_size_der;
941 }
942 } else {
943 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
944 /*producer_time_derivative=*/nullptr,
945 /*consumer_time_derivative=*/nullptr,
946 /*buffer_size_derivative=*/nullptr);
947 }
948 output_time = self_processing_time / parallelism + wait_time;
949 (*output_times)[long_name()] = output_time;
950 return;
951 }
952
953 consumer_time = input_time * ratio;
954 producer_time = ratio * OutputTimeForInputs(*output_times);
955 if (gradients) {
956 double producer_time_der = 0.0L;
957 double consumer_time_der = 0.0L;
958 double buffer_size_der = 0.0L;
959 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
960 &producer_time_der, &consumer_time_der,
961 &buffer_size_der);
962 double inputs_time_der_sum =
963 OutputTimeGradientsForInputs(*output_time_gradients);
964 (*output_time_gradients)[long_name()] =
965 consumer_time_der + producer_time_der * inputs_time_der_sum;
966
967 for (const auto& pair : CollectTunableParametersLocked()) {
968 auto* gradient = gtl::FindOrNull(
969 *gradients, std::make_pair(pair.first, pair.second->name));
970 if (gradient) {
971 *gradient *= (ratio * producer_time_der);
972 }
973 }
974
975 // Add derivative w.r.t. own parameter if it's tunable.
976 if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
977 (*gradients)[std::make_pair(long_name(),
978 (*parallelism_parameter)->name)] =
979 buffer_size_der / ratio -
980 (1.0L + consumer_time_der +
981 producer_time_der * inputs_time_der_sum) *
982 self_processing_time / Square(parallelism);
983 } else if (buffer_size_parameter &&
984 (*buffer_size_parameter)->state->tunable) {
985 (*gradients)[std::make_pair(
986 long_name(), (*buffer_size_parameter)->name)] = buffer_size_der;
987 }
988 } else {
989 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
990 /*producer_time_derivative=*/nullptr,
991 /*consumer_time_derivative=*/nullptr,
992 /*buffer_size_derivative=*/nullptr);
993 }
994 output_time = self_processing_time / parallelism + wait_time;
995 (*output_times)[long_name()] = output_time;
996 }
997
998 // The processing time is the sum of the self processing time and the product
999 // of `Ratio()` and the sum of processing times of inputs.
1000 void TotalProcessingTimeLocked(NodeValues* processing_times,
1001 NodeValues* total_processing_times) override
1002 TF_SHARED_LOCKS_REQUIRED(mu_) {
1003 double self_processing_time = SelfProcessingTimeLocked();
1004 if (processing_times) {
1005 (*processing_times)[long_name()] = self_processing_time;
1006 }
1007 auto ratio = RatioLocked();
1008 if (ratio == 0) {
1009 (*total_processing_times)[long_name()] = self_processing_time;
1010 return;
1011 }
1012 double inputs_processing_time =
1013 ratio * TotalProcessingTimeForInputs(*total_processing_times);
1014 (*total_processing_times)[long_name()] =
1015 self_processing_time + inputs_processing_time;
1016 }
1017
1018 double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
1019 double result = 0;
1020 auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
1021 if (!parameter) {
1022 parameter = gtl::FindOrNull(parameters_, kParallelism);
1023 }
1024
1025 if (parameter) {
1026 if (memory_ratio_ == 0) {
1027 result += (*parameter)->value * AverageBufferedElementSize();
1028 } else {
1029 // The estimation is currently not accurate for MapAndBatchDataset for
1030 // the maximum buffer size does not match `num_parallel_calls`
1031 // parameter.
1032 result +=
1033 (*parameter)->value * AverageBufferedElementSize() / memory_ratio_;
1034 }
1035 }
1036 return result;
1037 }
1038
1039 private:
1040 // Identifies how many input elements need to be created to construct an
1041 // element for the dataset.
1042 //
1043 // Currently the value is 1 for PrefetchDataset and ParallelMapDataset,
1044 // batch_size for MapAndBatchDataset and ParallelBatchDataset.
1045 const double ratio_;
1046 // For parallelism nodes, identifies how many parallelism calls are introduced
1047 // by one buffered element. The value is defined to correctly estimate RAM
1048 // budget bound with given num_parallel_calls (or buffer_size) combined with
1049 // the estimated average size of buffered elements.
1050 const double memory_ratio_;
1051};
1052
1053class UnknownRatio : public Node {
1054 public:
1055 using Node::Node;
1056
1057 virtual ~UnknownRatio() {}
1058
1059 double Ratio() const override {
1060 tf_shared_lock l(mu_);
1061 return RatioLocked();
1062 }
1063
1064 protected:
1065 double RatioLocked() const TF_SHARED_LOCKS_REQUIRED(mu_) {
1066 // TODO(wilsin): Consistent with UnknownRatio, current implementation
1067 // assumes that the number of input elements consumed per output is the same
1068 // across all inputs.
1069 if (num_elements_ == 0 || inputs_.empty() ||
1070 inputs_.front()->num_elements() == 0) {
1071 return 0.0;
1072 }
1073 return static_cast<double>(inputs_.front()->num_elements()) /
1074 static_cast<double>(num_elements_);
1075 }
1076
1077 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
1078 TF_SHARED_LOCKS_REQUIRED(mu_) {
1079 return std::make_shared<UnknownRatio>(Args{id_, name_, std::move(output)});
1080 }
1081
1082 // The input time is the sum of inherited input time and self processing time,
1083 // divided by the ratio estimate.
1084 void InputTimeLocked(NodeValues* input_times) const override
1085 TF_SHARED_LOCKS_REQUIRED(mu_) {
1086 double inherited_input_time;
1087 if (output_) {
1088 inherited_input_time = (*input_times)[output_->long_name()];
1089 } else {
1090 inherited_input_time = (*input_times)[kModelInputTimeKey];
1091 }
1092
1093 if (num_elements_ == 0 || inputs_.empty() ||
1094 inputs_.front()->num_elements() == 0) {
1095 (*input_times)[long_name()] = inherited_input_time;
1096 return;
1097 }
1098 std::shared_ptr<Node> input = inputs_.front();
1099 double ratio = static_cast<double>(input->num_elements()) /
1100 static_cast<double>(num_elements_);
1101 double input_time =
1102 (inherited_input_time + SelfProcessingTimeLocked()) / ratio;
1103 (*input_times)[long_name()] = input_time;
1104 }
1105
1106 // The output time is the sum of the self processing time and the product of
1107 // the ratio estimate and the sum of output times of inputs.
1108 void OutputTimeLocked(const NodeValues& input_times,
1109 ParameterGradients* gradients, NodeValues* output_times,
1110 NodeValues* output_time_gradients) const override
1111 TF_SHARED_LOCKS_REQUIRED(mu_) {
1112 double self_processing_time = SelfProcessingTimeLocked();
1113 if (num_elements_ == 0 || inputs_.empty() ||
1114 inputs_.front()->num_elements() == 0) {
1115 (*output_times)[long_name()] = self_processing_time;
1116 if (gradients) {
1117 for (const auto& pair : CollectTunableParametersLocked()) {
1118 gradients->erase(std::make_pair(pair.first, pair.second->name));
1119 }
1120 }
1121 return;
1122 }
1123 // TODO(jsimsa): The current implementation assumes that the number of input
1124 // elements consumed per output is the same across all inputs.
1125 double ratio = static_cast<double>(inputs_.front()->num_elements()) /
1126 static_cast<double>(num_elements_);
1127 if (gradients) {
1128 for (const auto& pair : CollectTunableParametersLocked()) {
1129 auto* gradient = gtl::FindOrNull(
1130 *gradients, std::make_pair(pair.first, pair.second->name));
1131 if (gradient) {
1132 *gradient *= ratio;
1133 }
1134 }
1135 (*output_time_gradients)[long_name()] =
1136 OutputTimeGradientsForInputs(*output_time_gradients);
1137 }
1138 double inputs_output_time = ratio * OutputTimeForInputs(*output_times);
1139 (*output_times)[long_name()] = self_processing_time + inputs_output_time;
1140 }
1141
1142 // The processing time is the sum of the self processing time and the product
1143 // of the ratio estimate and the sum of processing times of inputs.
1144 void TotalProcessingTimeLocked(
1145 absl::flat_hash_map<string, double>* processing_times,
1146 absl::flat_hash_map<string, double>* total_processing_times) override
1147 TF_SHARED_LOCKS_REQUIRED(mu_) {
1148 double self_processing_time = SelfProcessingTimeLocked();
1149 if (processing_times) {
1150 (*processing_times)[long_name()] = self_processing_time;
1151 }
1152 if (inputs_.empty() || num_elements_ == 0) {
1153 (*total_processing_times)[long_name()] = self_processing_time;
1154 return;
1155 }
1156 // TODO(jsimsa): The current implementation assumes that the number of input
1157 // elements consumed per output is the same across all inputs.
1158 std::shared_ptr<Node> input = inputs_.front();
1159 double ratio = static_cast<double>(input->num_elements()) /
1160 static_cast<double>(num_elements_);
1161 double inputs_processing_time =
1162 ratio * TotalProcessingTimeForInputs(*total_processing_times);
1163 (*total_processing_times)[long_name()] =
1164 self_processing_time + inputs_processing_time;
1165 }
1166
1167 Status ToProto(ModelProto::Node* node_proto) const {
1168 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
1169 node_proto->set_node_class(NodeClass::UNKNOWN_RATIO);
1170 return OkStatus();
1171 }
1172};
1173
1174class Unknown : public Node {
1175 public:
1176 using Node::Node;
1177
1178 virtual ~Unknown() {}
1179
1180 protected:
1181 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
1182 TF_SHARED_LOCKS_REQUIRED(mu_) {
1183 return std::make_shared<Unknown>(Args{id_, name_, std::move(output)});
1184 }
1185
1186 // The input time is the inherited input time.
1187 void InputTimeLocked(NodeValues* input_times) const override
1188 TF_SHARED_LOCKS_REQUIRED(mu_) {
1189 double inherited_input_time;
1190 if (output_) {
1191 inherited_input_time = (*input_times)[output_->long_name()];
1192 } else {
1193 inherited_input_time = (*input_times)[kModelInputTimeKey];
1194 }
1195 (*input_times)[long_name()] = inherited_input_time;
1196 }
1197
1198 // The output time is the sum of output times of inputs.
1199 void OutputTimeLocked(const NodeValues& input_times,
1200 ParameterGradients* gradients, NodeValues* output_times,
1201 NodeValues* output_time_gradients) const override
1202 TF_SHARED_LOCKS_REQUIRED(mu_) {
1203 (*output_times)[long_name()] = OutputTimeForInputs(*output_times);
1204 if (gradients) {
1205 (*output_time_gradients)[long_name()] =
1206 OutputTimeGradientsForInputs(*output_time_gradients);
1207 }
1208 }
1209
1210 // The processing time is the sum of processing times of inputs.
1211 void TotalProcessingTimeLocked(NodeValues* processing_times,
1212 NodeValues* total_processing_times) override
1213 TF_SHARED_LOCKS_REQUIRED(mu_) {
1214 if (processing_times) {
1215 (*processing_times)[long_name()] = SelfProcessingTimeLocked();
1216 }
1217 (*total_processing_times)[long_name()] =
1218 TotalProcessingTimeForInputs(*total_processing_times);
1219 }
1220
1221 Status ToProto(ModelProto::Node* node_proto) const {
1222 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
1223 node_proto->set_node_class(NodeClass::UNKNOWN);
1224 return OkStatus();
1225 }
1226};
1227
1228class AsyncKnownRatio : public AsyncRatio {
1229 public:
1230 AsyncKnownRatio(Node::Args args, double ratio, double memory_ratio,
1231 std::vector<std::shared_ptr<Parameter>> parameters)
1232 : AsyncRatio(args, ratio, memory_ratio, parameters) {}
1233
1234 virtual ~AsyncKnownRatio() {}
1235
1236 protected:
1237 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
1238 TF_SHARED_LOCKS_REQUIRED(mu_) {
1239 std::vector<std::shared_ptr<Parameter>> parameters;
1240 for (auto& pair : parameters_) {
1241 parameters.push_back(pair.second);
1242 }
1243 return std::make_shared<AsyncKnownRatio>(
1244 Args{id_, name_, std::move(output)}, Ratio(), MemoryRatio(),
1245 parameters);
1246 }
1247
1248 Status ToProto(ModelProto::Node* node_proto) const {
1249 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
1250 node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO);
1251 node_proto->set_ratio(Ratio());
1252 node_proto->set_memory_ratio(MemoryRatio());
1253 return OkStatus();
1254 }
1255};
1256
1257class AsyncUnknownRatio : public AsyncRatio {
1258 public:
1259 AsyncUnknownRatio(Node::Args args,
1260 std::vector<std::shared_ptr<Parameter>> parameters)
1261 : AsyncRatio(args, /*ratio=*/0.0, /*memory_ratio=*/0.0, parameters) {}
1262
1263 virtual ~AsyncUnknownRatio() {}
1264
1265 double Ratio() const override {
1266 tf_shared_lock l(mu_);
1267 return RatioLocked();
1268 }
1269
1270 protected:
1271 double RatioLocked() const TF_SHARED_LOCKS_REQUIRED(mu_) override {
1272 // TODO(wilsin): Consistent with UnknownRatio, current implementation
1273 // assumes that the number of input elements consumed per output is the same
1274 // across all inputs.
1275 if (num_elements_ == 0 || inputs_.empty() ||
1276 inputs_.front()->num_elements() == 0) {
1277 return 0.0;
1278 }
1279 return static_cast<double>(inputs_.front()->num_elements()) /
1280 static_cast<double>(num_elements_);
1281 }
1282
1283 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
1284 TF_SHARED_LOCKS_REQUIRED(mu_) {
1285 std::vector<std::shared_ptr<Parameter>> parameters;
1286 for (auto& pair : parameters_) {
1287 parameters.push_back(pair.second);
1288 }
1289 return std::make_shared<AsyncUnknownRatio>(
1290 Args{id_, name_, std::move(output)}, parameters);
1291 }
1292
1293 Status ToProto(ModelProto::Node* node_proto) const {
1294 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
1295 node_proto->set_node_class(NodeClass::ASYNC_UNKNOWN_RATIO);
1296 return OkStatus();
1297 }
1298};
1299
1300} // namespace
1301
1302thread_local int64_t Node::work_start_;
1303
1304std::shared_ptr<Parameter> MakeParameter(const string& name,
1305 std::shared_ptr<SharedState> state,
1306 double min, double max) {
1307 return std::make_shared<Parameter>(name, state, min, max);
1308}
1309
1310std::shared_ptr<Parameter> MakeNonTunableParameter(const string& name,
1311 double value) {
1312 return std::make_shared<Parameter>(name, nullptr, /*min=*/value,
1313 /*max=*/value);
1314}
1315
1316std::shared_ptr<Node> MakeInterleaveManyNode(
1317 Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
1318 DCHECK(absl::c_any_of(parameters,
1319 [](const std::shared_ptr<Parameter>& parameter) {
1320 return parameter->name == kCycleLength;
1321 }));
1322 return std::make_shared<InterleaveMany>(std::move(args),
1323 std::move(parameters));
1324}
1325
1326std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
1327 Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
1328 DCHECK(absl::c_any_of(parameters,
1329 [](const std::shared_ptr<Parameter>& parameter) {
1330 return parameter->name == kCycleLength;
1331 }));
1332 return std::make_shared<AsyncInterleaveMany>(std::move(args),
1333 std::move(parameters));
1334}
1335
1336std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio) {
1337 return std::make_shared<KnownRatio>(std::move(args), ratio);
1338}
1339
1340std::shared_ptr<Node> MakeAsyncKnownRatioNode(
1341 Node::Args args, double ratio, double memory_ratio,
1342 std::vector<std::shared_ptr<Parameter>> parameters) {
1343 return std::make_shared<AsyncKnownRatio>(std::move(args), ratio, memory_ratio,
1344 std::move(parameters));
1345}
1346
1347std::shared_ptr<Node> MakeAsyncKnownRatioNode(
1348 Node::Args args, double ratio,
1349 std::vector<std::shared_ptr<Parameter>> parameters) {
1350 return MakeAsyncKnownRatioNode(std::move(args), /*ratio=*/ratio,
1351 /*memory_ratio=*/ratio, std::move(parameters));
1352}
1353
1354std::shared_ptr<Node> MakeSourceNode(Node::Args args) {
1355 return MakeKnownRatioNode(std::move(args), 0);
1356}
1357
1358std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args) {
1359 return std::make_shared<UnknownRatio>(std::move(args));
1360}
1361
1362std::shared_ptr<Node> MakeAsyncUnknownRatioNode(
1363 Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
1364 return std::make_shared<AsyncUnknownRatio>(std::move(args),
1365 std::move(parameters));
1366}
1367
1368std::shared_ptr<Node> MakeUnknownNode(Node::Args args) {
1369 return std::make_shared<Unknown>(std::move(args));
1370}
1371
1372double Node::ComputeWaitTime(const double& producer_time,
1373 const double& consumer_time,
1374 const double& buffer_size,
1375 double* producer_time_derivative,
1376 double* consumer_time_derivative,
1377 double* buffer_size_derivative) {
1378 // If we set x=`consumer_time`, y=`producer_time`, n=`buffer_size`,
1379 // p=`p_buffer_empty`, T=`wait_time`, then we have:
1380 // if y = 0, then p = 0;
1381 // elif x = 0, then p = 1;
1382 // elif x = y, then p = 1 / (n+1);
1383 // else p = [1 - x/y] / [1 - power(x/y, n+1)].
1384 //
1385 // We also have T = p * y, and derivatives of T w.r.t. x, y, n are computed:
1386 // dT/dx = dp/dx * y,
1387 // dT/dy = p + dp/dy * y,
1388 // dT/dn = dp/dn * y.
1389 // Then the remaining work is to compute dp/dx, dp/dy, dp/dn by considering
1390 // different cases and substitute the values into above formulas.
1391
1392 // Case 1: if producer is infinitely fast. The buffer will always be full.
1393 // Wait time will always be 0.
1394 if (producer_time == 0) {
1395 if (producer_time_derivative) {
1396 // Note a common error is `*producer_time_derivative = 0` since p=0 on the
1397 // line y=0 doesn't imply dp/dy = 0 there. Actually to compute dp/dy at
1398 // (x,0), we need to consider lim_{dy->0+} [p(x,dy)-p(x,0)] / dy, where
1399 // p(x,0)=0 and p(x,dy) = [1 - x/dy] / [1 - power(x/dy, n+1)].
1400 if (buffer_size == 0 || consumer_time == 0) {
1401 *producer_time_derivative = 1.0L;
1402 } else {
1403 *producer_time_derivative = 0.0L;
1404 }
1405 }
1406 if (consumer_time_derivative) {
1407 *consumer_time_derivative = 0.0L;
1408 }
1409 if (buffer_size_derivative) {
1410 *buffer_size_derivative = 0.0L;
1411 }
1412 return 0.0L;
1413 }
1414
1415 // Case 2: if consumer is infinitely fast. Wait time is always the time to
1416 // produce an output.
1417 if (consumer_time == 0) {
1418 if (producer_time_derivative) {
1419 *producer_time_derivative = 1.0L;
1420 }
1421 if (consumer_time_derivative) {
1422 // Note a common error is `*consumer_time_derivative = 0` since p=1 on the
1423 // line x=0 doesn't imply dp/dx = 0 there. Actually to compute dp/dx at
1424 // (0,y), we need to consider lim_{dx->0+} [p(dx,y)-p(0,y)] / dx, where
1425 // p(0,y)=1, p(dx,y) = [1 - dx/y] / [1 - power(dx/y, n+1)] if y!=0.
1426 if (buffer_size == 0) {
1427 *consumer_time_derivative = 0.0L;
1428 } else {
1429 *consumer_time_derivative = -1.0L;
1430 }
1431 }
1432 if (buffer_size_derivative) {
1433 *buffer_size_derivative = 0.0L;
1434 }
1435 return producer_time;
1436 }
1437
1438 // Case 3: the consumer and the producer are equally fast. Expected wait time
1439 // decreases linearly with the size of the buffer.
1440 if (consumer_time == producer_time) {
1441 const double p_buffer_empty = 1.0L / (buffer_size + 1.0L);
1442 const double p_buffer_empty_der =
1443 -buffer_size / (2.0L * buffer_size + 2.0L);
1444 if (producer_time_derivative) {
1445 // Note a common error is `*producer_time_derivative = p_buffer_empty`
1446 // since p=1/(n+1) on the line x=y doesn't imply dp/dy = 0 there. Actually
1447 // to compute dp/dy at (y,y), we need to consider lim_{dy->0}
1448 // [p(y,y+dy)-p(y,y)] / dy, where p(y,y)=1/(n+1), p(y,y+dy) = [1 -
1449 // y/(y+dy)] / [1 - power(y/(y+dy), n+1)].
1450 *producer_time_derivative = p_buffer_empty - p_buffer_empty_der;
1451 }
1452 if (consumer_time_derivative) {
1453 // Note a common error is `*consumer_time_derivative = 0` since p=1/(n+1)
1454 // on the line x=y doesn't imply dp/dx = 0 there. Actually to compute
1455 // dp/dx at (x,x), we need to consider lim_{dx->0} [p(x+dx,x)-p(x,x)] /
1456 // dx, where p(x,x)=1/(n+1), p(x+dx,x) = [1 - (x+dx)/x] / [1 -
1457 // power((x+dx)/x, n+1)].
1458 *consumer_time_derivative = p_buffer_empty_der;
1459 }
1460 if (buffer_size_derivative) {
1461 *buffer_size_derivative = -producer_time / Square(buffer_size + 1.0L);
1462 }
1463 return p_buffer_empty * producer_time;
1464 }
1465
1466 // Case 4: the consumer is slower than the producer and neither is infinitely
1467 // fast. Case 4 and Case 5 actually follow same formula. Separate them for
1468 // numerical computation reasons.
1469 if (consumer_time > producer_time) {
1470 const double ratio = producer_time / consumer_time;
1471 const double ratio_pow = std::pow(ratio, buffer_size);
1472 const double p_buffer_empty =
1473 ratio_pow * (1.0L - ratio) / (1.0L - ratio * ratio_pow);
1474 const double p_buffer_empty_der =
1475 (buffer_size - (buffer_size + 1.0L) * ratio + ratio_pow * ratio) *
1476 ratio_pow / ratio / Square(1.0L - ratio_pow * ratio);
1477 if (producer_time_derivative) {
1478 *producer_time_derivative = p_buffer_empty + p_buffer_empty_der * ratio;
1479 }
1480 if (consumer_time_derivative) {
1481 *consumer_time_derivative = -p_buffer_empty_der * Square(ratio);
1482 }
1483 if (buffer_size_derivative) {
1484 *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1485 std::log(ratio) * producer_time;
1486 }
1487 return p_buffer_empty * producer_time;
1488 }
1489
1490 // Case 5: the producer is slower than the consumer and neither is infinitely
1491 // fast.
1492 const double ratio = consumer_time / producer_time;
1493 const double ratio_pow = std::pow(ratio, buffer_size);
1494 const double p_buffer_empty = (1.0L - ratio) / (1.0L - ratio_pow * ratio);
1495 const double p_buffer_empty_der =
1496 ((buffer_size + 1.0L - buffer_size * ratio) * ratio_pow - 1.0L) /
1497 Square(1.0L - ratio_pow * ratio);
1498 if (producer_time_derivative) {
1499 *producer_time_derivative = p_buffer_empty - p_buffer_empty_der * ratio;
1500 }
1501 if (consumer_time_derivative) {
1502 *consumer_time_derivative = p_buffer_empty_der;
1503 }
1504 if (buffer_size_derivative) {
1505 *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1506 ratio_pow * ratio * std::log(ratio) *
1507 producer_time;
1508 }
1509 return p_buffer_empty * producer_time;
1510}
1511
1512Node::ModelParameters Node::CollectTunableParametersLocked() const {
1513 Node::ModelParameters parameters;
1514 // Collect tunable parameters from the leaves of the nodes tree to the root.
1515 for (const auto& node :
1516 CollectNodesLocked(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1517 tf_shared_lock l(node->mu_);
1518 node->CollectTunableParametersHelper(&parameters);
1519 }
1520 CollectTunableParametersHelper(&parameters);
1521 return parameters;
1522}
1523
1524Node::ModelParameters Node::CollectTunableParameters() const {
1525 tf_shared_lock l(mu_);
1526 return CollectTunableParametersLocked();
1527}
1528
1529Node::ModelParameters Node::CollectNodeTunableParameters() const {
1530 tf_shared_lock l(mu_);
1531 Node::ModelParameters parameters;
1532 CollectTunableParametersHelper(&parameters);
1533 return parameters;
1534}
1535
1536string Node::DebugString() const {
1537 absl::flat_hash_map<string, string> debug_strings;
1538 tf_shared_lock l(mu_);
1539 // Build up the debug string from the leaves of the nodes tree to the root.
1540 for (const auto& node :
1541 CollectNodesLocked(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1542 tf_shared_lock l(node->mu_);
1543 node->DebugStringHelper(&debug_strings);
1544 }
1545 DebugStringHelper(&debug_strings);
1546
1547 return debug_strings[long_name()];
1548}
1549
1550void Node::FlushMetrics() {
1551 if (!record_metrics_) {
1552 return;
1553 }
1554 metrics_.record_bytes_consumed(bytes_consumed_);
1555 metrics_.record_bytes_produced(bytes_produced_);
1556 metrics_.record_num_elements(num_elements_);
1557}
1558
1559double Node::OutputTime(Node::NodeValues* input_times,
1560 Node::ParameterGradients* gradients) const {
1561 // To store the output time gradient w.r.t. input time (if `gradients` is not
1562 // `nullptr`) and the output time for each node.
1563 Node::NodeValues output_time_gradients, output_times;
1564 tf_shared_lock l(mu_);
1565 auto nodes = CollectNodesLocked(TraversalOrder::BFS, IsAutotuneNode);
1566
1567 // Computes and stores input time for each node from the root to leaves of the
1568 // nodes tree.
1569 InputTimeLocked(input_times);
1570 for (const auto& node : nodes) {
1571 tf_shared_lock l(node->mu_);
1572 node->InputTimeLocked(input_times);
1573 }
1574
1575 std::reverse(nodes.begin(), nodes.end());
1576 // Computes and stores the output time and output time gradient w.r.t. input
1577 // time (if `gradients` is not `nullptr`) for each node from leaves of the
1578 // nodes tree to the root.
1579 for (const auto& node : nodes) {
1580 tf_shared_lock l(node->mu_);
1581 node->OutputTimeLocked(*input_times, gradients, &output_times,
1582 &output_time_gradients);
1583 }
1584 OutputTimeLocked(*input_times, gradients, &output_times,
1585 &output_time_gradients);
1586
1587 return output_times[long_name()];
1588}
1589
1590double Node::ComputeSelfTime() const {
1591 if (num_elements_ == 0) {
1592 return 0;
1593 }
1594 tf_shared_lock l(mu_);
1595 return processing_time_ema_;
1596}
1597
1598std::shared_ptr<Node> Node::Snapshot() const {
1599 NodePairList node_pairs;
1600 auto result = SnapshotHelper(nullptr, &node_pairs);
1601
1602 while (!node_pairs.empty()) {
1603 auto node_pair = node_pairs.front();
1604 node_pairs.pop_front();
1605 std::shared_ptr<Node> current = node_pair.first,
1606 cloned_output = node_pair.second;
1607 cloned_output->add_input(
1608 current->SnapshotHelper(cloned_output, &node_pairs));
1609 }
1610 return result;
1611}
1612
1613double Node::SelfProcessingTime() const {
1614 tf_shared_lock l(mu_);
1615 return SelfProcessingTimeLocked();
1616}
1617
1618double Node::TotalBufferedBytes() const {
1619 Node::NodeValues total_bytes;
1620 tf_shared_lock l(mu_);
1621 // Compute total buffered bytes from the leaves of the nodes tree to the root.
1622 for (const auto& node :
1623 CollectNodesLocked(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1624 tf_shared_lock l(node->mu_);
1625 node->TotalBufferedBytesHelper(&total_bytes);
1626 }
1627 TotalBufferedBytesHelper(&total_bytes);
1628
1629 return total_bytes[long_name()];
1630}
1631
1632double Node::TotalMaximumBufferedBytes() const {
1633 Node::NodeValues total_bytes;
1634 tf_shared_lock l(mu_);
1635 // Compute total maximum buffered bytes from the leaves of the nodes tree to
1636 // the root.
1637 for (const auto& node :
1638 CollectNodesLocked(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1639 tf_shared_lock l(node->mu_);
1640 node->TotalMaximumBufferedBytesHelper(&total_bytes);
1641 }
1642 TotalMaximumBufferedBytesHelper(&total_bytes);
1643
1644 return total_bytes[long_name()];
1645}
1646
1647double Node::TotalProcessingTime(Node::NodeValues* processing_times) {
1648 // Create a hash map to store the per-element CPU time spent in the subtree
1649 // rooted in each node.
1650 Node::NodeValues total_processing_times;
1651 tf_shared_lock l(mu_);
1652
1653 // Computes per-element CPU time spent in the subtree rooted in the node from
1654 // the leaves of the nodes tree to the root.
1655 for (const auto& node :
1656 CollectNodesLocked(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1657 tf_shared_lock l(node->mu_);
1658 node->TotalProcessingTimeLocked(processing_times, &total_processing_times);
1659 }
1660 TotalProcessingTimeLocked(processing_times, &total_processing_times);
1661
1662 return total_processing_times[long_name()];
1663}
1664
1665double Node::AverageBufferedElementSize() const {
1666 DCHECK_GE(num_elements_, 0);
1667 DCHECK_GE(buffered_elements_, 0);
1668 if (num_elements_ <= 0) {
1669 if (buffered_elements_ <= 0) {
1670 // If there are no produced elements or buffered elements recorded, return
1671 // 0.
1672 return 0;
1673 }
1674 // If there are no produced elements but some buffered elements, return the
1675 // average size of all buffered elements.
1676 return static_cast<double>(buffered_bytes_) /
1677 static_cast<double>(buffered_elements_);
1678 }
1679
1680 if (buffered_elements_ <= 0) {
1681 // If there are no buffered elements but some produced elements, return the
1682 // average size of all produced elements.
1683 return static_cast<double>(bytes_produced_) /
1684 static_cast<double>(num_elements_);
1685 }
1686
1687 // Otherwise, return the mean value of average size of all produced elements
1688 // and average size of all buffered elements.
1689 return (static_cast<double>(bytes_produced_) /
1690 static_cast<double>(num_elements_) +
1691 static_cast<double>(buffered_bytes_) /
1692 static_cast<double>(buffered_elements_)) /
1693 2.0;
1694}
1695
1696double Node::OutputTimeForInputs(const Node::NodeValues& output_times) const {
1697 double sum = 0;
1698 for (auto& input : inputs_) {
1699 // Inputs for which autotuning is disabled are excluded.
1700 if (input->autotune()) {
1701 sum += output_times.at(input->long_name());
1702 }
1703 }
1704 return sum;
1705}
1706
1707double Node::OutputTimeGradientsForInputs(
1708 const Node::NodeValues& output_time_gradients) const {
1709 double sum = 0;
1710 for (auto& input : inputs_) {
1711 // Inputs for which autotuning is disabled are excluded.
1712 if (input->autotune()) {
1713 sum +=
1714 gtl::FindWithDefault(output_time_gradients, input->long_name(), 0.0L);
1715 }
1716 }
1717 return sum;
1718}
1719
1720double Node::TotalProcessingTimeForInputs(
1721 const Node::NodeValues& total_processing_times) {
1722 // If the number of elements produced by an input is smaller than this
1723 // constant, then its processing time is estimated using a weighted average of
1724 // the empirical processing time and processing time history.
1725 constexpr int kNumElementsThreshold = 30;
1726
1727 // Identifies the minimum number of input processing times to collect before
1728 // the processing time history is used as a prior.
1729 constexpr int kCountThreshold = 30;
1730
1731 double sum = 0;
1732 for (auto& input : inputs_) {
1733 // Inputs for which autotuning is disabled are excluded.
1734 if (input->autotune()) {
1735 double input_processing_time =
1736 total_processing_times.at(input->long_name());
1737 int64_t num_elements = input->num_elements();
1738 if (num_elements < kNumElementsThreshold) {
1739 if (input_processing_time_count_ < kCountThreshold) {
1740 sum += input_processing_time;
1741 } else {
1742 // The fewer elements the input has produced so far, the more weight
1743 // is assigned to the prior to reduce volatility.
1744 double prior_weight = 1.0L / static_cast<double>(2 << num_elements);
1745 double prior =
1746 input_processing_time_sum_ / input_processing_time_count_;
1747 sum += (1.0L - prior_weight) * input_processing_time +
1748 prior_weight * prior;
1749 }
1750 } else {
1751 sum += input_processing_time;
1752 input_processing_time_count_++;
1753 input_processing_time_sum_ += input_processing_time;
1754 }
1755 }
1756 }
1757 return sum;
1758}
1759
1760double Node::SelfProcessingTimeLocked() const {
1761 if (num_elements_ == 0) {
1762 return 0;
1763 }
1764 return static_cast<double>(processing_time_) /
1765 static_cast<double>(num_elements_);
1766}
1767
1768Node::NodeVector Node::CollectNodes(
1769 TraversalOrder order,
1770 bool collect_node(const std::shared_ptr<Node>)) const {
1771 tf_shared_lock l(mu_);
1772 return CollectNodesLocked(order, collect_node);
1773}
1774
1775bool Node::TryDownsizeBuffer() {
1776 if (!IsAsync()) {
1777 return false;
1778 }
1779 Node::ModelParameters tunable_parameters;
1780 {
1781 tf_shared_lock l(mu_);
1782 if (buffered_elements_low_ > buffered_elements_high_) {
1783 // No element is stored in the buffer yet. Do nothing.
1784 return false;
1785 }
1786 CollectTunableParametersHelper(&tunable_parameters);
1787 }
1788 Node::ModelParameters buffer_size_parameters;
1789 for (auto& parameter : tunable_parameters) {
1790 if (parameter.second->name != kBufferSize) {
1791 continue;
1792 }
1793 buffer_size_parameters.push_back(std::move(parameter));
1794 }
1795 bool downsized = false;
1796 // Sync buffer state values to parameter values
1797 for (auto& [node_name, parameter] : buffer_size_parameters) {
1798 tf_shared_lock l(*parameter->state->mu);
1799 parameter->value = parameter->state->value;
1800 }
1801 {
1802 // Downsize buffers
1803 tf_shared_lock l(mu_);
1804 for (auto& [node_name, parameter] : buffer_size_parameters) {
1805 if (buffered_elements_low_ > 0 &&
1806 (buffered_elements_high_ - buffered_elements_low_ + 1) <
1807 parameter->value) {
1808 double old_value = parameter->value;
1809 // By default, we double buffer sizes if there is enough RAM in
1810 // upsize. We cap the downsize by 1/4 of the current size to avoid
1811 // undoing the previous upsize.
1812 parameter->value =
1813 std::max(buffered_elements_high_ - buffered_elements_low_ + 1,
1814 static_cast<int64_t>(old_value * 0.75));
1815 if (old_value != parameter->value) {
1816 VLOG(2) << "Downsize buffer " << long_name()
1817 << "::" << parameter->name << " from " << old_value << " to "
1818 << parameter->value;
1819 downsized = true;
1820 }
1821 }
1822 }
1823 }
1824 // Since SharedState locks are the same as the Ops iterator locks, locking of
1825 // the SharedState locks should be minimized in the optimization thread.
1826 if (downsized) {
1827 UpdateStateValues(&buffer_size_parameters);
1828 }
1829 return downsized;
1830}
1831
1832void Node::CollectBufferParametersToUpsize(
1833 absl::flat_hash_map<Node*, Parameter*>& node_parameters) {
1834 {
1835 tf_shared_lock l(mu_);
1836 for (auto& [node_name, parameter] : parameters_) {
1837 if ((parameter->name != kBufferSize) ||
1838 (parameter->state == nullptr || !parameter->state->tunable)) {
1839 continue;
1840 }
1841 if (buffered_elements_low_ <= 0 &&
1842 buffered_elements_high_ >= parameter->value) {
1843 parameter->value = parameter->state->value;
1844 node_parameters[this] = parameter.get();
1845 }
1846 }
1847 }
1848 for (auto& [node, parameter] : node_parameters) {
1849 tf_shared_lock l(*parameter->state->mu);
1850 parameter->value = parameter->state->value;
1851 }
1852}
1853
1854Node::NodeVector Node::CollectNodesLocked(
1855 TraversalOrder order, bool collect_node(const std::shared_ptr<Node>)) const
1856 TF_SHARED_LOCKS_REQUIRED(mu_) {
1857 NodeVector node_vector;
1858 std::list<std::shared_ptr<Node>> temp_list;
1859
1860 for (auto& input : inputs_) {
1861 if (collect_node(input)) {
1862 node_vector.push_back(input);
1863 temp_list.push_back(input);
1864 }
1865 }
1866
1867 while (!temp_list.empty()) {
1868 auto cur_node = temp_list.front();
1869 temp_list.pop_front();
1870 tf_shared_lock l(cur_node->mu_);
1871 for (auto& input : cur_node->inputs_) {
1872 if (collect_node(input)) {
1873 node_vector.push_back(input);
1874 temp_list.push_back(input);
1875 }
1876 }
1877 }
1878
1879 if (order == TraversalOrder::REVERSE_BFS) {
1880 std::reverse(node_vector.begin(), node_vector.end());
1881 }
1882 return node_vector;
1883}
1884
1885void Node::CollectTunableParametersHelper(
1886 Node::ModelParameters* parameters) const TF_SHARED_LOCKS_REQUIRED(mu_) {
1887 // If autotune is turned off or there are no elements recorded, we don't
1888 // collect the parameters on the node.
1889 if (!autotune_ || num_elements_ <= 0) {
1890 return;
1891 }
1892 for (auto& pair : parameters_) {
1893 if (pair.second->state != nullptr && pair.second->state->tunable) {
1894 parameters->push_back(std::make_pair(long_name(), pair.second));
1895 }
1896 }
1897}
1898
1899void Node::DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings)
1900 const TF_SHARED_LOCKS_REQUIRED(mu_) {
1901 string result;
1902 strings::StrAppend(&result, long_name(), ":\n");
1903 strings::StrAppend(&result, " autotune=", autotune_.load(), "\n");
1904 strings::StrAppend(&result, " buffered_bytes=", buffered_bytes_.load(),
1905 "\n");
1906 strings::StrAppend(&result, " buffered_elements=", buffered_elements_.load(),
1907 "\n");
1908 strings::StrAppend(&result, " bytes_consumed=", bytes_consumed_.load(),
1909 "\n");
1910 strings::StrAppend(&result, " bytes_produced=", bytes_produced_.load(),
1911 "\n");
1912 strings::StrAppend(&result, " processing_time=", processing_time_.load(),
1913 "\n");
1914 strings::StrAppend(&result, " num_elements=", num_elements_.load(), "\n");
1915 string inputs;
1916 for (auto& input : inputs_) {
1917 strings::StrAppend(&inputs, input->long_name(), ",");
1918 }
1919 strings::StrAppend(&result, " inputs={", inputs, "}\n");
1920 for (auto& input : inputs_) {
1921 strings::StrAppend(&result, debug_strings->at(input->long_name()));
1922 }
1923 debug_strings->insert(std::make_pair(long_name(), result));
1924}
1925
1926std::shared_ptr<Node> Node::SnapshotHelper(
1927 std::shared_ptr<Node> cloned_output, Node::NodePairList* node_pairs) const {
1928 tf_shared_lock l(mu_);
1929
1930 // Clone current node(`this`), also set clone of its output node
1931 // (`cloned_output`) to be the output node of the cloned node
1932 // (`cloned_current`).
1933 std::shared_ptr<Node> cloned_current = Clone(cloned_output);
1934 {
1935 cloned_current->autotune_.store(autotune_);
1936 cloned_current->buffered_bytes_.store(buffered_bytes_);
1937 cloned_current->buffered_elements_.store(buffered_elements_);
1938 cloned_current->buffered_elements_low_.store(buffered_elements_low_);
1939 cloned_current->buffered_elements_high_.store(buffered_elements_high_);
1940 cloned_current->bytes_consumed_.store(bytes_consumed_);
1941 cloned_current->bytes_produced_.store(bytes_produced_);
1942 cloned_current->num_elements_.store(num_elements_);
1943 cloned_current->record_metrics_.store(false);
1944 cloned_current->processing_time_.store(processing_time_);
1945 {
1946 mutex_lock l2(cloned_current->mu_);
1947 cloned_current->parameters_ = parameters_;
1948 cloned_current->previous_processing_time_ = previous_processing_time_;
1949 cloned_current->processing_time_ema_ = processing_time_ema_;
1950 }
1951 }
1952
1953 for (auto& input : inputs_) {
1954 node_pairs->push_back(std::make_pair(input, cloned_current));
1955 }
1956 return cloned_current;
1957}
1958
1959void Node::TotalBufferedBytesHelper(Node::NodeValues* total_bytes) const
1960 TF_SHARED_LOCKS_REQUIRED(mu_) {
1961 if (!autotune_) {
1962 total_bytes->insert(std::make_pair(long_name(), 0));
1963 return;
1964 }
1965
1966 double result = 0;
1967 auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
1968 if (!parameter) {
1969 parameter = gtl::FindOrNull(parameters_, kParallelism);
1970 }
1971 if (parameter) {
1972 result = buffered_bytes_;
1973 }
1974 for (auto& input : inputs_) {
1975 result += total_bytes->at(input->long_name());
1976 }
1977 total_bytes->insert(std::make_pair(long_name(), result));
1978}
1979
1980void Node::TotalMaximumBufferedBytesHelper(Node::NodeValues* total_bytes) const
1981 TF_SHARED_LOCKS_REQUIRED(mu_) {
1982 if (!autotune_) {
1983 total_bytes->insert(std::make_pair(long_name(), 0));
1984 return;
1985 }
1986
1987 double result = MaximumBufferedBytes();
1988 for (auto& input : inputs_) {
1989 result += total_bytes->at(input->long_name());
1990 }
1991 total_bytes->insert(std::make_pair(long_name(), result));
1992}
1993
1994double Node::MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
1995 return 0;
1996}
1997
1998Status Node::ToProto(ModelProto::Node* node_proto) const {
1999 tf_shared_lock l(mu_);
2000 node_proto->set_id(id_);
2001 node_proto->set_name(name_);
2002 node_proto->set_autotune(autotune_);
2003 node_proto->set_buffered_bytes(buffered_bytes_);
2004 node_proto->set_buffered_elements(buffered_elements_);
2005 node_proto->set_bytes_consumed(bytes_consumed_);
2006 node_proto->set_bytes_produced(bytes_produced_);
2007 node_proto->set_num_elements(num_elements_);
2008 node_proto->set_processing_time(processing_time_);
2009 node_proto->set_record_metrics(record_metrics_);
2010
2011 // Produce protos for all parameters.
2012 for (auto const& parameter : parameters_) {
2013 ModelProto::Node::Parameter* parameter_proto = node_proto->add_parameters();
2014 parameter_proto->set_name(parameter.first);
2015 parameter_proto->set_value(parameter.second->value);
2016 parameter_proto->set_min(parameter.second->min);
2017 parameter_proto->set_max(parameter.second->max);
2018 if (parameter.second->state != nullptr) {
2019 parameter_proto->set_state_value(parameter.second->state->value);
2020 parameter_proto->set_tunable(parameter.second->state->tunable);
2021 }
2022 }
2023
2024 // Add input node ids.
2025 for (auto const& input : inputs_) {
2026 node_proto->add_inputs(input->id());
2027 }
2028 return OkStatus();
2029}
2030
2031Status Node::FromProtoHelper(ModelProto::Node node_proto,
2032 std::shared_ptr<Node> node) {
2033 {
2034 tf_shared_lock l(node->mu_);
2035 node->autotune_.store(node_proto.autotune());
2036 node->buffered_bytes_.store(node_proto.buffered_bytes());
2037 node->buffered_elements_.store(node_proto.buffered_elements());
2038 if (node_proto.buffered_elements() == 0) {
2039 node->buffered_elements_low_.store(std::numeric_limits<int64_t>::max());
2040 node->buffered_elements_high_.store(std::numeric_limits<int64_t>::min());
2041 } else {
2042 node->buffered_elements_low_.store(node_proto.buffered_elements());
2043 node->buffered_elements_high_.store(node_proto.buffered_elements());
2044 }
2045 node->bytes_consumed_.store(node_proto.bytes_consumed());
2046 node->bytes_produced_.store(node_proto.bytes_produced());
2047 node->num_elements_.store(node_proto.num_elements());
2048 node->processing_time_.store(node_proto.processing_time());
2049 node->record_metrics_.store(node_proto.record_metrics());
2050
2051 // Restore parameters.
2052 int64_t num_parameters = node_proto.parameters_size();
2053 for (int i = 0; i < num_parameters; i++) {
2054 const ModelProto::Node::Parameter& parameter_proto =
2055 node_proto.parameters(i);
2056 std::shared_ptr<SharedState> state;
2057 if (parameter_proto.tunable()) {
2058 state = std::make_shared<SharedState>(
2059 kAutotune, std::make_shared<mutex>(),
2060 std::make_shared<condition_variable>());
2061 state->value = parameter_proto.state_value();
2062 } else {
2063 state = std::make_shared<SharedState>(
2064 parameter_proto.state_value(), std::make_shared<mutex>(),
2065 std::make_shared<condition_variable>());
2066 }
2067 node->parameters_[parameter_proto.name()] =
2068 MakeParameter(parameter_proto.name(), state, parameter_proto.min(),
2069 parameter_proto.max());
2070 node->parameters_[parameter_proto.name()]->value =
2071 std::max(parameter_proto.min(), parameter_proto.value());
2072 }
2073 }
2074 {
2075 mutex_lock l(node->mu_);
2076 node->UpdateProcessingTimeEma();
2077 }
2078 return OkStatus();
2079}
2080
2081Status Node::FromProto(ModelProto::Node node_proto,
2082 std::shared_ptr<Node> output,
2083 std::shared_ptr<Node>* node) {
2084 // Note that parameters are restored in `FromProtoHelper`.
2085 Args args = {node_proto.id(), node_proto.name(), std::move(output)};
2086 switch (node_proto.node_class()) {
2087 case NodeClass::INTERLEAVE_MANY:
2088 *node = std::make_shared<InterleaveMany>(args);
2089 break;
2090 case NodeClass::ASYNC_INTERLEAVE_MANY:
2091 *node = std::make_shared<AsyncInterleaveMany>(
2092 args, /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
2093 break;
2094 case NodeClass::KNOWN_RATIO:
2095 *node = std::make_shared<KnownRatio>(args, node_proto.ratio());
2096 break;
2097 case NodeClass::ASYNC_KNOWN_RATIO:
2098 *node = std::make_shared<AsyncKnownRatio>(
2099 args, node_proto.ratio(), node_proto.memory_ratio(),
2100 /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
2101 break;
2102 case NodeClass::UNKNOWN_RATIO:
2103 *node = std::make_shared<UnknownRatio>(args);
2104 break;
2105 case NodeClass::ASYNC_UNKNOWN_RATIO:
2106 *node = std::make_shared<AsyncUnknownRatio>(
2107 args, /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
2108 break;
2109 default:
2110 *node = std::make_shared<Unknown>(args);
2111 }
2112 return FromProtoHelper(node_proto, *node);
2113}
2114
2115Model::Model() : optimization_period_ms_(kOptimizationPeriodMinMs) {
2116 model_gauge_cell_ = metrics::GetTFDataModelGauge(
2117 strings::StrCat(reinterpret_cast<uint64>(this)));
2118 model_gauge_cell_->Set([&]() { return DebugString(); });
2119}
2120
2121Model::~Model() {
2122 // Before the model is destroyed, we record an empty string in the gauge to
2123 // prevent race condition where the gauge callback is called after the Model
2124 // is destroyed.
2125 model_gauge_cell_->Set([]() { return std::string(); });
2126}
2127
2128void Model::AddNode(Node::Factory factory, const string& name,
2129 std::shared_ptr<Node> parent,
2130 std::shared_ptr<Node>* out_node) {
2131 // The name captures the sequence of iterators joined by `::`. We only use the
2132 // last element of the sequence as the name node.
2133 auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back();
2134 mutex_lock l(mu_);
2135 std::shared_ptr<Node> node = factory({id_counter_++, node_name, parent});
2136 if (!output_) {
2137 output_ = node;
2138 }
2139 if (parent) {
2140 VLOG(3) << "Adding " << node->long_name() << " as input for "
2141 << parent->long_name();
2142 parent->add_input(node);
2143 } else {
2144 VLOG(3) << "Adding " << node->long_name();
2145 }
2146 *out_node = std::move(node);
2147 // TODO(jsimsa): Reset the optimization period when a node is added so that
2148 // autotuning adapts to changes to the input pipeline faster. Initial attempt
2149 // to enable this functionality caused a regression (see b/179812091).
2150}
2151
2152void Model::FlushMetrics() {
2153 std::deque<std::shared_ptr<Node>> queue;
2154 {
2155 tf_shared_lock l(mu_);
2156 if (output_) queue.push_back(output_);
2157 }
2158 while (!queue.empty()) {
2159 auto node = queue.front();
2160 queue.pop_front();
2161 node->FlushMetrics();
2162 for (auto input : node->inputs()) {
2163 queue.push_back(input);
2164 }
2165 }
2166}
2167
2168void Model::Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget,
2169 int64_t ram_budget, double model_input_time,
2170 CancellationManager* cancellation_manager) {
2171 std::shared_ptr<Node> snapshot;
2172 {
2173 tf_shared_lock l(mu_);
2174 snapshot = output_->Snapshot();
2175 }
2176 if (!port::JobName().empty()) {
2177 RecordAutotuneRamUsage(ram_budget, TotalMaximumBufferedBytes(snapshot));
2178 }
2179 OptimizationParams optimization_params;
2180 optimization_params.set_algorithm(algorithm);
2181 optimization_params.set_cpu_budget(cpu_budget);
2182 optimization_params.set_ram_budget(ram_budget);
2183 optimization_params.set_model_input_time(model_input_time);
2184 switch (algorithm) {
2185 case AutotuneAlgorithm::DEFAULT:
2186 case AutotuneAlgorithm::MAX_PARALLELISM:
2187 OptimizeMaxParallelism(snapshot, optimization_params,
2188 cancellation_manager);
2189 break;
2190 case AutotuneAlgorithm::HILL_CLIMB:
2191 OptimizeHillClimb(snapshot, optimization_params, cancellation_manager);
2192 break;
2193 case AutotuneAlgorithm::GRADIENT_DESCENT:
2194 OptimizeGradientDescent(snapshot, optimization_params,
2195 cancellation_manager);
2196 break;
2197 case AutotuneAlgorithm::STAGE_BASED:
2198 OptimizeStageBased(snapshot, optimization_params, cancellation_manager);
2199 break;
2200 default:
2201 VLOG(2) << "Autotuning algorithm was not recognized. Aborting "
2202 "optimization.";
2203 return;
2204 }
2205 if (experiment_ == "autotune_buffer_optimization") {
2206 OptimizeBuffers(snapshot, optimization_params.ram_budget());
2207 }
2208}
2209
2210void Model::RemoveNode(std::shared_ptr<Node> node) {
2211 mutex_lock l(mu_);
2212 if (node) {
2213 if (node->output()) {
2214 node->output()->remove_input(node);
2215 }
2216 VLOG(3) << "Removing " << node->long_name();
2217 }
2218}
2219
2220Model::ModelParameters Model::CollectTunableParameters(
2221 std::shared_ptr<Node> node) {
2222 return node->CollectTunableParameters();
2223}
2224
2225bool Model::DownsizeBuffers(std::shared_ptr<Node> snapshot) {
2226 Node::NodeVector nodes =
2227 snapshot->CollectNodes(TraversalOrder::BFS, IsAsyncNode);
2228 nodes.push_back(snapshot);
2229 bool downsized = false;
2230 for (auto& node : nodes) {
2231 if (node->TryDownsizeBuffer()) {
2232 downsized = true;
2233 }
2234 }
2235 return downsized;
2236}
2237
2238absl::flat_hash_map<Node*, Parameter*> Model::CollectBufferParametersToUpsize(
2239 std::shared_ptr<Node> snapshot) {
2240 Node::NodeVector nodes =
2241 snapshot->CollectNodes(TraversalOrder::BFS, IsAsyncNode);
2242 absl::flat_hash_map<Node*, Parameter*> node_parameters;
2243 if (snapshot->IsAsync()) {
2244 snapshot->CollectBufferParametersToUpsize(node_parameters);
2245 }
2246 for (auto& node : nodes) {
2247 node->CollectBufferParametersToUpsize(node_parameters);
2248 }
2249 return node_parameters;
2250}
2251
2252bool Model::ShouldStop(int64_t cpu_budget, int64_t ram_budget,
2253 const Model::ModelParameters& parameters,
2254 const Model::ModelParameters& parallelism_parameters,
2255 const Model::ModelParameters& buffer_size_parameters,
2256 std::shared_ptr<Node> snapshot,
2257 bool* cpu_budget_reached) {
2258 if (!(*cpu_budget_reached)) {
2259 // If those essential transformations' parallelism reaches the CPU budget,
2260 // we will only tune the buffer size parameters in future iterations.
2261 int64_t model_parallelism = 0;
2262 for (auto& pair : parallelism_parameters) {
2263 model_parallelism += std::round(pair.second->value);
2264 }
2265 *cpu_budget_reached = (model_parallelism > cpu_budget);
2266 }
2267
2268 bool all_max = AreAllParametersMax(
2269 *cpu_budget_reached ? buffer_size_parameters : parameters);
2270
2271 // If all parameters have reached their maximum values or RAM budget is
2272 // reached, we stop the iterations.
2273 return all_max || TotalMaximumBufferedBytes(snapshot) > ram_budget;
2274}
2275
2276// TODO(jsimsa): Add support for tracking and using the model input time.
2277Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget,
2278 int64_t ram_budget,
2279 CancellationManager* cancellation_manager) {
2280 std::function<void()> unused;
2281 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
2282 cancellation_manager,
2283 [this]() {
2284 mutex_lock l(mu_);
2285 optimize_cond_var_.notify_all();
2286 },
2287 /*deregister_fn=*/&unused));
2288
2289 int64_t last_optimization_ms = 0;
2290 int64_t current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
2291 while (true) {
2292 {
2293 mutex_lock l(mu_);
2294 while (!cancellation_manager->IsCancelled() &&
2295 last_optimization_ms + optimization_period_ms_ > current_time_ms) {
2296 auto wait_ms =
2297 last_optimization_ms + optimization_period_ms_ - current_time_ms;
2298 VLOG(2) << "Waiting for " << wait_ms << " ms.";
2299 optimize_cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
2300 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
2301 }
2302 if (cancellation_manager->IsCancelled()) {
2303 return OkStatus();
2304 }
2305 }
2306
2307 int64_t start_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
2308 double model_input_time = 0.0;
2309 // Model input time is set to 0 for all optimization algorithms except for
2310 // stage-based optimization algorithm for historical reason. In stage-based
2311 // optimization algorithm, the model input time is used as a target
2312 // optimization time of all stages in the pipeline.
2313 if (algorithm == AutotuneAlgorithm::STAGE_BASED) {
2314 model_input_time = ComputeTargetTimeNsec();
2315 }
2316 Optimize(algorithm, cpu_budget, ram_budget, model_input_time,
2317 cancellation_manager);
2318 int64_t end_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
2319 VLOG(2) << "Optimized for " << end_ms - start_ms << " ms.";
2320
2321 // Exponentially increase the period of running the optimization until a
2322 // threshold is reached.
2323 {
2324 mutex_lock l(mu_);
2325 optimization_period_ms_ =
2326 std::min(optimization_period_ms_ << 1, kOptimizationPeriodMaxMs);
2327 }
2328 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
2329 last_optimization_ms = current_time_ms;
2330 FlushMetrics();
2331 }
2332}
2333
2334void Model::OptimizeGradientDescent(
2335 std::shared_ptr<Node> snapshot,
2336 const OptimizationParams& optimization_params,
2337 CancellationManager* cancellation_manager) {
2338 VLOG(2) << "Starting optimization of tunable parameters with Gradient "
2339 "Descent.";
2340 auto parameters = CollectTunableParameters(snapshot);
2341 if (parameters.empty()) {
2342 VLOG(2) << "The Gradient Descent optimization is terminated since no node "
2343 "with tunable parameters has recorded elements.";
2344 return;
2345 }
2346 VLOG(2) << "Number of tunable parameters: " << parameters.size();
2347
2348 // The vectors of "essential" parallelism parameters and buffer size
2349 // parameters.
2350 Model::ModelParameters parallelism_parameters, buffer_size_parameters;
2351 CollectParameters(snapshot, parameters, &parallelism_parameters,
2352 &buffer_size_parameters);
2353
2354 // Initialize the parameter values to minimal before tuning.
2355 for (auto& pair : parameters) {
2356 pair.second->value = pair.second->min;
2357 }
2358
2359 // Optimization is stopped once the `OutputTime` improvement is smaller than
2360 // this value.
2361 constexpr double kOptimizationPrecision = 100.0L;
2362
2363 // Maximum number of iterations for optimization.
2364 constexpr int64_t kMaxIterations = 1000;
2365
2366 double output_time = 0;
2367 double new_output_time;
2368
2369 // When the CPU budget is reached, the parallelism parameter values are fixed
2370 // and we only increase the buffer size parameters.
2371 bool cpu_budget_reached = false;
2372
2373 for (int i = 0; i < kMaxIterations; ++i) {
2374 if (cancellation_manager->IsCancelled() ||
2375 ShouldStop(optimization_params.cpu_budget(),
2376 optimization_params.ram_budget(), parameters,
2377 parallelism_parameters, buffer_size_parameters, snapshot,
2378 &cpu_budget_reached)) {
2379 break;
2380 }
2381 Model::ParameterGradients gradients;
2382 new_output_time = OutputTime(
2383 snapshot, optimization_params.model_input_time(), &gradients);
2384 // We also terminate once the improvement of the output latency is too
2385 // small.
2386 if (std::abs(output_time - new_output_time) < kOptimizationPrecision) {
2387 break;
2388 }
2389
2390 UpdateParameterValues(
2391 gradients, &(cpu_budget_reached ? buffer_size_parameters : parameters));
2392 output_time = new_output_time;
2393 }
2394
2395 for (auto& pair : parameters) {
2396 pair.second->value = std::round(pair.second->value);
2397 }
2398 UpdateStateValues(&parameters);
2399}
2400
2401void Model::OptimizeHillClimbHelper(
2402 std::shared_ptr<Node> snapshot,
2403 const OptimizationParams& optimization_params,
2404 CancellationManager* cancellation_manager, StopPredicate should_stop) {
2405 VLOG(2) << "Starting optimization of tunable parameters with Hill Climb.";
2406 const double processing_time = TotalProcessingTime(snapshot);
2407 auto parameters = CollectTunableParameters(snapshot);
2408 if (parameters.empty()) {
2409 VLOG(2) << "There are no tunable parameters.";
2410 return;
2411 }
2412 VLOG(2) << "Number of tunable parameters: " << parameters.size();
2413
2414 // Buffer size parameter will only be incremented if the output latency
2415 // improvement is greater than this constant.
2416 constexpr double kBufferSizeMinDelta = 1.0L;
2417
2418 // Skip buffer size optimization if we are running the new buffering
2419 // algorithm.
2420 bool skip_buffer_sizes = (experiment_ == "autotune_buffer_optimization");
2421 if (skip_buffer_sizes) {
2422 constexpr float TEN_MINUTES = 60.0 * 10.0;
2423 LOG_EVERY_N_SEC(INFO, TEN_MINUTES)
2424 << "Skipping buffer_size parameters in HillClimb (message logged "
2425 "every "
2426 "10 minutes).";
2427 }
2428 // Initialize the parameter values to minimal before tuning.
2429 for (auto& pair : parameters) {
2430 if (skip_buffer_sizes && (pair.second->name == kBufferSize)) {
2431 continue;
2432 }
2433 pair.second->value = pair.second->min;
2434 }
2435 while (!cancellation_manager->IsCancelled()) {
2436 const double output_time =
2437 OutputTime(snapshot, optimization_params.model_input_time(),
2438 /*gradients=*/nullptr);
2439 if (should_stop(parameters, processing_time, output_time,
2440 TotalMaximumBufferedBytes(snapshot))) {
2441 break;
2442 }
2443
2444 double best_delta = -1.0L;
2445 Parameter* best_parameter = nullptr;
2446 for (auto& pair : parameters) {
2447 if (pair.second->value >= pair.second->max ||
2448 (skip_buffer_sizes && (pair.second->name == kBufferSize))) {
2449 continue;
2450 }
2451 pair.second->value++;
2452 double new_output_time =
2453 OutputTime(snapshot, optimization_params.model_input_time(),
2454 /*gradients=*/nullptr);
2455 double delta = output_time - new_output_time;
2456 if (delta > best_delta &&
2457 (delta > kBufferSizeMinDelta || pair.second->name != kBufferSize)) {
2458 best_delta = delta;
2459 best_parameter = pair.second.get();
2460 }
2461 pair.second->value--;
2462 }
2463 if (!best_parameter) {
2464 VLOG(2) << "Failed to find a tunable parameter that would further "
2465 "decrease the output time. This suggests that the hill-climb "
2466 "optimization got stuck in a local maximum. The optimization "
2467 "attempt will stop now.";
2468 break;
2469 }
2470 best_parameter->value++;
2471 }
2472 UpdateStateValues(&parameters);
2473}
2474void Model::RecordIteratorGapTime(uint64_t duration_usec) {
2475 mutex_lock l(gap_mu_);
2476 // Drop duration if it is too large.
2477 if (duration_usec >= kGapDurationThresholdUsec) {
2478 return;
2479 }
2480 gap_times_usec_.push_back(duration_usec);
2481 // Keep only the latest `window` gap times. Drop the oldest one.
2482 while (gap_times_usec_.size() > kGapTimeWindow) {
2483 gap_times_usec_.pop_front();
2484 }
2485}
2486
2487double Model::ComputeTargetTimeNsec() {
2488 tf_shared_lock l(gap_mu_);
2489 if (gap_times_usec_.empty()) {
2490 return 0.0;
2491 }
2492 // Remove outliers.
2493 std::vector<uint64_t> clean_gap_times_usec =
2494 OutlierPruner({gap_times_usec_.begin(), gap_times_usec_.end()})
2495 .GetCleanPoints();
2496 if (clean_gap_times_usec.empty()) {
2497 return 0.0;
2498 }
2499 // Compute mean after outliers are removed.
2500 double sum_gap_time_usec = std::accumulate(clean_gap_times_usec.begin(),
2501 clean_gap_times_usec.end(), 0);
2502 return sum_gap_time_usec / static_cast<double>(clean_gap_times_usec.size()) *
2503 1.0e3;
2504}
2505
2506void Model::OptimizeStageBased(std::shared_ptr<Node> snapshot,
2507 const OptimizationParams& optimization_params,
2508 CancellationManager* cancellation_manager) {
2509 return OptimizeStageBasedParallelism(
2510 snapshot, optimization_params.model_input_time(), optimization_params,
2511 cancellation_manager);
2512}
2513
2514void Model::OptimizeStageBasedParallelism(
2515 std::shared_ptr<Node> snapshot, double target_time_nsec,
2516 const OptimizationParams& optimization_params,
2517 CancellationManager* cancellation_manager) {
2518 VLOG(2) << "Starting optimization of tunable parameters with Stage-Based "
2519 "optimization with a target time of "
2520 << optimization_params.model_input_time() << " nanoseconds.";
2521 Node::ModelParameters tunable_parameters = CollectTunableParameters(snapshot);
2522 // Initialize the parallelism parameter values to minimal before tuning.
2523 for (std::pair<string, std::shared_ptr<Parameter>>& pair :
2524 tunable_parameters) {
2525 if (pair.second->name != kParallelism) {
2526 continue;
2527 }
2528 pair.second->value = pair.second->min;
2529 }
2530 ModelTiming model_timing(snapshot);
2531 ModelTimingPriorityQueue priority_queue(model_timing);
2532 StatusOr<std::pair<double, Node*>> critical_root_status =
2533 priority_queue.PopSlowestStageRoot();
2534 if (!critical_root_status.ok()) {
2535 return;
2536 }
2537 NodeParallelismParameters node_parallelism;
2538 std::pair<double, Node*> critical_root = critical_root_status.value();
2539 while (critical_root.first > target_time_nsec) {
2540 Parameter* parallelism_parameter =
2541 node_parallelism.Get(critical_root.second);
2542 // Stop optimization if the critical stage has no `parallelism` parameter or
2543 // it has reached the max parallelism value.
2544 if (parallelism_parameter == nullptr ||
2545 parallelism_parameter->value >= parallelism_parameter->max) {
2546 break;
2547 }
2548 parallelism_parameter->value += 1.0;
2549 if (TotalMaximumBufferedBytes(snapshot) >
2550 optimization_params.ram_budget()) {
2551 // Increasing the parallelism by 1 exceeded ram budget. Reduce it back and
2552 // stop optimization because we cannot improve the most critical stage.
2553 // There is also a decent chance that the current optimization iteration
2554 // is under-optimized. For that reason, return immediately without
2555 // updating the parameter state values.
2556 parallelism_parameter->value -= 1.0;
2557 return;
2558 }
2559 // Compute the new total time and put the node back in the queue after its
2560 // parallelism value has been increased by 1.
2561 model_timing.ComputeNodeTotalTime(*critical_root.second);
2562 const ModelTiming::NodeTiming* root_timing =
2563 model_timing.GetTiming(critical_root.second);
2564 // If timing has not improved, stop optimizing.
2565 if (critical_root.first <= root_timing->total_time_nsec) {
2566 parallelism_parameter->value -= 1.0;
2567 break;
2568 }
2569 // Push it back to the priority queue.
2570 priority_queue.Push(critical_root.second, *root_timing);
2571 // Get the next critical stage root.
2572 critical_root_status = priority_queue.PopSlowestStageRoot();
2573 if (!critical_root_status.ok()) {
2574 break;
2575 }
2576 critical_root = critical_root_status.value();
2577 }
2578 UpdateStateValues(&tunable_parameters);
2579}
2580
2581void Model::OptimizeBuffers(std::shared_ptr<Node> snapshot,
2582 int64_t ram_budget) {
2583 VLOG(2) << "Starting optimization of buffer_size parameters.";
2584 constexpr float TEN_MINUTES = 60.0 * 10.0;
2585 LOG_EVERY_N_SEC(INFO, TEN_MINUTES)
2586 << "Starting optimization of buffer_size parameters (message logged "
2587 "every 10 minutes).";
2588 // Reset node watermarks if any node's buffer is upsized or downsized. We
2589 // reset the watermarks of not only those nodes whose sizes change but all
2590 // nodes. The reason is that the optimization algorithm works on a snapshot of
2591 // nodes. There is no back references from snapshot of nodes to nodes. We
2592 // could add these back references but it is probably not necessary.
2593 bool downsized = DownsizeBuffers(snapshot);
2594 bool upsized = UpsizeBuffers(snapshot, ram_budget);
2595 if (downsized || upsized) {
2596 ResetBufferWatermarks();
2597 }
2598}
2599
2600bool Model::UpsizeBuffers(std::shared_ptr<Node> snapshot, int64_t ram_budget) {
2601 // Find buffers that should be up-sized.
2602 absl::flat_hash_map<Node*, Parameter*> node_parameters =
2603 CollectBufferParametersToUpsize(snapshot);
2604
2605 // Compute available memory.
2606 double available_ram_bytes =
2607 static_cast<double>(ram_budget) - TotalMaximumBufferedBytes(snapshot);
2608
2609 // Compute the max memory used by all buffers that should be upsized.
2610 double max_buffered_bytes = 0;
2611 for (auto& [node, parameter] : node_parameters) {
2612 if (node->buffered_elements() == 0) {
2613 continue;
2614 }
2615 max_buffered_bytes += static_cast<double>(node->buffered_bytes()) /
2616 static_cast<double>(node->buffered_elements()) *
2617 parameter->value;
2618 }
2619
2620 // Compute a uniform scaling factor for all buffers. Cap the factor at 2.
2621 double scaling_factor = 2.0;
2622 if (max_buffered_bytes > 0) {
2623 scaling_factor =
2624 1.0 + std::min(1.0, available_ram_bytes / max_buffered_bytes);
2625 }
2626
2627 bool upsized = false;
2628 // Up-size all buffers by the scaling factor.
2629 for (auto& [node, parameter] : node_parameters) {
2630 double old_value = parameter->value;
2631 // Scale the new buffer_size value. Use 1 if it is less than 1.
2632 double new_value = std::max(1.0, static_cast<double>(static_cast<int64_t>(
2633 parameter->value * scaling_factor)));
2634 // Cap the new buffer_size value at its max value.
2635 parameter->value = std::min(parameter->max, new_value);
2636 VLOG(2) << "Upsize buffer " << node->long_name() << "::" << parameter->name
2637 << " from " << old_value << " to " << parameter->value;
2638 if (parameter->value != parameter->state->value) {
2639 {
2640 mutex_lock l(*parameter->state->mu);
2641 parameter->state->value = parameter->value;
2642 parameter->state->cond_var->notify_all();
2643 }
2644 upsized = true;
2645 }
2646 }
2647 return upsized;
2648}
2649
2650void Model::ResetBufferWatermarks() {
2651 Node::NodeVector nodes =
2652 output()->CollectNodes(TraversalOrder::BFS, IsAsyncNode);
2653 nodes.push_back(output());
2654 for (auto& node : nodes) {
2655 node->ResetBufferWatermarks();
2656 }
2657}
2658
2659void Model::OptimizeHillClimb(std::shared_ptr<Node> snapshot,
2660 const OptimizationParams& optimization_params,
2661 CancellationManager* cancellation_manager) {
2662 auto should_stop = [&optimization_params](const ModelParameters& parameters,
2663 double processing_time,
2664 double output_time,
2665 double buffered_bytes) {
2666 const bool all_max = AreAllParametersMax(parameters);
2667 const bool output_time_budget_exceeded =
2668 output_time < processing_time / optimization_params.cpu_budget();
2669 const bool ram_budget_exceeded =
2670 buffered_bytes > optimization_params.ram_budget();
2671 if (all_max) {
2672 metrics::RecordTFDataAutotuneStoppingCriteria("all_max");
2673 }
2674 if (output_time_budget_exceeded) {
2675 metrics::RecordTFDataAutotuneStoppingCriteria("output_time");
2676 }
2677 if (ram_budget_exceeded) {
2678 metrics::RecordTFDataAutotuneStoppingCriteria("max_buffered_bytes");
2679 }
2680 return all_max || output_time_budget_exceeded || ram_budget_exceeded;
2681 };
2682 OptimizeHillClimbHelper(snapshot, optimization_params, cancellation_manager,
2683 should_stop);
2684}
2685
2686void Model::OptimizeMaxParallelism(
2687 std::shared_ptr<Node> snapshot,
2688 const OptimizationParams& optimization_params,
2689 CancellationManager* cancellation_manager) {
2690 auto should_stop = [&optimization_params](const ModelParameters& parameters,
2691 double processing_time,
2692 double output_time,
2693 double buffered_bytes) {
2694 const bool all_max = AreAllParametersMax(parameters);
2695 const bool ram_budget_exceeded =
2696 buffered_bytes > optimization_params.ram_budget();
2697 if (all_max) {
2698 metrics::RecordTFDataAutotuneStoppingCriteria("all_max");
2699 }
2700 if (ram_budget_exceeded) {
2701 metrics::RecordTFDataAutotuneStoppingCriteria("max_buffered_bytes");
2702 }
2703 return all_max || ram_budget_exceeded;
2704 };
2705 OptimizeHillClimbHelper(snapshot, optimization_params, cancellation_manager,
2706 should_stop);
2707}
2708
2709double Model::OutputTime(std::shared_ptr<Node> node, double model_input_time,
2710 Model::ParameterGradients* gradients) {
2711 // To store the input time for each node.
2712 Model::NodeValues input_times = {{kModelInputTimeKey, model_input_time}};
2713
2714 // TODO(jsimsa): Now that we are accounting for buffer size in wait time
2715 // computation, assuming that the input is infinitely fast will result in
2716 // inaccurate estimates of the output latency.
2717 //
2718 // We should compute the output latency as a fix-point of the following
2719 // equation: `output_time = node(OutputTime(input_times(1, output_time))`.
2720
2721 return node->OutputTime(&input_times, gradients);
2722}
2723
2724double Model::TotalBufferedBytes(std::shared_ptr<Node> node) {
2725 return node->TotalBufferedBytes();
2726}
2727
2728double Model::TotalMaximumBufferedBytes(std::shared_ptr<Node> node) {
2729 return node->TotalMaximumBufferedBytes();
2730}
2731
2732double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
2733 return node->TotalProcessingTime(/*processing_times=*/nullptr);
2734}
2735
2736Status Model::ToProto(ModelProto* model_proto) {
2737 tf_shared_lock l(mu_);
2738 model_proto->set_id_counter(id_counter_);
2739 return ModelToProtoHelper(output_, model_proto);
2740}
2741
2742Status Model::FromProto(ModelProto model_proto, std::unique_ptr<Model>* model) {
2743 std::unique_ptr<Model> restored_model = std::make_unique<Model>();
2744 mutex_lock l(restored_model->mu_);
2745 TF_RETURN_IF_ERROR(
2746 ModelFromProtoHelper(model_proto, &restored_model->output_));
2747 restored_model->id_counter_ = model_proto.id_counter();
2748 *model = std::move(restored_model);
2749 return OkStatus();
2750}
2751
2752Status Model::Save(const string& fname, std::shared_ptr<Node> snapshot,
2753 const OptimizationParams& optimization_params) {
2754 ModelProto model_proto;
2755 std::unique_ptr<Model> model_snapshot = std::make_unique<Model>();
2756 {
2757 mutex_lock l(model_snapshot->mu_);
2758 model_snapshot->output_ = std::move(snapshot);
2759 model_snapshot->id_counter_ = id_counter_;
2760 }
2761 TF_RETURN_IF_ERROR(model_snapshot->ToProto(&model_proto));
2762 OptimizationParams* saved_optimization_params =
2763 model_proto.mutable_optimization_params();
2764 *saved_optimization_params = optimization_params;
2765 return WriteBinaryProto(Env::Default(), fname, model_proto);
2766}
2767
2768Status Model::Load(const string& fname, std::unique_ptr<Model>* model,
2769 OptimizationParams* optimization_params) {
2770 ModelProto model_proto;
2771 TF_RETURN_IF_ERROR(
2772 ReadTextOrBinaryProto(Env::Default(), fname, &model_proto));
2773 TF_RETURN_IF_ERROR(FromProto(model_proto, model));
2774 const OptimizationParams restored_optimization_params =
2775 model_proto.optimization_params();
2776 *optimization_params = restored_optimization_params;
2777 return OkStatus();
2778}
2779
2780std::string Model::DebugString() {
2781 constexpr int64_t kMinSecondsBetweenCalls = 30;
2782 if (absl::Now() < cache_until_) return cached_debug_string_;
2783 std::shared_ptr<Node> snapshot;
2784 {
2785 tf_shared_lock l(mu_);
2786 if (!output_) return cached_debug_string_;
2787 snapshot = output_->Snapshot();
2788 }
2789 // TODO(jsimsa): Populate OptimizationParams.
2790 ModelProto model_proto;
2791 Status s = ModelToProtoHelper(snapshot, &model_proto);
2792 if (s.ok()) {
2793 cached_debug_string_ = model_proto.DebugString();
2794 } else {
2795 LOG(WARNING) << s.error_message();
2796 }
2797 cache_until_ = absl::Now() + absl::Seconds(kMinSecondsBetweenCalls);
2798 return cached_debug_string_;
2799}
2800
2801ModelTiming::ModelTiming(std::shared_ptr<Node> root) : root_(root) {
2802 DCHECK(root_.get() != nullptr);
2803 auto bfs_nodes = CollectNodes(root_, TraversalOrder::BFS, IsAnyNode);
2804 auto reverse_bfs_nodes = bfs_nodes;
2805 std::reverse(reverse_bfs_nodes.begin(), reverse_bfs_nodes.end());
2806 ComputePipelineRatios(bfs_nodes);
2807 ComputeTotalTimes(reverse_bfs_nodes);
2808}
2809
2810Node::NodeVector ModelTiming::CollectNodes(
2811 std::shared_ptr<Node> root, TraversalOrder order,
2812 bool collect_node(const std::shared_ptr<Node>)) const {
2813 if (root == nullptr) {
2814 return Node::NodeVector({});
2815 }
2816 auto subtree_nodes = root->CollectNodes(order, collect_node);
2817 Node::NodeVector nodes;
2818 if (order == TraversalOrder::BFS) {
2819 nodes.push_back(root);
2820 nodes.insert(nodes.end(), subtree_nodes.begin(), subtree_nodes.end());
2821 } else {
2822 nodes.insert(nodes.end(), subtree_nodes.begin(), subtree_nodes.end());
2823 nodes.push_back(root);
2824 }
2825 return nodes;
2826}
2827
2828const ModelTiming::NodeTiming* ModelTiming::GetTiming(const Node* node) const {
2829 if (timing_nodes_.find(node) == timing_nodes_.end()) {
2830 return nullptr;
2831 }
2832 return &(timing_nodes_.at(node));
2833}
2834
2835void ModelTiming::ComputePipelineRatios(const Node::NodeVector& bfs_nodes) {
2836 for (const auto& node : bfs_nodes) {
2837 auto& node_timing = timing_nodes_[node.get()];
2838 if (!node->autotune()) {
2839 // These are inactive nodes marked by parallel interleave
2840 // transformations.
2841 node_timing.pipeline_ratio = 0.0;
2842 continue;
2843 }
2844 double parent_pipeline_ratio = 1.0;
2845 double parent_ratio = 1.0;
2846 if (node->output() != nullptr || timing_nodes_.contains(node->output())) {
2847 const auto& output_timing = timing_nodes_[node->output()];
2848 parent_pipeline_ratio = output_timing.pipeline_ratio;
2849 parent_ratio = node->output()->Ratio();
2850 if (parent_ratio <= 0.0) {
2851 // Parent ratio is unknown, we use 1.0 as a guess.
2852 parent_ratio = 1.0;
2853 }
2854 }
2855 node_timing.pipeline_ratio = parent_pipeline_ratio * parent_ratio;
2856 }
2857}
2858
2859void ModelTiming::ComputeNonAsyncInterleaveManyTotalTime(const Node& node) {
2860 DCHECK(timing_nodes_.contains(&node));
2861 auto& node_timing = timing_nodes_[&node];
2862 double input_total_time_nsec = 0.0;
2863 for (auto input : node.inputs()) {
2864 if (input->IsAsync()) {
2865 continue;
2866 }
2867 if (!input->autotune() || input->num_elements() <= 0) {
2868 continue;
2869 }
2870 DCHECK(timing_nodes_.contains(input.get()))
2871 << "Input " << input->long_name() << " of node " << node.long_name()
2872 << " has no timing node.";
2873
2874 input_total_time_nsec += timing_nodes_[input.get()].total_time_nsec;
2875 }
2876 node_timing.total_time_nsec =
2877 node_timing.self_time_nsec + input_total_time_nsec * node.Ratio();
2878}
2879
2880void ModelTiming::ComputeAsyncInterleaveManyTotalTime(const Node& node) {
2881 DCHECK(timing_nodes_.contains(&node));
2882 auto& node_timing = timing_nodes_[&node];
2883 double max_input_total_time_nsec = 0.0;
2884 double sum_input_throughput = 0.0;
2885 auto inputs = node.inputs();
2886 // `ParallelInterleave` is often used to interleave processing of datasets
2887 // generated from the first input, e.g. reading from IO where the first input
2888 // has the list of all filenames. The first input is typically not the
2889 // bottleneck. We exclude the timing of the first input in the throughput
2890 // computation of the remaining input. It also excluded from the total time
2891 // computation of the async interleave node.
2892 auto input = std::next(inputs.begin());
2893 // `num_active_inputs` holds the number of inputs that the
2894 // `ParallelInterleave` is reading from, not including those that are warm
2895 // starting, which can be detected by checking the value of `autotune()`. It
2896 // also does not count async inputs because they would be in their own
2897 // stages. This number is typically the same as `cycle_length`. It will be
2898 // used below to scale the throughput of inputs if `cycle_length` is smaller
2899 // than `num_active_inputs`.
2900 int num_active_inputs = 0;
2901 for (; input != inputs.end(); ++input) {
2902 if ((*input)->IsAsync()) {
2903 continue;
2904 }
2905 if (!(*input)->autotune() || (*input)->num_elements() <= 0) {
2906 continue;
2907 }
2908 DCHECK(timing_nodes_.contains((*input).get()))
2909 << "Input " << (*input)->long_name() << " of node " << node.long_name()
2910 << " has no timing node.";
2911 auto input_total_time_nsec = timing_nodes_[(*input).get()].total_time_nsec;
2912 max_input_total_time_nsec =
2913 std::max(input_total_time_nsec, max_input_total_time_nsec);
2914 if (input_total_time_nsec > 0.0) {
2915 sum_input_throughput += 1.0 / input_total_time_nsec;
2916 }
2917 ++num_active_inputs;
2918 }
2919 auto parallelism_param = node.ParameterValue(kParallelism);
2920 double parallelism = num_active_inputs;
2921 if (parallelism_param.ok()) {
2922 parallelism = parallelism_param.value();
2923 }
2924 // After cl/445005635, there should always be `deterministic` parameter for an
2925 // ASYNC_INTERLEAVE_MANY node. The "not-ok" check is to allow the code to work
2926 // with protos saved and restored before that CL. Similarly for `cycle_length`
2927 // with cl/436244658.
2928 auto deterministic_param = node.ParameterValue(kDeterministic);
2929 bool deterministic = false;
2930 if (deterministic_param.ok()) {
2931 deterministic = deterministic_param.value() == 1.0;
2932 }
2933 auto cycle_length_param = node.ParameterValue(kCycleLength);
2934 double cycle_length = num_active_inputs;
2935 if (cycle_length_param.ok()) {
2936 cycle_length = cycle_length_param.value();
2937 }
2938 double input_total_time_nsec = 0.0;
2939 if (deterministic) {
2940 // If deterministic = true, then the total time is `max input total time /
2941 // min(parallelism, cycle_length)`.
2942 input_total_time_nsec =
2943 max_input_total_time_nsec / std::min(parallelism, cycle_length);
2944 } else if (sum_input_throughput > 0.0) {
2945 // If deterministic = false, then the total time is
2946 // `1/sum_input_throughput`. Scale the throughput according to `parallelism`
2947 // and `cycle_length` if `cycle_length` or `parallelism` is smaller than
2948 // active inputs. `cycle_length` and `parallelism` could theoretically be
2949 // larger than active inputs when some inputs are async and some are sync.
2950 if (std::min(cycle_length, parallelism) < num_active_inputs) {
2951 sum_input_throughput *=
2952 std::min(parallelism, cycle_length) / num_active_inputs;
2953 }
2954 input_total_time_nsec = 1.0 / sum_input_throughput;
2955 }
2956 node_timing.total_time_nsec =
2957 node_timing.self_time_nsec + input_total_time_nsec;
2958}
2959
2960void ModelTiming::ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes) {
2961 for (const auto& node : reverse_bfs_nodes) {
2962 ComputeNodeTotalTime(*(node.get()));
2963 }
2964}
2965
2966void ModelTiming::ComputeNodeTotalTime(const Node& node) {
2967 NodeTiming& node_timing = timing_nodes_[&node];
2968 node_timing.self_time_nsec = node.ComputeSelfTime();
2969 if (!node.autotune() || node.num_elements() <= 0) {
2970 return;
2971 }
2972#if !defined(IS_MOBILE_PLATFORM)
2973 // This block of code is defined only for non-mobile platform because mobile
2974 // platform lacks RTTI, i.e. the use of `dynamic_cast`.
2975 if (dynamic_cast<const AsyncInterleaveMany*>(&node) != nullptr) {
2976 ComputeAsyncInterleaveManyTotalTime(node);
2977 } else {
2978 ComputeNonAsyncInterleaveManyTotalTime(node);
2979 }
2980#else // !IS_MOBILE_PLATFORM
2981 ComputeNonAsyncInterleaveManyTotalTime(node);
2982#endif // !IS_MOBILE_PLATFORM
2983}
2984
2985std::vector<std::shared_ptr<Node>> ModelTiming::GetStageRoots() const {
2986 auto bfs_nodes = CollectNodes(root_, TraversalOrder::BFS, IsAnyNode);
2987 std::vector<std::shared_ptr<Node>> roots;
2988 if (!bfs_nodes.empty() && !bfs_nodes[0]->IsAsync()) {
2989 roots.push_back(bfs_nodes[0]);
2990 }
2991 for (auto& node : bfs_nodes) {
2992 if (node->IsAsync()) {
2993 roots.push_back(node);
2994 }
2995 }
2996 return roots;
2997}
2998
2999std::vector<std::shared_ptr<Node>> ModelTiming::GetStageNodes(
3000 std::shared_ptr<Node> stage_root) const {
3001 return CollectNodes(stage_root, TraversalOrder::BFS, IsSyncNode);
3002}
3003
3004} // namespace model
3005} // namespace data
3006} // namespace tensorflow
3007