Source code for idq.io.triggers

import pluggy

from astropy.table import vstack
from gwpy.table.filters import in_segmentlist
from ligo.segments import segment, segmentlist

from ... import features
from ... import utils


[docs]class DataLoader(object): """A data loader to retrieve features spanning multiple channels.""" _skip_filter = False _default_columns = [] # which columns will be pulled from data by default _allowed_columns = None # which columns are allowed; None allows anything _required_kwargs = [] # which kwargs are required for this type of data object def __init__(self, start, end, segs=None, columns=None, **kwargs): """ Each instance points to what is immutable - start/end time - columns In addition, while segments are essentially frozen, they can be trimmed as long as they're a strict subset of the previous segments When data is queried it will automatically be within these segments and only those columns and returns tabular data keyed by channel. """ self._data = {} self._cached = {} assert start <= end, "start must be <= end" self._start = start self._end = end if segs is not None: assert ( segmentlist([segment(start, end)]) & segs ) == segs, "segments must be a subset of start, end times" else: segs = segmentlist([segment(start, end)]) self._segs = segs for kwarg in self._required_kwargs: assert kwarg in kwargs, "kwarg=%s required" % kwarg self.kwargs = kwargs if not columns: columns = self._default_columns if self._allowed_columns is not None: for col in columns: assert col in self._allowed_columns, "column=%s not allowed" % col self._columns = [str(col) for col in columns] @property def start(self): return self._start @property def end(self): return self._end @property def segs(self): return self._segs @property def columns(self): return self._columns @property def channels(self): return sorted(self._cached.keys())
[docs] def is_cached(self, channel, bounds=None): """ Returns whether or not data is cached for the channel requested. """ if bounds is None: bounds = dict() if channel not in self._cached: return False # check to see whether there are any cached column bounds # that are not within the new requested bounds cache = self._cached[channel] for col in cache.keys(): # check if requesting possibly looser bounds than what's cached if col not in bounds: return False else: # they share this column, so check relative bounds requested_min, requested_max = bounds[col] cached_min, cached_max = cache[col] # check if we want more than is cached if (requested_min < cached_min) or (requested_max > cached_max): return False return True
[docs] def query( self, channels=None, columns=None, segs=None, time=None, bounds=None, **kwargs ): """ Submits a query to get features and returns the result of such a query as a dictionary of EventTables, keyed by channel. """ # check that all columns in bounds will be returned if bounds is None: bounds = dict() else: for col in bounds.keys(): assert ( col in self.columns ), f"bounds can only be placed on columns included, which {col} is not" # format and validate query if channels: if isinstance(channels, str): channels = [channels] if segs: assert segs in self.segs, ( "segments=%s passed in query are not a subset of contained segments" % segs ) assert time, "if passing in segments, time column needs to be specified" if columns and (self._allowed_columns is not None): if isinstance(columns, str): columns = [columns] for col in columns: assert col in self._allowed_columns, "column=%s not allowed" % col # determine if there's any cached data if self._data and (channels is not None): channels2query = [ chan for chan in channels if not self.is_cached(chan, bounds=bounds) ] if channels2query: self._data.update( self._query( channels=channels2query, segs=segs, bounds=bounds, **kwargs ) ) else: # no cached data, retrieve all features from query self._data.update( self._query(channels=channels, segs=segs, bounds=bounds, **kwargs) ) # grab queried data from cache if channels: queried_data = {channel: self._data[channel] for channel in channels} else: queried_data = dict(self._data) if segs: for channel in queried_data.keys(): queried_data[channel] = queried_data[channel].filter( (time, in_segmentlist, segs) ) if columns: for channel in queried_data.keys(): queried_data[channel] = queried_data[channel][columns] # update what is considered cached for channel in self._data.keys(): if not self.is_cached(channel, bounds=bounds): # we requested more data than was available # set the bounds to the most recent (widest) request self._cached[channel] = bounds # filter cached data by bounds for chan, datum in queried_data.items(): for col, (min_, max_) in bounds.items(): datum = datum.filter(f"{col} >= {min_}", f"{col} <= {max_}") queried_data[chan] = datum return features.DataChunk( start=self.start, end=self.end, segs=segs, columns=tuple(columns if columns else self.columns), features=queried_data, skip_filter=self._skip_filter, )
def _query(self, **kwargs): raise NotImplementedError("Child classes should define this for themselves!")
[docs] def filter(self, segs=None, bounds=None, time=features.DEFAULT_TIME_NAME): """ update segments and filters out data that don't span segments also filters by bounds and updates cache as needed NOTE: this requires knowledge of the "time" key within data """ # check that all columns in bounds will be returned if bounds is None: bounds = {} else: for col in bounds.keys(): assert ( col in self.columns ), f"bounds can only be placed on columns included, which {col} is not" if segs is None: segs = self.segs assert (self.segs & segs) == segs, ( "new segs must be a strict subset of existing segs\nsegs=%s\nnew segs=%s" % (self.segs, segs) ) # filter by segments self._segs = segmentlist(segs) for channel in self._data.keys(): self._data[channel] = self._data[channel].filter( (time, in_segmentlist, self.segs) ) # filter by bounds for chan, datum in self._data.items(): for col, (min_, max_) in bounds.items(): datum = datum.filter(f"{col} >= {min_}", f"{col} <= {max_}") self._data[chan] = datum # update what is considered cached for channel in self._data.keys(): for col, col_bounds in bounds.items(): if col not in self._cached[channel]: self._cached[channel][col] = col_bounds else: new_min, new_max = col_bounds cache_min, cache_max = self._cached[channel][col] self._cached[channel][col] = ( max(new_min, cache_min), min(new_max, cache_max), )
[docs] def flush( self, max_stride=features.DEFAULT_MAX_STRIDE, time=features.DEFAULT_TIME_NAME ): """ remove data to a target span and number of samples """ # update span and filter self._start = max(self.start, self.end - max_stride) return self.filter(time=time)
[docs] def pop(self, channel, default=None): """Remove and return all data associated with this channel.""" self._cached.pop(channel, None) return self._data.pop(channel, default)
def __add__(self, other): """ return a new data loader combining the two data loaders as well as any cached data contained within we take the union of segments and spans we also keep the union of channels from each object. NOTE: only data loaders with the same columns/kwargs can be combined """ assert type(self) is type( other ), "DataLoaders need to be the same type to be combined" assert set(self.columns) == set( other.columns ), "DataLoaders must have the same columns to be combined" assert ( self.kwargs == other.kwargs ), "DataLoaders must have the same kwargs to be combined" # take union of various properties start = min(self.start, other.start) end = max(self.end, other.end) segs = self.segs | other.segs # set up and combine cached data from both loaders dataloader = self.__class__( start, end, segs=segs, columns=self.columns, **self.kwargs ) # find channels with data only present in one loader. # we can combine data from these easily without careful # treatment with caching since bounds don't overlap for this, that in [(self, other), (other, self)]: diff = set(this.channels) - set(that.channels) dataloader._data.update( { channel: data for channel, data in this._data.items() if channel in diff } ) dataloader._cached.update( { channel: cached for channel, cached in this._cached.items() if channel in diff } ) # find channels common to both and combine data. we can check whether # data is cached for each channel in both loaders, which can be safely # combined if both specify each other's data is cached since this # implies bounds are identical common = set(self.channels) & set(other.channels) uncached = [] for channel in common: self_bounds = self._cached[channel] other_bounds = other._cached[channel] if self.is_cached(channel, other_bounds) and other.is_cached( channel, self_bounds ): dataloader._data[channel] = vstack( [self._data[channel], other._data[channel]], join_type="exact", ) dataloader._cached[channel] = self._cached[channel] else: uncached.append(channel) # deal with channels that can't be safely combined do this by taking the # intersection of the two bounds and combining both datasets. we do this # so that we don't trigger an expensive load operation or make the # addition process "lossy". if a user requests bounds in a future query, # this will err on the side of querying for more data than it may need for channel in common: self_bounds = set(self._cached[channel].keys()) - set( other._cached[channel].keys() ) other_bounds = set(other._cached[channel].keys()) - set( self._cached[channel].keys() ) common_bounds = set(self._cached[channel].keys()) & set( other._cached[channel].keys() ) # copy over bounds which are present in one but not the other dataloader._cached[channel] = { bound: self._cached[channel][bound] for bound in self_bounds } dataloader._cached[channel].update( {bound: other._cached[channel][bound] for bound in other_bounds} ) # take intersection of common bounds for bound in common_bounds: self_min, self_max = self._cached[channel][bound] other_min, other_max = other._cached[channel][bound] dataloader._cached[channel][bound] = ( min(self_min, other_min), max(self_max, other_max), ) # combine data together dataloader._data[channel] = vstack( [self._data[channel], other._data[channel]], join_type="exact", ) return dataloader
[docs] def target_times(self, time, target_channel, target_bounds, segs=None): """ A convenience function to extract target times, implicitly loading needed data. """ # find target times target_times = self.query( channels=target_channel, bounds=target_bounds ).features[target_channel][time] # filter times by segments if segs is not None: target_times = target_times[utils.times_in_segments(target_times, segs)] return target_times
[docs] def random_times( self, time, target_channel, dirty_bounds, dirty_window, random_rate, segs=None ): """ A convenience function to extract random times, implicitly loading needed data. """ # draw random times dirty_times = self.query(channels=target_channel, bounds=dirty_bounds).features[ target_channel ][time] # generate segments dirty_seg = utils.times2segments(dirty_times, dirty_window) # only draw from segments that are outside of dirtysegs but within # classifier_data.segs random_times = utils.draw_random_times(self.segs - dirty_seg, rate=random_rate) # filter times by segments if segs is not None: random_times = random_times[utils.times_in_segments(random_times, segs)] # find segments where data is clean clean_seg = self.segs - dirty_seg if segs is not None: clean_seg &= segs return random_times, clean_seg
hookspec = pluggy.HookspecMarker("iDQ") @hookspec def get_dataloaders(): """ This hook is used to return data loaders in the form: {"type[:format]": DataLoader} where the type refers to a specific data backend (snax, kleine-welle, etc.) and format (optional) refers to a spefific file format and/or data layout. """