@@ -45,6 +45,48 @@ Cond::Cond(const NodeRef &condition, const NodeRef &true_node,
45
45
}
46
46
}
47
47
48
+ /* *
49
+ * Makes sure all inputs of the given node have been evaluated, without
50
+ * evaluating the node itself, unless it has already been evaluated
51
+ * and cached previously. This it necessary for proper work of the Cond
52
+ * node by two reasons:
53
+ * 1. To match the behaviour of `cond` from TF, which is this (a quote):
54
+ *
55
+ * Note that the conditional execution applies only to the operations
56
+ * defined in true_fn and false_fn. Consider the following simple program:
57
+ *
58
+ * z = tf.multiply(a, b)
59
+ * result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
60
+ *
61
+ * If x < y, the tf.add operation will be executed and tf.square
62
+ * operation will not be executed. Since z is needed for at least
63
+ * one branch of the cond, the tf.multiply operation is always executed,
64
+ * unconditionally. Although this behavior is consistent with
65
+ * the dataflow model of TensorFlow, it has occasionally surprised
66
+ * some users who expected a lazier semantics.
67
+ *
68
+ * https://www.tensorflow.org/api_docs/python/tf/cond
69
+ *
70
+ * 2. Such evaluation helps to make sure we don't have any values stored
71
+ * in cache with counters > 0 waiting to be used during the run.
72
+ * By evaluating those nodes we imitate usage of them as inputs, thus
73
+ * making sure that caching works as expected.
74
+ */
75
+
76
+ void evaluate_inputs_of (const NodeRef &node, Context &context,
77
+ ExecutionCache &cache) {
78
+ if (cache.is_cached (node->id )) {
79
+ // the node (and its inputs) has already been evaluated before,
80
+ // because it's present in the cache. We evaluate it once more
81
+ // (no actual work will be done) to decrease the cache counter.
82
+ node->eval (context, cache);
83
+ } else {
84
+ for (const auto &inp: node->inputs ()) {
85
+ inp->eval (context, cache);
86
+ }
87
+ }
88
+ }
89
+
48
90
MultiArrayRef Cond::eval (Context &context, ExecutionCache &cache) const {
49
91
MultiArrayRef result;
50
92
if (!cache.get (id, result)) {
@@ -53,8 +95,10 @@ MultiArrayRef Cond::eval(Context &context, ExecutionCache &cache) const {
53
95
cond_value->fetch_data_into (condition);
54
96
if (condition[0 ]) {
55
97
result = _true_node->eval (context, cache);
98
+ evaluate_inputs_of (_false_node, context, cache);
56
99
} else {
57
100
result = _false_node->eval (context, cache);
101
+ evaluate_inputs_of (_true_node, context, cache);
58
102
}
59
103
cache.put (id, result);
60
104
}
0 commit comments