Skip to content

Commit

Permalink
feat: allow to create entry without read limit
Browse files Browse the repository at this point in the history
if the expiration is nil, the entry should never expire
if the max reads is nil the entry can be read unlimited many times
  • Loading branch information
Ajnasz committed Jun 13, 2024
1 parent 0c320d3 commit 24e1ef6
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 141 deletions.
8 changes: 6 additions & 2 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ func TestGetEntry(t *testing.T) {

keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter)
entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager)
meta, encKey, err := entryManager.CreateEntry(ctx, "text/plain", []byte(testCase.Value), 1, time.Second*10)
expire := time.Second * 10
maxReads := 1
meta, encKey, err := entryManager.CreateEntry(ctx, "text/plain", []byte(testCase.Value), &expire, &maxReads)

if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -444,7 +446,9 @@ func TestGetEntryJSON(t *testing.T) {

keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter)
entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager)
meta, encKey, err := entryManager.CreateEntry(ctx, "text/plain", []byte(testCase.Value), 1, time.Second*10)
expire := time.Second * 10
maxReads := 1
meta, encKey, err := entryManager.CreateEntry(ctx, "text/plain", []byte(testCase.Value), &expire, &maxReads)
if err != nil {
t.Error(err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/createentry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type CreateEntryParser interface {

// CreateEntryManager is an interface for creating entries
type CreateEntryManager interface {
CreateEntry(ctx context.Context, contentType string, body []byte, maxReads int, expiration time.Duration) (*services.EntryMeta, key.Key, error)
CreateEntry(ctx context.Context, contentType string, body []byte, expiration *time.Duration, maxReads *int) (*services.EntryMeta, key.Key, error)
}

// CreateEntryView is an interface for rendering the create entry response
Expand Down Expand Up @@ -63,7 +63,7 @@ func (c CreateHandler) handle(w http.ResponseWriter, r *http.Request) error {

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
entry, key, err := c.entryManager.CreateEntry(ctx, data.ContentType, data.Body, data.MaxReads, data.Expiration)
entry, key, err := c.entryManager.CreateEntry(ctx, data.ContentType, data.Body, &data.Expiration, &data.MaxReads)

if err != nil {
return err
Expand Down
5 changes: 2 additions & 3 deletions internal/api/createentry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ func (m *MockEntryManager) CreateEntry(
ctx context.Context,
contentType string,
body []byte,
maxReads int,
expiration time.Duration,
expiration *time.Duration,
maxReads *int,
) (*services.EntryMeta, key.Key, error) {
fmt.Println("content type", contentType)
args := m.Called(ctx, contentType, body, maxReads, expiration)

if args.Get(1) == nil {
Expand Down
7 changes: 2 additions & 5 deletions internal/api/generateentrykey.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"context"
"fmt"
"net/http"
"time"

Expand All @@ -18,7 +17,7 @@ type GenerateEntryKeyView interface {
}

type GenerateEntryKeyManager interface {
GenerateEntryKey(ctx context.Context, UUID string, k key.Key, expire time.Duration, maxReads int) (*services.EntryKeyData, error)
GenerateEntryKey(ctx context.Context, UUID string, k key.Key, expire *time.Duration, maxReads *int) (*services.EntryKeyData, error)
}

type GenerateEntryKeyHandler struct {
Expand Down Expand Up @@ -46,12 +45,10 @@ func (g GenerateEntryKeyHandler) handle(w http.ResponseWriter, r *http.Request)
return err
}

fmt.Println(request)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

entry, err := g.entryManager.GenerateEntryKey(ctx, request.UUID, request.Key, request.Expiration, request.MaxReads)
entry, err := g.entryManager.GenerateEntryKey(ctx, request.UUID, request.Key, &request.Expiration, &request.MaxReads)
if err != nil {
return err
}
Expand Down
7 changes: 6 additions & 1 deletion internal/api/generateentrykey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ type MockGenerateEntryKeyManager struct {
mock.Mock
}

func (m *MockGenerateEntryKeyManager) GenerateEntryKey(ctx context.Context, UUID string, k key.Key, expire time.Duration, maxReads int) (*services.EntryKeyData, error) {
func (m *MockGenerateEntryKeyManager) GenerateEntryKey(ctx context.Context,
UUID string,
k key.Key,
expire *time.Duration,
maxReads *int,
) (*services.EntryKeyData, error) {
args := m.Called(ctx, UUID, k)
return args.Get(0).(*services.EntryKeyData), args.Error(2)
}
Expand Down
9 changes: 8 additions & 1 deletion internal/models/entrykey.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ type EntryKey struct {

type EntryKeyModel struct{}

func (e *EntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte, expire time.Time, remainingReads int) (*EntryKey, error) {
func (e *EntryKeyModel) Create(ctx context.Context,
tx *sql.Tx,
entryUUID string,
encryptedKey []byte,
hash []byte,
expire *time.Time,
remainingReads *int,
) (*EntryKey, error) {

now := time.Now()
res := tx.QueryRowContext(ctx, `
Expand Down
13 changes: 9 additions & 4 deletions internal/models/entrykey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ func createTestEntryKey(ctx context.Context, tx *sql.Tx) (string, string, error)

model := &EntryKeyModel{}

entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hash entrykey use tx"), time.Now().Add(time.Hour), 2)
expire := time.Now().Add(time.Hour)
maxReads := 2
entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hash entrykey use tx"), &expire, &maxReads)

if err != nil {
return "", "", err
Expand Down Expand Up @@ -73,7 +75,9 @@ func Test_EntryKeyModel_Create(t *testing.T) {

model := &EntryKeyModel{}

entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hashke"), time.Now().Add(time.Hour), 2)
expire := time.Now().Add(time.Hour)
remainingReads := 2
entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hashke"), &expire, &remainingReads)

if err != nil {
if err := tx.Rollback(); err != nil {
Expand Down Expand Up @@ -135,7 +139,9 @@ func Test_EntryKeyModel_Get(t *testing.T) {
model := &EntryKeyModel{}

for i := 0; i < 10; i++ {
_, err = model.Create(ctx, tx, uid, []byte("test"), []byte(fmt.Sprintf("hashke %d", i)), time.Now().Add(time.Hour), 2)
expire := time.Now().Add(time.Hour)
maxReads := 2
_, err = model.Create(ctx, tx, uid, []byte("test"), []byte(fmt.Sprintf("hashke %d", i)), &expire, &maxReads)

if err != nil {
if err := tx.Rollback(); err != nil {
Expand All @@ -156,7 +162,6 @@ func Test_EntryKeyModel_Get(t *testing.T) {
}

entryKeys, err := model.Get(ctx, tx, uid)
fmt.Println("ENTRY KEYS", entryKeys)

if err != nil {
if err := tx.Rollback(); err != nil {
Expand Down
33 changes: 29 additions & 4 deletions internal/services/entrykeymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ var ErrEntryCreateFailed = errors.New("entry create failed")
var ErrGetDEKFailed = errors.New("get DEK failed")

type EntryKeyModel interface {
Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte, expire time.Time, remainingReads int) (*models.EntryKey, error)
Create(ctx context.Context,
tx *sql.Tx,
entryUUID string,
encryptedKey []byte,
hash []byte,
expire *time.Time,
remainingReads *int,
) (*models.EntryKey, error)
Get(ctx context.Context, tx *sql.Tx, entryUUID string) ([]models.EntryKey, error)
Delete(ctx context.Context, tx *sql.Tx, uuid string) error
SetExpire(ctx context.Context, tx *sql.Tx, uuid string, expire time.Time) error
Expand All @@ -42,7 +49,12 @@ func NewEntryKeyManager(db *sql.DB, model EntryKeyModel, hasher hasher.Hasher, e
}
}

func (e *EntryKeyManager) Create(ctx context.Context, entryUUID string, dek key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) {
func (e *EntryKeyManager) Create(ctx context.Context,
entryUUID string,
dek key.Key,
expire *time.Time,
maxRead *int,
) (*EntryKey, key.Key, error) {

tx, err := e.db.BeginTx(ctx, nil)
if err != nil {
Expand Down Expand Up @@ -87,7 +99,14 @@ func modelEntryKeyToEntryKey(m *models.EntryKey) *EntryKey {
}
}

func (e *EntryKeyManager) CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) {
func (e *EntryKeyManager) CreateWithTx(ctx context.Context,
tx *sql.Tx,
entryUUID string,
dek key.Key,
expire *time.Time,
maxRead *int,
) (*EntryKey, key.Key,
error) {
k, err := key.NewGeneratedKey()

if err != nil {
Expand Down Expand Up @@ -202,7 +221,13 @@ func (e *EntryKeyManager) GetDEKTx(ctx context.Context, tx *sql.Tx, entryUUID st
}

// GenerateEncryptionKey creates a new key for the entry
func (e EntryKeyManager) GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) {
func (e EntryKeyManager) GenerateEncryptionKey(
ctx context.Context,
entryUUID string,
existingKey key.Key,
expire *time.Time,
maxRead *int,
) (*EntryKey, key.Key, error) {
tx, err := e.db.BeginTx(ctx, nil)
if err != nil {
return nil, nil, err
Expand Down
Loading

0 comments on commit 24e1ef6

Please sign in to comment.