Skip to content

ZeruiW/TemporalSHAP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Temporal SHAP - README

Overview

This repository provides an implementation for generating SHAP values for temporal segments in video classification tasks using a Timesformer model. It includes tools for processing videos, computing Shapley values, and evaluating performance metrics.

Table of Contents

Requirements

  • Python 3.8+
  • PyTorch
  • Transformers
  • av
  • numpy
  • scikit-learn
  • psutil
  • tqdm

Installation

  1. Clone the repository:

    git clone https://github.com/yourusername/temporal-shap.git
    cd temporal-shap
  2. Install the required packages:

    pip install -r requirements.txt

Usage

Configuration

Modify the config dictionary in the script to set your parameters:

config = {
    "model_name": "facebook/timesformer-base-finetuned-k400",
    "image_processor_name": "MCG-NJU/videomae-base-finetuned-kinetics",
    "num_samples": 100,
    "num_classes": 400,
    "num_samples_per_class": 25,
    "video_list_path": "archive/kinetics400_val_list_videos.txt",
    "video_directory": "archive/zoom_blur",
    "use_exact": True
}

Initialization

Initialize the video processor and SHAP calculator:

video_processor = VideoProcessor(config["model_name"], config["image_processor_name"])
shap_calculator = TemporalShap(num_samples=config["num_samples"])

Video Processing

Process the videos and compute predictions and SHAP values:

sampled_files = [...]  # List of video filenames
true_labels = [...]  # Corresponding true labels

video_data = process_videos(video_processor, shap_calculator, sampled_files, true_labels, use_exact=config["use_exact"])

SHAP Calculation

Calculate SHAP values for the segments:

sv_true_label = shap_calculator.approximate_shapley_values(segment_outputs, true_label)
sv_video_pred = shap_calculator.approximate_shapley_values(segment_outputs, video_pred_label)

Performance Evaluation

Compute performance metrics:

accuracy, precision, recall, f1 = compute_metrics(video_data)
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

save_performance_metrics(accuracy, precision, recall, f1, time_consumed, cpu_energy_consumed, gpu_energy_consumed, filename="performance.json")

Results

Results are saved in results.json and performance metrics are saved in performance.json. You can load and inspect them for detailed analysis.

Contributing

Contributions are welcome! Please open an issue or submit a pull request with your changes.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published