diff --git a/src/IdentityServer/Services/Default/DefaultPersistedGrantService.cs b/src/IdentityServer/Services/Default/DefaultPersistedGrantService.cs
index 20103e5a2..362181a06 100644
--- a/src/IdentityServer/Services/Default/DefaultPersistedGrantService.cs
+++ b/src/IdentityServer/Services/Default/DefaultPersistedGrantService.cs
@@ -28,7 +28,7 @@ public class DefaultPersistedGrantService : IPersistedGrantService
/// The store.
/// The serializer.
/// The logger.
- public DefaultPersistedGrantService(IPersistedGrantStore store,
+ public DefaultPersistedGrantService(IPersistedGrantStore store,
IPersistentGrantSerializer serializer,
ILogger logger)
{
@@ -41,18 +41,34 @@ public DefaultPersistedGrantService(IPersistedGrantStore store,
public async Task> GetAllGrantsAsync(string subjectId)
{
using var activity = Tracing.ServiceActivitySource.StartActivity("DefaultPersistedGrantService.GetAllGrants");
-
+
if (String.IsNullOrWhiteSpace(subjectId)) throw new ArgumentNullException(nameof(subjectId));
var grants = (await _store.GetAllAsync(new PersistedGrantFilter { SubjectId = subjectId }))
.Where(x => x.ConsumedTime == null) // filter consumed grants
.ToArray();
+ List errors = new List();
+
+ T DeserializeAndCaptureErrors(string data)
+ {
+ try
+ {
+ return _serializer.Deserialize(data);
+ }
+ catch (Exception ex)
+ {
+ errors.Add(ex);
+ return default(T);
+ }
+ }
+
try
{
var consents = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.UserConsent)
- .Select(x => _serializer.Deserialize(x.Data))
- .Select(x => new Grant
+ .Select(x => DeserializeAndCaptureErrors(x.Data))
+ .Where(x => x != default)
+ .Select(x => new Grant
{
ClientId = x.ClientId,
SubjectId = subjectId,
@@ -62,7 +78,8 @@ public async Task> GetAllGrantsAsync(string subjectId)
});
var codes = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.AuthorizationCode)
- .Select(x => _serializer.Deserialize(x.Data))
+ .Select(x => DeserializeAndCaptureErrors(x.Data))
+ .Where(x => x != default)
.Select(x => new Grant
{
ClientId = x.ClientId,
@@ -74,7 +91,8 @@ public async Task> GetAllGrantsAsync(string subjectId)
});
var refresh = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.RefreshToken)
- .Select(x => _serializer.Deserialize(x.Data))
+ .Select(x => DeserializeAndCaptureErrors(x.Data))
+ .Where(x => x != default)
.Select(x => new Grant
{
ClientId = x.ClientId,
@@ -86,7 +104,8 @@ public async Task> GetAllGrantsAsync(string subjectId)
});
var access = grants.Where(x => x.Type == IdentityServerConstants.PersistedGrantTypes.ReferenceToken)
- .Select(x => _serializer.Deserialize(x.Data))
+ .Select(x => DeserializeAndCaptureErrors(x.Data))
+ .Where(x => x != default)
.Select(x => new Grant
{
ClientId = x.ClientId,
@@ -101,6 +120,11 @@ public async Task> GetAllGrantsAsync(string subjectId)
consents = Join(consents, refresh);
consents = Join(consents, access);
+ if (errors.Count > 0)
+ {
+ _logger.LogError(new AggregateException(errors), "One or more errors occured during deserialization of persisted grants, returning successfull items.");
+ }
+
return consents.ToArray();
}
catch (Exception ex)
@@ -115,7 +139,7 @@ private IEnumerable Join(IEnumerable first, IEnumerable sec
{
var list = first.ToList();
- foreach(var other in second)
+ foreach (var other in second)
{
var match = list.FirstOrDefault(x => x.ClientId == other.ClientId);
if (match != null)
@@ -154,10 +178,11 @@ private IEnumerable Join(IEnumerable first, IEnumerable sec
public Task RemoveAllGrantsAsync(string subjectId, string clientId = null, string sessionId = null)
{
using var activity = Tracing.ServiceActivitySource.StartActivity("DefaultPersistedGrantService.RemoveAllGrants");
-
+
if (String.IsNullOrWhiteSpace(subjectId)) throw new ArgumentNullException(nameof(subjectId));
- return _store.RemoveAllAsync(new PersistedGrantFilter {
+ return _store.RemoveAllAsync(new PersistedGrantFilter
+ {
SubjectId = subjectId,
ClientId = clientId,
SessionId = sessionId
diff --git a/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs b/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs
index dae95eb79..1cba7b093 100644
--- a/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs
+++ b/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs
@@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Reflection;
using System.Security.Claims;
using System.Threading.Tasks;
using Duende.IdentityServer;
@@ -540,4 +541,82 @@ await _userConsent.StoreUserConsentAsync(new Consent()
grants.Count().Should().Be(1);
grants.First().Scopes.Should().Contain(new string[] { "foo1", "foo2", "quux3" });
}
-}
\ No newline at end of file
+
+ [Fact]
+ public async Task GetAllGrantsAsync_should_filter_items_with_corrupt_data_from_result()
+ {
+ var mockStore = new CorruptingPersistedGrantStore(_store)
+ {
+ ClientIdToCorrupt = "client2"
+ };
+
+ _subject = new DefaultPersistedGrantService(
+ mockStore,
+ new PersistentGrantSerializer(),
+ TestLogger.Create());
+
+ await _userConsent.StoreUserConsentAsync(new Consent()
+ {
+ ClientId = "client1",
+ SubjectId = "123",
+ Scopes = new string[] { "foo1", "foo2" }
+ });
+ await _userConsent.StoreUserConsentAsync(new Consent()
+ {
+ ClientId = "client2",
+ SubjectId = "123",
+ Scopes = new string[] { "foo3" }
+ });
+
+ var grants = await _subject.GetAllGrantsAsync("123");
+
+ grants.Count().Should().Be(1);
+ grants.First().Scopes.Should().Contain(new string[] { "foo1", "foo2" });
+ }
+
+ class CorruptingPersistedGrantStore : IPersistedGrantStore
+ {
+ public string ClientIdToCorrupt { get; set; }
+
+ private IPersistedGrantStore _inner;
+
+ public CorruptingPersistedGrantStore(IPersistedGrantStore inner)
+ {
+ _inner = inner;
+ }
+
+ public async Task> GetAllAsync(PersistedGrantFilter filter)
+ {
+ var items = await _inner.GetAllAsync(filter);
+ if (ClientIdToCorrupt != null)
+ {
+ var itemsToCorrupt = items.Where(x => x.ClientId == ClientIdToCorrupt);
+ foreach(var corruptItem in itemsToCorrupt)
+ {
+ corruptItem.Data = "corrupt";
+ }
+ }
+ return items;
+ }
+
+ public Task GetAsync(string key)
+ {
+ return _inner.GetAsync(key);
+ }
+
+ public Task RemoveAllAsync(PersistedGrantFilter filter)
+ {
+ return _inner.RemoveAllAsync(filter);
+ }
+
+ public Task RemoveAsync(string key)
+ {
+ return _inner.RemoveAsync(key);
+ }
+
+ public Task StoreAsync(PersistedGrant grant)
+ {
+ return _inner.StoreAsync(grant);
+ }
+ }
+}