diff --git a/Endpoint/Controllers/V1/AuthController.cs b/Endpoint/Controllers/V1/AuthController.cs index 0ce1e4d..e53a5f2 100644 --- a/Endpoint/Controllers/V1/AuthController.cs +++ b/Endpoint/Controllers/V1/AuthController.cs @@ -16,7 +16,7 @@ using Mirea.Api.Security.Services; using System; using System.Collections.Generic; using System.Diagnostics; -using System.Security.Claims; +using System.Linq; using System.Threading.Tasks; using OAuthProvider = Mirea.Api.Security.Common.Domain.OAuthProvider; @@ -130,17 +130,23 @@ public class AuthController(IOptionsSnapshot user, IOptionsSnapshot /// The identifier of the OAuth provider to authorize with. + /// The address where the user will need to be redirected after the end of communication with the OAuth provider /// A redirect to the OAuth provider's authorization URL. /// Thrown if the specified provider is not valid. [HttpGet("AuthorizeOAuth2")] [MaintenanceModeIgnore] - public ActionResult AuthorizeOAuth2([FromQuery] int provider) + public ActionResult AuthorizeOAuth2([FromQuery] int provider, [FromQuery] Uri callback) { if (!Enum.IsDefined(typeof(OAuthProvider), provider)) throw new ControllerArgumentException("There is no selected provider"); - return Redirect(oAuthService.GetProviderRedirect(HttpContext, GetCookieParams(), HttpContext.GetApiUrl(Url.Action("OAuth2")!), - (OAuthProvider)provider).AbsoluteUri); + if (!callback.IsAbsoluteUri) + throw new ControllerArgumentException("The callback URL must be absolute."); + + return Redirect(oAuthService.GetProviderRedirect(HttpContext, GetCookieParams(), + HttpContext.GetApiUrl(Url.Action("OAuth2")!), + (OAuthProvider)provider, + callback).AbsoluteUri); } /// @@ -152,9 +158,17 @@ public class AuthController(IOptionsSnapshot user, IOptionsSnapshotA list of available providers and their redirect URLs. [HttpGet("AvailableProviders")] [MaintenanceModeIgnore] - public ActionResult> AvailableProviders() => + public ActionResult> AvailableProviders([FromQuery] Uri callback) => Ok(oAuthService .GetAvailableProviders(HttpContext.GetApiUrl(Url.Action("AuthorizeOAuth2")!)) + .Select(x => + { + if (!callback.IsAbsoluteUri) + throw new ControllerArgumentException("The callback URL must be absolute."); + + x.Redirect = new Uri(x.Redirect + "&callback=" + Uri.EscapeDataString(callback.AbsoluteUri)); + return x; + }) .ConvertToDto()); /// diff --git a/Security/Services/OAuthService.cs b/Security/Services/OAuthService.cs index a0a5964..c3ab569 100644 --- a/Security/Services/OAuthService.cs +++ b/Security/Services/OAuthService.cs @@ -6,6 +6,7 @@ using Mirea.Api.Security.Common.Domain.OAuth2.UserInfo; using Mirea.Api.Security.Common.Interfaces; using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Net.Http; using System.Net.Http.Headers; @@ -110,23 +111,83 @@ public class OAuthService(ILogger logger, Dictionary(data) ?? + throw new NullReferenceException($"Couldn't convert data to {nameof(OAuthPayload)}."); + } + catch (Exception ex) + { + logger.LogWarning(ex, "Couldn't decrypt the data OAuth request."); + throw new InvalidOperationException("Couldn't decrypt the data.", ex); + } + } + + public Uri GetProviderRedirect(HttpContext context, CookieOptionsParameters cookieOptions, string redirectUri, + OAuthProvider provider, Uri callback) { var (clientId, _) = providers[provider]; var requestInfo = new RequestContextInfo(context, cookieOptions); - var state = GetHmacString(requestInfo, secretKey); + var payload = EncryptPayload(new OAuthPayload() + { + Provider = provider, + Callback = callback.AbsoluteUri + }); + + var checksum = GetHmacString(requestInfo); var redirectUrl = $"?client_id={clientId}" + "&response_type=code" + $"&redirect_uri={redirectUri}" + $"&scope={ProviderData[provider].Scope}" + - $"&state={Uri.EscapeDataString(state + "_" + Enum.GetName(provider))}"; + $"&state={Uri.EscapeDataString(payload + "_" + checksum)}"; logger.LogInformation("Redirecting user Fingerprint: {Fingerprint} to OAuth provider {Provider} with state: {State}", requestInfo.Fingerprint, provider, - state); + checksum); return new Uri(ProviderData[provider].RedirectUrl + redirectUrl); } @@ -137,27 +198,38 @@ public class OAuthService(ILogger logger, Dictionary LoginOAuth(HttpContext context, CookieOptionsParameters cookieOptions, string redirectUrl, string code, string state, CancellationToken cancellation = default) { - var partsState = state.Split('_'); + + var parts = state.Split('_'); - if (!Enum.TryParse(partsState.Last(), true, out var provider) || - !providers.TryGetValue(provider, out var providerInfo) || - !ProviderData.TryGetValue(provider, out var currentProviderStruct)) + if (parts.Length != 2) { - logger.LogWarning("Failed to parse OAuth provider from state: {State}", state); - throw new InvalidOperationException("Invalid authorization request."); + throw new SecurityException("The request data is invalid or malformed."); + } - var secretStateData = string.Join("_", partsState.SkipLast(1)); - var requestInfo = new RequestContextInfo(context, cookieOptions); - var secretData = GetHmacString(requestInfo, secretKey); + var payload = DecryptPayload(parts[0]); + var checksum = parts[1]; - if (secretData != secretStateData) + if (!providers.TryGetValue(payload.Provider, out var providerInfo) || + !ProviderData.TryGetValue(payload.Provider, out var currentProviderStruct)) + { + logger.LogWarning("The OAuth provider specified in the payload " + + "is not registered as a possible data recipient from state: {State}", + state); + + throw new SecurityException("Invalid authorization request. Please try again later."); + } + var requestInfo = new RequestContextInfo(context, cookieOptions); + var checksumRequest = GetHmacString(requestInfo); + + + if (checksumRequest != checksum) { logger.LogWarning( "Fingerprint mismatch. Possible CSRF attack detected. Fingerprint: {Fingerprint}, State: {State}, ExpectedState: {ExpectedState}", requestInfo.Fingerprint, - secretData, - secretStateData + checksumRequest, + checksum ); throw new SecurityException("Suspicious activity detected. Please try again."); }