Skip to content

Commit

Permalink
feat!: support multiple generative models (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
reugn authored Oct 23, 2024
1 parent de16134 commit 84f9c7e
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 24 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ If you don't already have one, create a key in [Google AI Studio](https://makers
The system chat message must begin with an exclamation mark and is used for internal operations.
A short list of supported system commands:

| Command | Description
| --- | ---
| !q | Quit the application
| !p | Delete the history used as chat context by the model
| !m | Toggle input mode (single-line <-> multi-line)
| Command | Description |
|---------|------------------------------------------------------|
| !q | Quit the application |
| !p | Delete the history used as chat context by the model |
| !i | Toggle input mode (single-line <-> multi-line) |
| !m | Select generative model |

### CLI help
```console
Expand All @@ -54,7 +55,8 @@ Usage:
Flags:
-f, --format render markdown-formatted response (default true)
-h, --help help for this command
-m, --multiline read input as a multi-line string
-m, --model string generative model name (default "gemini-pro")
--multiline read input as a multi-line string
-s, --style string markdown format style (ascii, dark, light, pink, notty, dracula) (default "auto")
-t, --term string multi-line input terminator (default "$")
-v, --version version for this command
Expand Down
5 changes: 3 additions & 2 deletions cmd/gemini/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ func run() int {
}

var opts cli.ChatOpts
rootCmd.Flags().StringVarP(&opts.Model, "model", "m", gemini.DefaultModel, "generative model name")
rootCmd.Flags().BoolVarP(&opts.Format, "format", "f", true, "render markdown-formatted response")
rootCmd.Flags().StringVarP(&opts.Style, "style", "s", "auto",
"markdown format style (ascii, dark, light, pink, notty, dracula)")
rootCmd.Flags().BoolVarP(&opts.Multiline, "multiline", "m", false, "read input as a multi-line string")
rootCmd.Flags().BoolVar(&opts.Multiline, "multiline", false, "read input as a multi-line string")
rootCmd.Flags().StringVarP(&opts.Terminator, "term", "t", "$", "multi-line input terminator")

rootCmd.RunE = func(_ *cobra.Command, _ []string) error {
apiKey := os.Getenv(apiKeyEnv)
chatSession, err := gemini.NewChatSession(context.Background(), apiKey)
chatSession, err := gemini.NewChatSession(context.Background(), opts.Model, apiKey)
if err != nil {
return err
}
Expand Down
44 changes: 38 additions & 6 deletions gemini/chat_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,36 @@ package gemini

import (
"context"
"sync"

"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
)

// ChatSession represents a gemini-pro powered chat session.
const DefaultModel = "gemini-pro"

// ChatSession represents a gemini powered chat session.
type ChatSession struct {
ctx context.Context
ctx context.Context

client *genai.Client
session *genai.ChatSession

loadModels sync.Once
models []string
}

// NewChatSession returns a new ChatSession.
func NewChatSession(ctx context.Context, apiKey string) (*ChatSession, error) {
// NewChatSession returns a new [ChatSession].
func NewChatSession(ctx context.Context, model, apiKey string) (*ChatSession, error) {
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
if err != nil {
return nil, err
}

return &ChatSession{
ctx: ctx,
client: client,
session: client.GenerativeModel("gemini-pro").StartChat(),
session: client.GenerativeModel(model).StartChat(),
}, nil
}

Expand All @@ -37,12 +45,36 @@ func (c *ChatSession) SendMessageStream(input string) *genai.GenerateContentResp
return c.session.SendMessageStream(c.ctx, genai.Text(input))
}

// SetGenerativeModel sets the name of the generative model for the chat.
// It preserves the history from the previous chat session.
func (c *ChatSession) SetGenerativeModel(model string) {
history := c.session.History
c.session = c.client.GenerativeModel(model).StartChat()
c.session.History = history
}

// ListModels returns a list of the supported generative model names.
func (c *ChatSession) ListModels() []string {
c.loadModels.Do(func() {
c.models = []string{DefaultModel}
iter := c.client.ListModels(c.ctx)
for {
modelInfo, err := iter.Next()
if err != nil {
break
}
c.models = append(c.models, modelInfo.Name)
}
})
return c.models
}

// ClearHistory clears chat history.
func (c *ChatSession) ClearHistory() {
c.session.History = make([]*genai.Content, 0)
}

// Close closes the genai.Client.
// Close closes the chat session.
func (c *ChatSession) Close() error {
return c.client.Close()
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/charmbracelet/glamour v0.7.0
github.com/chzyer/readline v1.5.1
github.com/google/generative-ai-go v0.18.0
github.com/manifoldco/promptui v0.9.0
github.com/muesli/termenv v0.15.2
github.com/spf13/cobra v1.8.1
google.golang.org/api v0.196.0
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd3
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/charmbracelet/glamour v0.7.0 h1:2BtKGZ4iVJCDfMF229EzbeR1QRKLWztO9dMtjmqZSng=
github.com/charmbracelet/glamour v0.7.0/go.mod h1:jUMh5MeihljJPQbJ/wf4ldw2+yBP59+ctV36jASy7ps=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI=
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
Expand Down Expand Up @@ -93,6 +96,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA=
github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
Expand Down Expand Up @@ -169,6 +174,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
1 change: 1 addition & 0 deletions internal/cli/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

// ChatOpts represents Chat configuration options.
type ChatOpts struct {
Model string
Format bool
Style string
Multiline bool
Expand Down
50 changes: 40 additions & 10 deletions internal/cli/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ const (
systemCmdPrefix = "!"
systemCmdQuit = "!q"
systemCmdPurgeHistory = "!p"
systemCmdToggleInputMode = "!m"
systemCmdSelectInputMode = "!i"
systemCmdSelectModel = "!m"
)

type command interface {
Expand All @@ -44,18 +45,39 @@ func (c *systemCommand) run(message string) bool {
case systemCmdPurgeHistory:
c.chat.model.ClearHistory()
c.print("Cleared the chat history.")
case systemCmdToggleInputMode:
case systemCmdSelectInputMode:
multiline, err := selectInputMode(c.chat.opts.Multiline)
if err != nil {
c.error(err)
break
}
if multiline == c.chat.opts.Multiline {
c.printSelectedCurrent()
break
}
c.chat.opts.Multiline = multiline
if c.chat.opts.Multiline {
c.print("Switched to single-line input mode.")
c.chat.reader.HistoryEnable()
c.chat.opts.Multiline = false
} else {
c.print("Switched to multi-line input mode.")
// disable history for multi-line messages since it is
// unusable for future requests
c.chat.reader.HistoryDisable()
c.chat.opts.Multiline = true
} else {
c.print("Switched to single-line input mode.")
c.chat.reader.HistoryEnable()
}
case systemCmdSelectModel:
model, err := selectModel(c.chat.opts.Model, c.chat.model.ListModels())
if err != nil {
c.error(err)
break
}
if model == c.chat.opts.Model {
c.printSelectedCurrent()
break
}
c.chat.opts.Model = model
c.chat.model.SetGenerativeModel(model)
c.print(fmt.Sprintf("Selected '%s' generative model.", model))
default:
c.print("Unknown system command.")
}
Expand All @@ -66,6 +88,14 @@ func (c *systemCommand) print(message string) {
fmt.Printf("%s%s\n", c.chat.prompt.cli, message)
}

func (c *systemCommand) printSelectedCurrent() {
fmt.Printf("%sThe selection is unchanged.\n", c.chat.prompt.cli)
}

func (c *systemCommand) error(err error) {
fmt.Printf(color.Red("%s%s\n"), c.chat.prompt.cli, err)
}

type geminiCommand struct {
chat *Chat
spinner *spinner
Expand Down Expand Up @@ -104,7 +134,7 @@ func (c *geminiCommand) runBlocking(message string) {
var buf strings.Builder
for _, candidate := range response.Candidates {
for _, part := range candidate.Content.Parts {
buf.WriteString(fmt.Sprintf("%s", part))
fmt.Fprintf(&buf, "%s", part)
}
}
output, err := glamour.Render(buf.String(), c.chat.opts.Style)
Expand Down Expand Up @@ -138,6 +168,6 @@ func (c *geminiCommand) runStreaming(message string) {
}

func (c *geminiCommand) printFlush(message string) {
fmt.Fprintf(c.writer, "%s", message)
c.writer.Flush()
_, _ = c.writer.WriteString(message)
_ = c.writer.Flush()
}
51 changes: 51 additions & 0 deletions internal/cli/select.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package cli

import (
"slices"

"github.com/manifoldco/promptui"
)

var (
inputMode = []string{"single-line", "multi-line"}
)

// selectModel returns the selected generative model name.
func selectModel(current string, models []string) (string, error) {
prompt := promptui.Select{
Label: "Select generative model",
HideSelected: true,
Items: models,
CursorPos: slices.Index(models, current),
}

_, result, err := prompt.Run()
if err != nil {
return "", err
}

return result, nil
}

// selectInputMode returns true if multiline input is selected;
// otherwise, it returns false.
func selectInputMode(multiline bool) (bool, error) {
var cursorPos int
if multiline {
cursorPos = 1
}

prompt := promptui.Select{
Label: "Select input mode",
HideSelected: true,
Items: inputMode,
CursorPos: cursorPos,
}

_, result, err := prompt.Run()
if err != nil {
return false, err
}

return result == inputMode[1], nil
}

0 comments on commit 84f9c7e

Please sign in to comment.