feat: store the result at each stage

This commit is contained in:
nikita 2024-12-26 14:18:12 +03:00
parent 36026b3afb
commit 157708d00f
2 changed files with 51 additions and 6 deletions

View File

@ -0,0 +1,9 @@
namespace Mirea.Api.Security.Common.Domain.Caching;
internal class OAuthUserExtension
{
public string? Message { get; set; }
public bool IsSuccess { get; set; }
public required OAuthProvider? Provider { get; set; }
public OAuthUser? User { get; set; }
}

View File

@ -1,6 +1,7 @@
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Mirea.Api.Security.Common.Domain; using Mirea.Api.Security.Common.Domain;
using Mirea.Api.Security.Common.Domain.Caching;
using Mirea.Api.Security.Common.Interfaces; using Mirea.Api.Security.Common.Interfaces;
using Mirea.Api.Security.Common.OAuth2; using Mirea.Api.Security.Common.OAuth2;
using Mirea.Api.Security.Common.OAuth2.UserInfo; using Mirea.Api.Security.Common.OAuth2.UserInfo;
@ -11,6 +12,7 @@ using System.IO;
using System.Linq; using System.Linq;
using System.Net.Http; using System.Net.Http;
using System.Net.Http.Headers; using System.Net.Http.Headers;
using System.Security;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
@ -166,7 +168,14 @@ public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider
} }
} }
public Uri GetProviderRedirect(HttpContext context, CookieOptionsParameters cookieOptions, string redirectUri, private Task StoreOAuthUserInCache(string key, OAuthUserExtension data, CancellationToken cancellation) =>
cache.SetAsync(
key,
JsonSerializer.SerializeToUtf8Bytes(data),
slidingExpiration: TimeSpan.FromMinutes(15),
cancellationToken: cancellation);
public Uri GetProviderRedirect(HttpContext context, CookieOptions cookieOptions, string redirectUri, public Uri GetProviderRedirect(HttpContext context, CookieOptions cookieOptions, string redirectUri,
OAuthProvider provider, Uri callback) OAuthProvider provider, Uri callback)
{ {
@ -205,17 +214,30 @@ public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider
{ {
Token = GeneratorKey.GenerateBase64(32) Token = GeneratorKey.GenerateBase64(32)
}; };
var parts = state.Split('_'); var parts = state.Split('_');
if (parts.Length != 2) if (parts.Length != 2)
{ {
result.ErrorMessage = "The request data is invalid or malformed."; result.ErrorMessage = "The request data is invalid or malformed.";
await StoreOAuthUserInCache(result.Token, new OAuthUserExtension()
{
Message = result.ErrorMessage,
Provider = null
}, cancellation);
return result; return result;
} }
var payload = DecryptPayload(parts[0]); var payload = DecryptPayload(parts[0]);
var checksum = parts[1]; var checksum = parts[1];
var cacheData = new OAuthUserExtension()
{
Provider = payload.Provider
};
result.Callback = new Uri(payload.Callback); result.Callback = new Uri(payload.Callback);
if (!providers.TryGetValue(payload.Provider, out var providerInfo) || if (!providers.TryGetValue(payload.Provider, out var providerInfo) ||
@ -226,6 +248,10 @@ public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider
state); state);
result.ErrorMessage = "Invalid authorization request. Please try again later."; result.ErrorMessage = "Invalid authorization request. Please try again later.";
cacheData.Message = result.ErrorMessage;
await StoreOAuthUserInCache(result.Token, cacheData, cancellation);
return result; return result;
} }
@ -233,6 +259,7 @@ public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider
var checksumRequest = GetHmacString(requestInfo); var checksumRequest = GetHmacString(requestInfo);
result.ErrorMessage = "Authorization failed. Please try again later."; result.ErrorMessage = "Authorization failed. Please try again later.";
cacheData.Message = result.ErrorMessage;
if (checksumRequest != checksum) if (checksumRequest != checksum)
{ {
@ -243,6 +270,8 @@ public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider
checksum checksum
); );
await StoreOAuthUserInCache(result.Token, cacheData, cancellation);
return result; return result;
} }
@ -258,6 +287,8 @@ public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider
payload.Provider, payload.Provider,
checksum); checksum);
await StoreOAuthUserInCache(result.Token, cacheData, cancellation);
return result; return result;
} }
@ -275,6 +306,8 @@ public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider
logger.LogWarning(ex, "Failed to retrieve user information from provider {Provider}", logger.LogWarning(ex, "Failed to retrieve user information from provider {Provider}",
payload.Provider); payload.Provider);
await StoreOAuthUserInCache(result.Token, cacheData, cancellation);
return result; return result;
} }
@ -284,12 +317,15 @@ public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider
result.ErrorMessage = null; result.ErrorMessage = null;
result.Success = true; result.Success = true;
await cache.SetAsync( await StoreOAuthUserInCache(result.Token, new OAuthUserExtension
result.Token, {
JsonSerializer.SerializeToUtf8Bytes(user), IsSuccess = true,
absoluteExpirationRelativeToNow: TimeSpan.FromMinutes(15), User = user,
cancellationToken: cancellation); Provider = payload.Provider
}, cancellation);
return result; return result;
} }
} }