view env/lib/python3.7/site-packages/networkx/readwrite/gexf.py @ 3:758bc20232e8 draft

"planemo upload commit 2a0fe2cc28b09e101d37293e53e82f61762262ec"
author shellac
date Thu, 14 May 2020 16:20:52 -0400
parents 26e78fe6e8c4
children
line wrap: on
line source

# Copyright (C) 2013-2019 by
#
# Authors: Aric Hagberg <hagberg@lanl.gov>
#          Dan Schult <dschult@colgate.edu>
#          Pieter Swart <swart@lanl.gov>
# All rights reserved.
# BSD license.
# Based on GraphML NetworkX GraphML reader
"""Read and write graphs in GEXF format.

GEXF (Graph Exchange XML Format) is a language for describing complex
network structures, their associated data and dynamics.

This implementation does not support mixed graphs (directed and
undirected edges together).

Format
------
GEXF is an XML format.  See https://gephi.org/gexf/format/schema.html for the
specification and https://gephi.org/gexf/format/basic.html for examples.
"""
import itertools
import time

import networkx as nx
from networkx.utils import open_file, make_str
try:
    from xml.etree.cElementTree import (Element, ElementTree, SubElement,
                                        tostring)
except ImportError:
    try:
        from xml.etree.ElementTree import (Element, ElementTree, SubElement,
                                           tostring)
    except ImportError:
        pass

__all__ = ['write_gexf', 'read_gexf', 'relabel_gexf_graph', 'generate_gexf']


@open_file(1, mode='wb')
def write_gexf(G, path, encoding='utf-8', prettyprint=True,
               version='1.2draft'):
    """Write G in GEXF format to path.

    "GEXF (Graph Exchange XML Format) is a language for describing
    complex networks structures, their associated data and dynamics" [1]_.

    Node attributes are checked according to the version of the GEXF
    schemas used for parameters which are not user defined,
    e.g. visualization 'viz' [2]_. See example for usage.

    Parameters
    ----------
    G : graph
       A NetworkX graph
    path : file or string
       File or file name to write.
       File names ending in .gz or .bz2 will be compressed.
    encoding : string (optional, default: 'utf-8')
       Encoding for text data.
    prettyprint : bool (optional, default: True)
       If True use line breaks and indenting in output XML.

    Examples
    --------
    >>> G = nx.path_graph(4)
    >>> nx.write_gexf(G, "test.gexf")

    # visualization data
    >>> G.nodes[0]['viz'] = {'size': 54}
    >>> G.nodes[0]['viz']['position'] = {'x' : 0, 'y' : 1}
    >>> G.nodes[0]['viz']['color'] = {'r' : 0, 'g' : 0, 'b' : 256}


    Notes
    -----
    This implementation does not support mixed graphs (directed and undirected
    edges together).

    The node id attribute is set to be the string of the node label.
    If you want to specify an id use set it as node data, e.g.
    node['a']['id']=1 to set the id of node 'a' to 1.

    References
    ----------
    .. [1] GEXF File Format, https://gephi.org/gexf/format/
    .. [2] GEXF viz schema 1.1, https://gephi.org/gexf/1.1draft/viz
    """
    writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint,
                        version=version)
    writer.add_graph(G)
    writer.write(path)


def generate_gexf(G, encoding='utf-8', prettyprint=True, version='1.2draft'):
    """Generate lines of GEXF format representation of G.

    "GEXF (Graph Exchange XML Format) is a language for describing
    complex networks structures, their associated data and dynamics" [1]_.

    Parameters
    ----------
    G : graph
    A NetworkX graph
    encoding : string (optional, default: 'utf-8')
    Encoding for text data.
    prettyprint : bool (optional, default: True)
    If True use line breaks and indenting in output XML.
    version : string (default: 1.2draft)
    Version of GEFX File Format (see https://gephi.org/gexf/format/schema.html)
    Supported values: "1.1draft", "1.2draft"


    Examples
    --------
    >>> G = nx.path_graph(4)
    >>> linefeed = chr(10) # linefeed=\n
    >>> s = linefeed.join(nx.generate_gexf(G))  # doctest: +SKIP
    >>> for line in nx.generate_gexf(G):  # doctest: +SKIP
    ...    print line

    Notes
    -----
    This implementation does not support mixed graphs (directed and undirected
    edges together).

    The node id attribute is set to be the string of the node label.
    If you want to specify an id use set it as node data, e.g.
    node['a']['id']=1 to set the id of node 'a' to 1.

    References
    ----------
    .. [1] GEXF File Format, https://gephi.org/gexf/format/
    """
    writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint,
                        version=version)
    writer.add_graph(G)
    for line in str(writer).splitlines():
        yield line


