jpayne@68: import codecs jpayne@68: import re jpayne@68: from typing import Any, Optional, Tuple jpayne@68: jpayne@68: from .core import IDNAError, alabel, decode, encode, ulabel jpayne@68: jpayne@68: _unicode_dots_re = re.compile("[\u002e\u3002\uff0e\uff61]") jpayne@68: jpayne@68: jpayne@68: class Codec(codecs.Codec): jpayne@68: def encode(self, data: str, errors: str = "strict") -> Tuple[bytes, int]: jpayne@68: if errors != "strict": jpayne@68: raise IDNAError('Unsupported error handling "{}"'.format(errors)) jpayne@68: jpayne@68: if not data: jpayne@68: return b"", 0 jpayne@68: jpayne@68: return encode(data), len(data) jpayne@68: jpayne@68: def decode(self, data: bytes, errors: str = "strict") -> Tuple[str, int]: jpayne@68: if errors != "strict": jpayne@68: raise IDNAError('Unsupported error handling "{}"'.format(errors)) jpayne@68: jpayne@68: if not data: jpayne@68: return "", 0 jpayne@68: jpayne@68: return decode(data), len(data) jpayne@68: jpayne@68: jpayne@68: class IncrementalEncoder(codecs.BufferedIncrementalEncoder): jpayne@68: def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[bytes, int]: jpayne@68: if errors != "strict": jpayne@68: raise IDNAError('Unsupported error handling "{}"'.format(errors)) jpayne@68: jpayne@68: if not data: jpayne@68: return b"", 0 jpayne@68: jpayne@68: labels = _unicode_dots_re.split(data) jpayne@68: trailing_dot = b"" jpayne@68: if labels: jpayne@68: if not labels[-1]: jpayne@68: trailing_dot = b"." jpayne@68: del labels[-1] jpayne@68: elif not final: jpayne@68: # Keep potentially unfinished label until the next call jpayne@68: del labels[-1] jpayne@68: if labels: jpayne@68: trailing_dot = b"." jpayne@68: jpayne@68: result = [] jpayne@68: size = 0 jpayne@68: for label in labels: jpayne@68: result.append(alabel(label)) jpayne@68: if size: jpayne@68: size += 1 jpayne@68: size += len(label) jpayne@68: jpayne@68: # Join with U+002E jpayne@68: result_bytes = b".".join(result) + trailing_dot jpayne@68: size += len(trailing_dot) jpayne@68: return result_bytes, size jpayne@68: jpayne@68: jpayne@68: class IncrementalDecoder(codecs.BufferedIncrementalDecoder): jpayne@68: def _buffer_decode(self, data: Any, errors: str, final: bool) -> Tuple[str, int]: jpayne@68: if errors != "strict": jpayne@68: raise IDNAError('Unsupported error handling "{}"'.format(errors)) jpayne@68: jpayne@68: if not data: jpayne@68: return ("", 0) jpayne@68: jpayne@68: if not isinstance(data, str): jpayne@68: data = str(data, "ascii") jpayne@68: jpayne@68: labels = _unicode_dots_re.split(data) jpayne@68: trailing_dot = "" jpayne@68: if labels: jpayne@68: if not labels[-1]: jpayne@68: trailing_dot = "." jpayne@68: del labels[-1] jpayne@68: elif not final: jpayne@68: # Keep potentially unfinished label until the next call jpayne@68: del labels[-1] jpayne@68: if labels: jpayne@68: trailing_dot = "." jpayne@68: jpayne@68: result = [] jpayne@68: size = 0 jpayne@68: for label in labels: jpayne@68: result.append(ulabel(label)) jpayne@68: if size: jpayne@68: size += 1 jpayne@68: size += len(label) jpayne@68: jpayne@68: result_str = ".".join(result) + trailing_dot jpayne@68: size += len(trailing_dot) jpayne@68: return (result_str, size) jpayne@68: jpayne@68: jpayne@68: class StreamWriter(Codec, codecs.StreamWriter): jpayne@68: pass jpayne@68: jpayne@68: jpayne@68: class StreamReader(Codec, codecs.StreamReader): jpayne@68: pass jpayne@68: jpayne@68: jpayne@68: def search_function(name: str) -> Optional[codecs.CodecInfo]: jpayne@68: if name != "idna2008": jpayne@68: return None jpayne@68: return codecs.CodecInfo( jpayne@68: name=name, jpayne@68: encode=Codec().encode, jpayne@68: decode=Codec().decode, jpayne@68: incrementalencoder=IncrementalEncoder, jpayne@68: incrementaldecoder=IncrementalDecoder, jpayne@68: streamwriter=StreamWriter, jpayne@68: streamreader=StreamReader, jpayne@68: ) jpayne@68: jpayne@68: jpayne@68: codecs.register(search_function)