[关闭]
@wudawufanfan 2016-12-28T07:05:51.000000Z 字数 6089 阅读 467

在此处输入标题

未分类


在此输入正文

  1. from __future__ import print_function
  2. import numpy as np
  3. from scipy.integrate import simps
  4. from scipy import fftpack
  5. import warnings
  6. import collections
  7. class Schrod:
  8. def __init__(self, x, V, n_basis=20):
  9. """
  10. Parameters
  11. ----------
  12. x : array_like, float
  13. Length-N array of evenly spaced spatial coordinates
  14. V : array_like, float
  15. Length-N array giving the potential at each x
  16. n_basis : int
  17. The number of square-well basis states used in the calculation (default=20)
  18. """
  19. # Set the inputs
  20. self.x = x
  21. self.V = V
  22. self.n_basis = n_basis
  23. # Validate the inputs
  24. N = self.x.size
  25. assert N > 1
  26. assert self.x.shape == (N,)
  27. assert (np.diff(x) >= 0).all()
  28. V_shape = self.V.shape
  29. V_shape_len = len(V_shape)
  30. assert V_shape_len == 1 or V_shape_len == 2
  31. if V_shape_len == 1:
  32. assert V_shape == (N,)
  33. elif V_shape_len == 2:
  34. assert V_shape[1] == N
  35. assert isinstance(n_basis, int)
  36. assert n_basis > 0
  37. # Set the derived quantities
  38. self._N = N
  39. self.dx = x[1] - x[0]
  40. self._x_min = x[0]
  41. self._x_max = x[-1]
  42. self.box_size = np.abs(self._x_max - self._x_min)
  43. self._x_center = self._x_min + self.box_size / 2.
  44. self.dk = 2 * np.pi / self.box_size
  45. self.k = -0.5 * (self._N-1) * self.dk + self.dk * np.arange(self._N)
  46. # Allocate memory for eigenvalues and eigenvectors
  47. self.eigs = np.zeros(shape=(V_shape[0], n_basis))
  48. self.vecs = np.zeros(shape=(V_shape[0], n_basis, n_basis))
  49. def solve(self, verbose=False):
  50. if verbose:
  51. print("Calculating the Hamiltonian matrices...")
  52. Hs = self._H(verbose)
  53. if verbose:
  54. print("Diagonalizing the Hamiltonian matrices...")
  55. self.eigs, self.vecs = np.linalg.eigh(Hs, UPLO='L')
  56. def solve_to_tol(self, n_eig, tol=1e-6, n_init=5, n_max=50, n_step=5, err_type = "max_mean", verbose=False):
  57. """
  58. Increase the basis size until the maximum change among the lowest
  59. <n_eig> eigenvalues is less than <tol>
  60. :param n_eig: The number of eigenvalues (lowest) to check for convergence.
  61. :param tol: The desired maximum relative error in the output
  62. :param n_init: The initial number of basis states. Must satisfy n_init >= n_eig
  63. :param n_max: The maximum number of basis states to include before stopping.
  64. :param n_step: The number of basis states added in each convergence test.
  65. :return: A tuple with the following information (A, B, C):
  66. A: Boolean. True if maximum measured relative error less than or equal to <tol>
  67. B: The number of basis states included when convergence reached.
  68. C: The measured relative error
  69. :param err_type: The type of error to calculate, one of
  70. "max": the maximum error is no greater than :param tol:
  71. "mean": the mean error is no greater than :param tol:
  72. "max_mean" (default): the maximum of the mean error per eigenvalue is no greater than :param tol:
  73. :param verbose: bool, whether to print diagnostic information
  74. """
  75. # The initial truncation
  76. self.set_n_basis(n_init)
  77. self.solve()
  78. eigs = self.eigs[..., 0:n_eig]
  79. err = 1
  80. # The maximum relative error in the lowst <n_eig> eigenvalues
  81. measured_tol = 1 # the max rel. err. of the eigenvalues
  82. n = n_init
  83. passed = False
  84. while (not passed) and (n+n_step <= n_max):
  85. n += n_step
  86. self.set_n_basis(n)
  87. self.solve()
  88. eigs_new = self.eigs[..., 0:n_eig]
  89. err = np.abs((eigs - eigs_new) / 0.5 / (eigs + eigs_new))
  90. if err_type is "max":
  91. measured_tol = np.max(err)
  92. elif err_type is "mean":
  93. measured_tol = np.mean(err)
  94. elif err_type is "max_mean":
  95. measured_tol = np.max(np.mean(err, axis=0))
  96. else:
  97. warnings.warn("Invalid <err_type>. Must be one of 'max', 'mean' or 'max_mean'. Assuming 'max_mean'.")
  98. measured_tol = np.max(np.mean(err, axis=0))
  99. eigs = eigs_new
  100. if verbose:
  101. print("Number of basis states: %i \n Measured error: %.2e" % (n, measured_tol))
  102. if measured_tol <= tol:
  103. passed = True
  104. if not passed:
  105. warnings.warn("Unable to achieve desired tolerance of %.2e.\n"
  106. "Achieved a tolerance of %.2e.\n"
  107. "Try increasing the maximum number of basis states." % (tol, measured_tol))
  108. solution = collections.namedtuple('Solution',
  109. ['eig_errs', 'n_basis_converged', 'passed'])
  110. return solution(err, n, passed)
  111. def psi_eig_x(self):
  112. basis_vec = np.arange(1, self.vecs.shape[-1] + 1)
  113. return np.tensordot(self.vecs,
  114. self._psi0(basis_vec, self.x), axes=(-2,0))
  115. def prob_eig_x(self):
  116. return self.psi_eig_x() ** 2
  117. def psi_tx(self, psi_0_x, t_vec):
  118. # Caculate the overlap of the initial wavefunction with all of the eigenstates
  119. psis = self.psi_eig_x()
  120. coeffs = simps(x=self.x, y=psi_0_x * psis, axis=-1)
  121. # Calculate the complex phases at each time
  122. phases = np.exp(-1j * np.outer(t_vec, self.eigs))
  123. # Calculate the wavefunction on the grid at each time slice
  124. psi_of_t = np.dot(phases * coeffs, psis)
  125. print(psi_of_t.shape)
  126. return psi_of_t
  127. def prob_tx(self, psi_0, t_vec):
  128. return np.absolute(self.psi_tx(psi_0, t_vec)) ** 2
  129. def psi_tk(self, psi_0, t_vec):
  130. psitx = self.psi_tx(psi_0,t_vec)
  131. psitk = fftpack.fft(psitx, overwrite_x=True)
  132. return fftpack.fftshift(psitk, axes=-1)
  133. def prob_tk(self, psi_0, t_vec):
  134. return np.absolute(self.psi_tk(psi_0, t_vec)) ** 2
  135. def expected_E(self, psi):
  136. integrand = -0.5 * np.conj(psi) * np.gradient(np.gradient(psi, self.dx, axis=-1), self.dx, axis=-1) + \
  137. np.absolute(psi)**2 * self.V
  138. return np.real(simps(x=self.x, y=integrand, axis=-1))
  139. # Set functions
  140. def set_x(self, x):
  141. self.x = x
  142. self._x_min = x[0]
  143. self._x_max = x[-1]
  144. self.box_size = self._x_max = self._x_min
  145. self._x_center = self._x_min + self.box_size / 2.
  146. def set_V(self, V):
  147. self.V = V
  148. def set_n_basis(self, n_basis):
  149. self.n_basis = n_basis
  150. # Private functions:
  151. def _H(self, verbose=False):
  152. n_matels = self.n_basis * (self.n_basis + 1) / 2
  153. # initialize an empty hamiltonian(s)
  154. h=0
  155. if len(self.V.shape) is 2:
  156. h = np.zeros((self.V.shape[0], self.n_basis, self.n_basis))
  157. elif len(self.V.shape) is 1:
  158. h = np.zeros(( self.n_basis, self.n_basis))
  159. for m in range(self.n_basis):
  160. for n in range(m + 1):
  161. h[..., m, n] = self._Vmn(m, n)
  162. # Print a status
  163. n_sofar = (m + 1) * m / 2 + n + 1
  164. percent = n_sofar / n_matels * 100
  165. if verbose:
  166. print("\r Status: %0.2f %% complete" % percent, end='')
  167. if verbose:
  168. print("")
  169. return h + np.diag(self._E0(np.arange(1, self.n_basis + 1)))
  170. def _psi0(self, n, x):
  171. """
  172. Evaluate the nth box state at x
  173. :param n: array-like, 1-indexed state labels
  174. :param x: array-like, positions between -1 and 1
  175. :return: an array of shape (len(n), len(x))
  176. """
  177. kn = n * np.pi / self.box_size
  178. return np.sqrt(2 / self.box_size) * \
  179. np.sin(np.outer(kn, x - self._x_center + self.box_size / 2))
  180. def _E0(self, n):
  181. """
  182. The nth energy level in the box
  183. :param n: the state label
  184. :return: the energy
  185. """
  186. return n ** 2 * np.pi ** 2 / (2. * self.box_size ** 2)
  187. def _matel_integrand(self, m, n):
  188. """
  189. The n,m matrix element of the V potential evaluated at the x coordinates
  190. :param n: the row index
  191. :param m: the column index
  192. :param x: array-like, a vector of x coordinates
  193. :param V: array-like, an array of potential values. The rows correspond to
  194. the entries in x. The columns correspond to different potentials
  195. :return: the integrand of the matrix element
  196. """
  197. return self._psi0(m + 1, self.x) * self._psi0(n + 1, self.x) * self.V
  198. def _Vmn(self, m, n):
  199. return simps(x=self.x, y=self._matel_integrand(m, n), axis=-1)
  200. if __name__ == "__main__":
  201. x = np.linspace(-1,1,10)
  202. V = 1/2 * x**2
  203. eqn = Schrod(x, V)
  204. print(eqn.k)
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注