Skip to content

Commit

Permalink
Merge pull request #22 from lincc-frameworks/main
Browse files Browse the repository at this point in the history
Merge recent changes
  • Loading branch information
delucchi-cmu committed Oct 18, 2023
2 parents 4199c4c + 6d22296 commit 32d4cbe
Show file tree
Hide file tree
Showing 7 changed files with 642 additions and 632 deletions.
84 changes: 84 additions & 0 deletions astrodet/astrodet.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,90 @@ def add_val_loss(self,val_loss):
self.vallossList.append(val_loss)


class LazyAstroTrainer(SimpleTrainer):
def __init__(self, model, data_loader, optimizer, cfg, cfg_old):
super().__init__(model, data_loader, optimizer)
#super().__init__(model, data_loader, optimizer)

# Borrowed from DefaultTrainer constructor
# see https://detectron2.readthedocs.io/en/latest/_modules/detectron2/engine/defaults.html#DefaultTrainer
self.checkpointer = checkpointer.DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,cfg_old.OUTPUT_DIR)
# load weights
self.checkpointer.load(cfg.train.init_checkpoint)

# record loss over iteration
self.lossList = []
self.vallossList = []

self.period = 20
self.iterCount = 0

self.scheduler = self.build_lr_scheduler(cfg_old, optimizer)
#self.scheduler = instantiate(cfg.lr_multiplier)
self.valloss=0



#Note: print out loss over p iterations
def set_period(self,p):
self.period = p



# Copied directly from SimpleTrainer, add in custom manipulation with the loss
# see https://detectron2.readthedocs.io/en/latest/_modules/detectron2/engine/train_loop.html#SimpleTrainer
def run_step(self):
self.iterCount = self.iterCount + 1
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
data_time = time.perf_counter() - start
data = next(self._data_loader_iter)
# Note: in training mode, model() returns loss
loss_dict = self.model(data)
#print('Loss dict',loss_dict)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
all_losses = [l.cpu().detach().item() for l in loss_dict.values()]
self.optimizer.zero_grad()
losses.backward()


#self._write_metrics(loss_dict,data_time)

self.optimizer.step()


self.lossList.append(losses.cpu().detach().numpy())
if self.iterCount % self.period == 0 and comm.is_main_process():
#print("Iteration: ", self.iterCount, " time: ", data_time," loss: ",losses.cpu().detach().numpy(), "val loss: ",self.valloss, "lr: ", self.scheduler.get_lr())
print("Iteration: ", self.iterCount, " time: ", data_time, loss_dict.keys(), all_losses, "val loss: ",self.valloss, "lr: ", self.scheduler.get_lr())

del data
gc.collect()
torch.cuda.empty_cache()

@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
return build_lr_scheduler(cfg, optimizer)

def add_val_loss(self,val_loss):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""

self.vallossList.append(val_loss)




