Source code for autowisp.database.user_interface

"""Define interface to the pipeline database."""

import json
import logging
from time import sleep

from sqlalchemy import sql, select, delete, func, and_

from autowisp.database.interface import Session
from autowisp.data_reduction import DataReductionFile

# False positive
# pylint: disable=no-name-in-module
from autowisp.database.data_model.provenance import (
    Camera,
    CameraType,
    CameraChannel,
)

from autowisp.database.data_model import (
    Condition,
    ConditionExpression,
    Configuration,
    Image,
    ImageProcessingProgress,
    ImageType,
    LightCurveProcessingProgress,
    MasterFile,
    MasterType,
    ObservingSession,
    Parameter,
    ProcessedImages,
    ProcessingSequence,
    Step,
    step_param_association,
)

# pylint: enable=no-name-in-module


[docs] def get_db_configuration( version, db_session, step_id=None, max_version_only=False ): """Return list of Configuration instances given version.""" # False positives: # pylint: disable=no-member param_version_subq = ( select( Configuration.parameter_id, # False positivie # pylint: disable=not-callable sql.func.max(Configuration.version).label("version"), # pylint: enable=not-callable ) .filter(Configuration.version <= version) .group_by(Configuration.parameter_id) .subquery() ) config_select = select( func.max(Configuration.version) if max_version_only else Configuration ).join( param_version_subq, sql.expression.and_( (Configuration.parameter_id == param_version_subq.c.parameter_id), (Configuration.version == param_version_subq.c.version), ), ) if step_id is not None: config_select = config_select.join( step_param_association, Configuration.parameter_id == step_param_association.c.param_id, ).where(step_param_association.c.step_id == step_id) if max_version_only: return db_session.scalars(config_select).one() return db_session.scalars(config_select).all()
# pylint: enable=no-member
[docs] def get_processing_sequence(db_session): """ Return the sequence of steps in the pipeline. For image processing this will be a sequence of step/image type pairs, and for lightcurves it will just be a sequence of steps. """ select_seq = ( select(Step, ImageType) .select_from(ProcessingSequence) .join(Step, ProcessingSequence.step_id == Step.id) .join(ImageType, ProcessingSequence.image_type_id == ImageType.id) ) return db_session.execute(select_seq).all()
[docs] def list_channels(db_session): """List the combine set of channels for all cameras.""" return db_session.scalars(func.distinct(CameraChannel.name)).all()
[docs] def get_progress_images(step_id, image_type_id, config_version, db_session): """ Return number of images in final state and by status for given step/imtype. Args: step: Step instance for which to return the progress. image_type: ImageType instance for which to return the progress. config_version: Version of the configuration for which to report progress. db_session: Database session to use. Returns: [str, int, int]: Information on the images in final state. The entries are channel name, example status (>0 indicates success <0 indicates falure), number of images of that channel that have that status sign and are flagged final. [str, int]: Information about the images not in final state. The entries are channel name, number of non-final images of that channel. [str, int, int]: The pending images broken by status. The format is the same as the final state information, except for images not flagged as in final state for the given step. """ step_version = get_db_configuration( config_version, db_session, step_id, max_version_only=True ) def complete_processed_select(_select): """Return the given select joined and filtered to given processed.""" return _select.join( ImageProcessingProgress, ProcessedImages.progress_id == ImageProcessingProgress.id, ).where( ImageProcessingProgress.step_id == step_id, ImageProcessingProgress.configuration_version == step_version, ImageProcessingProgress.image_type_id == image_type_id, ) select_image_channel = ( select( CameraChannel.name, # False poisitive # pylint: disable=not-callable # pylint: disable=no-member func.count(Image.id), # pylint: enable=not-callable # pylint: enable=no-member ) .join( ObservingSession, ) .join(Camera) .join(CameraType) .join(CameraChannel) ) processed_select = complete_processed_select( select( ProcessedImages.channel, ProcessedImages.status, # False poisitive # pylint: disable=not-callable func.count(ProcessedImages.image_id), # pylint: enable=not-callable ) .join(Image) .join(ImageType) ).where(ImageType.id == image_type_id) final = db_session.execute( processed_select.where(ProcessedImages.final).group_by( ProcessedImages.status > 0, ProcessedImages.channel, ) ).all() by_status = db_session.execute( processed_select.where(~ProcessedImages.final).group_by( ProcessedImages.status, ProcessedImages.channel ) ).all() processed_subquery = ( complete_processed_select( select(ProcessedImages.image_id, ProcessedImages.channel) ) .where(ProcessedImages.final) .subquery() ) pending = db_session.execute( select_image_channel.outerjoin( processed_subquery, # False positive # pylint: disable=no-member and_( Image.id == processed_subquery.c.image_id, CameraChannel.name == processed_subquery.c.channel, ), # pylint: enable=no-member ) .where( # This is how NULL comparison is done in SQLAlchemy # pylint: disable=singleton-comparison # pylint: disable=no-member processed_subquery.c.image_id == None # pylint: enable=singleton-comparison # pylint: enable=no-member ) .where( # pylint: disable=no-member Image.image_type_id == image_type_id # pylint: enable=no-member ) .group_by(CameraChannel.name) ).all() return final, pending, by_status
[docs] def get_progress_lightcurves( step_id, image_type_id, config_version, db_session ): """Same as `get_progress_images()` but for lightcurve steps.""" step_version = get_db_configuration( config_version, db_session, step_id, max_version_only=True ) final = {} pending = {} for db_sphotref in db_session.scalars( select(MasterFile) .join(MasterType) .where(MasterType.name == "single_photref") ).all(): for _ in range(10): try: with DataReductionFile( db_sphotref.filename, "r" ) as sphotref_dr: header = sphotref_dr.get_frame_header() if ( not db_session.scalar( select(ImageType.id) .select_from(Image) .join(ImageType) .where( Image.raw_fname.contains( header["RAWFNAME"] + ".fits" ) ) ) == image_type_id ): continue channel = header["CLRCHNL"] if channel not in final: final[channel] = 0 if channel not in pending: pending[channel] = 0 if db_session.scalar( select( func.max(LightCurveProcessingProgress.final) ).filter_by( step_id=step_id, single_photref_id=db_sphotref.id, configuration_version=step_version, ) ): final[channel] += 1 else: pending[channel] += 1 break # h5py refuses to provide public interface to exceptions # pylint: disable=bare-except except: sleep(10) # pylint: enable=bare-except return ( [(channel, 1, count) for channel, count in final.items()], list(pending.items()), [], )
[docs] def get_progress(step, *args, **kwargs): """Return info about completed work ona given step.""" if step.name in [ "epd", "tfa", "generate_epd_statistics", "generate_tfa_statistics", ]: return get_progress_lightcurves(step.id, *args, **kwargs) return get_progress_images(step.id, *args, **kwargs)
[docs] def _get_config_info(version, step="All"): """Return info for displaying the configuration with given version.""" # False positive: # pylint: disable=no-member with Session.begin() as db_session: # pylint: enable=no-member if step != "All": restrict_param_ids = set( param.id for param in db_session.scalar( select(Step).filter_by(name=step) ).parameters ) config_list = get_db_configuration(version, db_session) config_info = {} for config in config_list: if step != "All" and config.parameter.id not in restrict_param_ids: continue if config.parameter.name not in config_info: config_info[config.parameter.name] = { "values": {}, "expression_counts": {}, "description": config.parameter.description, } param_info = config_info[config.parameter.name] param_info["values"][config.value] = set( expr.expression for expr in config.condition_expressions if expr.expression != "True" ) for expression in config.condition_expressions: param_info["expression_counts"][expression.expression] = ( param_info["expression_counts"].get( expression.expression, 0 ) + 1 ) return config_info
[docs] def get_json_config(version=0, step="All", **dump_kwargs): """Return the configuration as a JSON object.""" def get_children(values, expression_order): """Return the sub-tree for the given expressions.""" result = [] child_values = {} sibling_values = {} for value, val_expressions in values.items(): if not val_expressions: result.append({"name": value, "type": "value", "children": []}) elif expression_order[0] in val_expressions: child_values[value] = val_expressions - set( [expression_order[0]] ) else: sibling_values[value] = val_expressions if child_values: result.append( { "name": expression_order[0], "type": "condition", "children": get_children( child_values, expression_order[1:] ), } ) if sibling_values: result.extend(get_children(sibling_values, expression_order[1:])) return result config_data = { "name": "All" if step == "All" else step, "type": "step", "children": [], } for param, param_info in _get_config_info(version, step).items(): expression_order = [ expr_count[0] for expr_count in sorted( param_info["expression_counts"].items(), key=lambda expr_count: expr_count[1], reverse=True, ) ] config_data["children"].append( { "name": param, "type": "parameter", "description": param_info["description"], "children": get_children( param_info["values"], expression_order ), } ) return json.dumps(config_data, **dump_kwargs)
[docs] def _parse_json_config(json_config): """ Organize the given JSON configuration to parameters and conditions. Args: json_config: JSON configuration to be parsed. Formatted as a decision tree, where the path through the tree defines the combination of condition expressions that must be satisfied and the leaf at the end specifies the value for the parameter . Returns: dict: parameter name: [ { 'expressions': set(expression ID index in below list), 'value': value of parameter if all expressions are satisfied }, ... ] [str]: list of expression strings """ result = {} expression_list = [] def walk_json(sub_tree, parameter=None, expression_ids=None): """Recursively walk the JSON configuration tree adding to results.""" if sub_tree["type"] == "parameter": assert parameter is None assert sub_tree["name"] not in result assert expression_ids is None assert sub_tree["children"] for child in sub_tree["children"]: walk_json(child, sub_tree["name"], ()) elif sub_tree["type"] == "value": assert not sub_tree["children"] assert parameter assert expression_ids is not None if parameter not in result: result[parameter] = [] print( "Adding to parsed: " + repr(set(expression_ids)) + " -> " + repr(sub_tree["name"]) ) result[parameter].append( {"expressions": set(expression_ids), "value": sub_tree["name"]} ) elif sub_tree["type"] == "condition": assert sub_tree["children"] assert parameter try: condition_id = expression_list.index(sub_tree["name"]) except ValueError: condition_id = len(expression_list) expression_list.append(sub_tree["name"]) for child in sub_tree["children"]: walk_json(child, parameter, expression_ids + (condition_id,)) else: raise ValueError( f'Unexpected node type: {sub_tree["type"]} in JSON' " configuration" ) for child in json_config["children"]: walk_json(child) return result, expression_list
[docs] def _get_db_conditions(db_session): """Return dict of condition IDs containing sets of expression IDs.""" result = {} for condition_id, expression_id in db_session.execute( # False positive # pylint: disable=no-member select(Condition.id, Condition.expression_id) # pylint: enable=no-member ).all(): if condition_id not in result: result[condition_id] = set() result[condition_id].add(expression_id) return result
[docs] def _save_expressions(expressions, db_session): """Save new expressions to database and update configuration with DB IDs.""" expression_db_ids = [None for _ in expressions] for expr_ind, expression_str in enumerate(expressions): expression = db_session.execute( select(ConditionExpression).where( ConditionExpression.expression == (expression_str or "True") ) ).scalar_one_or_none() if expression is None: expression = ConditionExpression(expression=expression_str) db_session.add(expression) db_session.flush() expression_db_ids[expr_ind] = expression.id return expression_db_ids
[docs] def _save_conditions(configuration, expression_db_ids, db_session): """Create new conditions encounted in configuration and add their IDs.""" db_conditions = _get_db_conditions(db_session) print( "DB conditions:\n\t" + "\n\t".join(f"{k}: {v!r}" for k, v in db_conditions.items()) ) print("DB condition values: " + repr(db_conditions.values())) new_condition_id = db_session.scalar( # False positive # pylint: disable=no-member select(sql.functions.max(Condition.id) + 1) # pylint: enable=no-member ) default_expression_set = set( [ db_session.scalar( select(ConditionExpression.id).where( ConditionExpression.expression == "True" ) ) ] ) default_condition_id = [ k for k, v in db_conditions.items() if v == default_expression_set ][0] for param_info in configuration.values(): for param_condition in param_info: condition_expression_ids = ( set( expression_db_ids[expr_id] for expr_id in param_condition["expressions"] ) - default_expression_set ) param_condition["expressions"] = condition_expression_ids if not condition_expression_ids: param_condition["condition_id"] = default_condition_id else: matching_condition = [ k for k, v in db_conditions.items() if v == condition_expression_ids ] if matching_condition: param_condition["condition_id"] = matching_condition[0] else: db_session.add_all( # False positive # pylint: disable=not-callable Condition( id=new_condition_id, expression_id=expression_id ) # pylint: enable=not-callable for expression_id in condition_expression_ids ) param_condition["condition_id"] = new_condition_id new_condition_id += 1
[docs] def save_json_config(json_config, version): """Save configuration provided in JSON format to the database.""" configuration, expressions = _parse_json_config( json.loads(json_config.decode("ascii")) ) # False positive: # pylint: disable=no-member with Session.begin() as db_session: # pylint: enable=no-member compare_config = get_db_configuration(version, db_session) _save_conditions( configuration, _save_expressions(expressions, db_session), db_session, ) params_to_save = {} for param_name, param_info in configuration.items(): param_id = db_session.scalar( select(Parameter.id).where(Parameter.name == param_name) ) for condition_info in param_info: found = False for old_config in compare_config: if ( old_config.parameter_id == param_id and ( old_config.condition_id == condition_info["condition_id"] ) and old_config.value == condition_info["value"] ): found = True compare_config.remove(old_config) break if not found: params_to_save[param_name] = param_id for old_config in compare_config: if old_config.parameter.name not in configuration: continue params_to_save[old_config.parameter.name] = old_config.parameter_id for param_name, param_info in configuration.items(): if param_name in params_to_save: parameter_id = params_to_save[param_name] # False positive # pylint: disable=no-member delete_statement = ( delete(Configuration) .where(Configuration.parameter_id == parameter_id) .where(Configuration.version == version) ) # pylint: enable=no-member db_session.execute(delete_statement) db_session.add_all( # False positive # pylint: disable=not-callable Configuration( parameter_id=parameter_id, condition_id=condition_info["condition_id"], value=condition_info["value"], version=version, ) # pylint: enable=not-callable for condition_info in param_info )
[docs] def list_steps(): """List the pipeline steps.""" # False positive: # pylint: disable=no-member with Session.begin() as db_session: # pylint: enable=no-member return db_session.scalars(select(Step.name)).all()
[docs] def main(): """Avoid polluting the global namespace.""" logging.basicConfig(level=logging.DEBUG) logging.getLogger("sqlalchemy.engine").setLevel(logging.DEBUG) # False positive: # pylint: disable=no-member with Session.begin() as db_session: # pylint: enable=no-member print("Channels: " + repr(list_channels(db_session))) print(get_progress(8, 4, 0, db_session))
if __name__ == "__main__": main()