jpayne@69: """\ jpayne@69: A library of useful helper classes to the SAX classes, for the jpayne@69: convenience of application and driver writers. jpayne@69: """ jpayne@69: jpayne@69: import os, urllib.parse, urllib.request jpayne@69: import io jpayne@69: import codecs jpayne@69: from . import handler jpayne@69: from . import xmlreader jpayne@69: jpayne@69: def __dict_replace(s, d): jpayne@69: """Replace substrings of a string using a dictionary.""" jpayne@69: for key, value in d.items(): jpayne@69: s = s.replace(key, value) jpayne@69: return s jpayne@69: jpayne@69: def escape(data, entities={}): jpayne@69: """Escape &, <, and > in a string of data. jpayne@69: jpayne@69: You can escape other strings of data by passing a dictionary as jpayne@69: the optional entities parameter. The keys and values must all be jpayne@69: strings; each key will be replaced with its corresponding value. jpayne@69: """ jpayne@69: jpayne@69: # must do ampersand first jpayne@69: data = data.replace("&", "&") jpayne@69: data = data.replace(">", ">") jpayne@69: data = data.replace("<", "<") jpayne@69: if entities: jpayne@69: data = __dict_replace(data, entities) jpayne@69: return data jpayne@69: jpayne@69: def unescape(data, entities={}): jpayne@69: """Unescape &, <, and > in a string of data. jpayne@69: jpayne@69: You can unescape other strings of data by passing a dictionary as jpayne@69: the optional entities parameter. The keys and values must all be jpayne@69: strings; each key will be replaced with its corresponding value. jpayne@69: """ jpayne@69: data = data.replace("<", "<") jpayne@69: data = data.replace(">", ">") jpayne@69: if entities: jpayne@69: data = __dict_replace(data, entities) jpayne@69: # must do ampersand last jpayne@69: return data.replace("&", "&") jpayne@69: jpayne@69: def quoteattr(data, entities={}): jpayne@69: """Escape and quote an attribute value. jpayne@69: jpayne@69: Escape &, <, and > in a string of data, then quote it for use as jpayne@69: an attribute value. The \" character will be escaped as well, if jpayne@69: necessary. jpayne@69: jpayne@69: You can escape other strings of data by passing a dictionary as jpayne@69: the optional entities parameter. The keys and values must all be jpayne@69: strings; each key will be replaced with its corresponding value. jpayne@69: """ jpayne@69: entities = {**entities, '\n': ' ', '\r': ' ', '\t':' '} jpayne@69: data = escape(data, entities) jpayne@69: if '"' in data: jpayne@69: if "'" in data: jpayne@69: data = '"%s"' % data.replace('"', """) jpayne@69: else: jpayne@69: data = "'%s'" % data jpayne@69: else: jpayne@69: data = '"%s"' % data jpayne@69: return data jpayne@69: jpayne@69: jpayne@69: def _gettextwriter(out, encoding): jpayne@69: if out is None: jpayne@69: import sys jpayne@69: return sys.stdout jpayne@69: jpayne@69: if isinstance(out, io.TextIOBase): jpayne@69: # use a text writer as is jpayne@69: return out jpayne@69: jpayne@69: if isinstance(out, (codecs.StreamWriter, codecs.StreamReaderWriter)): jpayne@69: # use a codecs stream writer as is jpayne@69: return out jpayne@69: jpayne@69: # wrap a binary writer with TextIOWrapper jpayne@69: if isinstance(out, io.RawIOBase): jpayne@69: # Keep the original file open when the TextIOWrapper is jpayne@69: # destroyed jpayne@69: class _wrapper: jpayne@69: __class__ = out.__class__ jpayne@69: def __getattr__(self, name): jpayne@69: return getattr(out, name) jpayne@69: buffer = _wrapper() jpayne@69: buffer.close = lambda: None jpayne@69: else: jpayne@69: # This is to handle passed objects that aren't in the jpayne@69: # IOBase hierarchy, but just have a write method jpayne@69: buffer = io.BufferedIOBase() jpayne@69: buffer.writable = lambda: True jpayne@69: buffer.write = out.write jpayne@69: try: jpayne@69: # TextIOWrapper uses this methods to determine jpayne@69: # if BOM (for UTF-16, etc) should be added jpayne@69: buffer.seekable = out.seekable jpayne@69: buffer.tell = out.tell jpayne@69: except AttributeError: jpayne@69: pass jpayne@69: return io.TextIOWrapper(buffer, encoding=encoding, jpayne@69: errors='xmlcharrefreplace', jpayne@69: newline='\n', jpayne@69: write_through=True) jpayne@69: jpayne@69: class XMLGenerator(handler.ContentHandler): jpayne@69: jpayne@69: def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False): jpayne@69: handler.ContentHandler.__init__(self) jpayne@69: out = _gettextwriter(out, encoding) jpayne@69: self._write = out.write jpayne@69: self._flush = out.flush jpayne@69: self._ns_contexts = [{}] # contains uri -> prefix dicts jpayne@69: self._current_context = self._ns_contexts[-1] jpayne@69: self._undeclared_ns_maps = [] jpayne@69: self._encoding = encoding jpayne@69: self._short_empty_elements = short_empty_elements jpayne@69: self._pending_start_element = False jpayne@69: jpayne@69: def _qname(self, name): jpayne@69: """Builds a qualified name from a (ns_url, localname) pair""" jpayne@69: if name[0]: jpayne@69: # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is jpayne@69: # bound by definition to http://www.w3.org/XML/1998/namespace. It jpayne@69: # does not need to be declared and will not usually be found in jpayne@69: # self._current_context. jpayne@69: if 'http://www.w3.org/XML/1998/namespace' == name[0]: jpayne@69: return 'xml:' + name[1] jpayne@69: # The name is in a non-empty namespace jpayne@69: prefix = self._current_context[name[0]] jpayne@69: if prefix: jpayne@69: # If it is not the default namespace, prepend the prefix jpayne@69: return prefix + ":" + name[1] jpayne@69: # Return the unqualified name jpayne@69: return name[1] jpayne@69: jpayne@69: def _finish_pending_start_element(self,endElement=False): jpayne@69: if self._pending_start_element: jpayne@69: self._write('>') jpayne@69: self._pending_start_element = False jpayne@69: jpayne@69: # ContentHandler methods jpayne@69: jpayne@69: def startDocument(self): jpayne@69: self._write('\n' % jpayne@69: self._encoding) jpayne@69: jpayne@69: def endDocument(self): jpayne@69: self._flush() jpayne@69: jpayne@69: def startPrefixMapping(self, prefix, uri): jpayne@69: self._ns_contexts.append(self._current_context.copy()) jpayne@69: self._current_context[uri] = prefix jpayne@69: self._undeclared_ns_maps.append((prefix, uri)) jpayne@69: jpayne@69: def endPrefixMapping(self, prefix): jpayne@69: self._current_context = self._ns_contexts[-1] jpayne@69: del self._ns_contexts[-1] jpayne@69: jpayne@69: def startElement(self, name, attrs): jpayne@69: self._finish_pending_start_element() jpayne@69: self._write('<' + name) jpayne@69: for (name, value) in attrs.items(): jpayne@69: self._write(' %s=%s' % (name, quoteattr(value))) jpayne@69: if self._short_empty_elements: jpayne@69: self._pending_start_element = True jpayne@69: else: jpayne@69: self._write(">") jpayne@69: jpayne@69: def endElement(self, name): jpayne@69: if self._pending_start_element: jpayne@69: self._write('/>') jpayne@69: self._pending_start_element = False jpayne@69: else: jpayne@69: self._write('' % name) jpayne@69: jpayne@69: def startElementNS(self, name, qname, attrs): jpayne@69: self._finish_pending_start_element() jpayne@69: self._write('<' + self._qname(name)) jpayne@69: jpayne@69: for prefix, uri in self._undeclared_ns_maps: jpayne@69: if prefix: jpayne@69: self._write(' xmlns:%s="%s"' % (prefix, uri)) jpayne@69: else: jpayne@69: self._write(' xmlns="%s"' % uri) jpayne@69: self._undeclared_ns_maps = [] jpayne@69: jpayne@69: for (name, value) in attrs.items(): jpayne@69: self._write(' %s=%s' % (self._qname(name), quoteattr(value))) jpayne@69: if self._short_empty_elements: jpayne@69: self._pending_start_element = True jpayne@69: else: jpayne@69: self._write(">") jpayne@69: jpayne@69: def endElementNS(self, name, qname): jpayne@69: if self._pending_start_element: jpayne@69: self._write('/>') jpayne@69: self._pending_start_element = False jpayne@69: else: jpayne@69: self._write('' % self._qname(name)) jpayne@69: jpayne@69: def characters(self, content): jpayne@69: if content: jpayne@69: self._finish_pending_start_element() jpayne@69: if not isinstance(content, str): jpayne@69: content = str(content, self._encoding) jpayne@69: self._write(escape(content)) jpayne@69: jpayne@69: def ignorableWhitespace(self, content): jpayne@69: if content: jpayne@69: self._finish_pending_start_element() jpayne@69: if not isinstance(content, str): jpayne@69: content = str(content, self._encoding) jpayne@69: self._write(content) jpayne@69: jpayne@69: def processingInstruction(self, target, data): jpayne@69: self._finish_pending_start_element() jpayne@69: self._write('' % (target, data)) jpayne@69: jpayne@69: jpayne@69: class XMLFilterBase(xmlreader.XMLReader): jpayne@69: """This class is designed to sit between an XMLReader and the jpayne@69: client application's event handlers. By default, it does nothing jpayne@69: but pass requests up to the reader and events on to the handlers jpayne@69: unmodified, but subclasses can override specific methods to modify jpayne@69: the event stream or the configuration requests as they pass jpayne@69: through.""" jpayne@69: jpayne@69: def __init__(self, parent = None): jpayne@69: xmlreader.XMLReader.__init__(self) jpayne@69: self._parent = parent jpayne@69: jpayne@69: # ErrorHandler methods jpayne@69: jpayne@69: def error(self, exception): jpayne@69: self._err_handler.error(exception) jpayne@69: jpayne@69: def fatalError(self, exception): jpayne@69: self._err_handler.fatalError(exception) jpayne@69: jpayne@69: def warning(self, exception): jpayne@69: self._err_handler.warning(exception) jpayne@69: jpayne@69: # ContentHandler methods jpayne@69: jpayne@69: def setDocumentLocator(self, locator): jpayne@69: self._cont_handler.setDocumentLocator(locator) jpayne@69: jpayne@69: def startDocument(self): jpayne@69: self._cont_handler.startDocument() jpayne@69: jpayne@69: def endDocument(self): jpayne@69: self._cont_handler.endDocument() jpayne@69: jpayne@69: def startPrefixMapping(self, prefix, uri): jpayne@69: self._cont_handler.startPrefixMapping(prefix, uri) jpayne@69: jpayne@69: def endPrefixMapping(self, prefix): jpayne@69: self._cont_handler.endPrefixMapping(prefix) jpayne@69: jpayne@69: def startElement(self, name, attrs): jpayne@69: self._cont_handler.startElement(name, attrs) jpayne@69: jpayne@69: def endElement(self, name): jpayne@69: self._cont_handler.endElement(name) jpayne@69: jpayne@69: def startElementNS(self, name, qname, attrs): jpayne@69: self._cont_handler.startElementNS(name, qname, attrs) jpayne@69: jpayne@69: def endElementNS(self, name, qname): jpayne@69: self._cont_handler.endElementNS(name, qname) jpayne@69: jpayne@69: def characters(self, content): jpayne@69: self._cont_handler.characters(content) jpayne@69: jpayne@69: def ignorableWhitespace(self, chars): jpayne@69: self._cont_handler.ignorableWhitespace(chars) jpayne@69: jpayne@69: def processingInstruction(self, target, data): jpayne@69: self._cont_handler.processingInstruction(target, data) jpayne@69: jpayne@69: def skippedEntity(self, name): jpayne@69: self._cont_handler.skippedEntity(name) jpayne@69: jpayne@69: # DTDHandler methods jpayne@69: jpayne@69: def notationDecl(self, name, publicId, systemId): jpayne@69: self._dtd_handler.notationDecl(name, publicId, systemId) jpayne@69: jpayne@69: def unparsedEntityDecl(self, name, publicId, systemId, ndata): jpayne@69: self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata) jpayne@69: jpayne@69: # EntityResolver methods jpayne@69: jpayne@69: def resolveEntity(self, publicId, systemId): jpayne@69: return self._ent_handler.resolveEntity(publicId, systemId) jpayne@69: jpayne@69: # XMLReader methods jpayne@69: jpayne@69: def parse(self, source): jpayne@69: self._parent.setContentHandler(self) jpayne@69: self._parent.setErrorHandler(self) jpayne@69: self._parent.setEntityResolver(self) jpayne@69: self._parent.setDTDHandler(self) jpayne@69: self._parent.parse(source) jpayne@69: jpayne@69: def setLocale(self, locale): jpayne@69: self._parent.setLocale(locale) jpayne@69: jpayne@69: def getFeature(self, name): jpayne@69: return self._parent.getFeature(name) jpayne@69: jpayne@69: def setFeature(self, name, state): jpayne@69: self._parent.setFeature(name, state) jpayne@69: jpayne@69: def getProperty(self, name): jpayne@69: return self._parent.getProperty(name) jpayne@69: jpayne@69: def setProperty(self, name, value): jpayne@69: self._parent.setProperty(name, value) jpayne@69: jpayne@69: # XMLFilter methods jpayne@69: jpayne@69: def getParent(self): jpayne@69: return self._parent jpayne@69: jpayne@69: def setParent(self, parent): jpayne@69: self._parent = parent jpayne@69: jpayne@69: # --- Utility functions jpayne@69: jpayne@69: def prepare_input_source(source, base=""): jpayne@69: """This function takes an InputSource and an optional base URL and jpayne@69: returns a fully resolved InputSource object ready for reading.""" jpayne@69: jpayne@69: if isinstance(source, os.PathLike): jpayne@69: source = os.fspath(source) jpayne@69: if isinstance(source, str): jpayne@69: source = xmlreader.InputSource(source) jpayne@69: elif hasattr(source, "read"): jpayne@69: f = source jpayne@69: source = xmlreader.InputSource() jpayne@69: if isinstance(f.read(0), str): jpayne@69: source.setCharacterStream(f) jpayne@69: else: jpayne@69: source.setByteStream(f) jpayne@69: if hasattr(f, "name") and isinstance(f.name, str): jpayne@69: source.setSystemId(f.name) jpayne@69: jpayne@69: if source.getCharacterStream() is None and source.getByteStream() is None: jpayne@69: sysid = source.getSystemId() jpayne@69: basehead = os.path.dirname(os.path.normpath(base)) jpayne@69: sysidfilename = os.path.join(basehead, sysid) jpayne@69: if os.path.isfile(sysidfilename): jpayne@69: source.setSystemId(sysidfilename) jpayne@69: f = open(sysidfilename, "rb") jpayne@69: else: jpayne@69: source.setSystemId(urllib.parse.urljoin(base, sysid)) jpayne@69: f = urllib.request.urlopen(source.getSystemId()) jpayne@69: jpayne@69: source.setByteStream(f) jpayne@69: jpayne@69: return source