From 141b50d761f1f7b6fa7269d9bfdcad3df8341dc4 Mon Sep 17 00:00:00 2001 From: Hemang Kandwal Date: Mon, 19 Aug 2024 01:05:54 +0530 Subject: [PATCH] feat(core): add all database logic and multiple profiles in cookies feat(frontend): generate graphql schema --- core/graph/model/models_gen.go | 4 ++- core/graph/schema.graphqls | 1 + core/graph/schema.resolvers.go | 14 ++++++-- core/src/auth/auth.go | 26 ++++++++++++++ core/src/auth/login.go | 55 +++++++++++++++++++++++++++--- core/src/common/context.go | 1 + core/src/common/utils.go | 8 +++++ core/src/engine/engine.go | 27 ++++++++++++++- core/src/engine/plugin.go | 1 + core/src/router/middleware.go | 1 + frontend/src/generated/graphql.tsx | 1 + 11 files changed, 131 insertions(+), 8 deletions(-) diff --git a/core/graph/model/models_gen.go b/core/graph/model/models_gen.go index 8d9422b..476d1ec 100644 --- a/core/graph/model/models_gen.go +++ b/core/graph/model/models_gen.go @@ -92,6 +92,7 @@ type StorageUnit struct { type DatabaseType string const ( + DatabaseTypeAll DatabaseType = "All" DatabaseTypePostgres DatabaseType = "Postgres" DatabaseTypeMySQL DatabaseType = "MySQL" DatabaseTypeSqlite3 DatabaseType = "Sqlite3" @@ -102,6 +103,7 @@ const ( ) var AllDatabaseType = []DatabaseType{ + DatabaseTypeAll, DatabaseTypePostgres, DatabaseTypeMySQL, DatabaseTypeSqlite3, @@ -113,7 +115,7 @@ var AllDatabaseType = []DatabaseType{ func (e DatabaseType) IsValid() bool { switch e { - case DatabaseTypePostgres, DatabaseTypeMySQL, DatabaseTypeSqlite3, DatabaseTypeMongoDb, DatabaseTypeRedis, DatabaseTypeElasticSearch, DatabaseTypeMariaDb: + case DatabaseTypeAll, DatabaseTypePostgres, DatabaseTypeMySQL, DatabaseTypeSqlite3, DatabaseTypeMongoDb, DatabaseTypeRedis, DatabaseTypeElasticSearch, DatabaseTypeMariaDb: return true } return false diff --git a/core/graph/schema.graphqls b/core/graph/schema.graphqls index c84ee16..df938e8 100644 --- a/core/graph/schema.graphqls +++ b/core/graph/schema.graphqls @@ -3,6 +3,7 @@ # https://gqlgen.com/getting-started/ enum DatabaseType { + All, Postgres, MySQL, Sqlite3, diff --git a/core/graph/schema.resolvers.go b/core/graph/schema.resolvers.go index 3c6ddac..f094a64 100644 --- a/core/graph/schema.resolvers.go +++ b/core/graph/schema.resolvers.go @@ -11,6 +11,7 @@ import ( "github.com/clidey/whodb/core/graph/model" "github.com/clidey/whodb/core/src" "github.com/clidey/whodb/core/src/auth" + "github.com/clidey/whodb/core/src/common" "github.com/clidey/whodb/core/src/engine" "github.com/clidey/whodb/core/src/llm" ) @@ -254,8 +255,17 @@ func (r *queryResolver) AIModel(ctx context.Context) ([]string, error) { // AIChat is the resolver for the AIChat field. func (r *queryResolver) AIChat(ctx context.Context, typeArg model.DatabaseType, schema string, input model.ChatInput) ([]*model.AIChatMessage, error) { - config := engine.NewPluginConfig(auth.GetCredentials(ctx)) - messages, err := src.MainEngine.Choose(engine.DatabaseType(typeArg)).Chat(config, schema, input.Model, input.PreviousConversation, input.Query) + var messages []*engine.ChatMessage + var err error + if typeArg == model.DatabaseTypeAll { + configs := common.MapArrayPtr(auth.GetProfiles(ctx), func(credential *engine.Credentials) *engine.PluginConfig { + return engine.NewPluginConfig(credential) + }) + messages, err = src.MainEngine.Chat(configs) + } else { + config := engine.NewPluginConfig(auth.GetCredentials(ctx)) + messages, err = src.MainEngine.Choose(engine.DatabaseType(typeArg)).Chat(config, schema, input.Model, input.PreviousConversation, input.Query) + } if err != nil { return nil, err diff --git a/core/src/auth/auth.go b/core/src/auth/auth.go index 9d3e33b..eb50583 100644 --- a/core/src/auth/auth.go +++ b/core/src/auth/auth.go @@ -17,6 +17,7 @@ type AuthKey string const ( AuthKey_Token AuthKey = "Token" + AuthKey_Profiles AuthKey = "Profiles" AuthKey_Credentials AuthKey = "Credentials" ) @@ -28,6 +29,14 @@ func GetCredentials(ctx context.Context) *engine.Credentials { return credentials.(*engine.Credentials) } +func GetProfiles(ctx context.Context) []*engine.Credentials { + profiles := ctx.Value(AuthKey_Profiles) + if profiles == nil { + return nil + } + return profiles.([]*engine.Credentials) +} + func isPublicRoute(r *http.Request) bool { return !strings.HasPrefix(r.URL.Path, "/api/") && r.URL.Path != "/api" } @@ -70,6 +79,22 @@ func AuthMiddleware(next http.Handler) http.Handler { return } + allProfiles := []*engine.Credentials{} + profileCookie, err := r.Cookie(string(AuthKey_Profiles)) + if err == nil { + decodedValue, err := base64.StdEncoding.DecodeString(profileCookie.Value) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + err = json.Unmarshal(decodedValue, &allProfiles) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + } + if credentials.Id != nil { profiles := src.GetLoginProfiles() for i, loginProfile := range profiles { @@ -87,6 +112,7 @@ func AuthMiddleware(next http.Handler) http.Handler { ctx := r.Context() ctx = context.WithValue(ctx, AuthKey_Credentials, credentials) + ctx = context.WithValue(ctx, AuthKey_Profiles, allProfiles) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/core/src/auth/login.go b/core/src/auth/login.go index a4fd7cf..e6aa5ad 100644 --- a/core/src/auth/login.go +++ b/core/src/auth/login.go @@ -17,19 +17,66 @@ func Login(ctx context.Context, input *model.LoginCredentials) (*model.StatusRes return nil, err } - cookieValue := base64.StdEncoding.EncodeToString(loginInfoJSON) + tokenValue := base64.StdEncoding.EncodeToString(loginInfoJSON) - cookie := &http.Cookie{ + tokenCookie := &http.Cookie{ Name: string(AuthKey_Token), - Value: cookieValue, + Value: tokenValue, Path: "/", HttpOnly: true, Expires: time.Now().Add(24 * time.Hour), } + http.SetCookie(ctx.Value(common.RouterKey_ResponseWriter).(http.ResponseWriter), tokenCookie) - http.SetCookie(ctx.Value(common.RouterKey_ResponseWriter).(http.ResponseWriter), cookie) + var profiles []model.LoginCredentials + profilesCookie, err := ctx.Value(common.RouterKey_Request).(*http.Request).Cookie(string(AuthKey_Profiles)) + if err == nil { + decodedProfiles, err := base64.StdEncoding.DecodeString(profilesCookie.Value) + if err == nil { + json.Unmarshal(decodedProfiles, &profiles) + } + } + + profiles = append(profiles, *input) + + profiles = removeDuplicateProfiles(profiles) + + profilesJSON, err := json.Marshal(profiles) + if err != nil { + return nil, err + } + + profilesValue := base64.StdEncoding.EncodeToString(profilesJSON) + + profilesCookie = &http.Cookie{ + Name: string(AuthKey_Profiles), + Value: profilesValue, + Path: "/", + HttpOnly: true, + Expires: time.Now().Add(24 * time.Hour), + } + http.SetCookie(ctx.Value(common.RouterKey_ResponseWriter).(http.ResponseWriter), profilesCookie) return &model.StatusResponse{ Status: true, }, nil } + +func removeDuplicateProfiles(profiles []model.LoginCredentials) []model.LoginCredentials { + uniqueProfiles := make([]model.LoginCredentials, 0) + profileMap := make(map[string]bool) + + for _, profile := range profiles { + key := generateProfileKey(profile) + if !profileMap[key] { + uniqueProfiles = append(uniqueProfiles, profile) + profileMap[key] = true + } + } + + return uniqueProfiles +} + +func generateProfileKey(profile model.LoginCredentials) string { + return profile.Type + profile.Hostname + profile.Username + profile.Database +} diff --git a/core/src/common/context.go b/core/src/common/context.go index 0efec95..8e6effb 100644 --- a/core/src/common/context.go +++ b/core/src/common/context.go @@ -4,4 +4,5 @@ type RouterKey string const ( RouterKey_ResponseWriter RouterKey = "ResponseWriter" + RouterKey_Request RouterKey = "Request" ) diff --git a/core/src/common/utils.go b/core/src/common/utils.go index 8867622..0f6ca78 100644 --- a/core/src/common/utils.go +++ b/core/src/common/utils.go @@ -75,3 +75,11 @@ func JoinWithQuotes(arr []string) string { return strings.Join(quotedStrings, ", ") } + +func MapArrayPtr[T any, V any](items []*T, mapFunc func(*T) *V) []*V { + mappedItems := []*V{} + for _, item := range items { + mappedItems = append(mappedItems, mapFunc(item)) + } + return mappedItems +} diff --git a/core/src/engine/engine.go b/core/src/engine/engine.go index b2a6715..ea05395 100644 --- a/core/src/engine/engine.go +++ b/core/src/engine/engine.go @@ -1,6 +1,10 @@ package engine -import "github.com/clidey/whodb/core/graph/model" +import ( + "fmt" + + "github.com/clidey/whodb/core/graph/model" +) type DatabaseType string @@ -34,6 +38,27 @@ func (e *Engine) Choose(databaseType DatabaseType) *Plugin { return nil } +func (e *Engine) Chat(configs []*PluginConfig) ([]*ChatMessage, error) { + for _, config := range configs { + plugin := e.Choose(config.Credentials.Type) + schemas, err := plugin.GetSchema(config) + if err != nil { + return nil, err + } + for _, schema := range schemas { + storageUnits, err := plugin.GetStorageUnits(config, schema) + if err != nil { + return nil, err + } + for _, storageUnit := range storageUnits { + // use this to actually create the query + fmt.Sprintf("%v", storageUnit.Name) + } + } + } + return []*ChatMessage{}, nil +} + func GetStorageUnitModel(unit StorageUnit) *model.StorageUnit { attributes := []*model.Record{} for _, attribute := range unit.Attributes { diff --git a/core/src/engine/plugin.go b/core/src/engine/plugin.go index 5af4379..b252fd8 100644 --- a/core/src/engine/plugin.go +++ b/core/src/engine/plugin.go @@ -2,6 +2,7 @@ package engine type Credentials struct { Id *string + Type DatabaseType Hostname string Username string Password string diff --git a/core/src/router/middleware.go b/core/src/router/middleware.go index 754586f..c59d871 100644 --- a/core/src/router/middleware.go +++ b/core/src/router/middleware.go @@ -10,6 +10,7 @@ import ( func contextMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), common.RouterKey_ResponseWriter, w) + ctx = context.WithValue(ctx, common.RouterKey_Request, r) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/frontend/src/generated/graphql.tsx b/frontend/src/generated/graphql.tsx index 8b295e8..b2c0743 100644 --- a/frontend/src/generated/graphql.tsx +++ b/frontend/src/generated/graphql.tsx @@ -37,6 +37,7 @@ export type Column = { }; export enum DatabaseType { + All = 'All', ElasticSearch = 'ElasticSearch', MariaDb = 'MariaDB', MongoDb = 'MongoDB',