diff --git a/pytest_mh/utils/firewall.py b/pytest_mh/utils/firewall.py index f80066e..ac7a396 100644 --- a/pytest_mh/utils/firewall.py +++ b/pytest_mh/utils/firewall.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import abstractmethod +from random import randrange from typing import Any, Literal, TypeAlias from .. import MultihostHost, MultihostRole, MultihostUtility @@ -232,6 +233,7 @@ def __init__(self, host: MultihostHost) -> None: self.__inbound: FirewalldInboundRules = FirewalldInboundRules(self) self.__outbound: FirewalldOutboundRules = FirewalldOutboundRules(self) + self._policies: list[str] = [] self._priority: int = 30000 """ @@ -244,12 +246,24 @@ def __init__(self, host: MultihostHost) -> None: not remove "accept" rule but the "drop" rule takes precedence. """ + def setup(self) -> None: + """ + Set the firewall up. + + :meta private: + """ + super().setup() + self.add_policy(f"test-policy-{randrange(99999)}", ingress="HOST", egress="ANY") + self.host.conn.exec(["firewall-cmd", "--reload"], log_level=ProcessLogLevel.Error) + def teardown(self) -> None: """ Revert all firewall changes. :meta private: """ + for policy in list(self._policies): + self.remove_policy(policy) self.host.conn.exec(["firewall-cmd", "--reload"]) super().teardown() @@ -289,103 +303,99 @@ def _next_priority(self) -> int: self._priority -= 1 return priority - def add_direct_rule( + @property + def _default_policy(self) -> str: + """ + Returns the name of the default policy. + + The default policy is the first one to be added. + This is usually the one created at ``setup()``. + + If no policy was created (quite strange situation), an exception is raised. + """ + return self._policies[0] + + def add_policy( self, - chain: str, - args: list[Any], + name: str, *, - table: str = "filter", - ip_family: Literal["ipv4", "ipv6", "all"] = "all", + ingress: str | None = None, + egress: str | None = None, priority: int | None = None, + target: Literal["CONTINUE", "ACCEPT", "DROP", "REJECT"] | None = None, ) -> int: """ - Add a new direct rule. + Add a new (permanent) policy. - This methods returns a priority of this rule. You need to use this - priority if you remove the rule with :meth:`remove_direct_rule`. + Except for the name, all parameters are optional. When the priority is not provided, + the next priority is assigned. When the other parameters are not provided, + no value is assigned to the newly created policy and it defaults to ``firewalld``'s + defaults. - :param chain: iptables chain (e.g. INPUT or OUTPUT). - :type chain: str - :param args: iptables arguments - :type args: list[Any] - :param table: iptables table, defaults to "filter" - :type table: str, optional - :param ip_family: If the rules is added as IPv4, IPv6 rule or both, defaults to ``all``. - :type ip_family: Literal["ipv4", "ipv6", "all"], optional - :param priority: Rule priority, defaults to None (= auto-assign next value) - :type priority: int | None, optional - :return: Rule priority, to be used for rule removal. + :param name: The policy name + :type name: str + :param ingress: The ingress zone, not assigned if not provided. + :type ingress: str | None, optional. + :param egress: The egress zone, not assigned if not provided. + :type egress: str | None, optional. + :param priority: Rule priority, defaults to the next priority. + :type priority: int | None, optional. + :param target: Rule target, not assigned if not provided. + :type target: Literal["CONTINUE", "ACCEPT", "DROP", "REJECT"] | None, optional. + :return: Policy priority. :rtype: int """ + self.logger.info(f'Firewalld: adding policy "{name}"') + + cmd = ["firewall-cmd", "--permanent", "--new-policy", name] + self.host.conn.exec(cmd, log_level=ProcessLogLevel.Error) + self._policies.append(name) + cmd[2] = "--policy" + if priority is None: priority = self._next_priority + self.host.conn.exec([*cmd, "--set-priority", str(priority)], log_level=ProcessLogLevel.Error) - cmd = [table, chain, priority, *args] + if ingress is not None: + self.host.conn.exec([*cmd, "--add-ingress-zone", ingress], log_level=ProcessLogLevel.Error) - if ip_family in ["ipv4", "all"]: - self.logger.info(f'Firewalld: adding IPv4 direct firewall rule: {" ".join([str(x) for x in cmd])}') - self.host.conn.exec( - ["firewall-cmd", "--direct", "--add-rule", "ipv4", *cmd], log_level=ProcessLogLevel.Error - ) + if egress is not None: + self.host.conn.exec([*cmd, "--add-egress-zone", egress], log_level=ProcessLogLevel.Error) - if ip_family in ["ipv6", "all"]: - self.logger.info(f'Firewalld: adding IPv6 direct firewall rule: {" ".join([str(x) for x in cmd])}') - self.host.conn.exec( - ["firewall-cmd", "--direct", "--add-rule", "ipv6", *cmd], log_level=ProcessLogLevel.Error - ) + if target is not None: + self.host.conn.exec([*cmd, "--set-target", target], log_level=ProcessLogLevel.Error) return priority - def remove_direct_rule( - self, - priority: int, - chain: str, - args: list[Any], - *, - table: str = "filter", - ip_family: Literal["ipv4", "ipv6", "all"], - ) -> None: + def remove_policy(self, name: str) -> None: """ - Remove direct rule. + Remove a (permanent) policy. - :param priority: Rule priority. - :type priority: int - :param chain: iptables chain (e.g. INPUT or OUTPUT). - :type chain: str - :param args: iptables arguments - :type args: list[Any] - :param table: iptables table, defaults to "filter" - :type table: str, optional - :param ip_family: If the rules is removed from IPv4, IPv6 rules or both, defaults to ``all``. - :type ip_family: Literal["ipv4", "ipv6", "all"], optional + :param name: The name of the policy to be removed. + :type name: str """ - cmd = [table, chain, priority, *args] - - if ip_family in ["ipv4", "all"]: - self.logger.info(f'Firewalld: removing IPv4 direct firewall rule: {" ".join([str(x) for x in cmd])}') - self.host.conn.exec( - ["firewall-cmd", "--direct", "--remove-rule", "ipv4", *cmd], log_level=ProcessLogLevel.Error - ) + self.logger.info(f'Firewalld: removing policy "{name}"') + self.host.conn.exec(["firewall-cmd", "--permanent", "--delete-policy", name], log_level=ProcessLogLevel.Error) + self._policies.remove(name) - if ip_family in ["ipv6", "all"]: - self.logger.info(f'Firewalld: removing IPv6 direct firewall rule: {" ".join([str(x) for x in cmd])}') - self.host.conn.exec( - ["firewall-cmd", "--direct", "--remove-rule", "ipv6", *cmd], log_level=ProcessLogLevel.Error - ) - - def add_rich_rule(self, rule: str, priority: int | None = None) -> int: + def add_rich_rule(self, rule: str, policy: str | None = None, priority: int | None = None) -> int: """ Add rich rule. + When the policy is specified, the rule will be added to that policy, + or to the default policy when not specified. + The parameter "rule" is the part after "rule priority=X". This part is added automatically. That is: .. code-block:: console - $ firewall-cmd --add-rich-rule rule priority={priority} {rule} + $ firewall-cmd [--policy {policy}] --add-rich-rule rule priority={priority} {rule} :param rule: Firewalld rich rule. :type rule: str + :param policy: The policy to use. + :type policy: str | None, optional :param priority: Rule priority, defaults to None (= auto-assign next value) :type priority: int | None, optional @@ -395,13 +405,18 @@ def add_rich_rule(self, rule: str, priority: int | None = None) -> int: if priority is None: priority = self._next_priority + if policy is None: + policy = self._default_policy + rule = f"rule priority={priority} {rule}" self.logger.info(f'Firewalld: adding rich rule "{rule}"') - self.host.conn.exec(["firewall-cmd", "--add-rich-rule", rule], log_level=ProcessLogLevel.Error) + self.host.conn.exec( + ["firewall-cmd", "--policy", policy, "--add-rich-rule", rule], log_level=ProcessLogLevel.Error + ) return priority - def remove_rich_rule(self, priority: int, rule: str) -> None: + def remove_rich_rule(self, priority: int, rule: str, policy: str | None = None) -> None: """ Remove rich rule. @@ -412,14 +427,24 @@ def remove_rich_rule(self, priority: int, rule: str) -> None: $ firewall-cmd --remove-rich-rule rule priority="{priority}" {rule} + When the policy is specified, the rule will be removev from that policy, + or from the default policy when not specified. + :param priority: Rule priority :type priority: int :param rule: Firewalld rich rule. :type rule: str + :param policy: The policy to use. + :type policy: str | None, optional """ + if policy is None: + policy = self._default_policy + rule = f"rule priority={priority} {rule}" self.logger.info(f'Firewalld: removing rich rule "{rule}"') - self.host.conn.exec(["firewall-cmd", "--remove-rich-rule", rule], log_level=ProcessLogLevel.Error) + self.host.conn.exec( + ["firewall-cmd", "--policy", policy, "--remove-rich-rule", rule], log_level=ProcessLogLevel.Error + ) class FirewalldInboundRules(FirewallInboundRules): @@ -500,16 +525,16 @@ def __init__(self, firewall: Firewalld) -> None: self.firewall: Firewalld = firewall def accept_port(self, port: PortSpec | list[PortSpec]) -> None: - self.__add_port(port, action="ACCEPT") + self.__add_port(port, action="accept") def reject_port(self, port: PortSpec | list[PortSpec]) -> None: - self.__add_port(port, action="REJECT") + self.__add_port(port, action="reject") def drop_port(self, port: PortSpec | list[PortSpec]) -> None: - self.__add_port(port, action="DROP") + self.__add_port(port, action="drop") def accept_host(self, host: HostSpec | list[HostSpec]) -> None: - self.__add_host(host, action="ACCEPT") + self.__add_host(host, action="accept") def reject_host(self, host: HostSpec | list[HostSpec]) -> None: """ @@ -530,16 +555,16 @@ def reject_host(self, host: HostSpec | list[HostSpec]) -> None: :param host: Hostname, MultihostHost or MultihostRole object. :type host: HostSpec | list[HostSpec] """ - self.__add_host(host, action="REJECT") + self.__add_host(host, action="reject") def drop_host(self, host: HostSpec | list[HostSpec]) -> None: - self.__add_host(host, action="DROP") + self.__add_host(host, action="drop") def __add_port( self, port: PortSpec | list[PortSpec], *, - action: Literal["ACCEPT", "REJECT", "DROP"], + action: Literal["accept", "reject", "drop"], ) -> None: items = port if isinstance(port, list) else [port] for item in items: @@ -549,15 +574,13 @@ def __add_port( else: port, protocol = self.firewall.parse_port_spec(item) - self.firewall.add_direct_rule( - chain="OUTPUT", args=["-p", protocol, "--dport", port, "-j", action], table="filter" - ) + self.firewall.add_rich_rule(f"port port={port} protocol={protocol} {action}") def __add_host( self, host: HostSpec | list[HostSpec], *, - action: Literal["ACCEPT", "REJECT", "DROP"], + action: Literal["accept", "reject", "drop"], ) -> None: items = host if isinstance(host, list) else [host] for item in items: @@ -576,14 +599,10 @@ def __add_host( ) for ip in ipv4s: - self.firewall.add_direct_rule( - chain="OUTPUT", args=["--destination", ip, "-j", action], table="filter", ip_family="ipv4" - ) + self.firewall.add_rich_rule(f"family=ipv4 destination address={ip} {action}") for ip in ipv6s: - self.firewall.add_direct_rule( - chain="OUTPUT", args=["--destination", ip, "-j", action], table="filter", ip_family="ipv6" - ) + self.firewall.add_rich_rule(f"family=ipv6 destination address={ip} {action}") def __resolve_hostname(self, hostname: str, type: Literal["A", "AAAA"]) -> list[str]: result = self.firewall.host.conn.exec(["dig", "+short", "-t", type, hostname], log_level=ProcessLogLevel.Error)