class AstroPredictor:
Expand Down
204 changes: 204 additions & 0 deletions deepdisc-env-nobuilds.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
name: ddtestnew
channels:
- https://ftp.osuosl.org/pub/open-ce/1.8.0
- conda-forge
- https://ftp.osuosl.org/pub/open-ce/1.5.1
- https://ftp.osuosl.org/pub/open-ce/current
- https://ftp.osuosl.org/pub/open-ce/current/
- https://ftp.osuosl.org/pub/open-ce/1.8.0/
- defaults
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=5.1
- _pytorch_select=2.0
- absl-py=0.12.0
- aiohttp=3.8.5
- aiosignal=1.2.0
- astropy=5.1
- async-timeout=4.0.2
- attrs=23.1.0
- av=8.0.3
- blas=1.0
- blinker=1.6.2
- brotli=1.0.9
- brotli-bin=1.0.9
- brotlipy=0.7.0
- bzip2=1.0.8
- c-ares=1.19.1
- ca-certificates=2023.7.22
- cachetools=4.2.2
- cairo=1.16.0
- certifi=2023.7.22
- cffi=1.14.6
- charset-normalizer=2.0.4
- click=8.0.4
- cloudpickle=2.2.1
- contourpy=1.0.5
- cryptography=39.0.1
- cudatoolkit=11.2.2
- cudnn=8.1.1_11.2
- cycler=0.11.0
- cytoolz=0.11.2
- dask-core=2023.9.3
- eigen=3.4.0
- expat=2.5.0
- ffmpeg=4.2.2
- fontconfig=2.14.1
- fonttools=4.25.0
- freetype=2.12.1
- frozenlist=1.3.3
- fsspec=2023.9.2
- future=0.18.3
- geos=3.8.0
- glib=2.69.1
- gmp=6.2.1
- gnutls=3.6.15
- google-auth=1.23.0
- google-auth-oauthlib=0.4.4
- graphite2=1.3.14
- grpcio=1.42.0
- harfbuzz=4.3.0
- hdf5=1.10.6
- icu=58.2
- idna=3.4
- imagecodecs-lite=2019.12.3
- imageio=2.13.1
- imgaug=0.4.0
- importlib-metadata=6.0.0
- importlib_metadata=6.0.0
- importlib_resources=5.2.0
- jpeg=9e
- kiwisolver=1.4.4
- lame=3.100
- lcms2=2.12
- ld_impl_linux-ppc64le=2.38
- lerc=3.0
- leveldb=1.20
- libbrotlicommon=1.0.9
- libbrotlidec=1.0.9
- libbrotlienc=1.0.9
- libdeflate=1.17
- libffi=3.3
- libgcc-ng=11.2.0
- libgfortran-ng=7.3.0
- libgomp=11.2.0
- libidn2=2.3.4
- libopenblas=0.3.18
- libopus=1.3.1
- libpng=1.6.39
- libprotobuf=3.14.0
- libstdcxx-ng=11.2.0
- libtasn1=4.19.0
- libtiff=4.5.1
- libunistring=0.9.10
- libuuid=1.41.5
- libvpx=1.7.0
- libwebp-base=1.3.2
- libxcb=1.15
- libxml2=2.10.4
- lmdb=0.9.29
- locket=1.0.0
- lz4-c=1.9.4
- markdown=3.3.3
- markupsafe=2.1.1
- matplotlib=3.7.1
- matplotlib-base=3.7.1
- multidict=6.0.2
- munkres=1.1.4
- nccl=2.11.4
- ncurses=6.4
- nettle=3.7.3
- networkx=2.8.4
- numactl=2.0.16
- numpy=1.20.3
- numpy-base=1.20.3
- oauthlib=3.2.0
- olefile=0.46
- opencv=4.6.0
- openh264=2.1.1
- openjpeg=2.4.0
- openssl=1.1.1w
- packaging=23.1
- partd=1.4.1
- pcre=8.45
- pillow=8.3.1
- pip=23.2.1
- pixman=0.40.0
- protobuf=3.14.0
- pyasn1=0.4.8
- pyasn1-modules=0.2.8
- pybind11=2.9.2
- pybind11-global=2.9.2
- pycparser=2.21
- pyerfa=2.0.0
- pyjwt=2.4.0
- pyopenssl=23.2.0
- pyparsing=3.0.9
- pysocks=1.7.1
- python=3.9.15
- python-dateutil=2.8.2
- python_abi=3.9
- pytorch=1.10.1
- pytorch-base=1.10.1
- pywavelets=1.3.0
- pyyaml=5.4.1
- readline=8.2
- requests=2.31.0
- requests-oauthlib=1.3.0
- rsa=4.7.2
- scikit-image=0.19.2
- scipy=1.7.3
- sentencepiece=0.1.96
- shapely=2.0.1
- six=1.15.0
- snappy=1.1.9
- sqlite=3.41.2
- tabulate=0.8.10
- tensorboard=2.7.0
- tensorboard-data-server=0.6.1
- tensorboard-plugin-wit=1.6.0
- tifffile=2020.6.3
- tk=8.6.12
- toolz=0.12.0
- torchtext-base=0.11.1
- torchvision=0.11.2
- torchvision-base=0.11.2
- tornado=6.3.2
- tqdm=4.65.0
- tzdata=2023c
- urllib3=1.26.16
- werkzeug=2.2.3
- wheel=0.41.2
- x264=1!157.20191217
- xz=5.4.2
- yaml=0.2.5
- yarl=1.8.1
- zipp=3.11.0
- zlib=1.2.13
- zstd=1.5.5
- pip:
- antlr4-python3-runtime==4.9.3
- autograd==1.6.2
- black==23.9.1
- filelock==3.12.4
- fvcore==0.1.5.post20221221
- huggingface-hub==0.17.3
- hydra-core==1.3.2
- iopath==0.1.9
- mypy-extensions==1.0.0
- omegaconf==2.3.0
- pathspec==0.11.2
- peigen==0.0.9
- platformdirs==3.11.0
- portalocker==2.8.2
- proxmin==0.6.12
- pycocotools==2.0.7
- safetensors==0.4.0
- scarlet==1.0.1+g45187fd
- setuptools==68.2.2
- termcolor==2.3.0
- timm==0.9.7
- tomli==2.0.1
- typing-extensions==4.8.0
- yacs==0.1.8
prefix: /home/g4merz/.conda/envs/ddtestnew
172 changes: 88 additions & 84 deletions demo_hsc.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 32d4cbe

Please sign in to comment.