Coverage for frank/fit.py: 78%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

259 statements  

1# Frankenstein: 1D disc brightness profile reconstruction from Fourier data 

2# using non-parametric Gaussian Processes 

3# 

4# Copyright (C) 2019-2020 R. Booth, J. Jennings, M. Tazzari 

5# 

6# This program is free software: you can redistribute it and/or modify 

7# it under the terms of the GNU General Public License as published by 

8# the Free Software Foundation, either version 3 of the License, or 

9# (at your option) any later version. 

10# 

11# This program is distributed in the hope that it will be useful, 

12# but WITHOUT ANY WARRANTY; without even the implied warranty of 

13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

14# GNU General Public License for more details. 

15# 

16# You should have received a copy of the GNU General Public License 

17# along with this program. If not, see <https://www.gnu.org/licenses/> 

18# 

19"""This module runs Frankenstein to fit a source's 1D radial brightness profile. 

20 A default parameter file is used that specifies all options to run the fit 

21 and output results. Alternatively a custom parameter file can be provided. 

22""" 

23 

24import os 

25import time 

26import json 

27 

28# Force frank to run on a single thread if we are using it as a library 

29def _check_and_warn_if_parallel(): 

30 """Check numpy is running in parallel""" 

31 num_threads = int(os.environ.get('OMP_NUM_THREADS', '1')) 

32 if num_threads > 1: 

33 logging.warning("WARNING: You are running frank with " 

34 "OMP_NUM_THREADS={}.".format(num_threads) + 

35 "The code will likely run faster on a single thread.\n" 

36 "Use 'unset OMP_NUM_THREADS' or " 

37 "'export OMP_NUM_THREADS=1' to disable this warning.") 

38 

39import numpy as np 

40import logging 

41 

42import frank 

43from frank import io, geometry, make_figs, radial_fitters, utilities 

44 

45frank_path = os.path.dirname(frank.__file__) 

46 

47 

48def get_default_parameter_file(): 

49 """Get the path to the default parameter file""" 

50 return os.path.join(frank_path, 'default_parameters.json') 

51 

52 

53def load_default_parameters(): 

54 """Load the default parameters""" 

55 return json.load(open(get_default_parameter_file(), 'r')) 

56 

57 

58def get_parameter_descriptions(): 

59 """Get the description for parameters""" 

60 with open(os.path.join(frank_path, 'parameter_descriptions.json')) as f: 

61 param_descrip = json.load(f) 

62 return param_descrip 

63 

64 

65def helper(): 

66 param_descrip = get_parameter_descriptions() 

67 

68 print(""" 

69 Fit a 1D radial brightness profile with Frankenstein (frank) from the 

70 terminal with `python -m frank.fit`. A .json parameter file is required; 

71 the default is default_parameters.json and is 

72 of the form:\n\n {}""".format(json.dumps(param_descrip, indent=4))) 

73 

74 

75def parse_parameters(*args): 

76 """ 

77 Read in a .json parameter file to set the fit parameters 

78 

79 Parameters 

80 ---------- 

81 parameter_filename : string, default `default_parameters.json` 

82 Parameter file (.json; see frank.fit.helper) 

83 uvtable_filename : string 

84 UVTable file with data to be fit (.txt, .dat, .npy, or .npz). 

85 The UVTable column format should be: 

86 u [lambda] v [lambda] Re(V) [Jy] Im(V) [Jy] Weight [Jy^-2] 

87 

88 Returns 

89 ------- 

90 model : dict 

91 Dictionary containing model parameters the fit uses 

92 """ 

93 

94 import argparse 

95 

96 default_param_file = os.path.join(frank_path, 'default_parameters.json') 

97 

98 parser = argparse.ArgumentParser("Run a Frankenstein fit, by default using" 

99 " parameters in default_parameters.json") 

100 parser.add_argument("-p", "--parameter_filename", 

101 default=default_param_file, type=str, 

102 help="Parameter file (.json; see frank.fit.helper)") 

103 parser.add_argument("-uv", "--uvtable_filename", default=None, type=str, 

104 help="UVTable file with data to be fit. See" 

105 " frank.io.load_uvtable") 

106 parser.add_argument("-desc", "--print_parameter_description", default=None, 

107 action="store_true", 

108 help="Print the full description of all fit parameters") 

