Source code for idq.factories

import importlib.util
import os
import sys

import numpy as np
import pluggy

from . import classifiers
from . import features
from . import utils
from .io import reporters
from .io import triggers


DEFAULT_DT = classifiers.DEFAULT_DT


[docs]class DataLoaderFactory(object): """ A factory to create instances of DataLoaders """ def __init__(self): # set up plugin manager manager = pluggy.PluginManager("iDQ") manager.add_hookspecs(triggers) # load in dataloaders by module # FIXME: do this with importlib instead when python3 only from .io.triggers import kw, omicron, synthetic # base dataloaders manager.register(kw) manager.register(synthetic) manager.register(omicron) # extra (optional) dataloaders based on libraries try: from .io.triggers import snax except ImportError: pass else: manager.register(snax) # add all registered plugins to registry self._dataloaders = {} for loader_plugins in manager.hook.get_dataloaders(): self._dataloaders.update(loader_plugins) def __call__(self, start, end, segs=None, flavor=None, **kwargs): """ standardizes how we map data loader params into new instances * start is a float (gps seconds) * end is a float (gps seconds) * segs, if provided, should be a list of segments (a list of 2-element lists) * kwargs should be the result of **config.items(nickname) NOTE: 'flavor' is a reserved option that must be supplied returns * a data loader corresponding to the parameters within **kwargs """ if flavor not in self._dataloaders: raise KeyError("%s is not a known DataLoader!" % flavor) return self._dataloaders[flavor](start, end, segs=segs, **kwargs)
[docs]class ReporterFactory(object): """ A factory to create instances of Reporter-like objects """ def __init__(self): # set up plugin manager manager = pluggy.PluginManager("iDQ") manager.add_hookspecs(reporters) # load in reporters by module # FIXME: do this with importlib instead when python3 only # base reporters from .io.reporters import pkl, hdf5 manager.register(pkl) manager.register(hdf5) # extra (optional) reporters based on libraries try: from .io.reporters import gwf except ImportError: pass else: manager.register(gwf) try: from .io.reporters import ligolw except ImportError: pass else: manager.register(ligolw) # add all registered plugins to registry self._reporters = {} for reporter_plugins in manager.hook.get_reporters(): self._reporters.update(reporter_plugins) def __call__(self, rootdir, start, end, flavor=None, **kwargs): """ standarize how we construct reporters """ if flavor not in self._reporters: raise KeyError("%s is not a known Reporter!" % flavor) return self._reporters[flavor](rootdir, start, end, **kwargs)
[docs]class ClassifierFactory(object): """ A factory to create instances of SupervisedClassifier-like objects """ def __init__(self): # set up plugin manager manager = pluggy.PluginManager("iDQ") manager.add_hookspecs(classifiers) # load in classifiers by module # FIXME: do this with importlib instead when python3 only # base reporters from .classifiers import ovl manager.register(ovl) # extra (optional) classifiers based on libraries try: from .classifiers import sklearn except ImportError: pass else: manager.register(sklearn) try: from .classifiers import keras except ImportError: pass else: manager.register(keras) try: from .classifiers import xgb except ImportError: pass else: manager.register(xgb) # resolve local classifiers path suffix = os.path.join("idq", "classifiers.py") if "XDG_CONFIG_HOME" in os.environ: local_path = os.path.join(os.getenv("XDG_CONFIG_HOME"), suffix) else: local_path = os.path.join(os.getenv("HOME"), ".config", suffix) # load local classifiers if available if os.path.exists(local_path): spec = importlib.util.spec_from_file_location("classifiers", local_path) local_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(local_module) manager.register(local_module) # NOTE: register local module so that classifiers can be pickled sys.modules["classifiers"] = local_module # add all registered plugins to registry self._classifiers = {} for classifier_plugins in manager.hook.get_classifiers(): self._classifiers.update(classifier_plugins) def __call__(self, nickname, flavor=None, **kwargs): if flavor not in self._classifiers: raise KeyError("%s is not a known SupervisedClassifier!" % flavor) return self._classifiers[flavor](nickname, **kwargs)
[docs]class IncrementalClassifierFactory(object): """ A factory to create instances of IncrementalSupervisedClassifier-like objects """ def __init__(self): # set up plugin manager manager = pluggy.PluginManager("iDQ") manager.add_hookspecs(classifiers) # load in classifiers by module # FIXME: do this with importlib instead when python3 only # base reporters from .classifiers import ovl manager.register(ovl) # extra (optional) classifiers based on libraries try: from .classifiers import sklearn except ImportError: pass else: manager.register(sklearn) # add all registered plugins to registry self._classifiers = {} for classifier_plugins in manager.hook.get_incremental_classifiers(): self._classifiers.update(classifier_plugins) def __call__(self, nickname, flavor=None, **kwargs): if flavor not in self._classifiers: raise KeyError( "%s is not a known IncrementalSupervisedClassifier!" % flavor ) return self._classifiers[flavor](nickname, **kwargs)
[docs]class DatasetFactory(object): """ a factory object to generate datasets """ def __init__(self, data): self.data = data
[docs] def labeled(self, target_times, random_times): """ construct a dataset for these times target_times -> label=1. random_times -> label=0. """ return self( np.concatenate((target_times, random_times)), labels=[1.0] * len(target_times) + [0.0] * len(random_times), )
[docs] def unlabeled(self, dt=DEFAULT_DT, t_offset=0, segs=None): """ construct a dataset for regularly sampled unlabeled times dt. """ assert t_offset <= dt, "populate t_offset needs to be <= dt" if segs is None: segs = self.data.segs times = np.concatenate(tuple(utils.segs2times(segs, dt))) + t_offset return self(times)
def __call__(self, times, labels=None): """ a helper function to standardize how we assemble datasets """ if labels: assert len(labels) == len( times ), "labels and times must have the same length!" # populate dataset with features if isinstance(self.data, features.DataChunk): return features.Dataset.from_datachunks( self.data, times=times, labels=labels ) else: return features.Dataset(times=times, labels=labels, dataloader=self.data)