Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions emu_mps/baths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch

from emu_mps.mps import MPS
from emu_mps.mpo import MPO


class Baths:
"""
Helper container managing left and right environment ("bath") tensors
used during MPS sweeps.

Maintains stacks of left and right baths and provides utilities to
update them during left-to-right and right-to-left sweeps.
"""

_left: list[torch.Tensor]
_right: list[torch.Tensor]

def __init__(
self,
state: MPS,
hamiltonian: MPO,
):
"""
Initialize baths from the given MPS state and Hamiltonian.
Left baths start with a trivial identity-like tensor.
Right baths are initialized from the full system.
"""
self.device = state.factors[0].device
self.dtype = state.factors[0].dtype

self._left = [torch.ones(1, 1, 1, dtype=self.dtype, device=self.device)]
self._right = self._right_baths(state, hamiltonian, final_qubit=2)
assert len(self._right) == len(state.factors) - 1

def current_left(self) -> torch.Tensor:
return self._left[-1]

def current_right(self) -> torch.Tensor:
return self._right[-1]

def current(self) -> tuple[torch.Tensor, torch.Tensor]:
return self.current_left(), self.current_right()

def _new_left_bath(
self,
bath: torch.Tensor,
state: torch.Tensor,
op: torch.Tensor,
) -> torch.Tensor:
# this order is more efficient than contracting the op first in general
bath = torch.tensordot(bath, state.conj(), ([0], [0]))
bath = torch.tensordot(bath, op.to(bath.device), ([0, 2], [0, 1]))
bath = torch.tensordot(bath, state, ([0, 2], [0, 1]))
return bath

def append_left(self, state: MPS, hamiltonian: MPO, sweep_index: int) -> None:
new_node = self._new_left_bath(
self.current_left(),
state.factors[sweep_index],
hamiltonian.factors[sweep_index],
)
self._left.append(new_node.to(state.factors[sweep_index + 1].device))
return

def _new_right_bath(
self, bath: torch.Tensor, state: torch.Tensor, op: torch.Tensor
) -> torch.Tensor:
bath = torch.tensordot(state, bath, ([2], [2]))
bath = torch.tensordot(op.to(bath.device), bath, ([2, 3], [1, 3]))
bath = torch.tensordot(state.conj(), bath, ([1, 2], [1, 3]))
return bath

def _right_baths(
self,
state: MPS,
op: MPO,
final_qubit: int,
) -> list[torch.Tensor]:
"""
function to compute the right baths. The three indices in the bath are as follows:
(bond of state conj, bond of operator, bond of state)
The baths have shape
-xx
-xx
-xx
with the index ordering (top, middle, bottom)
bath tensors are put on the device of the factor to the left
"""

state_factor = state.factors[-1]
bath = torch.ones(1, 1, 1, device=state_factor.device, dtype=state_factor.dtype)
baths = [bath]
for i in range(len(state.factors) - 1, final_qubit - 1, -1):
bath = self._new_right_bath(bath, state.factors[i], op.factors[i])
bath = bath.to(state.factors[i - 1].device)
baths.append(bath)
return baths

def append_right(self, state: MPS, hamiltonian: MPO, sweep_index: int) -> None:
self._right.append(
self._new_right_bath(
self.current_right(),
state.factors[sweep_index],
hamiltonian.factors[sweep_index],
).to(state.factors[sweep_index - 1].device)
)

def pop_left(self) -> torch.Tensor:
return self._left.pop()

def pop_right(self) -> torch.Tensor:
return self._right.pop()
75 changes: 27 additions & 48 deletions emu_mps/mps_backend_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@
evolve_pair,
evolve_single,
minimize_energy_pair,
new_right_bath,
right_baths,
)
from emu_mps.utils import (
extended_mpo_factors,
extended_mps_factors,
get_extended_site_index,
new_left_bath,
)
from emu_mps.baths import Baths

dtype = torch.complex128

Expand Down Expand Up @@ -101,8 +99,7 @@ class MPSBackendImpl:
well_prepared_qubits_filter: Optional[torch.Tensor]
hamiltonian: MPO
state: MPS
left_baths: list[torch.Tensor]
right_baths: list[torch.Tensor]
baths: Baths
target_time: float
results: Results
_swipe_direction: SwipeDirection = SwipeDirection.LEFT_TO_RIGHT
Expand Down Expand Up @@ -307,17 +304,7 @@ def update_H_no_noise(self) -> None:
)

def init_baths(self) -> None:
self.left_baths = [
torch.ones(1, 1, 1, dtype=dtype, device=self.state.factors[0].device)
]
self.right_baths = right_baths(self.state, self.hamiltonian, final_qubit=2)
assert len(self.right_baths) == self.qubit_count - 1

def get_current_right_bath(self) -> torch.Tensor:
return self.right_baths[-1]

def get_current_left_bath(self) -> torch.Tensor:
return self.left_baths[-1]
self.baths = Baths(self.state, self.hamiltonian)

def init(self) -> None:
self.init_dark_qubits()
Expand All @@ -335,13 +322,13 @@ def _evolve(
) -> None:
"""
Time-evolve the state's tensors located at the given 1 or 2 indices by dt,
using the baths stored in self.left_baths and self.right_baths.
using the baths stored in self.baths.
When 2 indices are given, they need to be consecutive.
Updates the state's orthogonality center according to orth_center_right.
"""
assert 1 <= len(indices) <= 2

baths = (self.get_current_left_bath(), self.get_current_right_bath())
baths = self.baths.current()

