From 7fc6bb586eb02e2adf51062a7f61a2134a78b9c1 Mon Sep 17 00:00:00 2001 From: luo-cheng-xi Date: Sat, 21 Sep 2024 20:44:41 -0400 Subject: [PATCH] fix: the characters that width is 2 will cost 2 cells, the characters that width is 1 will cost 1 cell. It will change the way it handle is character. --- README.md | 8 ++- example/basic/main.go | 7 ++- go.mod | 2 + term.go | 124 +++++++++++++++--------------------------- test/term_test.go | 18 +++--- 5 files changed, 68 insertions(+), 91 deletions(-) diff --git a/README.md b/README.md index 708a0d5..d428302 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ package main import ( "fmt" "github.com/chengxilo/virtualterm" + "log" ) func main() { @@ -44,11 +45,16 @@ func main() { vt := virtualterm.NewDefault() vt.Write([]byte(str)) fmt.Println(str == "virtual-terminal") - fmt.Println(vt.String() == "virtual-terminal") + str,err := vt.String() + if err != nil { + log.Fatal(err) + } + fmt.Println(str == "virtual-terminal") // Output: // false // true } + ``` Use `virtualterm.Process` function. You will not need to create a virtual terminal and input on your own. diff --git a/example/basic/main.go b/example/basic/main.go index 0a3f4b3..b9b22b4 100644 --- a/example/basic/main.go +++ b/example/basic/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "github.com/chengxilo/virtualterm" + "log" ) func main() { @@ -10,7 +11,11 @@ func main() { vt := virtualterm.NewDefault() vt.Write([]byte(str)) fmt.Println(str == "virtual-terminal") - fmt.Println(vt.String() == "virtual-terminal") + str, err := vt.String() + if err != nil { + log.Fatal(err) + } + fmt.Println(str == "virtual-terminal") // Output: // false // true diff --git a/go.mod b/go.mod index 7cdf4bd..36a7637 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,9 @@ require github.com/stretchr/testify v1.9.0 require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rivo/uniseg v0.2.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/term.go b/term.go index a86b8a2..ea22d68 100644 --- a/term.go +++ b/term.go @@ -3,15 +3,19 @@ package virtualterm import ( "errors" "fmt" + "github.com/mattn/go-runewidth" "log" "math" "strconv" "strings" ) -var INF = math.MaxInt - 1 var ErrCannotHandle = errors.New("this csi is not supported or syntax error") -var ErrNonDetermistics = errors.New("non-deterministic") + +// ErrNonDeterministic is created for some situation that cannot be handled well. +// for example, if you write "你好\b啊" the result will be "你 啊" in Windows Powershell, and will be +// "你好啊" in git bash in Windows. I would treat it as an error caused by users' input. +var ErrNonDeterministic = errors.New("non-deterministic") // VirtualTerm this is created to simulate a terminal,handle the special character such as '\r','\b', "\033[1D". // For example: if you input "cute\rhat", the result of String() would be "hate" @@ -23,9 +27,6 @@ type VirtualTerm struct { // the content of virtual terminal content [][]rune - // xOffset is the offset on X. - xOffset int - // silence will shut down the log.By default, it is true silence bool } @@ -162,62 +163,9 @@ func (vt *VirtualTerm) WriteString(s string) (n int, err error) { return vt.Write([]byte(s)) } -// runeWidth -func (*VirtualTerm) runeWidth(r rune) int { - if len(string(r)) >= 3 { - return 2 - } else { - return 1 - } -} - // cursorMove can control the cursor func (vt *VirtualTerm) cursorMove(x int, y int) { - // handle the offset first - if vt.xOffset != 0 { - x += vt.xOffset - vt.xOffset = 0 - } - if x < 0 { - // move the cursor left - // this may cause offset - x = -x - var far int - for vt.cx > 0 && x > 0 { - far = vt.runeWidth(vt.content[vt.cy][vt.cx-1]) - if x >= far { - x -= far - vt.cx-- - } else { - break - } - } - if x == 1 && vt.cx != 0 { - vt.xOffset = -1 - } - } else if x > 0 { - // move the cursor right - // this situation will not cause offset - for x > 0 { - // if cx is out of bound, add empty element - if vt.cx >= len(vt.content[vt.cy])-1 { - vt.content[vt.cy] = append(vt.content[vt.cy], ' ') - } - far := vt.runeWidth(vt.content[vt.cy][vt.cx]) - if far <= x { - x -= far - vt.cx++ - } else { - break - } - } - if x == 1 { - vt.xOffset = -1 - vt.cx++ - } - } - - // avoid index out of bound + vt.cx = max(vt.cx+x, 0) vt.cy = max(vt.cy+y, 0) for vt.cy >= len(vt.content) { vt.content = append(vt.content, []rune{' '}) @@ -231,22 +179,18 @@ func (vt *VirtualTerm) cursorMove(x int, y int) { func (vt *VirtualTerm) cursorHome() { vt.cx = 0 vt.cy = 0 - vt.xOffset = 0 } // writeRune write Rune to content. func (vt *VirtualTerm) writeRune(r rune) error { - // if the offset of cursor is not zero, means that there will be non-deterministic for the output - // For example, if your output is "你好\bCOOL", than it might be "你好OOL"(git bash in Windows) or 你 COOL("Windows powershell") - // So it should be treated as an error. - if vt.xOffset != 0 { - return ErrNonDetermistics + wid := runewidth.RuneWidth(r) + // write according to the width of rune. + // For example: '中' need two cells, but 'a' only need one cell + for wid > 0 { + vt.content[vt.cy][vt.cx] = r + vt.cursorMove(1, 0) + wid-- } - // get the width of rune - far := vt.runeWidth(r) - vt.content[vt.cy][vt.cx] = r - vt.cursorMove(far, 0) - return nil } @@ -256,11 +200,11 @@ func (vt *VirtualTerm) WriteRunes(p []rune) (n int, err error) { switch p[i] { case '\r': // Carriage Return - vt.cursorMove(-INF, 0) + vt.cursorMove(-math.MaxInt, 0) case '\n': // NewLine // If the cursor is on the last line, add a new line - vt.cursorMove(-INF, 1) + vt.cursorMove(-math.MaxInt, 1) case '\b': vt.cursorMove(-1, 0) case '\033': @@ -292,26 +236,44 @@ func (vt *VirtualTerm) WriteRunes(p []rune) (n int, err error) { } // writeString write String to content -func (vt *VirtualTerm) writeString(s string) { +func (vt *VirtualTerm) writeString(s string) error { for _, c := range s { - vt.writeRune(c) + if err := vt.writeRune(c); err != nil { + return err + } } + return nil } -func (vt *VirtualTerm) String() string { +// String get the result of the prediction. +// If there is some non-deterministic,you will get an error. +func (vt *VirtualTerm) String() (string, error) { builder := strings.Builder{} - for i, line := range vt.content { - for j, c := range line { - if j == len(line)-1 { + for i := 0; i < len(vt.content); i++ { + for j := 0; j < len(vt.content[i]); j++ { + if j == len(vt.content[i])-1 { break } - builder.WriteRune(c) + c := vt.content[i][j] + wid := runewidth.RuneWidth(c) + // if it is a character cost 2 cells,such as '中','ひ','안',or emoji + if wid == 1 { + // append the character to result + builder.WriteRune(c) + } else if wid == 2 { + if vt.content[i][j] != vt.content[i][j+1] { + return "", ErrNonDeterministic + } + builder.WriteRune(c) + j++ + } + } if i != len(vt.content)-1 { builder.WriteRune('\n') } } - return builder.String() + return builder.String(), nil } // Clear all the content in virtual terminal @@ -334,5 +296,5 @@ func Process(input string) (string, error) { if err != nil { return "", err } - return vt.String(), err + return vt.String() } diff --git a/test/term_test.go b/test/term_test.go index 627067f..f508f61 100644 --- a/test/term_test.go +++ b/test/term_test.go @@ -51,7 +51,7 @@ func TestCarriageReturn(t *testing.T) { if err != nil { t.Fatal(err) } - actual := vt.String() + actual, _ := vt.String() assert.Equal(t, te.output, actual) vt.Clear() } @@ -74,7 +74,7 @@ func TestNewLine(t *testing.T) { if err != nil { t.Fatal(err) } - actual := vt.String() + actual, _ := vt.String() assert.Equal(t, te.output, actual) vt.Clear() } @@ -95,7 +95,7 @@ func TestBackspace(t *testing.T) { if err != nil { t.Fatal(err) } - actual := vt.String() + actual, _ := vt.String() assert.Equal(t, te.output, actual) vt.Clear() } @@ -107,6 +107,7 @@ func TestCSI(t *testing.T) { input string expected string }{ + {"123\r嗨", "嗨3"}, {"\033[123*", ""}, {"你好\r\033[4C啊", "你好啊"}, {"你好\r\033[C", "你好"}, @@ -136,7 +137,7 @@ func TestCSI(t *testing.T) { if err != nil { t.Fatal(err, i) } - actual := vt.String() + actual, _ := vt.String() if actual != te.expected { log.Print("actual: "+actual+"expected: ", te.expected, "test index: ", i) t.Fail() @@ -150,6 +151,7 @@ func TestInvalidInput(t *testing.T) { tests := []struct { input string }{ + {"你好\ba"}, {"我是\b猫"}, {"我是\033[1D猫"}, {"我是\bhero"}, @@ -157,10 +159,10 @@ func TestInvalidInput(t *testing.T) { {"\bレ\033[Dモン"}, } for _, te := range tests { - _, err := vt.WriteString(te.input) - if !errors.Is(err, virtualterm.ErrNonDetermistics) { - t.Logf("error is not expected") - t.Fail() + vt.WriteString(te.input) + if _, err := vt.String(); !errors.Is(err, virtualterm.ErrNonDeterministic) { + t.Fatal(err) } + vt.Clear() } }