Source code for stellar_evolution.manager

"""Define class for managing many stellar evolution interpolations."""

import os
import os.path
import shutil
from uuid import uuid4 as get_uuid
from decimal import Decimal
import ctypes
from math import isnan
import sys

import numpy
from sqlalchemy import exists, and_, func, create_engine

sys.path.append('..')

#Need to add POET package to module search path before importing
#pylint: disable=wrong-import-position
from stellar_evolution.change_variables import VarChangingInterpolator
from stellar_evolution.library_interface import\
    library_track_fname_rex,\
    library_track_fname,\
    library
from stellar_evolution import Session
from stellar_evolution.basic_utils import db_session_scope, tempdir_scope
from basic_utils import Structure

from .manager_data_model import\
    DataModelBase,\
    SerializedInterpolator,\
    InterpolationParameters,\
    Quantity,\
    Track,\
    ModelSuite

from .managed_interpolator import\
    ManagedInterpolator,\
    verify_checksum,\
    checksum_filename
#pylint: enable=wrong-import-position

[docs]class StellarEvolutionManager: """ Class for managing a collection of stellar evolution inteprolations. """
[docs] @staticmethod def _get_decimal(value): """Prepare a value for db storage as a limited precision decimal.""" return Decimal(value).quantize(Decimal('0.0001'))
[docs] @staticmethod def _define_evolution_quantities(db_session): """ Define the quantities tracked by VarChangingInterpolator instances. Args: db_session: The currently active database session. Returns: None """ db_quantities = [ Quantity(id=id, name=name) for name, id in VarChangingInterpolator.quantity_ids.items() ] db_session.add_all(db_quantities)
[docs] def _initialize_database(self, db_engine, db_session): """ Ensure all database tables exist and contain at least required data. Args: db_session: An sqlalchemy session, used to update the database. Returns: None """ for table in DataModelBase.metadata.sorted_tables: if not db_engine.has_table(table.name): table.create(db_engine) if table.name == 'quantities': self._define_evolution_quantities(db_session) elif table.name == 'model_suites': db_session.add(ModelSuite(name='MESA'))
[docs] def _get_db_config(self, db_session): """Read some configuration from the database.""" self._quantities = [ Structure(id=q.id, name=q.name) for q in db_session.query(Quantity).all() ] self._suites = db_session.query(ModelSuite).all() self._new_track_id = db_session.query(Track).count() + 1
#Not simplifiable without loss of intuitive use. #pylint: disable=too-many-arguments
[docs] def _add_track(self, track_fname, mass, feh, model_suite, db_session): """Add a track to the database.""" db_track = Track(id=self._new_track_id, filename=os.path.abspath(track_fname), mass=mass, feh=feh, checksum=checksum_filename(track_fname), suite=db_session.query( ModelSuite ).filter_by( name=model_suite ).one()) db_session.add(db_track) self._new_track_id += 1
#pylint: enable=too-many-arguments
[docs] @staticmethod def _track_grid_from_files(track_fnames, db_session, model_suite=None): """ Organize track files in a mass - [Fe/H] grid and verify checksums. In addition, if model_suite is not None, all tracks are verified to belong to this suite. Args: track_fnames: See get_interpolator track_fnames argument. Returns: dict: keys - the masses of tracks; and values - further dictionaries with keys the [Fe/H] of tracks and values the filename of each track. """ track_grid = dict() for fname in track_fnames: absolute_fname = os.path.abspath(fname) db_track = db_session.query(Track).filter_by( filename=absolute_fname ) if model_suite is not None: assert db_track.suite.name == model_suite verify_checksum(fname, db_track.checksum, 'track') mass_key = db_track.mass feh_key = db_track.feh if mass_key not in track_grid: track_grid[mass_key] = dict() assert feh_key not in track_grid[mass_key] track_grid[mass_key][feh_key] = (fname, db_track.id) return track_grid
[docs] def _track_grid_from_grid(self, mass_list, feh_list, model_suite, db_session): """ Return a mass - [Fe/H] grid with filenames and checksums. Fails if multiple tracks are registered for some (mass, [Fe/H], model suite) combination. Args: mass_list: The masses for which to include tracks. feh_list: The [Fe/H values for which to include tracks. model_suite: The software suite from whose tracks to choose. db_session: A database session to submit queries to. Returns: dict: See _track_grid_from_files() """ if mass_list is None: mass_list = [ record[0] for record in db_session.query(Track.mass).filter( Track.model_suite_id == ModelSuite.id, ModelSuite.name == model_suite ).distinct().order_by(Track.mass).all() ] if feh_list is None: feh_list = [ record[0] for record in db_session.query(Track.feh).filter( Track.model_suite_id == ModelSuite.id, ModelSuite.name == model_suite ).distinct().order_by(Track.feh).all() ] track_grid = {m: dict() for m in mass_list} for mass in mass_list: for feh in feh_list: db_track = db_session.query(Track).filter( Track.mass == self._get_decimal(mass), Track.feh == self._get_decimal(feh), Track.model_suite_id == ModelSuite.id, ModelSuite.name == model_suite ).one() verify_checksum(db_track.filename, db_track.checksum, 'track') track_grid[mass][feh] = (db_track.filename, db_track.id) return track_grid
[docs] def _find_existing_interpolator(self, *, track_grid, nodes, smoothing, db_session, vs_log_age, log_quantity): """ Return the specified interpolation if already exists, otherwise None. Args: track_grid: See result of _track_grid_from_files() nodes: see get_interpolator. smoothing: see get_interpolator. db_session: The currently active database session. vs_log_age: see get_interpolator. log_quantity: see get_interpolator. Returns: VarChangingInterpolator: Pre-serialized interpolation matching the given arguments if one is found in the interpolation archive. If no pre-serialized interpolation exists, returns None. """ num_tracks = len(track_grid) * len(next(iter(track_grid.values()))) track_counts = db_session.query( SerializedInterpolator.id, func.count('*').label('num_tracks') ).join( SerializedInterpolator.tracks ).group_by( SerializedInterpolator.id ).subquery() match_config = ( [ exists().where( and_( InterpolationParameters.interpolator_id == SerializedInterpolator.id , InterpolationParameters.quantity_id == quantity.id , InterpolationParameters.smoothing == (None if isnan(smoothing[quantity.name]) else self._get_decimal(smoothing[quantity.name])) , InterpolationParameters.nodes == nodes[quantity.name], InterpolationParameters.vs_log_age == vs_log_age[quantity.name], InterpolationParameters.log_quantity == log_quantity[quantity.name] ) ) for quantity in self._quantities ] + [ SerializedInterpolator.tracks.any(Track.id == track_id) for mass, mass_row in track_grid.items() for feh, (track_fname, track_id) in mass_row.items() ] + [track_counts.c.num_tracks == num_tracks] ) result = db_session.query( SerializedInterpolator ).join( track_counts, SerializedInterpolator.id == track_counts.c.id ).filter( *match_config ).one_or_none() if result is None: return result return ManagedInterpolator( db_interpolator=result, serialization_path=self._serialization_path, db_session=db_session )
#TODO: think about breaking this function up. #pylint: disable=too-many-locals
[docs] def _create_new_interpolator(self, *, track_grid, nodes, smoothing, db_session, vs_log_age, log_quantity, num_threads, name=None): """ Generate the specified interpolation and add it to the archive. Args: track_grid: See result of _track_grid_from_files() nodes: see get_interpolator(). smoothing: see get_interpolator(). vs_log_age: see get_interpolator(). log_quantity: see get_interpolator(). num_threads: The number of simultaneous threads to use when constructing the interpolation. db_session: The database query session to use. name: The name to assign to the new interpolator. If None, the UUID used to form the filename is used. Returns: VarChangingInterpolator: Created from scratch based on the given arguments. """ interp_str = str(get_uuid()) interp_fname = os.path.join(self._serialization_path, interp_str) db_interpolator = SerializedInterpolator( id=(db_session.query(SerializedInterpolator).count() + 1), name=(name or interp_str), filename=interp_str ) if ( db_session.query( SerializedInterpolator ).filter_by(name=name).count() ): raise ValueError('Interpolator named %s already exists, with a ' 'different configuration than the one being ' 'constructed!' % repr(name)) with tempdir_scope() as track_dir: for mass, mass_row in track_grid.items(): for feh, (track_fname, track_id) in mass_row.items(): track = db_session.query( Track ).filter_by( id=track_id ).one() db_interpolator.tracks.append(track) shutil.copy( track_fname, os.path.join(track_dir, library_track_fname(mass, feh)) ) interp_smoothing = numpy.empty( len(VarChangingInterpolator.quantity_list), dtype=ctypes.c_double ) interp_nodes = numpy.empty( len(VarChangingInterpolator.quantity_list), dtype=ctypes.c_int ) interp_vs_log_age = numpy.empty( len(VarChangingInterpolator.quantity_list), dtype=ctypes.c_bool ) interp_log_quantity = numpy.empty( len(VarChangingInterpolator.quantity_list), dtype=ctypes.c_bool ) for q_name, q_index in \ VarChangingInterpolator.quantity_ids.items(): interp_smoothing[q_index] = smoothing[q_name] interp_nodes[q_index] = nodes[q_name] interp_vs_log_age[q_index] = vs_log_age[q_name] interp_log_quantity[q_index] = log_quantity[q_name] db_interpolator.parameters = [ InterpolationParameters(quantity_id=q.id, nodes=nodes[q.name], smoothing=smoothing[q.name], vs_log_age=vs_log_age[q.name], log_quantity=log_quantity[q.name], interpolator=db_interpolator) for q in self._quantities ] actual_interpolator = ManagedInterpolator( db_interpolator=db_interpolator, serialization_path=self._serialization_path, db_session=db_session, mesa_dir=track_dir, smoothing=interp_smoothing, nodes=interp_nodes, vs_log_age=interp_vs_log_age, log_quantity=interp_log_quantity, num_threads=num_threads ) actual_interpolator.save(interp_fname) db_interpolator.checksum = checksum_filename(interp_fname) db_session.add(db_interpolator) db_session.add_all(db_interpolator.parameters) return actual_interpolator
#pylint: enable=too-many-locals
[docs] def __init__(self, serialization_path): """ Create a manager storing serialized interpolators in the given path. Args: serialization_path: The path where to store serialized interpolators. Returns: None """ if not os.path.exists(serialization_path): os.makedirs(serialization_path) db_engine = create_engine( 'sqlite:///' + os.path.join(serialization_path, 'serialized.sqlite'), echo=False ) Session.configure(bind=db_engine) self._serialization_path = serialization_path with db_session_scope() as db_session: self._initialize_database(db_engine, db_session) with db_session_scope() as db_session: self._get_db_config(db_session)
[docs] def get_interpolator( self, *, nodes=VarChangingInterpolator.default_nodes, smoothing=VarChangingInterpolator.default_smoothing, vs_log_age=VarChangingInterpolator.default_vs_log_age, log_quantity=VarChangingInterpolator.default_log_quantity, track_fnames=None, masses=None, feh=None, model_suite='MESA', new_interp_name=None, num_threads=1 ): """ Return a stellar evolution interpolator with the given configuration. All tracks that the interpolator should be based on must be pre-registered with the manager. Two ways are supported for identifying tracks: as a list of filenames or as a mass-[Fe/H] grid combined with a suite. The first case always works, while the second requires that the set of identified tracks is unique, i.e. for none of the mass - [Fe/H] combinations there are two or more tracks registered for the given suite. Args: nodes: The number of nodes to use for the age interpolation of each quantity of each track. Should be a dictionary with keys VarChangingInterpolator.quantity_list. See the POET code StellarEvolution::Interpolator::create_from() documentation for a description of what this actually means. smoothing: The amount of smoothing to use for the age interpolation of each quantity of each track. Should be a dictionary with keys VarChangingInterpolator.quantity_list. See the POET code StellarEvolution::Interpolator::create_from() documentation for a description of what this actually means. vs_log_age: Use log(age) instead of age as the independent argument for the intperpolation? Should be a dictionary with keys VarChangingInterpolator.quantity_list. log_quantity: Interpolate log(quantity) instead of quantity? Should be a dictionary with keys VarChangingInterpolator.quantity_list. track_fnames: A list of files containing stellar evolution tracks the interpolator should be based on. masses: A list of the stellar masses to include in the interpolation. Unique tracks with those masses and all selected [Fe/H] (see next argument) must already be registered with the database for the given suite. If None, all track masses from the given suite are used. feh: A list of the stellar [Fe/H] values to include in the interpolation. If None, all track [Fe/H] from the given suite are used. model_suite: The software suite used to generate the stellar evolution tracks. May be omitted if tracks are specified by filename, but must be supplied if using masses and [Fe/H]. new_interp_name: Name to assign to the a newly generated interolator. Ignored if an interpolator matching all other arguments already exists. If not specified, and no interpolator exists matching the remining arguments, a new interpolator is not generated. num_threads: If a new interpolator is created this many simultaneous interpolation threads are used. Returns: VarChangingInterpolator: Configured per the arguments supplied or None if no existing interpolator is found and creating a new one is forbidden (see new_interp_name argument). """ with db_session_scope() as db_session: if track_fnames is None: track_grid = self._track_grid_from_grid(masses, feh, model_suite, db_session) else: track_grid = self._track_grid_from_files(track_fnames, db_session, model_suite) result = self._find_existing_interpolator( track_grid=track_grid, nodes=nodes, smoothing=smoothing, vs_log_age=vs_log_age, log_quantity=log_quantity, db_session=db_session ) if result is not None: return result if new_interp_name is None: return None return self._create_new_interpolator( track_grid=track_grid, nodes=nodes, smoothing=smoothing, vs_log_age=vs_log_age, log_quantity=log_quantity, db_session=db_session, name=new_interp_name, num_threads=num_threads )
[docs] def get_interpolator_by_name(self, name): """Return the interpolator with the given name.""" with db_session_scope() as db_session: #False positive #pylint: disable=no-member return ManagedInterpolator( db_interpolator=db_session.query( SerializedInterpolator ).filter_by( name=name ).one(), serialization_path=self._serialization_path, db_session=db_session )
#pylint: enable=no-member
[docs] @staticmethod def list_interpolator_names(): """Return a list of all intorpolator names.""" with db_session_scope() as db_session: return [ record[0] for record in #False positive #pylint: disable=no-member db_session.query(SerializedInterpolator.name).all() #pylint: enable=no-member ]
[docs] @staticmethod def list_suites(): """Return a list of all software suites with available tracks.""" with db_session_scope() as db_session: return [ record[0] for record in #False positive #pylint: disable=no-member db_session.query(ModelSuite.name).all() #pylint: enable=no-member ]
[docs] @staticmethod def get_suite_tracks(model_suite='MESA'): """Return all tracks from a given suite.""" with db_session_scope() as db_session: #False positive #pylint: disable=no-member return db_session.query(Track).filter( and_(Track.model_suite_id == ModelSuite.id, ModelSuite.name == model_suite) )
#pylint: enable=no-member
[docs] def register_track(self, track_fname, mass, feh, model_suite='MESA'): """Register a track for use in creating interpolators.""" with db_session_scope() as db_session: self._add_track(track_fname, self._get_decimal(mass), self._get_decimal(feh), model_suite, db_session)
[docs] def register_track_collection(self, track_fnames, fname_rex=library_track_fname_rex, model_suite='MESA'): """ Add a collection of tracks with [Fe/H] and M* encoded in filename. Args: -track_fnames: The filenames of the tracks to add. - fname_rex: A regular expression defining groups named 'MASS' and either 'Z' or 'FeH' used to parse the filename for the stellar mass and metallicity each track applies to. - model_suite: The software suite used to generate the stellar evolution tracks. """ for fname in track_fnames: parsed_fname = fname_rex.match( os.path.basename(fname) ).groupdict() mass = self._get_decimal(parsed_fname['MASS']) if 'Z' in parsed_fname: feh = self._get_decimal( library.feh_from_z(float(parsed_fname['Z'])) ) else: feh = self._get_decimal(parsed_fname['FeH']) with db_session_scope() as db_session: self._add_track(fname, mass, feh, model_suite, db_session)