@open_file(0, mode='rb')
def read_gexf(path, node_type=None, relabel=False, version='1.2draft'):
    """Read graph in GEXF format from path.

    "GEXF (Graph Exchange XML Format) is a language for describing
    complex networks structures, their associated data and dynamics" [1]_.

    Parameters
    ----------
    path : file or string
       File or file name to write.
       File names ending in .gz or .bz2 will be compressed.
    node_type: Python type (default: None)
       Convert node ids to this type if not None.
    relabel : bool (default: False)
       If True relabel the nodes to use the GEXF node "label" attribute
       instead of the node "id" attribute as the NetworkX node label.
    version : string (default: 1.2draft)
    Version of GEFX File Format (see https://gephi.org/gexf/format/schema.html)
       Supported values: "1.1draft", "1.2draft"

    Returns
    -------
    graph: NetworkX graph
        If no parallel edges are found a Graph or DiGraph is returned.
        Otherwise a MultiGraph or MultiDiGraph is returned.

    Notes
    -----
    This implementation does not support mixed graphs (directed and undirected
    edges together).

    References
    ----------
    .. [1] GEXF File Format, https://gephi.org/gexf/format/
    """
    reader = GEXFReader(node_type=node_type, version=version)
    if relabel:
        G = relabel_gexf_graph(reader(path))
    else:
        G = reader(path)
    return G


class GEXF(object):
    versions = {}
    d = {'NS_GEXF': "http://www.gexf.net/1.1draft",
         'NS_VIZ': "http://www.gexf.net/1.1draft/viz",
         'NS_XSI': "http://www.w3.org/2001/XMLSchema-instance",
         'SCHEMALOCATION': ' '.join(['http://www.gexf.net/1.1draft',
                                     'http://www.gexf.net/1.1draft/gexf.xsd']),
         'VERSION': '1.1'}
    versions['1.1draft'] = d
    d = {'NS_GEXF': "http://www.gexf.net/1.2draft",
         'NS_VIZ': "http://www.gexf.net/1.2draft/viz",
         'NS_XSI': "http://www.w3.org/2001/XMLSchema-instance",
         'SCHEMALOCATION': ' '.join(['http://www.gexf.net/1.2draft',
                                     'http://www.gexf.net/1.2draft/gexf.xsd']),
         'VERSION': '1.2'}
    versions['1.2draft'] = d

    types = [(int, "integer"),
             (float, "float"),
             (float, "double"),
             (bool, "boolean"),
             (list, "string"),
             (dict, "string"),
             (int, "long"),
             (str, "liststring"),
             (str, "anyURI"),
             (str, "string")]

    # These additions to types allow writing numpy types
    try:
        import numpy as np
    except ImportError:
        pass
    else:
        # prepend so that python types are created upon read (last entry wins)
        types = [(np.float64, "float"), (np.float32, "float"),
                 (np.float16, "float"), (np.float_, "float"),
                 (np.int, "int"), (np.int8, "int"),
                 (np.int16, "int"), (np.int32, "int"),
                 (np.int64, "int"), (np.uint8, "int"),
                 (np.uint16, "int"), (np.uint32, "int"),
                 (np.uint64, "int"), (np.int_, "int"),
                 (np.intc, "int"), (np.intp, "int"),
                 ] + types

    xml_type = dict(types)
    python_type = dict(reversed(a) for a in types)

    # http://www.w3.org/TR/xmlschema-2/#boolean
    convert_bool = {
        'true': True, 'false': False,
        'True': True, 'False': False,
        '0': False, 0: False,
        '1': True, 1: True
    }

    def set_version(self, version):
        d = self.versions.get(version)
        if d is None:
            raise nx.NetworkXError('Unknown GEXF version %s.' % version)
        self.NS_GEXF = d['NS_GEXF']
        self.NS_VIZ = d['NS_VIZ']
        self.NS_XSI = d['NS_XSI']
        self.SCHEMALOCATION = d['SCHEMALOCATION']
        self.VERSION = d['VERSION']
        self.version = version


