Skip to content

Commit

Permalink
make file downloading cancellable
Browse files Browse the repository at this point in the history
  • Loading branch information
sensslen committed Jan 2, 2024
1 parent 427aa8a commit 3e68d25
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 38 deletions.
17 changes: 12 additions & 5 deletions src/NuGetUtility/LicenseValidator/LicenseValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public LicenseValidator(IImmutableDictionary<Uri, string> licenseMapping,
}

public async Task<IEnumerable<LicenseValidationResult>> Validate(
IAsyncEnumerable<ReferencedPackageWithContext> packages)
IAsyncEnumerable<ReferencedPackageWithContext> packages,
CancellationToken token)
{
var result = new ConcurrentDictionary<LicenseNameAndVersion, LicenseValidationResult>();
await foreach (ReferencedPackageWithContext info in packages)
Expand All @@ -45,7 +46,7 @@ public async Task<IEnumerable<LicenseValidationResult>> Validate(
}
else if (info.PackageInfo.LicenseUrl != null)
{
await ValidateLicenseByUrl(info.PackageInfo, info.Context, result);
await ValidateLicenseByUrl(info.PackageInfo, info.Context, result, token);
}
else
{
Expand Down Expand Up @@ -150,14 +151,20 @@ private void ValidateLicenseByMetadata(IPackageMetadata info,

private async Task ValidateLicenseByUrl(IPackageMetadata info,
string context,
ConcurrentDictionary<LicenseNameAndVersion, LicenseValidationResult> result)
ConcurrentDictionary<LicenseNameAndVersion, LicenseValidationResult> result,
CancellationToken token)
{
if (info.LicenseUrl!.IsAbsoluteUri)
{
try
{
await _fileDownloader.DownloadFile(info.LicenseUrl,
$"{info.Identity.Id}__{info.Identity.Version}.html");
$"{info.Identity.Id}__{info.Identity.Version}.html",
token);
}
catch (OperationCanceledException)
{
// swallow cancellation
}
catch (Exception e)
{
Expand Down Expand Up @@ -207,7 +214,7 @@ private bool IsLicenseValid(string licenseId)
return true;
}

return _allowedLicenses.Any(l => l.Equals(licenseId));
return _allowedLicenses.Any(allowedLicense => allowedLicense.Equals(licenseId));
}

private string GetLicenseNotAllowedMessage(string license)
Expand Down
2 changes: 1 addition & 1 deletion src/NuGetUtility/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ private async Task<int> OnExecuteAsync(CancellationToken cancellationToken)
});
IAsyncEnumerable<ReferencedPackageWithContext> downloadedLicenseInformation =
packagesForProject.SelectMany(p => GetPackageInfos(p, overridePackageInformation, cancellationToken));
var results = (await validator.Validate(downloadedLicenseInformation)).ToList();
var results = (await validator.Validate(downloadedLicenseInformation, cancellationToken)).ToList();

if (projectReaderExceptions.Any())
{
Expand Down
12 changes: 6 additions & 6 deletions src/NuGetUtility/Wrapper/HttpClientWrapper/FileDownloader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,26 @@ public FileDownloader(HttpClient client, string downloadDirectory)
_downloadDirectory = downloadDirectory;
}