109 

110 args = parser.parse_args(*args) 

111 

112 if args.print_parameter_description: 

113 helper() 

114 exit() 

115 

116 model = json.load(open(args.parameter_filename, 'r')) 

117 

118 if args.uvtable_filename: 

119 model['input_output']['uvtable_filename'] = args.uvtable_filename 

120 

121 if ('uvtable_filename' not in model['input_output'] or 

122 not model['input_output']['uvtable_filename']): 

123 raise ValueError("uvtable_filename isn't specified." 

124 " Set it in the parameter file or run Frankenstein with" 

125 " python -m frank.fit -uv <uvtable_filename>") 

126 

127 uv_path = model['input_output']['uvtable_filename'] 

128 if not model['input_output']['save_dir']: 

129 # If not specified, use the UVTable directory as the save directory 

130 model['input_output']['save_dir'] = os.path.dirname(uv_path) 

131 

132 # Add a save prefix to the .json parameter file for later use 

133 model['input_output']['save_prefix'] = save_prefix = \ 

134 os.path.join(model['input_output']['save_dir'], 

135 os.path.splitext(os.path.basename(uv_path))[0]) 

136 

137 if model['input_output']['save_figures'] is True: 

138 model['input_output']['fig_save_prefix'] = save_prefix 

139 else: 

140 model['input_output']['fig_save_prefix'] = None 

141 

142 log_path = save_prefix + '_frank_fit.log' 

143 frank.enable_logging(log_path) 

144 

145 # Check whether the code runs in parallel now that the logging has been 

146 # initialized. 

147 _check_and_warn_if_parallel() 

148 

149 

150 logging.info('\nRunning Frankenstein on' 

151 ' {}'.format(model['input_output']['uvtable_filename'])) 

152 

153 # Sanity check some of the .json parameters 

154 if model['hyperparameters']['method'] not in ["Normal", "LogNormal"]: 

155 err = ValueError("method should be 'Normal' or 'LogNormal'") 

156 raise err 

157 

158 if model['geometry']['scale_height'] is not None and \ 

159 model['hyperparameters']['method'] == 'LogNormal': 

160 logging.warning("WARNING: 'scale_height' is not 'None' AND 'method' is " 

161 "'LogNormal'. It is suggested to use the 'Normal' method " 

162 "for vertical inference.") 

163 

164 if model['hyperparameters']['nonnegative'] and \ 

165 model['hyperparameters']['method'] == 'LogNormal': 

166 logging.warning("WARNING: 'nonnegative' is 'true' AND 'method' is " 

167 "'LogNormal' --> performing a LogNormal fit and setting " 

168 "'nonnegative' to 'false'") 

169 model['hyperparameters']['nonnegative'] = False 

170 

171 if model['hyperparameters']['method'] == 'LogNormal' and \ 

172 model['hyperparameters']['p0'] is not None and \ 

173 model['hyperparameters']['p0'] > 1e-30: 

174 err = ValueError("p0 = {}. If method is 'LogNormal', p0 should be " 

175 "<~ 1e-35 (we recommend 1e-35)" 

176 ".".format(model['hyperparameters']['p0'])) 

177 raise err 

178 

179 if model['plotting']['diag_plot']: 

180 if model['plotting']['iter_plot_range'] is not None: 

181 err = ValueError("iter_plot_range should be 'null' (None)" 

182 " or a list specifying the start and end" 

183 " points of the range to be plotted") 

184 try: 

185 if len(model['plotting']['iter_plot_range']) != 2: 

186 raise err 

187 except TypeError: 

188 raise err 

189 

190 if model['plotting']['stretch'] not in ["power", "asinh"]: 

191 err = ValueError("stretch should be 'power' or 'asinh'") 

192 raise err 

193 

194 if model['modify_data']['baseline_range'] is not None: 

195 err = ValueError("baseline_range should be 'null' (None)" 

196 " or a list specifying the low and high" 

197 " baselines [unit: \\lambda] outside of which the" 

198 " data will be truncated before fitting") 

199 try: 

200 if len(model['modify_data']['baseline_range']) != 2: 

201 raise err 

202 except TypeError: 

203 raise err 

204 

205 if model['input_output']['format'] is None: 

206 path, format = os.path.splitext(uv_path) 