class GEXFWriter(GEXF):
    # class for writing GEXF format files
    # use write_gexf() function
    def __init__(self, graph=None, encoding='utf-8', prettyprint=True,
                 version='1.2draft'):
        try:
            import xml.etree.ElementTree as ET
        except ImportError:
            raise ImportError('GEXF writer requires '
                              'xml.elementtree.ElementTree')
        self.prettyprint = prettyprint
        self.encoding = encoding
        self.set_version(version)
        self.xml = Element('gexf',
                           {'xmlns': self.NS_GEXF,
                            'xmlns:xsi': self.NS_XSI,
                            'xsi:schemaLocation': self.SCHEMALOCATION,
                            'version': self.VERSION})

        # Make meta element a non-graph element
        # Also add lastmodifieddate as attribute, not tag
        meta_element = Element('meta')
        subelement_text = 'NetworkX {}'.format(nx.__version__)
        SubElement(meta_element, 'creator').text = subelement_text
        meta_element.set('lastmodifieddate', time.strftime('%Y-%m-%d'))
        self.xml.append(meta_element)

        ET.register_namespace('viz', self.NS_VIZ)

        # counters for edge and attribute identifiers
        self.edge_id = itertools.count()
        self.attr_id = itertools.count()
        self.all_edge_ids = set()
        # default attributes are stored in dictionaries
        self.attr = {}
        self.attr['node'] = {}
        self.attr['edge'] = {}
        self.attr['node']['dynamic'] = {}
        self.attr['node']['static'] = {}
        self.attr['edge']['dynamic'] = {}
        self.attr['edge']['static'] = {}

        if graph is not None:
            self.add_graph(graph)

    def __str__(self):
        if self.prettyprint:
            self.indent(self.xml)
        s = tostring(self.xml).decode(self.encoding)
        return s

    def add_graph(self, G):
        # first pass through G collecting edge ids
        for u, v, dd in G.edges(data=True):
            eid = dd.get('id')
            if eid is not None:
                self.all_edge_ids.add(make_str(eid))
        # set graph attributes
        if G.graph.get('mode') == 'dynamic':
            mode = 'dynamic'
        else:
            mode = 'static'
        # Add a graph element to the XML
        if G.is_directed():
            default = 'directed'
        else:
            default = 'undirected'
        name = G.graph.get('name', '')
        graph_element = Element('graph', defaultedgetype=default, mode=mode,
                                name=name)
        self.graph_element = graph_element
        self.add_nodes(G, graph_element)
        self.add_edges(G, graph_element)
        self.xml.append(graph_element)

    def add_nodes(self, G, graph_element):
        nodes_element = Element('nodes')
        for node, data in G.nodes(data=True):
            node_data = data.copy()
            node_id = make_str(node_data.pop('id', node))
            kw = {'id': node_id}
            label = make_str(node_data.pop('label', node))
            kw['label'] = label
            try:
                pid = node_data.pop('pid')
                kw['pid'] = make_str(pid)
            except KeyError:
                pass
            try:
                start = node_data.pop('start')
                kw['start'] = make_str(start)
                self.alter_graph_mode_timeformat(start)
            except KeyError:
                pass
            try:
                end = node_data.pop('end')
                kw['end'] = make_str(end)
                self.alter_graph_mode_timeformat(end)
            except KeyError:
                pass
            # add node element with attributes
            node_element = Element('node', **kw)
            # add node element and attr subelements
            default = G.graph.get('node_default', {})
            node_data = self.add_parents(node_element, node_data)
            if self.version == '1.1':
                node_data = self.add_slices(node_element, node_data)
            else:
                node_data = self.add_spells(node_element, node_data)
            node_data = self.add_viz(node_element, node_data)
            node_data = self.add_attributes('node', node_element,
                                            node_data, default)
            nodes_element.append(node_element)
        graph_element.append(nodes_element)

    def add_edges(self, G, graph_element):
        def edge_key_data(G):
            # helper function to unify multigraph and graph edge iterator
            if G.is_multigraph():
                for u, v, key, data in G.edges(data=True, keys=True):
                    edge_data = data.copy()
                    edge_data.update(key=key)
                    edge_id = edge_data.pop('id', None)
                    if edge_id is None:
                        edge_id = next(self.edge_id)
                        while make_str(edge_id) in self.all_edge_ids:
                            edge_id = next(self.edge_id)
                        self.all_edge_ids.add(make_str(edge_id))
                    yield u, v, edge_id, edge_data
            else:
                for u, v, data in G.edges(data=True):
                    edge_data = data.copy()
                    edge_id = edge_data.pop('id', None)
                    if edge_id is None:
                        edge_id = next(self.edge_id)
                        while make_str(edge_id) in self.all_edge_ids:
                            edge_id = next(self.edge_id)
                        self.all_edge_ids.add(make_str(edge_id))
                    yield u, v, edge_id, edge_data
        edges_element = Element('edges')
        for u, v, key, edge_data in edge_key_data(G):
            kw = {'id': make_str(key)}
            try:
                edge_weight = edge_data.pop('weight')
                kw['weight'] = make_str(edge_weight)
            except KeyError:
                pass
            try:
                edge_type = edge_data.pop('type')
                kw['type'] = make_str(edge_type)
            except KeyError:
                pass
            try:
                start = edge_data.pop('start')
                kw['start'] = make_str(start)
                self.alter_graph_mode_timeformat(start)
            except KeyError:
                pass
            try:
                end = edge_data.pop('end')
                kw['end'] = make_str(end)
                self.alter_graph_mode_timeformat(end)
            except KeyError:
                pass
            source_id = make_str(G.nodes[u].get('id', u))
            target_id = make_str(G.nodes[v].get('id', v))
            edge_element = Element('edge',
                                   source=source_id, target=target_id, **kw)
            default = G.graph.get('edge_default', {})
            if self.version == '1.1':
                edge_data = self.add_slices(edge_element, edge_data)
            else:
                edge_data = self.add_spells(edge_element, edge_data)
            edge_data = self.add_viz(edge_element, edge_data)
            edge_data = self.add_attributes('edge', edge_element,
                                            edge_data, default)
            edges_element.append(edge_element)
        graph_element.append(edges_element)

    def add_attributes(self, node_or_edge, xml_obj, data, default):
        # Add attrvalues to node or edge
        attvalues = Element('attvalues')
        if len(data) == 0:
            return data
        mode = 'static'
        for k, v in data.items():
            # rename generic multigraph key to avoid any name conflict
            if k == 'key':
                k = 'networkx_key'
            val_type = type(v)
            if val_type not in self.xml_type:
                raise TypeError('attribute value type is not allowed: %s'
                                % val_type)
            if isinstance(v, list):
                # dynamic data
                for val, start, end in v:
                    val_type = type(val)
                    if start is not None or end is not None:
                        mode = 'dynamic'
                        self.alter_graph_mode_timeformat(start)
                        self.alter_graph_mode_timeformat(end)
                        break
                attr_id = self.get_attr_id(make_str(k),
                                           self.xml_type[val_type],
                                           node_or_edge, default, mode)
                for val, start, end in v:
                    e = Element('attvalue')
                    e.attrib['for'] = attr_id
                    e.attrib['value'] = make_str(val)
                    # Handle nan, inf, -inf differently
                    if val_type == float:
                        if e.attrib['value'] == 'inf':
                            e.attrib['value'] = 'INF'
                        elif e.attrib['value'] == 'nan':
                            e.attrib['value'] = 'NaN'
                        elif e.attrib['value'] == '-inf':
                            e.attrib['value'] = '-INF'
                    if start is not None:
                        e.attrib['start'] = make_str(start)
                    if end is not None:
                        e.attrib['end'] = make_str(end)
                    attvalues.append(e)
            else:
                # static data
                mode = 'static'
                attr_id = self.get_attr_id(make_str(k),
                                           self.xml_type[val_type],
                                           node_or_edge, default, mode)
                e = Element('attvalue')
                e.attrib['for'] = attr_id
                if isinstance(v, bool):
                    e.attrib['value'] = make_str(v).lower()
                else:
                    e.attrib['value'] = make_str(v)
                    # Handle float nan, inf, -inf differently
                    if val_type == float:
                        if e.attrib['value'] == 'inf':
                            e.attrib['value'] = 'INF'
                        elif e.attrib['value'] == 'nan':
                            e.attrib['value'] = 'NaN'
                        elif e.attrib['value'] == '-inf':
                            e.attrib['value'] = '-INF'
                attvalues.append(e)
        xml_obj.append(attvalues)
        return data

    def get_attr_id(self, title, attr_type, edge_or_node, default, mode):
        # find the id of the attribute or generate a new id
        try:
            return self.attr[edge_or_node][mode][title]
        except KeyError:
            # generate new id
            new_id = str(next(self.attr_id))
            self.attr[edge_or_node][mode][title] = new_id
            attr_kwargs = {'id': new_id, 'title': title, 'type': attr_type}
            attribute = Element('attribute', **attr_kwargs)
            # add subelement for data default value if present
            default_title = default.get(title)
            if default_title is not None:
                default_element = Element('default')
                default_element.text = make_str(default_title)
                attribute.append(default_element)
            # new insert it into the XML
            attributes_element = None
            for a in self.graph_element.findall('attributes'):
                # find existing attributes element by class and mode
                a_class = a.get('class')
                a_mode = a.get('mode', 'static')
                if a_class == edge_or_node and a_mode == mode:
                    attributes_element = a
            if attributes_element is None:
                # create new attributes element
                attr_kwargs = {'mode': mode, 'class': edge_or_node}
                attributes_element = Element('attributes', **attr_kwargs)
                self.graph_element.insert(0, attributes_element)
            attributes_element.append(attribute)
        return new_id

    def add_viz(self, element, node_data):
        viz = node_data.pop('viz', False)
        if viz:
            color = viz.get('color')
            if color is not None:
                if self.VERSION == '1.1':
                    e = Element('{%s}color' % self.NS_VIZ,
                                r=str(color.get('r')),
                                g=str(color.get('g')),
                                b=str(color.get('b')))
                else:
                    e = Element('{%s}color' % self.NS_VIZ,
                                r=str(color.get('r')),
                                g=str(color.get('g')),
                                b=str(color.get('b')),
                                a=str(color.get('a')))
                element.append(e)

            size = viz.get('size')
            if size is not None:
                e = Element('{%s}size' % self.NS_VIZ, value=str(size))
                element.append(e)

            thickness = viz.get('thickness')
            if thickness is not None:
                e = Element('{%s}thickness' % self.NS_VIZ,
                            value=str(thickness))
                element.append(e)

            shape = viz.get('shape')
            if shape is not None:
                if shape.startswith('http'):
                    e = Element('{%s}shape' % self.NS_VIZ,
                                value='image', uri=str(shape))
                else:
                    e = Element('{%s}shape' % self.NS_VIZ, value=str(shape))
                element.append(e)

            position = viz.get('position')
            if position is not None:
                e = Element('{%s}position' % self.NS_VIZ,
                            x=str(position.get('x')),
                            y=str(position.get('y')),
                            z=str(position.get('z')))
                element.append(e)
        return node_data

    def add_parents(self, node_element, node_data):
        parents = node_data.pop('parents', False)
        if parents:
            parents_element = Element('parents')
            for p in parents:
                e = Element('parent')
                e.attrib['for'] = str(p)
                parents_element.append(e)
            node_element.append(parents_element)
        return node_data

    def add_slices(self, node_or_edge_element, node_or_edge_data):
        slices = node_or_edge_data.pop('slices', False)
        if slices:
            slices_element = Element('slices')
            for start, end in slices:
                e = Element('slice', start=str(start), end=str(end))
                slices_element.append(e)
            node_or_edge_element.append(slices_element)
        return node_or_edge_data

    def add_spells(self, node_or_edge_element, node_or_edge_data):
        spells = node_or_edge_data.pop('spells', False)
        if spells:
            spells_element = Element('spells')
            for start, end in spells:
                e = Element('spell')
                if start is not None:
                    e.attrib['start'] = make_str(start)
                    self.alter_graph_mode_timeformat(start)
                if end is not None:
                    e.attrib['end'] = make_str(end)
                    self.alter_graph_mode_timeformat(end)
                spells_element.append(e)
            node_or_edge_element.append(spells_element)
        return node_or_edge_data

    def alter_graph_mode_timeformat(self, start_or_end):
        # If 'start' or 'end' appears, alter Graph mode to dynamic and
        # set timeformat
        if self.graph_element.get('mode') == 'static':
            if start_or_end is not None:
                if isinstance(start_or_end, str):
                    timeformat = 'date'
                elif isinstance(start_or_end, float):
                    timeformat = 'double'
                elif isinstance(start_or_end, int):
                    timeformat = 'long'
                else:
                    raise nx.NetworkXError(
                        'timeformat should be of the type int, float or str')
                self.graph_element.set('timeformat', timeformat)
                self.graph_element.set('mode', 'dynamic')

    def write(self, fh):
        # Serialize graph G in GEXF to the open fh
        if self.prettyprint:
            self.indent(self.xml)
        document = ElementTree(self.xml)
        document.write(fh, encoding=self.encoding, xml_declaration=True)

    def indent(self, elem, level=0):
        # in-place prettyprint formatter
        i = "\n" + "  " * level
        if len(elem):
            if not elem.text or not elem.text.strip():
                elem.text = i + "  "
            if not elem.tail or not elem.tail.strip():
                elem.tail = i
            for elem in elem:
                self.indent(elem, level + 1)
            if not elem.tail or not elem.tail.strip():
                elem.tail = i
        else:
            if level and (not elem.tail or not elem.tail.strip()):
                elem.tail = i


