Skip to content

Commit f2182d6

Browse files
readme fix
allow config override via override parameter fix ee session initialization config now uses externally provided .env file to setup added notebook examples from the Applied Deep Learning Book Chapter 1
1 parent 7db84e5 commit f2182d6

9 files changed

+9960
-43
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ from aces.config import Config
2323
from aces.model_trainer import ModelTrainer
2424

2525
if __name__ == "__main__":
26+
config_file = "config.env"
2627
config = Config()
2728
trainer = ModelTrainer(config)
2829
trainer.train_model()

aces/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ class Config:
9595
GCP_MACHINE_TYPE (str): The Google Cloud Platform machine type.
9696
"""
9797

98-
def __init__(self, config_file) -> None:
98+
def __init__(self, config_file, override=False) -> None:
9999

100-
load_dotenv(config_file)
100+
load_dotenv(config_file, override=override)
101101

102102
self.BASEDIR = Path(os.getenv("BASEDIR"))
103103
_DATADIR = os.getenv("DATADIR")

aces/ee_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_credentials_by_service_account_key(key):
3636
return credentials
3737

3838
@staticmethod
39-
def initialize_session(use_highvolume : bool = False, key : Union[str, None] = None):
39+
def initialize_session(use_highvolume : bool = False, key : Union[str, None] = None, project: str = None):
4040
"""
4141
Initialize the Earth Engine session.
4242
@@ -45,14 +45,22 @@ def initialize_session(use_highvolume : bool = False, key : Union[str, None] = N
4545
key (str or None): The path to the service account key JSON file. If None, the default credentials will be used.
4646
"""
4747
if key is None:
48-
if use_highvolume:
48+
if use_highvolume and project:
49+
ee.Initialize(opt_url="https://earthengine-highvolume.googleapis.com", project=project)
50+
elif use_highvolume:
4951
ee.Initialize(opt_url="https://earthengine-highvolume.googleapis.com")
52+
elif project:
53+
ee.Initialize(project=project)
5054
else:
5155
ee.Initialize()
5256
else:
5357
credentials = EEUtils.get_credentials_by_service_account_key(key)
54-
if use_highvolume:
58+
if use_highvolume and project:
59+
ee.Initialize(credentials, opt_url="https://earthengine-highvolume.googleapis.com", project=project)
60+
elif use_highvolume:
5561
ee.Initialize(credentials, opt_url="https://earthengine-highvolume.googleapis.com")
62+
elif project:
63+
ee.Initialize(credentials, project=project)
5664
else:
5765
ee.Initialize(credentials)
5866

notebook/aces_rice_classification_paro_2021.ipynb

Lines changed: 6148 additions & 0 deletions
Large diffs are not rendered by default.

notebook/count_sample_size.ipynb

Lines changed: 790 additions & 0 deletions
Large diffs are not rendered by default.

notebook/prediction_dnn.ipynb

Lines changed: 1339 additions & 0 deletions
Large diffs are not rendered by default.

notebook/prediction_unet.ipynb

Lines changed: 1625 additions & 0 deletions
Large diffs are not rendered by default.

workflow/v2/4.export_image_for_prediction.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,47 +78,50 @@
7878
composite_during = composite_during.regexpRename("$(.*)", "_during")
7979
image = composite_before.addBands(composite_during).toFloat()
8080

81-
if Config.USE_ELEVATION:
81+
config_file = "config.env"
82+
config = Config(config_file)
83+
84+
if config.USE_ELEVATION:
8285
elevation = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/elevationParo")
8386
slope = ee.Image("projects/servir-sco-assets/assets/Bhutan/ACES_2/slopeParo")
8487
image = image.addBands(elevation).addBands(slope).toFloat()
85-
Config.FEATURES.extend(["elevation", "slope"])
88+
config.FEATURES.extend(["elevation", "slope"])
8689

8790

88-
if Config.USE_S1:
91+
if config.USE_S1:
8992
sentinel1_asc_before_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Ascending2021/s1AscBefore")
9093
sentinel1_asc_during_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Ascending2021/s1AscDuring")
9194
sentinel1_desc_before_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Descending2021/s1DescBefore")
9295
sentinel1_desc_during_composite = ee.Image("projects/servir-sco-assets/assets/Bhutan/Sentinel1Descending2021/s1DescDuring")
9396

9497
image = image.addBands(sentinel1_asc_before_composite).addBands(sentinel1_asc_during_composite).addBands(sentinel1_desc_before_composite).addBands(sentinel1_desc_during_composite).toFloat()
95-
Config.FEATURES.extend(["vv_asc_before", "vh_asc_before", "vv_asc_during", "vh_asc_during",
98+
config.FEATURES.extend(["vv_asc_before", "vh_asc_before", "vv_asc_during", "vh_asc_during",
9699
"vv_desc_before", "vh_desc_before", "vv_desc_during", "vh_desc_during"])
97100

98101
# dem = ee.Image("MERIT/DEM/v1_0_3") # ee.Image('USGS/SRTMGL1_003');
99102
# dem = dem.clip(fc_country)
100103
# riceZone = dem.gt(rice_zone[region]["min"]).And(dem.lte(rice_zone[region]["max"]))
101104
# image = image.clip(region_fc).updateMask(riceZone)
102105

103-
image = image.select(Config.FEATURES)
106+
image = image.select(config.FEATURES)
104107
print("image", image.bandNames().getInfo())
105108

106109
# Specify patch and file dimensions.
107110
formatOptions = {
108-
"patchDimensions": [Config.PATCH_SHAPE_SINGLE, Config.PATCH_SHAPE_SINGLE],
111+
"patchDimensions": [config.PATCH_SHAPE_SINGLE, config.PATCH_SHAPE_SINGLE],
109112
"maxFileSize": 104857600,
110113
"compressed": True
111114
}
112115

113-
if Config.KERNEL_BUFFER:
114-
formatOptions["kernelSize"] = Config.KERNEL_BUFFER
116+
if config.KERNEL_BUFFER:
117+
formatOptions["kernelSize"] = config.KERNEL_BUFFER
115118

116119
# Setup the task
117120
image_export_options = {
118-
"description": Config.GCS_IMAGE_DIR.split("/")[-1],
119-
"file_name_prefix": f"{Config.GCS_IMAGE_DIR}/{Config.GCS_IMAGE_PREFIX}",
120-
"bucket": Config.GCS_BUCKET,
121-
"scale": Config.SCALE,
121+
"description": "export_task_for_prediction",
122+
"file_name_prefix": f"{config.GCS_IMAGE_DIR}/{config.GCS_IMAGE_PREFIX}",
123+
"bucket": config.GCS_BUCKET,
124+
"scale": config.SCALE,
122125
"file_format": "TFRecord",
123126
"region": region_fc, # image.geometry(),
124127
"format_options": formatOptions,

workflow/v2/5.prediction_unet.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,24 @@
1616
import subprocess
1717

1818

19-
OUTPUT_IMAGE_FILE = str(Config.MODEL_DIR / "prediction" / f"{Config.OUTPUT_NAME}.TFRecord")
20-
if not os.path.exists(str(Config.MODEL_DIR / "prediction")): os.mkdir(str(Config.MODEL_DIR / "prediction"))
19+
config_file = "config.env"
20+
config = Config(config_file)
21+
22+
OUTPUT_IMAGE_FILE = str(config.MODEL_DIR / "prediction" / f"{config.OUTPUT_NAME}.TFRecord")
23+
if not os.path.exists(str(config.MODEL_DIR / "prediction")): os.mkdir(str(config.MODEL_DIR / "prediction"))
2124
print(f"OUTPUT_IMAGE_FILE: {OUTPUT_IMAGE_FILE}")
2225

23-
OUTPUT_GCS_PATH = f"gs://{Config.GCS_BUCKET}/prediction/{Config.OUTPUT_NAME}.TFRecord"
26+
OUTPUT_GCS_PATH = f"gs://{config.GCS_BUCKET}/prediction/{config.OUTPUT_NAME}.TFRecord"
2427
print(f"OUTPUT_GCS_PATH: {OUTPUT_GCS_PATH}")
2528

26-
ls = f"sudo gsutil ls gs://{Config.GCS_BUCKET}/{Config.GCS_IMAGE_DIR}/"
29+
ls = f"sudo gsutil ls gs://{config.GCS_BUCKET}/{config.GCS_IMAGE_DIR}/"
2730
print(f"ls >> : {ls}")
2831
files_list = subprocess.check_output(ls, shell=True)
2932
files_list = files_list.decode("utf-8")
3033
files_list = files_list.split("\n")
3134

3235
# Get only the files generated by the image export.
33-
exported_files_list = [s for s in files_list if Config.GCS_IMAGE_PREFIX in s]
36+
exported_files_list = [s for s in files_list if config.GCS_IMAGE_PREFIX in s]
3437

3538
print(f"exported_files_list: {exported_files_list}")
3639

@@ -51,11 +54,11 @@
5154
print(f"json_file: {json_file}")
5255

5356
if Config.USE_BEST_MODEL_FOR_INFERENCE:
54-
print(f"Using best model for inference.\nLoading model from {str(Config.MODEL_DIR)}/{Config.MODEL_CHECKPOINT_NAME}.tf")
55-
this_model = tf.keras.models.load_model(f"{str(Config.MODEL_DIR)}/{Config.MODEL_CHECKPOINT_NAME}.tf")
57+
print(f"Using best model for inference.\nLoading model from {str(config.MODEL_DIR)}/{config.MODEL_CHECKPOINT_NAME}.tf")
58+
this_model = tf.keras.models.load_model(f"{str(config.MODEL_DIR)}/{config.MODEL_CHECKPOINT_NAME}.tf")
5659
else:
57-
print(f"Using last model for inference.\nLoading model from {str(Config.MODEL_DIR)}/trained-model")
58-
this_model = tf.keras.models.load_model(f"{str(Config.MODEL_DIR)}/trained-model")
60+
print(f"Using last model for inference.\nLoading model from {str(config.MODEL_DIR)}/trained-model")
61+
this_model = tf.keras.models.load_model(f"{str(config.MODEL_DIR)}/trained-model")
5962

6063
print(this_model.summary())
6164

@@ -74,40 +77,40 @@
7477
patch_dimensions_flat = [patch_width * patch_height, 1]
7578

7679
# Get set up for prediction.
77-
if Config.KERNEL_BUFFER:
78-
x_buffer = Config.KERNEL_BUFFER[0] // 2
79-
y_buffer = Config.KERNEL_BUFFER[1] // 2
80+
if config.KERNEL_BUFFER:
81+
x_buffer = config.KERNEL_BUFFER[0] // 2
82+
y_buffer = config.KERNEL_BUFFER[1] // 2
8083

8184
buffered_shape = [
82-
Config.PATCH_SHAPE[0] + Config.KERNEL_BUFFER[0],
83-
Config.PATCH_SHAPE[1] + Config.KERNEL_BUFFER[1],
85+
config.PATCH_SHAPE[0] + config.KERNEL_BUFFER[0],
86+
config.PATCH_SHAPE[1] + config.KERNEL_BUFFER[1],
8487
]
8588
else:
8689
x_buffer = 0
8790
y_buffer = 0
88-
buffered_shape = Config.PATCH_SHAPE
91+
buffered_shape = config.PATCH_SHAPE
8992

90-
if Config.USE_ELEVATION:
91-
Config.FEATURES.extend(["elevation", "slope"])
93+
if config.USE_ELEVATION:
94+
config.FEATURES.extend(["elevation", "slope"])
9295

9396

94-
if Config.USE_S1:
95-
Config.FEATURES.extend(["vv_asc_before", "vh_asc_before", "vv_asc_during", "vh_asc_during",
97+
if config.USE_S1:
98+
config.FEATURES.extend(["vv_asc_before", "vh_asc_before", "vv_asc_during", "vh_asc_during",
9699
"vv_desc_before", "vh_desc_before", "vv_desc_during", "vh_desc_during"])
97100

98-
print(f"Config.FEATURES: {Config.FEATURES}")
101+
print(f"Config.FEATURES: {config.FEATURES}")
99102

100103
image_columns = [
101-
tf.io.FixedLenFeature(shape=buffered_shape, dtype=tf.float32) for k in Config.FEATURES
104+
tf.io.FixedLenFeature(shape=buffered_shape, dtype=tf.float32) for k in config.FEATURES
102105
]
103106

104-
image_features_dict = dict(zip(Config.FEATURES, image_columns))
107+
image_features_dict = dict(zip(config.FEATURES, image_columns))
105108

106109
def parse_image(example_proto):
107110
return tf.io.parse_single_example(example_proto, image_features_dict)
108111

109112
def toTupleImage(inputs):
110-
inputsList = [inputs.get(key) for key in Config.FEATURES]
113+
inputsList = [inputs.get(key) for key in config.FEATURES]
111114
stacked = tf.stack(inputsList, axis=0)
112115
stacked = tf.transpose(stacked, [1, 2, 0])
113116
return stacked
@@ -139,8 +142,8 @@ def toTupleImage(inputs):
139142
print(f"Writing patch {i}...")
140143

141144
prediction_patch = prediction_patch[
142-
x_buffer: x_buffer+Config.PATCH_SHAPE[0],
143-
y_buffer: y_buffer+Config.PATCH_SHAPE[1]
145+
x_buffer: x_buffer+config.PATCH_SHAPE[0],
146+
y_buffer: y_buffer+config.PATCH_SHAPE[1]
144147
]
145148

146149
example = tf.train.Example(
@@ -181,6 +184,6 @@ def toTupleImage(inputs):
181184
print(f"uploading classified image to earth engine: {result}")
182185

183186
# upload to earth engine asset
184-
upload_image = f"earthengine upload image --asset_id={Config.EE_OUTPUT_ASSET}/{Config.OUTPUT_NAME} --pyramiding_policy=mode {OUTPUT_GCS_PATH} {json_file}"
187+
upload_image = f"earthengine upload image --asset_id={config.EE_OUTPUT_ASSET}/{config.OUTPUT_NAME} --pyramiding_policy=mode {OUTPUT_GCS_PATH} {json_file}"
185188
result = subprocess.check_output(upload_image, shell=True)
186189
print(f"uploading classified image to earth engine: {result}")

0 commit comments

Comments
 (0)