Skip to content

Commit 48dbd64

Browse files
committed
Fix caching during conditional execution
1 parent 356cdd5 commit 48dbd64

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

include/avalanche/ExecutionCache.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class ExecutionCache : private CachedItemsMap {
2121
explicit ExecutionCache(DeviceIndex device_idx);
2222
explicit ExecutionCache(BufferPoolRef buffer_pool);
2323
bool get_from_cache_no_counter(NodeId node_id, MultiArrayRef &result) const;
24+
bool is_cached(const NodeId node_id) const;
2425
void zero_reuse_counters();
2526
void put(const NodeId node_id, const MultiArrayRef &array);
2627
void set_node_params(NodeId node_id,

src/avalanche/ExecutionCache.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,10 @@ bool ExecutionCache::get_info(NodeId node_id, CachedItem &info) const {
9696
return false;
9797
}
9898

99+
bool ExecutionCache::is_cached(const NodeId node_id) const {
100+
auto cached = find(node_id);
101+
return cached != this->end() && cached->second.data;
102+
}
103+
99104

100105
} // namespace

src/avalanche/conditional_nodes.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,48 @@ Cond::Cond(const NodeRef &condition, const NodeRef &true_node,
4545
}
4646
}
4747

48+
/**
49+
* Makes sure all inputs of the given node have been evaluated, without
50+
* evaluating the node itself, unless it has already been evaluated
51+
* and cached previously. This it necessary for proper work of the Cond
52+
* node by two reasons:
53+
* 1. To match the behaviour of `cond` from TF, which is this (a quote):
54+
*
55+
* Note that the conditional execution applies only to the operations
56+
* defined in true_fn and false_fn. Consider the following simple program:
57+
*
58+
* z = tf.multiply(a, b)
59+
* result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
60+
*
61+
* If x < y, the tf.add operation will be executed and tf.square
62+
* operation will not be executed. Since z is needed for at least
63+
* one branch of the cond, the tf.multiply operation is always executed,
64+
* unconditionally. Although this behavior is consistent with
65+
* the dataflow model of TensorFlow, it has occasionally surprised
66+
* some users who expected a lazier semantics.
67+
*
68+
* https://www.tensorflow.org/api_docs/python/tf/cond
69+
*
70+
* 2. Such evaluation helps to make sure we don't have any values stored
71+
* in cache with counters > 0 waiting to be used during the run.
72+
* By evaluating those nodes we imitate usage of them as inputs, thus
73+
* making sure that caching works as expected.
74+
*/
75+
76+
void evaluate_inputs_of(const NodeRef &node, Context &context,
77+
ExecutionCache &cache) {
78+
if (cache.is_cached(node->id)) {
79+
// the node (and its inputs) has already been evaluated before,
80+
// because it's present in the cache. We evaluate it once more
81+
// (no actual work will be done) to decrease the cache counter.
82+
node->eval(context, cache);
83+
} else {
84+
for (const auto &inp: node->inputs()) {
85+
inp->eval(context, cache);
86+
}
87+
}
88+
}
89+
4890
MultiArrayRef Cond::eval(Context &context, ExecutionCache &cache) const {
4991
MultiArrayRef result;
5092
if (!cache.get(id, result)) {
@@ -53,8 +95,10 @@ MultiArrayRef Cond::eval(Context &context, ExecutionCache &cache) const {
5395
cond_value->fetch_data_into(condition);
5496
if (condition[0]) {
5597
result = _true_node->eval(context, cache);
98+
evaluate_inputs_of(_false_node, context, cache);
5699
} else {
57100
result = _false_node->eval(context, cache);
101+
evaluate_inputs_of(_true_node, context, cache);
58102
}
59103
cache.put(id, result);
60104
}

test/test_tree_evaluation.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,18 @@ TEST_CASE("Conditional evaluation") {
141141
auto update1 = F<UpdateAdd>(var1, one);
142142
auto update2 = F<UpdateAdd>(var2, one);
143143
auto output = Cond::make(condition, update1, update2);
144-
// Only var2 should be incremented, because condition == 0
144+
INFO("Check chat only var2 should be incremented, because condition == 0");
145145
context->init<std::int8_t>(condition, {0});
146146
evaluate_and_check<float>(output, {1}, Shape(), context);
147147
evaluate_and_check<float>(var1, {0}, Shape(), context);
148148
evaluate_and_check<float>(var2, {1}, Shape(), context);
149-
// Now only var1 should be incremented, because condition == 1
149+
INFO("Now only var1 should be incremented, because condition == 1");
150150
context->init<std::int8_t>(condition, {1});
151151
evaluate_and_check<float>(output, {1}, Shape(), context);
152152
evaluate_and_check<float>(var1, {1}, Shape(), context);
153153
evaluate_and_check<float>(var2, {1}, Shape(), context);
154-
// Again only var1 should be incremented, because condition is still == 1
154+
INFO("Again only var1 should be incremented, "
155+
"because condition is still == 1");
155156
evaluate_and_check<float>(output, {2}, Shape(), context);
156157
evaluate_and_check<float>(var1, {2}, Shape(), context);
157158
evaluate_and_check<float>(var2, {1}, Shape(), context);

0 commit comments

Comments
 (0)