diff --git a/api/api_test.go b/api/api_test.go index 8a8b492..bc41f69 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -387,7 +387,9 @@ func TestGetEntry(t *testing.T) { keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter) entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager) - meta, encKey, err := entryManager.CreateEntry(ctx, "text/plain", []byte(testCase.Value), 1, time.Second*10) + expire := time.Second * 10 + maxReads := 1 + meta, encKey, err := entryManager.CreateEntry(ctx, "text/plain", []byte(testCase.Value), &expire, &maxReads) if err != nil { t.Fatal(err) @@ -444,7 +446,9 @@ func TestGetEntryJSON(t *testing.T) { keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter) entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager) - meta, encKey, err := entryManager.CreateEntry(ctx, "text/plain", []byte(testCase.Value), 1, time.Second*10) + expire := time.Second * 10 + maxReads := 1 + meta, encKey, err := entryManager.CreateEntry(ctx, "text/plain", []byte(testCase.Value), &expire, &maxReads) if err != nil { t.Error(err) } diff --git a/internal/api/createentry.go b/internal/api/createentry.go index 29c76d0..d9391aa 100644 --- a/internal/api/createentry.go +++ b/internal/api/createentry.go @@ -20,7 +20,7 @@ type CreateEntryParser interface { // CreateEntryManager is an interface for creating entries type CreateEntryManager interface { - CreateEntry(ctx context.Context, contentType string, body []byte, maxReads int, expiration time.Duration) (*services.EntryMeta, key.Key, error) + CreateEntry(ctx context.Context, contentType string, body []byte, expiration *time.Duration, maxReads *int) (*services.EntryMeta, key.Key, error) } // CreateEntryView is an interface for rendering the create entry response @@ -63,7 +63,7 @@ func (c CreateHandler) handle(w http.ResponseWriter, r *http.Request) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - entry, key, err := c.entryManager.CreateEntry(ctx, data.ContentType, data.Body, data.MaxReads, data.Expiration) + entry, key, err := c.entryManager.CreateEntry(ctx, data.ContentType, data.Body, &data.Expiration, &data.MaxReads) if err != nil { return err diff --git a/internal/api/createentry_test.go b/internal/api/createentry_test.go index 52c62c3..438bc8d 100644 --- a/internal/api/createentry_test.go +++ b/internal/api/createentry_test.go @@ -36,10 +36,9 @@ func (m *MockEntryManager) CreateEntry( ctx context.Context, contentType string, body []byte, - maxReads int, - expiration time.Duration, + expiration *time.Duration, + maxReads *int, ) (*services.EntryMeta, key.Key, error) { - fmt.Println("content type", contentType) args := m.Called(ctx, contentType, body, maxReads, expiration) if args.Get(1) == nil { diff --git a/internal/api/generateentrykey.go b/internal/api/generateentrykey.go index 2b84eeb..1042950 100644 --- a/internal/api/generateentrykey.go +++ b/internal/api/generateentrykey.go @@ -2,7 +2,6 @@ package api import ( "context" - "fmt" "net/http" "time" @@ -18,7 +17,7 @@ type GenerateEntryKeyView interface { } type GenerateEntryKeyManager interface { - GenerateEntryKey(ctx context.Context, UUID string, k key.Key, expire time.Duration, maxReads int) (*services.EntryKeyData, error) + GenerateEntryKey(ctx context.Context, UUID string, k key.Key, expire *time.Duration, maxReads *int) (*services.EntryKeyData, error) } type GenerateEntryKeyHandler struct { @@ -46,12 +45,10 @@ func (g GenerateEntryKeyHandler) handle(w http.ResponseWriter, r *http.Request) return err } - fmt.Println(request) - ctx, cancel := context.WithCancel(context.Background()) defer cancel() - entry, err := g.entryManager.GenerateEntryKey(ctx, request.UUID, request.Key, request.Expiration, request.MaxReads) + entry, err := g.entryManager.GenerateEntryKey(ctx, request.UUID, request.Key, &request.Expiration, &request.MaxReads) if err != nil { return err } diff --git a/internal/api/generateentrykey_test.go b/internal/api/generateentrykey_test.go index 2290333..40ce1d6 100644 --- a/internal/api/generateentrykey_test.go +++ b/internal/api/generateentrykey_test.go @@ -30,7 +30,12 @@ type MockGenerateEntryKeyManager struct { mock.Mock } -func (m *MockGenerateEntryKeyManager) GenerateEntryKey(ctx context.Context, UUID string, k key.Key, expire time.Duration, maxReads int) (*services.EntryKeyData, error) { +func (m *MockGenerateEntryKeyManager) GenerateEntryKey(ctx context.Context, + UUID string, + k key.Key, + expire *time.Duration, + maxReads *int, +) (*services.EntryKeyData, error) { args := m.Called(ctx, UUID, k) return args.Get(0).(*services.EntryKeyData), args.Error(2) } diff --git a/internal/models/entrykey.go b/internal/models/entrykey.go index 7f0125b..3448cb6 100644 --- a/internal/models/entrykey.go +++ b/internal/models/entrykey.go @@ -18,7 +18,14 @@ type EntryKey struct { type EntryKeyModel struct{} -func (e *EntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte, expire time.Time, remainingReads int) (*EntryKey, error) { +func (e *EntryKeyModel) Create(ctx context.Context, + tx *sql.Tx, + entryUUID string, + encryptedKey []byte, + hash []byte, + expire *time.Time, + remainingReads *int, +) (*EntryKey, error) { now := time.Now() res := tx.QueryRowContext(ctx, ` diff --git a/internal/models/entrykey_test.go b/internal/models/entrykey_test.go index 9a75fb9..541ef6a 100644 --- a/internal/models/entrykey_test.go +++ b/internal/models/entrykey_test.go @@ -41,7 +41,9 @@ func createTestEntryKey(ctx context.Context, tx *sql.Tx) (string, string, error) model := &EntryKeyModel{} - entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hash entrykey use tx"), time.Now().Add(time.Hour), 2) + expire := time.Now().Add(time.Hour) + maxReads := 2 + entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hash entrykey use tx"), &expire, &maxReads) if err != nil { return "", "", err @@ -73,7 +75,9 @@ func Test_EntryKeyModel_Create(t *testing.T) { model := &EntryKeyModel{} - entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hashke"), time.Now().Add(time.Hour), 2) + expire := time.Now().Add(time.Hour) + remainingReads := 2 + entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hashke"), &expire, &remainingReads) if err != nil { if err := tx.Rollback(); err != nil { @@ -135,7 +139,9 @@ func Test_EntryKeyModel_Get(t *testing.T) { model := &EntryKeyModel{} for i := 0; i < 10; i++ { - _, err = model.Create(ctx, tx, uid, []byte("test"), []byte(fmt.Sprintf("hashke %d", i)), time.Now().Add(time.Hour), 2) + expire := time.Now().Add(time.Hour) + maxReads := 2 + _, err = model.Create(ctx, tx, uid, []byte("test"), []byte(fmt.Sprintf("hashke %d", i)), &expire, &maxReads) if err != nil { if err := tx.Rollback(); err != nil { @@ -156,7 +162,6 @@ func Test_EntryKeyModel_Get(t *testing.T) { } entryKeys, err := model.Get(ctx, tx, uid) - fmt.Println("ENTRY KEYS", entryKeys) if err != nil { if err := tx.Rollback(); err != nil { diff --git a/internal/services/entrykeymanager.go b/internal/services/entrykeymanager.go index 2475e00..a04a232 100644 --- a/internal/services/entrykeymanager.go +++ b/internal/services/entrykeymanager.go @@ -18,7 +18,14 @@ var ErrEntryCreateFailed = errors.New("entry create failed") var ErrGetDEKFailed = errors.New("get DEK failed") type EntryKeyModel interface { - Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte, expire time.Time, remainingReads int) (*models.EntryKey, error) + Create(ctx context.Context, + tx *sql.Tx, + entryUUID string, + encryptedKey []byte, + hash []byte, + expire *time.Time, + remainingReads *int, + ) (*models.EntryKey, error) Get(ctx context.Context, tx *sql.Tx, entryUUID string) ([]models.EntryKey, error) Delete(ctx context.Context, tx *sql.Tx, uuid string) error SetExpire(ctx context.Context, tx *sql.Tx, uuid string, expire time.Time) error @@ -42,7 +49,12 @@ func NewEntryKeyManager(db *sql.DB, model EntryKeyModel, hasher hasher.Hasher, e } } -func (e *EntryKeyManager) Create(ctx context.Context, entryUUID string, dek key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { +func (e *EntryKeyManager) Create(ctx context.Context, + entryUUID string, + dek key.Key, + expire *time.Time, + maxRead *int, +) (*EntryKey, key.Key, error) { tx, err := e.db.BeginTx(ctx, nil) if err != nil { @@ -87,7 +99,14 @@ func modelEntryKeyToEntryKey(m *models.EntryKey) *EntryKey { } } -func (e *EntryKeyManager) CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { +func (e *EntryKeyManager) CreateWithTx(ctx context.Context, + tx *sql.Tx, + entryUUID string, + dek key.Key, + expire *time.Time, + maxRead *int, +) (*EntryKey, key.Key, + error) { k, err := key.NewGeneratedKey() if err != nil { @@ -202,7 +221,13 @@ func (e *EntryKeyManager) GetDEKTx(ctx context.Context, tx *sql.Tx, entryUUID st } // GenerateEncryptionKey creates a new key for the entry -func (e EntryKeyManager) GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { +func (e EntryKeyManager) GenerateEncryptionKey( + ctx context.Context, + entryUUID string, + existingKey key.Key, + expire *time.Time, + maxRead *int, +) (*EntryKey, key.Key, error) { tx, err := e.db.BeginTx(ctx, nil) if err != nil { return nil, nil, err diff --git a/internal/services/entrykeymanager_test.go b/internal/services/entrykeymanager_test.go index 9971945..544229b 100644 --- a/internal/services/entrykeymanager_test.go +++ b/internal/services/entrykeymanager_test.go @@ -17,7 +17,14 @@ type MockEntryKeyModel struct { mock.Mock } -func (m *MockEntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte, expire time.Time, remainingReads int) (*models.EntryKey, error) { +func (m *MockEntryKeyModel) Create(ctx context.Context, + tx *sql.Tx, + entryUUID string, + encryptedKey []byte, + hash []byte, + expire *time.Time, + remainingReads *int, +) (*models.EntryKey, error) { args := m.Called(ctx, tx, entryUUID, encryptedKey, hash, expire, remainingReads) return args.Get(0).(*models.EntryKey), args.Error(1) } @@ -95,7 +102,7 @@ func TestEntryKeyManager_Create(t *testing.T) { encrypter.On("Encrypt", dek.Get()).Return(encryptedKey, nil) hasher.On("Hash", dek.Get()).Return(hash) - model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash, expire, maxRead).Return(&models.EntryKey{ + model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash, &expire, &maxRead).Return(&models.EntryKey{ UUID: "test-uuid", EntryUUID: entryUUID, EncryptedKey: encryptedKey, @@ -112,7 +119,7 @@ func TestEntryKeyManager_Create(t *testing.T) { } manager := NewEntryKeyManager(db, model, hasher, crypto) - entryKey, key, err := manager.Create(ctx, entryUUID, *dek, expire, maxRead) + entryKey, key, err := manager.Create(ctx, entryUUID, *dek, &expire, &maxRead) model.AssertExpectations(t) encrypter.AssertExpectations(t) @@ -127,107 +134,112 @@ func TestEntryKeyManager_Create(t *testing.T) { assert.NotEmpty(t, key.Get()) } -// func TestEntryKeyManager_Create_NoExpire(t *testing.T) { -// db, sqlMock, err := sqlmock.New() -// if err != nil { -// t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) -// } -// -// defer db.Close() -// -// sqlMock.ExpectBegin() -// sqlMock.ExpectCommit() -// -// ctx := context.Background() -// model := &MockEntryKeyModel{} -// hasher := &MockHasher{} -// encrypter := &EncrypterMock{} -// dek, err := key.NewGeneratedKey() -// assert.NoError(t, err) -// entryUUID := "test-entry-uuid" -// encryptedKey := []byte("test-encrypted-key") -// hash := []byte("test-hash") -// -// hasher.On("Hash", dek.Get()).Return(hash) -// encrypter.On("Encrypt", dek.Get()).Return(encryptedKey, nil) -// model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash).Return(&models.EntryKey{ -// UUID: "test-uuid", -// EntryUUID: entryUUID, -// EncryptedKey: encryptedKey, -// KeyHash: hash, -// Created: time.Now(), -// Expire: sql.NullTime{Time: time.Now(), Valid: false}, -// RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, -// }, nil) -// -// crypto := func(key key.Key) Encrypter { -// return encrypter -// } -// manager := NewEntryKeyManager(db, model, hasher, crypto) -// entryKey, key, err := manager.Create(ctx, entryUUID, *dek, 0, 0) -// -// hasher.AssertExpectations(t) -// encrypter.AssertExpectations(t) -// model.AssertExpectations(t) -// if sqlMock.ExpectationsWereMet() != nil { -// t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) -// } -// assert.NoError(t, err) -// assert.Equal(t, "test-uuid", entryKey.UUID) -// // assert.False(nil, entryKey.Expire) -// assert.NotEmpty(t, key.Get()) -// } - -// func TestEntryKeyManager_Create_NoMaxRead(t *testing.T) { -// db, sqlMock, err := sqlmock.New() -// if err != nil { -// t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) -// } -// -// defer db.Close() -// -// sqlMock.ExpectBegin() -// sqlMock.ExpectCommit() -// -// ctx := context.Background() -// model := &MockEntryKeyModel{} -// hasher := &MockHasher{} -// encrypter := &EncrypterMock{} -// entryUUID := "test-entry-uuid" -// dek := []byte("test-dek") -// encryptedKey := []byte("test-encrypted-key") -// hash := []byte("test-hash") -// -// hasher.On("Hash", dek).Return(hash) -// encrypter.On("Encrypt", dek).Return(encryptedKey, nil) -// model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash).Return(&models.EntryKey{ -// UUID: "test-uuid", -// EntryUUID: entryUUID, -// EncryptedKey: encryptedKey, -// KeyHash: hash, -// Created: time.Now(), -// Expire: sql.NullTime{Time: time.Now(), Valid: false}, -// RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, -// }, nil) -// -// crypto := func(key key.Key) Encrypter { -// return encrypter -// } -// -// manager := NewEntryKeyManager(db, model, hasher, crypto) -// entryKey, key, err := manager.Create(ctx, entryUUID, dek, nil, nil) -// -// model.AssertExpectations(t) -// hasher.AssertExpectations(t) -// encrypter.AssertExpectations(t) -// if sqlMock.ExpectationsWereMet() != nil { -// t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) -// } -// assert.NoError(t, err) -// assert.Equal(t, "test-uuid", entryKey.UUID) -// // key.Get should not return an empty string -// assert.NotEmpty(t, key.Get()) -// } +func TestEntryKeyManager_Create_NoExpire(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + dek, err := key.NewGeneratedKey() + assert.NoError(t, err) + entryUUID := "test-entry-uuid" + encryptedKey := []byte("test-encrypted-key") + hash := []byte("test-hash") + + hasher.On("Hash", dek.Get()).Return(hash) + encrypter.On("Encrypt", dek.Get()).Return(encryptedKey, nil) + var maxRead int + var nullTime sql.NullTime + model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash, mock.Anything, &maxRead).Return(&models.EntryKey{ + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: time.Now(), + Expire: nullTime, + RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, + }, nil) + + crypto := func(key key.Key) Encrypter { + return encrypter + } + manager := NewEntryKeyManager(db, model, hasher, crypto) + entryKey, key, err := manager.Create(ctx, entryUUID, *dek, nil, &maxRead) + + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + model.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.NoError(t, err) + assert.Equal(t, "test-uuid", entryKey.UUID) + // assert.False(nil, entryKey.Expire) + assert.NotEmpty(t, key.Get()) +} + +func TestEntryKeyManager_Create_NoMaxRead(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + entryUUID := "test-entry-uuid" + dek := []byte("test-dek") + encryptedKey := []byte("test-encrypted-key") + hash := []byte("test-hash") + expire := time.Now() + + hasher.On("Hash", dek).Return(hash) + encrypter.On("Encrypt", dek).Return(encryptedKey, nil) + var maxRead *int + model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash, &expire, maxRead). + Return(&models.EntryKey{ + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: time.Now(), + Expire: sql.NullTime{Time: expire, Valid: false}, + RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, + }, nil) + + crypto := func(key key.Key) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + entryKey, key, err := manager.Create(ctx, entryUUID, dek, &expire, nil) + + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.NoError(t, err) + assert.Equal(t, "test-uuid", entryKey.UUID) + // key.Get should not return an empty string + assert.NotEmpty(t, key.Get()) +} func TestEntryKeyManager_GetDEK(t *testing.T) { db, sqlMock, err := sqlmock.New() @@ -461,7 +473,7 @@ func TestEntryKeyManager_GenerateEncryptionKey(t *testing.T) { hasher.On("Hash", dek).Return(hash) encrypter.On("Encrypt", mock.Anything).Return(newEncryptedKey, nil) - model.On("Create", ctx, mock.Anything, entryUUID, newEncryptedKey, hash, expire, maxRead).Return(&models.EntryKey{ + model.On("Create", ctx, mock.Anything, entryUUID, newEncryptedKey, hash, &expire, &maxRead).Return(&models.EntryKey{ UUID: "new-test-uuid", EntryUUID: entryUUID, EncryptedKey: newEncryptedKey, @@ -479,7 +491,7 @@ func TestEntryKeyManager_GenerateEncryptionKey(t *testing.T) { manager := NewEntryKeyManager(db, model, hasher, crypto) - entryKey, key, err := manager.GenerateEncryptionKey(ctx, entryUUID, encryptedKey, expire, maxRead) + entryKey, key, err := manager.GenerateEncryptionKey(ctx, entryUUID, encryptedKey, &expire, &maxRead) model.AssertExpectations(t) hasher.AssertExpectations(t) diff --git a/internal/services/entrymanager.go b/internal/services/entrymanager.go index 2a30aa9..6d514b3 100644 --- a/internal/services/entrymanager.go +++ b/internal/services/entrymanager.go @@ -67,7 +67,7 @@ func NewEntryManager(db *sql.DB, model EntryModel, crypto EncrypterFactory, keyM // It stores the encrypted data in the database // It stores the key in the key manager // It returns the meta data of the entry and the key -func (e *EntryManager) CreateEntry(ctx context.Context, contentType string, data []byte, remainingReads int, expire time.Duration) (*EntryMeta, key.Key, error) { +func (e *EntryManager) CreateEntry(ctx context.Context, contentType string, data []byte, expire *time.Duration, remainingReads *int) (*EntryMeta, key.Key, error) { uid := uuid.NewUUIDString() tx, err := e.db.Begin() @@ -100,7 +100,13 @@ func (e *EntryManager) CreateEntry(ctx context.Context, contentType string, data return nil, nil, errors.Join(ErrCreateEntryFailed, err) } - expireAt := time.Now().Add(expire) + var expireAt *time.Time + + if expire != nil { + fromNow := time.Now().Add(*expire) + expireAt = &fromNow + } + entryKey, kek, err := e.keyManager.CreateWithTx(ctx, tx, uid, dek.Get(), expireAt, remainingReads) if err != nil { @@ -260,8 +266,15 @@ func (e *EntryManager) DeleteExpired(ctx context.Context) error { return nil } -func (e *EntryManager) GenerateEntryKey(ctx context.Context, entryUUID string, k key.Key, expire time.Duration, maxReads int) (*EntryKeyData, error) { - meta, kek, err := e.keyManager.GenerateEncryptionKey(ctx, entryUUID, k, time.Now().Add(expire), maxReads) +func (e *EntryManager) GenerateEntryKey(ctx context.Context, entryUUID string, k key.Key, expire *time.Duration, maxReads *int) (*EntryKeyData, error) { + var expireAt *time.Time + + if expire != nil { + fromNow := time.Now().Add(*expire) + expireAt = &fromNow + } + + meta, kek, err := e.keyManager.GenerateEncryptionKey(ctx, entryUUID, k, expireAt, maxReads) if err != nil { return nil, err } diff --git a/internal/services/entrymanager_test.go b/internal/services/entrymanager_test.go index 61f01ba..efe2dbc 100644 --- a/internal/services/entrymanager_test.go +++ b/internal/services/entrymanager_test.go @@ -57,7 +57,9 @@ func Test_EntryService_Create(t *testing.T) { }, *kek, nil) service := NewEntryManager(db, entryModel, crypto, keyManager) - meta, key, err := service.CreateEntry(ctx, "text/plain", data, 1, time.Minute) + expire := time.Minute + maxReads := 1 + meta, key, err := service.CreateEntry(ctx, "text/plain", data, &expire, &maxReads) assert.NoError(t, err) assert.NotNil(t, meta) @@ -116,7 +118,9 @@ func TestCreateError(t *testing.T) { keyManager := new(MockEntryKeyer) service := NewEntryManager(db, entryModel, crypto, keyManager) - meta, key, err := service.CreateEntry(ctx, "text/plain", data, 1, time.Minute) + expire := time.Minute + maxReads := 1 + meta, key, err := service.CreateEntry(ctx, "text/plain", data, &expire, &maxReads) assert.Error(t, err) assert.Nil(t, meta) @@ -483,7 +487,7 @@ func Test_EntryManager_GenerateEntryKey(t *testing.T) { service := NewEntryManager(nil, nil, nil, keyManager) - entryKey, err := service.GenerateEntryKey(context.Background(), entryUUID, *dek, expire, remainingReads) + entryKey, err := service.GenerateEntryKey(context.Background(), entryUUID, *dek, &expire, &remainingReads) assert.NoError(t, err) assert.Equal(t, entryUUID, entryKey.EntryUUID) @@ -509,7 +513,7 @@ func Test_EntryManager_GenerateEntryKey(t *testing.T) { service := NewEntryManager(nil, nil, nil, keyManager) - entryKey, err := service.GenerateEntryKey(context.Background(), entryUUID, *dek, expire, remainingReads) + entryKey, err := service.GenerateEntryKey(context.Background(), entryUUID, *dek, &expire, &remainingReads) assert.Error(t, err) assert.Nil(t, entryKey) diff --git a/internal/services/interfaces.go b/internal/services/interfaces.go index 7b4e27d..aabb0c2 100644 --- a/internal/services/interfaces.go +++ b/internal/services/interfaces.go @@ -22,9 +22,9 @@ type EntryModel interface { // EntryKeyer is the interface for the entry key manager // It is used to create, read and access entry keys type EntryKeyer interface { - CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire time.Time, maxRead int) (entryKey *EntryKey, kek key.Key, err error) + CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire *time.Time, maxRead *int) (entryKey *EntryKey, kek key.Key, err error) GetDEKTx(ctx context.Context, tx *sql.Tx, entryUUID string, kek key.Key) (dek key.Key, entryKey *EntryKey, err error) - GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) + GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire *time.Time, maxRead *int) (*EntryKey, key.Key, error) UseTx(ctx context.Context, tx *sql.Tx, entryUUID string) error } diff --git a/internal/services/mocks.go b/internal/services/mocks.go index 32e768e..3da39dc 100644 --- a/internal/services/mocks.go +++ b/internal/services/mocks.go @@ -13,7 +13,12 @@ type MockEntryKeyer struct { mock.Mock } -func (m *MockEntryKeyer) Create(ctx context.Context, entryUUID string, dek key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { +func (m *MockEntryKeyer) Create(ctx context.Context, + entryUUID string, + dek key.Key, + expire *time.Time, + maxRead *int, +) (*EntryKey, key.Key, error) { args := m.Called(ctx, entryUUID, dek, expire, maxRead) if args.Get(1) == nil { return args.Get(0).(*EntryKey), nil, args.Error(2) @@ -21,7 +26,13 @@ func (m *MockEntryKeyer) Create(ctx context.Context, entryUUID string, dek key.K return args.Get(0).(*EntryKey), args.Get(1).(key.Key), args.Error(2) } -func (m *MockEntryKeyer) CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { +func (m *MockEntryKeyer) CreateWithTx(ctx context.Context, + tx *sql.Tx, + entryUUID string, + dek key.Key, + expire *time.Time, + maxRead *int, +) (*EntryKey, key.Key, error) { args := m.Called(ctx, tx, entryUUID, dek, expire, maxRead) return args.Get(0).(*EntryKey), args.Get(1).(key.Key), args.Error(2) } @@ -36,7 +47,14 @@ func (m *MockEntryKeyer) GetDEKTx(ctx context.Context, tx *sql.Tx, entryUUID str return args.Get(0).(key.Key), args.Get(1).(*EntryKey), args.Error(2) } -func (m *MockEntryKeyer) GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { +func (m *MockEntryKeyer) GenerateEncryptionKey(ctx context.Context, + entryUUID string, + existingKey key.Key, + expire *time.Time, + maxRead *int, +) (*EntryKey, + key.Key, + error) { args := m.Called(ctx, entryUUID, existingKey, expire, maxRead) return args.Get(0).(*EntryKey), args.Get(1).(key.Key), args.Error(2) }