;+
; Name: hsi_vis_multiscale_clean
;
; Purpose: This function implements a multi-scale version of the CLEAN algorithm. It returns the multiscale-clean map in the case of
; RHESSI visibilities (a generalization for STIX visibilities is in progress).  
; 
; Method: In this multiscale approach, the solution is constructed as sum of sources of different size modeled by using a family 
; of basis functions (playing the role of the clean beam) each one characterized by a specific scale (the support in the case of
; paraboloids basis functions and the FWHM in the case of Gaussian basis functions).
; This approach explicitly relies on the fact that RHESSI and STIX instruments sample the (u,v) plane according to circles 
; characterized by increasing radii with the following implications:
;  1) the PSF is the sum of PSF components, each one associated to a specific circle (or groups of circles) of sampled visibilities 
;  in the (u,v) plane.
;  2) the dirty map corresponding to the measured visibilities can be interpreted as the sum of several dirty maps, each one 
;  corresponding to one PSF component.
; 
; Choice of the scales: The choice of the scales is made by the user by grouping the circles in the (u,v) plane (i.e., by grouping
; the detectors). Once this choice has been done, the code automatically computes the corresponding PSF component and the FWHM of its 
; central peak. Then it uses this value to build a basis function characterized by this scale value.
;  
; In the RHESSI case (for example):
; 
; grouping detectors   1 and 2   corresponds to  a scale of   3 arcsec 
; grouping detectors   1, 2, 3   corresponds to  a scale of   3 arcsec 
; grouping detectors   2, 3      corresponds to  a scale of   5 arcsec 
; grouping detectors   2, 3, 4   corresponds to  a scale of   6 arcsec
; grouping detectors   3, 4      corresponds to  a scale of   8 arcsec 
; grouping detectors   3, 4, 5   corresponds to  a scale of  10 arcsec 
; grouping detectors   4, 5      corresponds to  a scale of  14 arcsec
; grouping detectors   4, 5, 6   corresponds to  a scale of  17 arcsec
; grouping detectors   5, 6      corresponds to  a scale of  25 arcsec
; grouping detectors   5, 6, 7   corresponds to  a scale of  29 arcsec
; grouping detectors   6, 7      corresponds to  a scale of  41 arcsec
; grouping detectors   6, 7, 8   corresponds to  a scale of  46 arcsec
; grouping detectors   7, 8      corresponds to  a scale of  69 arcsec
; grouping detectors   7, 8, 9   corresponds to  a scale of  74 arcsec
; grouping detectors   8, 9      corresponds to  a scale of 116 arcsec
;
; Inputs:
;   - vis - visibility bag
;
; Keyword inputs:
;   - image_dim    number of pixels in x and y, 1 or 2 element long vector or scalar
;                  images are square so the second number isn't used
;   - pixel        pixel size in asec (pixels are square)
;   - det_scale    a 9-element vector describing how the 9 detectors are grouped in a number of scales. The position in the array
;                  corresponds to the number of the detector while the entry value corresponds to the assigned scale. 
;                  For instance, if 4 scales are used, the entries can assume the values:
;                  0  -->   the corresponding detector is not used 
;                  1  -->   the corresponding detector is assigned to group/scale 1
;                  2  -->   the corresponding detector is assigned to group/scale 2
;                  3  -->   the corresponding detector is assigned to group/scale 3
;                  4  -->   the corresponding detector is assigned to group/scale 4
;                  (i.e., [1 1 2 2 3 3 3 4 4] means detectors 1 and 2 assigned to scale 1; detectors 3 and 4 assigned to scale 2;
;                  detectors 5, 6 and 7 assigned to scale 3; detectors 8 and 9 to scale 4)   
;                  See restrictions below.            
;   - niter        max iterations  (default 100)
;   - gain         clean loop gain factor (default 0.05)
;   - gauss_basis  1b for gaussian basis-functions
;                  0b for parabolic basis-functions
;   - progress_map 1b to see the progress of the solution 
;   - progress_iter   number of iterations after which the progress of the solution has to be updated (default 5)
;   - negative_max if set stop when the absolute maximum is a negative value (default 1)
;
; Restrictions:
;   - Scale values corresponding to not-used detectors must be = 0 (i.e., [0 1 2 2 3 3 3 4 4] means detector 1 not used)
;   - Scale values must be in increasing order (i.e., [1 1 3 3 2 2 2 4 4] is NOT allowed)
;   - Scale values must cover the whole range between 1 and the max value (i.e., [1 1 2 2 4 4 4 5 5] is NOT allowed)
; 
; History:
; - October 2013 R.A. Schwartz, developed initial version
; -   April 2014 S. Giordano, updated clean loop, introduced both gaussians and paraboloids as pixons
; -   June  2014 S. Giordano, added residuals
; -   June  2014 A.M. Massone, introduced scale bias, taken into account of different number of visibilities at each scale
; -   Jan   2017 R.A. Schwartz, removed "absmax" function in favor of max() with /ABSOLUTE flag
; -   June  2017 A.M. Massone, improved description of the method (in particular relation between beam and scales), deleted the 
;                              dirty maps cleaning progress shown as surfaces, added the dirty maps cleaning progress shown as maps,
;                              added the keyword progress_iter
;-

