Skip to content

Commit

Permalink
training data generation script fixes
Browse files Browse the repository at this point in the history
remove docs build on PR accept
bugfixes
  • Loading branch information
biplovbhandari committed Apr 19, 2024
1 parent 0ce74ea commit 3980fdc
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 80 deletions.
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ MODEL_DIR_NAME = "unet_v1"
# if so it is generated as trial_MODELTYPE + datetime.now() + _v + version
AUTO_MODEL_DIR_NAME = False

# training data output config
DATA_OUTPUT_DIR = "training_data"

# specify features and labels
# keep them in this format for reading
FEATURES = "red_before
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ on:
branches:
- main
- master
pull_request:
branches:
- main
- master

jobs:
deploy:
Expand Down
4 changes: 4 additions & 0 deletions aces/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Config:
MODEL_CHECKPOINT_NAME (str): The name for model checkpoints.
MODEL_DIR_NAME (str): The name of the model directory. MODEL_DIR = OUTPUT_DIR / MODEL_DIR_NAME
AUTO_MODEL_DIR_NAME (bool): Flag to use automatic model directory naming.
DATA_OUTPUT_DIR (str): The output directory for data export. Used in the workflow/v2/generate_training_patches script.
# True generates as trial_MODELTYPE + datetime.now() + _v + version
# False uses the MODEL_DIR_NAME
FEATURES (str): The list of features used in the model.
Expand Down Expand Up @@ -131,6 +132,9 @@ def __init__(self, config_file, override=False) -> None:
self.MODEL_DIR = self.OUTPUT_DIR / self.MODEL_DIR_NAME

self.AUTO_MODEL_DIR_NAME = os.getenv("AUTO_MODEL_DIR_NAME") == "True"

self.DATA_OUTPUT_DIR = os.getenv("DATA_OUTPUT_DIR")

self.SCALE = int(os.getenv("SCALE"))

self.FEATURES = os.getenv("FEATURES").split("\n")
Expand Down
22 changes: 15 additions & 7 deletions aces/ee_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def beam_export_collection_to_cloud_storage(collection_index, start_training, **
print(f"Exporting training data to gs://{bucket}/{file_prefix}..")
training_task = ee.batch.Export.table.toCloudStorage(
collection=collection,
description=description,
description=f"{description}__index_{index}",
fileNamePrefix=file_prefix,
bucket=bucket,
fileFormat=kwargs.get("file_format", "TFRecord"),
Expand Down Expand Up @@ -356,8 +356,8 @@ def sample_image_by_collection(image: ee.Image, collection: ee.FeatureCollection
@staticmethod
def sample_image(image: ee.Image, region: ee.FeatureCollection, **kwargs: dict) -> ee.FeatureCollection:
sample = image.sample(region=region,
scale=kwargs.get("scale", Config.SCALE),
seed=kwargs.get("seed", Config.SEED),
scale=kwargs.get("SCALE") or kwargs.get("scale"),
seed=kwargs.get("SEED") or kwargs.get("seed"),
geometries=kwargs.get("geometries", False))
return sample

Expand All @@ -382,11 +382,19 @@ def beam_yield_sample_points(index, sample_locations: ee.List, use_service_accou
return point["coordinates"]

@staticmethod
def beam_sample_neighbourhood(coords_index, image, use_service_account: bool = False):
def beam_sample_neighbourhood(coords_index, image, config: Union[Config, str] = "config.env", use_service_account: bool = False):
from aces.ee_utils import EEUtils
from aces.config import Config
import ee
EEUtils.initialize_session(use_highvolume=True, key=Config.EE_SERVICE_CREDENTIALS if use_service_account else None)

if isinstance(config, str):
config = Config(config)
elif isinstance(config, Config):
config = config
else:
raise ValueError("config must be of type Config or str")

EEUtils.initialize_session(use_highvolume=True, key=config.EE_SERVICE_CREDENTIALS if use_service_account else None)

coords = coords_index[0]
index = coords_index[1]
Expand All @@ -403,12 +411,12 @@ def create_neighborhood(kernel) -> ee.Image:
def sample_data(image, points) -> ee.FeatureCollection:
return image.sample(
region=points,
scale=Config.SCALE,
scale=config.SCALE,
tileScale=16,
geometries=False
)

image_kernel = get_kernel(Config.PATCH_SHAPE_SINGLE)
image_kernel = get_kernel(config.PATCH_SHAPE_SINGLE)
neighborhood = create_neighborhood(image_kernel)
training_data = sample_data(neighborhood, ee.Geometry.Point(coords))
return training_data, index
Expand Down
Loading

0 comments on commit 3980fdc

Please sign in to comment.