#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from netgen.geom2d import unit_square
import netgen.gui
import ngsolve as ngs
import matplotlib.pyplot as plt
import numpy as np

def SolveProblem(h=0.5, p=1, levels=1, precond='jacobi'):
    """
    modified version of 
    https://ngsolve.org/docu/latest/i-tutorials/unit-2.1.1-preconditioners/preconditioner.html
    
    Solve Poisson problem on l refinement levels.
        h: coarse mesh size
        p: polynomial degree
        l: number of refinement levels
        precond: name of a built-in preconditioner
    OUTPUT:
        List of tuples of ndofs and iterations
    """
    mesh = ngs.Mesh(unit_square.GenerateMesh(maxh=h))
    # mesh = Mesh(unit_cube.GenerateMesh(maxh=h))
    fes = ngs.H1(mesh, order=p, dirichlet="bottom|left")

    u, v = fes.TnT()
    a = ngs.BilinearForm(fes)
    a += ngs.SymbolicBFI(ngs.grad(u)*ngs.grad(v))
    f = ngs.LinearForm(fes)
    f += ngs.SymbolicLFI(1*v)
    gfu = ngs.GridFunction(fes)
    ngs.Draw (gfu)
    if precond == "multigrid":
        c = ngs.Preconditioner(a, precond) # 'Register' c to a BEFORE assembly
        pre = c.mat
        
    steps = []

    for l in range(levels):
        if l > 0: mesh.Refine()
        fes.Update()
        a.Assemble()

        if precond == "jacobi":
            pre = a.mat.CreateSmoother(freedofs = fes.FreeDofs())
        if precond == "gaussseidel":
            smooth = a.mat.CreateSmoother(freedofs = fes.FreeDofs())
            pre = GS(smooth)
        if precond == "symgaussseidel":
            smooth = a.mat.CreateSmoother(freedofs = fes.FreeDofs())
            pre = SymmetricGS(smooth)
        if precond == None:
            pre = None
            
        f.Assemble()
        gfu.Update()

        # Conjugate gradient solver
        out = myCG(a.mat, f.vec, pre, maxsteps=1000,  \
                               printrates = False, tol = 1e-08)

        gfu.vec.data = out[0]
        steps.append( (fes.ndof, out[1]) )
        ngs.Redraw()
    return steps


class GS(ngs.BaseMatrix):
    '''
    Gauß-Seidel-Iteration
    vgl. https://ngsolve.org/docu/latest/i-tutorials/unit-2.1.2-blockjacobi/blockjacobi.html
    '''
    def __init__ (self, smoother):
        super(GS, self).__init__()
        self.smoother = smoother
    def Mult (self, x, y):
        y[:] = 0.0
        self.smoother.Smooth(y, x)
        self.smoother.Smooth(y, x)
        self.smoother.Smooth(y, x)
    def Height (self):
        return self.smoother.height
    def Width (self):
        return self.smoother.height

class SymmetricGS(ngs.BaseMatrix):
    '''
    symmetrische Gauß-Seidel-Iteration 
    vgl. https://ngsolve.org/docu/latest/i-tutorials/unit-2.1.2-blockjacobi/blockjacobi.html
    '''
    def __init__ (self, smoother):
        super(SymmetricGS, self).__init__()
        self.smoother = smoother
    def Mult (self, x, y):
        y[:] = 0.0
        self.smoother.Smooth(y, x)
        self.smoother.SmoothBack(y,x)
    def Height (self):
        return self.smoother.height
    def Width (self):
        return self.smoother.height

  
def myCG(mat, rhs, pre=None, sol=None, tol=1e-12, maxsteps = 100, printrates = True, initialize = True, conjugate=False):
    """preconditioned conjugate gradient method
    
    modified version of Ngsolve-cg 
    (https://ngsolve.org/docu/latest/i-tutorials/unit-2.1.2-blockjacobi/blockjacobi.html)

    Parameters
    ----------

    mat : Matrix
      The left hand side of the equation to solve. The matrix has to be spd or hermitsch.

    rhs : Vector
      The right hand side of the equation.

    pre : Preconditioner
      If provided the preconditioner is used.

    sol : Vector
      Start vector for CG method, if initialize is set False. Gets overwritten by the solution vector. If sol = None then a new vector is created.

    tol : double
      Tolerance of the residuum. CG stops if tolerance is reached.

    maxsteps : int
      Number of maximal steps for CG. If the maximal number is reached before the tolerance is reached CG stops.

    printrates : bool
      If set to True then the error of the iterations is displayed.

    initialize : bool
      If set to True then the initial guess for the CG method is set to zero. Otherwise the values of the vector sol, if provided, is used.

    conjugate : bool
      If set to True, then the complex inner product is used.


    Returns
    -------
    (vector)
      Solution vector of the CG method.

    """

    u = sol if sol else rhs.CreateVector()
    d = rhs.CreateVector()
    w = rhs.CreateVector()
    s = rhs.CreateVector()

    if initialize: u[:] = 0.0
    d.data = rhs - mat * u
    w.data = pre * d if pre else d
    err0 = np.sqrt(abs(w.InnerProduct(d,conjugate=conjugate)))
    s.data = w
    
    wdn = w.InnerProduct(d, conjugate=conjugate)

    if wdn==0:
        return u
    
    for it in range(maxsteps):
        w.data = mat * s
        wd = wdn
        as_s = s.InnerProduct(w, conjugate=conjugate)        
        alpha = wd / as_s
        u.data += alpha * s
        d.data += (-alpha) * w

        w.data = pre*d if pre else d
        
        wdn = w.InnerProduct(d, conjugate=conjugate)
        beta = wdn / wd

        s *= beta
        s.data += w

        err = np.sqrt(abs(wd))
        if printrates:
            print ("it = ", it, " err = ", err)
        if err < tol*err0: break
    else:
        print("Warning: CG did not converge to TOL")

    return u, it, err

if __name__ == '__main__':
    res = SolveProblem(levels = 8, precond = "jacobi")
    print(res)
    res = SolveProblem(levels = 8, precond = "multigrid")
    print(res)
    res = SolveProblem(levels = 8, precond = "gaussseidel")
    print(res)
    res = SolveProblem(levels = 8, precond = "symgaussseidel")
    print(res)
    res = SolveProblem(levels = 8, precond = None)
    print(res)
