MireaBackend/Security/Services/OAuthService.cs

292 lines
11 KiB
C#
Raw Normal View History

2024-11-04 02:39:10 +03:00
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Mirea.Api.Security.Common.Domain;
using Mirea.Api.Security.Common.Domain.OAuth2;
using Mirea.Api.Security.Common.Domain.OAuth2.UserInfo;
using Mirea.Api.Security.Common.Interfaces;
using System;
using System.Collections.Generic;
2024-12-26 08:47:56 +03:00
using System.IO;
2024-11-04 02:39:10 +03:00
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
2024-12-18 07:24:33 +03:00
using System.Security.Cryptography;
using System.Text;
2024-11-04 02:39:10 +03:00
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace Mirea.Api.Security.Services;
public class OAuthService(ILogger<OAuthService> logger, Dictionary<OAuthProvider, (string ClientId, string Secret)> providers,
ICacheService cache)
2024-11-04 02:39:10 +03:00
{
public required ReadOnlyMemory<byte> SecretKey { private get; init; }
2024-11-04 02:39:10 +03:00
private static readonly Dictionary<OAuthProvider, OAuthProviderUrisData> ProviderData = new()
{
[OAuthProvider.Google] = new OAuthProviderUrisData
{
RedirectUrl = "https://accounts.google.com/o/oauth2/v2/auth",
TokenUrl = "https://oauth2.googleapis.com/token",
UserInfoUrl = "https://www.googleapis.com/oauth2/v2/userinfo",
Scope = "openid email profile",
AuthHeader = "Bearer",
UserInfoType = typeof(GoogleUserInfo)
},
[OAuthProvider.Yandex] = new OAuthProviderUrisData
{
RedirectUrl = "https://oauth.yandex.ru/authorize",
TokenUrl = "https://oauth.yandex.ru/token",
UserInfoUrl = "https://login.yandex.ru/info?format=json",
Scope = "login:email login:info login:avatar",
AuthHeader = "OAuth",
UserInfoType = typeof(YandexUserInfo)
},
[OAuthProvider.MailRu] = new OAuthProviderUrisData
{
RedirectUrl = "https://oauth.mail.ru/login",
TokenUrl = "https://oauth.mail.ru/token",
UserInfoUrl = "https://oauth.mail.ru/userinfo",
AuthHeader = "",
Scope = "",
UserInfoType = typeof(MailRuUserInfo)
}
};
private static async Task<OAuthTokenResponse?> ExchangeCodeForTokensAsync(string requestUri, string redirectUrl, string code,
string clientId, string secret, CancellationToken cancellation)
2024-11-04 02:39:10 +03:00
{
var tokenRequest = new HttpRequestMessage(HttpMethod.Post, requestUri)
{
Content = new FormUrlEncodedContent(new Dictionary<string, string>
{
{ "code", code },
{ "client_id", clientId },
{ "client_secret", secret },
{ "redirect_uri", redirectUrl},
{ "grant_type", "authorization_code" }
})
};
using var httpClient = new HttpClient();
httpClient.DefaultRequestHeaders.UserAgent.ParseAdd("MireaSchedule/1.0 (Winsomnia)");
var response = await httpClient.SendAsync(tokenRequest, cancellation);
var data = await response.Content.ReadAsStringAsync(cancellation);
if (!response.IsSuccessStatusCode)
throw new HttpRequestException(data);
return JsonSerializer.Deserialize<OAuthTokenResponse>(data);
}
private static async Task<OAuthUser?> GetUserProfileAsync(string requestUri, string authHeader, string accessToken, OAuthProvider provider,
CancellationToken cancellation)
2024-11-04 02:39:10 +03:00
{
var request = new HttpRequestMessage(HttpMethod.Get, requestUri);
if (string.IsNullOrEmpty(authHeader))
request.RequestUri = new Uri(request.RequestUri?.AbsoluteUri + "?access_token=" + accessToken);
else
request.Headers.Authorization = new AuthenticationHeaderValue(authHeader, accessToken);
using var httpClient = new HttpClient();
httpClient.DefaultRequestHeaders.UserAgent.ParseAdd("MireaSchedule/1.0 (Winsomnia)");
var response = await httpClient.SendAsync(request, cancellation);
var data = await response.Content.ReadAsStringAsync(cancellation);
if (!response.IsSuccessStatusCode)
throw new HttpRequestException(data);
var userInfo = JsonSerializer.Deserialize(data, ProviderData[provider].UserInfoType) as IUserInfo;
return userInfo?.MapToInternalUser();
}
private string GetHmacString(RequestContextInfo contextInfo)
2024-12-18 07:24:33 +03:00
{
var hmac = new HMACSHA256(SecretKey.ToArray());
2024-12-18 07:24:33 +03:00
return Convert.ToBase64String(hmac.ComputeHash(
Encoding.UTF8.GetBytes($"{contextInfo.Fingerprint}_{contextInfo.Ip}_{contextInfo.UserAgent}")));
}
2024-12-26 08:47:56 +03:00
private string EncryptPayload(OAuthPayload payload)
{
var data = JsonSerializer.Serialize(payload);
var aes = Aes.Create();
aes.Key = SecretKey.ToArray();
aes.GenerateIV();
using var encryptor = aes.CreateEncryptor(aes.Key, aes.IV);
using var ms = new MemoryStream();
ms.Write(aes.IV, 0, aes.IV.Length);
using (var cs = new CryptoStream(ms, encryptor, CryptoStreamMode.Write))
using (var writer = new StreamWriter(cs))
{
writer.Write(data);
}
return Convert.ToBase64String(ms.ToArray());
}
private OAuthPayload DecryptPayload(string encryptedData)
{
try
{
var cipherBytes = Convert.FromBase64String(encryptedData);
using var aes = Aes.Create();
aes.Key = SecretKey.ToArray();
var iv = new byte[16];
Array.Copy(cipherBytes, 0, iv, 0, iv.Length);
aes.IV = iv;
using var ms = new MemoryStream(cipherBytes, 16, cipherBytes.Length - 16);
using var decryptor = aes.CreateDecryptor(aes.Key, aes.IV);
using var cs = new CryptoStream(ms, decryptor, CryptoStreamMode.Read);
using var reader = new StreamReader(cs);
var data = reader.ReadToEnd();
return JsonSerializer.Deserialize<OAuthPayload>(data) ??
throw new NullReferenceException($"Couldn't convert data to {nameof(OAuthPayload)}.");
}
catch (Exception ex)
{
logger.LogWarning(ex, "Couldn't decrypt the data OAuth request.");
throw new InvalidOperationException("Couldn't decrypt the data.", ex);
}
}
public Uri GetProviderRedirect(HttpContext context, CookieOptionsParameters cookieOptions, string redirectUri,
OAuthProvider provider, Uri callback)
2024-11-04 02:39:10 +03:00
{
2024-12-22 07:25:41 +03:00
var (clientId, _) = providers[provider];
2024-12-23 07:48:28 +03:00
var requestInfo = new RequestContextInfo(context, cookieOptions);
2024-12-26 08:47:56 +03:00
var payload = EncryptPayload(new OAuthPayload()
{
Provider = provider,
Callback = callback.AbsoluteUri
});
var checksum = GetHmacString(requestInfo);
2024-12-23 07:48:28 +03:00
2024-12-22 07:25:41 +03:00
var redirectUrl = $"?client_id={clientId}" +
2024-11-04 02:39:10 +03:00
"&response_type=code" +
$"&redirect_uri={redirectUri}" +
$"&scope={ProviderData[provider].Scope}" +
2024-12-26 08:47:56 +03:00
$"&state={Uri.EscapeDataString(payload + "_" + checksum)}";
2024-12-23 07:48:28 +03:00
logger.LogInformation("Redirecting user Fingerprint: {Fingerprint} to OAuth provider {Provider} with state: {State}",
requestInfo.Fingerprint,
provider,
2024-12-26 08:47:56 +03:00
checksum);
return new Uri(ProviderData[provider].RedirectUrl + redirectUrl);
}
2024-12-23 07:48:28 +03:00
public (OAuthProvider Provider, Uri Redirect)[] GetAvailableProviders(string redirectUri) =>
[.. providers.Select(x => (x.Key, new Uri(redirectUri.TrimEnd('/') + "/?provider=" + (int)x.Key)))];
2024-11-04 02:39:10 +03:00
public async Task<LoginOAuthResult> LoginOAuth(HttpContext context, CookieOptionsParameters cookieOptions,
string redirectUrl, string code, string state, CancellationToken cancellation = default)
2024-11-04 02:39:10 +03:00
{
var result = new LoginOAuthResult()
{
Token = GeneratorKey.GenerateBase64(32)
};
2024-12-26 08:47:56 +03:00
var parts = state.Split('_');
2024-11-04 02:39:10 +03:00
2024-12-26 08:47:56 +03:00
if (parts.Length != 2)
2024-11-04 02:39:10 +03:00
{
result.ErrorMessage = "The request data is invalid or malformed.";
return result;
2024-11-04 02:39:10 +03:00
}
2024-12-26 08:47:56 +03:00
var payload = DecryptPayload(parts[0]);
var checksum = parts[1];
result.Callback = new Uri(payload.Callback);
2024-12-26 08:47:56 +03:00
if (!providers.TryGetValue(payload.Provider, out var providerInfo) ||
!ProviderData.TryGetValue(payload.Provider, out var currentProviderStruct))
{
logger.LogWarning("The OAuth provider specified in the payload " +
"is not registered as a possible data recipient from state: {State}",
state);
result.ErrorMessage = "Invalid authorization request. Please try again later.";
return result;
2024-12-26 08:47:56 +03:00
}
2024-12-23 07:48:28 +03:00
var requestInfo = new RequestContextInfo(context, cookieOptions);
2024-12-26 08:47:56 +03:00
var checksumRequest = GetHmacString(requestInfo);
2024-11-04 02:39:10 +03:00
result.ErrorMessage = "Authorization failed. Please try again later.";
2024-12-26 08:47:56 +03:00
if (checksumRequest != checksum)
2024-11-04 02:39:10 +03:00
{
2024-12-23 07:48:28 +03:00
logger.LogWarning(
"Fingerprint mismatch. Possible CSRF attack detected. Fingerprint: {Fingerprint}, State: {State}, ExpectedState: {ExpectedState}",
requestInfo.Fingerprint,
2024-12-26 08:47:56 +03:00
checksumRequest,
checksum
2024-12-23 07:48:28 +03:00
);
return result;
2024-11-04 02:39:10 +03:00
}
OAuthTokenResponse? accessToken;
2024-11-04 02:39:10 +03:00
try
{
accessToken = await ExchangeCodeForTokensAsync(currentProviderStruct.TokenUrl, redirectUrl, code, providerInfo.ClientId,
providerInfo.Secret, cancellation);
2024-11-04 02:39:10 +03:00
}
catch (Exception ex)
{
logger.LogWarning(ex, "Failed to exchange code for access token with provider {Provider}. State: {State}",
payload.Provider,
checksum);
return result;
2024-11-04 02:39:10 +03:00
}
if (accessToken == null)
return result;
2024-11-04 02:39:10 +03:00
OAuthUser? user;
2024-11-04 02:39:10 +03:00
try
{
user = await GetUserProfileAsync(currentProviderStruct.UserInfoUrl, currentProviderStruct.AuthHeader, accessToken.AccessToken,
payload.Provider, cancellation);
2024-11-04 02:39:10 +03:00
}
catch (Exception ex)
{
logger.LogWarning(ex, "Failed to retrieve user information from provider {Provider}",
payload.Provider);
return result;
2024-11-04 02:39:10 +03:00
}
if (user == null)
return result;
result.ErrorMessage = null;
result.Success = true;
await cache.SetAsync(
result.Token,
JsonSerializer.SerializeToUtf8Bytes(user),
absoluteExpirationRelativeToNow: TimeSpan.FromMinutes(15),
cancellationToken: cancellation);
2024-11-04 02:39:10 +03:00
return result;
2024-11-04 02:39:10 +03:00
}
}