Skip to content

Commit

Permalink
Merge pull request #390 from matrix-org/kegan/fallback-key-types-fix
Browse files Browse the repository at this point in the history
bugfix: correctly tell clients when the fallback key has been used
  • Loading branch information
kegsay authored Jan 5, 2024
2 parents a8e9c56 + eae54fb commit e6a2e67
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 18 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,12 @@ go test -p 1 -count 1 $(go list ./... | grep -v tests-e2e) -timeout 120s
Run end-to-end tests:

```shell
# Run each line in a separate terminal windows. Will need to `docker login`
# to ghcr and pull the image.
docker run --rm -e "SYNAPSE_COMPLEMENT_DATABASE=sqlite" -e "SERVER_NAME=synapse" -p 8888:8008 ghcr.io/matrix-org/synapse-service:v1.72.0
# Will need to `docker login` to ghcr and pull the image.
docker run -d --rm -e "SYNAPSE_COMPLEMENT_DATABASE=sqlite" -e "SERVER_NAME=synapse" -p 8888:8008 ghcr.io/matrix-org/synapse-service:v1.94.0

export SYNCV3_SECRET=foobar
export SYNCV3_SERVER=http://localhost:8888
export SYNCV3_DB="user=$(whoami) dbname=syncv3_test sslmode=disable"

(go build ./cmd/syncv3 && dropdb syncv3_test && createdb syncv3_test && cd tests-e2e && ./run-tests.sh -count=1 .)
```
1 change: 1 addition & 0 deletions internal/device_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type DeviceData struct {
OTKCounts MapStringInt `json:"otk"`
// Contains the latest device_unused_fallback_key_types value
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
// If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up.
FallbackKeyTypes []string `json:"fallback"`

DeviceLists DeviceLists `json:"dl"`
Expand Down
32 changes: 21 additions & 11 deletions sync2/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,20 +727,30 @@ func (p *poller) parseE2EEData(ctx context.Context, res *SyncResponse) error {
}
shouldSetOTKs = true
}
var changedFallbackTypes []string
var changedFallbackTypes []string // nil slice == don't set, empty slice = no fallback key
shouldSetFallbackKeys := false
if len(res.DeviceUnusedFallbackKeyTypes) > 0 {
if len(p.fallbackKeyTypes) != len(res.DeviceUnusedFallbackKeyTypes) {
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
} else {
for i := range res.DeviceUnusedFallbackKeyTypes {
if res.DeviceUnusedFallbackKeyTypes[i] != p.fallbackKeyTypes[i] {
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
break
}
if len(p.fallbackKeyTypes) != len(res.DeviceUnusedFallbackKeyTypes) {
// length mismatch always causes an update
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
shouldSetFallbackKeys = true
} else {
// lengths match, if they are non-zero then compare each element.
// if they are zero, check for nil vs empty slice.
if len(res.DeviceUnusedFallbackKeyTypes) == 0 {
isCurrentNil := res.DeviceUnusedFallbackKeyTypes == nil
isPreviousNil := p.fallbackKeyTypes == nil
if isCurrentNil != isPreviousNil {
shouldSetFallbackKeys = true
changedFallbackTypes = []string{}
}
}
for i := range res.DeviceUnusedFallbackKeyTypes {
if res.DeviceUnusedFallbackKeyTypes[i] != p.fallbackKeyTypes[i] {
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
shouldSetFallbackKeys = true
break
}
}
shouldSetFallbackKeys = true
}

deviceListChanges := internal.ToDeviceListChangesMap(res.DeviceLists.Changed, res.DeviceLists.Left)
Expand Down
6 changes: 3 additions & 3 deletions sync3/extensions/e2ee.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (r *E2EERequest) Name() string {
type E2EEResponse struct {
OTKCounts map[string]int `json:"device_one_time_keys_count,omitempty"`
DeviceLists *E2EEDeviceList `json:"device_lists,omitempty"`
FallbackKeyTypes []string `json:"device_unused_fallback_key_types,omitempty"`
FallbackKeyTypes *[]string `json:"device_unused_fallback_key_types,omitempty"`
}

type E2EEDeviceList struct {
Expand All @@ -37,7 +37,7 @@ func (r *E2EEResponse) HasData(isInitial bool) bool {
if isInitial {
return true // ensure we send OTK counts immediately
}
return r.DeviceLists != nil || len(r.FallbackKeyTypes) > 0 || len(r.OTKCounts) > 0
return r.DeviceLists != nil || r.FallbackKeyTypes != nil || len(r.OTKCounts) > 0
}

func (r *E2EERequest) AppendLive(ctx context.Context, res *Response, extCtx Context, up caches.Update) {
Expand All @@ -63,7 +63,7 @@ func (r *E2EERequest) ProcessInitial(ctx context.Context, res *Response, extCtx
extRes := &E2EEResponse{}
hasUpdates := false
if dd.FallbackKeyTypes != nil && (dd.FallbackKeysChanged() || extCtx.IsInitial) {
extRes.FallbackKeyTypes = dd.FallbackKeyTypes
extRes.FallbackKeyTypes = &dd.FallbackKeyTypes
hasUpdates = true
}
if dd.OTKCounts != nil && (dd.OTKCountChanged() || extCtx.IsInitial) {
Expand Down
287 changes: 287 additions & 0 deletions tests-e2e/encryption_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
package syncv3_test

import (
"encoding/json"
"fmt"
"testing"

"github.com/matrix-org/complement/b"
"github.com/matrix-org/complement/client"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/sync3/extensions"
"github.com/matrix-org/sliding-sync/testutils/m"
)

func TestEncryptionFallbackKey(t *testing.T) {
alice := registerNewUser(t)
bob := registerNewUser(t)
roomID := alice.MustCreateRoom(t, map[string]interface{}{
"preset": "public_chat",
})
bob.JoinRoom(t, roomID, nil)

// snaffled from rust SDK
keysUploadBody := fmt.Sprintf(`{
"device_keys": {
"algorithms": [
"m.olm.v1.curve25519-aes-sha2",
"m.megolm.v1.aes-sha2"
],
"device_id": "MUPCQIATEC",
"keys": {
"curve25519:MUPCQIATEC": "NroPrV4HHJ/Wj0A0XMrHt7IuThVnwpT6tRZXQXkO4kI",
"ed25519:MUPCQIATEC": "G9zNR/pZb24Rm0FXiQYutSzcbQvii+AZn/4cmi6LOUI"
},
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "2CHK2tJO/p2OiNWC2jLKsH5t+pHwnomSHOIpAPuEVi2vJZ4BRRsb4tSFYzEx4cUDg3KCYjoQuCymYHpnk1uqDQ"
}
},
"user_id": "%s"
},
"fallback_keys": {
"signed_curve25519:AAAAAAAAAAA": {
"fallback": true,
"key": "s5+eOJYK1s5xPt51BlYEXx8fQ8NqpwAUjE1mVxw05V8",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "TLGi0LJEDxgt37gBCpd8huZa72h0UTB8jIEUoTz/rjbCcGQo1xOlvA5rU+RoTkF1KwVtduOMbZcSGg4ZTfBkDQ"
}
}
}
},
"one_time_keys": {
"signed_curve25519:AAAAAAAAAA0": {
"key": "IuCQvr2AaZC70tCG6g1ZardACNe3mcKZ2PjKJ2p49UM",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "FXBkzwuLkfriWJ1B2z9wTHvi7WTOZGvs2oSNJ7CycXJYC6k06sa7a+OMQtpMP2RTuIpiYC+wZ3nFoKp1FcCcBQ"
}
}
},
"signed_curve25519:AAAAAAAAAA4": {
"key": "pgeLFCJPLYUtyLPKDPr76xRYgPjjY4/lEUH98tExxCo",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "/o44D5qjTdiYORSXmCVYE3Vzvbz2OlIBC58ELe+EAAgIZTJyDxmBJIFotP6CIuFmB/p4lGCd41Fb6T5BnmLvBQ"
}
}
},
"signed_curve25519:AAAAAAAAAA8": {
"key": "gAhoEOtrGTEG+gfAsCU+JS7+wJTlC51+kZ9vLr9BZGA",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "DLDj1c2UncqcCrEwSUEf31ni6W+E6D58EEGFIWj++ydBxuiEnHqFMF7AZU8GGcjQBDIH13uNe8xxO7/KeBbUDQ"
}
}
}
}
}`, bob.UserID, bob.UserID, bob.UserID, bob.UserID, bob.UserID, bob.UserID)

bob.MustDo(t, "POST", []string{"_matrix", "client", "v3", "keys", "upload"},
client.WithRawBody([]byte(keysUploadBody)), client.WithContentType("application/json"),
)

res := bob.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
})
m.MatchResponse(t, res, m.MatchFallbackKeyTypes([]string{"signed_curve25519"}), m.MatchOTKCounts(map[string]int{
"signed_curve25519": 3,
}))

// claim a OTK, it should decrease the count
mustClaimOTK(t, alice, bob)
// claiming OTKs does not wake up the sync loop, so send something to kick it.
alice.MustSendTyping(t, roomID, true, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// OTK was claimed so change should be included.
// fallback key was not touched so should be missing.
MatchOTKAndFallbackTypes(map[string]int{
"signed_curve25519": 2,
}, nil),
)

mustClaimOTK(t, alice, bob)
alice.MustSendTyping(t, roomID, false, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// OTK was claimed so change should be included.
// fallback key was not touched so should be missing.
MatchOTKAndFallbackTypes(map[string]int{
"signed_curve25519": 1,
}, nil),
)

mustClaimOTK(t, alice, bob)
alice.MustSendTyping(t, roomID, true, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// OTK was claimed so change should be included.
// fallback key was not touched so should be missing.
MatchOTKAndFallbackTypes(map[string]int{
"signed_curve25519": 0,
}, nil),
)

mustClaimOTK(t, alice, bob)
alice.MustSendTyping(t, roomID, false, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// no OTK change here so it shouldn't be included.
// we should be explicitly sent device_unused_fallback_key_types: []
MatchOTKAndFallbackTypes(nil, []string{}),
)

// now re-upload a fallback key, it should be repopulated.
keysUploadBody = fmt.Sprintf(`{
"fallback_keys": {
"signed_curve25519:AAAAAAAAADA": {
"fallback": true,
"key": "N8DKj83RTN7lLZrH6shMqHbVhNrxd96OQseQVFmNgTU",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "ZnKsVcNmOLBv0LMGeNpCfCO2am9L223EiyddWPx9wPOtuYt6KZIPox/SFwVmqBwkUdnmeTb6tVgCpZwcH8doDw"
}
}
}
}
}`, bob.UserID)
bob.MustDo(t, "POST", []string{"_matrix", "client", "v3", "keys", "upload"},
client.WithRawBody([]byte(keysUploadBody)), client.WithContentType("application/json"),
)

alice.MustSendTyping(t, roomID, true, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// no OTK change here so it shouldn't be included.
// we should be explicitly sent device_unused_fallback_key_types: ["signed_curve25519"]
MatchOTKAndFallbackTypes(nil, []string{"signed_curve25519"}),
)

// another claim should remove it
mustClaimOTK(t, alice, bob)

alice.MustSendTyping(t, roomID, false, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// no OTK change here so it shouldn't be included.
// we should be explicitly sent device_unused_fallback_key_types: []
MatchOTKAndFallbackTypes(nil, []string{}),
)
}

// Regression test to make sure EX uploads a fallback key initially.
// EX relies on device_unused_fallback_key_types: [] being present in the
// sync response before it will upload any fallback keys at all, it doesn't
// automatically do it on first login.
func TestEncryptionFallbackKeyToldIfMissingInitially(t *testing.T) {
alice := registerNewUser(t)
bob := registerNewUser(t)
roomID := alice.MustCreateRoom(t, map[string]interface{}{
"preset": "public_chat",
})
bob.JoinRoom(t, roomID, nil)
res := bob.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
})
m.MatchResponse(t, res, m.MatchFallbackKeyTypes([]string{}))

// upload a fallback key and do another initial request => should include key
keysUploadBody := fmt.Sprintf(`{
"fallback_keys": {
"signed_curve25519:AAAAAAAAADA": {
"fallback": true,
"key": "N8DKj83RTN7lLZrH6shMqHbVhNrxd96OQseQVFmNgTU",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "ZnKsVcNmOLBv0LMGeNpCfCO2am9L223EiyddWPx9wPOtuYt6KZIPox/SFwVmqBwkUdnmeTb6tVgCpZwcH8doDw"
}
}
}
}
}`, bob.UserID)
bob.MustDo(t, "POST", []string{"_matrix", "client", "v3", "keys", "upload"},
client.WithRawBody([]byte(keysUploadBody)), client.WithContentType("application/json"),
)
sentinelEventID := bob.SendEventSynced(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
"body": "Sentinel",
},
})
bob.SlidingSyncUntilEventID(t, "", roomID, sentinelEventID)
res = bob.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
})
m.MatchResponse(t, res, m.MatchFallbackKeyTypes([]string{"signed_curve25519"}))

// consume the fallback key and do another initial request => should be []
mustClaimOTK(t, alice, bob)
sentinelEventID = bob.SendEventSynced(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
"body": "Sentinel 2",
},
})
bob.SlidingSyncUntilEventID(t, "", roomID, sentinelEventID)
res = bob.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
})
m.MatchResponse(t, res, m.MatchFallbackKeyTypes([]string{}))
}

