Skip to content

Commit

Permalink
fix: various connection tests
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Bluhm <dbluhm@pm.me>
  • Loading branch information
dbluhm committed Sep 19, 2024
1 parent 0d3939e commit 1c844c8
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 127 deletions.
4 changes: 4 additions & 0 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import logging
from typing import Dict, List, Optional, Sequence, Text, Tuple, Union
import warnings

import pydid
from base58 import b58decode
Expand Down Expand Up @@ -298,6 +299,8 @@ async def create_did_document(
) -> DIDDoc:
"""Create our DID doc for a given DID.
This method is deprecated and will be removed.
Args:
did_info (DIDInfo): The DID information (DID and verkey) used in the
connection.
Expand All @@ -310,6 +313,7 @@ async def create_did_document(
DIDDoc: The prepared `DIDDoc` instance.
"""
warnings.warn("create_did_document is deprecated and will be removed soon")
did_doc = DIDDoc(did=did_info.did)
did_controller = did_info.did
did_key = did_info.verkey
Expand Down
12 changes: 9 additions & 3 deletions aries_cloudagent/connections/models/tests/test_conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
InvitationMessage,
)
from aries_cloudagent.protocols.out_of_band.v1_0.messages.service import Service
from aries_cloudagent.wallet.key_type import ED25519
from ....did.did_key import DIDKey

from ....core.in_memory import InMemoryProfile
from ....storage.base import BaseStorage
Expand All @@ -19,6 +21,7 @@ def setUp(self):
self.test_seed = "testseed000000000000000000000001"
self.test_did = "55GkHamhTU1ZbTbV2ab9DE"
self.test_verkey = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx"
self.test_didkey = DIDKey.from_public_key_b58(self.test_verkey, ED25519)
self.test_endpoint = "http://localhost"

self.test_target_did = "GbuDUYXaUZRfHD2jeDuQuP"
Expand Down Expand Up @@ -355,10 +358,13 @@ async def test_attach_retrieve_invitation(self):
connection_id = await record.save(self.session)

service = Service(
recipient_keys=[self.test_verkey],
_id="asdf",
_type="did-communication",
recipient_keys=[self.test_didkey.did],
service_endpoint="http://localhost:8999",
)
invi = InvitationMessage(
handshake_protocols=["didexchange/1.1"],
services=[service],
label="abc123",
)
Expand Down Expand Up @@ -431,11 +437,11 @@ async def test_deserialize_connection_protocol(self):
state=ConnRecord.State.INIT,
my_did=self.test_did,
their_role=ConnRecord.Role.REQUESTER,
connection_protocol="connections/1.0",
connection_protocol="didexchange/1.0",
)
ser = record.serialize()
deser = ConnRecord.deserialize(ser)
assert deser.connection_protocol == "connections/1.0"
assert deser.connection_protocol == "didexchange/1.0"

