-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata_structures.py
2417 lines (1971 loc) · 84.9 KB
/
data_structures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import division, print_function, absolute_import
from warnings import warn
import copy as _copy
import numpy as _np
from . import utils
from astropy import time as _time, table as _tbl, units as _u, constants as _const, table as _table
from astropy.io import fits as _fits
from .crebin import rebin as _rebin
from matplotlib import pyplot as plt, pyplot as _pl
from functools import reduce
def _FITSformat(dtype):
dstr = str(dtype)
dstr = dstr.replace('>', '')
if dstr in ['uint8']:
return 'B'
if dstr in ['int8', 'i1', 'int16']:
return 'I'
if dstr in ['uint16', 'int32', 'i2', 'i4']:
return 'J'
if dstr in ['uint32', 'int64', 'i8']:
return 'K'
if dstr in ['float32', 'f4']:
return 'E'
if dstr in ['float64', 'f8']:
return 'D'
raise ValueError('Not sure waht to do with the {} data type.'.format(dstr))
_name_dict = {'t':'time', 'w':'wavelength', 'y':'xdisp', 'a':'area_eff', 'q':'qualflag', 'o':'order', 'n':'obs_no',
'e':'wght', 'r':'rgn_wght'}
class Photons:
"""
Base class for holding spectral photon data. Each instance contains a list of observed photons, including,
at a minimum:
- time of arrival, t
- wavelength, w
Optionally, the list may also include:
- detector effective area at arrival location, a
- data quality flags, q
- spectrum order number, o
- cross dispersion distance from spectral trace, y
- event weight (termed epsilon for HST observations), e
- region weight (1.0 for signal, -signal_area/backgroun_area for background, 0.0 otherwsise), r
- anything else the user want's to define
Some operators have been defined for this class:
#FIXME
These attributes are all protected in this class and are accessed using the bracket syntax, e.g. if `photons` is an
instance of class `Photons`, then
>>> photons['w']
will return the photon wavelengths.
I thought about deriving this class from FITS_rec or _np.recarray, but ultimately it felt cumbersome. I decided
it's better they be a part of the object, rather than a parent to it, since it will also contain various metadata.
"""
# for ease of use, map some alternative names to the proper photon property names
_alternate_names = {'time':'t',
'wavelength':'w', 'wave':'w', 'wvln':'w', 'wav':'w', 'waveln':'w',
'effective area':'a', 'area':'a', 'area_eff':'a',
'data quality':'q', 'dq':'q', 'quality':'q', 'flags':'q', 'qualflag':'q',
'order':'o', 'segment':'o', 'seg':'o',
'observation':'n', 'obs':'n','obs_no':'n',
'xdisp':'y', 'cross dispersion':'y',
'weight':'e', 'eps':'e', 'epsilon':'e', 'event weight':'e', 'wt':'e', 'wght':'e',
'region':'r', 'ribbon':'r', 'rgn_wght':'r'}
#region OBJECT HANDLING
def __init__(self, **kwargs):
"""
Create an empty Photons object. This is really just to show what the derived classes must define. If I ever
see a need to be able to create Photons objects using arrays of data in the future, then I'll rewrite this.
Parameters
----------
obs_metadata : list of metadata dict (or similar types) for each obs
time_datum : time object
data : astropy Table with at least a 't' and 'w' column (see class description for what to name others
Returns
-------
A Photons object.
"""
# metadata associated with the observations that recorded the photons, one list entry for each observation
self.obs_metadata = kwargs.get('obs_metadata', [{}])
self.time_datum = kwargs.get('time_datum', _time.Time('2000-01-01T00:00:00'))
if 'photons' in kwargs:
self.photons = kwargs['photons']
else:
self.photons = _tbl.Table(names=['t', 'w'], dtype=['f8', 'f8'])
self.photons['t'].unit = _u.s
self.photons['w'].unit = _u.AA
if 'obs_times' in kwargs:
self.obs_times = kwargs['obs_times']
else:
if 'n' in self and len(self) > 0:
self.photons.group_by('n')
rngs = [_np.array([[a['t'].min(), a['t'].max()]]) for a in self.photons.groups]
self.obs_times = rngs
elif len(self) > 0:
self.obs_times = [_np.array([[self.photons['t'].min(), self.photons['t'].max]])]
else:
self.obs_times = [_np.array([[]])]
if 'obs_bandpasses' in kwargs:
self.obs_bandpasses = kwargs['obs_bandpasses']
else:
if 'n' in self and len(self) > 0:
self.photons.group_by('n')
rngs = [_np.array([[a['w'].min(), a['w'].max()]]) for a in self.photons.groups]
self.obs_bandpasses = rngs
elif len(self) > 0:
self.obs_bandpasses = [_np.array([[self.photons['w'].min(), self.photons['w'].max]])]
else:
self.obs_bandpasses = [_np.array([[]])]
_ovr_doc = ('"overlap_handling : "adjust Aeff"|"clip"\n'
'\t How to handle multiple observations that overlap in wavelength.\n'
'\t - If "clip," photons from the observation with fewer photons in the overlap are removed.\n'
'\t - If "adjust Aeff," the effective area estimate at the wavelengths of the affected photons '
'are increased as appropriate.')
def merge_like_observations(self, overlap_handling="adjust Aeff", min_rate_ratio=0.5):
"""
Merge observations that have the same exposure times in place.
Parameters
----------
{ovr}
If their bandpass ranges overlap, then the photons
in the overlap get de-weighted due to the "extra" observing time that isn't otherwise accounted for.
Returns
-------
None (operation done in place)
"""
def get_signal():
if 'r' in self:
signal = self['r'] > 0
else:
signal = _np.ones(len(self), bool)
return signal
signal = get_signal()
# make sure these are lists so that you can delete from them
self.obs_times = list(self.obs_times)
self.obs_metadata = list(self.obs_metadata)
self.obs_bandpasses = list(self.obs_bandpasses)
i = 0
while i < len(self.obs_metadata):
j = i + 1
while j < len(self.obs_metadata):
# if observations have same exposure starts and ends
if _np.all(self.obs_times[i] == self.obs_times[j]):
i_photons = self['n'] == i
j_photons = self['n'] == j
overlap = utils.rangeset_intersect(self.obs_bandpasses[i], self.obs_bandpasses[j])
if len(overlap) > 0:
in_overlap = utils.inranges(self['w'], overlap)
xi = i_photons & in_overlap
xj = j_photons & in_overlap
Ni = float(_np.sum(xi & signal))
Nj = float(_np.sum(xj & signal))
worthwhile = (Ni > 0) and (Nj > 0) and (Ni/Nj > min_rate_ratio) and (Nj/Ni > min_rate_ratio)
if overlap_handling == "adjust Aeff" and worthwhile:
Ai = self._Aeff_interpolator(filter=xi)
Aj = self._Aeff_interpolator(filter=xj)
wi = self['w'][xi]
wj = self['w'][xj]
Ai_at_j = Ai(wj)
Aj_at_i = Aj(wi)
self['a'][xi] += Aj_at_i
self['a'][xj] += Ai_at_j
elif overlap_handling == "clip" or not worthwhile:
if Ni < Nj:
ii, = _np.nonzero(xi)
self.photons.remove_rows(ii)
self.obs_bandpasses[i] = utils.rangeset_subtract(self.obs_bandpasses[i], overlap)
else:
jj, = _np.nonzero(xj)
self.photons.remove_rows(jj)
self.obs_bandpasses[j] = utils.rangeset_subtract(self.obs_bandpasses[j], overlap)
j_photons = self['n'] == j
signal = get_signal()
else:
raise ValueError("overlap_handling option not recognized.")
# associate photons from observation j with i
self['n'][j_photons] = i
# decrement higher observation numbers (else while loop indexing gets messed up)
self['n'][self.photons['n'] > j] -= 1
# update properties of observation i
self.obs_metadata[i] += self.obs_metadata[j]
self.obs_bandpasses[i] = utils.rangeset_union(self.obs_bandpasses[i], self.obs_bandpasses[j])
# remove observation j
del self.obs_times[j]
del self.obs_metadata[j]
del self.obs_bandpasses[j]
else:
j += 1
i += 1
merge_like_observations.__doc__ = merge_like_observations.__doc__.format(ovr=_ovr_doc)
def merge_orders(self, overlap_handling="adjust Aeff"):
"""
Merge the orders in each observation in place.
Parameters
----------
{ovr}
Returns
-------
None (operation done in place)
Notes
-----
Merging is accomplished by changing the "region" weights of the photons where there is overlap in wavelength
to account for the double (or more than double) counting.
"""
# split into separate photons objects for each observation, and split the orders within that observation into
# faux separate observations, then merge them
if len(self.obs_times) > 1:
separate = [self.get_obs(i) for i in range(len(self.obs_times))]
else:
separate = [self]
for obj in separate:
temp_meta = obj.obs_metadata
order_ranges = obj.obs_bandpasses[0]
Norders = len(order_ranges)
obj.obs_metadata = [0]*Norders
obj.obs_times *= Norders
obj.obs_bandpasses = [order_ranges[[i]] for i in range(len(order_ranges))]
obj.photons['n'] = obj.photons['o'] - _np.min(obj.photons['o'])
obj.photons.remove_column('o')
obj.merge_like_observations(overlap_handling=overlap_handling)
obj.obs_metadata = temp_meta
obj.photons.remove_column('n')
obj.obs_bandpasses = [_np.vstack(obj.obs_bandpasses)]
obj.obs_times = [obj.obs_times[0]]
all = sum(separate[1:], separate[0])
self.obs_times = all.obs_times
self.obs_bandpasses = all.obs_bandpasses
self.obs_metadata = all.obs_metadata
self.photons = all.photons
merge_orders.__doc__ = merge_orders.__doc__.format(ovr=_ovr_doc)
def __getitem__(self, key):
key = self._get_proper_key(key)
return self.photons[key]
def __setitem__(self, key, value):
try:
key = self._get_proper_key(key)
self.photons[key] = value
except KeyError:
# try to make a new column. this might be dangerous
if _np.isscalar(value):
value = [value]*len(self)
col = _tbl.Column(data=value, name=key)
self.photons.add_column(col)
def __len__(self):
return len(self.photons)
def __add__(self, other):
"""
When adding, the photon recarrays will be added. The observation numbers will be adjusted or added as
appropriate and times will be referenced to the first of the two objects.
Parameters
----------
other
Returns
-------
"""
if not isinstance(other, Photons):
raise ValueError('Can only add a Photons object to another Photons object.')
# don't want to modify what is being added
other = other.copy()
# make column units consistent with self
other.match_units(self)
# add and /or update observation columns as necessary
self.add_observations_column()
other.add_observations_column()
n_obs_self = len(self.obs_metadata)
other['n'] += n_obs_self
# re-reference times to the datum of self
other.set_time_datum(self.time_datum)
# stack the data tables
photons = _tbl.vstack([self.photons, other.photons])
# leave it to the user to deal with sorting and grouping and dealing with overlap as they see fit :)
obs_metadata = self.obs_metadata + other.obs_metadata
obs_times = list(self.obs_times) + list(other.obs_times)
obs_bandpasses = list(self.obs_bandpasses) + list(other.obs_bandpasses)
return Photons(photons=photons, obs_metadata=obs_metadata, time_datum=self.time_datum, obs_times=obs_times,
obs_bandpasses=obs_bandpasses)
def copy(self):
new = Photons()
new.obs_metadata = [item.copy() for item in self.obs_metadata]
new.time_datum = self.time_datum
new.obs_times = [item.copy() for item in self.obs_times]
new.photons = self.photons.copy()
new.obs_bandpasses = [item.copy() for item in self.obs_bandpasses]
return new
def __contains__(self, item):
return item in self.photons.colnames
def writeFITS(self, path, overwrite=False):
"""
Parameters
----------
path
overwrite
Returns
-------
"""
primary_hdu = _fits.PrimaryHDU()
# save photontable to first extension
photon_cols = []
for colname in self.photons.colnames:
tbl_col = self[colname]
name = _name_dict.get(colname, colname)
format = _FITSformat(tbl_col.dtype)
fits_col = _fits.Column(name=name, format=format, array=tbl_col.data, unit=str(tbl_col.unit))
photon_cols.append(fits_col)
photon_hdr = _fits.Header()
photon_hdr['zerotime'] = (self.time_datum.jd, 'julian date of time = 0.0')
photon_hdu = _fits.BinTableHDU.from_columns(photon_cols, header=photon_hdr)
# save obs metadata, time ranges, bandpasses to additional extensions
obs_hdus = []
for meta, times, bands in zip(self.obs_metadata, self.obs_times, self.obs_bandpasses):
# check that meta can be saved as a fits header
if isinstance(meta, _fits.Header):
hdr = meta
elif hasattr(meta, 'iteritems'):
hdr = _fits.Header(iter(meta.items()))
else:
raise ValueError('FITS file cannot be constructed because Photons object has an improper list of '
'observation metadata. The metadata items must either be pyFITS header objects or '
'have an "iteritems()" method (i.e. be dictionary-like).')
# save obs time and wave ranges to each extension
bandpas0, bandpas1 = bands.T
start, stop = times.T
arys = [bandpas0, bandpas1, start, stop]
names = ['bandpas0', 'bandpas1', 'start', 'stop']
units = [str(self['w'].unit)]*2 + [str(self['t'].unit)]*2
formats = ['{}D'.format(len(a)) for a in arys]
info_cols = [_fits.Column(array=a.reshape([1,-1]), name=n, unit=u, format=fmt)
for a,n,u,fmt in zip(arys, names, units, formats)]
hdu = _fits.BinTableHDU.from_columns(info_cols, header=meta)
obs_hdus.append(hdu)
# save all extensions
hdulist = _fits.HDUList([primary_hdu, photon_hdu] + obs_hdus)
hdulist.writeto(path, overwrite=overwrite)
@classmethod
def loadFITS(cls, path):
"""
Parameters
----------
path
Returns
-------
Photons object
"""
# create an empty Photons object
obj = Photons()
# open file
hdulist = _fits.open(path)
# parse photon data
photon_hdu = hdulist[1]
photon_hdr, photons = photon_hdu.header, photon_hdu.data
obj.time_datum = _time.Time(photon_hdr['zerotime'], format='jd')
tbl_cols = []
for i, key in enumerate(photons.names):
unit = photon_hdr['TUNIT{}'.format(i+1)]
if unit not in [None, "None"]:
unit = _u.Unit(unit)
name = cls._alternate_names.get(key, key.lower())
col = _tbl.Column(data=photons[key], name=name, unit=unit)
tbl_cols.append(col)
obj.photons = _tbl.Table(tbl_cols)
# parse observation time and wavelength ranges
def parse_bands_and_time(extension):
bandpas0, bandpas1, start, stop = [extension.data[s] for s in ['bandpas0', 'bandpas1', 'start', 'stop']]
bands = _np.vstack([bandpas0, bandpas1]).T
times = _np.vstack([start, stop]).T
bands.reshape(-1,2)
times = times.reshape(-1,2)
return bands, times
# parse observation metadata
obj.obs_metadata = [hdu.header for hdu in hdulist[2:]]
# parse times and bands
pairs = list(map(parse_bands_and_time, hdulist[2:]))
bands, times = list(zip(*pairs))
obj.obs_times = times
obj.obs_bandpasses = bands
return obj
#endregion
#region DATA MANIPULATION METHODS
def set_time_datum(self, new_datum=None):
"""
Modifies the Photons object in-place to have a new time datum. Default is to set to the time of the earliest
photon.
Parameters
----------
new_datum : any object recognized by astropy.time.Time()
Returns
-------
None
"""
if new_datum is None:
dt = _time.TimeDelta(self['t'].min(), format=self['t'].unit.to_string())
new_datum = self.time_datum + dt
else:
dt = new_datum - self.time_datum
dt = dt.to(self['t'].unit).value
# ensure appropriate precision is maintained
ddt = _np.diff(self['t'])
dtmin = ddt[ddt > 0].min()
max_bit = _np.log2(self['t'][-1] + abs(dt))
min_bit = _np.log2(dtmin)
need_bits = _np.ceil(max_bit - min_bit) + 3
need_bytes = _np.ceil(need_bits/8.)
if need_bytes > 8:
raise ValueError('Resetting the time atum of this observation by {} {} will result in loss of numerical '
'precision of the photon arrival times.'.format(dt, self['t'].unit))
use_bytes = have_bytes = int(self['t'].dtype.str[-1])
while need_bytes > use_bytes and use_bytes < 8:
use_bytes *= 2
if use_bytes != have_bytes:
new_dtype = 'f' + str(use_bytes)
self['t'] = self['t'].astype(new_dtype)
self['t'] -= dt
self.time_datum = new_datum
self.obs_times = [t - dt for t in self.obs_times]
def match_units(self, other):
"""
Converts the units of each column in self to the units of hte corresponding column in other, if mathcing
columns are present.
Parameters
----------
other : Photons object
Returns
-------
"""
for key in self.photons.colnames:
if key in other.photons.colnames:
if other[key].unit:
unit = other[key].unit
if str(unit).lower() == 'none' and str(self[key].unit).lower() == 'none':
continue
self[key].convert_unit_to(unit)
def add_observations_column(self):
"""
Adds a column for observation identifiers (to match the index of the observation metadata list item) to self,
if such a column is not already present.
"""
if 'n' not in self:
if len(self.obs_metadata) > 1:
raise ValueError('Photons cannot be assigned to multiple observations because who the F knows which '
'obseration they belong to?')
n_ary = _np.zeros(len(self.photons))
n_col = _tbl.Column(data=n_ary, dtype='i2', name='n')
self.photons.add_column(n_col)
def divvy(self, ysignal, yback=(), order=None):
"""
Provides a simple means of divyying photons into signal and background regions (and adding/updating the
associated 'r' column, by specifying limits of these regions in the y coordinate.
Users can implement more complicated divvying schemes (such as changing signal and background region sizes) by
simply creating their own 'r' column explicitly.
Parameters
----------
ysignal : 1D or 2D array-like
[[y00, y01], [y10, y11], ...] giving limits of signal regions
yback : 1D or 2D array-like
akin to ysignal, but for background regions
Returns
-------
"""
# groom the input
ysignal, yback = [_np.reshape(a, [-1, 2]) for a in [ysignal, yback]]
assert order is None or type(order) == int
# join the edges into one list
edges, isignal, iback, area_ratio = self._get_ribbon_edges(ysignal, yback)
# deal just with the photons of appropriate order
if order is not None:
filter = self['o'] == order
else:
filter = slice(None)
# determine which band counts are in
y = self['y'][filter]
ii = _np.searchsorted(edges, y)
# add/modify weights in 'r' column
# TRADE: sacrifice memory with a float weights column versus storing the area ratio and just using integer
# flags because this allows better flexibility when combining photons from multiple observations
if 'r' not in self:
self['r'] = _np.zeros(len(self), 'f4')
self['r'][filter] = 0
signal = _np.zeros(len(self), bool)
signal[filter] = reduce(_np.logical_or, [ii == i for i in isignal])
self['r'][signal] = 1.0
if len(yback) > 0:
bkgnd = _np.zeros(len(self), bool)
bkgnd[filter] = reduce(_np.logical_or, [ii == i for i in iback])
self['r'][bkgnd] = -area_ratio
def squish(self, keep='both'):
"""
Removes superfluous counts -- those that aren't in a signal region or background region.
"""
if 'r' not in self:
raise ValueError('Photon object must have an \'r\' column (specifying a region weight) before it can be '
'squished.')
valid_keeps = ['both'] + ['back', 'background', 'bg'] + ['signal']
if keep not in valid_keeps:
raise ValueError('keep parameter must be one of {}'.format(valid_keeps))
new = self.copy()
if keep in ['both']:
superfluous = (self['r'] == 0)
new.photons = self.photons[~superfluous]
elif keep in ['back', 'background', 'bg']:
new.photons = self.photons[self['r'] < 0]
del new.photons['r']
elif keep in ['signal']:
new.photons = self.photons[self['r'] > 0]
del new.photons['r']
return new
#endregion
# region ANALYSIS METHODS
def spectrum(self, bins, waveranges=None, fluxed=False, energy_units='erg', order=None, time_ranges=None,
bin_method='elastic', background=False):
"""
Parameters
----------
bins
waveranges
fluxed
energy_units
order
Returns
-------
bin_edges, bin_midpts, rates, errors
"""
filter = self._filter_boiler(waveranges=None, time_ranges=time_ranges, order=order)
bin_edges, i_gaps = self._groom_wbins(bins, waveranges, bin_method=bin_method)
counts, errors = self._histogram('w', bin_edges, None, fluxed, energy_units, filter, background=background)
# add nans if there are gaps
if i_gaps is not None:
counts[i_gaps] = _np.nan
errors[i_gaps] = _np.nan
# divide by bin widths and exposure time to get rates
bin_exptime = self.time_per_bin(bin_edges, time_ranges=time_ranges)
bin_exptime[bin_exptime == 0] = _np.nan # use nans to avoid division errors
bin_widths = _np.diff(bin_edges)
rates = counts/bin_exptime/bin_widths
errors = errors/bin_exptime/bin_widths
# get bin midpoints
bin_midpts = (bin_edges[:-1] + bin_edges[1:])/2.0
return bin_edges, bin_midpts, rates, errors
def spectrum_smooth(self, n, wave_range=None, time_ranges=None, fluxed=False, energy_units='erg', order=None):
"""
Parameters
----------
n
wave_range
time_range
fluxed
energy_units
Returns
-------
bin_start, bin_stop, bin_midpts, rates, errors
"""
# TODO: check with G230L spectrum from ak sco and see what is going on
filter = self._filter_boiler(waveranges=wave_range, time_ranges=time_ranges, order=order)
# get pertinent photon info
weights = self._full_weights(fluxed, energy_units)
w = self['w']
# which photons have nonzero weight
countable = (weights != 0)
# filter out superfluous photons
keep = filter & countable
w, weights = [a[keep] for a in [w, weights]]
# sort photons and weights in order of wavelength
isort = _np.argsort(w)
w, weights = w[isort], weights[isort]
if wave_range is None:
wave_range = w[[0, -1]]
# smooth using same process as for lightcurve_smooth
bin_start, bin_stop, bin_midpts, rates, errors = _smooth_boilerplate(w, weights, n, wave_range)
# divide by time to get rates
bin_exptimes = self.time_per_bin([bin_start, bin_stop], time_ranges)
rates = rates/bin_exptimes
errors = errors/bin_exptimes
return bin_start, bin_stop, bin_midpts, rates, errors
def lightcurve(self, time_step, bandpasses, time_range=None, bin_method='elastic', fluxed=False,
energy_units='erg', nan_between=False, background=False, order=None):
"""
Parameters
----------
time_step
bandpasses
time_range
bin_method:
elastic, full, or partial
fluxed
energy_units
Returns
-------
bin_start, bin_stop, bin_midpts, rates, errors
"""
# check that at least some observations cover the input
covered = self.check_wavelength_coverage(bandpasses)
if not _np.any(covered):
raise ValueError('None of the observations cover the provided bandpasses.')
# construct time bins. this is really where this method is doing a lot of work for the user in dealing with
# the exposures and associated gaps
if hasattr(time_step, '__iter__'):
edges = _np.unique(_np.ravel(time_step))
mids = utils.midpts(edges)
valid_times = utils.rangeset_intersect(_np.vstack(self.obs_times), _np.array(time_step))
valid_bins = utils.inranges(mids, valid_times)
else:
edges, valid_bins = self._construct_time_bins(time_step, bin_method, time_range, bandpasses)
inbands = self._bandpass_filter(bandpasses, check_coverage=False)
filter = self._filter_boiler(waveranges=None, time_ranges=None, order=order)
filter = filter & inbands
# histogram the counts
counts, errors = self._histogram('t', edges, time_range, fluxed, energy_units, filter=filter,
background=background)
# get length of each time bin and the bin start and stop
dt = _np.diff(edges)
bin_start, bin_stop = edges[:-1].copy(), edges[1:].copy()
# get rid of the bins in between exposures
if nan_between:
betweens = _np.ones_like(counts, bool)
betweens[valid_bins] = False
for a in [counts, errors, dt, bin_start, bin_stop]:
a[betweens] = _np.nan
else:
counts, errors, dt, bin_start, bin_stop = [a[valid_bins] for a in [counts, errors, dt, bin_start, bin_stop]]
if counts.size == 0:
return [_np.array([]) for _ in range(5)]
# divide by exposure time to get rates
rates, errors = counts/dt, errors/dt
# bin midpoints
bin_midpts = (bin_start + bin_stop)/2.0
return bin_start, bin_stop, bin_midpts, rates, errors
def lightcurve_smooth(self, n, bandpasses, time_range=None, fluxed=False, energy_units='erg', nan_between=False,
independent=False, order=None):
"""
Parameters
----------
n
bandpasses
time_range
fluxed
energy_units
nan_between
Add NaN points between each observation. Useful; for plotting because it breaks the plotted line between
exposures.
independent :
Return only ever nth point in each exposure such that the points are statistically independent.
Returns
-------
bin_start, bin_stop, bin_midpt, rates, error
"""
# FIXME there is a bug where fluxing doesn't work when time range is set
# check that at least some observations cover the input
covered = self.check_wavelength_coverage(bandpasses)
if not _np.any(covered):
raise ValueError('None of the observations cover the provided bandpasses.')
# get pertinent photon info
weights = self._full_weights(fluxed, energy_units)
t = self['t']
obs = self['n'] if 'n' in self else _np.zeros(len(t), bool)
# which photons are in wavelength bandpasses
inbands = self._bandpass_filter(bandpasses, check_coverage=False)
# which photons are in specified time range and order
filter = self._filter_boiler(waveranges=None, time_ranges=time_range, order=order)
# which photons have nonzero weight
countable = (weights != 0)
# filter superfluous photons
keep = inbands & filter & countable
t, obs, weights = [a[keep] for a in [t, obs, weights]]
isort = _np.argsort(t)
t, obs, weights = [a[isort] for a in [t, obs, weights]]
curves =[] # each curve in list will have bin_start, bin_stop, bin_midpts, rates, error
for i in range(len(self.obs_metadata)):
from_obs_i = (obs == i)
_t, _weights = t[from_obs_i], weights[from_obs_i]
for rng in self.obs_times[i]:
inrng = (_t > rng[0]) & (_t < rng[1])
if _np.sum(inrng) < n:
continue
curve = _smooth_boilerplate(_t[inrng], _weights[inrng], n, time_range, independent)
if nan_between:
curve = [_np.append(a, _np.nan) for a in curve]
curves.append(curve)
# sneaky code to take the list of curves and combine them
if len(curves) > 0:
bin_start, bin_stop, bin_midpt, rates, error = [_np.hstack(a) for a in zip(*curves)]
else:
bin_start, bin_stop, bin_midpt, rates, error = [_np.array([]) for _ in range(5)]
return bin_start, bin_stop, bin_midpt, rates, error
def spectrum_frames(self, bins, time_step, **kws):
"""
Parameters
----------
bins
time_step
waverange
time_range
bin_method
fluxed
energy_units
order
progress_bar
Returns
-------
starts, stops, time_midpts, bin_edges, bin_midpts, rates, errors
"""
waveranges = kws.get('waveranges', None)
time_range = kws.get('time_range', None)
w_bin_method = kws.get('w_bin_method', 'full')
t_bin_method = kws.get('t_bin_method', 'full')
fluxed = kws.get('fluxed', False)
energy_units = kws.get('energy_units', 'erg')
order = kws.get('order', None)
progress_bar = kws.get('progress_bar', False)
slim = kws.get('slim', True)
# if using limited ranges, slim down the photon list accordingly to
# speed up runtime (can make a huge difference)
if slim and (waveranges is not None or time_range is not None):
keep = _np.ones(len(self), bool)
if waveranges is not None:
keep = keep & utils.inranges(self['w'], waveranges)
if time_range is not None:
keep = keep & utils.inranges(self['t'], time_range)
kws['slim'] = False
p = _copy.deepcopy(self)
p.photons = p.photons[keep]
return p.spectrum_frames(bins, time_step, **kws)
if progress_bar:
try:
from tqdm import tqdm as bar
except ImportError:
raise ImportError('You need the tqdm module for a progress '
'bar.')
else:
bar = lambda x: x
kws = dict(fluxed=fluxed, energy_units=energy_units, order=order)
# make wbins ahead of time so this doesn't happen in loop
bin_edges, i_gaps = self._groom_wbins(bins, waveranges, bin_method=w_bin_method)
kws['bins'] = bin_edges
# get start and stop of all the time steps
if hasattr(time_step, '__iter__'):
print('I see that you gave user supplied time bins. That is fine, but note that no checks will be ' \
'performed to ensure wavlength and time coverage in that range. And any time_range parameter will ' \
'be ignored.')
starts, stops = _np.asarray(time_step).T
else:
time_edges, valid_time_bins = self._construct_time_bins(time_step, t_bin_method, time_range)
starts, stops = time_edges[:-1][valid_time_bins], time_edges[1:][valid_time_bins]
time_midpts = (starts + stops)/2.0
spectra = []
for time_range in bar(list(zip(starts, stops))):
kws['time_ranges'] = time_range
spectra.append(self.spectrum(**kws))
bin_edges, bin_midpts, rates, errors = list(map(_np.array, list(zip(*spectra))))
rates[:,i_gaps] = _np.nan
errors[:,i_gaps] = _np.nan
return starts, stops, time_midpts, bin_edges, bin_midpts, rates, errors
def continuum_subtracted_lightcurves(self, dt, dw, continuum_bands, lc_bands, poly_order, fluxed=False,
energy_units='erg', time_range=None, w_bin_method='elastic',
t_bin_method='elastic', progress_bar=False):
kws = dict(fluxed=fluxed, energy_units=energy_units)
t0, t1, t, we, w, f, e = self.spectrum_frames(dw, dt, waveranges=continuum_bands, w_bin_method=w_bin_method,
t_bin_method=t_bin_method, time_range=time_range, progress_bar=progress_bar, **kws)
tbins = _np.array([t0, t1]).T
good = ~_np.isnan(f[0])
wbins = _np.array([we[0,:-1][good], we[0,1:][good]]).T
dw = wbins[:,1] - wbins[:,0]
polyfuncs = []
for i in range(len(t)):
_, _, pf = utils.polyfit_binned(wbins, f[i,good]*dw, e[i,good]*dw,
poly_order)
polyfuncs.append(pf)
# tf, cf, lf for total flux, continuum flux, line flux
lfs, les = [], []
for band in lc_bands:
_, _, _, tf, te = self.lightcurve(tbins, band, bin_method=t_bin_method, **kws)
lf, le = [_np.zeros_like(tf) for _ in range(2)]
for i in range(len(t)):
cf, cfe = polyfuncs[i](band)
cf = _np.sum(cf)
cfe = _np.sqrt(_np.sum(cfe**2))
lf[i] = tf[i] - cf
le[i] = _np.sqrt(te[i]**2 + cfe**2)
lfs.append(lf)
les.append(le)
return t0, t1, t, lfs, les
def average_rate(self, bands, timeranges=None, fluxed=False, energy_units='erg', order=None):
if timeranges is None:
timeranges = _np.vstack(self.obs_times)
t0, t1, t, f, e = self.lightcurve(timeranges, bands, fluxed=fluxed, energy_units=energy_units, order=order)
dt = t1 - t0
return _np.sum(dt * f) / _np.sum(dt)
#endregion
#region UTILITIES
def get_obs(self, obsno):
if 'n' not in self:
self.photons['n'] = 0
new = Photons()
new.photons = self.photons[self.photons['n'] == obsno]
new.photons['n'] = 0
new.time_datum = self.time_datum
new.obs_bandpasses = [self.obs_bandpasses[obsno]]
new.obs_times = [self.obs_times[obsno]]
new.obs_metadata = [self.obs_metadata[obsno]]
return new
def which_obs(self, t):
obsno = _np.ones(t.shape, 'i2')*(-1)
for n, tranges in enumerate(self.obs_times):
i = _np.searchsorted(tranges.ravel(), t)
inobs = (i % 2) == 1
if not _np.all(obsno[inobs] == -1):
raise ValueError('Some observation time ranges overlap at the input times, so these times cannot be '
'uniquely associated with an observation.')
obsno[inobs] = n
return obsno
def image(self, xax, yax, bins, weighted=False, scalefunc=None, show=True):
"""