diff --git a/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml b/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml index 64af657..ef20cb0 100644 --- a/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml +++ b/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml @@ -2,6 +2,7 @@ \ No newline at end of file diff --git a/App/Models/RpcModel.cs b/App/Models/RpcModel.cs index 034f405..08d2303 100644 --- a/App/Models/RpcModel.cs +++ b/App/Models/RpcModel.cs @@ -1,4 +1,7 @@ +using System; using System.Collections.Generic; +using System.Diagnostics; +using Coder.Desktop.App.Converters; using Coder.Desktop.Vpn.Proto; namespace Coder.Desktop.App.Models; @@ -19,11 +22,168 @@ public enum VpnLifecycle Stopping, } +public enum VpnStartupStage +{ + Unknown, + Initializing, + Downloading, + Finalizing, +} + +public class VpnDownloadProgress +{ + public ulong BytesWritten { get; set; } = 0; + public ulong? BytesTotal { get; set; } = null; // null means unknown total size + + public double Progress + { + get + { + if (BytesTotal is > 0) + { + return (double)BytesWritten / BytesTotal.Value; + } + return 0.0; + } + } + + public override string ToString() + { + // TODO: it would be nice if the two suffixes could match + var s = FriendlyByteConverter.FriendlyBytes(BytesWritten); + if (BytesTotal != null) + s += $" of {FriendlyByteConverter.FriendlyBytes(BytesTotal.Value)}"; + else + s += " of unknown"; + if (BytesTotal != null) + s += $" ({Progress:0%})"; + return s; + } + + public VpnDownloadProgress Clone() + { + return new VpnDownloadProgress + { + BytesWritten = BytesWritten, + BytesTotal = BytesTotal, + }; + } + + public static VpnDownloadProgress FromProto(StartProgressDownloadProgress proto) + { + return new VpnDownloadProgress + { + BytesWritten = proto.BytesWritten, + BytesTotal = proto.HasBytesTotal ? proto.BytesTotal : null, + }; + } +} + +public class VpnStartupProgress +{ + public const string DefaultStartProgressMessage = "Starting Coder Connect..."; + + // Scale the download progress to an overall progress value between these + // numbers. + private const double DownloadProgressMin = 0.05; + private const double DownloadProgressMax = 0.80; + + public VpnStartupStage Stage { get; init; } = VpnStartupStage.Unknown; + public VpnDownloadProgress? DownloadProgress { get; init; } = null; + + // 0.0 to 1.0 + public double Progress + { + get + { + switch (Stage) + { + case VpnStartupStage.Unknown: + case VpnStartupStage.Initializing: + return 0.0; + case VpnStartupStage.Downloading: + var progress = DownloadProgress?.Progress ?? 0.0; + return DownloadProgressMin + (DownloadProgressMax - DownloadProgressMin) * progress; + case VpnStartupStage.Finalizing: + return DownloadProgressMax; + default: + throw new ArgumentOutOfRangeException(); + } + } + } + + public override string ToString() + { + switch (Stage) + { + case VpnStartupStage.Unknown: + case VpnStartupStage.Initializing: + return DefaultStartProgressMessage; + case VpnStartupStage.Downloading: + var s = "Downloading Coder Connect binary..."; + if (DownloadProgress is not null) + { + s += "\n" + DownloadProgress; + } + + return s; + case VpnStartupStage.Finalizing: + return "Finalizing Coder Connect startup..."; + default: + throw new ArgumentOutOfRangeException(); + } + } + + public VpnStartupProgress Clone() + { + return new VpnStartupProgress + { + Stage = Stage, + DownloadProgress = DownloadProgress?.Clone(), + }; + } + + public static VpnStartupProgress FromProto(StartProgress proto) + { + return new VpnStartupProgress + { + Stage = proto.Stage switch + { + StartProgressStage.Initializing => VpnStartupStage.Initializing, + StartProgressStage.Downloading => VpnStartupStage.Downloading, + StartProgressStage.Finalizing => VpnStartupStage.Finalizing, + _ => VpnStartupStage.Unknown, + }, + DownloadProgress = proto.Stage is StartProgressStage.Downloading ? + VpnDownloadProgress.FromProto(proto.DownloadProgress) : + null, + }; + } +} + public class RpcModel { public RpcLifecycle RpcLifecycle { get; set; } = RpcLifecycle.Disconnected; - public VpnLifecycle VpnLifecycle { get; set; } = VpnLifecycle.Unknown; + public VpnLifecycle VpnLifecycle + { + get; + set + { + if (VpnLifecycle != value && value == VpnLifecycle.Starting) + // Reset the startup progress when the VPN lifecycle changes to + // Starting. + VpnStartupProgress = null; + field = value; + } + } + + // Nullable because it is only set when the VpnLifecycle is Starting + public VpnStartupProgress? VpnStartupProgress + { + get => VpnLifecycle is VpnLifecycle.Starting ? field ?? new VpnStartupProgress() : null; + set; + } public IReadOnlyList Workspaces { get; set; } = []; @@ -35,6 +195,7 @@ public RpcModel Clone() { RpcLifecycle = RpcLifecycle, VpnLifecycle = VpnLifecycle, + VpnStartupProgress = VpnStartupProgress?.Clone(), Workspaces = Workspaces, Agents = Agents, }; diff --git a/App/Services/RpcController.cs b/App/Services/RpcController.cs index 7beff66..168a1be 100644 --- a/App/Services/RpcController.cs +++ b/App/Services/RpcController.cs @@ -161,7 +161,10 @@ public async Task StartVpn(CancellationToken ct = default) throw new RpcOperationException( $"Cannot start VPN without valid credentials, current state: {credentials.State}"); - MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; }); + MutateState(state => + { + state.VpnLifecycle = VpnLifecycle.Starting; + }); ServiceMessage reply; try @@ -283,15 +286,28 @@ private void ApplyStatusUpdate(Status status) }); } + private void ApplyStartProgressUpdate(StartProgress message) + { + MutateState(state => + { + // The model itself will ignore this value if we're not in the + // starting state. + state.VpnStartupProgress = VpnStartupProgress.FromProto(message); + }); + } + private void SpeakerOnReceive(ReplyableRpcMessage message) { switch (message.Message.MsgCase) { + case ServiceMessage.MsgOneofCase.Start: + case ServiceMessage.MsgOneofCase.Stop: case ServiceMessage.MsgOneofCase.Status: ApplyStatusUpdate(message.Message.Status); break; - case ServiceMessage.MsgOneofCase.Start: - case ServiceMessage.MsgOneofCase.Stop: + case ServiceMessage.MsgOneofCase.StartProgress: + ApplyStartProgressUpdate(message.Message.StartProgress); + break; case ServiceMessage.MsgOneofCase.None: default: // TODO: log unexpected message diff --git a/App/ViewModels/TrayWindowViewModel.cs b/App/ViewModels/TrayWindowViewModel.cs index d8b3182..820ff12 100644 --- a/App/ViewModels/TrayWindowViewModel.cs +++ b/App/ViewModels/TrayWindowViewModel.cs @@ -29,7 +29,6 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost { private const int MaxAgents = 5; private const string DefaultDashboardUrl = "https://coder.com"; - private const string DefaultHostnameSuffix = ".coder"; private readonly IServiceProvider _services; private readonly IRpcController _rpcController; @@ -53,6 +52,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost [ObservableProperty] [NotifyPropertyChangedFor(nameof(ShowEnableSection))] + [NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))] [NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))] [NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))] [NotifyPropertyChangedFor(nameof(ShowAgentsSection))] @@ -63,6 +63,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost [ObservableProperty] [NotifyPropertyChangedFor(nameof(ShowEnableSection))] + [NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))] [NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))] [NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))] [NotifyPropertyChangedFor(nameof(ShowAgentsSection))] @@ -70,7 +71,25 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost [NotifyPropertyChangedFor(nameof(ShowFailedSection))] public partial string? VpnFailedMessage { get; set; } = null; - public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Started; + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(VpnStartProgressIsIndeterminate))] + [NotifyPropertyChangedFor(nameof(VpnStartProgressValueOrDefault))] + public partial int? VpnStartProgressValue { get; set; } = null; + + public int VpnStartProgressValueOrDefault => VpnStartProgressValue ?? 0; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(VpnStartProgressMessageOrDefault))] + public partial string? VpnStartProgressMessage { get; set; } = null; + + public string VpnStartProgressMessageOrDefault => + string.IsNullOrEmpty(VpnStartProgressMessage) ? VpnStartupProgress.DefaultStartProgressMessage : VpnStartProgressMessage; + + public bool VpnStartProgressIsIndeterminate => VpnStartProgressValueOrDefault == 0; + + public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Starting and not VpnLifecycle.Started; + + public bool ShowVpnStartProgressSection => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Starting; public bool ShowWorkspacesHeader => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Started; @@ -170,6 +189,20 @@ private void UpdateFromRpcModel(RpcModel rpcModel) VpnLifecycle = rpcModel.VpnLifecycle; VpnSwitchActive = rpcModel.VpnLifecycle is VpnLifecycle.Starting or VpnLifecycle.Started; + // VpnStartupProgress is only set when the VPN is starting. + if (rpcModel.VpnLifecycle is VpnLifecycle.Starting && rpcModel.VpnStartupProgress != null) + { + // Convert 0.00-1.00 to 0-100. + var progress = (int)(rpcModel.VpnStartupProgress.Progress * 100); + VpnStartProgressValue = Math.Clamp(progress, 0, 100); + VpnStartProgressMessage = rpcModel.VpnStartupProgress.ToString(); + } + else + { + VpnStartProgressValue = null; + VpnStartProgressMessage = null; + } + // Add every known agent. HashSet workspacesWithAgents = []; List agents = []; diff --git a/App/Views/Pages/TrayWindowLoginRequiredPage.xaml b/App/Views/Pages/TrayWindowLoginRequiredPage.xaml index c1d69aa..171e292 100644 --- a/App/Views/Pages/TrayWindowLoginRequiredPage.xaml +++ b/App/Views/Pages/TrayWindowLoginRequiredPage.xaml @@ -36,7 +36,7 @@ diff --git a/App/Views/Pages/TrayWindowMainPage.xaml b/App/Views/Pages/TrayWindowMainPage.xaml index 283867d..f488454 100644 --- a/App/Views/Pages/TrayWindowMainPage.xaml +++ b/App/Views/Pages/TrayWindowMainPage.xaml @@ -43,6 +43,8 @@ + + + HorizontalContentAlignment="Left"> diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index 986ce46..bb9b39c 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -277,8 +277,8 @@ public async Task Download(CancellationToken ct) var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, NullDownloadValidator.Instance, ct); await dlTask.Task; - Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); - Assert.That(dlTask.BytesRead, Is.EqualTo(4)); + Assert.That(dlTask.BytesTotal, Is.EqualTo(4)); + Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); Assert.That(dlTask.Progress, Is.EqualTo(1)); Assert.That(dlTask.IsCompleted, Is.True); Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); @@ -300,18 +300,62 @@ public async Task DownloadSameDest(CancellationToken ct) NullDownloadValidator.Instance, ct); var dlTask0 = await startTask0; await dlTask0.Task; - Assert.That(dlTask0.TotalBytes, Is.EqualTo(5)); - Assert.That(dlTask0.BytesRead, Is.EqualTo(5)); + Assert.That(dlTask0.BytesTotal, Is.EqualTo(5)); + Assert.That(dlTask0.BytesWritten, Is.EqualTo(5)); Assert.That(dlTask0.Progress, Is.EqualTo(1)); Assert.That(dlTask0.IsCompleted, Is.True); var dlTask1 = await startTask1; await dlTask1.Task; - Assert.That(dlTask1.TotalBytes, Is.EqualTo(5)); - Assert.That(dlTask1.BytesRead, Is.EqualTo(5)); + Assert.That(dlTask1.BytesTotal, Is.EqualTo(5)); + Assert.That(dlTask1.BytesWritten, Is.EqualTo(5)); Assert.That(dlTask1.Progress, Is.EqualTo(1)); Assert.That(dlTask1.IsCompleted, Is.True); } + [Test(Description = "Download with X-Original-Content-Length")] + [CancelAfter(30_000)] + public async Task DownloadWithXOriginalContentLength(CancellationToken ct) + { + using var httpServer = new TestHttpServer(async ctx => + { + ctx.Response.StatusCode = 200; + ctx.Response.Headers.Add("X-Original-Content-Length", "4"); + ctx.Response.ContentType = "text/plain"; + // Don't set Content-Length. + await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + var manager = new Downloader(NullLogger.Instance); + var req = new HttpRequestMessage(HttpMethod.Get, url); + var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct); + + await dlTask.Task; + Assert.That(dlTask.BytesTotal, Is.EqualTo(4)); + Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); + } + + [Test(Description = "Download with mismatched Content-Length")] + [CancelAfter(30_000)] + public async Task DownloadWithMismatchedContentLength(CancellationToken ct) + { + using var httpServer = new TestHttpServer(async ctx => + { + ctx.Response.StatusCode = 200; + ctx.Response.Headers.Add("X-Original-Content-Length", "5"); // incorrect + ctx.Response.ContentType = "text/plain"; + await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + var manager = new Downloader(NullLogger.Instance); + var req = new HttpRequestMessage(HttpMethod.Get, url); + var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct); + + var ex = Assert.ThrowsAsync(() => dlTask.Task); + Assert.That(ex.Message, Is.EqualTo("Downloaded file size does not match expected response content length: Expected=5, BytesWritten=4")); + } + [Test(Description = "Download with custom headers")] [CancelAfter(30_000)] public async Task WithHeaders(CancellationToken ct) @@ -347,7 +391,7 @@ public async Task DownloadExisting(CancellationToken ct) var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, NullDownloadValidator.Instance, ct); await dlTask.Task; - Assert.That(dlTask.BytesRead, Is.Zero); + Assert.That(dlTask.BytesWritten, Is.Zero); Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); Assert.That(File.GetLastWriteTime(destPath), Is.LessThan(DateTime.Now - TimeSpan.FromDays(1))); } @@ -368,7 +412,7 @@ public async Task DownloadExistingDifferentContent(CancellationToken ct) var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, NullDownloadValidator.Instance, ct); await dlTask.Task; - Assert.That(dlTask.BytesRead, Is.EqualTo(4)); + Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); Assert.That(File.GetLastWriteTime(destPath), Is.GreaterThan(DateTime.Now - TimeSpan.FromDays(1))); } diff --git a/Vpn.Proto/vpn.proto b/Vpn.Proto/vpn.proto index 2561a4b..bace7e0 100644 --- a/Vpn.Proto/vpn.proto +++ b/Vpn.Proto/vpn.proto @@ -60,7 +60,8 @@ message ServiceMessage { oneof msg { StartResponse start = 2; StopResponse stop = 3; - Status status = 4; // either in reply to a StatusRequest or broadcasted + Status status = 4; // either in reply to a StatusRequest or broadcasted + StartProgress start_progress = 5; // broadcasted during startup } } @@ -218,6 +219,28 @@ message StartResponse { string error_message = 2; } +// StartProgress is sent from the manager to the client to indicate the +// download/startup progress of the tunnel. This will be sent during the +// processing of a StartRequest before the StartResponse is sent. +// +// Note: this is currently a broadcasted message to all clients due to the +// inability to easily send messages to a specific client in the Speaker +// implementation. If clients are not expecting these messages, they +// should ignore them. +enum StartProgressStage { + Initializing = 0; + Downloading = 1; + Finalizing = 2; +} +message StartProgressDownloadProgress { + uint64 bytes_written = 1; + optional uint64 bytes_total = 2; // unknown in some situations +} +message StartProgress { + StartProgressStage stage = 1; + optional StartProgressDownloadProgress download_progress = 2; // only set when stage == Downloading +} + // StopRequest is a request from the manager to stop the tunnel. The tunnel replies with a // StopResponse. message StopRequest {} diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index 6a3108b..c4a916f 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -339,31 +339,35 @@ internal static async Task TaskOrCancellation(Task task, CancellationToken cance } /// -/// Downloads an Url to a file on disk. The download will be written to a temporary file first, then moved to the final +/// Downloads a Url to a file on disk. The download will be written to a temporary file first, then moved to the final /// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if /// it hasn't changed. /// public class DownloadTask { - private const int BufferSize = 4096; + private const int BufferSize = 64 * 1024; + private const string XOriginalContentLengthHeader = "X-Original-Content-Length"; // overrides Content-Length if available - private static readonly HttpClient HttpClient = new(); + private static readonly HttpClient HttpClient = new(new HttpClientHandler + { + AutomaticDecompression = DecompressionMethods.All, + }); private readonly string _destinationDirectory; private readonly ILogger _logger; private readonly RaiiSemaphoreSlim _semaphore = new(1, 1); private readonly IDownloadValidator _validator; - public readonly string DestinationPath; + private readonly string _destinationPath; + private readonly string _tempDestinationPath; public readonly HttpRequestMessage Request; - public readonly string TempDestinationPath; - public ulong? TotalBytes { get; private set; } - public ulong BytesRead { get; private set; } public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync - - public double? Progress => TotalBytes == null ? null : (double)BytesRead / TotalBytes.Value; + public bool DownloadStarted { get; private set; } // Whether we've received headers yet and started the actual download + public ulong BytesWritten { get; private set; } + public ulong? BytesTotal { get; private set; } + public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value; public bool IsCompleted => Task.IsCompleted; internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator) @@ -374,17 +378,17 @@ internal DownloadTask(ILogger logger, HttpRequestMessage req, string destination if (string.IsNullOrWhiteSpace(destinationPath)) throw new ArgumentException("Destination path must not be empty", nameof(destinationPath)); - DestinationPath = Path.GetFullPath(destinationPath); - if (Path.EndsInDirectorySeparator(DestinationPath)) - throw new ArgumentException($"Destination path '{DestinationPath}' must not end in a directory separator", + _destinationPath = Path.GetFullPath(destinationPath); + if (Path.EndsInDirectorySeparator(_destinationPath)) + throw new ArgumentException($"Destination path '{_destinationPath}' must not end in a directory separator", nameof(destinationPath)); - _destinationDirectory = Path.GetDirectoryName(DestinationPath) + _destinationDirectory = Path.GetDirectoryName(_destinationPath) ?? throw new ArgumentException( - $"Destination path '{DestinationPath}' must have a parent directory", + $"Destination path '{_destinationPath}' must have a parent directory", nameof(destinationPath)); - TempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(DestinationPath) + + _tempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(_destinationPath) + ".download-" + Path.GetRandomFileName()); } @@ -406,9 +410,9 @@ private async Task Start(CancellationToken ct = default) // If the destination path exists, generate a Coder SHA1 ETag and send // it in the If-None-Match header to the server. - if (File.Exists(DestinationPath)) + if (File.Exists(_destinationPath)) { - await using var stream = File.OpenRead(DestinationPath); + await using var stream = File.OpenRead(_destinationPath); var etag = Convert.ToHexString(await SHA1.HashDataAsync(stream, ct)).ToLower(); Request.Headers.Add("If-None-Match", "\"" + etag + "\""); } @@ -419,11 +423,11 @@ private async Task Start(CancellationToken ct = default) _logger.LogInformation("File has not been modified, skipping download"); try { - await _validator.ValidateAsync(DestinationPath, ct); + await _validator.ValidateAsync(_destinationPath, ct); } catch (Exception e) { - _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", DestinationPath); + _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", _destinationPath); throw new Exception("Existing file failed validation after 304 Not Modified", e); } @@ -446,24 +450,38 @@ private async Task Start(CancellationToken ct = default) } if (res.Content.Headers.ContentLength >= 0) - TotalBytes = (ulong)res.Content.Headers.ContentLength; + BytesTotal = (ulong)res.Content.Headers.ContentLength; + + // X-Original-Content-Length overrules Content-Length if set. + if (res.Headers.TryGetValues(XOriginalContentLengthHeader, out var headerValues)) + { + // If there are multiple we only look at the first one. + var headerValue = headerValues.ToList().FirstOrDefault(); + if (!string.IsNullOrEmpty(headerValue) && ulong.TryParse(headerValue, out var originalContentLength)) + BytesTotal = originalContentLength; + else + _logger.LogWarning( + "Failed to parse {XOriginalContentLengthHeader} header value '{HeaderValue}'", + XOriginalContentLengthHeader, headerValue); + } await Download(res, ct); } private async Task Download(HttpResponseMessage res, CancellationToken ct) { + DownloadStarted = true; try { var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null; FileStream tempFile; try { - tempFile = File.Create(TempDestinationPath, BufferSize, FileOptions.SequentialScan); + tempFile = File.Create(_tempDestinationPath, BufferSize, FileOptions.SequentialScan); } catch (Exception e) { - _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath); + _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", _tempDestinationPath); throw; } @@ -476,13 +494,14 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) { await tempFile.WriteAsync(buffer.AsMemory(0, n), ct); sha1?.TransformBlock(buffer, 0, n, null, 0); - BytesRead += (ulong)n; + BytesWritten += (ulong)n; } } - if (TotalBytes != null && BytesRead != TotalBytes) + BytesTotal ??= BytesWritten; + if (BytesWritten != BytesTotal) throw new IOException( - $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesRead}"); + $"Downloaded file size does not match expected response content length: Expected={BytesTotal}, BytesWritten={BytesWritten}"); // Verify the ETag if it was sent by the server. if (res.Headers.Contains("ETag") && sha1 != null) @@ -497,26 +516,34 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) try { - await _validator.ValidateAsync(TempDestinationPath, ct); + await _validator.ValidateAsync(_tempDestinationPath, ct); } catch (Exception e) { _logger.LogWarning(e, "Downloaded file '{TempDestinationPath}' failed custom validation", - TempDestinationPath); + _tempDestinationPath); throw new HttpRequestException("Downloaded file failed validation", e); } - File.Move(TempDestinationPath, DestinationPath, true); + File.Move(_tempDestinationPath, _destinationPath, true); } - finally + catch { #if DEBUG _logger.LogWarning("Not deleting temporary file '{TempDestinationPath}' in debug mode", - TempDestinationPath); + _tempDestinationPath); #else - if (File.Exists(TempDestinationPath)) - File.Delete(TempDestinationPath); + try + { + if (File.Exists(_tempDestinationPath)) + File.Delete(_tempDestinationPath); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to delete temporary file '{TempDestinationPath}'", _tempDestinationPath); + } #endif + throw; } } } diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index fc014c0..fdb62af 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -131,6 +131,8 @@ private async ValueTask HandleClientMessageStart(ClientMessage me { try { + await BroadcastStartProgress(StartProgressStage.Initializing, cancellationToken: ct); + var serverVersion = await CheckServerVersionAndCredentials(message.Start.CoderUrl, message.Start.ApiToken, ct); if (_status == TunnelStatus.Started && _lastStartRequest != null && @@ -151,10 +153,14 @@ private async ValueTask HandleClientMessageStart(ClientMessage me _lastServerVersion = serverVersion; // TODO: each section of this operation needs a timeout + // Stop the tunnel if it's running so we don't have to worry about // permissions issues when replacing the binary. await _tunnelSupervisor.StopAsync(ct); + await DownloadTunnelBinaryAsync(message.Start.CoderUrl, serverVersion.SemVersion, ct); + + await BroadcastStartProgress(StartProgressStage.Finalizing, cancellationToken: ct); await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage, HandleTunnelRpcError, ct); @@ -237,6 +243,9 @@ private void HandleTunnelRpcMessage(ReplyableRpcMessage CurrentStatus(CancellationToken ct = default) private async Task BroadcastStatus(TunnelStatus? newStatus = null, CancellationToken ct = default) { if (newStatus != null) _status = newStatus.Value; - await _managerRpc.BroadcastAsync(new ServiceMessage + await FallibleBroadcast(new ServiceMessage { Status = await CurrentStatus(ct), }, ct); } + private async Task FallibleBroadcast(ServiceMessage message, CancellationToken ct = default) + { + // Broadcast the messages out with a low timeout. If clients don't + // receive broadcasts in time, it's not a big deal. + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + cts.CancelAfter(TimeSpan.FromMilliseconds(30)); + try + { + await _managerRpc.BroadcastAsync(message, cts.Token); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Could not broadcast low priority message to all RPC clients: {Message}", message); + } + } + private void HandleTunnelRpcError(Exception e) { _logger.LogError(e, "Manager<->Tunnel RPC error"); @@ -425,12 +450,61 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected _logger.LogDebug("Skipping tunnel binary version validation"); } + // Note: all ETag, signature and version validation is performed by the + // DownloadTask. var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct); - // TODO: monitor and report progress when we have a mechanism to do so + // Wait for the download to complete, sending progress updates every + // 50ms. + while (true) + { + // Wait for the download to complete, or for a short delay before + // we send a progress update. + var delayTask = Task.Delay(TimeSpan.FromMilliseconds(50), ct); + var winner = await Task.WhenAny([ + downloadTask.Task, + delayTask, + ]); + if (winner == downloadTask.Task) + break; + + // Task.WhenAny will not throw if the winner was cancelled, so + // check CT afterward and not beforehand. + ct.ThrowIfCancellationRequested(); + + if (!downloadTask.DownloadStarted) + // Don't send progress updates if we don't know what the + // progress is yet. + continue; + + var progress = new StartProgressDownloadProgress + { + BytesWritten = downloadTask.BytesWritten, + }; + if (downloadTask.BytesTotal != null) + progress.BytesTotal = downloadTask.BytesTotal.Value; - // Awaiting this will check the checksum (via the ETag) if the file - // exists, and will also validate the signature and version. + await BroadcastStartProgress(StartProgressStage.Downloading, progress, ct); + } + + // Await again to re-throw any exceptions that occurred during the + // download. await downloadTask.Task; + + // We don't send a broadcast here as we immediately send one in the + // parent routine. + _logger.LogInformation("Completed downloading VPN binary"); + } + + private async Task BroadcastStartProgress(StartProgressStage stage, StartProgressDownloadProgress? downloadProgress = null, CancellationToken cancellationToken = default) + { + await FallibleBroadcast(new ServiceMessage + { + StartProgress = new StartProgress + { + Stage = stage, + DownloadProgress = downloadProgress, + }, + }, cancellationToken); } } diff --git a/Vpn.Service/ManagerRpc.cs b/Vpn.Service/ManagerRpc.cs index c23752f..4920570 100644 --- a/Vpn.Service/ManagerRpc.cs +++ b/Vpn.Service/ManagerRpc.cs @@ -127,14 +127,20 @@ public async Task ExecuteAsync(CancellationToken stoppingToken) public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct) { + // Sends messages to all clients simultaneously and waits for them all + // to send or fail/timeout. + // // Looping over a ConcurrentDictionary is exception-safe, but any items // added or removed during the loop may or may not be included. - foreach (var (clientId, client) in _activeClients) + await Task.WhenAll(_activeClients.Select(async item => + { try { - var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); - cts.CancelAfter(5 * 1000); - await client.Speaker.SendMessage(message, cts.Token); + // Enforce upper bound in case a CT with a timeout wasn't + // supplied. + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + cts.CancelAfter(TimeSpan.FromSeconds(2)); + await item.Value.Speaker.SendMessage(message, cts.Token); } catch (ObjectDisposedException) { @@ -142,11 +148,12 @@ public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct) } catch (Exception e) { - _logger.LogWarning(e, "Failed to send message to client {ClientId}", clientId); + _logger.LogWarning(e, "Failed to send message to client {ClientId}", item.Key); // TODO: this should probably kill the client, but due to the // async nature of the client handling, calling Dispose // will not remove the client from the active clients list } + })); } private async Task HandleRpcClientAsync(ulong clientId, Speaker speaker, diff --git a/Vpn.Service/Program.cs b/Vpn.Service/Program.cs index fc61247..094875d 100644 --- a/Vpn.Service/Program.cs +++ b/Vpn.Service/Program.cs @@ -16,10 +16,12 @@ public static class Program #if !DEBUG private const string ServiceName = "Coder Desktop"; private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\VpnService"; + private const string DefaultLogLevel = "Information"; #else // This value matches Create-Service.ps1. private const string ServiceName = "Coder Desktop (Debug)"; private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\DebugVpnService"; + private const string DefaultLogLevel = "Debug"; #endif private const string ManagerConfigSection = "Manager"; @@ -81,6 +83,10 @@ private static async Task BuildAndRun(string[] args) builder.Services.AddSingleton(); // Services + builder.Services.AddHostedService(); + builder.Services.AddHostedService(); + + // Either run as a Windows service or a console application if (!Environment.UserInteractive) { MainLogger.Information("Running as a windows service"); @@ -91,9 +97,6 @@ private static async Task BuildAndRun(string[] args) MainLogger.Information("Running as a console application"); } - builder.Services.AddHostedService(); - builder.Services.AddHostedService(); - var host = builder.Build(); Log.Logger = (ILogger)host.Services.GetService(typeof(ILogger))!; MainLogger.Information("Application is starting"); @@ -108,7 +111,7 @@ private static void AddDefaultConfig(IConfigurationBuilder builder) ["Serilog:Using:0"] = "Serilog.Sinks.File", ["Serilog:Using:1"] = "Serilog.Sinks.Console", - ["Serilog:MinimumLevel"] = "Information", + ["Serilog:MinimumLevel"] = DefaultLogLevel, ["Serilog:Enrich:0"] = "FromLogContext", ["Serilog:WriteTo:0:Name"] = "File", diff --git a/Vpn.Service/TunnelSupervisor.cs b/Vpn.Service/TunnelSupervisor.cs index a323cac..7dd6738 100644 --- a/Vpn.Service/TunnelSupervisor.cs +++ b/Vpn.Service/TunnelSupervisor.cs @@ -99,18 +99,16 @@ public async Task StartAsync(string binPath, }, }; // TODO: maybe we should change the log format in the inner binary - // to something without a timestamp - var outLogger = Log.ForContext("SourceContext", "coder-vpn.exe[OUT]"); - var errLogger = Log.ForContext("SourceContext", "coder-vpn.exe[ERR]"); + // to something without a timestamp _subprocess.OutputDataReceived += (_, args) => { if (!string.IsNullOrWhiteSpace(args.Data)) - outLogger.Debug("{Data}", args.Data); + _logger.LogInformation("stdout: {Data}", args.Data); }; _subprocess.ErrorDataReceived += (_, args) => { if (!string.IsNullOrWhiteSpace(args.Data)) - errLogger.Debug("{Data}", args.Data); + _logger.LogInformation("stderr: {Data}", args.Data); }; // Pass the other end of the pipes to the subprocess and dispose diff --git a/Vpn/Speaker.cs b/Vpn/Speaker.cs index d113a50..37ec554 100644 --- a/Vpn/Speaker.cs +++ b/Vpn/Speaker.cs @@ -123,7 +123,7 @@ public async Task StartAsync(CancellationToken ct = default) // Handshakes should always finish quickly, so enforce a 5s timeout. using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); cts.CancelAfter(TimeSpan.FromSeconds(5)); - await PerformHandshake(ct); + await PerformHandshake(cts.Token); // Start ReceiveLoop in the background. _receiveTask = ReceiveLoop(_cts.Token);