class GEXFReader(GEXF):
    # Class to read GEXF format files
    # use read_gexf() function
    def __init__(self, node_type=None, version='1.2draft'):
        try:
            import xml.etree.ElementTree
        except ImportError:
            raise ImportError('GEXF reader requires '
                              'xml.elementtree.ElementTree.')
        self.node_type = node_type
        # assume simple graph and test for multigraph on read
        self.simple_graph = True
        self.set_version(version)

    def __call__(self, stream):
        self.xml = ElementTree(file=stream)
        g = self.xml.find('{%s}graph' % self.NS_GEXF)
        if g is not None:
            return self.make_graph(g)
        # try all the versions
        for version in self.versions:
            self.set_version(version)
            g = self.xml.find('{%s}graph' % self.NS_GEXF)
            if g is not None:
                return self.make_graph(g)
        raise nx.NetworkXError('No <graph> element in GEXF file.')

    def make_graph(self, graph_xml):
        # start with empty DiGraph or MultiDiGraph
        edgedefault = graph_xml.get('defaultedgetype', None)
        if edgedefault == 'directed':
            G = nx.MultiDiGraph()
        else:
            G = nx.MultiGraph()

        # graph attributes
        graph_name = graph_xml.get('name', '')
        if graph_name != '':
            G.graph['name'] = graph_name
        graph_start = graph_xml.get('start')
        if graph_start is not None:
            G.graph['start'] = graph_start
        graph_end = graph_xml.get('end')
        if graph_end is not None:
            G.graph['end'] = graph_end
        graph_mode = graph_xml.get('mode', '')
        if graph_mode == 'dynamic':
            G.graph['mode'] = 'dynamic'
        else:
            G.graph['mode'] = 'static'

        # timeformat
        self.timeformat = graph_xml.get('timeformat')
        if self.timeformat == 'date':
            self.timeformat = 'string'

        # node and edge attributes
        attributes_elements = graph_xml.findall('{%s}attributes' %
                                                self.NS_GEXF)
        # dictionaries to hold attributes and attribute defaults
        node_attr = {}
        node_default = {}
        edge_attr = {}
        edge_default = {}
        for a in attributes_elements:
            attr_class = a.get('class')
            if attr_class == 'node':
                na, nd = self.find_gexf_attributes(a)
                node_attr.update(na)
                node_default.update(nd)
                G.graph['node_default'] = node_default
            elif attr_class == 'edge':
                ea, ed = self.find_gexf_attributes(a)
                edge_attr.update(ea)
                edge_default.update(ed)
                G.graph['edge_default'] = edge_default
            else:
                raise  # unknown attribute class

        # Hack to handle Gephi0.7beta bug
        # add weight attribute
        ea = {'weight': {'type': 'double', 'mode': 'static',
                                 'title': 'weight'}}
        ed = {}
        edge_attr.update(ea)
        edge_default.update(ed)
        G.graph['edge_default'] = edge_default

        # add nodes
        nodes_element = graph_xml.find('{%s}nodes' % self.NS_GEXF)
        if nodes_element is not None:
            for node_xml in nodes_element.findall('{%s}node' % self.NS_GEXF):
                self.add_node(G, node_xml, node_attr)

        # add edges
        edges_element = graph_xml.find('{%s}edges' % self.NS_GEXF)
        if edges_element is not None:
            for edge_xml in edges_element.findall('{%s}edge' % self.NS_GEXF):
                self.add_edge(G, edge_xml, edge_attr)

        # switch to Graph or DiGraph if no parallel edges were found.
        if self.simple_graph:
            if G.is_directed():
                G = nx.DiGraph(G)
            else:
                G = nx.Graph(G)
        return G

    def add_node(self, G, node_xml, node_attr, node_pid=None):
        # add a single node with attributes to the graph

        # get attributes and subattributues for node
        data = self.decode_attr_elements(node_attr, node_xml)
        data = self.add_parents(data, node_xml)  # add any parents
        if self.version == '1.1':
            data = self.add_slices(data, node_xml)  # add slices
        else:
            data = self.add_spells(data, node_xml)  # add spells
        data = self.add_viz(data, node_xml)  # add viz
        data = self.add_start_end(data, node_xml)  # add start/end

        # find the node id and cast it to the appropriate type
        node_id = node_xml.get('id')
        if self.node_type is not None:
            node_id = self.node_type(node_id)

        # every node should have a label
        node_label = node_xml.get('label')
        data['label'] = node_label

        # parent node id
        node_pid = node_xml.get('pid', node_pid)
        if node_pid is not None:
            data['pid'] = node_pid

        # check for subnodes, recursive
        subnodes = node_xml.find('{%s}nodes' % self.NS_GEXF)
        if subnodes is not None:
            for node_xml in subnodes.findall('{%s}node' % self.NS_GEXF):
                self.add_node(G, node_xml, node_attr, node_pid=node_id)

        G.add_node(node_id, **data)

    def add_start_end(self, data, xml):
        # start and end times
        ttype = self.timeformat
        node_start = xml.get('start')
        if node_start is not None:
            data['start'] = self.python_type[ttype](node_start)
        node_end = xml.get('end')
        if node_end is not None:
            data['end'] = self.python_type[ttype](node_end)
        return data

    def add_viz(self, data, node_xml):
        # add viz element for node
        viz = {}
        color = node_xml.find('{%s}color' % self.NS_VIZ)
        if color is not None:
            if self.VERSION == '1.1':
                viz['color'] = {'r': int(color.get('r')),
                                'g': int(color.get('g')),
                                'b': int(color.get('b'))}
            else:
                viz['color'] = {'r': int(color.get('r')),
                                'g': int(color.get('g')),
                                'b': int(color.get('b')),
                                'a': float(color.get('a', 1))}

        size = node_xml.find('{%s}size' % self.NS_VIZ)
        if size is not None:
            viz['size'] = float(size.get('value'))

        thickness = node_xml.find('{%s}thickness' % self.NS_VIZ)
        if thickness is not None:
            viz['thickness'] = float(thickness.get('value'))

        shape = node_xml.find('{%s}shape' % self.NS_VIZ)
        if shape is not None:
            viz['shape'] = shape.get('shape')
            if viz['shape'] == 'image':
                viz['shape'] = shape.get('uri')

        position = node_xml.find('{%s}position' % self.NS_VIZ)
        if position is not None:
            viz['position'] = {'x': float(position.get('x', 0)),
                               'y': float(position.get('y', 0)),
                               'z': float(position.get('z', 0))}

        if len(viz) > 0:
            data['viz'] = viz
        return data

    def add_parents(self, data, node_xml):
        parents_element = node_xml.find('{%s}parents' % self.NS_GEXF)
        if parents_element is not None:
            data['parents'] = []
            for p in parents_element.findall('{%s}parent' % self.NS_GEXF):
                parent = p.get('for')
                data['parents'].append(parent)
        return data

    def add_slices(self, data, node_or_edge_xml):
        slices_element = node_or_edge_xml.find('{%s}slices' % self.NS_GEXF)
        if slices_element is not None:
            data['slices'] = []
            for s in slices_element.findall('{%s}slice' % self.NS_GEXF):
                start = s.get('start')
                end = s.get('end')
                data['slices'].append((start, end))
        return data

    def add_spells(self, data, node_or_edge_xml):
        spells_element = node_or_edge_xml.find('{%s}spells' % self.NS_GEXF)
        if spells_element is not None:
            data['spells'] = []
            ttype = self.timeformat
            for s in spells_element.findall('{%s}spell' % self.NS_GEXF):
                start = self.python_type[ttype](s.get('start'))
                end = self.python_type[ttype](s.get('end'))
                data['spells'].append((start, end))
        return data

    def add_edge(self, G, edge_element, edge_attr):
        # add an edge to the graph

        # raise error if we find mixed directed and undirected edges
        edge_direction = edge_element.get('type')
        if G.is_directed() and edge_direction == 'undirected':
            raise nx.NetworkXError(
                'Undirected edge found in directed graph.')
        if (not G.is_directed()) and edge_direction == 'directed':
            raise nx.NetworkXError(
                'Directed edge found in undirected graph.')

        # Get source and target and recast type if required
        source = edge_element.get('source')
        target = edge_element.get('target')
        if self.node_type is not None:
            source = self.node_type(source)
            target = self.node_type(target)

        data = self.decode_attr_elements(edge_attr, edge_element)
        data = self.add_start_end(data, edge_element)

        if self.version == '1.1':
            data = self.add_slices(data, edge_element)  # add slices
        else:
            data = self.add_spells(data, edge_element)  # add spells

        # GEXF stores edge ids as an attribute
        # NetworkX uses them as keys in multigraphs
        # if networkx_key is not specified as an attribute
        edge_id = edge_element.get('id')
        if edge_id is not None:
            data['id'] = edge_id

        # check if there is a 'multigraph_key' and use that as edge_id
        multigraph_key = data.pop('networkx_key', None)
        if multigraph_key is not None:
            edge_id = multigraph_key

        weight = edge_element.get('weight')
        if weight is not None:
            data['weight'] = float(weight)

        edge_label = edge_element.get('label')
        if edge_label is not None:
            data['label'] = edge_label

        if G.has_edge(source, target):
            # seen this edge before - this is a multigraph
            self.simple_graph = False
        G.add_edge(source, target, key=edge_id, **data)
        if edge_direction == 'mutual':
            G.add_edge(target, source, key=edge_id, **data)

    def decode_attr_elements(self, gexf_keys, obj_xml):
        # Use the key information to decode the attr XML
        attr = {}
        # look for outer '<attvalues>' element
        attr_element = obj_xml.find('{%s}attvalues' % self.NS_GEXF)
        if attr_element is not None:
            # loop over <attvalue> elements
            for a in attr_element.findall('{%s}attvalue' % self.NS_GEXF):
                key = a.get('for')  # for is required
                try:  # should be in our gexf_keys dictionary
                    title = gexf_keys[key]['title']
                except KeyError:
                    raise nx.NetworkXError('No attribute defined for=%s.'
                                           % key)
                atype = gexf_keys[key]['type']
                value = a.get('value')
                if atype == 'boolean':
                    value = self.convert_bool[value]
                else:
                    value = self.python_type[atype](value)
                if gexf_keys[key]['mode'] == 'dynamic':
                    # for dynamic graphs use list of three-tuples
                    # [(value1,start1,end1), (value2,start2,end2), etc]
                    ttype = self.timeformat
                    start = self.python_type[ttype](a.get('start'))
                    end = self.python_type[ttype](a.get('end'))
                    if title in attr:
                        attr[title].append((value, start, end))
                    else:
                        attr[title] = [(value, start, end)]
                else:
                    # for static graphs just assign the value
                    attr[title] = value
        return attr

    def find_gexf_attributes(self, attributes_element):
        # Extract all the attributes and defaults
        attrs = {}
        defaults = {}
        mode = attributes_element.get('mode')
        for k in attributes_element.findall('{%s}attribute' % self.NS_GEXF):
            attr_id = k.get('id')
            title = k.get('title')
            atype = k.get('type')
            attrs[attr_id] = {'title': title, 'type': atype, 'mode': mode}
            # check for the 'default' subelement of key element and add
            default = k.find('{%s}default' % self.NS_GEXF)
            if default is not None:
                if atype == 'boolean':
                    value = self.convert_bool[default.text]
                else:
                    value = self.python_type[atype](default.text)
                defaults[title] = value
        return attrs, defaults


