jpayne@68: """\ jpayne@68: A library of useful helper classes to the SAX classes, for the jpayne@68: convenience of application and driver writers. jpayne@68: """ jpayne@68: jpayne@68: import os, urllib.parse, urllib.request jpayne@68: import io jpayne@68: import codecs jpayne@68: from . import handler jpayne@68: from . import xmlreader jpayne@68: jpayne@68: def __dict_replace(s, d): jpayne@68: """Replace substrings of a string using a dictionary.""" jpayne@68: for key, value in d.items(): jpayne@68: s = s.replace(key, value) jpayne@68: return s jpayne@68: jpayne@68: def escape(data, entities={}): jpayne@68: """Escape &, <, and > in a string of data. jpayne@68: jpayne@68: You can escape other strings of data by passing a dictionary as jpayne@68: the optional entities parameter. The keys and values must all be jpayne@68: strings; each key will be replaced with its corresponding value. jpayne@68: """ jpayne@68: jpayne@68: # must do ampersand first jpayne@68: data = data.replace("&", "&") jpayne@68: data = data.replace(">", ">") jpayne@68: data = data.replace("<", "<") jpayne@68: if entities: jpayne@68: data = __dict_replace(data, entities) jpayne@68: return data jpayne@68: jpayne@68: def unescape(data, entities={}): jpayne@68: """Unescape &, <, and > in a string of data. jpayne@68: jpayne@68: You can unescape other strings of data by passing a dictionary as jpayne@68: the optional entities parameter. The keys and values must all be jpayne@68: strings; each key will be replaced with its corresponding value. jpayne@68: """ jpayne@68: data = data.replace("<", "<") jpayne@68: data = data.replace(">", ">") jpayne@68: if entities: jpayne@68: data = __dict_replace(data, entities) jpayne@68: # must do ampersand last jpayne@68: return data.replace("&", "&") jpayne@68: jpayne@68: def quoteattr(data, entities={}): jpayne@68: """Escape and quote an attribute value. jpayne@68: jpayne@68: Escape &, <, and > in a string of data, then quote it for use as jpayne@68: an attribute value. The \" character will be escaped as well, if jpayne@68: necessary. jpayne@68: jpayne@68: You can escape other strings of data by passing a dictionary as jpayne@68: the optional entities parameter. The keys and values must all be jpayne@68: strings; each key will be replaced with its corresponding value. jpayne@68: """ jpayne@68: entities = {**entities, '\n': ' ', '\r': ' ', '\t':' '} jpayne@68: data = escape(data, entities) jpayne@68: if '"' in data: jpayne@68: if "'" in data: jpayne@68: data = '"%s"' % data.replace('"', """) jpayne@68: else: jpayne@68: data = "'%s'" % data jpayne@68: else: jpayne@68: data = '"%s"' % data jpayne@68: return data jpayne@68: jpayne@68: jpayne@68: def _gettextwriter(out, encoding): jpayne@68: if out is None: jpayne@68: import sys jpayne@68: return sys.stdout jpayne@68: jpayne@68: if isinstance(out, io.TextIOBase): jpayne@68: # use a text writer as is jpayne@68: return out jpayne@68: jpayne@68: if isinstance(out, (codecs.StreamWriter, codecs.StreamReaderWriter)): jpayne@68: # use a codecs stream writer as is jpayne@68: return out jpayne@68: jpayne@68: # wrap a binary writer with TextIOWrapper jpayne@68: if isinstance(out, io.RawIOBase): jpayne@68: # Keep the original file open when the TextIOWrapper is jpayne@68: # destroyed jpayne@68: class _wrapper: jpayne@68: __class__ = out.__class__ jpayne@68: def __getattr__(self, name): jpayne@68: return getattr(out, name) jpayne@68: buffer = _wrapper() jpayne@68: buffer.close = lambda: None jpayne@68: else: jpayne@68: # This is to handle passed objects that aren't in the jpayne@68: # IOBase hierarchy, but just have a write method jpayne@68: buffer = io.BufferedIOBase() jpayne@68: buffer.writable = lambda: True jpayne@68: buffer.write = out.write jpayne@68: try: jpayne@68: # TextIOWrapper uses this methods to determine jpayne@68: # if BOM (for UTF-16, etc) should be added jpayne@68: buffer.seekable = out.seekable jpayne@68: buffer.tell = out.tell jpayne@68: except AttributeError: jpayne@68: pass jpayne@68: return io.TextIOWrapper(buffer, encoding=encoding, jpayne@68: errors='xmlcharrefreplace', jpayne@68: newline='\n', jpayne@68: write_through=True) jpayne@68: jpayne@68: class XMLGenerator(handler.ContentHandler): jpayne@68: jpayne@68: def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False): jpayne@68: handler.ContentHandler.__init__(self) jpayne@68: out = _gettextwriter(out, encoding) jpayne@68: self._write = out.write jpayne@68: self._flush = out.flush jpayne@68: self._ns_contexts = [{}] # contains uri -> prefix dicts jpayne@68: self._current_context = self._ns_contexts[-1] jpayne@68: self._undeclared_ns_maps = [] jpayne@68: self._encoding = encoding jpayne@68: self._short_empty_elements = short_empty_elements jpayne@68: self._pending_start_element = False jpayne@68: jpayne@68: def _qname(self, name): jpayne@68: """Builds a qualified name from a (ns_url, localname) pair""" jpayne@68: if name[0]: jpayne@68: # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is jpayne@68: # bound by definition to http://www.w3.org/XML/1998/namespace. It jpayne@68: # does not need to be declared and will not usually be found in jpayne@68: # self._current_context. jpayne@68: if 'http://www.w3.org/XML/1998/namespace' == name[0]: jpayne@68: return 'xml:' + name[1] jpayne@68: # The name is in a non-empty namespace jpayne@68: prefix = self._current_context[name[0]] jpayne@68: if prefix: jpayne@68: # If it is not the default namespace, prepend the prefix jpayne@68: return prefix + ":" + name[1] jpayne@68: # Return the unqualified name jpayne@68: return name[1] jpayne@68: jpayne@68: def _finish_pending_start_element(self,endElement=False): jpayne@68: if self._pending_start_element: jpayne@68: self._write('>') jpayne@68: self._pending_start_element = False jpayne@68: jpayne@68: # ContentHandler methods jpayne@68: jpayne@68: def startDocument(self): jpayne@68: self._write('\n' % jpayne@68: self._encoding) jpayne@68: jpayne@68: def endDocument(self): jpayne@68: self._flush() jpayne@68: jpayne@68: def startPrefixMapping(self, prefix, uri): jpayne@68: self._ns_contexts.append(self._current_context.copy()) jpayne@68: self._current_context[uri] = prefix jpayne@68: self._undeclared_ns_maps.append((prefix, uri)) jpayne@68: jpayne@68: def endPrefixMapping(self, prefix): jpayne@68: self._current_context = self._ns_contexts[-1] jpayne@68: del self._ns_contexts[-1] jpayne@68: jpayne@68: def startElement(self, name, attrs): jpayne@68: self._finish_pending_start_element() jpayne@68: self._write('<' + name) jpayne@68: for (name, value) in attrs.items(): jpayne@68: self._write(' %s=%s' % (name, quoteattr(value))) jpayne@68: if self._short_empty_elements: jpayne@68: self._pending_start_element = True jpayne@68: else: jpayne@68: self._write(">") jpayne@68: jpayne@68: def endElement(self, name): jpayne@68: if self._pending_start_element: jpayne@68: self._write('/>') jpayne@68: self._pending_start_element = False jpayne@68: else: jpayne@68: self._write('' % name) jpayne@68: jpayne@68: def startElementNS(self, name, qname, attrs): jpayne@68: self._finish_pending_start_element() jpayne@68: self._write('<' + self._qname(name)) jpayne@68: jpayne@68: for prefix, uri in self._undeclared_ns_maps: jpayne@68: if prefix: jpayne@68: self._write(' xmlns:%s="%s"' % (prefix, uri)) jpayne@68: else: jpayne@68: self._write(' xmlns="%s"' % uri) jpayne@68: self._undeclared_ns_maps = [] jpayne@68: jpayne@68: for (name, value) in attrs.items(): jpayne@68: self._write(' %s=%s' % (self._qname(name), quoteattr(value))) jpayne@68: if self._short_empty_elements: jpayne@68: self._pending_start_element = True jpayne@68: else: jpayne@68: self._write(">") jpayne@68: jpayne@68: def endElementNS(self, name, qname): jpayne@68: if self._pending_start_element: jpayne@68: self._write('/>') jpayne@68: self._pending_start_element = False jpayne@68: else: jpayne@68: self._write('' % self._qname(name)) jpayne@68: jpayne@68: def characters(self, content): jpayne@68: if content: jpayne@68: self._finish_pending_start_element() jpayne@68: if not isinstance(content, str): jpayne@68: content = str(content, self._encoding) jpayne@68: self._write(escape(content)) jpayne@68: jpayne@68: def ignorableWhitespace(self, content): jpayne@68: if content: jpayne@68: self._finish_pending_start_element() jpayne@68: if not isinstance(content, str): jpayne@68: content = str(content, self._encoding) jpayne@68: self._write(content) jpayne@68: jpayne@68: def processingInstruction(self, target, data): jpayne@68: self._finish_pending_start_element() jpayne@68: self._write('' % (target, data)) jpayne@68: jpayne@68: jpayne@68: class XMLFilterBase(xmlreader.XMLReader): jpayne@68: """This class is designed to sit between an XMLReader and the jpayne@68: client application's event handlers. By default, it does nothing jpayne@68: but pass requests up to the reader and events on to the handlers jpayne@68: unmodified, but subclasses can override specific methods to modify jpayne@68: the event stream or the configuration requests as they pass jpayne@68: through.""" jpayne@68: jpayne@68: def __init__(self, parent = None): jpayne@68: xmlreader.XMLReader.__init__(self) jpayne@68: self._parent = parent jpayne@68: jpayne@68: # ErrorHandler methods jpayne@68: jpayne@68: def error(self, exception): jpayne@68: self._err_handler.error(exception) jpayne@68: jpayne@68: def fatalError(self, exception): jpayne@68: self._err_handler.fatalError(exception) jpayne@68: jpayne@68: def warning(self, exception): jpayne@68: self._err_handler.warning(exception) jpayne@68: jpayne@68: # ContentHandler methods jpayne@68: jpayne@68: def setDocumentLocator(self, locator): jpayne@68: self._cont_handler.setDocumentLocator(locator) jpayne@68: jpayne@68: def startDocument(self): jpayne@68: self._cont_handler.startDocument() jpayne@68: jpayne@68: def endDocument(self): jpayne@68: self._cont_handler.endDocument() jpayne@68: jpayne@68: def startPrefixMapping(self, prefix, uri): jpayne@68: self._cont_handler.startPrefixMapping(prefix, uri) jpayne@68: jpayne@68: def endPrefixMapping(self, prefix): jpayne@68: self._cont_handler.endPrefixMapping(prefix) jpayne@68: jpayne@68: def startElement(self, name, attrs): jpayne@68: self._cont_handler.startElement(name, attrs) jpayne@68: jpayne@68: def endElement(self, name): jpayne@68: self._cont_handler.endElement(name) jpayne@68: jpayne@68: def startElementNS(self, name, qname, attrs): jpayne@68: self._cont_handler.startElementNS(name, qname, attrs) jpayne@68: jpayne@68: def endElementNS(self, name, qname): jpayne@68: self._cont_handler.endElementNS(name, qname) jpayne@68: jpayne@68: def characters(self, content): jpayne@68: self._cont_handler.characters(content) jpayne@68: jpayne@68: def ignorableWhitespace(self, chars): jpayne@68: self._cont_handler.ignorableWhitespace(chars) jpayne@68: jpayne@68: def processingInstruction(self, target, data): jpayne@68: self._cont_handler.processingInstruction(target, data) jpayne@68: jpayne@68: def skippedEntity(self, name): jpayne@68: self._cont_handler.skippedEntity(name) jpayne@68: jpayne@68: # DTDHandler methods jpayne@68: jpayne@68: def notationDecl(self, name, publicId, systemId): jpayne@68: self._dtd_handler.notationDecl(name, publicId, systemId) jpayne@68: jpayne@68: def unparsedEntityDecl(self, name, publicId, systemId, ndata): jpayne@68: self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata) jpayne@68: jpayne@68: # EntityResolver methods jpayne@68: jpayne@68: def resolveEntity(self, publicId, systemId): jpayne@68: return self._ent_handler.resolveEntity(publicId, systemId) jpayne@68: jpayne@68: # XMLReader methods jpayne@68: jpayne@68: def parse(self, source): jpayne@68: self._parent.setContentHandler(self) jpayne@68: self._parent.setErrorHandler(self) jpayne@68: self._parent.setEntityResolver(self) jpayne@68: self._parent.setDTDHandler(self) jpayne@68: self._parent.parse(source) jpayne@68: jpayne@68: def setLocale(self, locale): jpayne@68: self._parent.setLocale(locale) jpayne@68: jpayne@68: def getFeature(self, name): jpayne@68: return self._parent.getFeature(name) jpayne@68: jpayne@68: def setFeature(self, name, state): jpayne@68: self._parent.setFeature(name, state) jpayne@68: jpayne@68: def getProperty(self, name): jpayne@68: return self._parent.getProperty(name) jpayne@68: jpayne@68: def setProperty(self, name, value): jpayne@68: self._parent.setProperty(name, value) jpayne@68: jpayne@68: # XMLFilter methods jpayne@68: jpayne@68: def getParent(self): jpayne@68: return self._parent jpayne@68: jpayne@68: def setParent(self, parent): jpayne@68: self._parent = parent jpayne@68: jpayne@68: # --- Utility functions jpayne@68: jpayne@68: def prepare_input_source(source, base=""): jpayne@68: """This function takes an InputSource and an optional base URL and jpayne@68: returns a fully resolved InputSource object ready for reading.""" jpayne@68: jpayne@68: if isinstance(source, os.PathLike): jpayne@68: source = os.fspath(source) jpayne@68: if isinstance(source, str): jpayne@68: source = xmlreader.InputSource(source) jpayne@68: elif hasattr(source, "read"): jpayne@68: f = source jpayne@68: source = xmlreader.InputSource() jpayne@68: if isinstance(f.read(0), str): jpayne@68: source.setCharacterStream(f) jpayne@68: else: jpayne@68: source.setByteStream(f) jpayne@68: if hasattr(f, "name") and isinstance(f.name, str): jpayne@68: source.setSystemId(f.name) jpayne@68: jpayne@68: if source.getCharacterStream() is None and source.getByteStream() is None: jpayne@68: sysid = source.getSystemId() jpayne@68: basehead = os.path.dirname(os.path.normpath(base)) jpayne@68: sysidfilename = os.path.join(basehead, sysid) jpayne@68: if os.path.isfile(sysidfilename): jpayne@68: source.setSystemId(sysidfilename) jpayne@68: f = open(sysidfilename, "rb") jpayne@68: else: jpayne@68: source.setSystemId(urllib.parse.urljoin(base, sysid)) jpayne@68: f = urllib.request.urlopen(source.getSystemId()) jpayne@68: jpayne@68: source.setByteStream(f) jpayne@68: jpayne@68: return source