11import numpy as np
2+ import sympy as sp
3+ from devito import Dimension , Constant , TimeFunction , Eq , solve , Operator
24#import matplotlib.pyplot as plt
35import scitools .std as plt
46
@@ -11,27 +13,33 @@ def solver(I, V, m, b, s, F, dt, T, damping='linear'):
1113 'quadratic', f(u')=b*u'*abs(u').
1214 F(t) and s(u) are Python functions.
1315 """
14- dt = float (dt ); b = float (b ); m = float (m ) # avoid integer div.
16+ dt = float (dt )
17+ b = float (b )
18+ m = float (m )
1519 Nt = int (round (T / dt ))
16- u = np .zeros (Nt + 1 )
17- t = np .linspace (0 , Nt * dt , Nt + 1 )
20+ t = Dimension ('t' , spacing = Constant ('h_t' ))
21+
22+ u = TimeFunction (name = 'u' , dimensions = (t ,),
23+ shape = (Nt + 1 ,), space_order = 2 )
24+
25+ u .data [0 ] = I
1826
19- u [0 ] = I
2027 if damping == 'linear' :
21- u [1 ] = u [0 ] + dt * V + dt ** 2 / (2 * m )* (- b * V - s (u [0 ]) + F (t [0 ]))
28+ # dtc for central difference (default for time is forward, 1st order)
29+ eqn = m * u .dt2 + b * u .dtc + s (u ) - F (u )
30+ stencil = Eq (u .forward , solve (eqn , u .forward ))
2231 elif damping == 'quadratic' :
23- u [1 ] = u [0 ] + dt * V + \
24- dt ** 2 / (2 * m )* (- b * V * abs (V ) - s (u [0 ]) + F (t [0 ]))
32+ # fd_order set as backward derivative used is 1st order
33+ eqn = m * u .dt2 + b * u .dt * sp .Abs (u .dtl (fd_order = 1 )) + s (u ) - F (u )
34+ stencil = Eq (u .forward , solve (eqn , u .forward ))
35+ # First timestep needs to have the backward timestep substituted
36+ stencil_init = stencil .subs (u .backward , u .forward - 2 * t .spacing * V )
37+ op_init = Operator (stencil_init , name = 'first_timestep' )
38+ op = Operator (stencil , name = 'main_loop' )
39+ op_init .apply (h_t = dt , t_M = 1 )
40+ op .apply (h_t = dt , t_m = 1 , t_M = Nt - 1 )
2541
26- for n in range (1 , Nt ):
27- if damping == 'linear' :
28- u [n + 1 ] = (2 * m * u [n ] + (b * dt / 2 - m )* u [n - 1 ] +
29- dt ** 2 * (F (t [n ]) - s (u [n ])))/ (m + b * dt / 2 )
30- elif damping == 'quadratic' :
31- u [n + 1 ] = (2 * m * u [n ] - m * u [n - 1 ] + b * u [n ]* abs (u [n ] - u [n - 1 ])
32- + dt ** 2 * (F (t [n ]) - s (u [n ])))/ \
33- (m + b * abs (u [n ] - u [n - 1 ]))
34- return u , t
42+ return u .data , np .linspace (0 , Nt * dt , Nt + 1 )
3543
3644def visualize (u , t , title = '' , filename = 'tmp' ):
3745 plt .plot (t , u , 'b-' )
@@ -46,8 +54,6 @@ def visualize(u, t, title='', filename='tmp'):
4654 plt .savefig (filename + '.pdf' )
4755 plt .show ()
4856
49- import sympy as sym
50-
5157def test_constant ():
5258 """Verify a constant solution."""
5359 u_exact = lambda t : I
@@ -68,24 +74,24 @@ def test_constant():
6874
6975def lhs_eq (t , m , b , s , u , damping = 'linear' ):
7076 """Return lhs of differential equation as sympy expression."""
71- v = sym .diff (u , t )
77+ v = sp .diff (u , t )
7278 if damping == 'linear' :
73- return m * sym .diff (u , t , t ) + b * v + s (u )
79+ return m * sp .diff (u , t , t ) + b * v + s (u )
7480 else :
75- return m * sym .diff (u , t , t ) + b * v * sym .Abs (v ) + s (u )
81+ return m * sp .diff (u , t , t ) + b * v * sp .Abs (v ) + s (u )
7682
7783def test_quadratic ():
7884 """Verify a quadratic solution."""
7985 I = 1.2 ; V = 3 ; m = 2 ; b = 0.9
8086 s = lambda u : 4 * u
81- t = sym .Symbol ('t' )
87+ t = sp .Symbol ('t' )
8288 dt = 0.2
8389 T = 2
8490
8591 q = 2 # arbitrary constant
8692 u_exact = I + V * t + q * t ** 2
87- F = sym .lambdify (t , lhs_eq (t , m , b , s , u_exact , 'linear' ))
88- u_exact = sym .lambdify (t , u_exact , modules = 'numpy' )
93+ F = sp .lambdify (t , lhs_eq (t , m , b , s , u_exact , 'linear' ))
94+ u_exact = sp .lambdify (t , u_exact , modules = 'numpy' )
8995 u1 , t1 = solver (I , V , m , b , s , F , dt , T , 'linear' )
9096 diff = np .abs (u_exact (t1 ) - u1 ).max ()
9197 tol = 1E-13
@@ -94,8 +100,8 @@ def test_quadratic():
94100 # In the quadratic damping case, u_exact must be linear
95101 # in order exactly recover this solution
96102 u_exact = I + V * t
97- F = sym .lambdify (t , lhs_eq (t , m , b , s , u_exact , 'quadratic' ))
98- u_exact = sym .lambdify (t , u_exact , modules = 'numpy' )
103+ F = sp .lambdify (t , lhs_eq (t , m , b , s , u_exact , 'quadratic' ))
104+ u_exact = sp .lambdify (t , u_exact , modules = 'numpy' )
99105 u2 , t2 = solver (I , V , m , b , s , F , dt , T , 'quadratic' )
100106 diff = np .abs (u_exact (t2 ) - u2 ).max ()
101107 assert diff < tol
@@ -127,11 +133,11 @@ def test_mms():
127133 """Use method of manufactured solutions."""
128134 m = 4. ; b = 1
129135 w = 1.5
130- t = sym .Symbol ('t' )
131- u_exact = 3 * sym .exp (- 0.2 * t )* sym .cos (1.2 * t )
136+ t = sp .Symbol ('t' )
137+ u_exact = 3 * sp .exp (- 0.2 * t )* sp .cos (1.2 * t )
132138 I = u_exact .subs (t , 0 ).evalf ()
133- V = sym .diff (u_exact , t ).subs (t , 0 ).evalf ()
134- u_exact_py = sym .lambdify (t , u_exact , modules = 'numpy' )
139+ V = sp .diff (u_exact , t ).subs (t , 0 ).evalf ()
140+ u_exact_py = sp .lambdify (t , u_exact , modules = 'numpy' )
135141 s = lambda u : u ** 3
136142 dt = 0.2
137143 T = 6
@@ -140,14 +146,14 @@ def test_mms():
140146 # Run grid refinements and compute exact error
141147 for i in range (5 ):
142148 F_formula = lhs_eq (t , m , b , s , u_exact , 'linear' )
143- F = sym .lambdify (t , F_formula )
149+ F = sp .lambdify (t , F_formula )
144150 u1 , t1 = solver (I , V , m , b , s , F , dt , T , 'linear' )
145151 error = np .sqrt (np .sum ((u_exact_py (t1 ) - u1 )** 2 )* dt )
146152 errors_linear .append ((dt , error ))
147153
148154 F_formula = lhs_eq (t , m , b , s , u_exact , 'quadratic' )
149155 #print sym.latex(F_formula, mode='plain')
150- F = sym .lambdify (t , F_formula )
156+ F = sp .lambdify (t , F_formula )
151157 u2 , t2 = solver (I , V , m , b , s , F , dt , T , 'quadratic' )
152158 error = np .sqrt (np .sum ((u_exact_py (t2 ) - u2 )** 2 )* dt )
153159 errors_quadratic .append ((dt , error ))
0 commit comments