Source code for autowisp.light_curves.apply_correction

"""Unified interface to the detrending algorithms."""

from multiprocessing import Pool
import logging

import numpy
from scipy.optimize import minimize
import pandas

from autowisp.multiprocessing_util import setup_process
from autowisp import DataReductionFile, LightCurveFile
from autowisp.catalog import read_catalog_file
from autowisp.database.interface import db_engine
from .epd_correction import EPDCorrection
from .reconstructive_correction_transit import ReconstructiveCorrectionTransit


[docs] def save_correction_statistics(correction_statistics, filename): """Save the given statistics (result of apply_parallel_correction).""" print("Correction statistics:\n" + repr(correction_statistics)) dframe = pandas.DataFrame( { column: correction_statistics[column] for column in ["ID", "mag", "xi", "eta"] }, ) num_photometries = correction_statistics["rms"][0].size for prefix in ["rms", "num_finite"]: for phot_index in range(num_photometries): dframe[prefix + f"_{phot_index:02d}"] = correction_statistics[ prefix ][:, phot_index] with open(filename, "w", encoding="utf-8") as outf: dframe.to_string(outf, col_space=25, index=False, justify="left")
[docs] def load_correction_statistics(filename, add_catalog=False): """Read a previously stored statistics from a file.""" with DataReductionFile() as mem_dr: dframe = pandas.read_csv( filename, delim_whitespace=True, index_col="ID" ) num_sources, num_photometries = dframe.shape num_photometries = (num_photometries - 3) // 2 result_dtype = EPDCorrection.get_result_dtype(num_photometries) if add_catalog: catalog = read_catalog_file( add_catalog, add_gnomonic_projection=True ) result_dtype += [ (col, catalog[col].dtype) for col in catalog.columns if col not in ["xi", "eta"] ] dframe = dframe.drop(columns=["xi", "eta"]).join( catalog, how="inner" ) result = numpy.empty(num_sources, dtype=result_dtype) for column in dframe.columns: if column.startswith("rms_") or column.startswith("num_finite_"): continue result[column] = dframe[column] for prefix in ["rms", "num_finite"]: for phot_index in range(num_photometries): result[prefix][:, phot_index] = dframe[ prefix + f"_{phot_index:02d}" ] if "2MASSID" in dframe.columns: for index, source_id in enumerate(dframe["2MASSID"]): result["ID"][index] = mem_dr.parse_hat_source_id(source_id) else: result["ID"] = dframe.index return result
[docs] def calculate_iterative_rejection_scatter( values, calculate_average, calculate_scatter, outlier_threshold, max_outlier_rejections, *, return_average=False, ): """ Calculate the scatter for a dataset, with outlier rejectio iterations. Args: values(numpy array like): The data to calculate the scatter of. calculate_average(callable): A callable that returns the average of the data aroung which the scatter will be calculated. calculate_scatter(callable): The scatter is defined as the square root of whatever get_scatter calculates from the square deviations of the data from the average. outlier_threshold(float): In units of the scatter, how far away should a point be from the average to be considered an outlier. max_outlier_rejections(int): The maximum number of iterations between outlier rejection and re-calculating the scatter to perform. return_average(bool): Should the average of the poinst also be returned? Returns: float, int: The scatter in values and the number of non-rejected points in the last scatter calculation. """ include_points = numpy.ones(values.shape, dtype=bool) non_outliers = True for _ in range(max_outlier_rejections): include_points = numpy.logical_and(include_points, non_outliers) average = calculate_average(values[include_points]) square_deviations = numpy.square(values - average) square_scatter = calculate_scatter(square_deviations[include_points]) non_outliers = ( square_deviations <= outlier_threshold**2 * square_scatter ) if non_outliers[include_points].all(): break if return_average: return numpy.sqrt(square_scatter), include_points.sum(), average return numpy.sqrt(square_scatter), include_points.sum()
[docs] def recalculate_correction_statistics( lc_fnames, fit_datasets, variables, lc_points_filter_expression, **calculate_scatter_config, ): """ Extract the performance metrics for a de-trending step directly from LCs. Args: lc_fnames([str]): The filenames of the light curves that were corrected. fit_datasets: See Correction.__init__(). extra_predictors: See EPDCorrection.__init__(). calculate__scatter_config: Arguments passed directly to calculate_iterative_rejection_scatter(). Returns: See apply_parallel_correction's return value. """ result = numpy.empty( len(lc_fnames), dtype=EPDCorrection.get_result_dtype(len(fit_datasets)) ) for lc_index, fname in enumerate(lc_fnames): with LightCurveFile(fname, "r") as lightcurve: for fit_index, (_, substitutions, to_dset) in enumerate( fit_datasets ): try: stat_points = lightcurve.evaluate_expression( variables, lc_points_filter_expression ) # False positive # pylint: disable=unbalanced-tuple-unpacking ( result["rms"][lc_index][fit_index], result["num_finite"][lc_index][fit_index], ) = calculate_iterative_rejection_scatter( lightcurve.get_dataset(to_dset, **substitutions)[ stat_points ], **calculate_scatter_config, ) # pylint: enable=unbalanced-tuple-unpacking except OSError: result["rms"][lc_index][fit_index] = numpy.nan result["num_finite"][lc_index][fit_index] = 0 return result
[docs] def pool_init(config): """Setup pool process.""" db_engine.dispose() setup_process(**config)
[docs] def apply_parallel_correction( lc_fnames, correct, num_parallel_processes, **config ): """ Correct LCs running one of the detrending algorithms in parallel. Args: lc_fnames([str]): The filenames of the light curves to correct. correct(Correction): The underlying correction to apply in parallel. num_parallel_processes(int): The maximum number of parallel processes to use. statistics_fname(str): Filename to use for saving the statistics. Returns: numpy.array: The return values of correct.__call__() in the same order as lc_fnames. """ logger = logging.getLogger(__name__) logger.info("Starting detrending %d light curves.", len(lc_fnames)) if num_parallel_processes == 1: result = numpy.concatenate([correct(lcf) for lcf in lc_fnames]) else: with Pool( num_parallel_processes, initializer=pool_init, initargs=(config,) ) as correction_pool: result = numpy.concatenate(correction_pool.map(correct, lc_fnames)) logger.info("Finished detrending.") return result
[docs] def apply_reconstructive_correction_transit( lc_fname, correct, *, transit_model, transit_parameters, fit_parameter_flags, num_limbdark_coef, ): """ Perform a reconstructive correction on a LC assuming it contains a transit. The corrected lightcurve, preserving the best-fit transit is saved in the lightcurve just like for non-reconstructive corrections. Args: transit_model: Object which supports the transit model intefrace of pytransit. transit_parameters(scipy float array): The full array of parameters required by the transit model's evaluate() method. fit_parameter_flags(scipy bool array): Flags indicating parameters whose values should be fit for (by having a corresponding entry of True). Must match exactly the shape of transit_parameters. num_limbdark_coef(int): How many of the transit parameters are limb darkening coefficinets? Those need to be passed to the model separately. correct(Correction): Instance of one of the correction algarithms to make adaptive. Returns: (scipy array, scipy array): * The best fit transit parameters * The return value of ReconstructiveCorrectionTransit.__call__() for the best-fit transit parameters. """ # This is intended to server as a callable. # pylint: disable=too-few-public-methods class MinimizeFunction: """Suitable callable for scipy.optimize.minimize().""" def __init__(self): """Create the underlying correction object.""" self.correct = ReconstructiveCorrectionTransit( transit_model, correct, fit_amplitude=False, ) self.transit_parameters = numpy.copy(transit_parameters) def __call__(self, fit_params): """ Return the RMS residual of the corrected LC around a transit model. Args: fit_params(scipy array): The values of the mutable model parameters for the current minimization function evaluation. Returns: float: RMS of the residuals after correcting around the transit model with the given parameters. """ self.transit_parameters[fit_parameter_flags] = fit_params return self.correct( lc_fname, self.transit_parameters[0], self.transit_parameters[1 : num_limbdark_coef + 1], *self.transit_parameters[num_limbdark_coef + 1 :], save=False, )["rms"] # pylint: enable=too-few-public-methods rms_function = MinimizeFunction() best_fit_transit = numpy.copy(transit_parameters) if fit_parameter_flags.any(): minimize_result = minimize( rms_function, transit_parameters[fit_parameter_flags] ) assert minimize_result.success best_fit_transit[fit_parameter_flags] = minimize_result.x return ( best_fit_transit, rms_function.correct( lc_fname, best_fit_transit[0], best_fit_transit[1 : num_limbdark_coef + 1], *best_fit_transit[num_limbdark_coef + 1 :], ), )