Skip to content

Commit

Permalink
Explicit adjoints for basic_gates.rotation bloqs (#881)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar authored Apr 24, 2024
1 parent 0c09d53 commit 0ea00a7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
23 changes: 22 additions & 1 deletion qualtran/bloqs/basic_gates/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import cached_property
from typing import Protocol, Union

import attrs
import cirq
import numpy as np
import sympy
Expand Down Expand Up @@ -88,6 +88,9 @@ def __pow__(self, power):
g = self.cirq_gate**power
return ZPowGate(g.exponent, g.global_shift, self.eps)

def adjoint(self) -> 'ZPowGate':
return attrs.evolve(self, exponent=-self.exponent)


@bloq_example
def _z_pow() -> ZPowGate:
Expand Down Expand Up @@ -120,6 +123,9 @@ def __pow__(self, power):
g = self.cirq_gate**power
return CZPowGate(g.exponent, g.global_shift, self.eps)

def adjoint(self) -> 'CZPowGate':
return attrs.evolve(self, exponent=-self.exponent)


@frozen
class XPowGate(CirqGateAsBloqBase):
Expand Down Expand Up @@ -173,6 +179,9 @@ def decompose_bloq(self) -> 'CompositeBloq':
def cirq_gate(self) -> cirq.Gate:
return cirq.XPowGate(exponent=self.exponent, global_shift=self.global_shift)

def adjoint(self) -> 'XPowGate':
return attrs.evolve(self, exponent=-self.exponent)


@bloq_example
def _x_pow() -> XPowGate:
Expand Down Expand Up @@ -235,6 +244,9 @@ def decompose_bloq(self) -> 'CompositeBloq':
def cirq_gate(self) -> cirq.Gate:
return cirq.YPowGate(exponent=self.exponent, global_shift=self.global_shift)

def adjoint(self) -> 'YPowGate':
return attrs.evolve(self, exponent=-self.exponent)


@bloq_example
def _y_pow() -> YPowGate:
Expand Down Expand Up @@ -274,6 +286,9 @@ def decompose_bloq(self) -> 'CompositeBloq':
def cirq_gate(self) -> cirq.Gate:
return cirq.rz(self.angle)

def adjoint(self) -> 'Rz':
return attrs.evolve(self, angle=-self.angle)


@frozen
class Rx(CirqGateAsBloqBase):
Expand All @@ -287,6 +302,9 @@ def decompose_bloq(self) -> 'CompositeBloq':
def cirq_gate(self) -> cirq.Gate:
return cirq.rx(self.angle)

def adjoint(self) -> 'Rx':
return attrs.evolve(self, angle=-self.angle)


@frozen
class Ry(CirqGateAsBloqBase):
Expand All @@ -300,6 +318,9 @@ def decompose_bloq(self) -> 'CompositeBloq':
def cirq_gate(self) -> cirq.Gate:
return cirq.ry(self.angle)

def adjoint(self) -> 'Ry':
return attrs.evolve(self, angle=-self.angle)


@bloq_example
def _rx() -> Rx:
Expand Down
16 changes: 15 additions & 1 deletion qualtran/bloqs/basic_gates/rotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

import cirq
import numpy as np
import pytest
from cirq.ops import SimpleQubitManager

from qualtran._infra.gate_with_registers import get_named_qubits
from qualtran.bloqs.basic_gates import Rx, Ry, Rz, XPowGate, YPowGate, ZPowGate
from qualtran.bloqs.basic_gates import CZPowGate, Rx, Ry, Rz, XPowGate, YPowGate, ZPowGate
from qualtran.bloqs.basic_gates.rotation import _rx, _ry, _rz


Expand All @@ -29,6 +30,19 @@ def test_rotation_gates():
assert Rz(angle).t_complexity().t_incl_rotations() == 1


@pytest.mark.parametrize(
"bloq",
[Rx(0.01), Ry(0.01), Rz(0.01), ZPowGate(0.01), YPowGate(0.01), XPowGate(0.01), CZPowGate(0.01)],
)
def test_rotation_gates_adjoint(bloq):
assert type(bloq) == type(bloq.adjoint())
np.testing.assert_allclose(
bloq.tensor_contract() @ bloq.adjoint().tensor_contract(),
np.identity(2 ** bloq.signature.n_qubits()),
atol=1e-8,
)


def test_as_cirq_op():
bloq = Rx(angle=np.pi / 4.0, eps=1e-8)
quregs = get_named_qubits(bloq.signature.lefts())
Expand Down

0 comments on commit 0ea00a7

Please sign in to comment.