Coverage for /builds/alexhroom/ase/ase/dft/bandgap.py: 63.08%

130 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-05 14:37 +0000

1from dataclasses import dataclass 

2import warnings 

3 

4import numpy as np 

5 

6 

7spin_error = ( 

8 'The spin keyword is no longer supported. Please call the function ' 

9 'with the energies corresponding to the desired spins.') 

10_deprecated = object() 

11 

12 

13def get_band_gap(calc, direct=False, spin=_deprecated): 

14 warnings.warn('Please use ase.dft.bandgap.bandgap() instead!') 

15 gap, (s1, k1, n1), (s2, k2, n2) = bandgap(calc, direct, spin=spin) 

16 ns = calc.get_number_of_spins() 

17 if ns == 2: 

18 return gap, (s1, k1), (s2, k2) 

19 return gap, k1, k2 

20 

21 

22@dataclass 

23class GapInfo: 

24 eigenvalues: np.ndarray 

25 

26 def __post_init__(self): 

27 self._gapinfo = _bandgap(self.eigenvalues, direct=False) 

28 self._direct_gapinfo = _bandgap(self.eigenvalues, direct=True) 

29 

30 @classmethod 

31 def fromcalc(cls, calc): 

32 kpts = calc.get_ibz_k_points() 

33 nk = len(kpts) 

34 ns = calc.get_number_of_spins() 

35 eigenvalues = np.array([[calc.get_eigenvalues(kpt=k, spin=s) 

36 for k in range(nk)] 

37 for s in range(ns)]) 

38 

39 efermi = calc.get_fermi_level() 

40 return cls(eigenvalues - efermi) 

41 

42 def gap(self): 

43 return self._gapinfo 

44 

45 def direct_gap(self): 

46 return self._direct_gapinfo 

47 

48 @property 

49 def is_metallic(self) -> bool: 

50 return self._gapinfo[0] == 0.0 

51 

52 @property 

53 def gap_is_direct(self) -> bool: 

54 """Whether the direct and indirect gaps are the same transition.""" 

55 return self._gapinfo[1:] == self._direct_gapinfo[1:] 

56 

57 def description(self, *, ibz_kpoints=None) -> str: 

58 """Return human-friendly description of direct/indirect gap. 

59 

60 If ibz_k_points are given, coordinates are printed as well.""" 

61 from typing import List 

62 

63 lines: List[str] = [] 

64 add = lines.append 

65 

66 def skn(skn): 

67 """Convert k-point indices (s, k, n) to string.""" 

68 description = 's={}, k={}, n={}'.format(*skn) 

69 if ibz_kpoints is not None: 

70 coordtxt = '[{:.2f}, {:.2f}, {:.2f}]'.format( 

71 *ibz_kpoints[skn[1]]) 

72 description = f'{description}, [{coordtxt}]' 

73 return f'({description})' 

74 

75 gap, skn1, skn2 = self.gap() 

76 direct_gap, skn_direct1, skn_direct2 = self.direct_gap() 

77 

78 if self.is_metallic: 

79 add('No gap') 

80 else: 

81 add(f'Gap: {gap:.3f} eV') 

82 add('Transition (v -> c):') 

83 add(f' {skn(skn1)} -> {skn(skn2)}') 

84 

85 if self.gap_is_direct: 

86 add('No difference between direct/indirect transitions') 

87 else: 

88 add('Direct/indirect transitions are different') 

89 add(f'Direct gap: {direct_gap:.3f} eV') 

90 if skn_direct1[0] == skn_direct2[0]: 

91 add(f'Transition at: {skn(skn_direct1)}') 

92 else: 

93 transition = skn((f'{skn_direct1[0]}->{skn_direct2[0]}', 

94 *skn_direct1[1:])) 

95 add(f'Transition at: {transition}') 

96 

97 return '\n'.join(lines) 

98 

99 

100def bandgap(calc=None, direct=False, spin=_deprecated, 

101 eigenvalues=None, efermi=None, output=None, kpts=None): 

