Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

npy: first stab at a n-dim array with support for ragged-arrays #22

Merged
merged 5 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
runs-on: ${{ matrix.platform }}
steps:
- name: Install Go
uses: actions/setup-go@v3
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go-version }}

Expand Down Expand Up @@ -80,11 +80,11 @@ jobs:
run: |
go run ./ci/run-tests.go $TAGS
- name: static-check
uses: dominikh/staticcheck-action@v1.2.0
uses: dominikh/staticcheck-action@v1
with:
install-go: false
cache-key: ${{ matrix.platform }}
version: "2022.1"
version: "2023.1.5"
- name: Upload-Coverage
if: matrix.platform == 'ubuntu-latest'
uses: codecov/codecov-action@v2
uses: codecov/codecov-action@v3
14 changes: 5 additions & 9 deletions dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ package npyio

import (
"bytes"
"errors"
"fmt"
"io"
"os"
"reflect"
"strings"

"github.com/sbinet/npyio/npy"
Expand Down Expand Up @@ -128,15 +128,11 @@ func display(o io.Writer, f io.Reader, fname string) error {

fmt.Fprintf(o, "npy-header: %v\n", r.Header)

rt := npy.TypeFrom(r.Header.Descr.Type)
if rt == nil {
return fmt.Errorf("npyio: no reflect type for %q", r.Header.Descr.Type)
}
rv := reflect.New(reflect.SliceOf(rt))
err = r.Read(rv.Interface())
if err != nil && err != io.EOF {
var arr npy.Array
err = r.Read(&arr)
if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("npyio: read error: %w", err)
}
fmt.Fprintf(o, "data = %v\n", rv.Elem().Interface())
fmt.Fprintf(o, "data = %v\n", arr.Data())
return nil
}
8 changes: 8 additions & 0 deletions dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ func TestDump(t *testing.T) {
name: "testdata/data_float64_forder.npz",
want: "testdata/data_float64_forder.npz.txt",
},
{
name: "testdata/ragged-array.npy",
want: "testdata/ragged-array.npy.txt",
},
{
name: "testdata/ragged-array-mixed.npy",
want: "testdata/ragged-array-mixed.npy.txt",
},
} {
t.Run(tc.name, func(t *testing.T) {
f, err := os.Open(tc.name)
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ go 1.20

require (
github.com/campoy/embedmd v1.0.0
github.com/nlpodyssey/gopickle v0.3.0
golang.org/x/text v0.14.0
gonum.org/v1/gonum v0.14.0
)

Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
github.com/campoy/embedmd v1.0.0 h1:V4kI2qTJJLf4J29RzI/MAt2c3Bl4dQSYPuflzwFH2hY=
github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8=
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
221 changes: 221 additions & 0 deletions npy/array.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// Copyright 2023 The npyio Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package npy

import (
"fmt"
"strings"

py "github.com/nlpodyssey/gopickle/types"
)

// Array is a multidimensional, homogeneous array of fixed-size items.
type Array struct {
descr ArrayDescr
shape []int
strides []int
fortran bool

data any
}

var (
_ py.PyNewable = (*Array)(nil)
_ py.PyStateSettable = (*Array)(nil)
)

func (*Array) PyNew(args ...any) (any, error) {
var (
subtype = args[0]
descr = args[1].(*ArrayDescr)
shape = args[2].([]int)
strides = args[3].([]int)
data = args[4].([]byte)
flags = args[5].(int)
)

return newArray(subtype, *descr, shape, strides, data, flags)
}

func newArray(subtype any, descr ArrayDescr, shape, strides []int, data []byte, flags int) (*Array, error) {
switch subtype := subtype.(type) {
case *Array:
// ok.
default:
return nil, fmt.Errorf("subtyping ndarray with %T is not (yet?) supported", subtype)
}

arr := &Array{
descr: descr,
shape: shape,
strides: strides,
data: data,
}
return arr, nil
}

func (arr *Array) PySetState(arg any) error {
tuple, ok := arg.(*py.Tuple)
if !ok {
return fmt.Errorf("invalid argument type %T", arg)
}

var (
vers = 0
shape py.Tuple
raw any
)
switch tuple.Len() {
case 5:
err := parseTuple(tuple, &vers, &shape, &arr.descr, &arr.fortran, nil)
if err != nil {
return fmt.Errorf("could not parse ndarray.__setstate__ tuple: %w", err)
}
raw = tuple.Get(4)
case 4:
err := parseTuple(tuple, &shape, &arr.descr, &arr.fortran, nil)
if err != nil {
return fmt.Errorf("could not parse ndarray.__setstate__ tuple: %w", err)
}
raw = tuple.Get(3)
default:
return fmt.Errorf("invalid length (%d) for ndarray.__setstate__ tuple", tuple.Len())
}

arr.shape = nil
for i := range shape {
v, ok := shape.Get(i).(int)
if !ok {
return fmt.Errorf("invalid shape[%d]: got=%T, want=int", i, shape.Get(i))
}
arr.shape = append(arr.shape, v)
}

err := arr.setupStrides()
if err != nil {
return fmt.Errorf("ndarray.__setstate__ could not infer strides: %w", err)
}

switch raw := raw.(type) {
case *py.List:
arr.data = raw

case []byte:
data, err := arr.descr.unmarshal(raw, arr.shape)
if err != nil {
return fmt.Errorf("ndarray.__setstate__ could not unmarshal raw data: %w", err)
}
arr.data = data
}

return nil
}

func (arr *Array) setupStrides() error {
// TODO(sbinet): complete implementation.
// see: _array_fill_strides in numpy/_core/multiarray/ctors.c

if arr.shape == nil {
arr.strides = nil
return nil
}

strides := make([]int, len(arr.shape))
// FIXME(sbinet): handle non-contiguous arrays
// FIXME(sbinet): handle FORTRAN arrays

var (
// notCFContig bool
noDim bool // a dimension != 1 was found
)

// check if array is both FORTRAN- and C-contiguous
for _, dim := range arr.shape {
if dim != 1 {
if noDim {
// notCFContig = true
break
}
noDim = true
}
}

itemsize := arr.descr.itemsize()
switch {
case arr.fortran:
for i, dim := range arr.shape {
strides[i] = itemsize
switch {
case dim != 0:
itemsize *= dim
default:
// notCFContig = false
}
}

default:
for i := len(arr.shape) - 1; i >= 0; i-- {
dim := arr.shape[i]
strides[i] = itemsize
switch {
case dim != 0:
itemsize *= dim
default:
// notCFContig = false
}
}
}

arr.strides = strides
return nil
}

// Descr returns the array's data type descriptor.
func (arr Array) Descr() ArrayDescr {
return arr.descr
}

// Shape returns the array's shape.
func (arr Array) Shape() []int {
return arr.shape
}

// Strides returns the array's strides in bytes.
func (arr Array) Strides() []int {
return arr.strides
}

// Fortran returns whether the array's data is stored in FORTRAN-order
// (ie: column-major) instead of C-order (ie: row-major.)
func (arr Array) Fortran() bool {
return arr.fortran
}

// Data returns the array's underlying data.
func (arr Array) Data() any {
return arr.data
}

func (arr Array) String() string {
o := new(strings.Builder)
fmt.Fprintf(o, "Array{descr: %v, ", arr.descr)
switch arr.shape {
case nil:
fmt.Fprintf(o, "shape: nil, ")
default:
fmt.Fprintf(o, "shape: %v, ", arr.shape)
}
switch arr.strides {
case nil:
fmt.Fprintf(o, "strides: nil, ")
default:
fmt.Fprintf(o, "strides: %v, ", arr.strides)
}
fmt.Fprintf(o, "fortran: %v, data: %+v}",
arr.fortran,
arr.data,
)
return o.String()
}
51 changes: 51 additions & 0 deletions npy/array_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2023 The npyio Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package npy

import (
"fmt"
"os"
"reflect"
"testing"
)

func TestArrayStringer(t *testing.T) {
f, err := os.Open("../testdata/data_float64_2x3x4_corder.npy")
if err != nil {
t.Fatalf("could not open testdata: %+v", err)
}
defer f.Close()

var arr Array
err = Read(f, &arr)
if err != nil {
t.Fatalf("could not read data: %+v", err)
}

var (
want = `Array{descr: ArrayDescr{kind: 'f', order: '<', flags: 0, esize: 8, align: 8, subarr: <nil>, names: [], fields: {}, meta: map[]}, shape: [2 3 4], strides: [96 32 8], fortran: false, data: [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]}`
got = fmt.Sprintf("%v", arr)
)

if got != want {
t.Fatalf("invalid array display:\ngot= %s\nwant=%s", got, want)
}

if got, want := arr.Descr().kind, byte('f'); got != want {
t.Fatalf("invalid kind: got=%c, want=%c", got, want)
}

if got, want := arr.Shape(), []int{2, 3, 4}; !reflect.DeepEqual(got, want) {
t.Fatalf("invalid shape:\ngot= %+v\nwant=%+v", got, want)
}

if got, want := arr.Strides(), []int{96, 32, 8}; !reflect.DeepEqual(got, want) {
t.Fatalf("invalid strides:\ngot= %+v\nwant=%+v", got, want)
}

if got, want := arr.Fortran(), false; got != want {
t.Fatalf("invalid fortran:\ngot= %+v\nwant=%+v", got, want)
}
}
Loading
Loading