func MatchOTKAndFallbackTypes(otkCount map[string]int, fallbackKeyTypes []string) m.RespMatcher {
return func(r *sync3.Response) error {
err := m.MatchOTKCounts(otkCount)(r)
if err != nil {
return err
}
// we should explicitly be sent device_unused_fallback_key_types: []
return m.MatchFallbackKeyTypes(fallbackKeyTypes)(r)
}
}

func mustClaimOTK(t *testing.T, claimer, claimee *CSAPI) {
claimRes := claimer.MustDo(t, "POST", []string{"_matrix", "client", "v3", "keys", "claim"}, client.WithJSONBody(t, map[string]any{
"one_time_keys": map[string]any{
claimee.UserID: map[string]any{
claimee.DeviceID: "signed_curve25519",
},
},
}))
var res struct {
Failures map[string]any `json:"failures"`
OTKs map[string]map[string]any `json:"one_time_keys"`
}
if err := json.NewDecoder(claimRes.Body).Decode(&res); err != nil {
t.Fatalf("failed to decode OTK response: %s", err)
}
if len(res.Failures) > 0 {
t.Fatalf("OTK response had failures: %+v", res.Failures)
}
otk := res.OTKs[claimee.UserID][claimee.DeviceID]
if otk == nil {
t.Fatalf("OTK was not claimed for %s|%s", claimee.UserID, claimee.DeviceID)
}
}
8 changes: 7 additions & 1 deletion testutils/m/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,13 @@ func MatchFallbackKeyTypes(fallbackKeyTypes []string) RespMatcher {
if res.Extensions.E2EE == nil {
return fmt.Errorf("MatchFallbackKeyTypes: no E2EE extension present")
}
if !reflect.DeepEqual(res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes) {
if res.Extensions.E2EE.FallbackKeyTypes == nil { // not supplied
if fallbackKeyTypes == nil {
return nil
}
return fmt.Errorf("MatchFallbackKeyTypes: FallbackKeyTypes is missing but want %v", fallbackKeyTypes)
}
if !reflect.DeepEqual(*res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes) {
return fmt.Errorf("MatchFallbackKeyTypes: got %v want %v", res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes)
}
return nil
Expand Down

0 comments on commit e6a2e67

Please sign in to comment.