import pluggy
from astropy.table import vstack
from gwpy.table.filters import in_segmentlist
from ligo.segments import segment, segmentlist
from ... import features
from ... import utils
[docs]class DataLoader(object):
"""A data loader to retrieve features spanning multiple channels."""
_skip_filter = False
_default_columns = [] # which columns will be pulled from data by default
_allowed_columns = None # which columns are allowed; None allows anything
_required_kwargs = [] # which kwargs are required for this type of data object
def __init__(self, start, end, segs=None, columns=None, **kwargs):
"""
Each instance points to what is immutable
- start/end time
- columns
In addition, while segments are essentially frozen, they can be
trimmed as long as they're a strict subset of the previous segments
When data is queried it will automatically be within
these segments and only those columns and returns tabular
data keyed by channel.
"""
self._data = {}
self._cached = {}
assert start <= end, "start must be <= end"
self._start = start
self._end = end
if segs is not None:
assert (
segmentlist([segment(start, end)]) & segs
) == segs, "segments must be a subset of start, end times"
else:
segs = segmentlist([segment(start, end)])
self._segs = segs
for kwarg in self._required_kwargs:
assert kwarg in kwargs, "kwarg=%s required" % kwarg
self.kwargs = kwargs
if not columns:
columns = self._default_columns
if self._allowed_columns is not None:
for col in columns:
assert col in self._allowed_columns, "column=%s not allowed" % col
self._columns = [str(col) for col in columns]
@property
def start(self):
return self._start
@property
def end(self):
return self._end
@property
def segs(self):
return self._segs
@property
def columns(self):
return self._columns
@property
def channels(self):
return sorted(self._cached.keys())
[docs] def is_cached(self, channel, bounds=None):
"""
Returns whether or not data is cached for the channel requested.
"""
if bounds is None:
bounds = dict()
if channel not in self._cached:
return False
# check to see whether there are any cached column bounds
# that are not within the new requested bounds
cache = self._cached[channel]
for col in cache.keys():
# check if requesting possibly looser bounds than what's cached
if col not in bounds:
return False
else:
# they share this column, so check relative bounds
requested_min, requested_max = bounds[col]
cached_min, cached_max = cache[col]
# check if we want more than is cached
if (requested_min < cached_min) or (requested_max > cached_max):
return False
return True
[docs] def query(
self, channels=None, columns=None, segs=None, time=None, bounds=None, **kwargs
):
"""
Submits a query to get features and returns the result of such a query
as a dictionary of EventTables, keyed by channel.
"""
# check that all columns in bounds will be returned
if bounds is None:
bounds = dict()
else:
for col in bounds.keys():
assert (
col in self.columns
), f"bounds can only be placed on columns included, which {col} is not"
# format and validate query
if channels:
if isinstance(channels, str):
channels = [channels]
if segs:
assert segs in self.segs, (
"segments=%s passed in query are not a subset of contained segments"
% segs
)
assert time, "if passing in segments, time column needs to be specified"
if columns and (self._allowed_columns is not None):
if isinstance(columns, str):
columns = [columns]
for col in columns:
assert col in self._allowed_columns, "column=%s not allowed" % col
# determine if there's any cached data
if self._data and (channels is not None):
channels2query = [
chan for chan in channels if not self.is_cached(chan, bounds=bounds)
]
if channels2query:
self._data.update(
self._query(
channels=channels2query, segs=segs, bounds=bounds, **kwargs
)
)
else:
# no cached data, retrieve all features from query
self._data.update(
self._query(channels=channels, segs=segs, bounds=bounds, **kwargs)
)
# grab queried data from cache
if channels:
queried_data = {channel: self._data[channel] for channel in channels}
else:
queried_data = dict(self._data)
if segs:
for channel in queried_data.keys():
queried_data[channel] = queried_data[channel].filter(
(time, in_segmentlist, segs)
)
if columns:
for channel in queried_data.keys():
queried_data[channel] = queried_data[channel][columns]
# update what is considered cached
for channel in self._data.keys():
if not self.is_cached(channel, bounds=bounds):
# we requested more data than was available
# set the bounds to the most recent (widest) request
self._cached[channel] = bounds
# filter cached data by bounds
for chan, datum in queried_data.items():
for col, (min_, max_) in bounds.items():
datum = datum.filter(f"{col} >= {min_}", f"{col} <= {max_}")
queried_data[chan] = datum
return features.DataChunk(
start=self.start,
end=self.end,
segs=segs,
columns=tuple(columns if columns else self.columns),
features=queried_data,
skip_filter=self._skip_filter,
)
def _query(self, **kwargs):
raise NotImplementedError("Child classes should define this for themselves!")
[docs] def filter(self, segs=None, bounds=None, time=features.DEFAULT_TIME_NAME):
"""
update segments and filters out data that don't span segments
also filters by bounds and updates cache as needed
NOTE: this requires knowledge of the "time" key within data
"""
# check that all columns in bounds will be returned
if bounds is None:
bounds = {}
else:
for col in bounds.keys():
assert (
col in self.columns
), f"bounds can only be placed on columns included, which {col} is not"
if segs is None:
segs = self.segs
assert (self.segs & segs) == segs, (
"new segs must be a strict subset of existing segs\nsegs=%s\nnew segs=%s"
% (self.segs, segs)
)
# filter by segments
self._segs = segmentlist(segs)
for channel in self._data.keys():
self._data[channel] = self._data[channel].filter(
(time, in_segmentlist, self.segs)
)
# filter by bounds
for chan, datum in self._data.items():
for col, (min_, max_) in bounds.items():
datum = datum.filter(f"{col} >= {min_}", f"{col} <= {max_}")
self._data[chan] = datum
# update what is considered cached
for channel in self._data.keys():
for col, col_bounds in bounds.items():
if col not in self._cached[channel]:
self._cached[channel][col] = col_bounds
else:
new_min, new_max = col_bounds
cache_min, cache_max = self._cached[channel][col]
self._cached[channel][col] = (
max(new_min, cache_min),
min(new_max, cache_max),
)
[docs] def flush(
self, max_stride=features.DEFAULT_MAX_STRIDE, time=features.DEFAULT_TIME_NAME
):
"""
remove data to a target span and number of samples
"""
# update span and filter
self._start = max(self.start, self.end - max_stride)
return self.filter(time=time)
[docs] def pop(self, channel, default=None):
"""Remove and return all data associated with this channel."""
self._cached.pop(channel, None)
return self._data.pop(channel, default)
def __add__(self, other):
"""
return a new data loader combining the two data loaders
as well as any cached data contained within
we take the union of segments and spans
we also keep the union of channels from each object.
NOTE: only data loaders with the same columns/kwargs can be combined
"""
assert type(self) is type(
other
), "DataLoaders need to be the same type to be combined"
assert set(self.columns) == set(
other.columns
), "DataLoaders must have the same columns to be combined"
assert (
self.kwargs == other.kwargs
), "DataLoaders must have the same kwargs to be combined"
# take union of various properties
start = min(self.start, other.start)
end = max(self.end, other.end)
segs = self.segs | other.segs
# set up and combine cached data from both loaders
dataloader = self.__class__(
start, end, segs=segs, columns=self.columns, **self.kwargs
)
# find channels with data only present in one loader.
# we can combine data from these easily without careful
# treatment with caching since bounds don't overlap
for this, that in [(self, other), (other, self)]:
diff = set(this.channels) - set(that.channels)
dataloader._data.update(
{
channel: data
for channel, data in this._data.items()
if channel in diff
}
)
dataloader._cached.update(
{
channel: cached
for channel, cached in this._cached.items()
if channel in diff
}
)
# find channels common to both and combine data. we can check whether
# data is cached for each channel in both loaders, which can be safely
# combined if both specify each other's data is cached since this
# implies bounds are identical
common = set(self.channels) & set(other.channels)
uncached = []
for channel in common:
self_bounds = self._cached[channel]
other_bounds = other._cached[channel]
if self.is_cached(channel, other_bounds) and other.is_cached(
channel, self_bounds
):
dataloader._data[channel] = vstack(
[self._data[channel], other._data[channel]],
join_type="exact",
)
dataloader._cached[channel] = self._cached[channel]
else:
uncached.append(channel)
# deal with channels that can't be safely combined do this by taking the
# intersection of the two bounds and combining both datasets. we do this
# so that we don't trigger an expensive load operation or make the
# addition process "lossy". if a user requests bounds in a future query,
# this will err on the side of querying for more data than it may need
for channel in common:
self_bounds = set(self._cached[channel].keys()) - set(
other._cached[channel].keys()
)
other_bounds = set(other._cached[channel].keys()) - set(
self._cached[channel].keys()
)
common_bounds = set(self._cached[channel].keys()) & set(
other._cached[channel].keys()
)
# copy over bounds which are present in one but not the other
dataloader._cached[channel] = {
bound: self._cached[channel][bound] for bound in self_bounds
}
dataloader._cached[channel].update(
{bound: other._cached[channel][bound] for bound in other_bounds}
)
# take intersection of common bounds
for bound in common_bounds:
self_min, self_max = self._cached[channel][bound]
other_min, other_max = other._cached[channel][bound]
dataloader._cached[channel][bound] = (
min(self_min, other_min),
max(self_max, other_max),
)
# combine data together
dataloader._data[channel] = vstack(
[self._data[channel], other._data[channel]],
join_type="exact",
)
return dataloader
[docs] def target_times(self, time, target_channel, target_bounds, segs=None):
"""
A convenience function to extract target times, implicitly loading needed data.
"""
# find target times
target_times = self.query(
channels=target_channel, bounds=target_bounds
).features[target_channel][time]
# filter times by segments
if segs is not None:
target_times = target_times[utils.times_in_segments(target_times, segs)]
return target_times
[docs] def random_times(
self, time, target_channel, dirty_bounds, dirty_window, random_rate, segs=None
):
"""
A convenience function to extract random times, implicitly loading needed data.
"""
# draw random times
dirty_times = self.query(channels=target_channel, bounds=dirty_bounds).features[
target_channel
][time]
# generate segments
dirty_seg = utils.times2segments(dirty_times, dirty_window)
# only draw from segments that are outside of dirtysegs but within
# classifier_data.segs
random_times = utils.draw_random_times(self.segs - dirty_seg, rate=random_rate)
# filter times by segments
if segs is not None:
random_times = random_times[utils.times_in_segments(random_times, segs)]
# find segments where data is clean
clean_seg = self.segs - dirty_seg
if segs is not None:
clean_seg &= segs
return random_times, clean_seg
hookspec = pluggy.HookspecMarker("iDQ")
@hookspec
def get_dataloaders():
"""
This hook is used to return data loaders in the form:
{"type[:format]": DataLoader}
where the type refers to a specific data backend (snax, kleine-welle, etc.)
and format (optional) refers to a spefific file format and/or data layout.
"""