Skip to content

Commit

Permalink
Changed wire allocator API to use Value as the key instead of Value.S…
Browse files Browse the repository at this point in the history
…tring().
  • Loading branch information
markkurossi committed Aug 21, 2023
1 parent d7d0cb3 commit 987a587
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 65 deletions.
58 changes: 29 additions & 29 deletions compiler/ssa/circuitgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
instr := step.Instr
var wires [][]*circuits.Wire
for _, in := range instr.In {
w, err := prog.Wires(in.String(), in.Type.Bits)
w, err := prog.Wires(in, in.Type.Bits)
if err != nil {
return err
}
wires = append(wires, w)
}
switch instr.Op {
case Iadd, Uadd:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -96,7 +96,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Isub, Usub:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -106,7 +106,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Imult, Umult:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -117,7 +117,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Idiv, Udiv:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -128,7 +128,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Imod, Umod:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand Down Expand Up @@ -158,7 +158,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}
o[bit] = w
}
err = prog.SetWires(instr.Out.String(), o)
err = prog.SetWires(*instr.Out, o)
if err != nil {
return err
}
Expand Down Expand Up @@ -189,7 +189,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}
o[bit] = w
}
err = prog.SetWires(instr.Out.String(), o)
err = prog.SetWires(*instr.Out, o)
if err != nil {
return err
}
Expand Down Expand Up @@ -225,13 +225,13 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
for bit := to - from; int(bit) < len(o); bit++ {
o[bit] = cc.ZeroWire()
}
err = prog.SetWires(instr.Out.String(), o)
err = prog.SetWires(*instr.Out, o)
if err != nil {
return err
}

case Index:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -247,7 +247,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Ilt, Ult:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -257,7 +257,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Ile, Ule:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -267,7 +267,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Igt, Ugt:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -277,7 +277,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Ige, Uge:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -287,7 +287,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Eq:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -297,7 +297,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Neq:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -312,7 +312,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
return fmt.Errorf("%s unsupported index type %T: %s",
instr.Op, instr.In[1], err)
}
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -327,7 +327,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
return fmt.Errorf("%s unsupported index type %T: %s",
instr.Op, instr.In[1], err)
}
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -337,7 +337,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case And:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -347,7 +347,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Or:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -357,7 +357,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Band:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -367,7 +367,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Bclr:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -377,7 +377,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Bor:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand All @@ -387,7 +387,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Bxor:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand Down Expand Up @@ -415,7 +415,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}
o[bit] = w
}
err := prog.SetWires(instr.Out.String(), o)
err := prog.SetWires(*instr.Out, o)
if err != nil {
return err
}
Expand Down Expand Up @@ -457,13 +457,13 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}
o[bit] = w
}
err = prog.SetWires(instr.Out.String(), o)
err = prog.SetWires(*instr.Out, o)
if err != nil {
return err
}

case Phi:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand Down Expand Up @@ -502,7 +502,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
var circOut []*circuits.Wire

for _, r := range instr.Ret {
o, err := prog.Wires(r.String(), r.Type.Bits)
o, err := prog.Wires(r, r.Type.Bits)
if err != nil {
return err
}
Expand Down Expand Up @@ -549,7 +549,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error {
}

case Builtin:
o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits)
o, err := prog.Wires(*instr.Out, instr.Out.Type.Bits)
if err != nil {
return err
}
Expand Down
14 changes: 7 additions & 7 deletions compiler/ssa/instructions.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Copyright (c) 2020-2022 Markku Rossi
// Copyright (c) 2020-2023 Markku Rossi
//
// All rights reserved.
//
Expand Down Expand Up @@ -151,7 +151,7 @@ type Instr struct {
Label *Block
Circ *circuit.Circuit
Builtin circuits.Builtin
GC string
GC *Value
Ret []Value
}

Expand Down Expand Up @@ -541,10 +541,10 @@ func NewBuiltinInstr(builtin circuits.Builtin, a, b, r Value) Instr {
}

// NewGCInstr creates a new GC instruction.
func NewGCInstr(v string) Instr {
func NewGCInstr(v Value) Instr {
return Instr{
Op: GC,
GC: v,
GC: &v,
}
}

Expand All @@ -560,7 +560,7 @@ func (i Instr) StringTyped() string {
func (i Instr) string(maxLen int, typesOnly bool) string {
result := i.Op.String()

if len(i.In) == 0 && i.Out == nil && i.Label == nil && len(i.GC) == 0 {
if len(i.In) == 0 && i.Out == nil && i.Label == nil && i.GC == nil {
return result
}

Expand Down Expand Up @@ -590,9 +590,9 @@ func (i Instr) string(maxLen int, typesOnly bool) string {
if i.Circ != nil {
result += fmt.Sprintf(" {G=%d, W=%d}", i.Circ.NumGates, i.Circ.NumWires)
}
if len(i.GC) > 0 {
if i.GC != nil {
result += " "
result += i.GC
result += i.GC.String()
}
for _, r := range i.Ret {
result += " "
Expand Down
9 changes: 6 additions & 3 deletions compiler/ssa/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ func NewProgram(params *utils.Params, in, out circuit.IO,
if len(arg.Name) == 0 {
arg.Name = fmt.Sprintf("arg{%d}", idx)
}
wires, err := prog.Wires(arg.Name, types.Size(arg.Size))
wires, err := prog.Wires(Value{
Const: true,
Name: arg.Name,
}, types.Size(arg.Size))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -194,7 +197,7 @@ func (prog *Program) GC() {
if !live {
// Input is not live.
gcs = append(gcs, Step{
Instr: NewGCInstr(in.String()),
Instr: NewGCInstr(in),
})
}
}
Expand Down Expand Up @@ -255,7 +258,7 @@ func (prog *Program) DefineConstants(zero, one *circuits.Wire) error {
wires = append(wires, w)
}

err := prog.SetWires(c.String(), wires)
err := prog.SetWires(c, wires)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 987a587

Please sign in to comment.