async def test_metadata_set_get(self):
record = ConnRecord(
Expand Down
156 changes: 32 additions & 124 deletions aries_cloudagent/connections/tests/test_base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from unittest import IsolatedAsyncioTestCase
from unittest.mock import call
import secrets

import base58
from pydid import DID, DIDDocument, DIDDocumentBuilder
from pydid.doc.builder import ServiceBuilder
from pydid.verification_method import (
Expand Down Expand Up @@ -31,9 +33,6 @@
from ...messaging.responder import BaseResponder, MockResponder
from ...multitenant.base import BaseMultitenantManager
from ...multitenant.manager import MultitenantManager
from ...protocols.coordinate_mediation.v1_0.models.mediation_record import (
MediationRecord,
)
from ...protocols.coordinate_mediation.v1_0.route_manager import (
CoordinateMediationV1RouteManager,
RouteManager,
Expand All @@ -46,7 +45,7 @@
from ...storage.error import StorageNotFoundError
from ...transport.inbound.receipt import MessageReceipt
from ...utils.multiformats import multibase, multicodec
from ...wallet.base import DIDInfo
from ...wallet.base import BaseWallet, DIDInfo
from ...wallet.did_method import SOV, DIDMethods
from ...wallet.error import WalletNotFoundError
from ...wallet.in_memory import InMemoryWallet
Expand Down Expand Up @@ -112,6 +111,13 @@ async def asyncSetUp(self):
DIDResolver: self.resolver,
},
)
async with self.profile.session() as session:
wallet = session.inject(BaseWallet)
info = await wallet.create_local_did(method=SOV, key_type=ED25519)

self.did = info.did
self.verkey = info.verkey

self.context = self.profile.context

self.multitenant_mgr = mock.MagicMock(MultitenantManager, autospec=True)
Expand All @@ -126,113 +132,6 @@ async def asyncSetUp(self):
self.manager = BaseConnectionManager(self.profile)
assert self.manager._profile

async def test_create_did_document(self):
did_info = DIDInfo(
self.test_did,
self.test_verkey,
None,
method=SOV,
key_type=ED25519,
)

did_doc = await self.manager.create_did_document(
did_info=did_info,
svc_endpoints=[self.test_endpoint],
)

async def test_create_did_document_mediation(self):
did_info = DIDInfo(
self.test_did,
self.test_verkey,
None,
method=SOV,
key_type=ED25519,
)
mediation_record = MediationRecord(
role=MediationRecord.ROLE_CLIENT,
state=MediationRecord.STATE_GRANTED,
connection_id=self.test_mediator_conn_id,
routing_keys=self.test_mediator_routing_keys,
endpoint=self.test_mediator_endpoint,
)
doc = await self.manager.create_did_document(
did_info, mediation_records=[mediation_record]
)
assert doc.service
services = list(doc.service.values())
assert len(services) == 1
(service,) = services
assert service.routing_keys
service_routing_key = service.routing_keys[0]
assert service_routing_key == mediation_record.routing_keys[0]
assert service.endpoint == mediation_record.endpoint

async def test_create_did_document_multiple_mediators(self):
did_info = DIDInfo(
self.test_did,
self.test_verkey,
None,
method=SOV,
key_type=ED25519,
)
mediation_record1 = MediationRecord(
role=MediationRecord.ROLE_CLIENT,
state=MediationRecord.STATE_GRANTED,
connection_id=self.test_mediator_conn_id,
routing_keys=self.test_mediator_routing_keys,
endpoint=self.test_mediator_endpoint,
)
mediation_record2 = MediationRecord(
role=MediationRecord.ROLE_CLIENT,
state=MediationRecord.STATE_GRANTED,
connection_id="mediator-conn-id2",
routing_keys=[
"did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDz#z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDz"
],
endpoint="http://mediatorw.example.com",
)
doc = await self.manager.create_did_document(
did_info, mediation_records=[mediation_record1, mediation_record2]
)
assert doc.service
services = list(doc.service.values())
assert len(services) == 1
(service,) = services
assert service.routing_keys[0] == mediation_record1.routing_keys[0]
assert service.routing_keys[1] == mediation_record2.routing_keys[0]
assert service.endpoint == mediation_record2.endpoint

async def test_create_did_document_mediation_svc_endpoints_overwritten(self):
did_info = DIDInfo(
self.test_did,
self.test_verkey,
None,
method=SOV,
key_type=ED25519,
)
mediation_record = MediationRecord(
role=MediationRecord.ROLE_CLIENT,
state=MediationRecord.STATE_GRANTED,
connection_id=self.test_mediator_conn_id,
routing_keys=self.test_mediator_routing_keys,
endpoint=self.test_mediator_endpoint,
)
self.route_manager.routing_info = mock.CoroutineMock(
return_value=(mediation_record.routing_keys, mediation_record.endpoint)
)
doc = await self.manager.create_did_document(
did_info,
svc_endpoints=[self.test_endpoint],
mediation_records=[mediation_record],
)
assert doc.service
services = list(doc.service.values())
assert len(services) == 1
(service,) = services
service_public_keys = service.routing_keys[0]
assert service_public_keys == mediation_record.routing_keys[0]
assert service.endpoint == mediation_record.endpoint

async def test_did_key_storage(self):
await self.manager.add_key_for_did(
did=self.test_target_did, key=self.test_target_verkey
Expand All @@ -252,7 +151,7 @@ async def test_fetch_connection_targets_no_my_did(self):

async def test_fetch_connection_targets_in_progress_conn(self):
mock_conn = mock.MagicMock(
my_did=self.test_did,
my_did=self.did,
their_did=self.test_target_did,
connection_id="dummy",
their_role=ConnRecord.Role.RESPONDER.rfc23,
Expand All @@ -275,6 +174,7 @@ async def test_fetch_targets_for_connection_in_progress_inv(self):
state=ConnRecord.State.INVITATION.rfc23,
invitation_msg_id="test-invite-msg-id",
)
mock_conn.retrieve_invitation = mock.CoroutineMock()
with mock.patch.object(
self.manager,
"_fetch_connection_targets_for_invitation",
Expand All @@ -292,11 +192,15 @@ async def test_fetch_targets_for_connection_in_progress_implicit(self):
connection_id="dummy",
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.INVITATION.rfc23,
invitation_msg_id=None,
invitation_key=None,
)
with mock.patch.object(
self.manager,
"resolve_invitation",
mock.CoroutineMock(),
mock.CoroutineMock(
return_value=(mock.MagicMock(), mock.MagicMock(), mock.MagicMock())
),
) as mock_resolve_invitation:
await self.manager._fetch_targets_for_connection_in_progress(
mock_conn, self.test_did
Expand Down Expand Up @@ -350,7 +254,7 @@ async def test_fetch_connection_targets_conn_invitation_btcr_without_services(se
self.resolver.resolve = mock.CoroutineMock(return_value=did_doc)
self.context.injector.bind_instance(DIDResolver, self.resolver)

invitation = InvitationMessage(did=did_doc.id)
invitation = InvitationMessage(services=[did_doc.id])
mock_conn = mock.MagicMock(
my_did=did_doc.id,
their_did=self.test_target_did,
Expand Down Expand Up @@ -397,7 +301,7 @@ async def test_fetch_connection_targets_conn_invitation_no_didcomm_services(self
did_doc = builder.build()
self.resolver.resolve = mock.CoroutineMock(return_value=did_doc)
self.context.injector.bind_instance(DIDResolver, self.resolver)
invitation = InvitationMessage(did=did_doc.id)
invitation = InvitationMessage(services=[did_doc.id])
mock_conn = mock.MagicMock(
my_did=did_doc.id,
their_did=self.test_target_did,
Expand Down Expand Up @@ -920,8 +824,10 @@ async def test_get_connection_targets_retrieve_connection(self):
service = OOBService(
did=None,
service_endpoint=self.test_endpoint,
recipient_keys=[self.test_target_verkey],
routing_keys=[self.test_verkey],
recipient_keys=[
DIDKey.from_public_key_b58(self.test_target_verkey, ED25519).did
],
routing_keys=[DIDKey.from_public_key_b58(self.test_verkey, ED25519).did],
)
conn_invite = InvitationMessage(
services=[service],
Expand Down Expand Up @@ -952,8 +858,8 @@ async def test_get_connection_targets_retrieve_connection(self):
assert target.did == mock_conn.their_did
assert target.endpoint == service.service_endpoint
assert target.label == conn_invite.label
assert target.recipient_keys == service.recipient_keys
assert target.routing_keys == service.routing_keys
assert target.recipient_keys == [self.test_target_verkey]
assert target.routing_keys == [self.test_verkey]
assert target.sender_key == local_did.verkey

async def test_get_connection_targets_from_cache(self):
Expand Down Expand Up @@ -1068,8 +974,10 @@ async def test_get_conn_targets_conn_invitation_no_cache(self):
service = OOBService(
did=None,
service_endpoint=self.test_endpoint,
recipient_keys=[self.test_target_verkey],
routing_keys=[self.test_verkey],
recipient_keys=[
DIDKey.from_public_key_b58(self.test_target_verkey, ED25519).did
],
routing_keys=[DIDKey.from_public_key_b58(self.test_verkey, ED25519).did],
)
conn_invite = InvitationMessage(
services=[service],
Expand All @@ -1093,14 +1001,14 @@ async def test_get_conn_targets_conn_invitation_no_cache(self):
assert target.did == mock_conn.their_did
assert target.endpoint == service.service_endpoint
assert target.label == conn_invite.label
assert target.recipient_keys == service.recipient_keys
assert target.routing_keys == service.routing_keys
assert target.recipient_keys == [self.test_target_verkey]
assert target.routing_keys == [self.test_verkey]
assert target.sender_key == local_did.verkey

async def test_create_static_connection(self):
with mock.patch.object(ConnRecord, "save", autospec=True) as mock_conn_rec_save:
_my, _their, conn_rec = await self.manager.create_static_connection(
my_did=self.test_did,
my_did=base58.b58encode(secrets.token_bytes(16)).decode(),
their_did=self.test_target_did,
their_verkey=self.test_target_verkey,
their_endpoint=self.test_endpoint,
Expand Down

0 comments on commit 1c844c8

Please sign in to comment.