diff --git a/apps/garbled/streaming.go b/apps/garbled/streaming.go index 9cee6332..3ff96152 100644 --- a/apps/garbled/streaming.go +++ b/apps/garbled/streaming.go @@ -20,6 +20,11 @@ import ( ) func streamEvaluatorMode(oti ot.OT, input input, once bool) error { + inputSizes, err := circuit.InputSizes(input) + if err != nil { + return err + } + ln, err := net.Listen("tcp", port) if err != nil { return err @@ -34,6 +39,18 @@ func streamEvaluatorMode(oti ot.OT, input input, once bool) error { fmt.Printf("New connection from %s\n", nc.RemoteAddr()) conn := p2p.NewConn(nc) + + err = conn.SendInputSizes(inputSizes) + if err != nil { + conn.Close() + return err + } + err = conn.Flush() + if err != nil { + conn.Close() + return err + } + outputs, result, err := circuit.StreamEvaluator(conn, oti, input, verbose) conn.Close() @@ -52,6 +69,14 @@ func streamEvaluatorMode(oti ot.OT, input input, once bool) error { func streamGarblerMode(params *utils.Params, oti ot.OT, input input, args []string) error { + inputSizes := make([][]int, 2) + + sizes, err := circuit.InputSizes(input) + if err != nil { + return err + } + inputSizes[0] = sizes + if len(args) != 1 || !strings.HasSuffix(args[0], ".mpcl") { return fmt.Errorf("streaming mode takes single MPCL file") } @@ -62,8 +87,14 @@ func streamGarblerMode(params *utils.Params, oti ot.OT, input input, conn := p2p.NewConn(nc) defer conn.Close() + sizes, err = conn.ReceiveInputSizes() + if err != nil { + return err + } + inputSizes[1] = sizes + outputs, result, err := compiler.New(params).StreamFile( - conn, oti, args[0], input) + conn, oti, args[0], input, inputSizes) if err != nil { return err } diff --git a/compiler/ast/builtin.go b/compiler/ast/builtin.go index 68ed885b..73282e1c 100644 --- a/compiler/ast/builtin.go +++ b/compiler/ast/builtin.go @@ -228,30 +228,53 @@ func lenEval(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator, switch arg := args[0].(type) { case *VariableRef: - var b ssa.Binding - var ok bool + var typeInfo types.Info if len(arg.Name.Package) > 0 { - var pkg *Package - pkg, ok = ctx.Packages[arg.Name.Package] + // Check if the package name is bound to a value. + b, ok := env.Get(arg.Name.Package) + if ok { + if b.Type.Type != types.TStruct { + return ssa.Undefined, false, ctx.Errorf(loc, + "%s undefined", arg.Name) + } + ok = false + for _, f := range b.Type.Struct { + if f.Name == arg.Name.Name { + typeInfo = f.Type + ok = true + break + } + } + if !ok { + return ssa.Undefined, false, ctx.Errorf(loc, + "undefined variable '%s'", arg.Name) + } + } else { + // Resolve name from the package. + pkg, ok := ctx.Packages[arg.Name.Package] + if !ok { + return ssa.Undefined, false, ctx.Errorf(loc, + "package '%s' not found", arg.Name.Package) + } + b, ok := pkg.Bindings.Get(arg.Name.Name) + if !ok { + return ssa.Undefined, false, ctx.Errorf(loc, + "undefined variable '%s'", arg.Name) + } + typeInfo = b.Type + } + } else { + b, ok := env.Get(arg.Name.Name) if !ok { return ssa.Undefined, false, ctx.Errorf(loc, - "package '%s' not found", arg.Name.Package) + "undefined variable '%s'", arg.Name) } - b, ok = pkg.Bindings.Get(arg.Name.Name) - } else { - b, ok = env.Get(arg.Name.Name) - } - if !ok { - return ssa.Undefined, false, ctx.Errorf(loc, - "undefined variable '%s'", arg.Name.String()) + typeInfo = b.Type } - var typeInfo types.Info - if b.Type.Type == types.TPtr { - typeInfo = *b.Type.ElementType - } else { - typeInfo = b.Type + if typeInfo.Type == types.TPtr { + typeInfo = *typeInfo.ElementType } switch typeInfo.Type { @@ -265,7 +288,7 @@ func lenEval(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator, default: return ssa.Undefined, false, ctx.Errorf(loc, - "invalid argument 1 (type %s) for len", b.Type) + "invalid argument 1 (type %s) for len", typeInfo) } default: diff --git a/compiler/ast/codegen.go b/compiler/ast/codegen.go index 1261fac3..ea045d5a 100644 --- a/compiler/ast/codegen.go +++ b/compiler/ast/codegen.go @@ -1,7 +1,7 @@ // // ast.go // -// Copyright (c) 2019-2022 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -19,28 +19,32 @@ import ( // Codegen implements compilation stack. type Codegen struct { - logger *utils.Logger - Params *utils.Params - Verbose bool - Package *Package - Packages map[string]*Package - Stack []Compilation - Types map[types.ID]*TypeInfo - Native map[string]*circuit.Circuit - HeapID int + logger *utils.Logger + Params *utils.Params + Verbose bool + Package *Package + Packages map[string]*Package + MainInputSizes [][]int + Stack []Compilation + Types map[types.ID]*TypeInfo + Native map[string]*circuit.Circuit + HeapID int } // NewCodegen creates a new compilation. func NewCodegen(logger *utils.Logger, pkg *Package, - packages map[string]*Package, params *utils.Params) *Codegen { + packages map[string]*Package, params *utils.Params, + mainInputSizes [][]int) *Codegen { + return &Codegen{ - logger: logger, - Params: params, - Verbose: params.Verbose, - Package: pkg, - Packages: packages, - Types: make(map[types.ID]*TypeInfo), - Native: make(map[string]*circuit.Circuit), + logger: logger, + Params: params, + Verbose: params.Verbose, + Package: pkg, + Packages: packages, + MainInputSizes: mainInputSizes, + Types: make(map[types.ID]*TypeInfo), + Native: make(map[string]*circuit.Circuit), } } diff --git a/compiler/ast/package.go b/compiler/ast/package.go index b18d9e39..25e446d2 100644 --- a/compiler/ast/package.go +++ b/compiler/ast/package.go @@ -67,15 +67,29 @@ func (pkg *Package) Compile(ctx *Codegen) (*ssa.Program, Annotations, error) { // Arguments. var inputs circuit.IO - for _, arg := range main.Args { + for idx, arg := range main.Args { typeInfo, err := arg.Type.Resolve(NewEnv(ctx.Start()), ctx, gen) if err != nil { return nil, nil, ctx.Errorf(arg, "invalid argument type: %s", err) } if typeInfo.Bits == 0 { - return nil, nil, - ctx.Errorf(arg, "argument %s of %s has unspecified type", + if ctx.MainInputSizes == nil { + return nil, nil, + ctx.Errorf(arg, "argument %s of %s has unspecified type", + arg.Name, main) + } + // Specify unspecified argument type. + if idx >= len(ctx.MainInputSizes) { + return nil, nil, ctx.Errorf(arg, + "not enough values for argument %s of %s", arg.Name, main) + } + err = typeInfo.InstantiateWithSizes(ctx.MainInputSizes[idx]) + if err != nil { + return nil, nil, ctx.Errorf(arg, + "can't specify unspecified argument %s of %s: %s", + arg.Name, main, err) + } } // Define argument in block. a := gen.NewVal(arg.Name, typeInfo, ctx.Scope()) @@ -265,10 +279,6 @@ func (pkg *Package) defineType(def *TypeInfo, ctx *Codegen, if err != nil { return err } - if info.Bits == 0 { - return ctx.Errorf(field, - "unspecified size for struct field %s", field.Name) - } field := types.StructField{ Name: field.Name, Type: info, diff --git a/compiler/compiler.go b/compiler/compiler.go index 5d36c2cf..f3b0ff68 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -80,7 +80,7 @@ func (c *Compiler) compile(source string, in io.Reader) ( return nil, nil, err } - ctx := ast.NewCodegen(logger, pkg, c.packages, c.params) + ctx := ast.NewCodegen(logger, pkg, c.packages, c.params, nil) program, annotation, err := pkg.Compile(ctx) if err != nil { @@ -99,18 +99,19 @@ func (c *Compiler) compile(source string, in io.Reader) ( // StreamFile compiles the input program and uses the streaming mode // to garble and stream the circuit to the evaluator node. func (c *Compiler) StreamFile(conn *p2p.Conn, oti ot.OT, file string, - input []string) (circuit.IO, []*big.Int, error) { + input []string, inputSizes [][]int) (circuit.IO, []*big.Int, error) { f, err := os.Open(file) if err != nil { return nil, nil, err } defer f.Close() - return c.stream(conn, oti, file, f, input) + return c.stream(conn, oti, file, f, input, inputSizes) } func (c *Compiler) stream(conn *p2p.Conn, oti ot.OT, source string, - in io.Reader, inputFlag []string) (circuit.IO, []*big.Int, error) { + in io.Reader, inputFlag []string, inputSizes [][]int) ( + circuit.IO, []*big.Int, error) { timing := circuit.NewTiming() @@ -120,7 +121,7 @@ func (c *Compiler) stream(conn *p2p.Conn, oti ot.OT, source string, return nil, nil, err } - ctx := ast.NewCodegen(logger, pkg, c.packages, c.params) + ctx := ast.NewCodegen(logger, pkg, c.packages, c.params, inputSizes) program, _, err := pkg.Compile(ctx) if err != nil { diff --git a/p2p/protocol.go b/p2p/protocol.go index a75f6bf3..8d722e98 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -252,6 +252,19 @@ func (c *Conn) SendString(val string) error { return c.SendData([]byte(val)) } +// SendInputSizes sends the input sizes. +func (c *Conn) SendInputSizes(sizes []int) error { + if err := c.SendUint32(len(sizes)); err != nil { + return err + } + for i := 0; i < len(sizes); i++ { + if err := c.SendUint32(sizes[i]); err != nil { + return err + } + } + return nil +} + // ReceiveByte receives a byte value. func (c *Conn) ReceiveByte() (byte, error) { if c.ReadStart+1 > c.ReadEnd { @@ -340,6 +353,23 @@ func (c *Conn) ReceiveString() (string, error) { return string(data), nil } +// ReceiveInputSizes receives input sizes. +func (c *Conn) ReceiveInputSizes() ([]int, error) { + count, err := c.ReceiveUint32() + if err != nil { + return nil, err + } + result := make([]int, count) + for i := 0; i < count; i++ { + size, err := c.ReceiveUint32() + if err != nil { + return nil, err + } + result[i] = size + } + return result, nil +} + // Receive implements OT receive for the bit value of a wire. func (c *Conn) Receive(receiver *ot.Receiver, wire, bit uint) ([]byte, error) { if err := c.SendUint32(int(wire)); err != nil { diff --git a/types/types.go b/types/types.go index 455dc37e..bfe09561 100644 --- a/types/types.go +++ b/types/types.go @@ -1,7 +1,7 @@ // // types.go // -// Copyright (c) 2020-2021 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -244,6 +244,53 @@ func (i *Info) Instantiate(o Info) bool { } } +// InstantiateWithSizes creates a concrete type of the unspecified +// type with given element sizes. +func (i *Info) InstantiateWithSizes(sizes []int) error { + if len(sizes) == 0 { + return fmt.Errorf("not enought sizes for type %v", i) + } + + switch i.Type { + case TBool: + + case TInt, TUint, TFloat: + if i.Bits == 0 { + i.Bits = Size(sizes[0]) + } + + case TStruct: + var structBits Size + for idx := range i.Struct { + if idx >= len(sizes) { + return fmt.Errorf("not enought sizes for type %v", i) + } + err := i.Struct[idx].Type.InstantiateWithSizes(sizes[idx:]) + if err != nil { + return err + } + i.Struct[idx].Type.Offset = structBits + structBits += i.Struct[idx].Type.Bits + } + i.Bits = structBits + + case TArray: + if i.ElementType == nil || i.ElementType.Bits == 0 { + return fmt.Errorf("array element type unspecified: %v", i) + } + i.ArraySize = Size(sizes[0]) / i.ElementType.Bits + if Size(sizes[0])%i.ElementType.Bits != 0 { + i.ArraySize++ + } + i.Bits = i.ArraySize * i.ElementType.Bits + + default: + return fmt.Errorf("can't specify %v", i) + } + + return nil +} + // Equal tests if the argument type is equal to this type info. func (i Info) Equal(o Info) bool { if i.Type != o.Type {