-
Notifications
You must be signed in to change notification settings - Fork 3
LIP0007
LIP | 7 |
---|---|
Title | Reusing tensors |
Author | J. van de Wolfshaar |
Status | Draft |
Type | Standard |
Discussion | Issue #37 |
PR | |
Created | Feb 21, 2018 |
Some tensors that are build at different phases of graph constructions actually compute the very same result. This results in an unneccesarily large computation graph. This LIP proposes a way to circumvent this easily.
Let us consider the MPE path construction. In the case of computing the path, the current code
calls get_value
or get_log_value
twice (once during upward graph construction and once during
the construction of MPE path ops). In both cases the concrete functions inside one of the Node
's
subclass instances will redefine the very same operations using the very same inputs per operation.
For any particular occurrence of such a pattern, we could reduce the graph size by reusing the tensor that was computed before.
We'll propose two different ways of reusing the tensors, each with its own merits.
One possible solution is to just add members to
Sum
that will hold e.g.
- The 'weighted' value tensor, i.e. the result after applying weighting
- The 'selected' value tensor, i.e. the result after applying IVs
- <Anything that might be reused in up and down passes>
Other types of Sum
nodes that might become part of future versions of LibSPN will also have to
compute a weighted and a selected value tensor. Perhaps it is a good idea to enforce these as
part of the implementation by adding an AbstractSum
class that inherits from OpNode
and defines
a few abstract methods e.g. _weighted_value
and _selected_value
. The Sum
node will then be a
subclass of AbtractSum
. Any other kind of sum node (e.g. a SumsLayer
) will also
inherit AbstractSum
, thereby being enforced to define its own _weighted_value
and
_selected_value
, and we could make provide a helper function in the AbstractSum
class that
will reuse already constructed tensors.
Compared to the next solution, this is cleaner. However, it does have a considerable impact on the code.
To reduce the impact on the current code and to maintain readibility, we can
use the functools.lru_cache
decorator for memoization. It creates a least-recently used cache
for whichever
function it is applied to. The cache's keys will be the arguments of the function. Whenever a key is not in the cache, it will run the body of the
function as is, otherwise it will look up the previous result in the cache obtained while
evaluating the function with the first occurence of the particular keys.
Such a decorator could easily be applied to any repeating function call that constructs part of our graph.
The following piece of code illustrates the use of functools.lru_cache
:
import functools
import tensorflow as tf
class A:
@functools.lru_cache(maxsize=1)
def tensor_with_cache(self):
return tf.ones(5)
def tensor_without_cache(self):
return tf.ones(5)
class B(A):
pass
b0 = B()
b1 = B()
# Cache per instance
print(b0.tensor_with_cache())
print(b0.tensor_with_cache())
# Should create another tensor only once
print(b1.tensor_with_cache())
print(b1.tensor_with_cache())
# Creates a new tensor twice!
print(b1.tensor_without_cache())
print(b1.tensor_without_cache())
sess = tf.Session()
print(len(sess.graph.get_operations()))
With the following output:
Tensor("ones:0", shape=(5,), dtype=float32)
Tensor("ones:0", shape=(5,), dtype=float32)
Tensor("ones_1:0", shape=(5,), dtype=float32)
Tensor("ones_1:0", shape=(5,), dtype=float32)
Tensor("ones_2:0", shape=(5,), dtype=float32)
Tensor("ones_3:0", shape=(5,), dtype=float32)
4
Hence, we propose to add the functools.lru_cache
decorator in any part of the code that computes
a certain Op with the same inputs.
Even better is to have a custom memoization, as this allows us to to switch off the memoization
easily through setting a flag in spn.conf
. This will be the final implementation.
The results below display MNIST training for two classes, where we have set a memoize
flag to
either True
or False
. It can be seen that the performance is generally better for the
implementation with memoization. Also, the graph size is reduced by 9.7 percent on average.
For the comparison below we used the multi nodes as these are more efficient and will most likely
become a common choice for building SPNs.
#-----------------------
InferenceType: MPE
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------+
| op_name | on_gpu | multi_nodes | spn_size | tf_size | memory_used | input_dist | setup_time | weights_init_time | first_run_time | rest_run_time | test_accuracy | config |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------|
| mnist_01 | True | True | 1400 | 51274 | 500793088 | MIXTURE | 27.5845 | 2.74238 | 3286.26 | 1062.12 | 0.463357 | memoize=False |
| mnist_01 | True | True | 1400 | 46130 | 501031680 | MIXTURE | 31.8005 | 2.66702 | 3046.4 | 996.867 | 0.463357 | memoize=True |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------+
-----------------------
InferenceType: MARGINAL
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------+
| op_name | on_gpu | multi_nodes | spn_size | tf_size | memory_used | input_dist | setup_time | weights_init_time | first_run_time | rest_run_time | test_accuracy | config |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------|
| mnist_01 | True | True | 1400 | 51274 | 501031680 | MIXTURE | 27.6599 | 2.64925 | 3388.85 | 1032.35 | 0.463357 | memoize=False |
| mnist_01 | True | True | 1400 | 46136 | 501031680 | MIXTURE | 28.244 | 2.73143 | 3117.22 | 969.174 | 0.463357 | memoize=True |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------+
-----------------------
InferenceType: MPE-LOG
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------+
| op_name | on_gpu | multi_nodes | spn_size | tf_size | memory_used | input_dist | setup_time | weights_init_time | first_run_time | rest_run_time | test_accuracy | config |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------|
| mnist_01 | True | True | 1400 | 52472 | 504583936 | MIXTURE | 32.6242 | 2.74428 | 3268.27 | 1120.02 | 0.997636 | memoize=False |
| mnist_01 | True | True | 1400 | 47328 | 507256320 | MIXTURE | 31.2596 | 2.91775 | 3299.9 | 987.384 | 0.998109 | memoize=True |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------+
-----------------------
InferenceType: MARGINAL-LOG
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------+
| op_name | on_gpu | multi_nodes | spn_size | tf_size | memory_used | input_dist | setup_time | weights_init_time | first_run_time | rest_run_time | test_accuracy | config |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------|
| mnist_01 | True | True | 1400 | 59636 | 507256320 | MIXTURE | 31.9085 | 3.15185 | 4279.56 | 1210.91 | 0.99669 | memoize=False |
| mnist_01 | True | True | 1400 | 54492 | 507256320 | MIXTURE | 33.9269 | 3.30922 | 4263.64 | 1283.15 | 0.997163 | memoize=True |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+---------------+