Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Make accuracy and loss values available for upstream use-cases #662

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions TrainingLoop/TrainingProgress.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ let progressBarLength = 30
/// A progress bar that displays to the console as a model trains, and as validation is performed.
/// It hooks into a TrainingLoop via a callback method.
public class TrainingProgress {
public var accuracies: [Float] // accessible list of accuracies values
public var losses: [Float] // acceessible list of loss values
var statistics: TrainingStatistics?
let metrics: Set<TrainingMetrics>
let liveStatistics: Bool
Expand All @@ -34,6 +36,8 @@ public class TrainingProgress {
/// This has an impact on performance, due to materialization of tensors, and updating values
/// on every batch can reduce training speed by up to 30%.
public init(metrics: Set<TrainingMetrics> = [.accuracy, .loss], liveStatistics: Bool = true) {
self.accuracies = []
self.losses = []
self.metrics = metrics
self.liveStatistics = liveStatistics
if !metrics.isEmpty {
Expand Down Expand Up @@ -68,6 +72,15 @@ public class TrainingProgress {
return result
}

func updateMetrics() {
if metrics.contains(.loss) {
losses.append(statistics!.averageLoss())
}
if metrics.contains(.accuracy) {
accuracies.append(statistics!.accuracy())
}
}

/// The callback used to hook into the TrainingLoop. This is updated once per event.
///
/// - Parameters:
Expand All @@ -92,6 +105,7 @@ public class TrainingProgress {
let metricDescriptionComponent: String
if liveStatistics || (batchCount == (batchIndex + 1)) {
metricDescriptionComponent = metricDescription()
updateMetrics()
} else {
metricDescriptionComponent = ""
}
Expand Down