Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle database overloads and return 429 #141

Merged
merged 10 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/Tes/Repository/DatabaseOverloadedException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using System;

namespace Tes.Repository
{
public class DatabaseOverloadedException : Exception
{
public override string Message => "The database is currently overloaded; consider scaling the database up or reduce the number of requests";
}
}
187 changes: 130 additions & 57 deletions src/Tes/Repository/TesTaskPostgreSqlRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ namespace Tes.Repository
using System.Linq.Expressions;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Npgsql;
using Tes.Models;
using Tes.Utilities;

Expand All @@ -20,13 +22,15 @@ namespace Tes.Repository
public class TesTaskPostgreSqlRepository : IRepository<TesTask>
{
private readonly Func<TesDbContext> createDbContext;
private readonly ILogger logger;

/// <summary>
/// Default constructor that also will create the schema if it does not exist
/// </summary>
/// <param name="connectionString">The PostgreSql connection string</param>
public TesTaskPostgreSqlRepository(string connectionString)
public TesTaskPostgreSqlRepository(string connectionString, ILogger logger)
{
this.logger = logger;
createDbContext = () => { return new TesDbContext(connectionString); };
using var dbContext = createDbContext();
dbContext.Database.EnsureCreatedAsync().Wait();
Expand All @@ -36,8 +40,9 @@ public TesTaskPostgreSqlRepository(string connectionString)
/// Default constructor that also will create the schema if it does not exist
/// </summary>
/// <param name="connectionString">The PostgreSql connection string</param>
public TesTaskPostgreSqlRepository(IOptions<PostgreSqlOptions> options)
public TesTaskPostgreSqlRepository(IOptions<PostgreSqlOptions> options, ILogger logger)
{
this.logger = logger;
var connectionString = new ConnectionStringUtility().GetPostgresConnectionString(options);
createDbContext = () => { return new TesDbContext(connectionString); };
using var dbContext = createDbContext();
Expand All @@ -48,8 +53,9 @@ public TesTaskPostgreSqlRepository(IOptions<PostgreSqlOptions> options)
/// Constructor for testing to enable mocking DbContext
/// </summary>
/// <param name="createDbContext">A delegate that creates a TesTaskPostgreSqlRepository context</param>
public TesTaskPostgreSqlRepository(Func<TesDbContext> createDbContext)
public TesTaskPostgreSqlRepository(Func<TesDbContext> createDbContext, ILogger logger)
{
this.logger = logger;
this.createDbContext = createDbContext;
using var dbContext = createDbContext();
dbContext.Database.EnsureCreatedAsync().Wait();
Expand All @@ -63,18 +69,19 @@ public TesTaskPostgreSqlRepository(Func<TesDbContext> createDbContext)
/// <returns></returns>
public async Task<bool> TryGetItemAsync(string id, Action<TesTask> onSuccess = null)
{
using var dbContext = createDbContext();

// Search for Id within the JSON
var item = await dbContext.TesTasks.FirstOrDefaultAsync(t => t.Json.Id == id);

if (item is null)
return await ExecuteAsync(async dbContext =>
{
return false;
}
// Search for Id within the JSON
var item = await dbContext.TesTasks.FirstOrDefaultAsync(t => t.Json.Id == id);

onSuccess?.Invoke(item.Json);
return true;
if (item is null)
{
return false;
}

onSuccess?.Invoke(item.Json);
return true;
});
}

/// <summary>
Expand All @@ -84,14 +91,15 @@ public async Task<bool> TryGetItemAsync(string id, Action<TesTask> onSuccess = n
/// <returns></returns>
public async Task<IEnumerable<TesTask>> GetItemsAsync(Expression<Func<TesTask, bool>> predicate)
{
using var dbContext = createDbContext();

// Search for items in the JSON
var query = dbContext.TesTasks.Select(t => t.Json).Where(predicate);
return await ExecuteAsync(async dbContext =>
{
// Search for items in the JSON
var query = dbContext.TesTasks.Select(t => t.Json).Where(predicate);

//var sqlQuery = query.ToQueryString();
//Debugger.Break();
return await query.ToListAsync();
//var sqlQuery = query.ToQueryString();
//Debugger.Break();
return await query.ToListAsync();
});
}

/// <summary>
Expand All @@ -101,11 +109,13 @@ public async Task<IEnumerable<TesTask>> GetItemsAsync(Expression<Func<TesTask, b
/// <returns></returns>
public async Task<TesTask> CreateItemAsync(TesTask item)
{
using var dbContext = createDbContext();
var dbItem = new TesTaskDatabaseItem { Json = item };
dbContext.TesTasks.Add(dbItem);
await dbContext.SaveChangesAsync();
return item;
return await ExecuteAsync(async dbContext =>
{
var dbItem = new TesTaskDatabaseItem { Json = item };
dbContext.TesTasks.Add(dbItem);
await dbContext.SaveChangesAsync();
return item;
});
}

/// <summary>
Expand All @@ -115,16 +125,18 @@ public async Task<TesTask> CreateItemAsync(TesTask item)
/// <returns></returns>
public async Task<List<TesTask>> CreateItemsAsync(List<TesTask> items)
{
using var dbContext = createDbContext();

foreach (var item in items)
return await ExecuteAsync(async dbContext =>
{
var dbItem = new TesTaskDatabaseItem { Json = item };
dbContext.TesTasks.Add(dbItem);
}
foreach (var item in items)
{
var dbItem = new TesTaskDatabaseItem { Json = item };
dbContext.TesTasks.Add(dbItem);
}

await dbContext.SaveChangesAsync();
return items;
await dbContext.SaveChangesAsync();

return items;
});
}

/// <summary>
Expand All @@ -134,23 +146,24 @@ public async Task<List<TesTask>> CreateItemsAsync(List<TesTask> items)
/// <returns></returns>
public async Task<TesTask> UpdateItemAsync(TesTask tesTask)
{
using var dbContext = createDbContext();

// Manually set entity state to avoid potential NPG PostgreSql bug
dbContext.ChangeTracker.AutoDetectChangesEnabled = false;
var item = await dbContext.TesTasks.FirstOrDefaultAsync(t => t.Json.Id == tesTask.Id);

if (item is null)
return await ExecuteAsync(async dbContext =>
{
throw new Exception($"No TesTask with ID {tesTask.Id} found in the database.");
}
// Manually set entity state to avoid potential NPG PostgreSql bug
dbContext.ChangeTracker.AutoDetectChangesEnabled = false;
var item = await dbContext.TesTasks.FirstOrDefaultAsync(t => t.Json.Id == tesTask.Id);

if (item is null)
{
throw new Exception($"No TesTask with ID {tesTask.Id} found in the database.");
}

item.Json = tesTask;
item.Json = tesTask;

// Manually set entity state to avoid potential NPG PostgreSql bug
dbContext.Entry(item).State = EntityState.Modified;
await dbContext.SaveChangesAsync();
return item.Json;
// Manually set entity state to avoid potential NPG PostgreSql bug
dbContext.Entry(item).State = EntityState.Modified;
await dbContext.SaveChangesAsync();
return item.Json;
});
}

/// <summary>
Expand All @@ -160,16 +173,18 @@ public async Task<TesTask> UpdateItemAsync(TesTask tesTask)
/// <returns></returns>
public async Task DeleteItemAsync(string id)
{
using var dbContext = createDbContext();
var item = await dbContext.TesTasks.FirstOrDefaultAsync(t => t.Json.Id == id);

if (item is null)
await ExecuteAsync(async dbContext =>
{
throw new Exception($"No TesTask with ID {item.Id} found in the database.");
}
var item = await dbContext.TesTasks.FirstOrDefaultAsync(t => t.Json.Id == id);

dbContext.TesTasks.Remove(item);
await dbContext.SaveChangesAsync();
if (item is null)
{
throw new Exception($"No TesTask with ID {item.Id} found in the database.");
}

dbContext.TesTasks.Remove(item);
await dbContext.SaveChangesAsync();
});
}

/// <summary>
Expand All @@ -182,8 +197,66 @@ public async Task DeleteItemAsync(string id)
public async Task<(string, IEnumerable<TesTask>)> GetItemsAsync(Expression<Func<TesTask, bool>> predicate, int pageSize, string continuationToken)
{
// TODO paging support
var results = await GetItemsAsync(predicate);
return (null, results);
return (null, await GetItemsAsync(predicate));
}

private async Task<T> ExecuteAsync<T>(Func<TesDbContext, Task<T>> action)
{
try
{
using var dbContext = createDbContext();
return await action(dbContext);
}
catch (NpgsqlException npgEx) when (npgEx.InnerException is TimeoutException)
{
logger.LogError(npgEx, npgEx.Message);
throw LogDatabaseOverloadedException();
}
catch (InvalidOperationException ioEx) when (ioEx.InnerException is TimeoutException)
{
logger.LogError(ioEx, ioEx.Message);
throw LogDatabaseOverloadedException();
}
catch (InvalidOperationException ioEx) when
(ioEx.InnerException is NpgsqlException npgSqlEx
&& npgSqlEx.Message?.StartsWith("The connection pool has been exhausted", StringComparison.OrdinalIgnoreCase) == true)
{
logger.LogError(ioEx, ioEx.Message);
throw LogDatabaseOverloadedException();
}
}

private async Task ExecuteAsync(Func<TesDbContext, Task> action)
{
try
{
using var dbContext = createDbContext();
await action(dbContext);
}
catch (NpgsqlException npgEx) when (npgEx.InnerException is TimeoutException)
{
logger.LogError(npgEx, npgEx.Message);
throw LogDatabaseOverloadedException();
}
catch (InvalidOperationException ioEx) when (ioEx.InnerException is TimeoutException)
{
logger.LogError(ioEx, ioEx.Message);
throw LogDatabaseOverloadedException();
}
catch (InvalidOperationException ioEx) when
(ioEx.InnerException is NpgsqlException npgSqlEx
&& npgSqlEx.Message?.StartsWith("The connection pool has been exhausted", StringComparison.OrdinalIgnoreCase) == true)
{
logger.LogError(ioEx, ioEx.Message);
throw LogDatabaseOverloadedException();
}
}

public DatabaseOverloadedException LogDatabaseOverloadedException()
{
var exception = new DatabaseOverloadedException();
logger.LogCritical(exception, exception.Message);
return exception;
}

public void Dispose()
Expand Down
Loading