comparison env/lib/python3.7/site-packages/networkx/readwrite/gexf.py @ 2:6af9afd405e9 draft

"planemo upload commit 0a63dd5f4d38a1f6944587f52a8cd79874177fc1"
author shellac
date Thu, 14 May 2020 14:56:58 -0400
parents 26e78fe6e8c4
children
comparison
equal deleted inserted replaced
1:75ca89e9b81c 2:6af9afd405e9
1 # Copyright (C) 2013-2019 by
2 #
3 # Authors: Aric Hagberg <hagberg@lanl.gov>
4 # Dan Schult <dschult@colgate.edu>
5 # Pieter Swart <swart@lanl.gov>
6 # All rights reserved.
7 # BSD license.
8 # Based on GraphML NetworkX GraphML reader
9 """Read and write graphs in GEXF format.
10
11 GEXF (Graph Exchange XML Format) is a language for describing complex
12 network structures, their associated data and dynamics.
13
14 This implementation does not support mixed graphs (directed and
15 undirected edges together).
16
17 Format
18 ------
19 GEXF is an XML format. See https://gephi.org/gexf/format/schema.html for the
20 specification and https://gephi.org/gexf/format/basic.html for examples.
21 """
22 import itertools
23 import time
24
25 import networkx as nx
26 from networkx.utils import open_file, make_str
27 try:
28 from xml.etree.cElementTree import (Element, ElementTree, SubElement,
29 tostring)
30 except ImportError:
31 try:
32 from xml.etree.ElementTree import (Element, ElementTree, SubElement,
33 tostring)
34 except ImportError:
35 pass
36
37 __all__ = ['write_gexf', 'read_gexf', 'relabel_gexf_graph', 'generate_gexf']
38
39
40 @open_file(1, mode='wb')
41 def write_gexf(G, path, encoding='utf-8', prettyprint=True,
42 version='1.2draft'):
43 """Write G in GEXF format to path.
44
45 "GEXF (Graph Exchange XML Format) is a language for describing
46 complex networks structures, their associated data and dynamics" [1]_.
47
48 Node attributes are checked according to the version of the GEXF
49 schemas used for parameters which are not user defined,
50 e.g. visualization 'viz' [2]_. See example for usage.
51
52 Parameters
53 ----------
54 G : graph
55 A NetworkX graph
56 path : file or string
57 File or file name to write.
58 File names ending in .gz or .bz2 will be compressed.
59 encoding : string (optional, default: 'utf-8')
60 Encoding for text data.
61 prettyprint : bool (optional, default: True)
62 If True use line breaks and indenting in output XML.
63
64 Examples
65 --------
66 >>> G = nx.path_graph(4)
67 >>> nx.write_gexf(G, "test.gexf")
68
69 # visualization data
70 >>> G.nodes[0]['viz'] = {'size': 54}
71 >>> G.nodes[0]['viz']['position'] = {'x' : 0, 'y' : 1}
72 >>> G.nodes[0]['viz']['color'] = {'r' : 0, 'g' : 0, 'b' : 256}
73
74
75 Notes
76 -----
77 This implementation does not support mixed graphs (directed and undirected
78 edges together).
79
80 The node id attribute is set to be the string of the node label.
81 If you want to specify an id use set it as node data, e.g.
82 node['a']['id']=1 to set the id of node 'a' to 1.
83
84 References
85 ----------
86 .. [1] GEXF File Format, https://gephi.org/gexf/format/
87 .. [2] GEXF viz schema 1.1, https://gephi.org/gexf/1.1draft/viz
88 """
89 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint,
90 version=version)
91 writer.add_graph(G)
92 writer.write(path)
93
94
95 def generate_gexf(G, encoding='utf-8', prettyprint=True, version='1.2draft'):
96 """Generate lines of GEXF format representation of G.
97
98 "GEXF (Graph Exchange XML Format) is a language for describing
99 complex networks structures, their associated data and dynamics" [1]_.
100
101 Parameters
102 ----------
103 G : graph
104 A NetworkX graph
105 encoding : string (optional, default: 'utf-8')
106 Encoding for text data.
107 prettyprint : bool (optional, default: True)
108 If True use line breaks and indenting in output XML.
109 version : string (default: 1.2draft)
110 Version of GEFX File Format (see https://gephi.org/gexf/format/schema.html)
111 Supported values: "1.1draft", "1.2draft"
112
113
114 Examples
115 --------
116 >>> G = nx.path_graph(4)
117 >>> linefeed = chr(10) # linefeed=\n
118 >>> s = linefeed.join(nx.generate_gexf(G)) # doctest: +SKIP
119 >>> for line in nx.generate_gexf(G): # doctest: +SKIP
120 ... print line
121
122 Notes
123 -----
124 This implementation does not support mixed graphs (directed and undirected
125 edges together).
126
127 The node id attribute is set to be the string of the node label.
128 If you want to specify an id use set it as node data, e.g.
129 node['a']['id']=1 to set the id of node 'a' to 1.
130
131 References
132 ----------
133 .. [1] GEXF File Format, https://gephi.org/gexf/format/
134 """
135 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint,
136 version=version)
137 writer.add_graph(G)
138 for line in str(writer).splitlines():
139 yield line
140
141
142 @open_file(0, mode='rb')
143 def read_gexf(path, node_type=None, relabel=False, version='1.2draft'):
144 """Read graph in GEXF format from path.
145
146 "GEXF (Graph Exchange XML Format) is a language for describing
147 complex networks structures, their associated data and dynamics" [1]_.
148
149 Parameters
150 ----------
151 path : file or string
152 File or file name to write.
153 File names ending in .gz or .bz2 will be compressed.
154 node_type: Python type (default: None)
155 Convert node ids to this type if not None.
156 relabel : bool (default: False)
157 If True relabel the nodes to use the GEXF node "label" attribute
158 instead of the node "id" attribute as the NetworkX node label.
159 version : string (default: 1.2draft)
160 Version of GEFX File Format (see https://gephi.org/gexf/format/schema.html)
161 Supported values: "1.1draft", "1.2draft"
162
163 Returns
164 -------
165 graph: NetworkX graph
166 If no parallel edges are found a Graph or DiGraph is returned.
167 Otherwise a MultiGraph or MultiDiGraph is returned.
168
169 Notes
170 -----
171 This implementation does not support mixed graphs (directed and undirected
172 edges together).
173
174 References
175 ----------
176 .. [1] GEXF File Format, https://gephi.org/gexf/format/
177 """
178 reader = GEXFReader(node_type=node_type, version=version)
179 if relabel:
180 G = relabel_gexf_graph(reader(path))
181 else:
182 G = reader(path)
183 return G
184
185
186 class GEXF(object):
187 versions = {}
188 d = {'NS_GEXF': "http://www.gexf.net/1.1draft",
189 'NS_VIZ': "http://www.gexf.net/1.1draft/viz",
190 'NS_XSI': "http://www.w3.org/2001/XMLSchema-instance",
191 'SCHEMALOCATION': ' '.join(['http://www.gexf.net/1.1draft',
192 'http://www.gexf.net/1.1draft/gexf.xsd']),
193 'VERSION': '1.1'}
194 versions['1.1draft'] = d
195 d = {'NS_GEXF': "http://www.gexf.net/1.2draft",
196 'NS_VIZ': "http://www.gexf.net/1.2draft/viz",
197 'NS_XSI': "http://www.w3.org/2001/XMLSchema-instance",
198 'SCHEMALOCATION': ' '.join(['http://www.gexf.net/1.2draft',
199 'http://www.gexf.net/1.2draft/gexf.xsd']),
200 'VERSION': '1.2'}
201 versions['1.2draft'] = d
202
203 types = [(int, "integer"),
204 (float, "float"),
205 (float, "double"),
206 (bool, "boolean"),
207 (list, "string"),
208 (dict, "string"),
209 (int, "long"),
210 (str, "liststring"),
211 (str, "anyURI"),
212 (str, "string")]
213
214 # These additions to types allow writing numpy types
215 try:
216 import numpy as np
217 except ImportError:
218 pass
219 else:
220 # prepend so that python types are created upon read (last entry wins)
221 types = [(np.float64, "float"), (np.float32, "float"),
222 (np.float16, "float"), (np.float_, "float"),
223 (np.int, "int"), (np.int8, "int"),
224 (np.int16, "int"), (np.int32, "int"),
225 (np.int64, "int"), (np.uint8, "int"),
226 (np.uint16, "int"), (np.uint32, "int"),
227 (np.uint64, "int"), (np.int_, "int"),
228 (np.intc, "int"), (np.intp, "int"),
229 ] + types
230
231 xml_type = dict(types)
232 python_type = dict(reversed(a) for a in types)
233
234 # http://www.w3.org/TR/xmlschema-2/#boolean
235 convert_bool = {
236 'true': True, 'false': False,
237 'True': True, 'False': False,
238 '0': False, 0: False,
239 '1': True, 1: True
240 }
241
242 def set_version(self, version):
243 d = self.versions.get(version)
244 if d is None:
245 raise nx.NetworkXError('Unknown GEXF version %s.' % version)
246 self.NS_GEXF = d['NS_GEXF']
247 self.NS_VIZ = d['NS_VIZ']
248 self.NS_XSI = d['NS_XSI']
249 self.SCHEMALOCATION = d['SCHEMALOCATION']
250 self.VERSION = d['VERSION']
251 self.version = version
252
253
254 class GEXFWriter(GEXF):
255 # class for writing GEXF format files
256 # use write_gexf() function
257 def __init__(self, graph=None, encoding='utf-8', prettyprint=True,
258 version='1.2draft'):
259 try:
260 import xml.etree.ElementTree as ET
261 except ImportError:
262 raise ImportError('GEXF writer requires '
263 'xml.elementtree.ElementTree')
264 self.prettyprint = prettyprint
265 self.encoding = encoding
266 self.set_version(version)
267 self.xml = Element('gexf',
268 {'xmlns': self.NS_GEXF,
269 'xmlns:xsi': self.NS_XSI,
270 'xsi:schemaLocation': self.SCHEMALOCATION,
271 'version': self.VERSION})
272
273 # Make meta element a non-graph element
274 # Also add lastmodifieddate as attribute, not tag
275 meta_element = Element('meta')
276 subelement_text = 'NetworkX {}'.format(nx.__version__)
277 SubElement(meta_element, 'creator').text = subelement_text
278 meta_element.set('lastmodifieddate', time.strftime('%Y-%m-%d'))
279 self.xml.append(meta_element)
280
281 ET.register_namespace('viz', self.NS_VIZ)
282
283 # counters for edge and attribute identifiers
284 self.edge_id = itertools.count()
285 self.attr_id = itertools.count()
286 self.all_edge_ids = set()
287 # default attributes are stored in dictionaries
288 self.attr = {}
289 self.attr['node'] = {}
290 self.attr['edge'] = {}
291 self.attr['node']['dynamic'] = {}
292 self.attr['node']['static'] = {}
293 self.attr['edge']['dynamic'] = {}
294 self.attr['edge']['static'] = {}
295
296 if graph is not None:
297 self.add_graph(graph)
298
299 def __str__(self):
300 if self.prettyprint:
301 self.indent(self.xml)
302 s = tostring(self.xml).decode(self.encoding)
303 return s
304
305 def add_graph(self, G):
306 # first pass through G collecting edge ids
307 for u, v, dd in G.edges(data=True):
308 eid = dd.get('id')
309 if eid is not None:
310 self.all_edge_ids.add(make_str(eid))
311 # set graph attributes
312 if G.graph.get('mode') == 'dynamic':
313 mode = 'dynamic'
314 else:
315 mode = 'static'
316 # Add a graph element to the XML
317 if G.is_directed():
318 default = 'directed'
319 else:
320 default = 'undirected'
321 name = G.graph.get('name', '')
322 graph_element = Element('graph', defaultedgetype=default, mode=mode,
323 name=name)
324 self.graph_element = graph_element
325 self.add_nodes(G, graph_element)
326 self.add_edges(G, graph_element)
327 self.xml.append(graph_element)
328
329 def add_nodes(self, G, graph_element):
330 nodes_element = Element('nodes')
331 for node, data in G.nodes(data=True):
332 node_data = data.copy()
333 node_id = make_str(node_data.pop('id', node))
334 kw = {'id': node_id}
335 label = make_str(node_data.pop('label', node))
336 kw['label'] = label
337 try:
338 pid = node_data.pop('pid')
339 kw['pid'] = make_str(pid)
340 except KeyError:
341 pass
342 try:
343 start = node_data.pop('start')
344 kw['start'] = make_str(start)
345 self.alter_graph_mode_timeformat(start)
346 except KeyError:
347 pass
348 try:
349 end = node_data.pop('end')
350 kw['end'] = make_str(end)
351 self.alter_graph_mode_timeformat(end)
352 except KeyError:
353 pass
354 # add node element with attributes
355 node_element = Element('node', **kw)
356 # add node element and attr subelements
357 default = G.graph.get('node_default', {})
358 node_data = self.add_parents(node_element, node_data)
359 if self.version == '1.1':
360 node_data = self.add_slices(node_element, node_data)
361 else:
362 node_data = self.add_spells(node_element, node_data)
363 node_data = self.add_viz(node_element, node_data)
364 node_data = self.add_attributes('node', node_element,
365 node_data, default)
366 nodes_element.append(node_element)
367 graph_element.append(nodes_element)
368
369 def add_edges(self, G, graph_element):
370 def edge_key_data(G):
371 # helper function to unify multigraph and graph edge iterator
372 if G.is_multigraph():
373 for u, v, key, data in G.edges(data=True, keys=True):
374 edge_data = data.copy()
375 edge_data.update(key=key)
376 edge_id = edge_data.pop('id', None)
377 if edge_id is None:
378 edge_id = next(self.edge_id)
379 while make_str(edge_id) in self.all_edge_ids:
380 edge_id = next(self.edge_id)
381 self.all_edge_ids.add(make_str(edge_id))
382 yield u, v, edge_id, edge_data
383 else:
384 for u, v, data in G.edges(data=True):
385 edge_data = data.copy()
386 edge_id = edge_data.pop('id', None)
387 if edge_id is None:
388 edge_id = next(self.edge_id)
389 while make_str(edge_id) in self.all_edge_ids:
390 edge_id = next(self.edge_id)
391 self.all_edge_ids.add(make_str(edge_id))
392 yield u, v, edge_id, edge_data
393 edges_element = Element('edges')
394 for u, v, key, edge_data in edge_key_data(G):
395 kw = {'id': make_str(key)}
396 try:
397 edge_weight = edge_data.pop('weight')
398 kw['weight'] = make_str(edge_weight)
399 except KeyError:
400 pass
401 try:
402 edge_type = edge_data.pop('type')
403 kw['type'] = make_str(edge_type)
404 except KeyError:
405 pass
406 try:
407 start = edge_data.pop('start')
408 kw['start'] = make_str(start)
409 self.alter_graph_mode_timeformat(start)
410 except KeyError:
411 pass
412 try:
413 end = edge_data.pop('end')
414 kw['end'] = make_str(end)
415 self.alter_graph_mode_timeformat(end)
416 except KeyError:
417 pass
418 source_id = make_str(G.nodes[u].get('id', u))
419 target_id = make_str(G.nodes[v].get('id', v))
420 edge_element = Element('edge',
421 source=source_id, target=target_id, **kw)
422 default = G.graph.get('edge_default', {})
423 if self.version == '1.1':
424 edge_data = self.add_slices(edge_element, edge_data)
425 else:
426 edge_data = self.add_spells(edge_element, edge_data)
427 edge_data = self.add_viz(edge_element, edge_data)
428 edge_data = self.add_attributes('edge', edge_element,
429 edge_data, default)
430 edges_element.append(edge_element)
431 graph_element.append(edges_element)
432
433 def add_attributes(self, node_or_edge, xml_obj, data, default):
434 # Add attrvalues to node or edge
435 attvalues = Element('attvalues')
436 if len(data) == 0:
437 return data
438 mode = 'static'
439 for k, v in data.items():
440 # rename generic multigraph key to avoid any name conflict
441 if k == 'key':
442 k = 'networkx_key'
443 val_type = type(v)
444 if val_type not in self.xml_type:
445 raise TypeError('attribute value type is not allowed: %s'
446 % val_type)
447 if isinstance(v, list):
448 # dynamic data
449 for val, start, end in v:
450 val_type = type(val)
451 if start is not None or end is not None:
452 mode = 'dynamic'
453 self.alter_graph_mode_timeformat(start)
454 self.alter_graph_mode_timeformat(end)
455 break
456 attr_id = self.get_attr_id(make_str(k),
457 self.xml_type[val_type],
458 node_or_edge, default, mode)
459 for val, start, end in v:
460 e = Element('attvalue')
461 e.attrib['for'] = attr_id
462 e.attrib['value'] = make_str(val)
463 # Handle nan, inf, -inf differently
464 if val_type == float:
465 if e.attrib['value'] == 'inf':
466 e.attrib['value'] = 'INF'
467 elif e.attrib['value'] == 'nan':
468 e.attrib['value'] = 'NaN'
469 elif e.attrib['value'] == '-inf':
470 e.attrib['value'] = '-INF'
471 if start is not None:
472 e.attrib['start'] = make_str(start)
473 if end is not None:
474 e.attrib['end'] = make_str(end)
475 attvalues.append(e)
476 else:
477 # static data
478 mode = 'static'
479 attr_id = self.get_attr_id(make_str(k),
480 self.xml_type[val_type],
481 node_or_edge, default, mode)
482 e = Element('attvalue')
483 e.attrib['for'] = attr_id
484 if isinstance(v, bool):
485 e.attrib['value'] = make_str(v).lower()
486 else:
487 e.attrib['value'] = make_str(v)
488 # Handle float nan, inf, -inf differently
489 if val_type == float:
490 if e.attrib['value'] == 'inf':
491 e.attrib['value'] = 'INF'
492 elif e.attrib['value'] == 'nan':
493 e.attrib['value'] = 'NaN'
494 elif e.attrib['value'] == '-inf':
495 e.attrib['value'] = '-INF'
496 attvalues.append(e)
497 xml_obj.append(attvalues)
498 return data
499
500 def get_attr_id(self, title, attr_type, edge_or_node, default, mode):
501 # find the id of the attribute or generate a new id
502 try:
503 return self.attr[edge_or_node][mode][title]
504 except KeyError:
505 # generate new id
506 new_id = str(next(self.attr_id))
507 self.attr[edge_or_node][mode][title] = new_id
508 attr_kwargs = {'id': new_id, 'title': title, 'type': attr_type}
509 attribute = Element('attribute', **attr_kwargs)
510 # add subelement for data default value if present
511 default_title = default.get(title)
512 if default_title is not None:
513 default_element = Element('default')
514 default_element.text = make_str(default_title)
515 attribute.append(default_element)
516 # new insert it into the XML
517 attributes_element = None
518 for a in self.graph_element.findall('attributes'):
519 # find existing attributes element by class and mode
520 a_class = a.get('class')
521 a_mode = a.get('mode', 'static')
522 if a_class == edge_or_node and a_mode == mode:
523 attributes_element = a
524 if attributes_element is None:
525 # create new attributes element
526 attr_kwargs = {'mode': mode, 'class': edge_or_node}
527 attributes_element = Element('attributes', **attr_kwargs)
528 self.graph_element.insert(0, attributes_element)
529 attributes_element.append(attribute)
530 return new_id
531
532 def add_viz(self, element, node_data):
533 viz = node_data.pop('viz', False)
534 if viz:
535 color = viz.get('color')
536 if color is not None:
537 if self.VERSION == '1.1':
538 e = Element('{%s}color' % self.NS_VIZ,
539 r=str(color.get('r')),
540 g=str(color.get('g')),
541 b=str(color.get('b')))
542 else:
543 e = Element('{%s}color' % self.NS_VIZ,
544 r=str(color.get('r')),
545 g=str(color.get('g')),
546 b=str(color.get('b')),
547 a=str(color.get('a')))
548 element.append(e)
549
550 size = viz.get('size')
551 if size is not None:
552 e = Element('{%s}size' % self.NS_VIZ, value=str(size))
553 element.append(e)
554
555 thickness = viz.get('thickness')
556 if thickness is not None:
557 e = Element('{%s}thickness' % self.NS_VIZ,
558 value=str(thickness))
559 element.append(e)
560
561 shape = viz.get('shape')
562 if shape is not None:
563 if shape.startswith('http'):
564 e = Element('{%s}shape' % self.NS_VIZ,
565 value='image', uri=str(shape))
566 else:
567 e = Element('{%s}shape' % self.NS_VIZ, value=str(shape))
568 element.append(e)
569
570 position = viz.get('position')
571 if position is not None:
572 e = Element('{%s}position' % self.NS_VIZ,
573 x=str(position.get('x')),
574 y=str(position.get('y')),
575 z=str(position.get('z')))
576 element.append(e)
577 return node_data
578
579 def add_parents(self, node_element, node_data):
580 parents = node_data.pop('parents', False)
581 if parents:
582 parents_element = Element('parents')
583 for p in parents:
584 e = Element('parent')
585 e.attrib['for'] = str(p)
586 parents_element.append(e)
587 node_element.append(parents_element)
588 return node_data
589
590 def add_slices(self, node_or_edge_element, node_or_edge_data):
591 slices = node_or_edge_data.pop('slices', False)
592 if slices:
593 slices_element = Element('slices')
594 for start, end in slices:
595 e = Element('slice', start=str(start), end=str(end))
596 slices_element.append(e)
597 node_or_edge_element.append(slices_element)
598 return node_or_edge_data
599
600 def add_spells(self, node_or_edge_element, node_or_edge_data):
601 spells = node_or_edge_data.pop('spells', False)
602 if spells:
603 spells_element = Element('spells')
604 for start, end in spells:
605 e = Element('spell')
606 if start is not None:
607 e.attrib['start'] = make_str(start)
608 self.alter_graph_mode_timeformat(start)
609 if end is not None:
610 e.attrib['end'] = make_str(end)
611 self.alter_graph_mode_timeformat(end)
612 spells_element.append(e)
613 node_or_edge_element.append(spells_element)
614 return node_or_edge_data
615
616 def alter_graph_mode_timeformat(self, start_or_end):
617 # If 'start' or 'end' appears, alter Graph mode to dynamic and
618 # set timeformat
619 if self.graph_element.get('mode') == 'static':
620 if start_or_end is not None:
621 if isinstance(start_or_end, str):
622 timeformat = 'date'
623 elif isinstance(start_or_end, float):
624 timeformat = 'double'
625 elif isinstance(start_or_end, int):
626 timeformat = 'long'
627 else:
628 raise nx.NetworkXError(
629 'timeformat should be of the type int, float or str')
630 self.graph_element.set('timeformat', timeformat)
631 self.graph_element.set('mode', 'dynamic')
632
633 def write(self, fh):
634 # Serialize graph G in GEXF to the open fh
635 if self.prettyprint:
636 self.indent(self.xml)
637 document = ElementTree(self.xml)
638 document.write(fh, encoding=self.encoding, xml_declaration=True)
639
640 def indent(self, elem, level=0):
641 # in-place prettyprint formatter
642 i = "\n" + " " * level
643 if len(elem):
644 if not elem.text or not elem.text.strip():
645 elem.text = i + " "
646 if not elem.tail or not elem.tail.strip():
647 elem.tail = i
648 for elem in elem:
649 self.indent(elem, level + 1)
650 if not elem.tail or not elem.tail.strip():
651 elem.tail = i
652 else:
653 if level and (not elem.tail or not elem.tail.strip()):
654 elem.tail = i
655
656
657 class GEXFReader(GEXF):
658 # Class to read GEXF format files
659 # use read_gexf() function
660 def __init__(self, node_type=None, version='1.2draft'):
661 try:
662 import xml.etree.ElementTree
663 except ImportError:
664 raise ImportError('GEXF reader requires '
665 'xml.elementtree.ElementTree.')
666 self.node_type = node_type
667 # assume simple graph and test for multigraph on read
668 self.simple_graph = True
669 self.set_version(version)
670
671 def __call__(self, stream):
672 self.xml = ElementTree(file=stream)
673 g = self.xml.find('{%s}graph' % self.NS_GEXF)
674 if g is not None:
675 return self.make_graph(g)
676 # try all the versions
677 for version in self.versions:
678 self.set_version(version)
679 g = self.xml.find('{%s}graph' % self.NS_GEXF)
680 if g is not None:
681 return self.make_graph(g)
682 raise nx.NetworkXError('No <graph> element in GEXF file.')
683
684 def make_graph(self, graph_xml):
685 # start with empty DiGraph or MultiDiGraph
686 edgedefault = graph_xml.get('defaultedgetype', None)
687 if edgedefault == 'directed':
688 G = nx.MultiDiGraph()
689 else:
690 G = nx.MultiGraph()
691
692 # graph attributes
693 graph_name = graph_xml.get('name', '')
694 if graph_name != '':
695 G.graph['name'] = graph_name
696 graph_start = graph_xml.get('start')
697 if graph_start is not None:
698 G.graph['start'] = graph_start
699 graph_end = graph_xml.get('end')
700 if graph_end is not None:
701 G.graph['end'] = graph_end
702 graph_mode = graph_xml.get('mode', '')
703 if graph_mode == 'dynamic':
704 G.graph['mode'] = 'dynamic'
705 else:
706 G.graph['mode'] = 'static'
707
708 # timeformat
709 self.timeformat = graph_xml.get('timeformat')
710 if self.timeformat == 'date':
711 self.timeformat = 'string'
712
713 # node and edge attributes
714 attributes_elements = graph_xml.findall('{%s}attributes' %
715 self.NS_GEXF)
716 # dictionaries to hold attributes and attribute defaults
717 node_attr = {}
718 node_default = {}
719 edge_attr = {}
720 edge_default = {}
721 for a in attributes_elements:
722 attr_class = a.get('class')
723 if attr_class == 'node':
724 na, nd = self.find_gexf_attributes(a)
725 node_attr.update(na)
726 node_default.update(nd)
727 G.graph['node_default'] = node_default
728 elif attr_class == 'edge':
729 ea, ed = self.find_gexf_attributes(a)
730 edge_attr.update(ea)
731 edge_default.update(ed)
732 G.graph['edge_default'] = edge_default
733 else:
734 raise # unknown attribute class
735
736 # Hack to handle Gephi0.7beta bug
737 # add weight attribute
738 ea = {'weight': {'type': 'double', 'mode': 'static',
739 'title': 'weight'}}
740 ed = {}
741 edge_attr.update(ea)
742 edge_default.update(ed)
743 G.graph['edge_default'] = edge_default
744
745 # add nodes
746 nodes_element = graph_xml.find('{%s}nodes' % self.NS_GEXF)
747 if nodes_element is not None:
748 for node_xml in nodes_element.findall('{%s}node' % self.NS_GEXF):
749 self.add_node(G, node_xml, node_attr)
750
751 # add edges
752 edges_element = graph_xml.find('{%s}edges' % self.NS_GEXF)
753 if edges_element is not None:
754 for edge_xml in edges_element.findall('{%s}edge' % self.NS_GEXF):
755 self.add_edge(G, edge_xml, edge_attr)
756
757 # switch to Graph or DiGraph if no parallel edges were found.
758 if self.simple_graph:
759 if G.is_directed():
760 G = nx.DiGraph(G)
761 else:
762 G = nx.Graph(G)
763 return G
764
765 def add_node(self, G, node_xml, node_attr, node_pid=None):
766 # add a single node with attributes to the graph
767
768 # get attributes and subattributues for node
769 data = self.decode_attr_elements(node_attr, node_xml)
770 data = self.add_parents(data, node_xml) # add any parents
771 if self.version == '1.1':
772 data = self.add_slices(data, node_xml) # add slices
773 else:
774 data = self.add_spells(data, node_xml) # add spells
775 data = self.add_viz(data, node_xml) # add viz
776 data = self.add_start_end(data, node_xml) # add start/end
777
778 # find the node id and cast it to the appropriate type
779 node_id = node_xml.get('id')
780 if self.node_type is not None:
781 node_id = self.node_type(node_id)
782
783 # every node should have a label
784 node_label = node_xml.get('label')
785 data['label'] = node_label
786
787 # parent node id
788 node_pid = node_xml.get('pid', node_pid)
789 if node_pid is not None:
790 data['pid'] = node_pid
791
792 # check for subnodes, recursive
793 subnodes = node_xml.find('{%s}nodes' % self.NS_GEXF)
794 if subnodes is not None:
795 for node_xml in subnodes.findall('{%s}node' % self.NS_GEXF):
796 self.add_node(G, node_xml, node_attr, node_pid=node_id)
797
798 G.add_node(node_id, **data)
799
800 def add_start_end(self, data, xml):
801 # start and end times
802 ttype = self.timeformat
803 node_start = xml.get('start')
804 if node_start is not None:
805 data['start'] = self.python_type[ttype](node_start)
806 node_end = xml.get('end')
807 if node_end is not None:
808 data['end'] = self.python_type[ttype](node_end)
809 return data
810
811 def add_viz(self, data, node_xml):
812 # add viz element for node
813 viz = {}
814 color = node_xml.find('{%s}color' % self.NS_VIZ)
815 if color is not None:
816 if self.VERSION == '1.1':
817 viz['color'] = {'r': int(color.get('r')),
818 'g': int(color.get('g')),
819 'b': int(color.get('b'))}
820 else:
821 viz['color'] = {'r': int(color.get('r')),
822 'g': int(color.get('g')),
823 'b': int(color.get('b')),
824 'a': float(color.get('a', 1))}
825
826 size = node_xml.find('{%s}size' % self.NS_VIZ)
827 if size is not None:
828 viz['size'] = float(size.get('value'))
829
830 thickness = node_xml.find('{%s}thickness' % self.NS_VIZ)
831 if thickness is not None:
832 viz['thickness'] = float(thickness.get('value'))
833
834 shape = node_xml.find('{%s}shape' % self.NS_VIZ)
835 if shape is not None:
836 viz['shape'] = shape.get('shape')
837 if viz['shape'] == 'image':
838 viz['shape'] = shape.get('uri')
839
840 position = node_xml.find('{%s}position' % self.NS_VIZ)
841 if position is not None:
842 viz['position'] = {'x': float(position.get('x', 0)),
843 'y': float(position.get('y', 0)),
844 'z': float(position.get('z', 0))}
845
846 if len(viz) > 0:
847 data['viz'] = viz
848 return data
849
850 def add_parents(self, data, node_xml):
851 parents_element = node_xml.find('{%s}parents' % self.NS_GEXF)
852 if parents_element is not None:
853 data['parents'] = []
854 for p in parents_element.findall('{%s}parent' % self.NS_GEXF):
855 parent = p.get('for')
856 data['parents'].append(parent)
857 return data
858
859 def add_slices(self, data, node_or_edge_xml):
860 slices_element = node_or_edge_xml.find('{%s}slices' % self.NS_GEXF)
861 if slices_element is not None:
862 data['slices'] = []
863 for s in slices_element.findall('{%s}slice' % self.NS_GEXF):
864 start = s.get('start')
865 end = s.get('end')
866 data['slices'].append((start, end))
867 return data
868
869 def add_spells(self, data, node_or_edge_xml):
870 spells_element = node_or_edge_xml.find('{%s}spells' % self.NS_GEXF)
871 if spells_element is not None:
872 data['spells'] = []
873 ttype = self.timeformat
874 for s in spells_element.findall('{%s}spell' % self.NS_GEXF):
875 start = self.python_type[ttype](s.get('start'))
876 end = self.python_type[ttype](s.get('end'))
877 data['spells'].append((start, end))
878 return data
879
880 def add_edge(self, G, edge_element, edge_attr):
881 # add an edge to the graph
882
883 # raise error if we find mixed directed and undirected edges
884 edge_direction = edge_element.get('type')
885 if G.is_directed() and edge_direction == 'undirected':
886 raise nx.NetworkXError(
887 'Undirected edge found in directed graph.')
888 if (not G.is_directed()) and edge_direction == 'directed':
889 raise nx.NetworkXError(
890 'Directed edge found in undirected graph.')
891
892 # Get source and target and recast type if required
893 source = edge_element.get('source')
894 target = edge_element.get('target')
895 if self.node_type is not None:
896 source = self.node_type(source)
897 target = self.node_type(target)
898
899 data = self.decode_attr_elements(edge_attr, edge_element)
900 data = self.add_start_end(data, edge_element)
901
902 if self.version == '1.1':
903 data = self.add_slices(data, edge_element) # add slices
904 else:
905 data = self.add_spells(data, edge_element) # add spells
906
907 # GEXF stores edge ids as an attribute
908 # NetworkX uses them as keys in multigraphs
909 # if networkx_key is not specified as an attribute
910 edge_id = edge_element.get('id')
911 if edge_id is not None:
912 data['id'] = edge_id
913
914 # check if there is a 'multigraph_key' and use that as edge_id
915 multigraph_key = data.pop('networkx_key', None)
916 if multigraph_key is not None:
917 edge_id = multigraph_key
918
919 weight = edge_element.get('weight')
920 if weight is not None:
921 data['weight'] = float(weight)
922
923 edge_label = edge_element.get('label')
924 if edge_label is not None:
925 data['label'] = edge_label
926
927 if G.has_edge(source, target):
928 # seen this edge before - this is a multigraph
929 self.simple_graph = False
930 G.add_edge(source, target, key=edge_id, **data)
931 if edge_direction == 'mutual':
932 G.add_edge(target, source, key=edge_id, **data)
933
934 def decode_attr_elements(self, gexf_keys, obj_xml):
935 # Use the key information to decode the attr XML
936 attr = {}
937 # look for outer '<attvalues>' element
938 attr_element = obj_xml.find('{%s}attvalues' % self.NS_GEXF)
939 if attr_element is not None:
940 # loop over <attvalue> elements
941 for a in attr_element.findall('{%s}attvalue' % self.NS_GEXF):
942 key = a.get('for') # for is required
943 try: # should be in our gexf_keys dictionary
944 title = gexf_keys[key]['title']
945 except KeyError:
946 raise nx.NetworkXError('No attribute defined for=%s.'
947 % key)
948 atype = gexf_keys[key]['type']
949 value = a.get('value')
950 if atype == 'boolean':
951 value = self.convert_bool[value]
952 else:
953 value = self.python_type[atype](value)
954 if gexf_keys[key]['mode'] == 'dynamic':
955 # for dynamic graphs use list of three-tuples
956 # [(value1,start1,end1), (value2,start2,end2), etc]
957 ttype = self.timeformat
958 start = self.python_type[ttype](a.get('start'))
959 end = self.python_type[ttype](a.get('end'))
960 if title in attr:
961 attr[title].append((value, start, end))
962 else:
963 attr[title] = [(value, start, end)]
964 else:
965 # for static graphs just assign the value
966 attr[title] = value
967 return attr
968
969 def find_gexf_attributes(self, attributes_element):
970 # Extract all the attributes and defaults
971 attrs = {}
972 defaults = {}
973 mode = attributes_element.get('mode')
974 for k in attributes_element.findall('{%s}attribute' % self.NS_GEXF):
975 attr_id = k.get('id')
976 title = k.get('title')
977 atype = k.get('type')
978 attrs[attr_id] = {'title': title, 'type': atype, 'mode': mode}
979 # check for the 'default' subelement of key element and add
980 default = k.find('{%s}default' % self.NS_GEXF)
981 if default is not None:
982 if atype == 'boolean':
983 value = self.convert_bool[default.text]
984 else:
985 value = self.python_type[atype](default.text)
986 defaults[title] = value
987 return attrs, defaults
988
989
990 def relabel_gexf_graph(G):
991 """Relabel graph using "label" node keyword for node label.
992
993 Parameters
994 ----------
995 G : graph
996 A NetworkX graph read from GEXF data
997
998 Returns
999 -------
1000 H : graph
1001 A NetworkX graph with relabed nodes
1002
1003 Raises
1004 ------
1005 NetworkXError
1006 If node labels are missing or not unique while relabel=True.
1007
1008 Notes
1009 -----
1010 This function relabels the nodes in a NetworkX graph with the
1011 "label" attribute. It also handles relabeling the specific GEXF
1012 node attributes "parents", and "pid".
1013 """
1014 # build mapping of node labels, do some error checking
1015 try:
1016 mapping = [(u, G.nodes[u]['label']) for u in G]
1017 except KeyError:
1018 raise nx.NetworkXError('Failed to relabel nodes: '
1019 'missing node labels found. '
1020 'Use relabel=False.')
1021 x, y = zip(*mapping)
1022 if len(set(y)) != len(G):
1023 raise nx.NetworkXError('Failed to relabel nodes: '
1024 'duplicate node labels found. '
1025 'Use relabel=False.')
1026 mapping = dict(mapping)
1027 H = nx.relabel_nodes(G, mapping)
1028 # relabel attributes
1029 for n in G:
1030 m = mapping[n]
1031 H.nodes[m]['id'] = n
1032 H.nodes[m].pop('label')
1033 if 'pid' in H.nodes[m]:
1034 H.nodes[m]['pid'] = mapping[G.nodes[n]['pid']]
1035 if 'parents' in H.nodes[m]:
1036 H.nodes[m]['parents'] = [mapping[p] for p in G.nodes[n]['parents']]
1037 return H
1038
1039
1040 # fixture for pytest
1041 def setup_module(module):
1042 import pytest
1043 xml.etree.cElementTree = pytest.importorskip('xml.etree.cElementTree')
1044
1045
1046 # fixture for pytest
1047 def teardown_module(module):
1048 import os
1049 try:
1050 os.unlink('test.gexf')
1051 except Exception as e:
1052 pass