;+
;
; NAME:
;   vis_wrfiwt
;
; PURPOSE:
;   Wavelet (FIWT) regularization deconvolution method.
;   Solves the following minimization problem: min_x {||B-Hx||^2_2 + \lambda||Wx||_1}
;
; CALLING SEQUENCE:
;   map = vis_wrfiwt(vis, imsize=imsize, /autolam)
;
; INPUTS:
;   vis: input visibility structure in standard format
;
; KEYWORDS:
;   IMSIZE: output map size in pixels (odd number)
;   PIXEL: pixel size in asec (default is 1)
;   NITER: max number of iterations (default is 100)
;   AUTOLAM: automatically estimate the regularization parameter \lambda 
;   LAM: set the regularization parameter if autolam=0 (default is 0.05)
;   SILENT: if not set, plots the min func values at each iteration (default is 1)
;   MAKEMAP: if set, returns the map structure. Otherwise returns the 2D matrix
;
; RETURNS:
;   map: image map in the structure format provided by the routine make_map.pro
;
; RESTRICTIONS:
;   -IMSIZE should be odd
;
; HISTORY:
;   May-2017 Written by Miguel A. Duval-Poo
;
; CONTACT:
;   duvalpoo [at] dima.unige.it
;
;-
function vis_wrfiwt, vis, IMSIZE=imsize, PIXEL=pixel, NITER=niter, LAM=lam, SILENT=silent, AUTOLAM=autolam, MAKEMAP=makemap

  default, imsize, 129
  default, lam, 0.05
  default, niter, 200
  default, silent, 1
  default, autolam, 1
  default, pixel, 1  
  default, makemap, 1
    
  ; build the wavelet spectra
  psi = fiwt_spectra(imsize, 3)  
   
  ; backprojection map
  vis_bpmap, vis, map=dirty_map, bp_fov=imsize[0]*pixel, pixel=pixel

  ; dirty beam
  psf = vis_psf(vis, pixel=pixel, image_dim=imsize[0])
  
  ; start
  B = real_part(dirty_map)
  P = real_part(psf)/total(psf) ; normalized!

  if total(size(B, /DIMENSIONS)) ne total(size(P, /DIMENSIONS)) then message, 'Error: size(dirty_map) must be equal to size(psf).'

  imsize =  size(B, /DIMENSIONS)
  center = fix(imsize/2)

  if imsize[0] ne imsize[1] then message, 'Error: dirty_map and psf matrix must be square.'

  imsize = imsize[0]
  
  cidx =  where(vis.isc eq max(vis.isc), dcount)   
  dc = max(real_part(vis.obsvis[cidx]))
  
  P = fft(shift(P, 1-center))*n_elements(P)
  
  ; flux constraint initial scaling
  B = B*(dc/total(B))
    
  ; computing the two dimensional transform of B
  Btrans = fft(B)

  ; the Lipschitz constant
  L = 2*max(abs(P)^2)

  ; initialization
  old_total_val = 0
  X_iter = B  
  Y = X_iter
  t_new = 1

  if keyword_set(autolam) then begin
    lam = 0 ; that is, solve the first iteration without regularization and based on the solution estimate lambda
  endif

  for i = 1,niter do begin
    ; store the old value of the iterate and the t-constant
    X_old = X_iter
    t_old = t_new

    ; gradient step
    D = P*fft(Y)-Btrans
    Y = Y-2./L*fft(conj(P)*D, /inverse)
    
    ; wavelet transform
    WY = fiwt(real_part(Y), psi)

    ; soft thresholding
    D = abs(WY)-lam;/L
    WY =  signum(abs(WY))*((D gt 0)*D)

    ; the new iterate inverse wavelet transform of WY
    X_iter = real_part(fiwt(WY, psi, /inverse))   

    ; flux constraint
    X_iter = X_iter - (total(X_iter)-dc)/(imsize*imsize)
            
    ; updating t and Y
    t_new = (1+sqrt(1.+4*t_old^2))/2.
    Y = X_iter+((t_old-1)/t_new)*(X_iter-X_old)

    ; evaluating
    residual = B - real_part(fft(P*fft(X_iter), /inverse))
    likelyhood = norm(abs(residual[*]))^2
    sparsity = total(abs(fiwt(X_iter, psi)))
    
    ; lambda estimation
    if i eq 1 and keyword_set(autolam) then begin
      lam = likelyhood/sparsity
    endif

    total_val = likelyhood + lam*sparsity    
        
    ; printing the information of the current iteration
    if ~keyword_set(silent) then print, i, total_val, total(abs(X_iter)), likelyhood

    ; stopping criteria
    if i GT 10 and old_total_val le total_val then break

    old_total_val = total_val

  endfor

  if ~keyword_set(silent) then print, 'Lambda: ', lam

  X_iter = real_part(X_iter*(X_iter gt 0))
  
  return, makemap ? make_map(X_iter,xcen=vis.xyoffset[0],ycen=vis.xyoffset[1], dx=pixel, dy=pixel,time=anytim(vis[0].trange[0],/ecs), id = 'WR-FIWT') : X_iter
  
end
