diff --git a/src/IdentityServer/Services/Default/DefaultPersistedGrantService.cs b/src/IdentityServer/Services/Default/DefaultPersistedGrantService.cs index 20103e5a2..362181a06 100644 --- a/src/IdentityServer/Services/Default/DefaultPersistedGrantService.cs +++ b/src/IdentityServer/Services/Default/DefaultPersistedGrantService.cs @@ -28,7 +28,7 @@ public class DefaultPersistedGrantService : IPersistedGrantService /// The store. /// The serializer. /// The logger. - public DefaultPersistedGrantService(IPersistedGrantStore store, + public DefaultPersistedGrantService(IPersistedGrantStore store, IPersistentGrantSerializer serializer, ILogger logger) { @@ -41,18 +41,34 @@ public DefaultPersistedGrantService(IPersistedGrantStore store, public async Task> GetAllGrantsAsync(string subjectId) { using var activity = Tracing.ServiceActivitySource.StartActivity("DefaultPersistedGrantService.GetAllGrants"); - + if (String.IsNullOrWhiteSpace(subjectId)) throw new ArgumentNullException(nameof(subjectId)); var grants = (await _store.GetAllAsync(new PersistedGrantFilter { SubjectId = subjectId })) .Where(x => x.ConsumedTime == null) // filter consumed grants .ToArray(); + List errors = new List(); + + T DeserializeAndCaptureErrors(string data) + { + try + { + return _serializer.Deserialize(data); + } + catch (Exception ex) + { + errors.Add(ex); + return default(T); + } + } + try { var consents = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.UserConsent) - .Select(x => _serializer.Deserialize(x.Data)) - .Select(x => new Grant + .Select(x => DeserializeAndCaptureErrors(x.Data)) + .Where(x => x != default) + .Select(x => new Grant { ClientId = x.ClientId, SubjectId = subjectId, @@ -62,7 +78,8 @@ public async Task> GetAllGrantsAsync(string subjectId) }); var codes = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.AuthorizationCode) - .Select(x => _serializer.Deserialize(x.Data)) + .Select(x => DeserializeAndCaptureErrors(x.Data)) + .Where(x => x != default) .Select(x => new Grant { ClientId = x.ClientId, @@ -74,7 +91,8 @@ public async Task> GetAllGrantsAsync(string subjectId) }); var refresh = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.RefreshToken) - .Select(x => _serializer.Deserialize(x.Data)) + .Select(x => DeserializeAndCaptureErrors(x.Data)) + .Where(x => x != default) .Select(x => new Grant { ClientId = x.ClientId, @@ -86,7 +104,8 @@ public async Task> GetAllGrantsAsync(string subjectId) }); var access = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.ReferenceToken) - .Select(x => _serializer.Deserialize(x.Data)) + .Select(x => DeserializeAndCaptureErrors(x.Data)) + .Where(x => x != default) .Select(x => new Grant { ClientId = x.ClientId, @@ -101,6 +120,11 @@ public async Task> GetAllGrantsAsync(string subjectId) consents = Join(consents, refresh); consents = Join(consents, access); + if (errors.Count > 0) + { + _logger.LogError(new AggregateException(errors), "One or more errors occured during deserialization of persisted grants, returning successfull items."); + } + return consents.ToArray(); } catch (Exception ex) @@ -115,7 +139,7 @@ private IEnumerable Join(IEnumerable first, IEnumerable sec { var list = first.ToList(); - foreach(var other in second) + foreach (var other in second) { var match = list.FirstOrDefault(x => x.ClientId == other.ClientId); if (match != null) @@ -154,10 +178,11 @@ private IEnumerable Join(IEnumerable first, IEnumerable sec public Task RemoveAllGrantsAsync(string subjectId, string clientId = null, string sessionId = null) { using var activity = Tracing.ServiceActivitySource.StartActivity("DefaultPersistedGrantService.RemoveAllGrants"); - + if (String.IsNullOrWhiteSpace(subjectId)) throw new ArgumentNullException(nameof(subjectId)); - return _store.RemoveAllAsync(new PersistedGrantFilter { + return _store.RemoveAllAsync(new PersistedGrantFilter + { SubjectId = subjectId, ClientId = clientId, SessionId = sessionId diff --git a/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs b/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs index dae95eb79..1cba7b093 100644 --- a/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs +++ b/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Security.Claims; using System.Threading.Tasks; using Duende.IdentityServer; @@ -540,4 +541,82 @@ await _userConsent.StoreUserConsentAsync(new Consent() grants.Count().Should().Be(1); grants.First().Scopes.Should().Contain(new string[] { "foo1", "foo2", "quux3" }); } -} \ No newline at end of file + + [Fact] + public async Task GetAllGrantsAsync_should_filter_items_with_corrupt_data_from_result() + { + var mockStore = new CorruptingPersistedGrantStore(_store) + { + ClientIdToCorrupt = "client2" + }; + + _subject = new DefaultPersistedGrantService( + mockStore, + new PersistentGrantSerializer(), + TestLogger.Create()); + + await _userConsent.StoreUserConsentAsync(new Consent() + { + ClientId = "client1", + SubjectId = "123", + Scopes = new string[] { "foo1", "foo2" } + }); + await _userConsent.StoreUserConsentAsync(new Consent() + { + ClientId = "client2", + SubjectId = "123", + Scopes = new string[] { "foo3" } + }); + + var grants = await _subject.GetAllGrantsAsync("123"); + + grants.Count().Should().Be(1); + grants.First().Scopes.Should().Contain(new string[] { "foo1", "foo2" }); + } + + class CorruptingPersistedGrantStore : IPersistedGrantStore + { + public string ClientIdToCorrupt { get; set; } + + private IPersistedGrantStore _inner; + + public CorruptingPersistedGrantStore(IPersistedGrantStore inner) + { + _inner = inner; + } + + public async Task> GetAllAsync(PersistedGrantFilter filter) + { + var items = await _inner.GetAllAsync(filter); + if (ClientIdToCorrupt != null) + { + var itemsToCorrupt = items.Where(x => x.ClientId == ClientIdToCorrupt); + foreach(var corruptItem in itemsToCorrupt) + { + corruptItem.Data = "corrupt"; + } + } + return items; + } + + public Task GetAsync(string key) + { + return _inner.GetAsync(key); + } + + public Task RemoveAllAsync(PersistedGrantFilter filter) + { + return _inner.RemoveAllAsync(filter); + } + + public Task RemoveAsync(string key) + { + return _inner.RemoveAsync(key); + } + + public Task StoreAsync(PersistedGrant grant) + { + return _inner.StoreAsync(grant); + } + } +}