jpayne@69: import codecs jpayne@69: import re jpayne@69: from typing import Any, Optional, Tuple jpayne@69: jpayne@69: from .core import IDNAError, alabel, decode, encode, ulabel jpayne@69: jpayne@69: _unicode_dots_re = re.compile("[\u002e\u3002\uff0e\uff61]") jpayne@69: jpayne@69: jpayne@69: class Codec(codecs.Codec): jpayne@69: def encode(self, data: str, errors: str = "strict") -> Tuple[bytes, int]: jpayne@69: if errors != "strict": jpayne@69: raise IDNAError('Unsupported error handling "{}"'.format(errors)) jpayne@69: jpayne@69: if not data: jpayne@69: return b"", 0 jpayne@69: jpayne@69: return encode(data), len(data) jpayne@69: jpayne@69: def decode(self, data: bytes, errors: str = "strict") -> Tuple[str, int]: jpayne@69: if errors != "strict": jpayne@69: raise IDNAError('Unsupported error handling "{}"'.format(errors)) jpayne@69: jpayne@69: if not data: jpayne@69: return "", 0 jpayne@69: jpayne@69: return decode(data), len(data) jpayne@69: jpayne@69: jpayne@69: class IncrementalEncoder(codecs.BufferedIncrementalEncoder): jpayne@69: def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[bytes, int]: jpayne@69: if errors != "strict": jpayne@69: raise IDNAError('Unsupported error handling "{}"'.format(errors)) jpayne@69: jpayne@69: if not data: jpayne@69: return b"", 0 jpayne@69: jpayne@69: labels = _unicode_dots_re.split(data) jpayne@69: trailing_dot = b"" jpayne@69: if labels: jpayne@69: if not labels[-1]: jpayne@69: trailing_dot = b"." jpayne@69: del labels[-1] jpayne@69: elif not final: jpayne@69: # Keep potentially unfinished label until the next call jpayne@69: del labels[-1] jpayne@69: if labels: jpayne@69: trailing_dot = b"." jpayne@69: jpayne@69: result = [] jpayne@69: size = 0 jpayne@69: for label in labels: jpayne@69: result.append(alabel(label)) jpayne@69: if size: jpayne@69: size += 1 jpayne@69: size += len(label) jpayne@69: jpayne@69: # Join with U+002E jpayne@69: result_bytes = b".".join(result) + trailing_dot jpayne@69: size += len(trailing_dot) jpayne@69: return result_bytes, size jpayne@69: jpayne@69: jpayne@69: class IncrementalDecoder(codecs.BufferedIncrementalDecoder): jpayne@69: def _buffer_decode(self, data: Any, errors: str, final: bool) -> Tuple[str, int]: jpayne@69: if errors != "strict": jpayne@69: raise IDNAError('Unsupported error handling "{}"'.format(errors)) jpayne@69: jpayne@69: if not data: jpayne@69: return ("", 0) jpayne@69: jpayne@69: if not isinstance(data, str): jpayne@69: data = str(data, "ascii") jpayne@69: jpayne@69: labels = _unicode_dots_re.split(data) jpayne@69: trailing_dot = "" jpayne@69: if labels: jpayne@69: if not labels[-1]: jpayne@69: trailing_dot = "." jpayne@69: del labels[-1] jpayne@69: elif not final: jpayne@69: # Keep potentially unfinished label until the next call jpayne@69: del labels[-1] jpayne@69: if labels: jpayne@69: trailing_dot = "." jpayne@69: jpayne@69: result = [] jpayne@69: size = 0 jpayne@69: for label in labels: jpayne@69: result.append(ulabel(label)) jpayne@69: if size: jpayne@69: size += 1 jpayne@69: size += len(label) jpayne@69: jpayne@69: result_str = ".".join(result) + trailing_dot jpayne@69: size += len(trailing_dot) jpayne@69: return (result_str, size) jpayne@69: jpayne@69: jpayne@69: class StreamWriter(Codec, codecs.StreamWriter): jpayne@69: pass jpayne@69: jpayne@69: jpayne@69: class StreamReader(Codec, codecs.StreamReader): jpayne@69: pass jpayne@69: jpayne@69: jpayne@69: def search_function(name: str) -> Optional[codecs.CodecInfo]: jpayne@69: if name != "idna2008": jpayne@69: return None jpayne@69: return codecs.CodecInfo( jpayne@69: name=name, jpayne@69: encode=Codec().encode, jpayne@69: decode=Codec().decode, jpayne@69: incrementalencoder=IncrementalEncoder, jpayne@69: incrementaldecoder=IncrementalDecoder, jpayne@69: streamwriter=StreamWriter, jpayne@69: streamreader=StreamReader, jpayne@69: ) jpayne@69: jpayne@69: jpayne@69: codecs.register(search_function)