public async Task DownloadFile(Uri url, string fileName)
public async Task DownloadFile(Uri url, string fileName, CancellationToken token)
{
await _parallelDownloadLimiter.WaitAsync();
await _parallelDownloadLimiter.WaitAsync(token);
try
{
for (int i = 0; i < MAX_RETRIES; i++)
{
await using FileStream file = File.OpenWrite(Path.Combine(_downloadDirectory, fileName));
var request = new HttpRequestMessage(HttpMethod.Get, url);

HttpResponseMessage response = await _client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
HttpResponseMessage response = await _client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, token);
response.EnsureSuccessStatusCode();
if (response.StatusCode == System.Net.HttpStatusCode.TooManyRequests)
{
await Task.Delay((int)Math.Pow(EXPONENTIAL_BACKOFF_WAIT_TIME_MILLISECONDS, i + 1));
await Task.Delay((int)Math.Pow(EXPONENTIAL_BACKOFF_WAIT_TIME_MILLISECONDS, i + 1), token);
continue;
}
using Stream downloadStream = await response.Content.ReadAsStreamAsync();
using Stream downloadStream = await response.Content.ReadAsStreamAsync(token);

await downloadStream.CopyToAsync(file);
await downloadStream.CopyToAsync(file, token);
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
{
public interface IFileDownloader
{
public Task DownloadFile(Uri url, string fileName);
public Task DownloadFile(Uri url, string fileName, CancellationToken token);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
public class NopFileDownloader : IFileDownloader
{
public Task DownloadFile(Uri url, string fileName)
public Task DownloadFile(Uri url, string fileName, CancellationToken token)
{
return Task.CompletedTask;
}
Expand Down
57 changes: 33 additions & 24 deletions tests/NuGetUtility.Test/LicenseValidator/LicenseValidatorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,34 @@ public void SetUp()
_context = fixture.Create<string>();
_projectUrl = fixture.Create<Uri>();
_ignoredLicenses = fixture.Create<string[]>();
_token = new CancellationTokenSource();

_uut = new NuGetUtility.LicenseValidator.LicenseValidator(_licenseMapping,
_allowedLicenses,
_fileDownloader,
_ignoredLicenses);
}

[TearDown]
public void TearDown()
{
_token.Dispose();
}

private NuGetUtility.LicenseValidator.LicenseValidator _uut = null!;
private IImmutableDictionary<Uri, string> _licenseMapping = null!;
private IEnumerable<string> _allowedLicenses = null!;
private string _context = null!;
private IFileDownloader _fileDownloader = null!;
private Uri _projectUrl = null!;
private string[] _ignoredLicenses = null!;
private CancellationTokenSource _token = null!;

[Test]
public async Task ValidatingEmptyList_Should_ReturnEmptyValidatedLicenses()
{
IAsyncEnumerable<ReferencedPackageWithContext> emptyListToValidate = Enumerable.Empty<ReferencedPackageWithContext>().AsAsyncEnumerable();
IEnumerable<LicenseValidationResult> results = await _uut.Validate(emptyListToValidate);
IEnumerable<LicenseValidationResult> results = await _uut.Validate(emptyListToValidate, _token.Token);
CollectionAssert.AreEqual(Enumerable.Empty<LicenseValidationResult>(), results);
}

Expand Down Expand Up @@ -101,7 +109,7 @@ public async Task ValidatingLicenses_Should_IgnorePackage_If_PackageNameMatchesE

IPackageMetadata package = SetupPackage(packageId, packageVersion);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -129,7 +137,7 @@ public async Task ValidatingLicenses_Should_NotIgnorePackage_If_PackageNameDoesN

IPackageMetadata package = SetupPackageWithExpressionLicenseInformation(packageId, packageVersion, license);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -159,7 +167,7 @@ public async Task ValidatingLicenses_Should_IgnorePackage_If_IgnoreWildcardMatch

IPackageMetadata package = SetupPackage(packageId, packageVersion);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -189,7 +197,7 @@ public async Task ValidatingLicenses_Should_IgnorePackage_If_IgnoreWildcardMatch

IPackageMetadata package = SetupPackage(packageId, packageVersion);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -220,7 +228,7 @@ public async Task ValidatingLicenses_Should_IgnorePackage_If_IgnoreWildcardMatch

IPackageMetadata package = SetupPackage(packageId, packageVersion);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand All @@ -247,7 +255,7 @@ public async Task ValidatingLicenses_Should_IgnorePackage_If_IgnoreWildcardMatch

IPackageMetadata package = SetupPackage(packageId, packageVersion);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -275,7 +283,7 @@ public async Task ValidatingLicensesWithExpressionLicenseInformation_Should_Give

IPackageMetadata package = SetupPackageWithExpressionLicenseInformation(packageId, packageVersion, license);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -303,7 +311,7 @@ public async Task ValidatingLicensesWithOverwriteLicenseInformation_Should_GiveC

IPackageMetadata package = SetupPackageWithOverwriteLicenseInformation(packageId, packageVersion, license);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -340,7 +348,7 @@ public async Task ValidatingLicensesWithMatchingLicenseUrl_Should_GiveCorrectVal
KeyValuePair<Uri, string> mappingLicense = _licenseMapping.Shuffle(34561).First();
IPackageMetadata package = SetupPackageWithLicenseUrl(packageId, packageVersion, mappingLicense.Key);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -368,7 +376,7 @@ public async Task ValidatingLicensesWithMatchingLicenseUrl_Should_GiveCorrectVal

IPackageMetadata package = SetupPackageWithLicenseUrl(packageId, packageVersion, licenseUrl);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -399,7 +407,7 @@ public async Task ValidatingLicensesWithNotSupportedLicenseMetadata_Should_GiveC

IPackageMetadata package = SetupPackageWithLicenseInformationOfType(packageId, packageVersion, license, licenseType);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -431,7 +439,7 @@ public async Task ValidatingLicensesWithoutLicenseInformation_Should_GiveCorrect

IPackageMetadata package = SetupPackage(packageId, packageVersion);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -459,7 +467,7 @@ public async Task ValidatingLicensesWithExpressionLicenseInformation_Should_Give
{
IPackageMetadata package = SetupPackageWithExpressionLicenseInformation(packageId, packageVersion, license);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -487,7 +495,7 @@ public async Task ValidatingLicensesWithOverwriteLicenseInformation_Should_GiveC
{
IPackageMetadata package = SetupPackageWithOverwriteLicenseInformation(packageId, packageVersion, license);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -515,7 +523,7 @@ public async Task ValidatingLicensesWithExpressionLicenseInformation_Should_Give
string validLicense = _allowedLicenses.Shuffle(135643).First();
IPackageMetadata package = SetupPackageWithExpressionLicenseInformation(packageId, packageVersion, validLicense);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand All @@ -538,7 +546,7 @@ public async Task ValidatingLicensesWithOverwriteLicenseInformation_Should_GiveC
string validLicense = _allowedLicenses.Shuffle(135643).First();
IPackageMetadata package = SetupPackageWithOverwriteLicenseInformation(packageId, packageVersion, validLicense);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand All @@ -561,7 +569,7 @@ public async Task ValidatingLicensesWithMatchingUrlInformation_Should_GiveCorrec
KeyValuePair<Uri, string> urlMatch = _licenseMapping.Shuffle(765).First();
IPackageMetadata package = SetupPackageWithLicenseUrl(packageId, packageVersion, urlMatch.Key);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down Expand Up @@ -589,10 +597,11 @@ public async Task ValidatingLicensesWithUrlInformation_Should_StartDownloadingSa
KeyValuePair<Uri, string> urlMatch = _licenseMapping.Shuffle(4567).First();
IPackageMetadata package = SetupPackageWithLicenseUrl(packageId, packageVersion, urlMatch.Key);

_ = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
_ = await _uut.Validate(CreateInput(package, _context), _token.Token);

await _fileDownloader.Received(1).DownloadFile(package.LicenseUrl!,
$"{package.Identity.Id}__{package.Identity.Version}.html");
$"{package.Identity.Id}__{package.Identity.Version}.html",
_token.Token);
}

[Test]
Expand All @@ -603,11 +612,11 @@ public void ValidatingLicensesWithUrlInformation_Should_ThrowLicenseDownloadInfo
{
KeyValuePair<Uri, string> urlMatch = _licenseMapping.Shuffle(12345).First();
IPackageMetadata package = SetupPackageWithLicenseUrl(packageId, packageVersion, urlMatch.Key);
_fileDownloader.When(m => m.DownloadFile(package.LicenseUrl!, Arg.Any<string>()))
_fileDownloader.When(m => m.DownloadFile(package.LicenseUrl!, Arg.Any<string>(), Arg.Any<CancellationToken>()))
.Do(_ => throw new Exception());

LicenseDownloadException? exception =
Assert.ThrowsAsync<LicenseDownloadException>(() => _uut.Validate(LicenseValidatorTest.CreateInput(package, _context)));
Assert.ThrowsAsync<LicenseDownloadException>(() => _uut.Validate(CreateInput(package, _context), _token.Token));
Assert.IsInstanceOf<Exception>(exception!.InnerException);
Assert.AreEqual(
$"Failed to download license for package {packageId} ({packageVersion}).\nContext: {_context}",
Expand All @@ -627,7 +636,7 @@ public async Task ValidatingLicensesWithMatchingUrlInformation_Should_GiveCorrec
_ignoredLicenses);
IPackageMetadata package = SetupPackageWithLicenseUrl(packageId, packageVersion, urlMatch.Key);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand All @@ -650,7 +659,7 @@ public async Task ValidatingLicensesWithNotMatchingUrlInformation_Should_GiveCor
{
IPackageMetadata package = SetupPackageWithLicenseUrl(packageId, packageVersion, licenseUrl);

IEnumerable<LicenseValidationResult> result = await _uut.Validate(LicenseValidatorTest.CreateInput(package, _context));
IEnumerable<LicenseValidationResult> result = await _uut.Validate(CreateInput(package, _context), _token.Token);

Assert.That(result,
Is.EquivalentTo(new[]
Expand Down

0 comments on commit 3e68d25

Please sign in to comment.