jpayne@7: from . import idnadata jpayne@7: import bisect jpayne@7: import unicodedata jpayne@7: import re jpayne@7: from typing import Union, Optional jpayne@7: from .intranges import intranges_contain jpayne@7: jpayne@7: _virama_combining_class = 9 jpayne@7: _alabel_prefix = b'xn--' jpayne@7: _unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]') jpayne@7: jpayne@7: class IDNAError(UnicodeError): jpayne@7: """ Base exception for all IDNA-encoding related problems """ jpayne@7: pass jpayne@7: jpayne@7: jpayne@7: class IDNABidiError(IDNAError): jpayne@7: """ Exception when bidirectional requirements are not satisfied """ jpayne@7: pass jpayne@7: jpayne@7: jpayne@7: class InvalidCodepoint(IDNAError): jpayne@7: """ Exception when a disallowed or unallocated codepoint is used """ jpayne@7: pass jpayne@7: jpayne@7: jpayne@7: class InvalidCodepointContext(IDNAError): jpayne@7: """ Exception when the codepoint is not valid in the context it is used """ jpayne@7: pass jpayne@7: jpayne@7: jpayne@7: def _combining_class(cp: int) -> int: jpayne@7: v = unicodedata.combining(chr(cp)) jpayne@7: if v == 0: jpayne@7: if not unicodedata.name(chr(cp)): jpayne@7: raise ValueError('Unknown character in unicodedata') jpayne@7: return v jpayne@7: jpayne@7: def _is_script(cp: str, script: str) -> bool: jpayne@7: return intranges_contain(ord(cp), idnadata.scripts[script]) jpayne@7: jpayne@7: def _punycode(s: str) -> bytes: jpayne@7: return s.encode('punycode') jpayne@7: jpayne@7: def _unot(s: int) -> str: jpayne@7: return 'U+{:04X}'.format(s) jpayne@7: jpayne@7: jpayne@7: def valid_label_length(label: Union[bytes, str]) -> bool: jpayne@7: if len(label) > 63: jpayne@7: return False jpayne@7: return True jpayne@7: jpayne@7: jpayne@7: def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool: jpayne@7: if len(label) > (254 if trailing_dot else 253): jpayne@7: return False jpayne@7: return True jpayne@7: jpayne@7: jpayne@7: def check_bidi(label: str, check_ltr: bool = False) -> bool: jpayne@7: # Bidi rules should only be applied if string contains RTL characters jpayne@7: bidi_label = False jpayne@7: for (idx, cp) in enumerate(label, 1): jpayne@7: direction = unicodedata.bidirectional(cp) jpayne@7: if direction == '': jpayne@7: # String likely comes from a newer version of Unicode jpayne@7: raise IDNABidiError('Unknown directionality in label {} at position {}'.format(repr(label), idx)) jpayne@7: if direction in ['R', 'AL', 'AN']: jpayne@7: bidi_label = True jpayne@7: if not bidi_label and not check_ltr: jpayne@7: return True jpayne@7: jpayne@7: # Bidi rule 1 jpayne@7: direction = unicodedata.bidirectional(label[0]) jpayne@7: if direction in ['R', 'AL']: jpayne@7: rtl = True jpayne@7: elif direction == 'L': jpayne@7: rtl = False jpayne@7: else: jpayne@7: raise IDNABidiError('First codepoint in label {} must be directionality L, R or AL'.format(repr(label))) jpayne@7: jpayne@7: valid_ending = False jpayne@7: number_type = None # type: Optional[str] jpayne@7: for (idx, cp) in enumerate(label, 1): jpayne@7: direction = unicodedata.bidirectional(cp) jpayne@7: jpayne@7: if rtl: jpayne@7: # Bidi rule 2 jpayne@7: if not direction in ['R', 'AL', 'AN', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']: jpayne@7: raise IDNABidiError('Invalid direction for codepoint at position {} in a right-to-left label'.format(idx)) jpayne@7: # Bidi rule 3 jpayne@7: if direction in ['R', 'AL', 'EN', 'AN']: jpayne@7: valid_ending = True jpayne@7: elif direction != 'NSM': jpayne@7: valid_ending = False jpayne@7: # Bidi rule 4 jpayne@7: if direction in ['AN', 'EN']: jpayne@7: if not number_type: jpayne@7: number_type = direction jpayne@7: else: jpayne@7: if number_type != direction: jpayne@7: raise IDNABidiError('Can not mix numeral types in a right-to-left label') jpayne@7: else: jpayne@7: # Bidi rule 5 jpayne@7: if not direction in ['L', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']: jpayne@7: raise IDNABidiError('Invalid direction for codepoint at position {} in a left-to-right label'.format(idx)) jpayne@7: # Bidi rule 6 jpayne@7: if direction in ['L', 'EN']: jpayne@7: valid_ending = True jpayne@7: elif direction != 'NSM': jpayne@7: valid_ending = False jpayne@7: jpayne@7: if not valid_ending: jpayne@7: raise IDNABidiError('Label ends with illegal codepoint directionality') jpayne@7: jpayne@7: return True jpayne@7: jpayne@7: jpayne@7: def check_initial_combiner(label: str) -> bool: jpayne@7: if unicodedata.category(label[0])[0] == 'M': jpayne@7: raise IDNAError('Label begins with an illegal combining character') jpayne@7: return True jpayne@7: jpayne@7: jpayne@7: def check_hyphen_ok(label: str) -> bool: jpayne@7: if label[2:4] == '--': jpayne@7: raise IDNAError('Label has disallowed hyphens in 3rd and 4th position') jpayne@7: if label[0] == '-' or label[-1] == '-': jpayne@7: raise IDNAError('Label must not start or end with a hyphen') jpayne@7: return True jpayne@7: jpayne@7: jpayne@7: def check_nfc(label: str) -> None: jpayne@7: if unicodedata.normalize('NFC', label) != label: jpayne@7: raise IDNAError('Label must be in Normalization Form C') jpayne@7: jpayne@7: jpayne@7: def valid_contextj(label: str, pos: int) -> bool: jpayne@7: cp_value = ord(label[pos]) jpayne@7: jpayne@7: if cp_value == 0x200c: jpayne@7: jpayne@7: if pos > 0: jpayne@7: if _combining_class(ord(label[pos - 1])) == _virama_combining_class: jpayne@7: return True jpayne@7: jpayne@7: ok = False jpayne@7: for i in range(pos-1, -1, -1): jpayne@7: joining_type = idnadata.joining_types.get(ord(label[i])) jpayne@7: if joining_type == ord('T'): jpayne@7: continue jpayne@7: elif joining_type in [ord('L'), ord('D')]: jpayne@7: ok = True jpayne@7: break jpayne@7: else: jpayne@7: break jpayne@7: jpayne@7: if not ok: jpayne@7: return False jpayne@7: jpayne@7: ok = False jpayne@7: for i in range(pos+1, len(label)): jpayne@7: joining_type = idnadata.joining_types.get(ord(label[i])) jpayne@7: if joining_type == ord('T'): jpayne@7: continue jpayne@7: elif joining_type in [ord('R'), ord('D')]: jpayne@7: ok = True jpayne@7: break jpayne@7: else: jpayne@7: break jpayne@7: return ok jpayne@7: jpayne@7: if cp_value == 0x200d: jpayne@7: jpayne@7: if pos > 0: jpayne@7: if _combining_class(ord(label[pos - 1])) == _virama_combining_class: jpayne@7: return True jpayne@7: return False jpayne@7: jpayne@7: else: jpayne@7: jpayne@7: return False jpayne@7: jpayne@7: jpayne@7: def valid_contexto(label: str, pos: int, exception: bool = False) -> bool: jpayne@7: cp_value = ord(label[pos]) jpayne@7: jpayne@7: if cp_value == 0x00b7: jpayne@7: if 0 < pos < len(label)-1: jpayne@7: if ord(label[pos - 1]) == 0x006c and ord(label[pos + 1]) == 0x006c: jpayne@7: return True jpayne@7: return False jpayne@7: jpayne@7: elif cp_value == 0x0375: jpayne@7: if pos < len(label)-1 and len(label) > 1: jpayne@7: return _is_script(label[pos + 1], 'Greek') jpayne@7: return False jpayne@7: jpayne@7: elif cp_value == 0x05f3 or cp_value == 0x05f4: jpayne@7: if pos > 0: jpayne@7: return _is_script(label[pos - 1], 'Hebrew') jpayne@7: return False jpayne@7: jpayne@7: elif cp_value == 0x30fb: jpayne@7: for cp in label: jpayne@7: if cp == '\u30fb': jpayne@7: continue jpayne@7: if _is_script(cp, 'Hiragana') or _is_script(cp, 'Katakana') or _is_script(cp, 'Han'): jpayne@7: return True jpayne@7: return False jpayne@7: jpayne@7: elif 0x660 <= cp_value <= 0x669: jpayne@7: for cp in label: jpayne@7: if 0x6f0 <= ord(cp) <= 0x06f9: jpayne@7: return False jpayne@7: return True jpayne@7: jpayne@7: elif 0x6f0 <= cp_value <= 0x6f9: jpayne@7: for cp in label: jpayne@7: if 0x660 <= ord(cp) <= 0x0669: jpayne@7: return False jpayne@7: return True jpayne@7: jpayne@7: return False jpayne@7: jpayne@7: jpayne@7: def check_label(label: Union[str, bytes, bytearray]) -> None: jpayne@7: if isinstance(label, (bytes, bytearray)): jpayne@7: label = label.decode('utf-8') jpayne@7: if len(label) == 0: jpayne@7: raise IDNAError('Empty Label') jpayne@7: jpayne@7: check_nfc(label) jpayne@7: check_hyphen_ok(label) jpayne@7: check_initial_combiner(label) jpayne@7: jpayne@7: for (pos, cp) in enumerate(label): jpayne@7: cp_value = ord(cp) jpayne@7: if intranges_contain(cp_value, idnadata.codepoint_classes['PVALID']): jpayne@7: continue jpayne@7: elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTJ']): jpayne@7: if not valid_contextj(label, pos): jpayne@7: raise InvalidCodepointContext('Joiner {} not allowed at position {} in {}'.format( jpayne@7: _unot(cp_value), pos+1, repr(label))) jpayne@7: elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTO']): jpayne@7: if not valid_contexto(label, pos): jpayne@7: raise InvalidCodepointContext('Codepoint {} not allowed at position {} in {}'.format(_unot(cp_value), pos+1, repr(label))) jpayne@7: else: jpayne@7: raise InvalidCodepoint('Codepoint {} at position {} of {} not allowed'.format(_unot(cp_value), pos+1, repr(label))) jpayne@7: jpayne@7: check_bidi(label) jpayne@7: jpayne@7: jpayne@7: def alabel(label: str) -> bytes: jpayne@7: try: jpayne@7: label_bytes = label.encode('ascii') jpayne@7: ulabel(label_bytes) jpayne@7: if not valid_label_length(label_bytes): jpayne@7: raise IDNAError('Label too long') jpayne@7: return label_bytes jpayne@7: except UnicodeEncodeError: jpayne@7: pass jpayne@7: jpayne@7: check_label(label) jpayne@7: label_bytes = _alabel_prefix + _punycode(label) jpayne@7: jpayne@7: if not valid_label_length(label_bytes): jpayne@7: raise IDNAError('Label too long') jpayne@7: jpayne@7: return label_bytes jpayne@7: jpayne@7: jpayne@7: def ulabel(label: Union[str, bytes, bytearray]) -> str: jpayne@7: if not isinstance(label, (bytes, bytearray)): jpayne@7: try: jpayne@7: label_bytes = label.encode('ascii') jpayne@7: except UnicodeEncodeError: jpayne@7: check_label(label) jpayne@7: return label jpayne@7: else: jpayne@7: label_bytes = label jpayne@7: jpayne@7: label_bytes = label_bytes.lower() jpayne@7: if label_bytes.startswith(_alabel_prefix): jpayne@7: label_bytes = label_bytes[len(_alabel_prefix):] jpayne@7: if not label_bytes: jpayne@7: raise IDNAError('Malformed A-label, no Punycode eligible content found') jpayne@7: if label_bytes.decode('ascii')[-1] == '-': jpayne@7: raise IDNAError('A-label must not end with a hyphen') jpayne@7: else: jpayne@7: check_label(label_bytes) jpayne@7: return label_bytes.decode('ascii') jpayne@7: jpayne@7: try: jpayne@7: label = label_bytes.decode('punycode') jpayne@7: except UnicodeError: jpayne@7: raise IDNAError('Invalid A-label') jpayne@7: check_label(label) jpayne@7: return label jpayne@7: jpayne@7: jpayne@7: def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str: jpayne@7: """Re-map the characters in the string according to UTS46 processing.""" jpayne@7: from .uts46data import uts46data jpayne@7: output = '' jpayne@7: jpayne@7: for pos, char in enumerate(domain): jpayne@7: code_point = ord(char) jpayne@7: try: jpayne@7: uts46row = uts46data[code_point if code_point < 256 else jpayne@7: bisect.bisect_left(uts46data, (code_point, 'Z')) - 1] jpayne@7: status = uts46row[1] jpayne@7: replacement = None # type: Optional[str] jpayne@7: if len(uts46row) == 3: jpayne@7: replacement = uts46row[2] jpayne@7: if (status == 'V' or jpayne@7: (status == 'D' and not transitional) or jpayne@7: (status == '3' and not std3_rules and replacement is None)): jpayne@7: output += char jpayne@7: elif replacement is not None and (status == 'M' or jpayne@7: (status == '3' and not std3_rules) or jpayne@7: (status == 'D' and transitional)): jpayne@7: output += replacement jpayne@7: elif status != 'I': jpayne@7: raise IndexError() jpayne@7: except IndexError: jpayne@7: raise InvalidCodepoint( jpayne@7: 'Codepoint {} not allowed at position {} in {}'.format( jpayne@7: _unot(code_point), pos + 1, repr(domain))) jpayne@7: jpayne@7: return unicodedata.normalize('NFC', output) jpayne@7: jpayne@7: jpayne@7: def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False, transitional: bool = False) -> bytes: jpayne@7: if not isinstance(s, str): jpayne@7: try: jpayne@7: s = str(s, 'ascii') jpayne@7: except UnicodeDecodeError: jpayne@7: raise IDNAError('should pass a unicode string to the function rather than a byte string.') jpayne@7: if uts46: jpayne@7: s = uts46_remap(s, std3_rules, transitional) jpayne@7: trailing_dot = False jpayne@7: result = [] jpayne@7: if strict: jpayne@7: labels = s.split('.') jpayne@7: else: jpayne@7: labels = _unicode_dots_re.split(s) jpayne@7: if not labels or labels == ['']: jpayne@7: raise IDNAError('Empty domain') jpayne@7: if labels[-1] == '': jpayne@7: del labels[-1] jpayne@7: trailing_dot = True jpayne@7: for label in labels: jpayne@7: s = alabel(label) jpayne@7: if s: jpayne@7: result.append(s) jpayne@7: else: jpayne@7: raise IDNAError('Empty label') jpayne@7: if trailing_dot: jpayne@7: result.append(b'') jpayne@7: s = b'.'.join(result) jpayne@7: if not valid_string_length(s, trailing_dot): jpayne@7: raise IDNAError('Domain too long') jpayne@7: return s jpayne@7: jpayne@7: jpayne@7: def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False) -> str: jpayne@7: try: jpayne@7: if not isinstance(s, str): jpayne@7: s = str(s, 'ascii') jpayne@7: except UnicodeDecodeError: jpayne@7: raise IDNAError('Invalid ASCII in A-label') jpayne@7: if uts46: jpayne@7: s = uts46_remap(s, std3_rules, False) jpayne@7: trailing_dot = False jpayne@7: result = [] jpayne@7: if not strict: jpayne@7: labels = _unicode_dots_re.split(s) jpayne@7: else: jpayne@7: labels = s.split('.') jpayne@7: if not labels or labels == ['']: jpayne@7: raise IDNAError('Empty domain') jpayne@7: if not labels[-1]: jpayne@7: del labels[-1] jpayne@7: trailing_dot = True jpayne@7: for label in labels: jpayne@7: s = ulabel(label) jpayne@7: if s: jpayne@7: result.append(s) jpayne@7: else: jpayne@7: raise IDNAError('Empty label') jpayne@7: if trailing_dot: jpayne@7: result.append('') jpayne@7: return '.'.join(result)