-
Notifications
You must be signed in to change notification settings - Fork 3
LIP0004
LIP | 4 |
---|---|
Title | Improved Pruning during Cloning and Serialization |
Author | A. Pronobis |
Status | Draft |
Type | Standard |
Discussion | Issue #28 |
PR | |
Created | Sep 1, 2017 |
The current implementation of SPN graph pruning operates directly on the current SPN graph. However, it has the following disadvantages:
- It is not possible to clear the TF graph of operations on the original SPN graph without destroying the operations for the pruned graph.
- Saving a pruned graph requires modifying the current graph and adds new TF ops.
In this proposal, we suggest to include the pruning mechanism as optional functionality of cloning and saving/loading of the SPN graph.
As a result, we will be able to:
- Generate a new pruned copy of the SPN graph in the same TF graph (might still be convenient for small SPNs)
- Generate a pruned SPN graph in a new TF graph and destroy the old TF graph
- Save a pruned version of the current SPN graph to a file without affecting the current graph
This proposal depends on LIP3 and LIP2.
We propose using the following workflow realizing pruning as part of cloning/saving:
- Create parameter snapshots (using
create_param_snapshots
, runstraverse_graph
) - Compute a path through the graph (top-down) which indicates which edges/nodes to keep (implemented with a new class
Pruner
usingcompute_graph_up_down
) - Re-create nodes by cloning or saving/loading them. This process should simplify nodes for which some operations have been eliminated by pruning and eliminate unnecessary nodes. This requires traversing the graph bottom up (using
compute_graph_up
).- When cloning: For each node that should be kept, clone the node (possibly simplified) and connect it to the already created input nodes.
- When saving/loading: For each node that should be kept, serialize the node to a file (possibly simplified). Then loading follows the bottom-up order and simply deserializes nodes and connects them to already existing input nodes.
- Restore parameter snapshots (using
restore_param_snapshots
, runstraverse_graph
)
We propose the following additions to the abstract Node
class:
-
children_indicators = get_pruning_indicators(indicators)
- Calculates indicators that are used to decide if a node or an output of a node should be preserved after pruning. Relies on parameter node snapshots. RaisesRunTimeError
if parameter snapshot is not available. -
new_node = clone(nodes_by_name, pruning_indicators=None)
- Create a copy of the node.nodes_by_name
is a dictionary with all already created nodes indexed by their names in the source SPN graph. The dictionary is used to look up the input nodes of the newly created node. The default implementation (in theNode
class) creates a new instance of the node of the same type. Subclasses should implement additional specific functionality. Ifpruning_indicators
are provided retain only the outputs with positive indicators. -
data = serialize(pruning_indicators=None)
- Serialize the node to a dictionary. Ifpruning_indicators
are provided retain only the outputs with positive indicators. -
deserialize(nodes_by_name)
- De-serialize the node from a dictionary.nodes_by_name
is a dictionary with all already created nodes indexed by their names in the original SPN graph. The dictionary is used to look up the input nodes of the newly created node.
We propose adding a Pruner
class performing calculation and storage of pruning indicators for an SPN graph. The class should implement the following interface:
-
__init__(root)
- Initialize the pruner for a specific graph -
indicators = get_pruning_indicators()
- Calculate the indicators for the graph (usescompute_graph_up_down
). Returns the computed indicators (see below). -
indicators
property - Provides access to a dictionary mapping nodes to calculated indicators. Returned usingMappingProxyType
.
We propose the following interface of the Replicator
class:
-
__init__(src_sess=None, dst_sess=None, dst_graph=None)
- If graph or sessions areNone
, the current default values are used. The sessions are necessary forwith_param_vals==True
. Otherwise, it is sufficient to providedst_graph
. -
node = find_node(name)
- Find node by its name in the source graph. -
new_root = clone(root, prune=False, with_param_vals=True)
- Instantiate a new graph based on the graph rooted inroot
. UsesPruner
internally to obtain indicators.
We propose keeping the following serialization/deserialization functions (with modifications):
-
data = serialize_graph(root, prune=False, with_param_vals=True, sess=None)
- Serializes an SPN graph to a dictionary. Prunes the graph ifprune==True
. UsesPruner
internally to obtain indicators. Ifsess
isNone
, the default session is used wheneverwith_param_vals=True
. It is not relevant ifwith_param_vals==False
. -
root = deserialize_graph(data, with_param_vals=True, sess=None, graph=None, nodes_by_name=None)
- De-serialize an SPN graph from a dictionary. If graph or session areNone
, the current default values are used. Ifwith_param_vals==True
,sess
is used to set the parameter values. Otherwise, it is sufficient to providegraph
.
We propose the following interface of the Saver
class:
-
__init__(path, sess=None)
- If sessions isNone
, the current default session is used. -
save(root, prune=False, with_param_vals=True)
- Save the graph structure (and possibly parameter values) to a file. Prunes the graph ifprune==True
. Internally usesserialize_graph
.
We propose the following interface of the Loader
class:
-
__init__(path, sess=None, graph=None)
- If graph or sessions areNone
, the current default values are used.sess
is only relevant ifwith_param_vals==True
. -
root = load(with_param_vals=True)
- Instantiate a new SPN graph based on the saved data. Internally usesdeserialize_graph
.
Create first graph and session:
graph1 = tf.Graph()
sess1 = tf.Session(graph=graph1)
Create and initialize an SPN in the first graph:
with graph1.as_default():
model1 = spn.Poon11NaiveMixtureModel()
root1 = model1.build()
init1 = spn.initialize_weights(root1)
sess1.run(init1)
Create another graph and session:
graph2 = tf.Graph()
sess2 = tf.Session(graph=graph2)
Prune while cloning:
replicator = spn.Replicator(src_sess=sess1, dst_sess=sess2)
replicator.clone(root1, prune=True, with_param_vals=True)
Delete old graph and session:
del graph1
del sess1