function hsi_vis_multiscale_clean, vis, image_dim = image_dim, pixel_size = pixel_size, $
                               det_scale = det_scale, niter = niter, gain  = gain, $
                               gauss_basis = gauss_basis, progress_map = progress_map, $
                               progress_iter = progress_iter, negative_max = negative_max, $
                               _extra=_extra, quiet=quiet
  
  det_scale_default = [1, 1, 1, 2, 2, 2, 3, 3, 3]
   
  default, niter, 100
  default, image_dim, 65
  image_dim0 = image_dim[0]
  default, pixel_size, 1.0
  default, negative_max, 1
  default, det_scale, det_scale_default
  default, gain, 0.1
  default, progress_map, 1
  default, progress_iter, 5
  default, gauss_basis, 0b
  default, quiet, 0
  
  ;;;;;;;;;;;;;;;;;; Force the image_dim0 to be an odd number
  image_dim0 = floor( image_dim[0] ) / 2 * 2 +1
  
  ;;;;;;;;;;;;;;;;;; Number of scale
  nscale = max( det_scale )
  
  ;;;;;;;;;;;;;;;;;; Check 1: scale values corresponding to not-used detectors must be = 0 
  not_present = where_arr( indgen(9), get_uniq( vis.isc ), /notequal, not_present_count )
  if not_present_count ge 1 then det_scale[ not_present ] = 0  
  
  ;;;;;;;;;;;;;;;;;; Detectors associated with non-zero scales
  det_scale_in = det_scale[where( det_scale )]
  
  ;;;;;;;;;;;;;;;;;; Check 2: scale values must be in increasing order
  dummy = where( det_scale_in-det_scale_in[sort( det_scale_in )],count)
  if count ne 0 then begin
    print, 'Warning: scales are not in increasing order ==> Non-zero scales set to the default values'
    ind_zero = where( det_scale eq 0, count )
    det_scale = det_scale_default 
    if count ne 0 then det_scale[ind_zero] = 0
    nscale = max( det_scale )
  endif
  
  ;;;;;;;;;;;;;;;;;; Check 3: scale values must cover the whole range between 1 and max value 
  det_scale_in = det_scale[where( det_scale )]
  uniq_scale = get_uniq( det_scale_in ) 
  if ( max( uniq_scale ) ne n_elements( uniq_scale ) ) then begin
    print, 'Warning: scales must cover the whole range between 1 and max value ==> Non-zero scales re-arranged accordingly'
    k = 1
    i = 1
    while k ne n_elements( uniq_scale )+1 do begin
      ind_i = where( det_scale eq i, count )
      if count ne 0 then begin
        det_scale[ind_i] = k
        k = k+1
      endif
      i = i+1
    endwhile
    nscale = max( det_scale )
  endif
 
 ;;;;;;;;;;;;;;;;;
 if ~quiet then print,'det_scale', det_scale
 if ~quiet then print,'nscale', nscale
 ;;;;;;;;;;;;;;;;;
 
;  if progress_map then $
;    if ~keyword_set( windows ) then begin
;      window, /free
;      windows = [!d.window, 0]
;      window, /free
;      windows[1] = !d.window
;      endif
    
  ;;;;;;; Computation of the basis functions as 
  ;;;;;;; Gaussian (gauss_basis = 1) or as Parabolic functions (gauss_basis = 0)
  
  dirty_pixon_image_dim = image_dim0 * 2 + 1
  
  pixon = hsi_vis_multiscale_pixon_gen( vis, npixel = dirty_pixon_image_dim[0], $
                                    pixel_size = pixel_size, psf = psf, $
                                    dim_pixon = dim_pixon, detect = det_scale, gauss_basis = gauss_basis,$
                                    quiet = quiet )
  
  ;;;;;;; Computation of the cross-convolutions between basis-functions and 
  ;;;;;;; PSFs at different scales

  image_dim = dim_pixon > image_dim0
  dirty_pixon_image_dim = image_dim * 2 + 1
   
  if image_dim0 ne image_dim then begin
    
    dirty_pixon_space = hsi_vis_multiscale_psf_gen( vis, nscale, pixon, npixel = dirty_pixon_image_dim, $
                                              pixel_size = pixel_size, psf = psf, max_dirty_pixon = $
                                              max_dirty_pixon, init_psf = 1b, detect = det_scale )
  endif else begin
    
    dirty_pixon_space = hsi_vis_multiscale_psf_gen( vis, nscale, pixon, npixel = dirty_pixon_image_dim, $
                                              pixel_size = pixel_size, max_dirty_pixon = max_dirty_pixon, $
                                              init_psf = 0b, psf = psf, detect = det_scale )
  endelse

;  window, 3, xsize = 700, ysize = 700
;  !p.multi = [0, nscale, nscale]
;  for ibasis = 0, nscale-1 do begin      
;      for jspace = 0, nscale - 1 do begin
;        plot_image, dirty_pixon_space[jspace].basis[ibasis].map
;      endfor
;  endfor
;  !p.multi = 0

  
  ;;;;;; Make the full dirty_map 
  dirty_map_base = vis_bproj( vis, image_dim = image_dim, pixel_size = pixel_size, info = info, _ref_extra = extra)
  nvis = n_elements(vis)
  
  ;;;;;; Make the dirty maps at each scale (taking into account that the backprojection algorithm
  ;;;;;; weights in different way each dirty map with respect to the number of visibilities available
  ;;;;;; at each scale)
  dirty_map_scale = reproduce( dirty_map_base, nscale )
  
  for i = 0, nscale-1 do begin
    ind = where( det_scale eq i+1 )
    vis_dummy = vis[where( vis.isc GE min( ind ) and vis.isc LE max( ind ), countvis )]
    dirty_map_scale[*, *, i] = vis_bproj( vis_dummy, image_dim = image_dim, pixel_size = pixel_size, info = info, _ref_extra = extra )
    dirty_map_scale[*, *, i] = dirty_map_scale[*, *, i] * countvis / nvis
  endfor
  
;  !p.multi=[0,nscale,1]
;  window, 4, xsize=1000, ysize=400, title='Dirty Maps'
;  for i = 0, nscale-1 do plot_image, dirty_map_scale[*,*,i]
;  !p.multi=0
 
  ;;;;;;;;;;; Scale-bias based on both the logarithmic sampling of the (u,v) plane 
  ;;;;;;;;;;; and the max value in each dirty map. Scale-bias is used to enphasize
  ;;;;;;;;;;; small scales first.
  
  scl_bias = fltarr( nscale )
  ragione = sqrt( 3.0 )
  scl_bias_0 = 1./( sqrt( 3. )^indgen( 9 ) )
  for i = 0, nscale - 1  do begin
    ind = where( det_scale eq i+1 )
    scl_bias[i] = scl_bias_0[min( ind )]/max( dirty_map_scale[*,*,i] )
   end
   
  ;;;;;;;;;;; Check if there are scales where the dirty maps are just noise and set to zero
  ;;;;;;;;;;; the corresponding scale-bias values
  
  for i = 0, nscale-1 do begin
    minmax_val = minmax( dirty_map_scale[*,*,i] )
    if ( abs( minmax_val[0] ) GT minmax_val[1] ) then scl_bias[i] = 0
  endfor
 
;  for i = 0, nscale-1 do print, 'Dirty map max, scale',i+1,', scaled by scale bias : ', max( dirty_map_scale[*, *, i] )*scl_bias[i]
  for i = 0, nscale-1 do if ~quiet then print,'Dirty map max, scale',i+1,', scaled by scale bias : ', max( dirty_map_scale[*, *, i] )*scl_bias[i]

  ;;;;;;;;;;; Enter the clean loop
  clean_component = replicate( {multi_clean_component, space_index: 0, flux: 0.0, location:0L}, niter )
  scale_mask = intarr( nscale ) + 1

  if progress_map then $
    if ~keyword_set( windows ) then begin
    window, /free
    windows = [!d.window, 0]
    window, /free
    windows[1] = !d.window
  endif
  
  for ii = 0, niter-1 do begin
    all_max = 0
    zscale = where( scale_mask, nzscale )
    
    if nzscale eq 0 then begin  
      if ~quiet then print,'scale'
      break
    endif
    
    ;;;; Find the overall (scaled) max
    for jz = 0, nzscale-1 do begin
      j = zscale[jz]
      amax = max( /absolute, dirty_map_scale[ *, *, j], jc )
        if negative_max and amax lt 0 then begin
        scale_mask[j] = 0
      endif
      this_max = abs(amax)*scl_bias[j]

      if this_max gt all_max then begin
        xypsf = jc
        all_max = this_max
        max_jscale = j
      endif
    
    endfor
    
   if ~quiet then print,'Max at scale: ', max_jscale+1
    
    ;;;;;;;; Divide the scaled max by the corresponding scale_bias factor 
    ;;;;;;;; in order to use the true value in the updating rules
    all_max /= scl_bias[max_jscale]
    
    ;;;;;;;; Scale flux in pixon flux basis: max_dirty_pixon[max_jscale] is the
    ;;;;;;;; cross convolution between the basis-function at scale max_jscale and 
    ;;;;;;;; the PSF at the same scale.
    ;;;;;;;; Get the fraction of the max_jscale,max_jscale component, multiply by 
    ;;;;;;;; the gain and reserve
    flux = gain * all_max / max_dirty_pixon[max_jscale]  
    clean_component[ii] = {multi_clean_component, max_jscale, flux, xypsf }
  
    ;;;;;;;; Update the dirty maps at each scale subtracting the cross convolution 
    ;;;;;;;; between the basis function at scale max_jscale and the PSF at each scale
    for iscale = 0, nscale -1 do begin
      dummy = flux * vis_psf( vis, xypsf, psf00 = dirty_pixon_space[max_jscale].basis[iscale].map, $
                              image_dim = image_dim, init = 0b)     
      dirty_map_scale[ *, *, iscale] -= dummy 
    endfor
    
    ;;;;;;;; Add the component (basis function at scale max_jscale) in the clean map. 
    clean_map = exist( clean_map ) ? clean_map : fltarr( image_dim+lonarr(2) )
    clean_map = psf_add( xypsf, clean_map, flux * pixon[*,*, max_jscale] )
    
    ;;;;;;;; Progress map 
    if progress_map and (ii mod progress_iter eq 0) then begin
      wset, windows[0]
      nrow = nscale / 2 
      nrow = (2 * nrow) eq nscale ? nrow : nrow + 1
      !p.multi = [0, nrow, 2]
      for im = 0, nscale-1 do begin    
        dirty_map_progress = make_map(dirty_map_scale[image_dim[0]/2-image_dim0/2:image_dim[0]/2+image_dim0/2, $
          image_dim[0]/2-image_dim0/2:image_dim[0]/2+image_dim0/2,im],$
          id = ' ', $
          xc = vis[0].xyoffset[0], yc = vis[0].xyoffset[1], $
          dx = pixel_size, dy = pixel_size, $
          xunits = 'arcsec', yunits = 'arcsec')
          string_tit='Dirty map at scale '+string(format='(i1.1)',im+1)
          plot_map, dirty_map_progress, title=string_tit
     endfor
      
      wset, windows[1]
      !p.multi = 0
      wait, .5
      clean_map_progress = make_map(clean_map[image_dim[0]/2-image_dim0/2:image_dim[0]/2+image_dim0/2, $
                       image_dim[0]/2-image_dim0/2:image_dim[0]/2+image_dim0/2],$
                       id = ' ', $
                       xc = vis[0].xyoffset[0], yc = vis[0].xyoffset[1], $
                       dx = pixel_size, dy = pixel_size, $
                       xunits = 'arcsec', yunits = 'arcsec',$
                       time=anytim(vis[0].trange[0],/ecs))
      plot_map, clean_map_progress, title='Clean Component map in progress', /limb, /cbar
    
    endif
     
    endfor
    
    ;;;;;;;; Provide the clean map in the requested dimensions
    result = clean_map[image_dim[0]/2-image_dim0/2:image_dim[0]/2+image_dim0/2, $
                     image_dim[0]/2-image_dim0/2:image_dim[0]/2+image_dim0/2]
  
    ;;;;;;;; Compute the residuals and add them to the clean map
    resid = dirty_map_scale[*,*,0] 
    for iscale = 1, nscale -1 do resid+= dirty_map_scale[*,*,iscale] 
  
    psftot = vis_psf( vis, image_dim = image_dim, pixel = pixel_size, init=1 )
    flux_psf = total( psftot ) * pixel_size * pixel_size

    residual = resid[image_dim[0]/2-image_dim0/2:image_dim[0]/2+image_dim0/2, $
                   image_dim[0]/2-image_dim0/2:image_dim[0]/2+image_dim0/2] / flux_psf
    
    ;;;;;;;; Return clean map + residuals
    return, result+residual
    
end
