# Pyrogram - Telegram MTProto API Client Library for Python # Copyright (C) 2017-present Dan # # This file is part of Pyrogram. # # Pyrogram is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published # by the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Pyrogram is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . import html import logging import re from html.parser import HTMLParser from typing import Optional import pyrogram from pyrogram import raw from pyrogram.enums import MessageEntityType from pyrogram.errors import PeerIdInvalid from . import utils log = logging.getLogger(__name__) class Parser(HTMLParser): MENTION_RE = re.compile(r"tg://user\?id=(\d+)") def __init__(self, client: "pyrogram.Client"): super().__init__() self.client = client self.text = "" self.entities = [] self.tag_entities = {} def handle_starttag(self, tag, attrs): attrs = dict(attrs) extra = {} if tag in ["b", "strong"]: entity = raw.types.MessageEntityBold elif tag in ["i", "em"]: entity = raw.types.MessageEntityItalic elif tag == "u": entity = raw.types.MessageEntityUnderline elif tag in ["s", "del", "strike"]: entity = raw.types.MessageEntityStrike elif tag == "blockquote": entity = raw.types.MessageEntityBlockquote elif tag == "code": entity = raw.types.MessageEntityCode elif tag == "pre": entity = raw.types.MessageEntityPre extra["language"] = attrs.get("language", "") elif tag == "spoiler": entity = raw.types.MessageEntitySpoiler elif tag == "a": url = attrs.get("href", "") mention = Parser.MENTION_RE.match(url) if mention: entity = raw.types.InputMessageEntityMentionName extra["user_id"] = int(mention.group(1)) else: entity = raw.types.MessageEntityTextUrl extra["url"] = url elif tag == "emoji": entity = raw.types.MessageEntityCustomEmoji custom_emoji_id = int(attrs.get("id")) extra["document_id"] = custom_emoji_id else: return if tag not in self.tag_entities: self.tag_entities[tag] = [] self.tag_entities[tag].append(entity(offset=len(self.text), length=0, **extra)) def handle_data(self, data): data = html.unescape(data) for entities in self.tag_entities.values(): for entity in entities: entity.length += len(data) self.text += data def handle_endtag(self, tag): try: self.entities.append(self.tag_entities[tag].pop()) except (KeyError, IndexError): line, offset = self.getpos() offset += 1 log.debug("Unmatched closing tag at line %s:%s", tag, line, offset) else: if not self.tag_entities[tag]: self.tag_entities.pop(tag) def error(self, message): pass class HTML: def __init__(self, client: Optional["pyrogram.Client"]): self.client = client async def parse(self, text: str): # Strip whitespaces from the beginning and the end, but preserve closing tags text = re.sub(r"^\s*(<[\w<>=\s\"]*>)\s*", r"\1", text) text = re.sub(r"\s*(]*>)\s*$", r"\1", text) parser = Parser(self.client) parser.feed(utils.add_surrogates(text)) parser.close() if parser.tag_entities: unclosed_tags = [] for tag, entities in parser.tag_entities.items(): unclosed_tags.append(f"<{tag}> (x{len(entities)})") log.info("Unclosed tags: %s", ", ".join(unclosed_tags)) entities = [] for entity in parser.entities: if isinstance(entity, raw.types.InputMessageEntityMentionName): try: if self.client is not None: entity.user_id = await self.client.resolve_peer(entity.user_id) except PeerIdInvalid: continue entities.append(entity) # Remove zero-length entities entities = list(filter(lambda x: x.length > 0, entities)) return { "message": utils.remove_surrogates(parser.text), "entities": sorted(entities, key=lambda e: e.offset) or None } @staticmethod def unparse(text: str, entities: list): def parse_one(entity): """ Parses a single entity and returns (start_tag, start), (end_tag, end) """ entity_type = entity.type start = entity.offset end = start + entity.length if entity_type in ( MessageEntityType.BOLD, MessageEntityType.ITALIC, MessageEntityType.UNDERLINE, MessageEntityType.STRIKETHROUGH, ): name = entity_type.name[0].lower() start_tag = f"<{name}>" end_tag = f"" elif entity_type == MessageEntityType.PRE: name = entity_type.name.lower() language = getattr(entity, "language", "") or "" start_tag = f'<{name} language="{language}">' if language else f"<{name}>" end_tag = f"" elif entity_type in ( MessageEntityType.CODE, MessageEntityType.BLOCKQUOTE, MessageEntityType.SPOILER, ): name = entity_type.name.lower() start_tag = f"<{name}>" end_tag = f"" elif entity_type == MessageEntityType.TEXT_LINK: url = entity.url start_tag = f'' end_tag = "" elif entity_type == MessageEntityType.TEXT_MENTION: user = entity.user start_tag = f'' end_tag = "" elif entity_type == MessageEntityType.CUSTOM_EMOJI: custom_emoji_id = entity.custom_emoji_id start_tag = f'' end_tag = "" else: return return (start_tag, start), (end_tag, end) def recursive(entity_i: int) -> int: """ Takes the index of the entity to start parsing from, returns the number of parsed entities inside it. Uses entities_offsets as a stack, pushing (start_tag, start) first, then parsing nested entities, and finally pushing (end_tag, end) to the stack. No need to sort at the end. """ this = parse_one(entities[entity_i]) if this is None: return 1 (start_tag, start), (end_tag, end) = this entities_offsets.append((start_tag, start)) internal_i = entity_i + 1 # while the next entity is inside the current one, keep parsing while internal_i < len(entities) and entities[internal_i].offset < end: internal_i += recursive(internal_i) entities_offsets.append((end_tag, end)) return internal_i - entity_i text = utils.add_surrogates(text) entities_offsets = [] # probably useless because entities are already sorted by telegram entities.sort(key=lambda e: (e.offset, -e.length)) # main loop for first-level entities i = 0 while i < len(entities): i += recursive(i) if entities_offsets: last_offset = entities_offsets[-1][1] # no need to sort, but still add entities starting from the end for entity, offset in reversed(entities_offsets): text = text[:offset] + entity + html.escape(text[offset:last_offset]) + text[last_offset:] last_offset = offset return utils.remove_surrogates(text)