207 if format in {'.gz', '.bz2'}: 

208 format = os.path.splitext(path)[1] 

209 model['input_output']['format'] = format[1:] 

210 

211 param_path = save_prefix + '_frank_used_pars.json' 

212 

213 logging.info( 

214 ' Saving parameters used to {}'.format(param_path)) 

215 with open(param_path, 'w') as f: 

216 json.dump(model, f, indent=4) 

217 

218 return model, param_path 

219 

220 

221def load_data(model): 

222 r""" 

223 Read in a UVTable with data to be fit. See frank.io.load_uvtable 

224 

225 Parameters 

226 ---------- 

227 model : dict 

228 Dictionary containing model parameters the fit uses 

229 

230 Returns 

231 ------- 

232 u, v : array, unit = :math:`\lambda` 

233 u and v coordinates of observations 

234 vis : array, unit = Jy 

235 Observed visibilities (complex: real + imag * 1j) 

236 weights : array, unit = Jy^-2 

237 Weights assigned to observed visibilities, of the form 

238 :math:`1 / \sigma^2` 

239 """ 

240 

241 u, v, vis, weights = io.load_uvtable( 

242 model['input_output']['uvtable_filename']) 

243 

244 return u, v, vis, weights 

245 

246 

247def alter_data(u, v, vis, weights, geom, model): 

248 r""" 

249 Apply one or more modifications to the data as specified in the parameter file 

250 

251 Parameters 

252 ---------- 

253 u, v : array, unit = :math:`\lambda` 

254 u and v coordinates of observations 

255 vis : array, unit = Jy 

256 Observed visibilities (complex: real + imag * 1j) 

257 weights : array, unit = Jy^-2 

258 Weights assigned to observed visibilities, of the form 

259 :math:`1 / \sigma^2` 

260 geom : SourceGeometry object 

261 Fitted geometry (see frank.geometry.SourceGeometry). 

262 model : dict 

263 Dictionary containing model parameters the fit uses 

264 

265 Returns 

266 ------- 

267 u, v, vis, weights : Parameters as above, with any or all altered according 

268 to the modification operations specified in model 

269 """ 

270 

271 if model['modify_data']['norm_wle'] is not None: 

272 u, v = utilities.normalize_uv( 

273 u, v, model['modify_data']['norm_wle']) 

274 

275 if model['modify_data']['baseline_range']: 

276 u, v, vis, weights = \ 

277 utilities.cut_data_by_baseline(u, v, vis, weights, 

278 model['modify_data']['baseline_range'], 

279 geom) 

280 

281 if model['modify_data']['correct_weights']: 

282 up, vp = geom.deproject(u,v) 

283 weights = utilities.estimate_weights(up, vp, vis, 

284 use_median=model['modify_data']['use_median_weight']) 

285 

286 return u, v, vis, weights 

287 

288 

289def determine_geometry(u, v, vis, weights, model): 

290 r""" 

291 Determine the source geometry (inclination, position angle, phase offset) 

292 

293 Parameters 

294 ---------- 

295 u, v : array, unit = :math:`\lambda` 

296 u and v coordinates of observations 

297 vis : array, unit = Jy 

298 Observed visibilities (complex: real + imag * 1j) 

299 weights : array, unit = Jy^-2 

300 Weights assigned to observed visibilities, of the form 

301 :math:`1 / \sigma^2` 

302 model : dict 

303 Dictionary containing model parameters the fit uses 

304 

305 Returns 

306 ------- 

307 geom : SourceGeometry object 

308 Fitted geometry (see frank.geometry.SourceGeometry) 

309 """ 

310 

311 logging.info(' Determining disc geometry') 

312 

313 if model['geometry']['type'] == 'known': 

314 logging.info(' Using your provided geometry for deprojection') 

315 

316 if all(x == 0 for x in (model['geometry']['inc'], 

317 model['geometry']['pa'], 

318 model['geometry']['dra'], 

319 model['geometry']['ddec'])): 

320 logging.info(" N.B.: All geometry parameters are 0 --> No geometry" 

321 " correction will be applied to the visibilities" 

322 ) 

323 

324 geom = geometry.FixedGeometry(model['geometry']['inc'], 

325 model['geometry']['pa'], 

326 model['geometry']['dra'], 

327 model['geometry']['ddec'] 

328 ) 

