jpayne@7: from .core import encode, decode, alabel, ulabel, IDNAError jpayne@7: import codecs jpayne@7: import re jpayne@7: from typing import Any, Tuple, Optional jpayne@7: jpayne@7: _unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]') jpayne@7: jpayne@7: class Codec(codecs.Codec): jpayne@7: jpayne@7: def encode(self, data: str, errors: str = 'strict') -> Tuple[bytes, int]: jpayne@7: if errors != 'strict': jpayne@7: raise IDNAError('Unsupported error handling \"{}\"'.format(errors)) jpayne@7: jpayne@7: if not data: jpayne@7: return b"", 0 jpayne@7: jpayne@7: return encode(data), len(data) jpayne@7: jpayne@7: def decode(self, data: bytes, errors: str = 'strict') -> Tuple[str, int]: jpayne@7: if errors != 'strict': jpayne@7: raise IDNAError('Unsupported error handling \"{}\"'.format(errors)) jpayne@7: jpayne@7: if not data: jpayne@7: return '', 0 jpayne@7: jpayne@7: return decode(data), len(data) jpayne@7: jpayne@7: class IncrementalEncoder(codecs.BufferedIncrementalEncoder): jpayne@7: def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[bytes, int]: jpayne@7: if errors != 'strict': jpayne@7: raise IDNAError('Unsupported error handling \"{}\"'.format(errors)) jpayne@7: jpayne@7: if not data: jpayne@7: return b'', 0 jpayne@7: jpayne@7: labels = _unicode_dots_re.split(data) jpayne@7: trailing_dot = b'' jpayne@7: if labels: jpayne@7: if not labels[-1]: jpayne@7: trailing_dot = b'.' jpayne@7: del labels[-1] jpayne@7: elif not final: jpayne@7: # Keep potentially unfinished label until the next call jpayne@7: del labels[-1] jpayne@7: if labels: jpayne@7: trailing_dot = b'.' jpayne@7: jpayne@7: result = [] jpayne@7: size = 0 jpayne@7: for label in labels: jpayne@7: result.append(alabel(label)) jpayne@7: if size: jpayne@7: size += 1 jpayne@7: size += len(label) jpayne@7: jpayne@7: # Join with U+002E jpayne@7: result_bytes = b'.'.join(result) + trailing_dot jpayne@7: size += len(trailing_dot) jpayne@7: return result_bytes, size jpayne@7: jpayne@7: class IncrementalDecoder(codecs.BufferedIncrementalDecoder): jpayne@7: def _buffer_decode(self, data: Any, errors: str, final: bool) -> Tuple[str, int]: jpayne@7: if errors != 'strict': jpayne@7: raise IDNAError('Unsupported error handling \"{}\"'.format(errors)) jpayne@7: jpayne@7: if not data: jpayne@7: return ('', 0) jpayne@7: jpayne@7: if not isinstance(data, str): jpayne@7: data = str(data, 'ascii') jpayne@7: jpayne@7: labels = _unicode_dots_re.split(data) jpayne@7: trailing_dot = '' jpayne@7: if labels: jpayne@7: if not labels[-1]: jpayne@7: trailing_dot = '.' jpayne@7: del labels[-1] jpayne@7: elif not final: jpayne@7: # Keep potentially unfinished label until the next call jpayne@7: del labels[-1] jpayne@7: if labels: jpayne@7: trailing_dot = '.' jpayne@7: jpayne@7: result = [] jpayne@7: size = 0 jpayne@7: for label in labels: jpayne@7: result.append(ulabel(label)) jpayne@7: if size: jpayne@7: size += 1 jpayne@7: size += len(label) jpayne@7: jpayne@7: result_str = '.'.join(result) + trailing_dot jpayne@7: size += len(trailing_dot) jpayne@7: return (result_str, size) jpayne@7: jpayne@7: jpayne@7: class StreamWriter(Codec, codecs.StreamWriter): jpayne@7: pass jpayne@7: jpayne@7: jpayne@7: class StreamReader(Codec, codecs.StreamReader): jpayne@7: pass jpayne@7: jpayne@7: jpayne@7: def search_function(name: str) -> Optional[codecs.CodecInfo]: jpayne@7: if name != 'idna2008': jpayne@7: return None jpayne@7: return codecs.CodecInfo( jpayne@7: name=name, jpayne@7: encode=Codec().encode, jpayne@7: decode=Codec().decode, jpayne@7: incrementalencoder=IncrementalEncoder, jpayne@7: incrementaldecoder=IncrementalDecoder, jpayne@7: streamwriter=StreamWriter, jpayne@7: streamreader=StreamReader, jpayne@7: ) jpayne@7: jpayne@7: codecs.register(search_function)