diff --git a/fastapi_rss/models/feed.py b/fastapi_rss/models/feed.py index e9bfe8c..1d4a2c1 100644 --- a/fastapi_rss/models/feed.py +++ b/fastapi_rss/models/feed.py @@ -46,7 +46,7 @@ def _generate_tree_list(root: etree.Element, key: str, value: List[dict]) -> Non if content is not None: itemroot.text = content else: - RSSFeed.generate_tree(itemroot, item) + RSSFeed.generate_tree(itemroot, item, {}) @staticmethod def _generate_tree_object(root: etree._Element, key: str, value: Union[dict, BaseModel]) -> None: @@ -92,7 +92,7 @@ def _generate_tree_default(root: etree._Element, key: str, value: Any) -> None: element.text = str(value) @staticmethod - def generate_tree(root: etree.Element, dict_: dict): + def generate_tree(root: etree.Element, dict_: dict, nsmap: dict): handlers = { (list, ): RSSFeed._generate_tree_list, (BaseModel, dict): RSSFeed._generate_tree_object, @@ -107,10 +107,15 @@ def generate_tree(root: etree.Element, dict_: dict): break else: RSSFeed._generate_tree_default(root, key, value) + if key == "docs" and 'http://www.w3.org/2005/Atom' in nsmap.values(): + atom_link = etree.SubElement(root, '{http://www.w3.org/2005/Atom}link', nsmap=nsmap) + atom_link.set('href', value) + atom_link.set('rel', 'self') + atom_link.set('type', 'application/rss+xml') def tostring(self, nsmap: Optional[Dict[str, str]] = None): - nsmap = nsmap or {} + nsmap = nsmap or {'atom': 'http://www.w3.org/2005/Atom'} rss = etree.Element('rss', version='2.0', nsmap=nsmap) channel = etree.SubElement(rss, 'channel') - RSSFeed.generate_tree(channel, self.dict()) - return etree.tostring(rss, pretty_print=True, xml_declaration=True) + RSSFeed.generate_tree(channel, self.dict(), nsmap) + return etree.tostring(rss, pretty_print=True, xml_declaration=True, encoding='UTF-8') diff --git a/tests/test_rss_response.py b/tests/test_rss_response.py index 82a60b1..421347b 100644 --- a/tests/test_rss_response.py +++ b/tests/test_rss_response.py @@ -8,7 +8,7 @@ def test_rss_sample_response(client: TestClient, expected_response: str, capsys) response = client.get('/1') assert response.status_code == 200 tree = etree.fromstring(response.content) - pretty: str = etree.tostring(tree, pretty_print=True, xml_declaration=True).decode('ascii') + pretty: str = etree.tostring(tree, pretty_print=True, xml_declaration=True).decode('utf-8') for line, expected_line in zip(pretty.splitlines(), expected_response.splitlines()): assert dedent(line) == dedent(expected_line), f'{line} != {expected_line}'