#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

%file symmetric_network.m
%Model of oscillatory network 
%This code generates Figures
%4.15, 4.16, and 4.17


@author: bingalls
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint



#declare parameter values
k0=8
k1=1
k2=5
n=2

#declare right-hand-side for model
def dSdt_original(S,t):
    dS=[0,0] #generate a list to store derivatives
    dS[0]=k0 - k1*S[0]*(1+np.power(S[1],n))
    dS[1]=k1*S[0]*(1+np.power(S[1],n)) - k2*S[1]
    return dS

times=np.linspace(0,8,1000)


#set initial conditions State vector is [s1, s2]

S1_0=[1.5,1];
S1=odeint(dSdt_original, S1_0, times) #run simulation
S2_0=[0,1];
S2=odeint(dSdt_original, S2_0, times) #run simulation
S3_0=[0,3];
S3=odeint(dSdt_original, S3_0, times) #run simulation
S4_0=[2,0];
S4=odeint(dSdt_original, S4_0, times) #run simulation


#generate figure 4.15A
plt.figure()
plt.title('Figure 4.15A')

plt.plot(times, S1[:,0], label="s1", linewidth=2)
plt.plot(times, S1[:,1], label="s2", linewidth=2)
plt.legend()
plt.ylabel("concentration")
plt.xlabel("time")
plt.ylim(0,3.5)



#generate  figure 4.15B



plt.figure() 
plt.plot(S2[:,0], S2[:,1], linewidth=2)
plt.plot(S3[:,0], S3[:,1], linewidth=2)
plt.plot(S4[:,0], S4[:,1], linewidth=2)
plt.xlabel("concentration of $S_1$")
plt.ylabel("concentration of $S_2$")
plt.legend()
plt.title('Figure 4.15B')


#generate nullclines
#s1 nullcline
ns12=np.linspace(0, 4, 100)
ns11=k0/(k1*(1+np.power(ns12,n)))
#s2 nullcline (k2/2 since divided by (1+1^0))
ns22=np.linspace(0, 4, 100)
ns21=k2*ns22/(k1*(1+np.power(ns22,n)))
#plot nullclines
plt.plot(ns11, ns12, 'k--')
plt.plot(ns21, ns22, 'k--')

#direction field
x = np.linspace(0, 4, 10)
y = np.linspace(0, 4, 10)
[xx,yy]=np.meshgrid(x,y)

xdot=k0 - k1*xx*(1+np.power(yy,n))
ydot=k1*xx*(1+np.power(yy,n)) - k2*yy
L = np.sqrt(np.power(xdot,2) + np.power(ydot,2)) # vector lengths
plt.quiver(xx, yy, xdot/L, ydot/L) 



plt.xlim(0, 4)
plt.ylim(0, 4)









###generate figure 4.16#######

k0=8
k1=1
k2=5
n=2.5



#set initial conditions State vector is [s1, s2]
# two trajectories to plot
S1_0=[0,1];
S1=odeint(dSdt_original, S1_0, times) #run simulation
S2_0=[0,3];
S2=odeint(dSdt_original, S2_0, times) #run simulation
S3_0=[2,0];
S3=odeint(dSdt_original, S3_0, times) #run simulation


#generate figure 4.16A
plt.figure()
plt.title('Figure 4.16A')

plt.plot(times, S1[:,0], label="s1", linewidth=2)
plt.plot(times, S1[:,1], label="s2", linewidth=2)
plt.legend()
plt.ylabel("concentration")
plt.xlabel("time")
plt.ylim(0,4)


#generate figure 4.16B

plt.figure() 
plt.plot(S1[:,0], S1[:,1], linewidth=2)
plt.plot(S2[:,0], S2[:,1], linewidth=2)
plt.plot(S3[:,0], S3[:,1], linewidth=2)
plt.xlabel("concentration of $S_1$")
plt.ylabel("concentration of $S_2$")
plt.legend()
plt.title('Figure 4.15B')


#generate nullclines
#s1 nullcline
ns12=np.linspace(0, 4, 100)
ns11=k0/(k1*(1+np.power(ns12,n)))
#s2 nullcline (k2/2 since divided by (1+1^0))
ns22=np.linspace(0, 4, 100)
ns21=k2*ns22/(k1*(1+np.power(ns22,n)))
#plot nullclines
plt.plot(ns11, ns12, 'k--')
plt.plot(ns21, ns22, 'k--')

#direction field
x = np.linspace(0, 4, 10)
y = np.linspace(0, 4, 10)
[xx,yy]=np.meshgrid(x,y)

xdot=k0 - k1*xx*(1+np.power(yy,n))
ydot=k1*xx*(1+np.power(yy,n)) - k2*yy
L = np.sqrt(np.power(xdot,2) + np.power(ydot,2)) # vector lengths
plt.quiver(xx, yy, xdot/L, ydot/L) 



plt.xlim(0, 4)
plt.ylim(0, 4)


#generate Figure 4.17 

longtimes=np.linspace(0,100,10000)

S1_0=[1.8,1.6];
S1=odeint(dSdt_original, S1_0, times) #run simulation
S2_0=[1.8,1.6];
S2=odeint(dSdt_original, S2_0, longtimes) #run simulation



#to reach limit cycle cut off first half of long trajectory
S2=S2[5000:10000]

plt.figure() 
plt.plot(S1[:,0], S1[:,1], linewidth=2)
plt.plot(S2[:,0], S2[:,1], linewidth=2)
plt.xlabel("concentration of $S_1$")
plt.ylabel("concentration of $S_2$")
plt.legend()
plt.title('Figure 4.17')


plt.xlim(1, 2.7)
plt.ylim(1.1, 2.3)



#generate nullclines
#s1 nullcline
ns12=np.linspace(0, 4, 100)
ns11=k0/(k1*(1+np.power(ns12,n)))
#s2 nullcline (k2/2 since divided by (1+1^0))
ns22=np.linspace(0, 4, 100)
ns21=k2*ns22/(k1*(1+np.power(ns22,n)))
#plot nullclines
plt.plot(ns11, ns12, 'k--')
plt.plot(ns21, ns22, 'k--')

#direction field
x = np.linspace(1, 2.8, 15)
y = np.linspace(1, 2.5, 15)
[xx,yy]=np.meshgrid(x,y)

xdot=k0 - k1*xx*(1+np.power(yy,n))
ydot=k1*xx*(1+np.power(yy,n)) - k2*yy
L = np.sqrt(np.power(xdot,2) + np.power(ydot,2)) # vector lengths
plt.quiver(xx, yy, xdot/L, ydot/L) 




