Commit e18c3e43 authored by mischa's avatar mischa
Browse files

Complex Roots Bugfix

parent 2a14da20
...@@ -182,7 +182,8 @@ class Lattice: ...@@ -182,7 +182,8 @@ class Lattice:
def c_s_sq(self): def c_s_sq(self):
if self._c_s_sq: if self._c_s_sq:
return self._c_s_sq return self._c_s_sq
return self.velocity_set()[0] self._c_s_sq = self.velocity_set()[0]
return self.c_s_sq
@property @property
def weights(self): def weights(self):
...@@ -516,7 +517,7 @@ class Lattice: ...@@ -516,7 +517,7 @@ class Lattice:
self.solution_expected() self.solution_expected()
self._calculate_coefficients() self._calculate_coefficients()
c_s_sq = sp.Symbol("c_s_sq") c_s_sq = sp.Symbol("c_s_sq", real=True)
if approximate: if approximate:
self._weight_polynomials = [ self._weight_polynomials = [
sp.Poly([approximate_ratio(x) for x in pol_coffs], c_s_sq) sp.Poly([approximate_ratio(x) for x in pol_coffs], c_s_sq)
...@@ -565,6 +566,11 @@ class Lattice: ...@@ -565,6 +566,11 @@ class Lattice:
weights = [x if abs(x) >= 1e-10 else 0 for x in weights] weights = [x if abs(x) >= 1e-10 else 0 for x in weights]
for i, shell in enumerate(self.shells): for i, shell in enumerate(self.shells):
shell.set_weight(c_s_sq, weights[i]) shell.set_weight(c_s_sq, weights[i])
if isinstance(c_s_sq, sp.CRootOf):
c_s_sq = complex(c_s_sq)
if c_s_sq.imag != 0:
raise ValueError("Imaginary root returned. This should not happen")
c_s_sq = float(c_s_sq.real)
self._c_s_sq = c_s_sq self._c_s_sq = c_s_sq
self._reduced_weights = weights self._reduced_weights = weights
return weights return weights
......
...@@ -139,6 +139,16 @@ class TestInitSchemes(unittest.TestCase): ...@@ -139,6 +139,16 @@ class TestInitSchemes(unittest.TestCase):
self.assertTrue(self.lattice.reduced_weights == self.lattice_from_order.reduced_weights) self.assertTrue(self.lattice.reduced_weights == self.lattice_from_order.reduced_weights)
class TestVelocitySets(unittest.TestCase):
def setUp(self):
self.seed = 20
self.lattice = Lattice.from_name("D2V37")
def test_velocity_set(self):
c_s_sq, weights, velocities = self.lattice.velocity_set()
self.assertAlmostEqual(c_s_sq, 0.6979533220196831, places=10)
if __name__ == "__main__": if __name__ == "__main__":
logger.disabled = True logger.disabled = True
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment