Skip to content

Commit

Permalink
Simple assert prop pass
Browse files Browse the repository at this point in the history
Related to #18
  • Loading branch information
dubiousconst282 committed Dec 23, 2023
1 parent b226732 commit 290bd0b
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 40 deletions.
1 change: 1 addition & 0 deletions src/DistIL.Cli/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ static void RunPasses(OptimizerOptions options, Compilation comp)

manager.AddPasses()
.Apply<ValueNumbering>()
.Apply<AssertionProp>()
.Apply<PresizeLists>()
.Apply<LoopStrengthReduction>()
.IfChanged(simplifySeg);
Expand Down
64 changes: 27 additions & 37 deletions src/DistIL/Analysis/DominatorTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@ public class DominatorTree : IMethodAnalysis
readonly Node _root;
bool _hasDfsIndices = false; // whether Node.{PreIndex, PostIndex} have been calculated

public MethodBody Method { get; }

public DominatorTree(MethodBody method)
{
Method = method;

var nodes = CreateNodes();
_root = nodes[^1];
var nodes = CreateNodes(method);
_root = nodes[0];
ComputeDom(nodes);
ComputeChildren(nodes);
}
Expand All @@ -33,6 +29,7 @@ public bool Dominates(BasicBlock parent, BasicBlock child)
if (!_hasDfsIndices) {
ComputeDfsIndices();
}

var parentNode = GetNode(parent);
var childNode = GetNode(child);

Expand Down Expand Up @@ -77,58 +74,51 @@ public IEnumerable<BasicBlock> GetChildren(BasicBlock block)
}
}

public int GetPreIndex(BasicBlock block)
{
if (!_hasDfsIndices) {
ComputeDfsIndices();
}
return GetNode(block).PreIndex;
}

// NOTE: Querying nodes from unreachable blocks lead to KeyNotFoundException.
// Unsure how to best handle them, but ideally passes would not even consider unreachable blocks in the first place.
private Node GetNode(BasicBlock block)
{
return _block2node[block];
}

