#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#' # Blatt 8

import numpy as np
from numpy.linalg import norm
from scipy.linalg import hessenberg, qr, hilbert, eig
import matplotlib.pyplot as plt
import matplotlib.cm as cmc
import time

#' ## Aufgabe 32 - QR-Algorithmus ohne Shifts

#' #(a)+(c) - qralg (wende den ungeshifteten QR-Algorithmus auf obere Hessenbergmatrix an)

def qralg(H, tol, plot=True):
#Eingabe: Matrix H und die Toleranz für den Abbruch
#Ausgabe: Matrix H mit Eigenwerten auf der Diagonalen
#Plot=True erstellt Plot, der zeigt, in welcher Reihenfolge die Einträge der
#Nebendiagonale verschwinden, bzw. wie sich die anderen Einträge verändern
    werteH = [abs(np.diag(H, -1))]

    if plot:
        fig = plt.figure(1)
        ax = fig.gca()
        p = ax.imshow(np.log(abs(H)+10**(-15)))
        fig.colorbar(p, ax=ax)

    while norm(np.diag(H, -1), ord=np.inf) >= tol:
        Q, R = qr(H)
        H = R @ Q
        werteH.append(abs(np.diag(H, -1)))
        if plot:
            p.set_data(np.log(abs(H)+10**(-15)))
            plt.pause(0.01)

    return H, werteH


#' #(b) - findeeigenwerte (EW mit Hilfe von qralg)

def findeeigenwerte(A, tol=1e-16, plot=True, timer=False):
#Eingabe: Matrix A und die Toleranz für den Abbruch beim QR Algorithmus
#Ausgabe: Vektor ew mit Approximationen der Eigenwerte
    H = hessenberg(A)
    if timer:
        start = time.time()
        H, werteH = qralg(H, tol, plot=False)
        end = time.time()
        return end-start
    else:
        H, werteH = qralg(H, tol, plot)
        ew = np.diag(H)
        return ew, np.array(werteH)


#' #(d) Test der Implementierung
# Erstelle mit Hilfe von Ähnlichkeitstrafo eine (nxn)-Matrix mit Eigenwerten n,n-1,...,1
def bspMatrix(n):
    ew = np.arange(1, n+1)[::-1]
    Q, R = qr(np.random.rand(n, n))
    np.fill_diagonal(R, ew)
    A = Q.conj().T @ R @ Q
    return A, ew
#%%
if __name__ == '__main__':
    n = 10
    A, ew_ex = bspMatrix(n)
    ew, werteH = findeeigenwerte(A)
    print("Der betragsgrößte Fehler der berechneten Eigenwerte beträgt "  +\
          str(norm(ew-ew_ex)) + ".")
    fig1 = plt.figure(2)
    ax1 = fig1.gca()
    for kk in range(n-1):
        ax1.semilogy(np.arange(len(werteH)), werteH[:, kk],\
                     label='|H_({0},{1})|'.format(kk+2, kk+1),\
                     color=cmc.rainbow(kk/(n-1)))
    ax1.legend(ncol=2)
    ax1.set_title("Matrix mit rellen Eigenwerten")
    ax1.set_xlabel("Anzahl der QR-Zerlegungen")
    ax1.set_ylabel("Nebendiagonale")

    #%%
    n = 15
    B = hilbert(n)
    ew, werteH = findeeigenwerte(B)
    ew_ex, ev_ex = eig(B)
    ew_ex = sorted(ew_ex, reverse=True)
    print("Der betragsgrößte Fehler der berechneten Eigenwerte beträgt "  +\
          str(norm(ew-ew_ex)) + ".")
    fig2 = plt.figure(3)
    ax2 = fig2.gca()
    for kk in range(n-1):
        ax2.semilogy(np.arange(len(werteH)), werteH[:, kk],\
                     label='|H_({0},{1})|'.format(kk+2, kk+1),\
                     color=cmc.rainbow(kk/(n-1)))
    ax2.legend(ncol=2)
    ax2.set_title("Hilbertmatrix")
    ax2.set_xlabel("Anzahl der QR-Zerlegungen")
    ax2.set_ylabel("Nebendiagonale")
