Skip to content

Commit

Permalink
Implement CSV import/export of tables
Browse files Browse the repository at this point in the history
Signed-off-by: Stefano Scafiti <stefano.scafiti96@gmail.com>
  • Loading branch information
ostafen committed Jul 24, 2024
1 parent bbb1608 commit 139ee49
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 3 deletions.
1 change: 1 addition & 0 deletions cmd/immuadmin/command/commandline.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func (cl *commandline) Register(rootCmd *cobra.Command) *cobra.Command {
cl.stats(rootCmd)
cl.serverConfig(rootCmd)
cl.database(rootCmd)

return rootCmd
}

Expand Down
236 changes: 236 additions & 0 deletions cmd/immuadmin/command/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@ limitations under the License.
package immuadmin

import (
"encoding/csv"
"fmt"
"io"
"os"
"path"
"strconv"
"strings"
"time"

"github.com/codenotary/immudb/cmd/helper"
c "github.com/codenotary/immudb/cmd/helper"
"github.com/codenotary/immudb/embedded/sql"
"github.com/codenotary/immudb/embedded/store"
"github.com/codenotary/immudb/embedded/tbtree"
"github.com/codenotary/immudb/pkg/api/schema"
"github.com/codenotary/immudb/pkg/client"
"github.com/codenotary/immudb/pkg/database"
"github.com/codenotary/immudb/pkg/replication"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -383,10 +389,240 @@ func (cl *commandline) database(cmd *cobra.Command) {
dbCmd.AddCommand(flushCmd)
dbCmd.AddCommand(compactCmd)
dbCmd.AddCommand(truncateCmd)
dbCmd.AddCommand(cl.createExportCmd())
dbCmd.AddCommand(cl.createImportCmd())

cmd.AddCommand(dbCmd)
}

func (cl *commandline) createExportCmd() *cobra.Command {
exportCmd := &cobra.Command{
Use: "export",
Short: "Dump an SQL table to a CSV file",
Aliases: []string{"e"},
ArgAliases: []string{"table"},
PersistentPreRunE: cl.ConfigChain(cl.connect),
PersistentPostRun: cl.disconnect,
RunE: func(cmd *cobra.Command, args []string) error {
table := args[0]

outputPath, _ := cmd.Flags().GetString("o")
if outputPath == "" {
wd, err := os.Getwd()
if err != nil {
return err
}
outputPath = path.Join(wd, table) + ".csv"
}

reader, err := cl.immuClient.SQLQueryReader(cl.context, fmt.Sprintf("SELECT * FROM %s", table), nil)
if err != nil {
return err
}
defer reader.Close()

csvFile, err := os.Create(outputPath)
if err != nil {
return err
}
defer csvFile.Close()

sep, err := cmd.Flags().GetString("s")
if err != nil {
return err
}
if len(sep) != 1 {
return fmt.Errorf("invalid separator")
}

writer := csv.NewWriter(csvFile)
writer.Comma = rune(sep[0])
writer.UseCRLF = true
defer writer.Flush()

cols := reader.Columns()

colNames := make([]string, len(cols))
for i, col := range cols {
colNames[i] = formatColName(col.Name)
}

if err := writer.Write(colNames); err != nil {
return err
}

out := make([]string, len(cols))
for reader.Next() {
row, err := reader.Read()
if err != nil {
return err
}

if err := rowToCSV(row, cols, out); err != nil {
return err
}

if err := writer.Write(out); err != nil {
return err
}
}
return writer.Error()
},
Args: cobra.ExactArgs(1),
}
exportCmd.Flags().String("o", "", "output")
exportCmd.Flags().String("s", ",", "separator")

return exportCmd
}

func rowToCSV(row client.Row, cols []client.Column, out []string) error {
for i, v := range row {
colType := cols[i].Type
rv, err := renderValue(v, colType)
if err != nil {
return err
}
out[i] = rv
}
return nil
}

func renderValue(v interface{}, colType string) (string, error) {
switch colType {
case sql.VarcharType, sql.JSONType, sql.UUIDType:
s, isStr := v.(string)
if !isStr {
return "", fmt.Errorf("invalid value received")
}
return s, nil
default:
sqlVal, err := schema.AsSQLValue(v)
if err != nil {
return "", err
}
return schema.RenderValue(sqlVal.Value), nil
}
}

func (cl *commandline) createImportCmd() *cobra.Command {
importCmd := &cobra.Command{
Use: "import",
Short: "Insert data to an existing table from a csv file",
Aliases: []string{"i"},
ArgAliases: []string{"file"},
PersistentPreRunE: cl.ConfigChain(cl.connect),
PersistentPostRun: cl.disconnect,
RunE: func(cmd *cobra.Command, args []string) error {
inputPath := args[0]

csvFile, err := os.Open(inputPath)
if err != nil {
return err
}
defer csvFile.Close()

sep, err := cmd.Flags().GetString("s")
if err != nil {
return err
}
if len(sep) != 1 {
return fmt.Errorf("invalid separator")
}

reader := csv.NewReader(csvFile)
reader.Comma = rune(sep[0])
reader.ReuseRecord = true

hasHeader, err := cmd.Flags().GetBool("h")
if err != nil {
return err
}

table, err := cmd.Flags().GetString("t")
if err != nil {
return err
}
if table == "" {
return fmt.Errorf("table name not specified")
}

if hasHeader {
_, err := reader.Read()
if err != nil && err != io.EOF {
return nil
}
}

// fetch column information
res, err := cl.immuClient.SQLQuery(cl.context, fmt.Sprintf("SELECT * FROM %s WHERE 0 = 0", table), nil, false)
if err != nil {
return err
}

cols := make([]string, len(res.Columns))
for i, col := range res.Columns {
cols[i] = formatColName(col.Name)
}

row, err := reader.Read()
for err == nil {
if len(row) != len(cols) {
return fmt.Errorf("wrong number of columns")
}

for i, v := range row {
row[i] = formatInsertValue(v, res.Columns[i].Type)
}

_, err = cl.immuClient.SQLExec(
cl.context,
fmt.Sprintf("INSERT INTO %s(%s) VALUES (%s)", table, strings.Join(cols, ","), strings.Join(row, ",")),
nil,
)
if err != nil {
return err
}
row, err = reader.Read()
}
if err != io.EOF {
return err
}
return nil
},
Args: cobra.ExactArgs(1),
}
importCmd.Flags().String("t", "", "table")
importCmd.Flags().Bool("h", true, "interpret the first column as header")
importCmd.Flags().String("s", ",", "separator")

return importCmd
}

func formatColName(col string) string {
idx := strings.Index(col, ".")
if idx >= 0 {
return col[idx+1 : len(col)-1]
}
return col
}

func formatInsertValue(v string, colType string) string {
if v == "NULL" {
return v
}

switch colType {
case sql.VarcharType:
return fmt.Sprintf("'%s'", v)
case sql.TimestampType, sql.JSONType, sql.UUIDType:
return fmt.Sprintf("CAST ('%s' AS %s)", v, colType)
case sql.BLOBType:
return fmt.Sprintf("x'%s'", v)
}
return v
}

func prepareDatabaseNullableSettings(flags *pflag.FlagSet) (*schema.DatabaseNullableSettings, error) {
var err error

Expand Down
4 changes: 2 additions & 2 deletions pkg/api/schema/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func EncodeParams(params map[string]interface{}) ([]*NamedParam, error) {

i := 0
for n, v := range params {
sqlVal, err := asSQLValue(v)
sqlVal, err := AsSQLValue(v)
if err != nil {
return nil, err
}
Expand All @@ -52,7 +52,7 @@ func NamedParamsFromProto(protoParams []*NamedParam) map[string]interface{} {
return params
}

func asSQLValue(v interface{}) (*SQLValue, error) {
func AsSQLValue(v interface{}) (*SQLValue, error) {
if v == nil {
return &SQLValue{Value: &SQLValue_Null{}}, nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/schema/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestAsSQLValue(t *testing.T) {
},
} {
t.Run(d.n, func(t *testing.T) {
sqlVal, err := asSQLValue(d.val)
sqlVal, err := AsSQLValue(d.val)
require.EqualValues(t, d.sqlVal, sqlVal)
if d.isErr {
require.ErrorIs(t, err, sql.ErrInvalidValue)
Expand Down

0 comments on commit 139ee49

Please sign in to comment.