/// <summary> Creates the tree nodes and returns an array with them in DFS post order. </summary>
private Node[] CreateNodes()
private Span<Node> CreateNodes(MethodBody method)
{
Debug.Assert(Method.EntryBlock.NumPreds == 0);
Debug.Assert(method.EntryBlock.NumPreds == 0);

var nodes = new Node[Method.NumBlocks];
int index = 0;
var nodes = new Node[method.NumBlocks];
int index = nodes.Length;

Method.TraverseDepthFirst(postVisit: block => {
method.TraverseDepthFirst(postVisit: block => {
var node = new Node() {
Block = block,
PostIndex = index
PostIndex = nodes.Length - index
};
_block2node.Add(block, node);
nodes[index++] = node;
nodes[--index] = node;
});
return nodes;
// index will only be >0 if there are unreachable blocks.
return nodes.AsSpan(index);
}

// Algorithm from the paper "A Simple, Fast Dominance Algorithm"
// https://www.cs.rice.edu/~keith/EMBED/dom.pdf
private void ComputeDom(Node[] nodes)
private void ComputeDom(Span<Node> nodesRPO)
{
var entry = nodes[^1];
var entry = nodesRPO[0];
entry.IDom = entry; // entry block dominates itself

bool changed = true;
while (changed) {
changed = false;
// foreach block in reverse post order, except entry (at `len - 1`)
for (int i = nodes.Length - 2; i >= 0; i--) {
var node = nodes[i];
var block = node.Block;
// foreach block in reverse post order, except entry (at index 0)
foreach (var node in nodesRPO[1..]) {
var newDom = default(Node);

foreach (var predBlock in block.Preds) {
var pred = GetNode(predBlock);
foreach (var predBlock in node.Block.Preds) {
var pred = _block2node.GetValueOrDefault(predBlock);

if (pred.IDom != null) {
if (pred?.IDom != null) {
newDom = newDom == null ? pred : Intersect(pred, newDom);
}
}
Expand All @@ -153,10 +143,10 @@ static Node Intersect(Node b1, Node b2)
}
}

private static void ComputeChildren(Node[] nodes)
private static void ComputeChildren(Span<Node> nodes)
{
// Ignore entry node (^1) to avoid cycles in the children list
foreach (var node in nodes.AsSpan()[..^1]) {
// Ignore entry node (index 0) to avoid cycles in the children list
foreach (var node in nodes[1..]) {
var parent = node.IDom;
if (parent.FirstChild == null) {
parent.FirstChild = node;
Expand Down Expand Up @@ -219,9 +209,9 @@ public class DominanceFrontier : IMethodAnalysis
static readonly RefSet<BasicBlock> _emptySet = new();
readonly Dictionary<BasicBlock, RefSet<BasicBlock>> _df = new();

public DominanceFrontier(DominatorTree domTree)
public DominanceFrontier(MethodBody method, DominatorTree domTree)
{
foreach (var block in domTree.Method) {
foreach (var block in method) {
if (block.NumPreds < 2) continue;

var blockDom = domTree.IDom(block);
Expand All @@ -239,7 +229,7 @@ public DominanceFrontier(DominatorTree domTree)
}

static IMethodAnalysis IMethodAnalysis.Create(IMethodAnalysisManager mgr)
=> new DominanceFrontier(mgr.GetAnalysis<DominatorTree>());
=> new DominanceFrontier(mgr.Method, mgr.GetAnalysis<DominatorTree>());

public RefSet<BasicBlock> Of(BasicBlock block)
=> _df.GetValueOrDefault(block, _emptySet);
Expand Down
33 changes: 33 additions & 0 deletions src/DistIL/IR/Instructions/CompareInst.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,37 @@ public static CompareOp GetUnsigned(this CompareOp op)
/// <summary> Returns the signed version of the current operator, if it is unsigned; otherwise returns it unchanged. </summary>
public static CompareOp GetSigned(this CompareOp op)
=> op.IsUnsigned() ? op + (CompareOp.Slt - CompareOp.Ult) : op;

/// <summary> Returns whether the operator is strict (non-inclusive, not "-or-equal"). </summary>
public static bool IsStrict(this CompareOp op)
=> op is CompareOp.Slt or CompareOp.Sgt or CompareOp.Ult or CompareOp.Ugt or
CompareOp.FOlt or CompareOp.FOgt or CompareOp.FUlt or CompareOp.FUgt;

/// <summary> Returns the strict version of this operator: Sle -> Slt, Sge -> Sgt, ... </summary>
public static CompareOp GetStrict(this CompareOp op)
=> op switch {
CompareOp.Sle => CompareOp.Slt,
CompareOp.Sge => CompareOp.Sgt,
CompareOp.Ule => CompareOp.Ult,
CompareOp.Uge => CompareOp.Ugt,
CompareOp.FOle => CompareOp.FOlt,
CompareOp.FOge => CompareOp.FOgt,
CompareOp.FUle => CompareOp.FUlt,
CompareOp.FUge => CompareOp.FUgt,
_ => op
};

/// <summary> Returns the non-strict version of this operator: Slt -> Sle, Sgt -> Sge, ... </summary>
public static CompareOp GetNonStrict(this CompareOp op)
=> op switch {
CompareOp.Slt => CompareOp.Sle,
CompareOp.Sgt => CompareOp.Sge,
CompareOp.Ult => CompareOp.Ule,
CompareOp.Ugt => CompareOp.Uge,
CompareOp.FOlt => CompareOp.FOle,
CompareOp.FOgt => CompareOp.FOge,
CompareOp.FUlt => CompareOp.FUle,
CompareOp.FUgt => CompareOp.FUge,
_ => op
};
}
5 changes: 3 additions & 2 deletions src/DistIL/IR/Values/Const.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ StackType.Int or StackType.Long

public class ConstNull : Const
{
private ConstNull() { }
private ConstNull() { ResultType = PrimType.Object; }

public static ConstNull Create() => new() { ResultType = PrimType.Object };
static readonly ConstNull _instance = new();
public static ConstNull Create() => _instance;

public override void Print(PrintContext ctx)
{
Expand Down
183 changes: 183 additions & 0 deletions src/DistIL/Passes/AssertionProp.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
namespace DistIL.Passes;

using DistIL.Analysis;

/// <summary> A simple dominator-based "assertion propagation" pass. </summary>
public class AssertionProp : IMethodPass
{
// We mostly want this pass to eliminate null-checks/duplicated conditions and possibly
// array range checks. These are quite few in typical methods, but it enables folding of
// some LINQ dispatching code.
//
// Unfortunately, extending this impl to support array bounds-check elimination may be difficult since
// the assert chain will not include info from other scopes of the dom tree, thus deriving accurate
// range info from it might not be possible.
//
// Both GCC and LLVM use an on-demand approach to determine this kind of info, which looks relatively easy
// to implement and could be faster as it would save some needless tracking:
// - https://gcc.gnu.org/wiki/AndrewMacLeod/Ranger
// - LLVM LazyValueInfo and ValueTracking
//
// SCCP is another option but it looks rather finicky, requiring insertion of "assertion" defs in the IR.
// It would probably catch edge cases that this might miss?
//
// RyuJIT on the otherhand appears to be using path sensitive DFA to compute all asserts + a later elimination pass,
// which also sounds complicated/expansive.

public MethodPassResult Run(MethodTransformContext ctx)
{
var domTree = ctx.GetAnalysis<DominatorTree>();
var impliedAsserts = new Dictionary<Value, Assertion>();
bool changed = false;

domTree.Traverse(preVisit: EnterBlock);

return changed ? MethodInvalidations.DataFlow : MethodInvalidations.None;

void EnterBlock(BasicBlock block)
{
// Add implications from predecessor branches.
foreach (var pred in block.Preds) {
if (pred.Last is not BranchInst { Cond: CompareInst cmp } br) continue;

// Backedges must be ignored because otherwise they will lead to
// incorrect folding of loop conditions:
// Header:
// (cond is implied to be true here, will be incorrectly folded at body)
// Body:
// cond = cmp ...
// if (cond) goto header;
if (domTree.Dominates(block, pred)) continue;

var op = block == br.Then ? cmp.Op : cmp.Op.GetNegated();
Imply(block, op, cmp.Left, cmp.Right);

// Replace existing uses of the predecessor branch condition's
// that are dominated by this block, to catch cases like:
// if (cond) { if (cond) { ... } }
if (block.NumPreds == 1) {
var condResult = ConstInt.CreateI(block == br.Then ? 1 : 0);

foreach (var use in cmp.Uses()) {
if (domTree.Dominates(block, use.Parent.Block)) {
use.Operand = condResult;
changed = true;
}
}
}
}

foreach (var inst in block.NonPhis()) {
// Access to an object implies that it must be non-null afterwards.
// FIXME: check if try..catch regions will mess with this
if (inst is MemoryInst or
CallInst { IsVirtual: true } or
FieldAddrInst { IsInstance: true } or
ArrayAddrInst
&& inst.Operands[0] is TrackedValue obj
) {
Imply(block, CompareOp.Ne, obj, ConstNull.Create());
continue;
}

if (inst is CompareInst cmp && EvaluateCond(block, cmp.Op, cmp.Left, cmp.Right) is bool cond) {
cmp.ReplaceUses(ConstInt.CreateI(cond ? 1 : 0));
changed = true;
continue;
}
}
}

// Adds an assertion that implies a true condition.
void Imply(BasicBlock block, CompareOp op, Value left, Value right)
{
if (EvaluateCond(block, op, left, right) is not null) return;

for (int i = 0; i < 2; i++) {
var operand = i == 0 ? left : right;

// Don't bother tracking asserts related to consts for now
if (operand is not TrackedValue) continue;

ref var lastNode = ref impliedAsserts.GetOrAddRef(operand);
lastNode = new Assertion() {
Block = block,
Op = op, Left = left, Right = right,
Prev = lastNode
};
}
}

bool? EvaluateCond(BasicBlock activeBlock, CompareOp op, Value left, Value right)
{
return Evaluate(GetActiveAssert(activeBlock, left), op, left, right) ??
Evaluate(GetActiveAssert(activeBlock, right), op, left, right);
}
Assertion? GetActiveAssert(BasicBlock activeBlock, Value key)
{
if (!impliedAsserts.ContainsKey(key)) return null;

ref var node = ref impliedAsserts.GetRef(key);

// Lazily remove asserts that are out of scope, rather than
// tracking and removing dirty state on PostVisit.
// This shouldn't be too slow since dom queries are O(1).
while (node != null && !domTree.Dominates(node.Block, activeBlock)) {
node = node.Prev;
}

return node;
}
}

private static bool? Evaluate(Assertion? assert, CompareOp op, Value left, Value right)
{
for (; assert != null; assert = assert.Prev!) {
bool argsMatch = assert.Left.Equals(left) && assert.Right.Equals(right);

if (!argsMatch && assert.Left.Equals(right) && assert.Right.Equals(left)) {
op = op.GetSwapped();
argsMatch = true;
}

if (argsMatch) {
if (AssertImpliesCond(assert.Op, op)) {
return true;
}
if (AssertImpliesCond(assert.Op, op.GetNegated())) {
return false;
}
}

// TODO: evaluate relations, eg. x < 10 implies x < 20
}
return null;
}

private static bool AssertImpliesCond(CompareOp assert, CompareOp cond)
{
// x < y implies x <= y
if (assert.IsStrict() && !cond.IsStrict()) {
cond = cond.GetStrict();
}
return assert == cond;
}

// List of known assertions linked to a value, ordered in most recent order.
class Assertion
{
public required CompareOp Op;
public required Value Left, Right;
public required BasicBlock Block;
public Assertion? Prev;

public override string ToString()
{
var sw = new StringWriter();
var symTable = (Left as TrackedValue ?? Right as TrackedValue)?.GetSymbolTable();
var pc = new PrintContext(sw, symTable ?? SymbolTable.Empty);
pc.Print($"{Op.ToString()} {Left}, {Right}");
return sw.ToString();
}
}
}
2 changes: 1 addition & 1 deletion src/DistIL/Passes/ValueNumbering.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public bool CheckAvail(Instruction def, Instruction user)
}

// Scan CFG backwards to find memory clobbers
// TODO: cache results or mark something on the avail list to avoid re-scans
// TODO: investigate using MemorySSA for fast avail-dep checks: https://llvm.org/docs/MemorySSA.html
var worklist = new DiscreteStack<BasicBlock>(user.Block);

while (worklist.TryPop(out var block)) {
Expand Down
Loading

0 comments on commit 290bd0b

Please sign in to comment.