import numpy as np
from gwpy.table import EventTable
from gwpy.table.filters import in_segmentlist
from gwtrigfind import find_trigger_files
from lal.utils import CacheEntry
from ... import utils
from ... import hookimpl
from . import DataLoader
[docs]class KWDataLoader(DataLoader):
"""
an extension of ClassifierData specifically for KleineWelle triggers.
expects triggers to be in multi-channel files
Note, if we do not request any specific channel(s), all discoverable
channels will be returned
"""
_default_columns = utils.KW_COLUMNS
_allowed_columns = utils.KW_COLUMNS
_required_kwargs = ["instrument"]
_suffix = "trg"
def _query(self, channels=None, bounds=None, verbose=False, **kwargs):
"""
read all relevant triggers from a single KW file
we assume the following directory structure
${rootdir}/${basename}-${gpsMOD1e5}/${basename}-${start}-${dur}.trg
"""
if bounds is None:
bounds = dict()
# look into non-standard directory if specified
if "rootdir" in self.kwargs:
base = self.kwargs["rootdir"]
else:
base = None
# NOTE: need to a specify a channel to find triggers even if it's
# not used to find files, so use one guaranteed to exist
if not channels:
channel = "{}CAL-DELTAL_EXTERNAL_DQ_32_2048".format(
self.kwargs["instrument"]
)
else:
channel = channels[0]
# generate file cache and filter by segments
cache = find_trigger_files(
channel, "kw", self.start, self.end, base=base, ext=self._suffix
)
cache = [CacheEntry.from_T050017(path) for path in cache]
cache = [c for c in cache if self.segs.intersects_segment(c.segment)]
# set up filters for selection
selection = [(utils.KW_TIMECOL, in_segmentlist, self.segs)]
if bounds:
for col, (min_, max_) in bounds.items():
selection.extend([f"{col} >= {min_}", f"{col} <= {max_}"])
# read in triggers
table = EventTable.read(
cache,
format="ascii.commented_header",
columns=self.columns + ["channel"],
selection=selection,
nproc=self.kwargs.get("nproc", 1),
)
# split by channel, remove column
data = {}
grouped = table.group_by("channel")
for key, group in zip(grouped.groups.keys, grouped.groups):
if not channels or key["channel"] in channels:
channel = key["channel"]
data[channel] = group
del data[channel]["channel"]
# fill in missing channels
if channels:
dtype = [(col, "float") for col in self.columns]
for channel in channels:
if channel not in data:
data[channel] = EventTable(data=np.array([], dtype=dtype))
return data
@hookimpl
def get_dataloaders():
return {
"kw:ascii": KWDataLoader,
"kw": KWDataLoader,
}