102 """Calculates the band-gap. 

103 

104 Parameters: 

105 

106 calc: Calculator object 

107 Electronic structure calculator object. 

108 direct: bool 

109 Calculate direct band-gap. 

110 eigenvalues: ndarray of shape (nspin, nkpt, nband) or (nkpt, nband) 

111 Eigenvalues. 

112 efermi: float 

113 Fermi level (defaults to 0.0). 

114 

115 Returns a (gap, p1, p2) tuple where p1 and p2 are tuples of indices of the 

116 valence and conduction points (s, k, n). 

117 

118 Example: 

119 

120 >>> gap, p1, p2 = bandgap(silicon.calc) 

121 >>> print(gap, p1, p2) 

122 1.2 (0, 0, 3), (0, 5, 4) 

123 >>> gap, p1, p2 = bandgap(silicon.calc, direct=True) 

124 >>> print(gap, p1, p2) 

125 3.4 (0, 0, 3), (0, 0, 4) 

126 """ 

127 

128 if spin is not _deprecated: 

129 raise RuntimeError(spin_error) 

130 

131 if calc: 

132 kpts = calc.get_ibz_k_points() 

133 nk = len(kpts) 

134 ns = calc.get_number_of_spins() 

135 eigenvalues = np.array([[calc.get_eigenvalues(kpt=k, spin=s) 

136 for k in range(nk)] 

137 for s in range(ns)]) 

138 if efermi is None: 

139 efermi = calc.get_fermi_level() 

140 

141 efermi = efermi or 0.0 

142 

143 gapinfo = GapInfo(eigenvalues - efermi) 

144 

145 e_skn = gapinfo.eigenvalues 

146 if eigenvalues.ndim == 2: 

147 e_skn = e_skn[np.newaxis] # spinors 

148 

149 if not np.isfinite(e_skn).all(): 

150 raise ValueError('Bad eigenvalues!') 

151 

152 gap, (s1, k1, n1), (s2, k2, n2) = _bandgap(e_skn, direct) 

153 

154 if eigenvalues.ndim != 3: 

155 p1 = (k1, n1) 

156 p2 = (k2, n2) 

157 else: 

158 p1 = (s1, k1, n1) 

159 p2 = (s2, k2, n2) 

160 

161 return gap, p1, p2 

162 

163 

164def _bandgap(e_skn, direct): 

165 """Helper function.""" 

166 ns, nk, nb = e_skn.shape 

167 s1 = s2 = k1 = k2 = n1 = n2 = None 

168 

169 N_sk = (e_skn < 0.0).sum(2) # number of occupied bands 

170 

171 # Check for bands crossing the fermi-level 

172 if ns == 1: 

173 if np.ptp(N_sk[0]) > 0: 

174 return 0.0, (None, None, None), (None, None, None) 

175 else: 

176 if (np.ptp(N_sk, axis=1) > 0).any(): 

177 return 0.0, (None, None, None), (None, None, None) 

178 

179 if (N_sk == 0).any() or (N_sk == nb).any(): 

180 raise ValueError('Too few bands!') 

181 

182 e_skn = np.array([[e_skn[s, k, N_sk[s, k] - 1:N_sk[s, k] + 1] 

183 for k in range(nk)] 

184 for s in range(ns)]) 

185 ev_sk = e_skn[:, :, 0] # valence band 

186 ec_sk = e_skn[:, :, 1] # conduction band 

187 

188 if ns == 1: 

189 s1 = 0 

190 s2 = 0 

191 gap, k1, k2 = find_gap(ev_sk[0], ec_sk[0], direct) 

192 n1 = N_sk[0, 0] - 1 

193 n2 = n1 + 1 

194 return gap, (0, k1, n1), (0, k2, n2) 

195 

196 gap, k1, k2 = find_gap(ev_sk.ravel(), ec_sk.ravel(), direct) 

197 if direct: 

198 # Check also spin flips: 

199 for s in [0, 1]: 

200 g, k, _ = find_gap(ev_sk[s], ec_sk[1 - s], direct) 

201 if g < gap: 

202 gap = g 

203 k1 = k + nk * s 

204 k2 = k + nk * (1 - s) 

205 

206 if gap > 0.0: 

207 s1, k1 = divmod(k1, nk) 

208 s2, k2 = divmod(k2, nk) 

209 n1 = N_sk[s1, k1] - 1 

210 n2 = N_sk[s2, k2] 

211 return gap, (s1, k1, n1), (s2, k2, n2) 

212 return 0.0, (None, None, None), (None, None, None) 

213 

214 

215def find_gap(ev_k, ec_k, direct): 

216 """Helper function.""" 

217 if direct: 

218 gap_k = ec_k - ev_k 

219 k = gap_k.argmin() 

220 return gap_k[k], k, k 

221 kv = ev_k.argmax() 

222 kc = ec_k.argmin() 

223 return ec_k[kc] - ev_k[kv], kv, kc