Coverage for frank/fit.py: 77%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Frankenstein: 1D disc brightness profile reconstruction from Fourier data
2# using non-parametric Gaussian Processes
3#
4# Copyright (C) 2019-2020 R. Booth, J. Jennings, M. Tazzari
5#
6# This program is free software: you can redistribute it and/or modify
7# it under the terms of the GNU General Public License as published by
8# the Free Software Foundation, either version 3 of the License, or
9# (at your option) any later version.
10#
11# This program is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with this program. If not, see <https://www.gnu.org/licenses/>
18#
19"""This module runs Frankenstein to fit a source's 1D radial brightness profile.
20 A default parameter file is used that specifies all options to run the fit
21 and output results. Alternatively a custom parameter file can be provided.
22"""
24import os
25import time
26import json
28# Force frank to run on a single thread if we are using it as a library
29def _check_and_warn_if_parallel():
30 """Check numpy is running in parallel"""
31 num_threads = int(os.environ.get('OMP_NUM_THREADS', '1'))
32 if num_threads > 1:
33 logging.warning("WARNING: You are running frank with "
34 "OMP_NUM_THREADS={}.".format(num_threads) +
35 "The code will likely run faster on a single thread.\n"
36 "Use 'unset OMP_NUM_THREADS' or "
37 "'export OMP_NUM_THREADS=1' to disable this warning.")
39import numpy as np
40import logging
42import frank
43from frank import io, geometry, make_figs, radial_fitters, utilities
45frank_path = os.path.dirname(frank.__file__)
48def get_default_parameter_file():
49 """Get the path to the default parameter file"""
50 return os.path.join(frank_path, 'default_parameters.json')
53def load_default_parameters():
54 """Load the default parameters"""
55 return json.load(open(get_default_parameter_file(), 'r'))
58def get_parameter_descriptions():
59 """Get the description for parameters"""
60 with open(os.path.join(frank_path, 'parameter_descriptions.json')) as f:
61 param_descrip = json.load(f)
62 return param_descrip
65def helper():
66 param_descrip = get_parameter_descriptions()
68 print("""
69 Fit a 1D radial brightness profile with Frankenstein (frank) from the
70 terminal with `python -m frank.fit`. A .json parameter file is required;
71 the default is default_parameters.json and is
72 of the form:\n\n {}""".format(json.dumps(param_descrip, indent=4)))
75def parse_parameters(*args):
76 """
77 Read in a .json parameter file to set the fit parameters
79 Parameters
80 ----------
81 parameter_filename : string, default `default_parameters.json`
82 Parameter file (.json; see frank.fit.helper)
83 uvtable_filename : string
84 UVTable file with data to be fit (.txt, .dat, .npy, or .npz).
85 The UVTable column format should be:
86 u [lambda] v [lambda] Re(V) [Jy] Im(V) [Jy] Weight [Jy^-2]
88 Returns
89 -------
90 model : dict
91 Dictionary containing model parameters the fit uses
92 """
94 import argparse
96 default_param_file = os.path.join(frank_path, 'default_parameters.json')
98 parser = argparse.ArgumentParser("Run a Frankenstein fit, by default using"
99 " parameters in default_parameters.json")
100 parser.add_argument("-p", "--parameter_filename",
101 default=default_param_file, type=str,
102 help="Parameter file (.json; see frank.fit.helper)")
103 parser.add_argument("-uv", "--uvtable_filename", default=None, type=str,
104 help="UVTable file with data to be fit. See"
105 " frank.io.load_uvtable")
106 parser.add_argument("-desc", "--print_parameter_description", default=None,
107 action="store_true",
108 help="Print the full description of all fit parameters")
110 args = parser.parse_args(*args)
112 if args.print_parameter_description:
113 helper()
114 exit()
116 model = json.load(open(args.parameter_filename, 'r'))
118 if args.uvtable_filename:
119 model['input_output']['uvtable_filename'] = args.uvtable_filename
121 if ('uvtable_filename' not in model['input_output'] or
122 not model['input_output']['uvtable_filename']):
123 raise ValueError("uvtable_filename isn't specified."
124 " Set it in the parameter file or run Frankenstein with"
125 " python -m frank.fit -uv <uvtable_filename>")
127 uv_path = model['input_output']['uvtable_filename']
128 if not model['input_output']['save_dir']:
129 # If not specified, use the UVTable directory as the save directory
130 model['input_output']['save_dir'] = os.path.dirname(uv_path)
132 # Add a save prefix to the .json parameter file for later use
133 model['input_output']['save_prefix'] = save_prefix = \
134 os.path.join(model['input_output']['save_dir'],
135 os.path.splitext(os.path.basename(uv_path))[0])
137 if model['input_output']['save_figures'] is True:
138 model['input_output']['fig_save_prefix'] = save_prefix
139 else:
140 model['input_output']['fig_save_prefix'] = None
142 log_path = save_prefix + '_frank_fit.log'
143 frank.enable_logging(log_path)
145 # Check whether the code runs in parallel now that the logging has been
146 # initialized.
147 _check_and_warn_if_parallel()
150 logging.info('\nRunning Frankenstein on'
151 ' {}'.format(model['input_output']['uvtable_filename']))
153 # Sanity check some of the .json parameters
154 if model['hyperparameters']['method'] not in ["Normal", "LogNormal"]:
155 err = ValueError("method should be 'Normal' or 'LogNormal'")
156 raise err
158 if model['geometry']['scale_height'] is not None and \
159 model['hyperparameters']['method'] == 'LogNormal':
160 logging.warning("WARNING: 'scale_height' is not 'None' AND 'method' is "
161 "'LogNormal'. It is suggested to use the 'Normal' method "
162 "for vertical inference.")
164 if model['hyperparameters']['nonnegative'] and \
165 model['hyperparameters']['method'] == 'LogNormal':
166 logging.warning("WARNING: 'nonnegative' is 'true' AND 'method' is "
167 "'LogNormal' --> performing a LogNormal fit and setting "
168 "'nonnegative' to 'false'")
169 model['hyperparameters']['nonnegative'] = False
171 if model['hyperparameters']['method'] == 'LogNormal' and \
172 model['hyperparameters']['p0'] is not None and \
173 model['hyperparameters']['p0'] > 1e-30:
174 err = ValueError("p0 = {}. If method is 'LogNormal', p0 should be "
175 "<~ 1e-35 (we recommend 1e-35)"
176 ".".format(model['hyperparameters']['p0']))
177 raise err
179 if model['plotting']['diag_plot']:
180 if model['plotting']['iter_plot_range'] is not None:
181 err = ValueError("iter_plot_range should be 'null' (None)"
182 " or a list specifying the start and end"
183 " points of the range to be plotted")
184 try:
185 if len(model['plotting']['iter_plot_range']) != 2:
186 raise err
187 except TypeError:
188 raise err
190 if model['plotting']['stretch'] not in ["power", "asinh"]:
191 err = ValueError("stretch should be 'power' or 'asinh'")
192 raise err
194 if model['modify_data']['baseline_range'] is not None:
195 err = ValueError("baseline_range should be 'null' (None)"
196 " or a list specifying the low and high"
197 " baselines [unit: \\lambda] outside of which the"
198 " data will be truncated before fitting")
199 try:
200 if len(model['modify_data']['baseline_range']) != 2:
201 raise err
202 except TypeError:
203 raise err
205 if model['input_output']['format'] is None:
206 path, format = os.path.splitext(uv_path)
207 if format in {'.gz', '.bz2'}:
208 format = os.path.splitext(path)[1]
209 model['input_output']['format'] = format[1:]
211 param_path = save_prefix + '_frank_used_pars.json'
213 logging.info(
214 ' Saving parameters used to {}'.format(param_path))
215 with open(param_path, 'w') as f:
216 json.dump(model, f, indent=4)
218 return model, param_path
221def load_data(model):
222 r"""
223 Read in a UVTable with data to be fit. See frank.io.load_uvtable
225 Parameters
226 ----------
227 model : dict
228 Dictionary containing model parameters the fit uses
230 Returns
231 -------
232 u, v : array, unit = :math:`\lambda`
233 u and v coordinates of observations
234 vis : array, unit = Jy
235 Observed visibilities (complex: real + imag * 1j)
236 weights : array, unit = Jy^-2
237 Weights assigned to observed visibilities, of the form
238 :math:`1 / \sigma^2`
239 """
241 u, v, vis, weights = io.load_uvtable(
242 model['input_output']['uvtable_filename'])
244 return u, v, vis, weights
247def alter_data(u, v, vis, weights, geom, model):
248 r"""
249 Apply one or more modifications to the data as specified in the parameter file
251 Parameters
252 ----------
253 u, v : array, unit = :math:`\lambda`
254 u and v coordinates of observations
255 vis : array, unit = Jy
256 Observed visibilities (complex: real + imag * 1j)
257 weights : array, unit = Jy^-2
258 Weights assigned to observed visibilities, of the form
259 :math:`1 / \sigma^2`
260 geom : SourceGeometry object
261 Fitted geometry (see frank.geometry.SourceGeometry).
262 model : dict
263 Dictionary containing model parameters the fit uses
265 Returns
266 -------
267 u, v, vis, weights : Parameters as above, with any or all altered according
268 to the modification operations specified in model
269 """
271 if model['modify_data']['norm_wle'] is not None:
272 u, v = utilities.normalize_uv(
273 u, v, model['modify_data']['norm_wle'])
275 if model['modify_data']['baseline_range']:
276 u, v, vis, weights = \
277 utilities.cut_data_by_baseline(u, v, vis, weights,
278 model['modify_data']['baseline_range'],
279 geom)
281 if model['modify_data']['correct_weights']:
282 up, vp = geom.deproject(u,v)
283 weights = utilities.estimate_weights(up, vp, vis,
284 use_median=model['modify_data']['use_median_weight'])
286 return u, v, vis, weights
289def determine_geometry(u, v, vis, weights, model):
290 r"""
291 Determine the source geometry (inclination, position angle, phase offset)
293 Parameters
294 ----------
295 u, v : array, unit = :math:`\lambda`
296 u and v coordinates of observations
297 vis : array, unit = Jy
298 Observed visibilities (complex: real + imag * 1j)
299 weights : array, unit = Jy^-2
300 Weights assigned to observed visibilities, of the form
301 :math:`1 / \sigma^2`
302 model : dict
303 Dictionary containing model parameters the fit uses
305 Returns
306 -------
307 geom : SourceGeometry object
308 Fitted geometry (see frank.geometry.SourceGeometry)
309 """
311 logging.info(' Determining disc geometry')
313 if model['geometry']['type'] == 'known':
314 logging.info(' Using your provided geometry for deprojection')
316 if all(x == 0 for x in (model['geometry']['inc'],
317 model['geometry']['pa'],
318 model['geometry']['dra'],
319 model['geometry']['ddec'])):
320 logging.info(" N.B.: All geometry parameters are 0 --> No geometry"
321 " correction will be applied to the visibilities"
322 )
324 geom = geometry.FixedGeometry(model['geometry']['inc'],
325 model['geometry']['pa'],
326 model['geometry']['dra'],
327 model['geometry']['ddec']
328 )
330 elif model['geometry']['type'] in ('gaussian', 'nonparametric'):
331 t1 = time.time()
333 if model['geometry']['initial_guess']:
334 guess = [model['geometry']['inc'], model['geometry']['pa'],
335 model['geometry']['dra'], model['geometry']['ddec']]
336 else:
337 guess = None
339 if model['geometry']['fit_inc_pa']:
340 inc_pa = None
341 else:
342 inc_pa = (model['geometry']['inc'],
343 model['geometry']['pa'])
345 if model['geometry']['fit_phase_offset']:
346 phase_centre = None
347 else:
348 phase_centre = (model['geometry']['dra'],
349 model['geometry']['ddec'])
352 if model['geometry']['type'] == 'gaussian':
353 geom = geometry.FitGeometryGaussian(
354 inc_pa=inc_pa,
355 phase_centre=phase_centre, guess=guess,
356 )
357 else:
358 geom = geometry.FitGeometryFourierBessel(
359 model['hyperparameters']['rout'], N=20,
360 inc_pa=inc_pa,
361 phase_centre=phase_centre, guess=guess
362 )
364 geom.fit(u, v, vis, weights)
366 logging.info(' Time taken for geometry %.1f sec' %
367 (time.time() - t1))
370 else:
371 raise ValueError("`geometry : type` in your parameter file must be one of"
372 " 'known', 'gaussian' or 'nonparametric'.")
374 logging.info(' Using: inc = {:.2f} deg,\n PA = {:.2f} deg,\n'
375 ' dRA = {:.2e} arcsec,\n'
376 ' dDec = {:.2e} arcsec'.format(geom.inc, geom.PA,
377 geom.dRA, geom.dDec))
379 # Store geometry
380 geom = geom.clone()
382 return geom
385def get_scale_height(model):
386 """
387 Parse the functional form for disc scale-height in the parameter file
389 Parameters
390 ----------
391 model : dict
392 Dictionary containing model parameters the fit uses
394 Returns
395 -------
396 scale_height : function R --> H
397 Returns None if scale_height is 'null' in the input parameter file.
398 Else, returns a function for the vertical thickness of disc provided in
399 the parameter file.
401 """
403 if model['geometry']['scale_height'] is None:
404 return
406 else:
407 if model['geometry']['rescale_flux']:
408 err = ValueError("scale_height should be 'null' if rescale_flux is 'true'")
409 raise err
411 def scale_height(R):
412 """Power-law with cutoff, unit=[arcsec]"""
413 vars = model['geometry']['scale_height']
414 h0, a, r0, b = vars['h0'], vars['a'], vars['r0'], vars['b']
416 return h0 * R ** a * np.exp(-(R / r0) ** b)
418 return scale_height
421def perform_fit(u, v, vis, weights, geom, model):
422 r"""
423 Deproject the observed visibilities and fit them for the brightness profile
425 Parameters
426 ----------
427 u, v : array, unit = :math:`\lambda`
428 u and v coordinates of observations
429 vis : array, unit = Jy
430 Observed visibilities (complex: real + imag * 1j)
431 weights : array, unit = Jy^-2
432 Weights assigned to observed visibilities, of the form
433 :math:`1 / \sigma^2`
434 geom : SourceGeometry object
435 Fitted geometry (see frank.geometry.SourceGeometry)
436 model : dict
437 Dictionary containing model parameters the fit uses
439 Returns
440 -------
441 sol : _HankelRegressor object
442 Reconstructed profile using Maximum a posteriori power spectrum
443 (see frank.radial_fitters.FrankFitter)
444 iteration_diagnostics : _HankelRegressor object
445 Diagnostics of the fit iteration
446 (see radial_fitters.FrankFitter.fit)
447 """
449 need_iterations = model['input_output']['iteration_diag'] or \
450 model['plotting']['diag_plot']
452 scale_height = get_scale_height(model)
454 t1 = time.time()
455 FF = radial_fitters.FrankFitter(Rmax=model['hyperparameters']['rout'],
456 N=model['hyperparameters']['n'],
457 geometry=geom,
458 alpha=model['hyperparameters']['alpha'],
459 p_0=model['hyperparameters']['p0'],
460 weights_smooth=model['hyperparameters']['wsmooth'],
461 tol=model['hyperparameters']['iter_tol'],
462 method=model['hyperparameters']['method'],
463 I_scale=model['hyperparameters']['I_scale'],
464 max_iter=model['hyperparameters']['max_iter'],
465 store_iteration_diagnostics=need_iterations,
466 convergence_failure=model['hyperparameters']['converge_failure'],
467 scale_height=scale_height,
468 assume_optically_thick=model['geometry']['rescale_flux']
469 )
471 sol = FF.fit(u, v, vis, weights)
473 if model['hyperparameters']['nonnegative'] and \
474 model['hyperparameters']['method'] == 'Normal':
475 # Add the best fit nonnegative solution to the fit's `sol` object
476 logging.info(' `nonnegative` is `true` in your parameter file --> '
477 'Storing the best fit nonnegative profile as the attribute '
478 '`nonneg` in the `sol` object')
479 setattr(sol, 'nonneg', sol.solve_non_negative())
481 logging.info(' Time taken to fit profile (with {:.0e} visibilities and'
482 ' {:d} collocation points) {:.1f} sec'.format(len(u),
483 model['hyperparameters']['n'],
484 time.time() - t1)
485 )
487 if need_iterations:
488 return sol, FF.iteration_diagnostics
489 else:
490 return [sol, None]
493def run_multiple_fits(u, v, vis, weights, geom, model):
494 r"""
495 Perform and overplot multiple fits to a dataset by varying two of the
496 model hyperparameters
498 Parameters
499 ----------
500 u, v : array, unit = :math:`\lambda`
501 u and v coordinates of observations
502 vis : array, unit = Jy
503 Observed visibilities (complex: real + imag * 1j)
504 weights : array, unit = Jy^-2
505 Weights assigned to observed visibilities, of the form
506 :math:`1 / \sigma^2`
507 geom : SourceGeometry object
508 Fitted geometry (see frank.geometry.SourceGeometry)
509 model : dict
510 Dictionary containing model parameters the fits use
512 Returns
513 -------
514 multifit_fig : Matplotlib `.Figure` instance
515 All produced figures, including the GridSpecs
516 multifit_axes : Matplotlib `~.axes.Axes` class
517 Axes for each of the produced figures
518 """
520 logging.info(' Looping fits over the hyperparameters `alpha` and `wsmooth`')
521 alphas = model['hyperparameters']['alpha']
522 ws = model['hyperparameters']['wsmooth']
523 sols = []
525 def number_to_list(x):
526 if np.isscalar(x):
527 return [x]
528 return x
530 alphas = number_to_list(alphas)
531 ws = number_to_list(ws)
533 import copy
534 for ii in range(len(alphas)):
535 for jj in range(len(ws)):
536 this_model = copy.deepcopy(model)
537 this_model['hyperparameters']['alpha'] = alphas[ii]
538 this_model['hyperparameters']['wsmooth'] = ws[jj]
539 this_model['input_output']['save_prefix'] += '_alpha{}_wsmooth{}'.format(alphas[ii], ws[jj])
541 logging.info(' Running fit for alpha = {}, wsmooth = {}'.format(alphas[ii], ws[jj]))
543 sol, iteration_diagnostics = perform_fit(u, v, vis, weights, geom, this_model)
544 sols.append(sol)
546 # Save the fit for the current choice of hyperparameter values
547 output_results(u, v, vis, weights, sol, geom, this_model,
548 iteration_diagnostics=iteration_diagnostics)
550 if len(alphas) in [1,2] and len(ws) in [1,2]:
551 multifit_fig, multifit_axes = make_figs.make_multifit_fig(u=u, v=v, vis=vis,
552 weights=weights, sols=sols,
553 bin_widths=model['plotting']['bin_widths'],
554 dist=model['plotting']['distance'],
555 force_style=model['plotting']['force_style'],
556 save_prefix=model['input_output']['fig_save_prefix'],
557 )
558 else:
559 logging.info('The multifit figure requires alpha and wsmooth to be lists of length <= 2.'
560 'Your lists are length {} and {} --> The multifit figure will not be made.'.format(len(alphas), len(ws)))
561 multifit_fig, multifit_axes = None, None
563 return multifit_fig, multifit_axes
566def output_results(u, v, vis, weights, sol, geom, model, iteration_diagnostics=None):
567 r"""
568 Save datafiles of fit results; generate and save figures of fit results (see
569 frank.io.save_fit, frank.make_figs)
571 Parameters
572 ----------
573 u, v : array, unit = :math:`\lambda`
574 u and v coordinates of observations
575 vis : array, unit = Jy
576 Observed visibilities (complex: real + imag * 1j)
577 weights : array, unit = Jy^-2
578 Weights assigned to observed visibilities, of the form
579 :math:`1 / \sigma^2`
580 sol : _HankelRegressor object
581 Reconstructed profile using Maximum a posteriori power spectrum
582 (see frank.radial_fitters.FrankFitter)
583 geom : SourceGeometry object
584 Fitted geometry (see frank.geometry.SourceGeometry)
585 model : dict
586 Dictionary containing model parameters the fit uses
587 iteration_diagnostics : _HankelRegressor object, optional, default=None
588 Diagnostics of the fit iteration
589 (see radial_fitters.FrankFitter.fit)
591 Returns
592 -------
593 figs : Matplotlib `.Figure` instance
594 All produced figures, including the GridSpecs
595 axes : Matplotlib `~.axes.Axes` class
596 Axes for each of the produced figures
597 """
599 if any([model['input_output']['save_solution'],
600 model['input_output']['save_profile_fit'],
601 model['input_output']['save_vis_fit'],
602 model['input_output']['save_uvtables']
603 ]):
605 io.save_fit(u, v, vis, weights, sol,
606 model['input_output']['save_prefix'],
607 model['input_output']['save_solution'],
608 model['input_output']['save_profile_fit'],
609 model['input_output']['save_vis_fit'],
610 model['input_output']['save_uvtables'],
611 model['input_output']['iteration_diag'],
612 iteration_diagnostics,
613 model['input_output']['format']
614 )
617 figs, axes = [], []
619 if any([model['plotting']['deprojec_plot'],
620 model['plotting']['quick_plot'],
621 model['plotting']['full_plot'],
622 model['plotting']['diag_plot'],
623 model['analysis']['compare_profile']
624 ]):
626 logging.info(' Plotting results')
628 if model['plotting']['deprojec_plot']:
629 deproj_fig, deproj_axes = make_figs.make_deprojection_fig(u=u, v=v,
630 vis=vis, weights=weights,
631 geom=geom,
632 bin_widths=model['plotting']['bin_widths'],
633 force_style=model['plotting']['force_style'],
634 save_prefix=model['input_output']['fig_save_prefix']
635 )
637 figs.append(deproj_fig)
638 axes.append(deproj_axes)
640 if model['plotting']['quick_plot']:
641 quick_fig, quick_axes = make_figs.make_quick_fig(u=u, v=v, vis=vis,
642 weights=weights, sol=sol,
643 bin_widths=model['plotting']['bin_widths'],
644 dist=model['plotting']['distance'],
645 logx=model['plotting']['plot_in_logx'],
646 force_style=model['plotting']['force_style'],
647 save_prefix=model['input_output']['fig_save_prefix'],
648 stretch=model['plotting']['stretch'],
649 gamma=model['plotting']['gamma'],
650 asinh_a=model['plotting']['asinh_a']
651 )
653 figs.append(quick_fig)
654 axes.append(quick_axes)
656 if model['plotting']['full_plot']:
657 full_fig, full_axes = make_figs.make_full_fig(u=u, v=v, vis=vis,
658 weights=weights, sol=sol,
659 bin_widths=model['plotting']['bin_widths'],
660 dist=model['plotting']['distance'],
661 logx=model['plotting']['plot_in_logx'],
662 force_style=model['plotting']['force_style'],
663 save_prefix=model['input_output']['fig_save_prefix'],
664 norm_residuals=model['plotting']['norm_residuals'],
665 stretch=model['plotting']['stretch'],
666 gamma=model['plotting']['gamma'],
667 asinh_a=model['plotting']['asinh_a']
668 )
670 figs.append(full_fig)
671 axes.append(full_axes)
673 if model['plotting']['diag_plot']:
674 diag_fig, diag_axes, _ = make_figs.make_diag_fig(r=sol.r, q=sol.q,
675 iteration_diagnostics=iteration_diagnostics,
676 iter_plot_range=model['plotting']['iter_plot_range'],
677 force_style=model['plotting']['force_style'],
678 save_prefix=model['input_output']['fig_save_prefix']
679 )
681 figs.append(diag_fig)
682 axes.append(diag_axes)
684 if model['analysis']['compare_profile']:
685 dat = np.genfromtxt(model['analysis']['compare_profile']).T
687 if len(dat) not in [2,3,4]:
688 raise ValueError("The file in your .json's `analysis` --> "
689 "`compare_profile` must have 2, 3 or 4 "
690 "columns: r [arcsec], I [Jy / sr], "
691 "negative uncertainty [Jy / sr] (optional), "
692 "positive uncertainty [Jy / sr] (optional, "
693 "assumed equal to negative uncertainty if not "
694 "provided).")
696 r_clean, I_clean = dat[0], dat[1]
697 if len(dat) == 3:
698 lo_err_clean, hi_err_clean = dat[2], dat[2]
699 elif len(dat) == 4:
700 lo_err_clean, hi_err_clean = dat[2], dat[3]
701 else:
702 lo_err_clean, hi_err_clean = None, None
703 clean_profile = {'r': r_clean, 'I': I_clean, 'lo_err': lo_err_clean,
704 'hi_err': hi_err_clean}
706 MAP_convolved = None
707 if model['analysis']['clean_beam']['bmaj'] is not None:
708 MAP_convolved = utilities.convolve_profile(sol.r, sol.I,
709 geom.inc, geom.PA,
710 model['analysis']['clean_beam'])
712 clean_fig, clean_axes = make_figs.make_clean_comparison_fig(u=u, v=v, vis=vis,
713 weights=weights, sol=sol,
714 clean_profile=clean_profile,
715 bin_widths=model['plotting']['bin_widths'],
716 stretch=model['plotting']['stretch'],
717 gamma=model['plotting']['gamma'],
718 asinh_a=model['plotting']['asinh_a'],
719 MAP_convolved=MAP_convolved,
720 dist=model['plotting']['distance'],
721 force_style=model['plotting']['force_style'],
722 save_prefix=model['input_output']['fig_save_prefix']
723 )
725 figs.append(clean_fig)
726 axes.append(clean_axes)
728 return figs, axes, model
731def perform_bootstrap(u, v, vis, weights, geom, model):
732 r"""
733 Perform a bootstrap analysis for the Franktenstein fit to a dataset
735 Parameters
736 ----------
737 u, v : array, unit = :math:`\lambda`
738 u and v coordinates of observations
739 vis : array, unit = Jy
740 Observed visibilities (complex: real + imag * 1j)
741 weights : array, unit = Jy^-2
742 Weights assigned to observed visibilities, of the form
743 :math:`1 / \sigma^2`
744 geom : SourceGeometry object
745 Fitted geometry (see frank.geometry.SourceGeometry)
746 model : dict
747 Dictionary containing model parameters the fit uses
749 Returns
750 -------
751 boot_fig : Matplotlib `.Figure` instance
752 The produced figure, including the GridSpec
753 boot_axes : Matplotlib `~.axes.Axes` class
754 The axes of the produced figure
755 """
757 if type(model['hyperparameters']['alpha']) is list or \
758 type(model['hyperparameters']['wsmooth']) is list:
759 raise ValueError("For the bootstrap, both `alpha` and `wsmooth` in your "
760 "parameter file must be a float, not a list.")
762 profiles_bootstrap = []
764 if model['hyperparameters']['nonnegative']:
765 logging.info(' `nonnegative` is `true` in your parameter file --> '
766 'The best fit nonnegative profile (rather than the MAP '
767 'profile) will be saved and used to generate the bootstrap '
768 'figure')
770 for trial in range(model['analysis']['bootstrap_ntrials']):
771 logging.info(' Bootstrap trial {} of {}'.format(trial + 1,
772 model['analysis']['bootstrap_ntrials']))
774 u_s, v_s, vis_s, w_s = utilities.draw_bootstrap_sample(
775 u, v, vis, weights)
777 sol, iteration_diagnostics = perform_fit(u_s, v_s, vis_s, w_s, geom, model)
779 if model['hyperparameters']['nonnegative']:
780 profiles_bootstrap.append(sol.nonneg)
781 else:
782 profiles_bootstrap.append(sol.I)
784 bootstrap_path = model['input_output']['save_prefix'] + '_bootstrap.npz'
786 logging.info(' Bootstrap complete. Saving fitted brightness profiles and'
787 ' the common set of collocation points')
789 np.savez(bootstrap_path, r=sol.r, profiles=np.array(profiles_bootstrap))
791 boot_fig, boot_axes = make_figs.make_bootstrap_fig(r=sol.r,
792 profiles=profiles_bootstrap,
793 force_style=model['plotting']['force_style'],
794 save_prefix=model['input_output']['fig_save_prefix']
795 )
797 return boot_fig, boot_axes
800def main(*args):
801 """Run the full Frankenstein pipeline to fit a dataset
803 Parameters
804 ----------
805 *args : strings
806 Simulates the command line arguments
807 """
809 model, param_path = parse_parameters(*args)
811 u, v, vis, weights = load_data(model)
813 geom = determine_geometry(u, v, vis, weights, model)
815 if model['modify_data']['baseline_range'] or \
816 model['modify_data']['correct_weights'] or \
817 model['modify_data']['norm_wle']:
818 u, v, vis, weights = alter_data(
819 u, v, vis, weights, geom, model)
821 # check units of (u,v)
822 # (after conversion if model['modify_data']['norm_wle'] is True)
823 utilities.check_uv(u, v)
825 if model['analysis']['bootstrap_ntrials']:
826 boot_fig, boot_axes = perform_bootstrap(
827 u, v, vis, weights, geom, model)
829 return boot_fig, boot_axes
831 elif type(model['hyperparameters']['alpha']) is list or \
832 type(model['hyperparameters']['wsmooth']) is list:
833 multifit_fig, multifit_axes = run_multiple_fits(u, v, vis, weights,
834 geom, model)
836 return multifit_fig, multifit_axes
838 else:
839 sol, iteration_diagnostics = perform_fit(
840 u, v, vis, weights, geom, model)
842 figs, axes, model = output_results(u, v, vis, weights, sol, geom, model,
843 iteration_diagnostics
844 )
846 logging.info(' Updating {} with final parameters used'
847 ''.format(param_path))
848 with open(param_path, 'w') as f:
849 json.dump(model, f, indent=4)
851 logging.info("IT'S ALIVE!!\n")
853 return figs, axes
856if __name__ == "__main__":
857 main()