Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLM capabilities to search through all databases #91

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core/graph/model/models_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions core/graph/schema.graphqls
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# https://gqlgen.com/getting-started/

enum DatabaseType {
All,
Postgres,
MySQL,
Sqlite3,
Expand Down
14 changes: 12 additions & 2 deletions core/graph/schema.resolvers.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions core/src/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type AuthKey string

const (
AuthKey_Token AuthKey = "Token"
AuthKey_Profiles AuthKey = "Profiles"
AuthKey_Credentials AuthKey = "Credentials"
)

Expand All @@ -28,6 +29,14 @@ func GetCredentials(ctx context.Context) *engine.Credentials {
return credentials.(*engine.Credentials)
}

func GetProfiles(ctx context.Context) []*engine.Credentials {
profiles := ctx.Value(AuthKey_Profiles)
if profiles == nil {
return nil
}
return profiles.([]*engine.Credentials)
}

func isPublicRoute(r *http.Request) bool {
return !strings.HasPrefix(r.URL.Path, "/api/") && r.URL.Path != "/api"
}
Expand Down Expand Up @@ -70,6 +79,22 @@ func AuthMiddleware(next http.Handler) http.Handler {
return
}

allProfiles := []*engine.Credentials{}
profileCookie, err := r.Cookie(string(AuthKey_Profiles))
if err == nil {
decodedValue, err := base64.StdEncoding.DecodeString(profileCookie.Value)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}

err = json.Unmarshal(decodedValue, &allProfiles)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
}

if credentials.Id != nil {
profiles := src.GetLoginProfiles()
for i, loginProfile := range profiles {
Expand All @@ -87,6 +112,7 @@ func AuthMiddleware(next http.Handler) http.Handler {

ctx := r.Context()
ctx = context.WithValue(ctx, AuthKey_Credentials, credentials)
ctx = context.WithValue(ctx, AuthKey_Profiles, allProfiles)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
Expand Down
55 changes: 51 additions & 4 deletions core/src/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,66 @@ func Login(ctx context.Context, input *model.LoginCredentials) (*model.StatusRes
return nil, err
}

cookieValue := base64.StdEncoding.EncodeToString(loginInfoJSON)
tokenValue := base64.StdEncoding.EncodeToString(loginInfoJSON)

cookie := &http.Cookie{
tokenCookie := &http.Cookie{
Name: string(AuthKey_Token),
Value: cookieValue,
Value: tokenValue,
Path: "/",
HttpOnly: true,
Expires: time.Now().Add(24 * time.Hour),
}
http.SetCookie(ctx.Value(common.RouterKey_ResponseWriter).(http.ResponseWriter), tokenCookie)

http.SetCookie(ctx.Value(common.RouterKey_ResponseWriter).(http.ResponseWriter), cookie)
var profiles []model.LoginCredentials
profilesCookie, err := ctx.Value(common.RouterKey_Request).(*http.Request).Cookie(string(AuthKey_Profiles))
if err == nil {
decodedProfiles, err := base64.StdEncoding.DecodeString(profilesCookie.Value)
if err == nil {
json.Unmarshal(decodedProfiles, &profiles)
}
}

profiles = append(profiles, *input)

profiles = removeDuplicateProfiles(profiles)

profilesJSON, err := json.Marshal(profiles)
if err != nil {
return nil, err
}

profilesValue := base64.StdEncoding.EncodeToString(profilesJSON)

profilesCookie = &http.Cookie{
Name: string(AuthKey_Profiles),
Value: profilesValue,
Path: "/",
HttpOnly: true,
Expires: time.Now().Add(24 * time.Hour),
}
http.SetCookie(ctx.Value(common.RouterKey_ResponseWriter).(http.ResponseWriter), profilesCookie)

return &model.StatusResponse{
Status: true,
}, nil
}

func removeDuplicateProfiles(profiles []model.LoginCredentials) []model.LoginCredentials {
uniqueProfiles := make([]model.LoginCredentials, 0)
profileMap := make(map[string]bool)

for _, profile := range profiles {
key := generateProfileKey(profile)
if !profileMap[key] {
uniqueProfiles = append(uniqueProfiles, profile)
profileMap[key] = true
}
}

return uniqueProfiles
}

func generateProfileKey(profile model.LoginCredentials) string {
return profile.Type + profile.Hostname + profile.Username + profile.Database
}
1 change: 1 addition & 0 deletions core/src/common/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ type RouterKey string

const (
RouterKey_ResponseWriter RouterKey = "ResponseWriter"
RouterKey_Request RouterKey = "Request"
)
8 changes: 8 additions & 0 deletions core/src/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,11 @@ func JoinWithQuotes(arr []string) string {

return strings.Join(quotedStrings, ", ")
}

func MapArrayPtr[T any, V any](items []*T, mapFunc func(*T) *V) []*V {
mappedItems := []*V{}
for _, item := range items {
mappedItems = append(mappedItems, mapFunc(item))
}
return mappedItems
}
27 changes: 26 additions & 1 deletion core/src/engine/engine.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package engine

import "github.com/clidey/whodb/core/graph/model"
import (
"fmt"

"github.com/clidey/whodb/core/graph/model"
)

type DatabaseType string

Expand Down Expand Up @@ -34,6 +38,27 @@ func (e *Engine) Choose(databaseType DatabaseType) *Plugin {
return nil
}

func (e *Engine) Chat(configs []*PluginConfig) ([]*ChatMessage, error) {
for _, config := range configs {
plugin := e.Choose(config.Credentials.Type)
schemas, err := plugin.GetSchema(config)
if err != nil {
return nil, err
}
for _, schema := range schemas {
storageUnits, err := plugin.GetStorageUnits(config, schema)
if err != nil {
return nil, err
}
for _, storageUnit := range storageUnits {
// use this to actually create the query
fmt.Sprintf("%v", storageUnit.Name)
}
}
}
return []*ChatMessage{}, nil
}

func GetStorageUnitModel(unit StorageUnit) *model.StorageUnit {
attributes := []*model.Record{}
for _, attribute := range unit.Attributes {
Expand Down
1 change: 1 addition & 0 deletions core/src/engine/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package engine

type Credentials struct {
Id *string
Type DatabaseType
Hostname string
Username string
Password string
Expand Down
1 change: 1 addition & 0 deletions core/src/router/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
func contextMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), common.RouterKey_ResponseWriter, w)
ctx = context.WithValue(ctx, common.RouterKey_Request, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
1 change: 1 addition & 0 deletions frontend/src/generated/graphql.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export type Column = {
};

export enum DatabaseType {
All = 'All',
ElasticSearch = 'ElasticSearch',
MariaDb = 'MariaDB',
MongoDb = 'MongoDB',
Expand Down
Loading