import numpy as np from bidict import bidict from scipy import constants,integrate import warnings from cycler import cycler import matplotlib as mpl from matplotlib import pyplot as plt from copy import deepcopy import re def nm2A(x): return 10.*x def hdf5_to_dict(filename_or_hdf5_object): """Load all elements in hdf5 into a dictionary. Groups define subdictionaries.""" import h5py ## decide if filename and open -- STR TEST MAY NOT BE A GOOD TEST if isinstance(filename_or_hdf5_object,str): filename_or_hdf5_object = h5py.File(expand_path(filename_or_hdf5_object),'r') retval = {} # the output data ## recurse through object loading data for key in filename_or_hdf5_object.keys(): ## make a new subdict recursively if isinstance(filename_or_hdf5_object[key],h5py.Dataset): retval[str(key)] = filename_or_hdf5_object[key][()] ## add data else: retval[str(key)] = hdf5_to_dict(filename_or_hdf5_object[key]) ## convert bytes object to unicode for key in retval: if isinstance(retval[key],bytes): retval[key] = retval[key].decode() return(retval) def dict_to_hdf5(filename,dictionary,keys=None,overwrite=True,compress_data=False): """Save all elements of a dictionary as datasets in an hdf5 file. Compression options a la h5py, e.g., 'gzip' or True for defaults""" import h5py import os filename = expand_path(filename) if keys is None: keys = list(dictionary.keys()) # default add all keys to datasets if overwrite: if os.path.isdir(filename): raise Exception("Is directory: "+filename) if os.path.lexists(filename): os.unlink(filename) else: if os.path.lexists(filename): raise Exception("File exists: "+filename) f = h5py.File(filename,mode='w') for key in keys: kwargs = {} if compress_data and not np.isscalar(dictionary[key]): kwargs['compression'] = "gzip" kwargs['compression_opts'] = 9 f.create_dataset(key,data=dictionary[key],**kwargs) f.close() def txt_to_dict( filename, # path or open data stream (will be closed) labels=None, delimiter=None, skiprows=0, comment_regexp='#', labels_commented=True, awkfilter=None, filter_function=None, filter_regexp=None, ignore_blank_lines=True, replacement_for_blank_elements='nan' ): """Convert text file to dictionary. Keys are taken from the first uncommented record, or the last commented record if labels_commented=True. Leading/trailing whitespace and leading comment_starts are stripped from keys. filter_function: if not None run lines through this function before parsing. filter_regexp: if not None then must be (pattern,repl) and line run through re.sub(pattern,repl,line) before parsing. """ ## If filename is a path name, open it, else assumed is already an open file. if type(filename)==str: filename = expand_path(filename) if awkfilter is not None: filename = pipe_through_awk(filename,awkfilter) else: filename=open(filename,'r', encoding='utf8') ## Reads all data, filter, and split lines = [] last_line_in_first_block_of_commented_lines = None first_block_commented_lines_passed = False number_of_columns = None for i,line in enumerate(filename.readlines()): if i < skiprows: continue line = line.strip() # remove leading/trailing whitespace if ignore_blank_lines and len(line)==0: continue if filter_function is not None: line = filter_function(line) if filter_regexp is not None: line = re.sub(filter_regexp[0],filter_regexp[1],line) line = (line.split() if delimiter is None else line.split(delimiter)) # split line if comment_regexp is not None and re.match(comment_regexp,line[0]): # commented line found if not first_block_commented_lines_passed: line[0] = re.sub(comment_regexp,'',line[0]) # remove comment start if len(line[0])==0: line = line[1:] # first element was comment only, last_line_in_first_block_of_commented_lines = line continue first_block_commented_lines_passed = True # look for no more key labels if number_of_columns is None: number_of_columns = len(line) else: assert len(line)==number_of_columns,f'Wrong number of column on line {i}' lines.append(line) # store this line, it contains data filename.close() if number_of_columns is None: return({}) # no data ## get data labels if labels is None: # look for labels if not given if labels_commented: # expect last line of initial unbroken comment block if last_line_in_first_block_of_commented_lines is None: labels = ['column'+str(i) for i in range(number_of_columns)] # get labels as column indices else: labels = [t.strip() for t in last_line_in_first_block_of_commented_lines] # get labels from commented line else: labels = [t.strip() for t in lines.pop(0)] # get from first line of data assert len(set(labels))==len(labels),'Non-unique data labels: '+repr(labels) assert len(labels)==number_of_columns,f'Number of labels ({len(labels)}) does not match number of columns ({number_of_columns})' if len(lines)==0: return({t:[] for t in key}) # no data ## get data from rest of file, and convert to arrays data = {} for key,column in zip(labels,zip(*lines)): column = [(t.strip() if len(t.strip())>0 else replacement_for_blank_elements) for t in column] data[key] = try_cast_to_numerical_array(column) return(data) def try_cast_to_numerical_array(x): """Try to cast an interator into an array of ints. On failure try floats. On failure return as array of strings.""" try: return np.array(x,dtype=float) except ValueError: return np.array(x,dtype=str) def expand_path(path): """Shortcut to os.path.expanduser(path). Returns a single file only, first in list of matching.""" import os return os.path.expanduser(path) def file_to_array_unpack(*args,**kwargs): """Same as file_to_array but unpack data by default.""" kwargs.setdefault('unpack',True) return(file_to_array(*args,**kwargs)) def resample(xin,yin,xout): """Spline or bin (as appropriate) (x,y) data to a given xout grid""" assert np.all(xin==np.unique(xin)),'Input x-data not monotonically increasing.' assert all(yin>=0),'Negative cross section in input data' assert not np.any(np.isnan(yin)),'NaN cross section in input data' assert xout[0]>=xin[0],'Output x minimum less than input.' assert xout[-1]<=xin[-1],'Output x maximum greater than input.' ## integration region boundary points -- edge points and mid ## points of xout xbnd = np.concatenate((xout[0:1],(xout[1:]+xout[:-1])/2,xout[-1:])) ## linear spline data to original and boundary points xfull = np.unique(np.concatenate((xin,xbnd))) yfull = spline(xin,yin,xfull,order=1) ## indentify boundary pointsin full ibnd = np.searchsorted(xfull,xbnd) ## compute trapezoidal cumulative integral ycum = np.concatenate(([0],integrate.cumtrapz(yfull,xfull))) ## output cross section points are integrated values between ## bounds yout = (ycum[ibnd[1:]]-ycum[ibnd[:-1]])/(xfull[ibnd[1:]]-xfull[ibnd[:-1]]) return yout def resample_out_of_bounds_to_zero(xin,yin,xout): yout = np.zeros(xout.shape,dtype=float) i = (xout>=xin[0])&(xout<=xin[-1]) if sum(i)>0: yout[i] = resample(xin,yin,xout[i]) return yout def spline(xi,yi,x,s=0,order=3,check_bounds=True,set_out_of_bounds_to_zero=True,sort_data=True,ignore_nan_data=False): """Evaluate spline interpolation of (xi,yi) at x. Optional argument s is spline tension. Order is degree of spline. Silently defaults to 2 or 1 if only 3 or 2 data points given. """ from scipy import interpolate order = min(order,len(xi)-1) xi,yi,x = np.array(xi),np.array(yi),np.array(x) if ignore_nan_data: i = np.isnan(xi)|np.isnan(yi) if any(i): xi,yi = xi[~i],yi[~i] if sort_data: i = np.argsort(xi) xi,yi = xi[i],yi[i] if set_out_of_bounds_to_zero: i = (x>=xi[0])&(x<=xi[-1]) y = np.zeros(x.shape) if any(i): y[i] = spline(xi,yi,x[i],s=s,order=order,set_out_of_bounds_to_zero=False,sort_data=False) return(y) if check_bounds: assert x[0]>=xi[0],'Splined lower limit outside data range: '+str(x[0])+' < '+str(xi[0]) assert x[-1]<=xi[-1],'Splined upper limit outside data range: '+format(x[-1],'0.10g')+' > '+format(xi[-1],'0.10g') return interpolate.UnivariateSpline(xi,yi,k=order,s=s)(x) def date_string(): """Get string representing date in ISO format.""" import datetime t = datetime.datetime.now() return('-'.join([str(t.year),format(t.month,'02d'),format(t.day,'02d')])) def format_columns(data,fmt='>10g',labels=None,header=None, record_separator='\n',field_separator=' ', comment_string='# ',): """Print args in with fixed column width. Labels are column titles. NOT QUITE THERE YET""" import re ## if data is dict, reinterpret appropriately if hasattr(data,'keys'): labels = data.keys() data = [data[key] for key in data] ## make formats a list as long as data if isinstance(fmt,str): fmt = [fmt for t in data] ## get string formatting for labels and failed formatting fmt_functions = [] for f in fmt: def fcn(val): try: ## return in given format if possible return(format(val,f)) except: ## else default to a string of that correct length r = re.match(r'[^0-9]*([0-9]+)(\.[0-9]+)?[^0-9].*',f) return(format(val,'>'+r.groups()[0]+'s')) fmt_functions.append(fcn) ## begin output records records = [] ## add header if required if header is not None: records.append(comment_string+header.strip().replace('\n','\n'+comment_string)) ## labels if required if labels!=None: records.append(comment_string+field_separator.join([f(label) for (f,label) in zip(fmt_functions,labels)])) ## compose formatted data columns comment_pad = ''.join([' ' for t in comment_string]) records.extend([comment_pad+field_separator.join([f(field) for (f,field) in zip(fmt_functions,record)]) for record in zip(*data)]) return(record_separator.join(records)) def string_to_file(fileName,string,mode='w'): """Write string to fileName.""" f = open(expand_path(fileName),mode=mode); f.write(string) f.close() def expand_path(path): """Shortcut to os.path.expanduser(path). Returns a single file only, first in list of matching.""" import os return os.path.expanduser(path) def leastsq(func, x0, dx, R=100., print_error_mesg=True,error_only=False, xtol=1.49012e-8,): """ Rejig the inputs of scipy.optimize.leastsq so that they do what I want them to. \nInputs:\n func -- The same as for leastsq. x0 -- The same as for leastsq. dx -- A sequence of the same length as x0 containing the desired absolute stepsize to use when calculating the finite difference Jacobean. R -- The ratio of two step sizes: Dx/dx. Where Dx is the maximum stepsize taken at any time. Note that this is only valid for the first iteration, after which leastsq appears to approximately double the 'factor' parameter. print_error_mesg -- if True output error code and message if failure \nOutputs: (x,sigma)\n x -- array of fitted parameters sigma -- error of these The reason for doing this is that I found it difficult to tweak the epsfcn, diag and factor parametres of leastsq to do what I wanted, as far as I can determine these behave in the following way: dx = x*sqrt(epsfcn) ; x!=0, dx = 1*sqrt(epsfcn) ; x==0. Default epsfcn=2.2e-16 on scucomp2. Dx = abs(x*100) ; x!=0, factor is not set, Dx = abs(x*factor) ; x!=0, factor is set, Dx = abs(factor) ; x==0, factor is set, Dx = 100 ; x==0, factor is not set, diag is not set, Dx = abs(100/diag) ; x==0, factor is not set, diag is set, Dx = abs(factor/diag); x==0, factor is set, diag is set. Many confusing cases, particularly annoying when initial x==0 and it is not possible to control dx or Dx individually for each parameter. My solution was to add a large value to each parameter so that there is little or no chance it will change magnitude during the course of the optimisation. This value was calculated separately for each parameter giving individual control over dx. I did not think of a way to also control Dx individually, instead the ratio R=Dx/dx may be globally set. """ from scipy import optimize ## limit the number of evaluation to a minimum number to compute ## the uncertainty from the second derivative - make R small to ## improve performance? - Doesn't work for very large number of ## parameters - errors are all nan, probably because of a bug in ## leastsq? if error_only: maxfev = len(x0)+1 R = 1. else: maxfev = 0 ## try and wangle actual inputs of numpy.leastsq to get the right ## step sizes x0=np.array(x0) dx=np.array(dx) epsfcn = 1e-15 # required that sqrt(epsfcn)<>p factor = R*np.sqrt(epsfcn) x = x0-xshift ## perform optimisation. try block is for the case where failure ## to calculte error try: (x,cov_x,info,mesg,success)=optimize.leastsq( lambda x:func(x+xshift), x, epsfcn=epsfcn, factor=factor, full_output=True, maxfev=maxfev, xtol = xtol, ) except ValueError as err: if err.message=='array must not contain infs or NaNs': raise Exception('Bad covariance matrix in error calculation, residual independent of some variable?') else: raise ## warn on error if requested if (not success) & print_error_mesg: import warnings warnings.warn("leastsq exit code: "+str(success)+mesg) ## sometimes this is not an array if not np.iterable(x): x=[x] ## attempt to calculate covariance of parameters if cov_x is None: sigma_x = np.nan*np.ones(len(x)) else: chisq=sum(info["fvec"]*info["fvec"]) dof=len(info["fvec"])-len(x)+1 # degrees of freedom ## assumes unweighted data with experimental uncertainty ## deduced from fitted residual. ref gavin2011. std_y = np.sqrt(chisq/dof) sigma_x = np.sqrt(cov_x.diagonal())*std_y return(x+xshift,sigma_x) def mkdir_if_necessary(directory): """Create directory tree if it doesn't exist.""" import os directory = expand_path(directory) partial_directories = directory.split('/') for i in range(len(partial_directories)): partial_directory = '/'.join(partial_directories[0:i+1]) if partial_directory=='' or os.path.isdir(partial_directory): continue else: if os.path.exists(partial_directory): raise Exception("Exists and is not a directory: "+partial_directory) os.mkdir(partial_directory) ## dictionary mapping for converting species names _species_name_translation_dict = dict( leiden = bidict({'Ca':'ca', 'He':'he', 'Cl':'cl', 'Cr':'cr', 'Mg':'mg', 'Mn':'mn', 'Na':'na', 'Ni':'ni', 'Rb':'rb', 'Ti':'ti', 'Zn':'zn', 'Si':'si', 'Li':'li', 'Fe':'fe', 'HCl':'hcl', 'Al':'al', 'AlH':'alh', 'LiH':'lih', 'MgH':'mgh', 'NaCl':'nacl', 'NaH':'nah', 'SiH':'sih', 'Co':'cob'}),) ## functions for converting a species name _species_name_translation_functions = {} def _f(name): """Translate form my normal species names into something that looks nice in matplotlib.""" name = re.sub(r'([0-9]+)',r'$_{\1}$',name) # subscript multiplicity name = re.sub(r'([+-])',r'$^{\1}$',name) # superscript charge return(name) _species_name_translation_functions[('standard','matplotlib')] = _f def _f(leiden_name): """Translate from Leidne data base to standard.""" ## default to uper casing name = leiden_name.upper() name = name.replace('C-','c-') name = name.replace('L-','l-') ## look for two-letter element names name = name.replace('CL','Cl') name = name.replace('SI','Si') name = name.replace('CA','Ca') ## look for isotopologues name = name.replace('C*','13C') name = name.replace('O*','18O') name = name.replace('N*','15N') ## assume final p implies + if name[-1]=='P' and name!='P': name = name[:-1]+'+' return name _species_name_translation_functions[('leiden','standard')] = _f def _f(standard_name): """Translate form my normal species names into the Leiden database equivalent.""" standard_name = standard_name.replace('+','p') return standard_name.lower() _species_name_translation_functions[('standard','leiden')] = _f def translate_species(name,input_encoding,output_encoding): """Translate species name between different formats.""" ## vectorise if not np.isscalar(name): return([translate_species(namei,input_encoding,output_encoding) for namei in name]) ## check translation dictionaries first if (output_encoding in _species_name_translation_dict and name in _species_name_translation_dict[output_encoding]): return(_species_name_translation_dict[output_encoding][name]) ## use a function if (input_encoding,output_encoding) in _species_name_translation_functions: return(_species_name_translation_functions[(input_encoding,output_encoding)](name)) ## try passing through 'standard' form if input_encoding=='standard' or output_encoding=='standard': raise Exception(f'Do not know how to translate species name betwen {repr((input_encoding,output_encoding))}') return( translate_species( translate_species(name,input_encoding,'standard'), 'standard',output_encoding)) def save_cross_section_leiden_photodissoc_database( filename, header, lines_wavelength=[], lines_integrated_cross_section=[], continuum_wavelength=[], continuum_cross_section=[], ): """Save a cross section into the file format of the Leiden photodissoication database. LINES NOT IMPLEMENTED!!!. Input wavelengths are expected in nm, but printed in files in Angstroms.""" lines_wavelength,lines_integrated_cross_section = np.array(lines_wavelength),np.array(lines_integrated_cross_section) continuum_wavelength,continuum_cross_section = np.array(continuum_wavelength),np.array(continuum_cross_section) ## ## do not include zero cross section ## i = continuum_cross_section>0 ## continuum_wavelength,continuum_cross_section = continuum_wavelength[i],continuum_cross_section[i] ## convert to Angstroms lines_wavelength = nm2A(np.array(lines_wavelength,ndmin=1,dtype=float)) # convert to Angstroms lines_integrated_cross_section *= 10. # convert to cm2.Angstroms continuum_wavelength = nm2A(np.array(continuum_wavelength,ndmin=1,dtype=float)) # convert to Angstroms ## begin file data lines = [] lines.append(header.strip()) ## IMPLEMENT LINES HERE lines.append(format(len(lines_wavelength))) # number of discrete lines for i,(wavelength,cross_section) in enumerate(zip(lines_wavelength,lines_integrated_cross_section)): lines.append(' '.join([format(i,'4d'),format(wavelength,'15.5f'),format(cross_section,'15.5e')])) ## continuum cross section lines.append(str(len(continuum_wavelength))) # number of continuum points if len(continuum_wavelength)>0: # lines.append(format(-(int(np.floor(continuum_wavelength[-1])+1)),'d')) # longest-wavelength threshold of continuum lines.append('-1') lines.extend([ format(i,'>4d')+' '+format(wavelength,'10.5f')+' '+format(cross_section,'0.3e') for (i,(wavelength,cross_section)) in enumerate(zip(continuum_wavelength,continuum_cross_section)) ]) ## write to file fid = open(filename,'w') fid.write('\n'.join(lines)) fid.close() ## standard papersize for figures - in inches papersize=dict( a4=(8.3,11.7), a4_portrait=(8.3,11.7), a4_landscape=(11.7,8.3), a5=(5.87,8.3), a5landscape=(8.3,5.87), letter=(8.5,11), letter_portrait=(8.5,11), letter_landscape=(11,8.5), latexTiny=(constants.golden*1.2,1.2), latexSmall=(constants.golden*1.7,1.7), latexMedium=(constants.golden*3.,3.), latexLarge=(constants.golden*4.,4.), squareMedium=(5.,5.), squareLarge=(8.,8.), article_full_page_width=6.77, # article_single_column_width=3.27, article_single_column_width=3.5, article_full_page_height=8.66, ) def presetRcParams( preset='base', # name of the preset to use make_fig=False, # make a figure and axes and return (fig,ax) **params, # a dictionay containing any valid rcparams, or abbreviated by removing xxx.yyy. etc ): """Call this function wit some presets before figure object is created. If make_fig = True return (fig,ax) figure and axis objects. Additional kwargs are applied directly to rcParams""" xscreen,yscreen = 5,5 ## dicitionary of dictionaries containing keyval pairs to ## substitute into rcParams presets = dict() ## the base params presets['base'] = { 'legend.handlelength' :1.5, 'legend.handletextpad' :0.4, 'legend.labelspacing' :0., # 'legend.loc' :'best', # setting this to best makes things slooooow 'legend.numpoints' :1, 'font.family' :'serif', 'text.usetex' :False, 'text.latex.preamble' :[ r'\usepackage{mhchem}', r'\usepackage[np]{numprint}\npdecimalsign{\ensuremath{.}} \npthousandsep{\,} \npproductsign{\times} \npfourdigitnosep ', ], 'mathtext.fontset' :'cm', 'lines.markeredgewidth': 1, # 'axes.prop_cycle': cycler('color',linecolors_colorblind_safe), 'axes.prop_cycle': cycler('color',linecolors_print), # 'axes.color_cycle': linecolors_colorblind_safe, 'patch.edgecolor': 'none', 'xtick.minor.top': True, 'xtick.minor.bottom': True, 'xtick.minor.visible': True , 'xtick.top': True , 'xtick.bottom': True , 'ytick.minor.right': True, 'ytick.minor.left': True, 'ytick.minor.visible': True , 'ytick.right': True , 'ytick.left': True , 'path.simplify' :False, # whether or not to speed up plots by joining line segments 'path.simplify_threshold' :1, # how much to do so 'agg.path.chunksize': 10000, # antialisin speed up -- does not seem to do anything over path.simplify ## set up axes tick marks and labels 'axes.formatter.limits' : (-3, 6), # use scientific notation if log10 of the axis range is smaller than the first or larger than the second 'axes.formatter.use_mathtext' : True, # When True, use mathtext for scientific notation. 'axes.formatter.useoffset' : False, # If True, the tick label formatter # will default to labeling ticks relative # to an offset when the data range is # small compared to the minimum absolute # value of the data. 'axes.formatter.offset_threshold' : 4, # When useoffset is True, the offset # will be used when it can remove # at least this number of significant # digits from tick labels. } presets['screen']=deepcopy(presets['base']) presets['screen'].update({ 'figure.figsize' :(xscreen,yscreen), # 'figure.figsize' :(10,10), 'figure.subplot.left':0.05, 'figure.subplot.right':0.95, 'figure.subplot.bottom':0.05, 'figure.subplot.top':0.95, 'figure.subplot.wspace':0.2, 'figure.subplot.hspace':0.2, 'figure.autolayout' : True, # reset tight_layout everytime figure is redrawn -- seems to cause problems with long title and label strings # 'toolbar' :'none' , # hides toolbar but also disables keyboard shortcuts 'legend.handlelength':4, 'text.usetex' :False, 'lines.linewidth' : 1, 'lines.markersize' : 10.0, 'grid.alpha' : 1.0, 'grid.color' : 'gray', 'grid.linestyle' : ':', 'legend.fontsize' :9., 'axes.titlesize' :14., 'axes.labelsize' :14., 'xtick.labelsize' :14., 'ytick.labelsize' :14., 'font.size' :14., 'axes.prop_cycle' : cycler('color',linecolors_screen), 'path.simplify' :True , # whether or not to speed up plots by joining line segments 'path.simplify_threshold' :1, # how much to do so 'agg.path.chunksize': 10000, # antialisin speed up -- does not seem to do anything over path.simplify }) linecolors_screen=( 'red', 'blue', 'green', 'black', 'orange', 'magenta', 'aqua', 'indigo', 'brown', ## 'grey', ## 'aliceblue', ## 'aquamarine', ## 'azure', ## 'beige', ## 'bisque', ## 'blanchedalmond', ## 'blue', ## 'blueviolet', ## 'brown', 'burlywood', 'cadetblue', 'chartreuse', 'chocolate', 'coral', 'cornflowerblue', ## 'cornsilk', 'crimson', 'cyan', # 'darkblue', # 'darkcyan', # 'darkgoldenrod', # 'darkgray', # 'darkgreen', # 'darkkhaki', # 'darkmagenta', # 'darkolivegreen', # 'darkorange', # 'darkorchid', # 'darkred', # 'darksalmon', # 'darkseagreen', # 'darkslateblue', # 'darkslategray', # 'darkturquoise', # 'darkviolet', # 'deeppink', # 'deepskyblue', # 'dimgray', # 'dodgerblue', # 'firebrick', # 'forestgreen', # 'fuchsia', # 'gainsboro', # 'gold', # 'goldenrod', # 'gray', # 'green', # 'greenyellow', # 'honeydew', # 'hotpink', # 'indianred', # 'indigo', # 'ivory', # 'khaki', # 'lavender', # 'lavenderblush', # 'lawngreen', # 'lemonchiffon', # 'lightblue', # 'lightcoral', # 'lightcyan', # 'lightgoldenrodyellow', # 'lightgreen', # 'lightgray', # 'lightpink', # 'lightsalmon', # 'lightseagreen', # 'lightskyblue', # 'lightslategray', # 'lightsteelblue', # 'lightyellow', # 'lime', # 'limegreen', # 'linen', # 'magenta', # 'maroon', # 'mediumaquamarine', # 'mediumblue', # 'mediumorchid', # 'mediumpurple', # 'mediumseagreen', # 'mediumslateblue', # 'mediumspringgreen', # 'mediumturquoise', # 'mediumvioletred', # 'midnightblue', # 'mintcream', # 'mistyrose', # 'moccasin', # 'navajowhite', # 'olive', # 'olivedrab', # 'orange', # 'orangered', # 'orchid', # 'palegoldenrod', # 'palegreen', # 'paleturquoise', # 'palevioletred', # 'papayawhip', # 'peachpuff', # 'peru', # 'pink', # 'plum', # 'powderblue', # 'purple', # 'red', # 'rosybrown', # 'royalblue', # 'saddlebrown', # 'salmon', # 'sandybrown', # 'seagreen', # 'seashell', # 'sienna', # 'silver', # 'skyblue', # 'slateblue', # 'slategray', # 'snow', # 'springgreen', # 'steelblue', # 'tan', # 'teal', # 'thistle', # 'tomato', # 'turquoise', # 'violet', # 'wheat', # 'yellow', # 'yellowgreen', # # 'floralwhite', 'ghostwhite', 'navy','oldlace', 'white','whitesmoke','antiquewhite', ) linecolors_colorblind_safe=( (204./256.,102./256.,119./256.), (61./256., 170./256.,153./256.), (51./256., 34./256., 136./256.), ## (17./256., 119./256.,51./256.), (170./256.,68./256., 153./256.), ## (136./256.,34./256., 85./256.), (153./256.,153./256.,51./256.), (136./256.,204./256.,238./256.), (221./256.,204./256.,199./256.), (51./256., 102./256.,170./256.), (17./256., 170./256.,153./256.), (102./256.,170./256.,85./256.), (153./256.,34./256., 136./256.), (238./256.,51./256., 51./256.), (238./256.,119./256.,34./256.), ## (204./256.,204./256.,85./256.), ## (255./256.,238./256.,51./256.), ## (119./256.,119./256.,119./256.), ) ## from http://www.sron.nl/~pault/ ## from http://colorbrewer2.org/#type=diverging&scheme=RdYlBu&n=6 linecolors_print=( # ## attempt1 # '#a50026', # '#f46d43', # '#fdae61', # '#fee090', # '#74add1', # '#4575b4', # '#4575b4', # '#313695', # '#d73027', # '#abd9e9', # '#e0f3f8', ## attempt 2 '#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#a65628', '#f781bf', '#ffff33', # light yellow ) # linecolors = mpl.rcParams['axes.color_cycle'] linecolors = [f['color'] for f in mpl.rcParams['axes.prop_cycle']] def newcolor(index=None,reset=None,linecolors=None): """Retuns a color string, different to the last one, from the list linecolors. If reset is set, returns to first element of linecolors, no color is returned. If index (int) is supplied return this color. If index is supplied and reset=True, set index to this color. """ global _newcolor_nextcolor if linecolors is None: linecolors = [f['color'] for f in mpl.rcParams['axes.prop_cycle']] # linecolors = [f for f in mpl.rcParams['axes.color_cycle']] if reset!=None or index in ['None','none','']: _newcolor_nextcolor=0 return if index is not None: ## index should be an int -- but make it work for anything try: index = int(index) except (TypeError,ValueError): # index = id(index) index = hash(index) if reset: _newcolor_nextcolor = (index) % len(linecolors) return(linecolors[(index) % len(linecolors)]) retval = linecolors[_newcolor_nextcolor] _newcolor_nextcolor = (_newcolor_nextcolor+1) % len(linecolors) return retval def legend( *plot_kwargs_or_lines, # can be dicts of plot kwargs including label ax=None, # axis to add legend to include_ax_lines=True, # add current lines in axis to legend color_text= True, # color the label text show_style=False, # hide the markers in_layout=False, # constraining tight_layout or not **legend_kwargs, # passed to legend ): """Make a legend and add to axis. Operates completely outside the normal scheme.""" if ax is None: ax = plt.gca() def _reproduce_line(line): # Makes a new empty line with the properties of the input 'line' new_line = plt.Line2D([],[]) # the new line to fill with properties for key in ('alpha','color','fillstyle','label', 'linestyle','linewidth','marker', 'markeredgecolor','markeredgewidth','markerfacecolor', 'markerfacecoloralt','markersize','markevery', 'solid_capstyle','solid_joinstyle',): # add all these properties if hasattr(line,'get_'+key): # if the input line has this property getattr(new_line,'set_'+key)(getattr(line,'get_'+key)()) # copy it to the new line elif hasattr(line,'get_children'): # if it does not but has children (i.e., and errorbar) then search in them for property for child in line.get_children(): if hasattr(child,'get_'+key): # property found try: # try to set in new line, if ValueError then it has an invalid valye for a Line2D getattr(new_line,'set_'+key)(getattr(child,'get_'+key)()) break # property added successfully search no more children except ValueError: pass else: # what to do if property not found anywhere pass # nothing! return(new_line) ## collect line handles and labels handles,labels = [],[] ## add existing lines in axis to legend if include_ax_lines: for handle,label in zip(*ax.get_legend_handles_labels()): if label!='_nolegend': labels.append(label) handles.append(_reproduce_line(handle)) ## add get input lines or kwargs for i,t in enumerate(plot_kwargs_or_lines): if isinstance(t,matplotlib.lines.Line2D) or isinstance(t,mpl.container.ErrorbarContainer): raise Exception("Does not currently work for some reason.") t = t[0] if t.get_label()!='_nolegend': labels.append(t.get_label()) handles.append(_reproduce_line(t)) elif isinstance(t,dict): if t['label']!='_nolegend_': labels.append(t['label']) handles.append(plt.Line2D([],[],**t)) else: raise Exception(f'Unhandled plot container type: {type(t)}') ## hide markers if desired if not show_style: for t in handles: t.set_linestyle('') t.set_marker('') legend_kwargs['handlelength'] = 0 legend_kwargs['handletextpad'] = 0 ## make a legend legend_kwargs.setdefault('handlelength',2) legend_kwargs.setdefault('loc','best') legend_kwargs.setdefault('frameon',False) legend_kwargs.setdefault('framealpha',1) legend_kwargs.setdefault('edgecolor','black') legend_kwargs.setdefault('fontsize','medium') if len(labels)==0: return(None) leg = ax.legend(labels=labels,handles=handles,**legend_kwargs) leg.set_in_layout(False) ## color the text if color_text: for text,handle in zip(leg.get_texts(),handles): if isinstance(handle,mpl.container.ErrorbarContainer): color = handle[0].get_color() else: color = handle.get_color() text.set_color(color) ## add to axis # ax.add_artist(leg) return(leg) def find(x): """Convert boolean array to array of True indices.""" return(np.where(x)[0]) def common(x,y,use_hash=False): """Return indices of common elements in x and y listed in the order they appear in x. Raises exception if repeating multiple matches found.""" if not use_hash: ix,iy = [],[] for ixi,xi in enumerate(x): iyi = find([xi==t for t in y]) if len(iyi)==1: ix.append(ixi) iy.append(iyi[0]) elif len(iyi)==0: continue else: raise Exception('Repeated value in y for: '+repr(xi)) if len(np.unique(iy))!=len(iy): raise Exception('Repeated value in x for something.') return(np.array(ix),np.array(iy)) else: xhash = np.array([hash(t) for t in x]) yhash = np.array([hash(t) for t in y]) ## get sorted hashes, checking for uniqueness xhash,ixhash = np.unique(xhash,return_index=True) assert len(xhash)==len(x),f'Non-unique values in x.' yhash,iyhash = np.unique(yhash,return_index=True) assert len(yhash)==len(y),f'Non-unique values in y.' ## use np.searchsorted to find one set of hashes in the other iy = np.arange(len(yhash)) ix = np.searchsorted(xhash,yhash) ## remove y beyond max of x i = ix0: ## get domains of line data - possible overlapping domains = [] width = linewidth_fwhm*fwhms_to_specify lines = np.sort(data['lines_wavelength']) domains.append([lines[0]-width]) for (line1,line2) in zip(lines[0:-1],lines[1:]): if line2-line1 < width*2: continue else: domains[-1].append(line1+width) domains.append([line2-width]) domains[-1].append(lines[-1]+width) ## get combined wavelength wavelength = np.unique(np.concatenate((data['continuum_wavelength'], # np.concatenate([np.arange(t0,t1,linewidth_fwhm/10.) for (t0,t1) in domains]) np.concatenate([ np.concatenate(([t0-0.1*linewidth_fwhm,t1+0.1*linewidth_fwhm], np.arange(t0,t1,linewidth_fwhm/10.) )) for (t0,t1) in domains]) ))) ## get total cross seciton on this wavelength scale if len(data['continuum_wavelength']) > 0: cross_section = resample_out_of_bounds_to_zero(data['continuum_wavelength'],data['continuum_cross_section'],wavelength) else: cross_section = np.full(len(wavelength),0.) for (t0,t1) in zip(data['lines_wavelength'],data['lines_integrated_cross_section']): i = (wavelength>=(t0-width))&(wavelength<=(t0+width)) if lineshape=='gaussian': cross_section[i] += gaussian(wavelength[i],fwhm=linewidth_fwhm,mean=t0)*t1 elif lineshape=='lorentzian': cross_section[i] += lorentzian(wavelength[i],Gamma=linewidth_fwhm,k0=t0)*t1 else: raise InputError('Invald lineshape: '+repr(lineshape)) data['wavelength'],data['cross_section'] = wavelength,cross_section else: data['wavelength'],data['cross_section'] = data['continuum_wavelength'],data['continuum_cross_section'] return(data) def save_cross_section_leiden_photodissoc_database( filename, header, lines_wavelength=[], lines_integrated_cross_section=[], continuum_wavelength=[], continuum_cross_section=[], ): """Save a cross section into the file format of the Leiden photodissoication database. LINES NOT IMPLEMENTED!!!. Input wavelengths are expected in nm, but printed in files in Angstroms.""" lines_wavelength,lines_integrated_cross_section = np.array(lines_wavelength),np.array(lines_integrated_cross_section) continuum_wavelength,continuum_cross_section = np.array(continuum_wavelength),np.array(continuum_cross_section) ## ## do not include zero cross section ## i = continuum_cross_section>0 ## continuum_wavelength,continuum_cross_section = continuum_wavelength[i],continuum_cross_section[i] ## convert to Angstroms lines_wavelength = np.array(lines_wavelength,ndmin=1,dtype=float)*10 # convert to Angstroms lines_integrated_cross_section *= 10. # convert to cm2.Angstroms continuum_wavelength = np.array(continuum_wavelength,ndmin=1,dtype=float)*10 # convert to Angstroms ## begin file data lines = [] lines.append(header.strip()) ## IMPLEMENT LINES HERE lines.append(format(len(lines_wavelength))) # number of discrete lines for i,(wavelength,cross_section) in enumerate(zip(lines_wavelength,lines_integrated_cross_section)): lines.append(' '.join([format(i,'4d'),format(wavelength,'15.5f'),format(cross_section,'15.5e')])) ## continuum cross section lines.append(str(len(continuum_wavelength))) # number of continuum points if len(continuum_wavelength)>0: # lines.append(format(-(int(np.floor(continuum_wavelength[-1])+1)),'d')) # longest-wavelength threshold of continuum lines.append('-1') lines.extend([ format(i,'>4d')+' '+format(wavelength,'10.5f')+' '+format(cross_section,'0.3e') for (i,(wavelength,cross_section)) in enumerate(zip(continuum_wavelength,continuum_cross_section)) ]) ## write to file fid = open(filename,'w') fid.write('\n'.join(lines)) fid.close()