Source code for stellar_evolution.library_interface

#!/usr/bin/env python3

"""An interface to the POET stellar evolution interpolation utilities."""

from ctypes import\
    cdll,\
    c_int, c_double, c_void_p, c_char_p, c_uint, c_bool,\
    byref
from ctypes.util import find_library
import re

import numpy

#class naming convention mimicks ctypes naming
#This is just a placeholder, so no public methods.
#pylint: disable=invalid-name
#pylint: disable=too-few-public-methods
[docs]class c_interpolator_p(c_void_p): """Type corresponding to pointer to intepolator in the POET library."""
[docs]class c_quantity_p(c_void_p): """Type corresponding to pointer to evolution quantities in POET library."""
#pylint: enable=invalid-name #pylint: enable=too-few-public-methods
[docs]def initialize_library(): """Prepare the stellarEvolution library for use.""" library_fname = find_library('stellarEvolution') if library_fname is None: raise OSError('Unable to find POET\'s stellarEvolution library.') result = cdll.LoadLibrary(library_fname) num_quantities = c_int.in_dll(result, 'NUM_QUANTITIES').value result.create_interpolator.argtypes = [ c_char_p, numpy.ctypeslib.ndpointer(dtype=c_double, ndim=1, shape=(num_quantities,), flags='C_CONTIGUOUS'), numpy.ctypeslib.ndpointer(dtype=c_int, ndim=1, shape=(num_quantities,), flags='C_CONTIGUOUS'), numpy.ctypeslib.ndpointer(dtype=c_bool, ndim=1, shape=(num_quantities,), flags='C_CONTIGUOUS'), numpy.ctypeslib.ndpointer(dtype=c_bool, ndim=1, shape=(num_quantities,), flags='C_CONTIGUOUS'), c_uint ] result.create_interpolator.restype = c_interpolator_p result.destroy_interpolator.argtypes = [ result.create_interpolator.restype ] result.destroy_interpolator.restype = None result.create_quantity.argtypes = [result.create_interpolator.restype, c_int, c_double, c_double] result.create_quantity.restype = c_quantity_p result.destroy_quantity.argtypes = [c_quantity_p] result.destroy_quantity.restype = None result.evaluate_quantity.argtypes = [c_quantity_p, c_double] result.evaluate_quantity.restype = c_double result.evaluate_quantity_array.argtypes = [ c_quantity_p, numpy.ctypeslib.ndpointer(dtype=c_double, ndim=1, flags='C_CONTIGUOUS'), c_uint, numpy.ctypeslib.ndpointer(dtype=c_double, ndim=1, flags='C_CONTIGUOUS'), ] result.quantity_min_age.restype = c_double result.quantity_max_age.restype = c_double result.quantity_continuous_range.restype = None result.save_interpolator.argtypes = [ result.create_interpolator.restype, c_char_p ] result.save_interpolator.restype = None result.load_interpolator.argtypes = [c_char_p] result.load_interpolator.restype = result.create_interpolator.restype result.differentiate_quantity.argtypes = [ c_quantity_p, c_double, numpy.ctypeslib.ndpointer(dtype=c_double, ndim=1, flags='C_CONTIGUOUS') ] result.differentiate_quantity_array.argtypes = [ c_quantity_p, numpy.ctypeslib.ndpointer(dtype=c_double, ndim=1, flags='C_CONTIGUOUS'), c_uint, numpy.ctypeslib.ndpointer(dtype=c_double, ndim=1, flags='C_CONTIGUOUS') ] result.default_smoothing.argtypes = [c_int] result.default_smoothing.restype = c_double result.default_vs_log_age.argtypes = [c_int] result.default_vs_log_age.restype = c_bool result.default_nodes.argtypes = [c_int] result.default_nodes.restype = c_int result.metallicity_from_feh.argtypes = [c_double] result.metallicity_from_feh.restype = c_double result.feh_from_metallicity.argtypes = [c_double] result.feh_from_metallicity.restype = c_double result.feh_from_z.argtypes = [c_double] result.feh_from_z.restype = c_double result.z_from_feh.argtypes = [c_double] result.z_from_feh.restype = c_double return result
library = initialize_library() library_track_fname_rex = re.compile( 'M(?P<MASS>[0-9.E+-]+)_Z(?P<Z>[0-9.E+-]+).csv' )
[docs]def library_track_fname(mass, feh): """ Returns the base name expected by library for a track. Args: mass: The mass of the star whose evolution is stored in the track. feh: The [Fe/H] value of the star whose evolution is stored in the track. Returns: str: The base filename the stellar evolution library expects to be used for the given track. """ return 'M%s_Z%s.csv' % (repr(float(mass)), repr(library.z_from_feh(feh)))
[docs]class MESAInterpolator: """A class for interpolating among a set of MESA tracks.""" quantity_list = ['RADIUS', 'ICONV', 'LUM', 'IRAD', 'MRAD', 'RRAD'] quantity_ids = {q: c_int.in_dll(library, q).value for q in quantity_list} quantity_names = {c_int.in_dll(library, q).value: q for q in quantity_list} default_smoothing = {q_name: library.default_smoothing(q_id) for q_name, q_id in quantity_ids.items()} default_nodes = {q_name: library.default_nodes(q_id) for q_name, q_id in quantity_ids.items()} default_vs_log_age = {q_name: library.default_vs_log_age(q_id) for q_name, q_id in quantity_ids.items()} default_log_quantity = {q_name: library.default_log_quantity(q_id) for q_name, q_id in quantity_ids.items()}
[docs] def __init__(self, **kwargs): """ Prepare a MESA based interpolation. Args: mesa_dir: A directory contaning a grid (mass and metallicity) of MESA tracks to base the interpolation on. Must not be specified if interpolator_fname is. smoothing: A numpy float array of the smoothing arguments to use for the interpolation of each quantity. Should be in the order defined by quantity_ids. nodes: A numpy integer array of the nodes to use for the interpolation of each quantity. Same order as smoothing. vs_log_age: A numpy boolean array indicating whether the interpolation for each quantity should be done vs log(age) instead of age. log_quantity: A numpy boolean array indicating whether the interpolation for each quantity should be of log(quantity) instead of quantity. interpolator_fname: The filename of a previously saved interpolator state. Must not be specified together with mesa_dir. If passed, the smoothing and nodes arguments are ignored. num_threads: The number of simultaneous threads to use when constructing the interpolation. Returns: None. """ if 'mesa_dir' in kwargs: self.interpolator = library.create_interpolator( kwargs['mesa_dir'].encode('ascii'), kwargs['smoothing'], kwargs['nodes'], kwargs['vs_log_age'], kwargs['log_quantity'], kwargs['num_threads'] ) else: assert 'interpolator_fname' in kwargs self.filename = kwargs['interpolator_fname'] self.interpolator = library.load_interpolator( kwargs['interpolator_fname'].encode('ascii') )
[docs] def delete(self): """Free the resources allocated at construction.""" library.destroy_interpolator(self.interpolator)
[docs] def save(self, filename): """ Save the interpolator created to the given file for faster creation. Args: filename: The name of the file to use for saving the state. Overwritten if exists. Returns: None """ self.filename = filename library.save_interpolator(self.interpolator, filename.encode('ascii'))
[docs] def __call__(self, quantity, mass, feh): """ Return a stellar quantity interpolated to the given mass and [Fe/H]. Args: quantity: A string identifying the quantity to interpolate. The following values are allowed: 'radius', 'iconv', 'lum', 'irad', 'mrad', 'rrad'. This is a case insensitive argument. mass: The mass of the star for which this quantity should be defined in solar masses. feh: The [Fe/H] of the star for which this quantity should be defined. Returns: Quantity: callable with an age parameter evaluating to the quantity at the given age. """ return Quantity( library.create_quantity(self.interpolator, self.quantity_ids[quantity.upper()], c_double(mass), c_double(feh)) )
[docs]class Quantity: """Callable that evaluates to the value of the quantity at a given age."""
[docs] def __init__(self, underlying_quantity): """Wrap the underlying EvolvingStellarQuantity into a callable.""" self.underlying_quantity = underlying_quantity self.min_age = library.quantity_min_age(underlying_quantity) self.max_age = library.quantity_max_age(underlying_quantity)
[docs] def delete(self): """Destroy the underlying quantity.""" library.destroy_quantity(self.underlying_quantity)
[docs] def __call__(self, age): """ Evaluate the underlying quantity at the given age(s) (in Gyr). Args: age: Either a single float or a numpy array of floats giving the ages at which to evaluate the quantity. Returns: type(age): The value(s) of the quantity in the same format as age. """ if isinstance(age, numpy.ndarray): assert (age >= self.min_age).all() assert (age <= self.max_age).all() result = numpy.empty(dtype=c_double, shape=(age.size,), order='C') library.evaluate_quantity_array(self.underlying_quantity, age, age.size, result) return result assert age > self.min_age assert age < self.max_age return library.evaluate_quantity(self.underlying_quantity, c_double(age))
[docs] def deriv(self, age): """ Return the 0-th, 1-st and 2-nd order derivatives of the quantity. Args: age(float or numpy array): Either a single float or a numpy array of floats giving the ages at which to evaluate the quantity. Returns: numpy array: Either 1-D (if age is a single float) or 2-D array if age is a numpy array where the outside (or only) index is the derivative order. """ if isinstance(age, numpy.ndarray): result = numpy.empty(dtype=c_double, shape=(3, age.size), order='C') library.differentiate_quantity_array( self.underlying_quantity, age, age.size, result.reshape(3 * age.size) ) else: result = numpy.empty(dtype=c_double, shape=(3,), order='C') library.differentiate_quantity(self.underlying_quantity, c_double(age), result) return result
[docs] def continuous_range(self, age): """ Return the range around age over which the quantity is continuous. Args: - age: The age around which the continuous region is required. Returns: A 2-tuple of the minimum and maximum ages surrounding age over thich the quantity is guaranteed continuous """ min_age, max_age = c_double(), c_double() library.quantity_continuous_range(self.underlying_quantity, c_double(age), byref(min_age), byref(max_age)) return min_age.value, max_age.value
[docs]def example(): """Example of the usage of the interface.""" # mesa_dir = '../poet_src/StellarEvolution/MESA' # interpolator = MESAInterpolator(mesa_dir = mesa_dir) # interpolator.save(mesa_dir + '/saved_interpolator') # for quantity_name in MESAInterpolator.quantity_list : # quantity = interpolator(quantity_name, 1.0, 0.0) # print(quantity_name + '(1.0) = ' + repr(quantity(1.0))) # print(quantity_name + '(4.6) = ' + repr(quantity(4.6))) loaded_interpolator = MESAInterpolator( interpolator_fname=('saved_interpolator') ) for quantity_name in MESAInterpolator.quantity_list: quantity = loaded_interpolator(quantity_name, 1.0, 0.25) ages = numpy.exp( numpy.linspace(numpy.log(max(quantity.min_age, 1e-5)), numpy.log(quantity.max_age), 5)[1:-1] ) print(quantity_name + '(' + repr(ages) + ') = ' + repr(quantity(ages)))
if __name__ == '__main__': example()