using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using Mirea.Api.Security.Common.Domain; using Mirea.Api.Security.Common.Domain.Caching; using Mirea.Api.Security.Common.Interfaces; using Mirea.Api.Security.Common.OAuth2; using Mirea.Api.Security.Common.OAuth2.UserInfo; using Mirea.Api.Security.Common.ViewModel; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Net.Http; using System.Net.Http.Headers; using System.Security.Cryptography; using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; using CookieOptions = Mirea.Api.Security.Common.Model.CookieOptions; namespace Mirea.Api.Security.Services; public class OAuthService(ILogger logger, Dictionary providers, ICacheService cache) { public required ReadOnlyMemory SecretKey { private get; init; } private static readonly Dictionary ProviderData = new() { [OAuthProvider.Google] = new OAuthProviderUrisData { RedirectUrl = "https://accounts.google.com/o/oauth2/v2/auth", TokenUrl = "https://oauth2.googleapis.com/token", UserInfoUrl = "https://www.googleapis.com/oauth2/v2/userinfo", Scope = "openid email profile", AuthHeader = "Bearer", UserInfoType = typeof(GoogleUserInfo) }, [OAuthProvider.Yandex] = new OAuthProviderUrisData { RedirectUrl = "https://oauth.yandex.ru/authorize", TokenUrl = "https://oauth.yandex.ru/token", UserInfoUrl = "https://login.yandex.ru/info?format=json", Scope = "login:email login:info login:avatar", AuthHeader = "OAuth", UserInfoType = typeof(YandexUserInfo) }, [OAuthProvider.MailRu] = new OAuthProviderUrisData { RedirectUrl = "https://oauth.mail.ru/login", TokenUrl = "https://oauth.mail.ru/token", UserInfoUrl = "https://oauth.mail.ru/userinfo", AuthHeader = "", Scope = "", UserInfoType = typeof(MailRuUserInfo) } }; private static async Task ExchangeCodeForTokensAsync(string requestUri, string redirectUrl, string code, string clientId, string secret, CancellationToken cancellation) { var tokenRequest = new HttpRequestMessage(HttpMethod.Post, requestUri) { Content = new FormUrlEncodedContent(new Dictionary { { "code", code }, { "client_id", clientId }, { "client_secret", secret }, { "redirect_uri", redirectUrl}, { "grant_type", "authorization_code" } }) }; using var httpClient = new HttpClient(); httpClient.DefaultRequestHeaders.UserAgent.ParseAdd("MireaSchedule/1.0 (Winsomnia)"); var response = await httpClient.SendAsync(tokenRequest, cancellation); var data = await response.Content.ReadAsStringAsync(cancellation); if (!response.IsSuccessStatusCode) throw new HttpRequestException(data); return JsonSerializer.Deserialize(data); } private static async Task GetUserProfileAsync(string requestUri, string authHeader, string accessToken, OAuthProvider provider, CancellationToken cancellation) { var request = new HttpRequestMessage(HttpMethod.Get, requestUri); if (string.IsNullOrEmpty(authHeader)) request.RequestUri = new Uri(request.RequestUri?.AbsoluteUri + "?access_token=" + accessToken); else request.Headers.Authorization = new AuthenticationHeaderValue(authHeader, accessToken); using var httpClient = new HttpClient(); httpClient.DefaultRequestHeaders.UserAgent.ParseAdd("MireaSchedule/1.0 (Winsomnia)"); var response = await httpClient.SendAsync(request, cancellation); var data = await response.Content.ReadAsStringAsync(cancellation); if (!response.IsSuccessStatusCode) throw new HttpRequestException(data); var userInfo = JsonSerializer.Deserialize(data, ProviderData[provider].UserInfoType) as IUserInfo; return userInfo?.MapToInternalUser(); } private string GetHmacString(RequestContextInfo contextInfo) { var hmac = new HMACSHA256(SecretKey.ToArray()); return Convert.ToBase64String(hmac.ComputeHash( Encoding.UTF8.GetBytes($"{contextInfo.Fingerprint}_{contextInfo.Ip}_{contextInfo.UserAgent}"))); } private string EncryptPayload(OAuthPayload payload) { var data = JsonSerializer.Serialize(payload); var aes = Aes.Create(); aes.Key = SecretKey.ToArray(); aes.GenerateIV(); using var encryptor = aes.CreateEncryptor(aes.Key, aes.IV); using var ms = new MemoryStream(); ms.Write(aes.IV, 0, aes.IV.Length); using (var cs = new CryptoStream(ms, encryptor, CryptoStreamMode.Write)) using (var writer = new StreamWriter(cs)) { writer.Write(data); } return Convert.ToBase64String(ms.ToArray()); } private OAuthPayload DecryptPayload(string encryptedData) { try { var cipherBytes = Convert.FromBase64String(encryptedData); using var aes = Aes.Create(); aes.Key = SecretKey.ToArray(); var iv = new byte[16]; Array.Copy(cipherBytes, 0, iv, 0, iv.Length); aes.IV = iv; using var ms = new MemoryStream(cipherBytes, 16, cipherBytes.Length - 16); using var decryptor = aes.CreateDecryptor(aes.Key, aes.IV); using var cs = new CryptoStream(ms, decryptor, CryptoStreamMode.Read); using var reader = new StreamReader(cs); var data = reader.ReadToEnd(); return JsonSerializer.Deserialize(data) ?? throw new NullReferenceException($"Couldn't convert data to {nameof(OAuthPayload)}."); } catch (Exception ex) { logger.LogWarning(ex, "Couldn't decrypt the data OAuth request."); throw new InvalidOperationException("Couldn't decrypt the data.", ex); } } private Task StoreOAuthUserInCache(string key, OAuthUserExtension data, CancellationToken cancellation) => cache.SetAsync( key, JsonSerializer.SerializeToUtf8Bytes(data), slidingExpiration: TimeSpan.FromMinutes(15), cancellationToken: cancellation); public Uri GetProviderRedirect(CookieOptions cookieOptions, HttpContext context, string redirectUri, OAuthProvider provider, Uri callback) { var (clientId, _) = providers[provider]; var requestInfo = new RequestContextInfo(context, cookieOptions); var payload = EncryptPayload(new OAuthPayload() { Provider = provider, Callback = callback.AbsoluteUri }); var checksum = GetHmacString(requestInfo); var redirectUrl = $"?client_id={clientId}" + "&response_type=code" + $"&redirect_uri={redirectUri}" + $"&scope={ProviderData[provider].Scope}" + $"&state={Uri.EscapeDataString(payload + "_" + checksum)}" + "&prompt=select_account" + "&force_confirm=true"; logger.LogInformation("Redirecting user Fingerprint: {Fingerprint} to OAuth provider {Provider} with state: {State}", requestInfo.Fingerprint, provider, checksum); return new Uri(ProviderData[provider].RedirectUrl + redirectUrl); } public (OAuthProvider Provider, Uri Redirect)[] GetAvailableProviders(string redirectUri) => [.. providers.Select(x => (x.Key, new Uri(redirectUri.TrimEnd('/') + "/?provider=" + (int)x.Key)))]; public async Task LoginOAuth(CookieOptions cookieOptions, HttpContext context, string redirectUrl, string code, string state, CancellationToken cancellation = default) { var result = new LoginOAuth() { Token = GeneratorKey.GenerateBase64(32) }; var parts = state.Split('_'); if (parts.Length != 2) { result.ErrorMessage = "The request data is invalid or malformed."; await StoreOAuthUserInCache(result.Token, new OAuthUserExtension() { Message = result.ErrorMessage, Provider = null }, cancellation); return result; } var payload = DecryptPayload(parts[0]); var checksum = parts[1]; var cacheData = new OAuthUserExtension() { Provider = payload.Provider }; result.Callback = new Uri(payload.Callback); if (!providers.TryGetValue(payload.Provider, out var providerInfo) || !ProviderData.TryGetValue(payload.Provider, out var currentProviderStruct)) { logger.LogWarning("The OAuth provider specified in the payload " + "is not registered as a possible data recipient from state: {State}", state); result.ErrorMessage = "Invalid authorization request. Please try again later."; cacheData.Message = result.ErrorMessage; await StoreOAuthUserInCache(result.Token, cacheData, cancellation); return result; } var requestInfo = new RequestContextInfo(context, cookieOptions); var checksumRequest = GetHmacString(requestInfo); result.ErrorMessage = "Authorization failed. Please try again later."; cacheData.Message = result.ErrorMessage; if (checksumRequest != checksum) { logger.LogWarning( "Fingerprint mismatch. Possible CSRF attack detected. Fingerprint: {Fingerprint}, State: {State}, ExpectedState: {ExpectedState}", requestInfo.Fingerprint, checksumRequest, checksum ); await StoreOAuthUserInCache(result.Token, cacheData, cancellation); return result; } OAuthTokenResponse? accessToken; try { accessToken = await ExchangeCodeForTokensAsync(currentProviderStruct.TokenUrl, redirectUrl, code, providerInfo.ClientId, providerInfo.Secret, cancellation); } catch (Exception ex) { logger.LogWarning(ex, "Failed to exchange code for access token with provider {Provider}. State: {State}", payload.Provider, checksum); await StoreOAuthUserInCache(result.Token, cacheData, cancellation); return result; } if (accessToken == null) return result; OAuthUser? user; try { user = await GetUserProfileAsync(currentProviderStruct.UserInfoUrl, currentProviderStruct.AuthHeader, accessToken.AccessToken, payload.Provider, cancellation); } catch (Exception ex) { logger.LogWarning(ex, "Failed to retrieve user information from provider {Provider}", payload.Provider); await StoreOAuthUserInCache(result.Token, cacheData, cancellation); return result; } if (user == null) return result; result.ErrorMessage = null; result.Success = true; await StoreOAuthUserInCache(result.Token, new OAuthUserExtension { IsSuccess = true, User = user, Provider = payload.Provider }, cancellation); return result; } public async Task<(OAuthUser? User, string? Message, bool IsSuccess, OAuthProvider? Provider)> GetOAuthUser(CookieOptions cookieOptions, HttpContext context, string token, CancellationToken cancellation = default) { var requestInfo = new RequestContextInfo(context, cookieOptions); var result = await cache.GetAsync(token, cancellation); var tokenFailedKey = $"{requestInfo.Fingerprint}_oauth_token_failed"; if (result == null) { var failedTokenAttemptsCount = await cache.GetAsync( tokenFailedKey, cancellation) ?? 1; var failedTokenCacheExpiration = TimeSpan.FromHours(1); if (failedTokenAttemptsCount > 5) { logger.LogWarning( "Multiple unsuccessful token attempts detected. Token {Token}, Fingerprint: {Fingerprint}. Attempt count: {AttemptCount}.", token, requestInfo.Fingerprint, failedTokenAttemptsCount); return (null, "Too many unsuccessful token attempts. Please try again later.", false, null); } logger.LogInformation( "Cache data not found or expired for token: {Token}. Fingerprint: {Fingerprint}. Attempt count: {AttemptNumber}.", token, requestInfo.Fingerprint, failedTokenAttemptsCount); await cache.SetAsync(tokenFailedKey, failedTokenAttemptsCount + 1, slidingExpiration: failedTokenCacheExpiration, cancellationToken: cancellation); return (null, "Invalid or expired token.", false, null); } await cache.RemoveAsync(tokenFailedKey, cancellation); const string log = "Cache data retrieved for token: {Token}. Fingerprint: {Fingerprint}."; if (result.User != null) logger.LogInformation(log + " Provider: {Provider}. UserId: {UserId}.", token, requestInfo.Fingerprint, result.User.Id, result.Provider); else if (result.Provider != null) logger.LogInformation(log + " Provider: {Provider}.", token, requestInfo.Fingerprint, result.Provider); else logger.LogInformation(log, token, requestInfo.Fingerprint); return (result.User, result.Message, result.IsSuccess, result.Provider); } }