#!/usr/bin/env python3
# -*- coding: utf-8 -*-

#' # Blatt 9

import numpy as np
import numpy.linalg as la
import matplotlib.pyplot as plt
import scipy.integrate as ode

#' # Aufgabe 33 - Räuber-Beute Modell

#Expliziten Euler importieren:
from ExpliziterEuler import ExpliziterEuler

#Impliziten Euler implementieren:
def ImpliziterEuler(f, Jacf, tspan, yinit, N, tol=1e-10):
    '''
    f : Funktion f(t, y)
    Jacf : Jacobi-Matrix von f
    tspan = [t_0, t_N] : Start- und Endzeitpunkt
    yinit = y(t_0)
    N : Anzahl der Schritte
    '''
    y = yinit*1.
    t = tspan[0]
    yvec = [y]
    tvec = [t]
    h = (tspan[1]-t)/N
    Id = np.eye(len(y))

    for k in range(N):
        yalt = y
        t += h
        g = lambda x: x - yvec[-1] - h*f(t, x)
        while la.norm(g(y)) > tol:
            #print('eins', g(yalt))
            y = yalt - la.solve(Id - h * Jacf(t, yalt), g(yalt))
            #print('zwei', g(yalt))
            yalt = y
        yvec.append(y)
        tvec.append(t)

    return tvec, np.asarray(yvec)

#%% Anwedung auf Räuber-Beute Modell

def f(t, y):
    fvec = np.zeros(y.shape)
    fvec[0] = -2*y[0] + y[0]*y[1]
    fvec[1] = y[1] - y[0]*y[1]
    return fvec

def Jacf(t, y):
    return np.array([[-2+y[1,0], y[0,0]], [-y[1,0], 1-y[0,0]]])

y0 = np.array([1, 1.5]).reshape(-1, 1)
tspan = (0, 30)
dt = 0.1 #N=300
N = round(tspan[1]/dt)

tExp, yExp = ExpliziterEuler(f, tspan, y0, N)
yExp = np.asarray(yExp)
tImp, yImp = ImpliziterEuler(f, Jacf, tspan, y0, N)
yref = ode.solve_ivp(f, tspan, y0.flatten(), atol=10**(-7), rtol=10**(-7), method ='BDF')

#%%
fig = plt.figure(1, figsize=(15, 10))
ax1 = plt.subplot(321)
ax2 = plt.subplot(322)
ax3 = plt.subplot(323)
ax4 = plt.subplot(324)
ax5 = plt.subplot(325)
ax6 = plt.subplot(326)

ax1.set_title('Expliziter Euler')
ax1.plot(tExp, yExp[:, 0], tExp, yExp[:, 1])
ax1.legend(('Räuber', 'Beute'))

ax2.plot(yExp[:, 0], yExp[:, 1])
ax2.plot(yExp[0, 0], yExp[0, 1], 'r*')

ax3.set_title('Impliziter Euler')
ax3.plot(tImp, yImp[:, 0], tImp, yImp[:, 1])
ax3.legend(('Räuber', 'Beute'))

ax4.plot(yImp[:, 0], yImp[:, 1])
ax4.plot(yImp[0, 0], yImp[0, 1], 'r*')

ax5.set_title('Referenz-Lösung')
ax5.plot(yref.t, yref.y[0, :], yref.t, yref.y[1, :])
ax5.legend(('Räuber', 'Beute'))

ax6.plot(yref.y[0, :], yref.y[1, :])
ax6.plot(yref.y[0, 0], yref.y[1, 0], 'r*')
