#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
%phase plane in Figures 4.2, 4.3, 4.4A, 4.5 and 
%continuation diagram in Figure 4.18
Created on Mon Dec 12 17:37:29 2022
%qmichaelis_menten.py
%Figure 3.3 Michaelis-Menten kinetics


@author: bingalls
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint



#assign parameter values
k1=20
k2=5
k3=5
k4=5
k5=2
K=1
n=4


#set time_grid for simulation
t_min=0; t_max=1.5; dt=0.001
times=np.arange(t_min, t_max+dt, dt) #generate time-grid list

#declare right-hand-side for model
def dSdt_original(S,t):
    dS=[0,0] #generate a list to store derivatives
    dS[0]=k1/(1+np.power(S[1],n)) - k3*S[0] - k5*S[0]
    dS[1]=k2 - k4*S[1] + k5*S[0]
    return dS

###Full model#####
#set initial conditions State vector is [s1, s2]
# five trajectories to plot
S1_0=[0,0];
S1=odeint(dSdt_original, S1_0, times) #run simulation
S2_0=[0.5,0.6];
S2=odeint(dSdt_original, S2_0, times) #run simulation
S3_0=[0.17,1.1];
S3=odeint(dSdt_original, S3_0, times) #run simulation
S4_0=[0.25,1.9];
S4=odeint(dSdt_original, S4_0, times) #run simulation
S5_0=[1.85,1.7];
S5=odeint(dSdt_original, S5_0, times) #run simulation

#plot figure 4.2A
plt.figure() #generate figure
plt.plot(times, S1[:,0], label="s1", linewidth=2)
plt.plot(times, S1[:,1], label="s2", linewidth=2)
plt.xlabel("time")
plt.ylabel("concentration")
plt.legend()
plt.title('Figure 4.2A')

#plot  figure 4.2B
plt.figure() 
plt.plot(S1[:,0], S1[:,1], label="s", linewidth=2)
plt.xlabel("concentration")
plt.ylabel("concentration")
plt.legend()
plt.title('Figure 4.2B')

#plot figure 4.3A
plt.figure() #generate figure
plt.plot(times, S1[:,0], label="s1", linewidth=2)
plt.plot(times, S1[:,1], label="s2", linewidth=2)
plt.plot(times, S2[:,0], label="s1", linewidth=2)
plt.plot(times, S2[:,1], label="s2", linewidth=2)
plt.plot(times, S3[:,0], label="s1", linewidth=2)
plt.plot(times, S3[:,1], label="s2", linewidth=2)
plt.plot(times, S4[:,0], label="s1", linewidth=2)
plt.plot(times, S4[:,1], label="s2", linewidth=2)
plt.plot(times, S5[:,0], label="s1", linewidth=2)
plt.plot(times, S5[:,1], label="s2", linewidth=2)
plt.xlabel("time")
plt.ylabel("concentration")
plt.legend()
plt.title('Figure 4.3A')

#plot  figure 4.3B
plt.figure() 
plt.plot(S1[:,0], S1[:,1], label="s", linewidth=2)
plt.plot(S2[:,0], S2[:,1], label="s", linewidth=2)
plt.plot(S3[:,0], S3[:,1], label="s", linewidth=2)
plt.plot(S4[:,0], S4[:,1], label="s", linewidth=2)
plt.plot(S5[:,0], S5[:,1], label="s", linewidth=2)
plt.xlabel("concentration")
plt.ylabel("concentration")
plt.legend()
plt.title('Figure 4.3B')


#plot  figure 4.4A

#generate direction field
x = np.linspace(0, 2, 20)
y = np.linspace(0, 2, 20)
[xx,yy]=np.meshgrid(x,y)
xdot=k1/(1+np.power(yy,n)) - k3*xx - k5*xx
ydot=k2 - k4*yy + k5*xx
L = np.sqrt(np.power(xdot,2) + np.power(ydot,2)) # vector lengths

plt.figure() 
plt.plot(S1[:,0], S1[:,1], label="s", linewidth=2)
plt.plot(S2[:,0], S2[:,1], label="s", linewidth=2)
plt.plot(S3[:,0], S3[:,1], label="s", linewidth=2)
plt.plot(S4[:,0], S4[:,1], label="s", linewidth=2)
plt.plot(S5[:,0], S5[:,1], label="s", linewidth=2)
plt.quiver(xx, yy, xdot/L, ydot/L) 
plt.xlabel("concentration")
plt.ylabel("concentration")
plt.legend()
plt.title('Figure 4.4A')

#plot  figure 4.5A

#generate nullclines
plt.figure() 

from numpy import arange
from numpy import meshgrid

delta = 0.025
# xrange = arange(0, 2, delta)
# yrange = arange(0, 2, delta)
xrange = arange(0, 2, delta)
yrange = arange(0, 2, delta)
X, Y = meshgrid(xrange,yrange)


F = k1/(1 + np.power(Y/K,n))-k3*X-k5*X
G= k2+k5*X-k4*Y

#plot nullclines
plt.contour(X, Y, F, [0])
plt.contour(X, Y, G, [0])




#plot trajectories
plt.plot(S1[:,0], S1[:,1], linewidth=2)
plt.plot(S2[:,0], S2[:,1], linewidth=2)
plt.plot(S3[:,0], S3[:,1], linewidth=2)
plt.plot(S4[:,0], S4[:,1], linewidth=2)
plt.plot(S5[:,0], S5[:,1], linewidth=2)

plt.xlabel("concentration")
plt.ylabel("concentration")
plt.legend()
plt.title('Figure 4.5A')

#plot  figure 4.5B

plt.figure() 

from numpy import arange
from numpy import meshgrid

delta = 0.025
# xrange = arange(0, 2, delta)
# yrange = arange(0, 2, delta)
xrange = arange(0, 2, delta)
yrange = arange(0, 2, delta)
X, Y = meshgrid(xrange,yrange)


F = k1/(1 + np.power(Y/K,n))-k3*X-k5*X
G= k2+k5*X-k4*Y

#plot nullclines
plt.contour(X, Y, F, [0])
plt.contour(X, Y, G, [0])

#generate direction field
x = np.linspace(0, 2, 20)
y = np.linspace(0, 2, 20)
[xx,yy]=np.meshgrid(x,y)
xdot=k1/(1+np.power(yy,n)) - k3*xx - k5*xx
ydot=k2 - k4*yy + k5*xx
L = np.sqrt(np.power(xdot,2) + np.power(ydot,2)) # vector lengths
plt.quiver(xx, yy, xdot/L, ydot/L) 


plt.title('Figure 4.5B')
# this is a dose-response generated by plotting the steady state reached from a collection of simulations


#Run N simulations to steady state for different values of k_1
#set time_grid for simulation
t_min=0; t_max=10; dt=0.01
bif_times=np.arange(t_min, t_max+dt, dt) #generate time-grid list


N=30
k1_values=np.zeros(N)
S1_state=np.zeros(N)
for i in range(1, N):
    k1=0+(40)*((i-1)/(N-1))
    S=odeint(dSdt_original, [0,0], bif_times) #run simulation
    k1_values[i]=k1;
    S1_state[i]=S[S.shape[0]-1,0]
    

plt.figure() 
plt.plot(k1_values,S1_state, linewidth=2)
plt.xlabel("$k_1$")
plt.ylabel("Steady state $S_1$ concentration")
plt.xlim(0,40)
plt.ylim(0,1.4)