def relabel_gexf_graph(G):
    """Relabel graph using "label" node keyword for node label.

    Parameters
    ----------
    G : graph
       A NetworkX graph read from GEXF data

    Returns
    -------
    H : graph
      A NetworkX graph with relabed nodes

    Raises
    ------
    NetworkXError
        If node labels are missing or not unique while relabel=True.

    Notes
    -----
    This function relabels the nodes in a NetworkX graph with the
    "label" attribute.  It also handles relabeling the specific GEXF
    node attributes "parents", and "pid".
    """
    # build mapping of node labels, do some error checking
    try:
        mapping = [(u, G.nodes[u]['label']) for u in G]
    except KeyError:
        raise nx.NetworkXError('Failed to relabel nodes: '
                               'missing node labels found. '
                               'Use relabel=False.')
    x, y = zip(*mapping)
    if len(set(y)) != len(G):
        raise nx.NetworkXError('Failed to relabel nodes: '
                               'duplicate node labels found. '
                               'Use relabel=False.')
    mapping = dict(mapping)
    H = nx.relabel_nodes(G, mapping)
    # relabel attributes
    for n in G:
        m = mapping[n]
        H.nodes[m]['id'] = n
        H.nodes[m].pop('label')
        if 'pid' in H.nodes[m]:
            H.nodes[m]['pid'] = mapping[G.nodes[n]['pid']]
        if 'parents' in H.nodes[m]:
            H.nodes[m]['parents'] = [mapping[p] for p in G.nodes[n]['parents']]
    return H


# fixture for pytest
def setup_module(module):
    import pytest
    xml.etree.cElementTree = pytest.importorskip('xml.etree.cElementTree')


# fixture for pytest
def teardown_module(module):
    import os
    try:
        os.unlink('test.gexf')
    except Exception as e:
        pass