#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Feb  6 13:38:54 2021

@author: schaedle
"""
import numpy as np
from scipy.linalg import expm
from scipy.integrate import solve_ivp
import matplotlib
import matplotlib.pyplot as plt
from scipy.linalg import eig

class awp(object):
    def __init__(self, t0=0, f=None, y0=None, jacf=None):
        self.t0 = t0
        if f is None:
            self.f = lambda y: -y
        else:
            self.f = f
        if jacf is None:
            self.jacf = None
        else:
            self.jacf = jacf
        if y0 is None:
            self.y0 = 1
        else:
            self.y0 = y0
        self.m = len(y0)
        
    def fluss(self, t0, t, y0):
        sol = solve_ivp(self.f, [t0, t], y0, jac=self.jacf, method='Radau', rtol=1e-4, atol=1e-8)
        return sol.y[:,-1]
    
    def plotLsg(self, fig, ax, y0s, t):
        
        Trajektorie = np.zeros((self.m, len(t), len(y0s)))
        
        
        for j, y0_ in enumerate(y0s):
            t0 = t[0]
            yt = y0_
            
            for i, t_ in enumerate(t):
                yt = self.fluss(t0, t_, np.array(yt))
                Trajektorie[:, i, j] = yt
                t0 = t_
        ## Zeichnen
        
        y01 = Trajektorie[0, 0, :].flatten()
        y02 = Trajektorie[1, 0, :].flatten()
        for idx in range(len(y0s)):
            y1 = Trajektorie[0, :, idx].flatten()
            y2 = Trajektorie[1, :, idx].flatten()
            ax[0].plot(y1,y2)
    
        ax[0].plot(y01, y02, 'ro')
        ax[1].plot(y01, y02, 'ro')
        
        l, = ax[0].plot(y01,y02,'bo')

        ax[0].axis('equal')
        ax[0].set(xlim=(-3, 3), ylim=(-3, 3))
        ax[1].axis('equal')
        ax[1].set(xlim=(-3, 3), ylim=(-3, 3))

        def animate(i):
            ux = Trajektorie[0, i, :].flatten()
            uy = Trajektorie[1, i, :].flatten()
            l.set_data(ux, uy)

        ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(t), repeat=False)
        ax[0].plot(y01, y02, 'ro')
        return fig, ax, ani
    
class awp_linah(awp):
    def __init__(self, t0, A, y0):
        self.A = A
        self.f = lambda y: A@y
        awp.__init__(self, t0, self.f, y0)
    def fluss(self, t0, t,y):
        return expm((t-t0)*self.A)@y
    
if __name__ == "__main__":
    
    A = np.array([[1.,0],[0,1]])
    y0 = np.array([1,-1])
    dgl1 = awp_linah(0, A, y0)
    fig,ax = plt.subplots(2)
    y0s = [[1,0],[0,1]]
    t = np.linspace(0,1)
    dgl1.plotLsg(fig, ax, y0s, t)