Source code for stellar_evolution.managed_interpolator

"""Define a stellar evolution interpolator class managed through a database."""

import os
import hashlib
import sys

import numpy

sys.path.append('..')

#Need to add POET package to module search path before importing
#pylint: disable=wrong-import-position
from stellar_evolution.basic_utils import db_session_scope
from stellar_evolution.change_variables import VarChangingInterpolator

from .manager_data_model import \
    VarchangeGrid,\
    VarchangeDependentValue,\
    VarchangeAgeNode,\
    VarchangeMassNode,\
    VarchangeFeHNode,\
    VarchangeDependentVariable
#pylint: enable=wrong-import-position

[docs]def checksum_filename(fname): """Return a str checksum of the file with the given name.""" assert os.path.exists(fname) with open(fname, 'rb') as opened_file: return hashlib.sha1(opened_file.read()).hexdigest()
[docs]def verify_checksum(filename, checksum, what): """ Check if the given file has the expected checksum. Args: filename: The name of the file whose checksum to verify. checksum: The expected value of the checksum. what: What is being verified (only used if error message if checksums do not match). Returns: None """ if checksum != checksum_filename(filename): raise IOError( '%s with filename %s registered, with a different checksum!' % (what.title(), repr(filename)) )
#This may be something to try to pick apart at a later time. #pylint: disable=too-many-instance-attributes
[docs]class ManagedInterpolator(VarChangingInterpolator): """Add properties describing the configuration of an interpolator."""
[docs] def _new_var_change_grid(self, *, grid_name, feh, masses, ages, db_session): """ Create a new grid with the given nodes and register it with the DB. Args: grid_name: The name to assign to the new grid in the database. feh: The [Fe/H] values at which to tabulate the dependent variables. masses: The stellar masses at which to tabulate the dependent variables. ages: The ages (in Gyrs) at which to tabulate the dependent variables. db_session: A database session to submit queries to. Returns: None """ self._varchange_grid_name = grid_name self._define_var_change_grid(feh, masses, ages) self._grid_db_id = db_session.query(VarchangeGrid).count() + 1 db_grid = VarchangeGrid( id=self._grid_db_id, name=grid_name, interpolator_id=self._db_id ) db_grid.feh_nodes = [ VarchangeFeHNode(index=index, value=value) for index, value in enumerate(feh) ] db_grid.mass_nodes = [ VarchangeMassNode(index=index, value=value) for index, value in enumerate(masses) ] db_grid.age_nodes = [ VarchangeAgeNode(index=index, value=value) for index, value in enumerate(ages) ] db_session.add(db_grid) db_session.add_all(db_grid.feh_nodes) db_session.add_all(db_grid.mass_nodes) db_session.add_all(db_grid.age_nodes)
[docs] @staticmethod def _variable_db_id(variable, db_session, must_exist=True): """ Return the ID of the given varibale in the database. Args: - variable: The name of the variable whose ID to return. - must_exst: If False, and the variable is not yet in the varchange_dependent_variables table, a new entry is added. Otherwise, an exception is raised if it is not there. Returns: - variable_db_id: The ID of the variable in the database. """ variable_db_id = db_session.query( VarchangeDependentVariable.id ).filter_by(name=variable).all() if must_exist or variable_db_id: return variable_db_id[0][0] variable_db_id = ( db_session.query(VarchangeDependentVariable).count() + 1 ) db_session.add(VarchangeDependentVariable(id=variable_db_id, name=variable)) return variable_db_id
[docs] def _read_variable_from_db(self, variable, db_session): """ Read the given variable's grid values from the DB. Args: - variable_name: The name of the variable to add. - db_session: A database session for queries. Returns: None, but has the same effect as calling VarChangingInterpolator._add_grid_variable, but finishes much faster. """ variable_db_id = self._variable_db_id(variable, db_session) setattr( self.grid, variable, numpy.empty( ( #Fales positive #pylint: disable=no-member self.grid.masses.size, self.grid.ages.size, self.grid.feh.size #pylint: enable=no-member ), dtype=(bool if variable == 'weights' else float) ) ) grid_var = getattr(self.grid, variable) for feh_i, mass_i, age_i, value in db_session.query( VarchangeDependentValue.feh_node_index, VarchangeDependentValue.mass_node_index, VarchangeDependentValue.age_node_index, VarchangeDependentValue.value ).filter_by( variable_id=variable_db_id, grid_id=self._grid_db_id ): grid_var[mass_i, age_i, feh_i] = value
[docs] def _add_variable_to_db(self, variable, db_session): """Add pre-calculated node values of a variable to DB.""" variable_db_id = self._variable_db_id(variable, db_session, False) grid_var = getattr(self.grid, variable) db_session.add_all( ( VarchangeDependentValue( variable_id=variable_db_id, grid_id=self._grid_db_id, feh_node_index=feh_i, mass_node_index=mass_i, age_node_index=age_i, value=grid_var[mass_i, age_i, feh_i] ) #Fales positive #pylint: disable=no-member for feh_i in range(self.grid.feh.size) for mass_i in range(self.grid.masses.size) for age_i in range(self.grid.ages.size) #pylint: enable=no-member ) )
[docs] def _add_grid_variable(self, variable): """ Prepares to use another dependent variable to change from. Args: see VarChangingInterpolator._add_grid_variable. Returns: None """ with db_session_scope() as db_session: variable_db_id = self._variable_db_id(variable, db_session, False) if ( #False positive #pylint: disable=no-member db_session.query( VarchangeDependentValue ).filter_by( variable_id=variable_db_id, grid_id=self._grid_db_id ).count() > 0 #pylint: enable=no-member ): self._read_variable_from_db(variable, db_session) #Attribute defined_weights defined by parent class. #pylint: disable=access-member-before-definition #pylint: disable=attribute-defined-outside-init if not self.defined_weights: self._read_variable_from_db('weights', db_session) self.defined_weights = True #pylint: enable=access-member-before-definition else: new_weights = not self.defined_weights super()._add_grid_variable(variable) assert self.defined_weights self._add_variable_to_db(variable, db_session) if new_weights: self._add_variable_to_db('weights', db_session)
[docs] def _set_var_change_grid(self, grid_name, db_session): """ Read a varchange grid from the DB and set the interpolator to use it. See VarChangingInterpolator._define_var_change_grid for newly created members of self. Args: - grid_name: The name of the grid in the database to read. - db_session: A database session for queries. Returns: True if a grid with the given name exists, False otherwise. """ grid_db_id = db_session.query( VarchangeGrid.id ).filter_by( name=grid_name, interpolator_id=self._db_id ).all() if not grid_db_id: return False self._grid_db_id = grid_db_id[0][0] self._define_var_change_grid( feh=numpy.array( db_session.query( VarchangeFeHNode.value ).filter_by( grid_id=self._grid_db_id, ).order_by( VarchangeFeHNode.index ).all() ).flatten(), masses=numpy.array( db_session.query( VarchangeMassNode.value ).filter_by( grid_id=self._grid_db_id, ).order_by( VarchangeMassNode.index ).all() ).flatten(), ages=numpy.array( db_session.query( VarchangeAgeNode.value ).filter_by( grid_id=self._grid_db_id, ).order_by( VarchangeAgeNode.index ).all() ).flatten() ) return True
[docs] def __init__(self, db_interpolator, serialization_path, db_session, **kwargs): """ Create VarChangingInterpolator and add properties describing config. Defines the following properties containing the information from db_interpolator: - name: The human readable name of the interpolator - _db_id: The ID of the interpolator in the database. - filename: The filename from which the interpolator was read. - nodes: A dictionary indexed by quantity giving the number of interpolation nodes used. - smoothing: Same as nodes but for the smoothing arguments. - suite: The software suite used to generate the tracks on which the interpolator is based. - track_masses: List of stellar masses on whose tracks the interpolation is based. - track_feh: List of stellar [Fe/H] on whose tracks the interpolation is based. Args: - db_interpolator: SerializedInterpolator instance from which to initialize self. - serialization_path: The directory where serialized interpolators are stored. Keyword only arguments: If not an empty dictionary, the underlying interpolator is constructed using those instead of the serialized filename. Returns: None """ interpolator_fname = os.path.join(serialization_path, db_interpolator.filename) if db_interpolator.checksum is not None: verify_checksum(interpolator_fname, db_interpolator.checksum, 'interpolator') self.name = db_interpolator.name self._db_id = db_interpolator.id self.filename = db_interpolator.filename self.smoothing = dict() self.nodes = dict() self.vs_log_age = dict() self.log_quantity = dict() suite = {track.suite.name for track in db_interpolator.tracks} assert len(suite) == 1 self.suite = suite.pop() for param in db_interpolator.parameters: quantity = self.quantity_names[param.quantity_id] self.smoothing[quantity] = param.smoothing or float('nan') self.nodes[quantity] = param.nodes self.vs_log_age[quantity] = param.vs_log_age self.log_quantity[quantity] = param.log_quantity self.track_masses = sorted( {track.mass for track in db_interpolator.tracks} ) self.track_feh = sorted( {track.feh for track in db_interpolator.tracks} ) if kwargs: super().__init__(grid_feh=numpy.array([]), grid_masses=numpy.array([]), grid_ages=numpy.array([]), **kwargs) else: super().__init__(grid_feh=numpy.array([]), grid_masses=numpy.array([]), grid_ages=numpy.array([]), interpolator_fname=interpolator_fname) if not self._set_var_change_grid('default', db_session): self._new_var_change_grid( grid_name='default', feh=numpy.linspace( float(self.track_feh[0]) * 0.99, float(self.track_feh[-1]) * 0.99, 3 * len(self.track_feh) ), masses=numpy.linspace( float(self.track_masses[0]), float(self.track_masses[-1]), 10 * len(self.track_masses) ), ages=numpy.linspace(1e-2, 13.71, 412), db_session=db_session )
[docs] def mass_range(self): """Return the range of masess covered by the interpolation grid.""" return ( round(float(self.track_masses[0]), 3), round(float(self.track_masses[-1]), 3) )
[docs] def feh_range(self): """Return the range of [Fe/H] covered by the interpolation grid.""" return ( round(float(self.track_feh[0]), 3), round(float(self.track_feh[-1]), 3) )
[docs] def mass_in_range(self, mass): """True iff the given mass is within the interpolation grid.""" return ( round(float(self.track_masses[0]), 3) <= mass <= round(float(self.track_masses[-1]), 3) )
[docs] def feh_in_range(self, feh): """True iff the given [Fe/H] is within the interpolation grid.""" return ( round(float(self.track_feh[0]), 3) <= feh <= round(float(self.track_feh[-1]), 3) )
[docs] def in_range(self, mass, feh): """True iff mass and [Fe/H] are within the interpolation grid.""" return self.mass_in_range(mass) and self.feh_in_range(feh)
[docs] def __str__(self): """Human readable representation of the interpolator.""" return (self.name + '[' + ', '.join(['%s(n: %d, s: %g)' % (quantity, self.nodes[quantity], self.smoothing[quantity]) for quantity in self.quantity_list]) + '] masses: [' + ', '.join([str(m) for m in self.track_masses]) + '], [Fe/H]: [' + ', '.join([str(feh) for feh in self.track_feh]) + ']')
#pylint: enable=too-many-instance-attributes