Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of upload files api. #1703

Merged
merged 12 commits into from
Feb 18, 2024
16 changes: 16 additions & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,12 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
// Make sure directories exists
os.MkdirAll(options.ImageDir, 0755)
os.MkdirAll(options.AudioDir, 0755)
os.MkdirAll(options.UploadDir, 0755)
os.MkdirAll(options.Loader.ModelPath, 0755)

// Load upload json
openai.LoadUploadConfig(options.UploadDir)

modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
Expand All @@ -244,6 +248,18 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
app.Post("/edits", auth, openai.EditEndpoint(cl, options))

// files
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options))
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, options))
app.Get("/files", auth, openai.ListFilesEndpoint(cl, options))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, options))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, options))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options))

// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
Expand Down
207 changes: 207 additions & 0 deletions api/openai/files.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package openai

import (
"encoding/json"
"errors"
"fmt"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
"os"
"path/filepath"
"time"
)

var uploadedFiles []File

// File represents the structure of a file object from the OpenAI API.
type File struct {
ID string `json:"id"` // Unique identifier for the file
Object string `json:"object"` // Type of the object (e.g., "file")
Bytes int `json:"bytes"` // Size of the file in bytes
CreatedAt time.Time `json:"created_at"` // The time at which the file was created
Filename string `json:"filename"` // The name of the file
Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.)
}

func saveUploadConfig(uploadDir string) {
file, err := json.MarshalIndent(uploadedFiles, "", " ")
if err != nil {
log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err)
}

err = os.WriteFile(filepath.Join(uploadDir, "uploadedFiles.json"), file, 0644)
if err != nil {
log.Error().Msgf("Failed to save uploadedFiles to file: %s", err)
}
}

func LoadUploadConfig(uploadPath string) {
file, err := os.ReadFile(filepath.Join(uploadPath, "uploadedFiles.json"))
if err != nil {
log.Error().Msgf("Failed to read file: %s", err)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nit: maybe we should check first if the file does not exist, I guess the first time this is starting it would always display an error message otherwise which could be misleading, but we can address this later

} else {
err = json.Unmarshal(file, &uploadedFiles)
if err != nil {
log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err)
}
}
}

// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := c.FormFile("file")
if err != nil {
return err
}

// Check the file size
if file.Size > int64(o.UploadLimitMB*1024*1024) {
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, o.UploadLimitMB))
}

purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
if purpose == "" {
return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined")
}

// Sanitize the filename to prevent directory traversal
filename := utils.SanitizeFileName(file.Filename)

savePath := filepath.Join(o.UploadDir, filename)

// Check if file already exists
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
return c.Status(fiber.StatusBadRequest).SendString("File already exists")
}

err = c.SaveFile(file, savePath)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error())
}

f := File{
ID: fmt.Sprintf("file-%d", time.Now().Unix()),
Object: "file",
Bytes: int(file.Size),
CreatedAt: time.Now(),
Filename: file.Filename,
Purpose: purpose,
}

uploadedFiles = append(uploadedFiles, f)
saveUploadConfig(o.UploadDir)
return c.Status(fiber.StatusOK).JSON(f)
}
}

// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
func ListFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
type ListFiles struct {
Data []File
Object string
}

return func(c *fiber.Ctx) error {
var listFiles ListFiles

purpose := c.Query("purpose")
if purpose == "" {
listFiles.Data = uploadedFiles
} else {
for _, f := range uploadedFiles {
if purpose == f.Purpose {
listFiles.Data = append(listFiles.Data, f)
}
}
}
listFiles.Object = "list"
return c.Status(fiber.StatusOK).JSON(listFiles)
}
}

func getFileFromRequest(c *fiber.Ctx) (*File, error) {
id := c.Params("file_id")
if id == "" {
return nil, fmt.Errorf("file_id parameter is required")
}

for _, f := range uploadedFiles {
if id == f.ID {
return &f, nil
}
}

return nil, fmt.Errorf("unable to find file id %s", id)
}

// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve
func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}

return c.JSON(file)
}
}

// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete
func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
type DeleteStatus struct {
Id string
Object string
Deleted bool
}

return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}

err = os.Remove(filepath.Join(o.UploadDir, file.Filename))
if err != nil {
// If the file doesn't exist then we should just continue to remove it
if !errors.Is(err, os.ErrNotExist) {
return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err))
}
}

// Remove upload from list
for i, f := range uploadedFiles {
if f.ID == file.ID {
uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...)
break
}
}

saveUploadConfig(o.UploadDir)
return c.JSON(DeleteStatus{
Id: file.ID,
Object: "file",
Deleted: true,
})
}
}

// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents
func GetFilesContentsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}

fileContents, err := os.ReadFile(filepath.Join(o.UploadDir, file.Filename))
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}

return c.Send(fileContents)
}
}
Loading
Loading