329 

330 elif model['geometry']['type'] in ('gaussian', 'nonparametric'): 

331 t1 = time.time() 

332 

333 if model['geometry']['initial_guess']: 

334 guess = [model['geometry']['inc'], model['geometry']['pa'], 

335 model['geometry']['dra'], model['geometry']['ddec']] 

336 else: 

337 guess = None 

338 

339 if model['geometry']['fit_inc_pa']: 

340 inc_pa = None 

341 else: 

342 inc_pa = (model['geometry']['inc'], 

343 model['geometry']['pa']) 

344 

345 if model['geometry']['fit_phase_offset']: 

346 phase_centre = None 

347 else: 

348 phase_centre = (model['geometry']['dra'], 

349 model['geometry']['ddec']) 

350 

351 

352 if model['geometry']['type'] == 'gaussian': 

353 geom = geometry.FitGeometryGaussian( 

354 inc_pa=inc_pa, 

355 phase_centre=phase_centre, guess=guess, 

356 ) 

357 else: 

358 geom = geometry.FitGeometryFourierBessel( 

359 model['hyperparameters']['rout'], N=20, 

360 inc_pa=inc_pa, 

361 phase_centre=phase_centre, guess=guess 

362 ) 

363 

364 geom.fit(u, v, vis, weights) 

365 

366 logging.info(' Time taken for geometry %.1f sec' % 

367 (time.time() - t1)) 

368 

369 

370 else: 

371 raise ValueError("`geometry : type` in your parameter file must be one of" 

372 " 'known', 'gaussian' or 'nonparametric'.") 

373 

374 logging.info(' Using: inc = {:.2f} deg,\n PA = {:.2f} deg,\n' 

375 ' dRA = {:.2e} arcsec,\n' 

376 ' dDec = {:.2e} arcsec'.format(geom.inc, geom.PA, 

377 geom.dRA, geom.dDec)) 

378 

379 # Store geometry 

380 geom = geom.clone() 

381 

382 return geom 

383 

384 

385def get_scale_height(model): 

386 """ 

387 Parse the functional form for disc scale-height in the parameter file 

388 

389 Parameters 

390 ---------- 

391 model : dict 

392 Dictionary containing model parameters the fit uses 

393 

394 Returns 

395 ------- 

396 scale_height : function R --> H 

397 Returns None if scale_height is 'null' in the input parameter file. 

398 Else, returns a function for the vertical thickness of disc provided in  

399 the parameter file. 

400 

401 """ 

402 

403 if model['geometry']['scale_height'] is None: 

404 return 

405 

406 else: 

407 if model['geometry']['rescale_flux']: 

408 err = ValueError("scale_height should be 'null' if rescale_flux is 'true'") 

409 raise err 

410 

411 def scale_height(R): 

412 """Power-law with cutoff, unit=[arcsec]""" 

413 vars = model['geometry']['scale_height'] 

414 h0, a, r0, b = vars['h0'], vars['a'], vars['r0'], vars['b'] 

415 

416 return h0 * R ** a * np.exp(-(R / r0) ** b) 

417 

418 return scale_height 

419 

420 

421def perform_fit(u, v, vis, weights, geom, model): 

422 r""" 

423 Deproject the observed visibilities and fit them for the brightness profile 

424 

425 Parameters 

426 ---------- 

427 u, v : array, unit = :math:`\lambda` 

428 u and v coordinates of observations 

429 vis : array, unit = Jy 

430 Observed visibilities (complex: real + imag * 1j) 

431 weights : array, unit = Jy^-2 

432 Weights assigned to observed visibilities, of the form 

433 :math:`1 / \sigma^2` 

434 geom : SourceGeometry object 

435 Fitted geometry (see frank.geometry.SourceGeometry) 

436 model : dict 

437 Dictionary containing model parameters the fit uses 

438 

439 Returns 

440 ------- 

441 sol : _HankelRegressor object 

442 Reconstructed profile using Maximum a posteriori power spectrum 

443 (see frank.radial_fitters.FrankFitter) 

444 iteration_diagnostics : _HankelRegressor object 

445 Diagnostics of the fit iteration 

446 (see radial_fitters.FrankFitter.fit) 

447 """ 

448 

449 need_iterations = model['input_output']['iteration_diag'] or \ 

