Skip to content
Andrzej Pronobis edited this page Sep 8, 2017 · 2 revisions

LIP 4 - Improved Pruning during Cloning and Serialization

LIP 4
Title Improved Pruning during Cloning and Serialization
Author A. Pronobis
Status Draft
Type Standard
Discussion Issue #28
PR
Created Sep 1, 2017

Introduction

The current implementation of SPN graph pruning operates directly on the current SPN graph. However, it has the following disadvantages:

  • It is not possible to clear the TF graph of operations on the original SPN graph without destroying the operations for the pruned graph.
  • Saving a pruned graph requires modifying the current graph and adds new TF ops.

In this proposal, we suggest to include the pruning mechanism as optional functionality of cloning and saving/loading of the SPN graph.

As a result, we will be able to:

  • Generate a new pruned copy of the SPN graph in the same TF graph (might still be convenient for small SPNs)
  • Generate a pruned SPN graph in a new TF graph and destroy the old TF graph
  • Save a pruned version of the current SPN graph to a file without affecting the current graph

Technical Background

This proposal depends on LIP3 and LIP2.

Proposal

We propose using the following workflow realizing pruning as part of cloning/saving:

  1. Create parameter snapshots (using create_param_snapshots, runs traverse_graph)
  2. Compute a path through the graph (top-down) which indicates which edges/nodes to keep (implemented with a new class Pruner using compute_graph_up_down)
  3. Re-create nodes by cloning or saving/loading them. This process should simplify nodes for which some operations have been eliminated by pruning and eliminate unnecessary nodes. This requires traversing the graph bottom up (using compute_graph_up).
    • When cloning: For each node that should be kept, clone the node (possibly simplified) and connect it to the already created input nodes.
    • When saving/loading: For each node that should be kept, serialize the node to a file (possibly simplified). Then loading follows the bottom-up order and simply deserializes nodes and connects them to already existing input nodes.
  4. Restore parameter snapshots (using restore_param_snapshots, runs traverse_graph)

We propose the following additions to the abstract Node class:

  • children_indicators = get_pruning_indicators(indicators) - Calculates indicators that are used to decide if a node or an output of a node should be preserved after pruning. Relies on parameter node snapshots. Raises RunTimeError if parameter snapshot is not available.
  • new_node = clone(nodes_by_name, pruning_indicators=None) - Create a copy of the node. nodes_by_name is a dictionary with all already created nodes indexed by their names in the source SPN graph. The dictionary is used to look up the input nodes of the newly created node. The default implementation (in the Node class) creates a new instance of the node of the same type. Subclasses should implement additional specific functionality. If pruning_indicators are provided retain only the outputs with positive indicators.
  • data = serialize(pruning_indicators=None) - Serialize the node to a dictionary. If pruning_indicators are provided retain only the outputs with positive indicators.
  • deserialize(nodes_by_name) - De-serialize the node from a dictionary. nodes_by_name is a dictionary with all already created nodes indexed by their names in the original SPN graph. The dictionary is used to look up the input nodes of the newly created node.

We propose adding a Pruner class performing calculation and storage of pruning indicators for an SPN graph. The class should implement the following interface:

  • __init__(root) - Initialize the pruner for a specific graph
  • indicators = get_pruning_indicators() - Calculate the indicators for the graph (uses compute_graph_up_down). Returns the computed indicators (see below).
  • indicators property - Provides access to a dictionary mapping nodes to calculated indicators. Returned using MappingProxyType.

We propose the following interface of the Replicator class:

  • __init__(src_sess=None, dst_sess=None, dst_graph=None) - If graph or sessions are None, the current default values are used. The sessions are necessary for with_param_vals==True. Otherwise, it is sufficient to provide dst_graph.
  • node = find_node(name) - Find node by its name in the source graph.
  • new_root = clone(root, prune=False, with_param_vals=True) - Instantiate a new graph based on the graph rooted in root. Uses Pruner internally to obtain indicators.

We propose keeping the following serialization/deserialization functions (with modifications):

  • data = serialize_graph(root, prune=False, with_param_vals=True, sess=None) - Serializes an SPN graph to a dictionary. Prunes the graph if prune==True. Uses Pruner internally to obtain indicators. If sess is None, the default session is used whenever with_param_vals=True. It is not relevant if with_param_vals==False.
  • root = deserialize_graph(data, with_param_vals=True, sess=None, graph=None, nodes_by_name=None) - De-serialize an SPN graph from a dictionary. If graph or session are None, the current default values are used. If with_param_vals==True, sess is used to set the parameter values. Otherwise, it is sufficient to provide graph.

We propose the following interface of the Saver class:

  • __init__(path, sess=None) - If sessions is None, the current default session is used.
  • save(root, prune=False, with_param_vals=True) - Save the graph structure (and possibly parameter values) to a file. Prunes the graph if prune==True. Internally uses serialize_graph.

We propose the following interface of the Loader class:

  • __init__(path, sess=None, graph=None) - If graph or sessions are None, the current default values are used. sess is only relevant if with_param_vals==True.
  • root = load(with_param_vals=True) - Instantiate a new SPN graph based on the saved data. Internally uses deserialize_graph.

Examples

Create first graph and session:

graph1 = tf.Graph()
sess1 = tf.Session(graph=graph1)

Create and initialize an SPN in the first graph:

with graph1.as_default():
    model1 = spn.Poon11NaiveMixtureModel()
    root1 = model1.build()
    init1 = spn.initialize_weights(root1)
sess1.run(init1)

Create another graph and session:

graph2 = tf.Graph()
sess2 = tf.Session(graph=graph2)

Prune while cloning:

replicator = spn.Replicator(src_sess=sess1, dst_sess=sess2)
replicator.clone(root1, prune=True, with_param_vals=True)

Delete old graph and session:

del graph1
del sess1

Decision

Clone this wiki locally