Skip to content

Commit

Permalink
dependencies update
Browse files Browse the repository at this point in the history
  • Loading branch information
cy-xu committed Oct 22, 2023
1 parent d257524 commit 284c87d
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 215 deletions.
5 changes: 5 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
History
=======

0.5.0 (2023-10-21)
- remove development dependencies from inference install

------------------

0.4.1 (2022-06-10)
- fix broken preivew images from online service

Expand Down
2 changes: 1 addition & 1 deletion cosmic_conn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

__author__ = """Chengyuan Xu, Curtis McCully, Boning Dong, D. Andrew Howell, and Pradeep Sen"""
__email__ = "[email protected]"
__version__ = "0.4.1"
__version__ = "0.5.0"

from cosmic_conn.inference_cr import init_model, detect_image, detect_FITS
3 changes: 2 additions & 1 deletion cosmic_conn/cr_pipeline/lco_cr_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
subtract_sky,
erase_boundary_np,
)
from cosmic_conn.cr_pipeline.utils_img import sep_source_mask, center_crop_npy
from cosmic_conn.cr_pipeline.utils_img import center_crop_npy
from cosmic_conn.cr_pipeline.utils_sep import sep_source_mask
from cosmic_conn.cr_pipeline.utils_io import (
hdul_to_array,
save_fits_with_CR,
Expand Down
195 changes: 1 addition & 194 deletions cosmic_conn/cr_pipeline/utils_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
import os
import numpy as np

from skimage import exposure, morphology, segmentation, io
from scipy import ndimage, signal
from skimage import exposure, segmentation, io

import sep
from astropy.table import Table
from astropy.stats import sigma_clipped_stats

from cosmic_conn.cr_pipeline.zscale import zscale
Expand Down Expand Up @@ -57,196 +54,6 @@ def variation_intensity(vector):
return variation_intensity


def prune_nans_from_table(table):
"""
From banzai
https://github.com/LCOGT/banzai/blob/master/banzai/utils/array_utils.py
"""
nan_in_row = np.zeros(len(table), dtype=bool)
for col in table.colnames:
nan_in_row |= np.isnan(table[col])
return table[~nan_in_row]


def sep_source_mask(frames, valid_mask):
"""
Following Banzai's SEP procedure with some modifications
https://github.com/LCOGT/banzai/blob/master/banzai/photometry.py
SEP manual
https://www.astromatic.net/pubsvn/software/sextractor/trunk/doc/sextractor.pdf
1. Extrat sources from the median frame, more reliable than SEP on individual frames
2. Calculate each soruce's SEP stats, ellipse, window size
3. Expand sources' rectangle window, acquire a local background mask
4. Local background mask shuold be very close across frames, except for extreme PSF wings
5. Calculate local background stats for each source and extract sources' mask from background
6. Generate a source mask for each frame
"""

# lower threshold as we work on the median frame
threshold = 3.5
min_area = 9

# prep median frame for SEP source extractor
median_frame = np.median(frames, axis=0)

# increase limit as we work on the median frame
ny, nx = median_frame.shape
sep.set_extract_pixstack(int(nx * ny * 0.10))

# subtract background
bkg = sep.Background(median_frame, bw=32, bh=32, fw=3, fh=3)
bkg_img = bkg.back()
median_frame -= bkg_img

# SEP applies a gaussian blur kernel
# [[1,2,1], [2,4,2], [1,2,1]] by default
try:
# corner case when active sources pixel exceeds limit will crash
sources = sep.extract(
median_frame, threshold, err=bkg.globalrms, minarea=min_area
)
except:
return None

# Convert the detections into a table
sources = Table(sources)

if len(sources) > 3000:
return None

# We remove anything with a detection flag >= 8
# This includes memory overflows and objects that are too close the edge
sources = sources[sources["flag"] < 8]

sources = prune_nans_from_table(sources)

# Calculate the ellipticity and elongation
sources["ellipticity"] = 1.0 - (sources["b"] / sources["a"])
sources["elongation"] = sources["a"] / sources["b"]

# Fix any value of theta that are invalid due to floating point rounding
# -pi / 2 < theta < pi / 2
sources["theta"][sources["theta"] > (np.pi / 2.0)] -= np.pi
sources["theta"][sources["theta"] < (-np.pi / 2.0)] += np.pi

# Calculate the FWHMs of the stars:
fwhm = 2.0 * (np.log(2) * (sources["a"]
** 2.0 + sources["b"] ** 2.0)) ** 0.5
sources["fwhm"] = fwhm

# small fwhm are often bad pixels/dust on CCD (consistent across frames)
sources = sources[fwhm > 1.0]

"""
Generate local mask for each source in each frame
"""
sources_masks = []

img_height = median_frame.shape[0]
img_width = median_frame.shape[1]

for frame in frames:
sources_mask = np.zeros_like(median_frame, dtype="uint8")

for i in range(len(sources)):
src = sources[i]

src_ellipse = np.zeros_like(median_frame, dtype="uint8")

# soruce's ellipse based on the median frame
sep.mask_ellipse(
src_ellipse, src["x"], src["y"], src["a"], src["b"], src["theta"], r=2
)

# expand source window by 1.5x so more sky and PSF wings could be included
x_expand = int((src["xmax"] - src["xmin"]) * 0.25)
y_expand = int((src["ymax"] - src["ymin"]) * 0.25)

# min, max to handle boundary cases
ymin = max(0, src["ymin"] - y_expand)
ymax = min(src["ymax"] + y_expand, img_height)
xmin = max(0, src["xmin"] - x_expand)
xmax = min(src["xmax"] + x_expand, img_width)

# fail safe #0, if source too close to valid_mask boundary, pass
# we got valid mask from reprojecting consecutive frames to same coordinates
# so some frame will have missing information near boundary
# valid_count = np.sum(valid_mask[ymin:ymax, xmin:xmax] == 0)
# if valid_count > 0:
# continue

src_window = np.zeros_like(median_frame, dtype="uint8")
src_window[ymin:ymax, xmin:xmax] = 1

# remove soruce ellipse from source window mask
bkg_mask = src_window * (1 - src_ellipse)

# fail safe #1, if ellipse larger than window, use ellipse directly
if bkg_mask.sum() == 0.0:
sources_mask[src_ellipse > 0] = 1
continue

# window crop background pixels from each frame
bkg_frm = bkg_mask * frame

# shift background and source pxiels values
# valid bkg area will be above zero, all other area == 0
zero_offset = np.min(bkg_frm) - 1
src_frm = (frame - zero_offset) * src_window
bkg_frm = (bkg_frm - zero_offset) * bkg_mask

# if effective background usable, get local background stats
# consider Lucy smoothing
bkg_vec = bkg_frm[ymin:ymax, xmin:xmax].flatten()
bkg_vec = bkg_vec[bkg_vec != 0]

# use 2 sigma above local background median as the spread boundary
# robust standard deviation ensures to reject PFS wings from stats
# 2 sigma widen the source's extent slightly but includes more soft contour
bkg_std = robust_standard_deviation(bkg_vec)
bkg_upperbound = np.median(bkg_vec) + 2.0 * bkg_std

src_peak_value = np.max(src_frm)
# src_peak_value = max(src_peak, np.max(src_frm))

# fail safe #2, if background upperbound higher than source peak
if bkg_upperbound >= src_peak_value:
sources_mask[src_ellipse > 0] = 1
continue

# effective pixel values range for the source
tolerance = src_peak_value - bkg_upperbound

# starting point for floodfill, from true peak pixel
seed_point = np.unravel_index(src_frm.argmax(), src_frm.shape)
# seed_point = (src['ypeak'], src['xpeak'])

# flood filling is better than thresholding
# as the mask expands from source center, CR in the source window won't be flagged
src_mask = segmentation.flood(
src_frm, seed_point, tolerance=tolerance
).astype("uint8")

sources_mask[src_ellipse > 0] = 1
sources_mask[src_mask > 0] = 1

sources_masks.append(sources_mask)

sources_masks = np.stack(sources_masks)

# dilate CR by n pixel to avoid introducing artificl sharp edges
struct = ndimage.generate_binary_structure(2, 2)
sources_dilation = np.zeros_like(sources_masks)

for i in range(sources_masks.shape[0]):
sources_dilation[i] = ndimage.binary_dilation(
sources_masks[i], structure=struct, iterations=2
)

return sources_dilation


def sigma_clip(ndarray, sigma=0.0, clip_source=False):
# clamp extreme outlier pixels for better stats
if sigma == 0.0:
Expand Down
4 changes: 3 additions & 1 deletion cosmic_conn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
import argparse
import glob
import requests
import zipfile
from pathlib import Path
from cosmic_conn.dl_framework.options import ModelOptions
Expand All @@ -23,6 +22,9 @@ def is_fits_file(filename):
return any(filename.endswith(extension) for extension in EXTENSIONS)

def download_file(url):
# only import requests if needed
import requests

local_filename = url.split('/')[-1]
# NOTE the stream=True parameter below
with requests.get(url, stream=True) as r:
Expand Down
5 changes: 3 additions & 2 deletions cosmic_conn/dl_framework/cosmic_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# model import
from cosmic_conn.dl_framework.utils_ml import (
clean_large,
memory_check,
modulus_boundary_crop,
subtract_sky,
median_weighted_bce,
Expand All @@ -26,7 +25,6 @@
remove_nan
)
from cosmic_conn.dl_framework.unet import UNet_module
from cosmic_conn.cr_pipeline.utils_img import save_as_png


class Cosmic_CoNN(nn.Module):
Expand Down Expand Up @@ -473,6 +471,9 @@ def save_checkpoint(self, epoch):
return model_path

def get_current_visuals(self, epoch):
# some packages are not required for inference
from cosmic_conn.cr_pipeline.utils_img import save_as_png

os.makedirs(self.epoch_dir, exist_ok=True)
root = self.epoch_dir

Expand Down
1 change: 0 additions & 1 deletion cosmic_conn/dl_framework/utils_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
import datetime
import random
import psutil
import numpy as np

import torch
Expand Down
7 changes: 2 additions & 5 deletions cosmic_conn/inference_cr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
import shutil
import logging
from tqdm import tqdm
import pretty_errors

import torch.backends.cudnn as cudnn
from astropy.io import fits

from cosmic_conn.dl_framework.cosmic_conn import Cosmic_CoNN
from cosmic_conn.dl_framework.options import ModelOptions
from cosmic_conn.data_utils import check_trained_models, console_arguments
from cosmic_conn.data_utils import parse_input
from cosmic_conn.data_utils import check_trained_models, console_arguments, parse_input

cudnn.enabled = True
cudnn.benchmark = True
Expand Down Expand Up @@ -120,8 +118,7 @@ def detect_FITS(model):
logging.info(msg)

except:
msg = f"No valid data found in extention 0, 1 or {ext}, \
to specify extension name: -e SCI."
msg = f"No valid data found in extention 0, 1 or {ext}, please check correct EXTNAME for the image and specify with: cosmic-conn -e EXTNAME"
logging.error(msg)
raise ValueError(msg)

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.1
current_version = 0.5.0
commit = True
tag = True

Expand Down
13 changes: 4 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,13 @@
"torch>=1.6.0",
"numpy",
"astropy>=3.0",
"reproject",
"sep",
"scikit-image",
"psutil",
"pretty-errors",
"tqdm",
"requests",
]

extras_require = {
"webapp": ["Flask>=1.1.0", "Flask-APScheduler>=1.12.0"],
"develop": ["tensorboard>=2.4.0", "scikit-learn>=0.24.0", "Flask>=1.1.0", "Flask-APScheduler>=1.12.0"],
extras_require = {
"webapp": ["requests", "Flask>=1.1.0", "Flask-APScheduler>=1.12.0"],
"develop": ["sep", "reproject", "psutil", "requests", "pretty-errors", "tensorboard>=2.4.0", "scikit-learn>=0.24.0", "Flask>=1.1.0", "Flask-APScheduler>=1.12.0"],
}

setup_requirements = [
Expand Down Expand Up @@ -90,6 +85,6 @@
test_suite="tests",
tests_require=test_requirements,
url="https://github.com/cy-xu/cosmic-conn",
version="0.4.1",
version="0.5.0",
zip_safe=False,
)

0 comments on commit 284c87d

Please sign in to comment.