Skip to content

Commit

Permalink
Use dummy input variables during Scan rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Aug 24, 2022
1 parent 6307c50 commit 04e2d8a
Showing 1 changed file with 41 additions and 22 deletions.
63 changes: 41 additions & 22 deletions aesara/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
graph_inputs,
io_toposort,
is_in_ancestors,
replace_nominals_with_dummies,
)
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
Expand Down Expand Up @@ -82,6 +83,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
"""
if not isinstance(node.op, Scan):
return False

op = node.op
op_info = op.info
# We only need to take care of sequences and other arguments
Expand All @@ -92,8 +94,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
st += op_info.n_sit_sot
st += op_info.n_shared_outs

op_ins = op.inner_inputs
op_outs = op.inner_outputs
op_ins, op_outs = replace_nominals_with_dummies(op.inner_inputs, op.inner_outputs)

# Corresponds to the initial states, which should stay untouched.
# We put those variables aside, and put them back at the end.
Expand Down Expand Up @@ -189,6 +190,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
allow_gc=op.allow_gc,
)
nw_outs = nwScan(*nw_outer, return_list=True)

return dict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
else:
return False
Expand All @@ -207,7 +209,9 @@ def push_out_non_seq_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False

node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
node_inputs, node_outputs = replace_nominals_with_dummies(
node.op.inner_inputs, node.op.inner_outputs
)

local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(node_outputs)
Expand Down Expand Up @@ -417,7 +421,9 @@ def push_out_seq_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False

node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
node_inputs, node_outputs = replace_nominals_with_dummies(
node.op.inner_inputs, node.op.inner_outputs
)

local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(node_outputs)
Expand Down Expand Up @@ -832,9 +838,10 @@ def push_out_add_scan(fgraph, node):

# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(
node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info
inner_inputs, inner_outputs = replace_nominals_with_dummies(
op.inner_inputs, op.inner_outputs
)
args = ScanArgs(node.inputs, node.outputs, inner_inputs, inner_outputs, op.info)

clients = {}
local_fgraph_topo = io_toposort(
Expand Down Expand Up @@ -1694,6 +1701,8 @@ def merge(self, nodes):
inner_outs = [[] for nd in nodes]
outer_outs = []

# inner_inputs, inner_outputs = replace_nominals_with_dummies(nd.op.inner_inputs, nd.op.inner_outputs)

def rename(ls, suffix):
for k in ls:
if k.name:
Expand Down Expand Up @@ -1967,11 +1976,16 @@ def scan_merge_inouts(fgraph, node):
# Do a first pass to merge identical external inputs.
# Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates.

inner_inputs, inner_outputs = replace_nominals_with_dummies(
node.op.inner_inputs, node.op.inner_outputs
)

a = ScanArgs(
node.inputs,
node.outputs,
node.op.inner_inputs,
node.op.inner_outputs,
inner_inputs,
inner_outputs,
node.op.info,
)

Expand Down Expand Up @@ -2173,10 +2187,15 @@ def push_out_dot1_scan(fgraph, node):
# Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot
op = node.op
sitsot_ins = op.inner_sitsot(op.inner_inputs)
sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)

inner_inputs, inner_outputs = replace_nominals_with_dummies(
op.inner_inputs, op.inner_outputs
)

sitsot_ins = op.inner_sitsot(inner_inputs)
sitsot_outs = op.inner_sitsot_outs(inner_outputs)
outer_sitsot = op.outer_sitsot_outs(node.outputs)
seqs = op.inner_seqs(op.inner_inputs)
seqs = op.inner_seqs(inner_inputs)
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):

if (
Expand Down Expand Up @@ -2218,23 +2237,23 @@ def push_out_dot1_scan(fgraph, node):
# First let us split all arguments according to their
# corresponding categories

inner_seqs = op.inner_seqs(op.inner_inputs)
inner_seqs = op.inner_seqs(inner_inputs)
outer_seqs = op.outer_seqs(node.inputs)
inner_mitmot = op.inner_mitmot(op.inner_inputs)
inner_mitmot = op.inner_mitmot(inner_inputs)
outer_mitmot = op.outer_mitmot(node.inputs)
inner_mitmot_outs = op.inner_mitmot_outs(op.inner_outputs)
inner_mitsot = op.inner_mitsot(op.inner_inputs)
inner_mitmot_outs = op.inner_mitmot_outs(inner_outputs)
inner_mitsot = op.inner_mitsot(inner_inputs)
outer_mitsot = op.outer_mitsot(node.inputs)
inner_mitsot_outs = op.inner_mitsot_outs(op.inner_outputs)
inner_sitsot = op.inner_sitsot(op.inner_inputs)
inner_mitsot_outs = op.inner_mitsot_outs(inner_outputs)
inner_sitsot = op.inner_sitsot(inner_inputs)
outer_sitsot = op.outer_sitsot(node.inputs)
inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
inner_sitsot_outs = op.inner_sitsot_outs(inner_outputs)
outer_nitsot = op.outer_nitsot(node.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs)
inner_shared = op.inner_shared(op.inner_inputs)
inner_nitsot_outs = op.inner_nitsot_outs(inner_outputs)
inner_shared = op.inner_shared(inner_inputs)
outer_shared = op.outer_shared(node.inputs)
inner_shared_outs = op.inner_shared_outs(op.inner_outputs)
inner_non_seqs = op.inner_non_seqs(op.inner_inputs)
inner_shared_outs = op.inner_shared_outs(inner_outputs)
inner_non_seqs = op.inner_non_seqs(inner_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)

new_info = dataclasses.replace(
Expand Down

0 comments on commit 04e2d8a

Please sign in to comment.