450 model['plotting']['diag_plot'] 

451 

452 scale_height = get_scale_height(model) 

453 

454 t1 = time.time() 

455 FF = radial_fitters.FrankFitter(Rmax=model['hyperparameters']['rout'], 

456 N=model['hyperparameters']['n'], 

457 geometry=geom, 

458 alpha=model['hyperparameters']['alpha'], 

459 p_0=model['hyperparameters']['p0'], 

460 weights_smooth=model['hyperparameters']['wsmooth'], 

461 tol=model['hyperparameters']['iter_tol'], 

462 method=model['hyperparameters']['method'], 

463 I_scale=model['hyperparameters']['I_scale'], 

464 max_iter=model['hyperparameters']['max_iter'], 

465 store_iteration_diagnostics=need_iterations, 

466 convergence_failure=model['hyperparameters']['converge_failure'], 

467 scale_height=scale_height, 

468 assume_optically_thick=model['geometry']['rescale_flux'] 

469 ) 

470 

471 sol = FF.fit(u, v, vis, weights) 

472 

473 if model['hyperparameters']['nonnegative'] and \ 

474 model['hyperparameters']['method'] == 'Normal': 

475 # Add the best fit nonnegative solution to the fit's `sol` object 

476 logging.info(' `nonnegative` is `true` in your parameter file --> ' 

477 'Storing the best fit nonnegative profile as the attribute ' 

478 '`nonneg` in the `sol` object') 

479 setattr(sol, 'nonneg', sol.solve_non_negative()) 

480 

481 logging.info(' Time taken to fit profile (with {:.0e} visibilities and' 

482 ' {:d} collocation points) {:.1f} sec'.format(len(u), 

483 model['hyperparameters']['n'], 

484 time.time() - t1) 

485 ) 

486 

487 if need_iterations: 

488 return sol, FF.iteration_diagnostics 

489 else: 

490 return [sol, None] 

491 

492 

493def run_multiple_fits(u, v, vis, weights, geom, model): 

494 r""" 

495 Perform and overplot multiple fits to a dataset by varying two of the 

496 model hyperparameters 

497 

498 Parameters 

499 ---------- 

500 u, v : array, unit = :math:`\lambda` 

501 u and v coordinates of observations 

502 vis : array, unit = Jy 

503 Observed visibilities (complex: real + imag * 1j) 

504 weights : array, unit = Jy^-2 

505 Weights assigned to observed visibilities, of the form 

506 :math:`1 / \sigma^2` 

507 geom : SourceGeometry object 

508 Fitted geometry (see frank.geometry.SourceGeometry) 

509 model : dict 

510 Dictionary containing model parameters the fits use 

511 

512 Returns 

513 ------- 

514 multifit_fig : Matplotlib `.Figure` instance 

515 All produced figures, including the GridSpecs 

516 multifit_axes : Matplotlib `~.axes.Axes` class 

517 Axes for each of the produced figures 

518 """ 

519 

520 logging.info(' Looping fits over the hyperparameters `alpha` and `wsmooth`') 

521 alphas = model['hyperparameters']['alpha'] 

522 ws = model['hyperparameters']['wsmooth'] 

523 sols = [] 

524 

525 def number_to_list(x): 

526 if np.isscalar(x): 

527 return [x] 

528 return x 

529 

530 alphas = number_to_list(alphas) 

531 ws = number_to_list(ws) 

532 

533 import copy 

534 for ii in range(len(alphas)): 

535 for jj in range(len(ws)): 

536 this_model = copy.deepcopy(model) 

537 this_model['hyperparameters']['alpha'] = alphas[ii] 

538 this_model['hyperparameters']['wsmooth'] = ws[jj] 

539 this_model['input_output']['save_prefix'] += '_alpha{}_wsmooth{}'.format(alphas[ii], ws[jj]) 

540 

541 logging.info(' Running fit for alpha = {}, wsmooth = {}'.format(alphas[ii], ws[jj])) 

542 

543 sol, iteration_diagnostics = perform_fit(u, v, vis, weights, geom, this_model) 

544 sols.append(sol) 

545 

546 # Save the fit for the current choice of hyperparameter values 

547 output_results(u, v, vis, weights, sol, geom, this_model, 

548 iteration_diagnostics=iteration_diagnostics) 

