Skip to content

Commit

Permalink
Merge pull request #8 from twosixlabs/implement-a-trained-xview-model…
Browse files Browse the repository at this point in the history
…-in-xview-model-example

implement a trained xview model in the xview model example
  • Loading branch information
mwartell authored Oct 24, 2023
2 parents b0f5667 + b0a7029 commit 4efef4d
Showing 1 changed file with 35 additions and 14 deletions.
49 changes: 35 additions & 14 deletions examples/src/charmory_examples/xview_example.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from pathlib import Path
from pprint import pprint
import sys

from PIL import Image
import albumentations as A
import art.attacks.evasion
from art.estimators.object_detection import PyTorchFasterRCNN
import boto3
import botocore
from datasets import load_dataset
import jatic_toolbox
from jatic_toolbox import __version__ as jatic_version
from jatic_toolbox.interop.huggingface import HuggingFaceObjectDetectionDataset
from jatic_toolbox.interop.torchvision import TorchVisionObjectDetector
import numpy as np
import torch
from torchvision.transforms._presets import ObjectDetection

from armory.art_experimental.attacks.patch import AttackWrapper
from armory.metrics.compute import BasicProfiler
Expand All @@ -19,20 +24,26 @@
from charmory.evaluation import Attack, Dataset, Evaluation, Metric, Model
from charmory.model.object_detection import JaticObjectDetectionModel
from charmory.tasks.object_detection import ObjectDetectionTask
from charmory.track import track_init_params, track_params

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from charmory.track import track_init_params
from charmory.utils import create_jatic_dataset_transform

BATCH_SIZE = 1
TRAINING_EPOCHS = 20
import torch
BUCKET_NAME = "armory-library-data"
KEY = "fasterrcnn_mobilenet_v3_2"


torch.set_float32_matmul_precision("high")
import armory.data.datasets


def load_huggingface_dataset():
train_data = load_dataset("Honaker/xview_dataset", split="train")
train_data = load_dataset("Honaker/xview_dataset_subset", split="train")

new_dataset = train_data.train_test_split(test_size=0.2, seed=1)
new_dataset = train_data.train_test_split(test_size=0.4, seed=3)
train_dataset, test_dataset = new_dataset["train"], new_dataset["test"]

train_dataset, test_dataset = HuggingFaceObjectDetectionDataset(
Expand All @@ -53,20 +64,31 @@ def main(argv: list = sys.argv[1:]):
###
# Model
###
model = track_params(jatic_toolbox.load_model)(
provider="torchvision",
model_name="fasterrcnn_resnet50_fpn",
task="object-detection",
s3 = boto3.resource("s3")
try:
s3.Bucket(BUCKET_NAME).download_file(
KEY, Path.cwd() / "fasterrcnn_mobilenet_v3_2"
)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404":
print("The object does not exist.")
else:
raise

model = torch.load(Path.cwd() / "fasterrcnn_mobilenet_v3_2")
model.to(DEVICE)

model = TorchVisionObjectDetector(
model=model, processor=ObjectDetection(), labels=None
)

# Bypass JATIC model wrapper to allow targeted adversarial attacks
model.forward = model._model.forward

detector = track_init_params(PyTorchFasterRCNN)(
JaticObjectDetectionModel(model),
channels_first=True,
clip_values=(0.0, 1.0),
)

model_transform = create_jatic_dataset_transform(model.preprocessor)

train_dataset, test_dataset = load_huggingface_dataset()
Expand Down Expand Up @@ -105,7 +127,6 @@ def transform(sample):
transformed = model_transform(transformed)
return transformed

train_dataset.set_transform(transform)
test_dataset.set_transform(transform)

train_dataloader = ArmoryDataLoader(
Expand All @@ -124,7 +145,7 @@ def transform(sample):
test_dataloader=test_dataloader,
)
eval_model = Model(
name="fasterrcnn-resnet-50",
name="xview-trained-fasterrcnn-resnet-50",
model=detector,
)

Expand Down Expand Up @@ -163,7 +184,7 @@ def transform(sample):

task = ObjectDetectionTask(
evaluation,
export_every_n_batches=5,
export_every_n_batches=2,
class_metrics=False,
)
engine = EvaluationEngine(task, limit_test_batches=10)
Expand Down

0 comments on commit 4efef4d

Please sign in to comment.