sec: add payload

This commit is contained in:
Polianin Nikita 2024-12-26 08:47:56 +03:00
parent 97187a8e45
commit 17fd260068
2 changed files with 107 additions and 21 deletions

View File

@ -16,7 +16,7 @@ using Mirea.Api.Security.Services;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Security.Claims;
using System.Linq;
using System.Threading.Tasks;
using OAuthProvider = Mirea.Api.Security.Common.Domain.OAuthProvider;
@ -130,17 +130,23 @@ public class AuthController(IOptionsSnapshot<Admin> user, IOptionsSnapshot<Gener
/// This method generates a redirect URL for the selected provider and redirects the user to it.
/// </remarks>
/// <param name="provider">The identifier of the OAuth provider to authorize with.</param>
/// <param name="callback">The address where the user will need to be redirected after the end of communication with the OAuth provider</param>
/// <returns>A redirect to the OAuth provider's authorization URL.</returns>
/// <exception cref="ControllerArgumentException">Thrown if the specified provider is not valid.</exception>
[HttpGet("AuthorizeOAuth2")]
[MaintenanceModeIgnore]
public ActionResult AuthorizeOAuth2([FromQuery] int provider)
public ActionResult AuthorizeOAuth2([FromQuery] int provider, [FromQuery] Uri callback)
{
if (!Enum.IsDefined(typeof(OAuthProvider), provider))
throw new ControllerArgumentException("There is no selected provider");
return Redirect(oAuthService.GetProviderRedirect(HttpContext, GetCookieParams(), HttpContext.GetApiUrl(Url.Action("OAuth2")!),
(OAuthProvider)provider).AbsoluteUri);
if (!callback.IsAbsoluteUri)
throw new ControllerArgumentException("The callback URL must be absolute.");
return Redirect(oAuthService.GetProviderRedirect(HttpContext, GetCookieParams(),
HttpContext.GetApiUrl(Url.Action("OAuth2")!),
(OAuthProvider)provider,
callback).AbsoluteUri);
}
/// <summary>
@ -152,9 +158,17 @@ public class AuthController(IOptionsSnapshot<Admin> user, IOptionsSnapshot<Gener
/// <returns>A list of available providers and their redirect URLs.</returns>
[HttpGet("AvailableProviders")]
[MaintenanceModeIgnore]
public ActionResult<List<AvailableOAuthProvidersResponse>> AvailableProviders() =>
public ActionResult<List<AvailableOAuthProvidersResponse>> AvailableProviders([FromQuery] Uri callback) =>
Ok(oAuthService
.GetAvailableProviders(HttpContext.GetApiUrl(Url.Action("AuthorizeOAuth2")!))
.Select(x =>
{
if (!callback.IsAbsoluteUri)
throw new ControllerArgumentException("The callback URL must be absolute.");
x.Redirect = new Uri(x.Redirect + "&callback=" + Uri.EscapeDataString(callback.AbsoluteUri));
return x;
})
.ConvertToDto());
/// <summary>

View File

@ -6,6 +6,7 @@ using Mirea.Api.Security.Common.Domain.OAuth2.UserInfo;
using Mirea.Api.Security.Common.Interfaces;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
@ -110,23 +111,83 @@ public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider
Encoding.UTF8.GetBytes($"{contextInfo.Fingerprint}_{contextInfo.Ip}_{contextInfo.UserAgent}")));
}
public Uri GetProviderRedirect(HttpContext context, CookieOptionsParameters cookieOptions, string redirectUri, OAuthProvider provider)
private string EncryptPayload(OAuthPayload payload)
{
var data = JsonSerializer.Serialize(payload);
var aes = Aes.Create();
aes.Key = SecretKey.ToArray();
aes.GenerateIV();
using var encryptor = aes.CreateEncryptor(aes.Key, aes.IV);
using var ms = new MemoryStream();
ms.Write(aes.IV, 0, aes.IV.Length);
using (var cs = new CryptoStream(ms, encryptor, CryptoStreamMode.Write))
using (var writer = new StreamWriter(cs))
{
writer.Write(data);
}
return Convert.ToBase64String(ms.ToArray());
}
private OAuthPayload DecryptPayload(string encryptedData)
{
try
{
var cipherBytes = Convert.FromBase64String(encryptedData);
using var aes = Aes.Create();
aes.Key = SecretKey.ToArray();
var iv = new byte[16];
Array.Copy(cipherBytes, 0, iv, 0, iv.Length);
aes.IV = iv;
using var ms = new MemoryStream(cipherBytes, 16, cipherBytes.Length - 16);
using var decryptor = aes.CreateDecryptor(aes.Key, aes.IV);
using var cs = new CryptoStream(ms, decryptor, CryptoStreamMode.Read);
using var reader = new StreamReader(cs);
var data = reader.ReadToEnd();
return JsonSerializer.Deserialize<OAuthPayload>(data) ??
throw new NullReferenceException($"Couldn't convert data to {nameof(OAuthPayload)}.");
}
catch (Exception ex)
{
logger.LogWarning(ex, "Couldn't decrypt the data OAuth request.");
throw new InvalidOperationException("Couldn't decrypt the data.", ex);
}
}
public Uri GetProviderRedirect(HttpContext context, CookieOptionsParameters cookieOptions, string redirectUri,
OAuthProvider provider, Uri callback)
{
var (clientId, _) = providers[provider];
var requestInfo = new RequestContextInfo(context, cookieOptions);
var state = GetHmacString(requestInfo, secretKey);
var payload = EncryptPayload(new OAuthPayload()
{
Provider = provider,
Callback = callback.AbsoluteUri
});
var checksum = GetHmacString(requestInfo);
var redirectUrl = $"?client_id={clientId}" +
"&response_type=code" +
$"&redirect_uri={redirectUri}" +
$"&scope={ProviderData[provider].Scope}" +
$"&state={Uri.EscapeDataString(state + "_" + Enum.GetName(provider))}";
$"&state={Uri.EscapeDataString(payload + "_" + checksum)}";
logger.LogInformation("Redirecting user Fingerprint: {Fingerprint} to OAuth provider {Provider} with state: {State}",
requestInfo.Fingerprint,
provider,
state);
checksum);
return new Uri(ProviderData[provider].RedirectUrl + redirectUrl);
}
@ -137,27 +198,38 @@ public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider
public async Task<(OAuthProvider provider, OAuthUser User)> LoginOAuth(HttpContext context, CookieOptionsParameters cookieOptions,
string redirectUrl, string code, string state, CancellationToken cancellation = default)
{
var partsState = state.Split('_');
var parts = state.Split('_');
if (!Enum.TryParse<OAuthProvider>(partsState.Last(), true, out var provider) ||
!providers.TryGetValue(provider, out var providerInfo) ||
!ProviderData.TryGetValue(provider, out var currentProviderStruct))
if (parts.Length != 2)
{
logger.LogWarning("Failed to parse OAuth provider from state: {State}", state);
throw new InvalidOperationException("Invalid authorization request.");
throw new SecurityException("The request data is invalid or malformed.");
}
var secretStateData = string.Join("_", partsState.SkipLast(1));
var requestInfo = new RequestContextInfo(context, cookieOptions);
var secretData = GetHmacString(requestInfo, secretKey);
var payload = DecryptPayload(parts[0]);
var checksum = parts[1];
if (secretData != secretStateData)
if (!providers.TryGetValue(payload.Provider, out var providerInfo) ||
!ProviderData.TryGetValue(payload.Provider, out var currentProviderStruct))
{
logger.LogWarning("The OAuth provider specified in the payload " +
"is not registered as a possible data recipient from state: {State}",
state);
throw new SecurityException("Invalid authorization request. Please try again later.");
}
var requestInfo = new RequestContextInfo(context, cookieOptions);
var checksumRequest = GetHmacString(requestInfo);
if (checksumRequest != checksum)
{
logger.LogWarning(
"Fingerprint mismatch. Possible CSRF attack detected. Fingerprint: {Fingerprint}, State: {State}, ExpectedState: {ExpectedState}",
requestInfo.Fingerprint,
secretData,
secretStateData
checksumRequest,
checksum
);
throw new SecurityException("Suspicious activity detected. Please try again.");
}