549 

550 if len(alphas) in [1,2] and len(ws) in [1,2]: 

551 multifit_fig, multifit_axes = make_figs.make_multifit_fig(u=u, v=v, vis=vis, 

552 weights=weights, sols=sols, 

553 bin_widths=model['plotting']['bin_widths'], 

554 dist=model['plotting']['distance'], 

555 force_style=model['plotting']['force_style'], 

556 save_prefix=model['input_output']['fig_save_prefix'], 

557 ) 

558 else: 

559 logging.info('The multifit figure requires alpha and wsmooth to be lists of length <= 2.' 

560 'Your lists are length {} and {} --> The multifit figure will not be made.'.format(len(alphas), len(ws))) 

561 multifit_fig, multifit_axes = None, None 

562 

563 return multifit_fig, multifit_axes 

564 

565 

566def output_results(u, v, vis, weights, sol, geom, model, iteration_diagnostics=None): 

567 r""" 

568 Save datafiles of fit results; generate and save figures of fit results (see 

569 frank.io.save_fit, frank.make_figs) 

570 

571 Parameters 

572 ---------- 

573 u, v : array, unit = :math:`\lambda` 

574 u and v coordinates of observations 

575 vis : array, unit = Jy 

576 Observed visibilities (complex: real + imag * 1j) 

577 weights : array, unit = Jy^-2 

578 Weights assigned to observed visibilities, of the form 

579 :math:`1 / \sigma^2` 

580 sol : _HankelRegressor object 

581 Reconstructed profile using Maximum a posteriori power spectrum 

582 (see frank.radial_fitters.FrankFitter) 

583 geom : SourceGeometry object 

584 Fitted geometry (see frank.geometry.SourceGeometry) 

585 model : dict 

586 Dictionary containing model parameters the fit uses 

587 iteration_diagnostics : _HankelRegressor object, optional, default=None 

588 Diagnostics of the fit iteration 

589 (see radial_fitters.FrankFitter.fit) 

590 

591 Returns 

592 ------- 

593 figs : Matplotlib `.Figure` instance 

594 All produced figures, including the GridSpecs 

595 axes : Matplotlib `~.axes.Axes` class 

596 Axes for each of the produced figures 

597 """ 

598 

599 if any([model['input_output']['save_solution'], 

600 model['input_output']['save_profile_fit'], 

601 model['input_output']['save_vis_fit'], 

602 model['input_output']['save_uvtables'] 

603 ]): 

604 

605 io.save_fit(u, v, vis, weights, sol, 

606 model['input_output']['save_prefix'], 

607 model['input_output']['save_solution'], 

608 model['input_output']['save_profile_fit'], 

609 model['input_output']['save_vis_fit'], 

610 model['input_output']['save_uvtables'], 

611 model['input_output']['iteration_diag'], 

612 iteration_diagnostics, 

613 model['input_output']['format'] 

614 ) 

615 

616 

617 figs, axes = [], [] 

618 

619 if any([model['plotting']['deprojec_plot'], 

620 model['plotting']['quick_plot'], 

621 model['plotting']['full_plot'], 

622 model['plotting']['diag_plot'], 

623 model['analysis']['compare_profile'] 

624 ]): 

625 

626 logging.info(' Plotting results') 

627 

628 if model['plotting']['deprojec_plot']: 

629 deproj_fig, deproj_axes = make_figs.make_deprojection_fig(u=u, v=v, 

630 vis=vis, weights=weights, 

631 geom=geom, 

632 bin_widths=model['plotting']['bin_widths'], 

633 force_style=model['plotting']['force_style'], 

634 save_prefix=model['input_output']['fig_save_prefix'] 

635 ) 

636 

637 figs.append(deproj_fig) 

638 axes.append(deproj_axes) 

639 

640 if model['plotting']['quick_plot']: 

641 quick_fig, quick_axes = make_figs.make_quick_fig(u=u, v=v, vis=vis, 

642 weights=weights, sol=sol, 

643 bin_widths=model['plotting']['bin_widths'], 

644 dist=model['plotting']['distance'], 

645 logx=model['plotting']['plot_in_logx'], 

646 force_style=model['plotting']['force_style'], 

647 save_prefix=model['input_output']['fig_save_prefix'], 

648 stretch=model['plotting']['stretch'], 

649 gamma=model['plotting']['gamma'], 

650 asinh_a=model['plotting']['asinh_a'] 

651 ) 

