"""Test utilities. Common functions used across various GstLAL unittests
"""
import gi
gi.require_version('Gst', '1.0')
from gi.repository import GObject, Gst
GObject.threads_init()
Gst.init(None)
import os
import pathlib
import string
import sys
import tempfile
import types
from typing import Tuple, Dict
from unittest import mock
import pytest
PLATFORM = sys.platform
DEFAULT_MOCK_PATCHES = (
# Ordered mapping of (target, {kwarg: value}) for passing into unittest.mock.patch
('gstlal.datafind.load_frame_cache', {'return_value': [1, 2, 3]}),
)
CLEAN_TRANSLATION = {ord(c): None for c in string.whitespace}
[docs]
def clean_str(c: str):
"""Clean a copyright string before comparison"""
return c.translate(CLEAN_TRANSLATION)
[docs]
def is_osx(platform: str = PLATFORM):
"""Check is OSX"""
return platform.lower() == 'darwin'
[docs]
def skip_osx(f: types.FunctionType) -> types.FunctionType:
"""Decorator wrapping pytest.skipif"""
return pytest.mark.skipif(is_osx(), reason='Test not supported on OSX')(f)
[docs]
def requires_full_build(f: types.FunctionType):
return pytest.mark.requires_full_build(f)
[docs]
def broken(reason: str):
def wrapper(f: types.FunctionType):
func = pytest.mark.skip(f, reason)
func = pytest.mark.broken(func)
return func
return wrapper
[docs]
def impl_deprecated(f):
return broken('Underlying implementation not included in build')(f)
[docs]
class GstLALTestManager:
"""Context manager for GstLAL tests"""
def __init__(self, patch_info: Tuple[Dict[str, Dict]] = DEFAULT_MOCK_PATCHES, env_overrides: dict = None, with_pipeline: bool = False):
self.tmp_dir = tempfile.TemporaryDirectory()
self.tmp_path = pathlib.Path(self.tmp_dir.name)
self.patch_info = patch_info
self._patches = []
self._env_originals = {}
self._env_overrides = {} if env_overrides is None else env_overrides
self._with_pipeline = with_pipeline
def __enter__(self):
"""Enter the GstLAL testing context"""
# Create temporary directory
self.tmp_dir.__enter__()
# Set all mocks
for target, kwargs in self.patch_info:
p = mock.patch(target, **kwargs)
self._patches.append(p)
p.__enter__()
# Set env overrides
keys = list(self._env_overrides.keys())
for k in keys:
self.override_env_var(k, self._env_overrides[k])
# Set pipeline
self.set_pipeline()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit GstLAL context"""
# Remove tmp dir
self.tmp_dir.__exit__(exc_type, exc_val, exc_tb)
# Unset all mocks
for p in self._patches:
p.__exit__(exc_type, exc_val, exc_tb)
# Undo all env overrides
keys = list(self._env_originals.keys())
for k in keys:
self.reset_env_var(k)
@property
def cache_path(self):
"""Cache path"""
return (self.tmp_path / 'cache.txt').as_posix()
[docs]
def override_env_var(self, key: str, val: str):
if key in self._env_originals:
raise ValueError('That env var is already overridden')
self._env_originals[key] = os.environ.get(key, None)
self._env_overrides[key] = val
os.environ[key] = val
[docs]
def reset_env_var(self, key: str):
if key not in self._env_originals:
raise ValueError('Unable to reset var: {}'.format(key))
val = self._env_originals.pop(key)
if val is None:
os.environ.pop(key)
else:
os.environ[key] = val
if key in self._env_overrides:
self._env_overrides.pop(key)
[docs]
def set_pipeline(self):
if self._with_pipeline:
self.pipeline = Gst.Pipeline(name=os.path.split(sys.argv[0])[1])
else:
self.pipeline = None