if len(indices) == 1:
assert orth_center_right is None
Expand Down Expand Up @@ -420,15 +407,13 @@ def _left_to_right_update_tdvp(self, delta_time: float) -> None:
dt=delta_time / 2,
orth_center_right=True,
)
self.left_baths.append(
new_left_bath(
self.get_current_left_bath(),
self.state.factors[self._sweep_index],
self.hamiltonian.factors[self._sweep_index],
).to(self.state.factors[self._sweep_index + 1].device)
self.baths.append_left(
self.state,
self.hamiltonian,
self._sweep_index,
)
self._evolve(self._sweep_index + 1, dt=-delta_time / 2)
self.right_baths.pop()
self.baths.pop_right()
self._sweep_index += 1
else:
# Time-evolution of the rightmost 2 tensors
Expand All @@ -442,18 +427,16 @@ def _left_to_right_update_tdvp(self, delta_time: float) -> None:

def _right_to_left_update_tdvp(self, delta_time: float) -> None:
if self._sweep_index > 0:
self.right_baths.append(
new_right_bath(
self.get_current_right_bath(),
self.state.factors[self._sweep_index + 1],
self.hamiltonian.factors[self._sweep_index + 1],
).to(self.state.factors[self._sweep_index].device)
self.baths.append_right(
self.state,
self.hamiltonian,
self._sweep_index + 1,
)
if not self.has_lindblad_noise:
# Free memory because it won't be used anymore
deallocate_tensor(self.right_baths[-2])
deallocate_tensor(self.baths._right[-2])
self._evolve(self._sweep_index, dt=-delta_time / 2)
self.left_baths.pop()
self.baths.pop_left()
self._evolve(
self._sweep_index - 1,
self._sweep_index,
Expand Down Expand Up @@ -774,7 +757,7 @@ def progress(self) -> None:
new_L, new_R, energy = minimize_energy_pair(
state_factors=self.state.factors[idx : idx + 2],
ham_factors=self.hamiltonian.factors[idx : idx + 2],
baths=(self.left_baths[-1], self.right_baths[-1]),
baths=self.baths.current(),
orth_center_right=orth_center_right,
config=self.config,
residual_tolerance=self.config.precision,
Expand All @@ -795,29 +778,25 @@ def progress(self) -> None:

def _left_to_right_update(self, idx: int) -> None:
if idx < self.qubit_count - 2:
self.left_baths.append(
new_left_bath(
self.get_current_left_bath(),
self.state.factors[idx],
self.hamiltonian.factors[idx],
).to(self.state.factors[idx + 1].device)
self.baths.append_left(
self.state,
self.hamiltonian,
idx,
)
self.right_baths.pop()
self.baths.pop_right()
self._sweep_index += 1

if self._sweep_index == self.qubit_count - 2:
self._swipe_direction = SwipeDirection.RIGHT_TO_LEFT

def _right_to_left_update(self, idx: int) -> None:
if idx > 0:
self.right_baths.append(
new_right_bath(
self.get_current_right_bath(),
self.state.factors[idx + 1],
self.hamiltonian.factors[idx + 1],
).to(self.state.factors[idx].device)
self.baths.append_right(
self.state,
self.hamiltonian,
idx + 1,
)
self.left_baths.pop()
self.baths.pop_left()
self._sweep_index -= 1

if self._sweep_index == 0:
Expand Down
23 changes: 11 additions & 12 deletions emu_mps/solver_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,18 @@ def new_right_bath(
return bath


"""
function to compute the right baths. The three indices in the bath are as follows:
(bond of state conj, bond of operator, bond of state)
The baths have shape
-xx
-xx
-xx
with the index ordering (top, middle, bottom)
bath tensors are put on the device of the factor to the left
"""


def right_baths(state: MPS, op: MPO, final_qubit: int) -> list[torch.Tensor]:
"""
function to compute the right baths. The three indices in the bath are as follows:
(bond of state conj, bond of operator, bond of state)
The baths have shape
-xx
-xx
-xx
with the index ordering (top, middle, bottom)
bath tensors are put on the device of the factor to the left
"""

state_factor = state.factors[-1]
bath = torch.ones(1, 1, 1, device=state_factor.device, dtype=state_factor.dtype)
baths = [bath]
Expand Down
10 changes: 5 additions & 5 deletions test/emu_mps/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,25 +676,25 @@ def check_baths(impl: MPSBackendImpl):
# Mocking MPSBackendImpl.get_current_right_bath() to check that
# the right baths administration happens properly when a quantum jump occurs.

assert len(impl.right_baths) in [
assert len(impl.baths._right) in [
impl.state.num_sites - impl._sweep_index,
impl.state.num_sites - impl._sweep_index - 1,
]

expected_right_baths = right_baths(
impl.state,
impl.hamiltonian,
final_qubit=impl.state.num_sites - len(impl.right_baths) + 1,
final_qubit=impl.state.num_sites - len(impl.baths.right) + 1,
)
assert all(
torch.allclose(actual, expected)
for actual, expected in zip(impl.right_baths, expected_right_baths)
for actual, expected in zip(impl.baths.right, expected_right_baths)
)

return impl.right_baths[-1]
return impl.baths.right[-1]

with patch(
"emu_mps.mps_backend_impl.MPSBackendImpl.get_current_right_bath", autospec=True
"emu_mps.mps_backend_impl.MPSBackendImpl.Baths.current_right", autospec=True
) as get_current_right_bath_mock:
get_current_right_bath_mock.side_effect = check_baths

Expand Down
Loading