652 

653 figs.append(quick_fig) 

654 axes.append(quick_axes) 

655 

656 if model['plotting']['full_plot']: 

657 full_fig, full_axes = make_figs.make_full_fig(u=u, v=v, vis=vis, 

658 weights=weights, sol=sol, 

659 bin_widths=model['plotting']['bin_widths'], 

660 dist=model['plotting']['distance'], 

661 logx=model['plotting']['plot_in_logx'], 

662 force_style=model['plotting']['force_style'], 

663 save_prefix=model['input_output']['fig_save_prefix'], 

664 norm_residuals=model['plotting']['norm_residuals'], 

665 stretch=model['plotting']['stretch'], 

666 gamma=model['plotting']['gamma'], 

667 asinh_a=model['plotting']['asinh_a'] 

668 ) 

669 

670 figs.append(full_fig) 

671 axes.append(full_axes) 

672 

673 if model['plotting']['diag_plot']: 

674 diag_fig, diag_axes, _ = make_figs.make_diag_fig(r=sol.r, q=sol.q, 

675 iteration_diagnostics=iteration_diagnostics, 

676 iter_plot_range=model['plotting']['iter_plot_range'], 

677 force_style=model['plotting']['force_style'], 

678 save_prefix=model['input_output']['fig_save_prefix'] 

679 ) 

680 

681 figs.append(diag_fig) 

682 axes.append(diag_axes) 

683 

684 if model['analysis']['compare_profile']: 

685 dat = np.genfromtxt(model['analysis']['compare_profile']).T 

686 

687 if len(dat) not in [2,3,4]: 

688 raise ValueError("The file in your .json's `analysis` --> " 

689 "`compare_profile` must have 2, 3 or 4 " 

690 "columns: r [arcsec], I [Jy / sr], " 

691 "negative uncertainty [Jy / sr] (optional), " 

692 "positive uncertainty [Jy / sr] (optional, " 

693 "assumed equal to negative uncertainty if not " 

694 "provided).") 

695 

696 r_clean, I_clean = dat[0], dat[1] 

697 if len(dat) == 3: 

698 lo_err_clean, hi_err_clean = dat[2], dat[2] 

699 elif len(dat) == 4: 

700 lo_err_clean, hi_err_clean = dat[2], dat[3] 

701 else: 

702 lo_err_clean, hi_err_clean = None, None 

703 clean_profile = {'r': r_clean, 'I': I_clean, 'lo_err': lo_err_clean, 

704 'hi_err': hi_err_clean} 

705 

706 MAP_convolved = None 

707 if model['analysis']['clean_beam']['bmaj'] is not None: 

708 MAP_convolved = utilities.convolve_profile(sol.r, sol.I, 

709 geom.inc, geom.PA, 

710 model['analysis']['clean_beam']) 

711 

712 clean_fig, clean_axes = make_figs.make_clean_comparison_fig(u=u, v=v, vis=vis, 

713 weights=weights, sol=sol, 

714 clean_profile=clean_profile, 

715 bin_widths=model['plotting']['bin_widths'], 

716 stretch=model['plotting']['stretch'], 

717 gamma=model['plotting']['gamma'], 

718 asinh_a=model['plotting']['asinh_a'], 

719 MAP_convolved=MAP_convolved, 

720 dist=model['plotting']['distance'], 

721 force_style=model['plotting']['force_style'], 

722 save_prefix=model['input_output']['fig_save_prefix'] 

723 ) 

724 

725 figs.append(clean_fig) 

726 axes.append(clean_axes) 

727 

728 return figs, axes, model 

729 

730 

731def perform_bootstrap(u, v, vis, weights, geom, model): 

732 r""" 

733 Perform a bootstrap analysis for the Franktenstein fit to a dataset 

734 

735 Parameters 

736 ---------- 

737 u, v : array, unit = :math:`\lambda` 

738 u and v coordinates of observations 

739 vis : array, unit = Jy 

740 Observed visibilities (complex: real + imag * 1j) 

741 weights : array, unit = Jy^-2 

742 Weights assigned to observed visibilities, of the form 

743 :math:`1 / \sigma^2` 

744 geom : SourceGeometry object 

745 Fitted geometry (see frank.geometry.SourceGeometry) 

746 model : dict 

747 Dictionary containing model parameters the fit uses 

748 

749 Returns 

750 ------- 

751 boot_fig : Matplotlib `.Figure` instance 

752 The produced figure, including the GridSpec 

753 boot_axes : Matplotlib `~.axes.Axes` class 

754 The axes of the produced figure 

755 """ 

756 

757 if type(model['hyperparameters']['alpha']) is list or \ 

758 type(model['hyperparameters']['wsmooth']) is list: 

759 raise ValueError("For the bootstrap, both `alpha` and `wsmooth` in your " 

760 "parameter file must be a float, not a list.") 

761 

762 profiles_bootstrap = [] 

763 

764 if model['hyperparameters']['nonnegative']: 

765 logging.info(' `nonnegative` is `true` in your parameter file --> ' 

766 'The best fit nonnegative profile (rather than the MAP ' 

767 'profile) will be saved and used to generate the bootstrap ' 

768 'figure') 

769 

770 for trial in range(model['analysis']['bootstrap_ntrials']): 

771 logging.info(' Bootstrap trial {} of {}'.format(trial + 1, 

772 model['analysis']['bootstrap_ntrials'])) 

773 

774 u_s, v_s, vis_s, w_s = utilities.draw_bootstrap_sample( 

775 u, v, vis, weights) 

776 

777 sol, iteration_diagnostics = perform_fit(u_s, v_s, vis_s, w_s, geom, model) 

778 

779 if model['hyperparameters']['nonnegative']: 

780 profiles_bootstrap.append(sol.nonneg) 

781 else: 

782 profiles_bootstrap.append(sol.I) 

783 

784 bootstrap_path = model['input_output']['save_prefix'] + '_bootstrap.npz' 

785 

786 logging.info(' Bootstrap complete. Saving fitted brightness profiles and' 

787 ' the common set of collocation points') 

788 

789 np.savez(bootstrap_path, r=sol.r, profiles=np.array(profiles_bootstrap)) 

790 

791 boot_fig, boot_axes = make_figs.make_bootstrap_fig(r=sol.r, 

792 profiles=profiles_bootstrap, 

793 force_style=model['plotting']['force_style'], 

794 save_prefix=model['input_output']['fig_save_prefix'] 

795 ) 

796 

797 return boot_fig, boot_axes 

798 

799 

800def main(*args): 

801 """Run the full Frankenstein pipeline to fit a dataset 

802 

803 Parameters 

804 ---------- 

805 *args : strings 

806 Simulates the command line arguments 

807 """ 

808 

809 model, param_path = parse_parameters(*args) 

810 

811 u, v, vis, weights = load_data(model) 

812 

813 geom = determine_geometry(u, v, vis, weights, model) 

814 

815 if model['modify_data']['baseline_range'] or \ 

816 model['modify_data']['correct_weights'] or \ 

817 model['modify_data']['norm_wle']: 

818 u, v, vis, weights = alter_data( 

819 u, v, vis, weights, geom, model) 

820 

821 if model['analysis']['bootstrap_ntrials']: 

822 boot_fig, boot_axes = perform_bootstrap( 

823 u, v, vis, weights, geom, model) 

824 

825 return boot_fig, boot_axes 

826 

827 elif type(model['hyperparameters']['alpha']) is list or \ 

828 type(model['hyperparameters']['wsmooth']) is list: 

829 multifit_fig, multifit_axes = run_multiple_fits(u, v, vis, weights, 

830 geom, model) 

831 

832 return multifit_fig, multifit_axes 

833 

834 else: 

835 sol, iteration_diagnostics = perform_fit( 

836 u, v, vis, weights, geom, model) 

837 

838 figs, axes, model = output_results(u, v, vis, weights, sol, geom, model, 

839 iteration_diagnostics 

840 ) 

841 

842 logging.info(' Updating {} with final parameters used' 

843 ''.format(param_path)) 

844 with open(param_path, 'w') as f: 

845 json.dump(model, f, indent=4) 

846 

847 logging.info("IT'S ALIVE!!\n") 

848 

849 return figs, axes 

850 

851 

852if __name__ == "__main__": 

853 main()