Skip to content

Commit

Permalink
npy: first stab at a n-dim array with support for ragged-arrays
Browse files Browse the repository at this point in the history
Fixes #20.

Signed-off-by: Sebastien Binet <binet@cern.ch>
  • Loading branch information
sbinet committed Nov 24, 2023
1 parent 862adbe commit b28492a
Show file tree
Hide file tree
Showing 15 changed files with 3,286 additions and 11 deletions.
11 changes: 3 additions & 8 deletions dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"fmt"
"io"
"os"
"reflect"
"strings"

"github.com/sbinet/npyio/npy"
Expand Down Expand Up @@ -128,15 +127,11 @@ func display(o io.Writer, f io.Reader, fname string) error {

fmt.Fprintf(o, "npy-header: %v\n", r.Header)

rt := npy.TypeFrom(r.Header.Descr.Type)
if rt == nil {
return fmt.Errorf("npyio: no reflect type for %q", r.Header.Descr.Type)
}
rv := reflect.New(reflect.SliceOf(rt))
err = r.Read(rv.Interface())
var arr npy.Array
err = r.Read(&arr)
if err != nil && err != io.EOF {
return fmt.Errorf("npyio: read error: %w", err)
}
fmt.Fprintf(o, "data = %v\n", rv.Elem().Interface())
fmt.Fprintf(o, "data = %v\n", arr.Data())
return nil
}
4 changes: 4 additions & 0 deletions dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func TestDump(t *testing.T) {
name: "testdata/data_float64_forder.npz",
want: "testdata/data_float64_forder.npz.txt",
},
{
name: "testdata/ragged-array.npy",
want: "testdata/ragged-array.npy.txt",
},
} {
t.Run(tc.name, func(t *testing.T) {
f, err := os.Open(tc.name)
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ go 1.20

require (
github.com/campoy/embedmd v1.0.0
github.com/nlpodyssey/gopickle v0.2.1-0.20231124153821-2139434d2287
golang.org/x/text v0.14.0
gonum.org/v1/gonum v0.14.0
)

Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
github.com/campoy/embedmd v1.0.0 h1:V4kI2qTJJLf4J29RzI/MAt2c3Bl4dQSYPuflzwFH2hY=
github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8=
github.com/nlpodyssey/gopickle v0.2.1-0.20231124153821-2139434d2287 h1:cqTk2IOiApRFn/e3YoAVQDgvOc6yU1zihhrH1w14WTg=
github.com/nlpodyssey/gopickle v0.2.1-0.20231124153821-2139434d2287/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
231 changes: 231 additions & 0 deletions npy/array.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
// Copyright 2023 The npyio Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package npy

import (
"fmt"
"strings"

py "github.com/nlpodyssey/gopickle/types"
)

// Array is a multidimensional, homogeneous array of fixed-size items.
type Array struct {
descr ArrayDescr
shape []int
strides []int
fortran bool

data any
}

var (
_ py.Callable = (*Array)(nil)
_ py.PyNewable = (*Array)(nil)
_ py.PyStateSettable = (*Array)(nil)
)

func (*Array) Call(args ...any) (any, error) {
switch sz := len(args); {
case sz < 1, sz > 3:
return nil, fmt.Errorf("invalid tuple length (got=%d)", sz)
}

return &Array{}, nil
}

func (*Array) PyNew(args ...any) (any, error) {
var (
subtype = args[0]
descr = args[1].(*ArrayDescr)
shape = args[2].([]int)
strides = args[3].([]int)
data = args[4].([]byte)
flags = args[5].(int)
)

return newArray(subtype, *descr, shape, strides, data, flags)
}

func newArray(subtype any, descr ArrayDescr, shape, strides []int, data []byte, flags int) (*Array, error) {
switch subtype := subtype.(type) {
case *Array:
// ok.
default:
return nil, fmt.Errorf("subtyping ndarray with %T is not (yet?) supported", subtype)
}

arr := &Array{
descr: descr,
shape: shape,
strides: strides,
data: data,
}
return arr, nil
}

func (arr *Array) PySetState(arg any) error {
tuple, ok := arg.(*py.Tuple)
if !ok {
return fmt.Errorf("invalid argument type %T", arg)
}

var (
vers = 0
shape py.Tuple
raw any
)
switch tuple.Len() {
case 5:
err := parseTuple(tuple, &vers, &shape, &arr.descr, &arr.fortran, nil)
if err != nil {
return fmt.Errorf("could not parse ndarray.__setstate__ tuple: %w", err)
}
raw = tuple.Get(4)
case 4:
err := parseTuple(tuple, &shape, &arr.descr, &arr.fortran, nil)
if err != nil {
return fmt.Errorf("could not parse ndarray.__setstate__ tuple: %w", err)
}
raw = tuple.Get(3)
default:
return fmt.Errorf("invalid length (%d) for ndarray.__setstate__ tuple", tuple.Len())
}

arr.shape = nil
for i := range shape {
v, ok := shape.Get(i).(int)
if !ok {
return fmt.Errorf("invalid shape[%d]: got=%T, want=int", i, shape.Get(i))
}
arr.shape = append(arr.shape, v)
}

err := arr.setupStrides()
if err != nil {
return fmt.Errorf("ndarray.__setstate__ could not infer strides: %w", err)
}

switch raw := raw.(type) {
case *py.List:
arr.data = raw

case []byte:
data, err := arr.descr.unmarshal(raw, arr.shape)
if err != nil {
return fmt.Errorf("ndarray.__setstate__ could not unmarshal raw data: %w", err)
}
arr.data = data
}

return nil
}

func (arr *Array) setupStrides() error {
// TODO(sbinet): complete implementation.
// see: _array_fill_strides in numpy/_core/multiarray/ctors.c

if arr.shape == nil {
arr.strides = nil
return nil
}

strides := make([]int, len(arr.shape))
// FIXME(sbinet): handle non-contiguous arrays
// FIXME(sbinet): handle FORTRAN arrays

var (
// notCFContig bool
noDim bool // a dimension != 1 was found
)

// check if array is both FORTRAN- and C-contiguous
for _, dim := range arr.shape {
if dim != 1 {
if noDim {
// notCFContig = true
break
}
noDim = true
}
}

itemsize := arr.descr.itemsize()
switch {
case arr.fortran:
for i, dim := range arr.shape {
strides[i] = itemsize
switch {
case dim != 0:
itemsize *= dim
default:
// notCFContig = false
}
}

default:
for i := len(arr.shape) - 1; i >= 0; i-- {
dim := arr.shape[i]
strides[i] = itemsize
switch {
case dim != 0:
itemsize *= dim
default:
// notCFContig = false
}
}
}

arr.strides = strides
return nil
}

// Descr returns the array's data type descriptor.
func (arr Array) Descr() ArrayDescr {
return arr.descr
}

// Shape returns the array's shape.
func (arr Array) Shape() []int {
return arr.shape
}

// Strides returns the array's strides in bytes.
func (arr Array) Strides() []int {
return arr.strides
}

// Fortran returns whether the array's data is stored in FORTRAN-order
// (ie: column-major) instead of C-order (ie: row-major.)
func (arr Array) Fortran() bool {
return arr.fortran
}

// Data returns the array's underlying data.
func (arr Array) Data() any {
return arr.data
}

func (arr Array) String() string {
o := new(strings.Builder)
fmt.Fprintf(o, "Array{descr: %v, ", arr.descr)
switch arr.shape {
case nil:
fmt.Fprintf(o, "shape: nil, ")
default:
fmt.Fprintf(o, "shape: %v, ", arr.shape)
}
switch arr.strides {
case nil:
fmt.Fprintf(o, "strides: nil, ")
default:
fmt.Fprintf(o, "strides: %v, ", arr.strides)
}
fmt.Fprintf(o, "fortran: %v, data: %+v}",
arr.fortran,
arr.data,
)
return o.String()
}
Loading

0 comments on commit b28492a

Please sign in to comment.