Mercurial > repos > perssond > ashlar
diff pyramid_upgrade.py @ 0:b3054f3d42b2 draft
"planemo upload for repository https://github.com/ohsu-comp-bio/ashlar commit 27f0c9be58e9e5aecc69067d0e60b5cb945de4b2-dirty"
author | perssond |
---|---|
date | Fri, 12 Mar 2021 00:14:49 +0000 |
parents | |
children | f183d9de4622 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pyramid_upgrade.py Fri Mar 12 00:14:49 2021 +0000 @@ -0,0 +1,566 @@ +import sys +import os +import argparse +import struct +import re +import fractions +import io +import xml.etree.ElementTree +import collections +import reprlib +import dataclasses +from typing import List, Any + + +datatype_formats = { + 1: "B", # BYTE + 2: "s", # ASCII + 3: "H", # SHORT + 4: "I", # LONG + 5: "I", # RATIONAL (pairs) + 6: "b", # SBYTE + 7: "B", # UNDEFINED + 8: "h", # SSHORT + 9: "i", # SLONG + 10: "i", # SRATIONAL (pairs) + 11: "f", # FLOAT + 12: "d", # DOUBLE + 13: "I", # IFD + 16: "Q", # LONG8 + 17: "q", # SLONG8 + 18: "Q", # IFD8 +} +rational_datatypes = {5, 10} + + +class TiffSurgeon: + """Read, manipulate and write IFDs in BigTIFF files.""" + + def __init__(self, path, *, writeable=False, encoding=None): + self.path = path + self.writeable = writeable + self.encoding = encoding + self.endian = "" + self.ifds = None + self.file = open(self.path, "r+b" if self.writeable else "rb") + self._validate() + + def _validate(self): + signature = self.read("2s") + signature = signature.decode("ascii", errors="ignore") + if signature == "II": + self.endian = "<" + elif signature == "MM": + self.endian = ">" + else: + raise FormatError(f"Not a TIFF file (signature is '{signature}').") + version = self.read("H") + if version == 42: + raise FormatError("Cannot process classic TIFF, only BigTIFF.") + offset_size, reserved, first_ifd_offset = self.read("H H Q") + if version != 43 or offset_size != 8 or reserved != 0: + raise FormatError("Malformed TIFF, giving up!") + self.first_ifd_offset = first_ifd_offset + + def read(self, fmt, *, file=None): + if file is None: + file = self.file + endian = self.endian or "=" + size = struct.calcsize(endian + fmt) + raw = file.read(size) + value = self.unpack(fmt, raw) + return value + + def write(self, fmt, *values): + if not self.writeable: + raise ValueError("File is opened as read-only.") + raw = self.pack(fmt, *values) + self.file.write(raw) + + def unpack(self, fmt, raw): + assert self.endian or re.match(r"\d+s", fmt), \ + "can't unpack non-string before endianness is detected" + fmt = self.endian + fmt + size = struct.calcsize(fmt) + values = struct.unpack(fmt, raw[:size]) + if len(values) == 1: + return values[0] + else: + return values + + def pack(self, fmt, *values): + assert self.endian, "can't pack without endian set" + fmt = self.endian + fmt + raw = struct.pack(fmt, *values) + return raw + + def read_ifds(self): + ifds = [self.read_ifd(self.first_ifd_offset)] + while ifds[-1].offset_next: + ifds.append(self.read_ifd(ifds[-1].offset_next)) + self.ifds = ifds + + def read_ifd(self, offset): + self.file.seek(offset) + num_tags = self.read("Q") + buf = io.BytesIO(self.file.read(num_tags * 20)) + offset_next = self.read("Q") + try: + tags = TagSet([self.read_tag(buf) for i in range(num_tags)]) + except FormatError as e: + raise FormatError(f"IFD at offset {offset}, {e}") from None + ifd = Ifd(tags, offset, offset_next) + return ifd + + def read_tag(self, buf): + tag = Tag(*self.read("H H Q 8s", file=buf)) + value, offset_range = self.tag_value(tag) + tag = dataclasses.replace(tag, value=value, offset_range=offset_range) + return tag + + def append_ifd_sequence(self, ifds): + """Write list of IFDs as a chained sequence at the end of the file. + + Returns a list of new Ifd objects with updated offsets. + + """ + self.file.seek(0, os.SEEK_END) + new_ifds = [] + for ifd in ifds: + offset = self.file.tell() + self.write("Q", len(ifd.tags)) + for tag in ifd.tags: + self.write_tag(tag) + offset_next = self.file.tell() + 8 if ifd is not ifds[-1] else 0 + self.write("Q", offset_next) + new_ifd = dataclasses.replace( + ifd, offset=offset, offset_next=offset_next + ) + new_ifds.append(new_ifd) + return new_ifds + + def append_tag_data(self, code, datatype, value): + """Build new tag and write data to the end of the file if necessary. + + Returns a Tag object corresponding to the passed parameters. This + function only writes any "overflow" data and not the IFD entry itself, + so the returned Tag must still be written to an IFD. + + If the value is small enough to fit in the data field within an IFD, no + data will actually be written to the file and the returned Tag object + will have the value encoded in its data attribute. Otherwise the data + will be appended to the file and the returned Tag's data attribute will + encode the corresponding offset. + + """ + fmt = datatype_formats[datatype] + # FIXME Should we perform our own check that values match datatype? + # struct.pack will do it but the exception won't be as understandable. + original_value = value + if isinstance(value, str): + if not self.encoding: + raise ValueError( + "ASCII tag values must be bytes if encoding is not set" + ) + value = [value.encode(self.encoding) + b"\x00"] + count = len(value[0]) + elif isinstance(value, bytes): + value = [value + b"\x00"] + count = len(value[0]) + else: + try: + len(value) + except TypeError: + value = [value] + count = len(value) + struct_count = count + if datatype in rational_datatypes: + value = [i for v in value for i in v.as_integer_ratio()] + count //= 2 + byte_count = struct_count * struct.calcsize(fmt) + if byte_count <= 8: + data = self.pack(str(struct_count) + fmt, *value) + data += bytes(8 - byte_count) + else: + self.file.seek(0, os.SEEK_END) + data = self.pack("Q", self.file.tell()) + self.write(str(count) + fmt, *value) + # TODO Compute and set offset_range. + tag = Tag(code, datatype, count, data, original_value) + return tag + + def write_first_ifd_offset(self, offset): + self.file.seek(8) + self.write("Q", offset) + + def write_tag(self, tag): + self.write("H H Q 8s", tag.code, tag.datatype, tag.count, tag.data) + + def tag_value(self, tag): + """Return decoded tag data and the file offset range.""" + fmt = datatype_formats[tag.datatype] + count = tag.count + if tag.datatype in rational_datatypes: + count *= 2 + byte_count = count * struct.calcsize(fmt) + if byte_count <= 8: + value = self.unpack(str(count) + fmt, tag.data) + offset_range = range(0, 0) + else: + offset = self.unpack("Q", tag.data) + self.file.seek(offset) + value = self.read(str(count) + fmt) + offset_range = range(offset, offset + byte_count) + if tag.datatype == 2: + value = value.rstrip(b"\x00") + if self.encoding: + try: + value = value.decode(self.encoding) + except UnicodeDecodeError as e: + raise FormatError(f"tag {tag.code}: {e}") from None + elif tag.datatype in rational_datatypes: + value = [ + fractions.Fraction(*v) for v in zip(value[::2], value[1::2]) + ] + if len(value) == 1: + value = value[0] + return value, offset_range + + def close(self): + self.file.close() + + +@dataclasses.dataclass(frozen=True) +class Tag: + code: int + datatype: int + count: int + data: bytes + value: Any = None + offset_range: range = None + + _vrepr = reprlib.Repr() + _vrepr.maxstring = 60 + _vrepr.maxother = 60 + vrepr = _vrepr.repr + + def __repr__(self): + return ( + self.__class__.__qualname__ + "(" + + f"code={self.code!r}, datatype={self.datatype!r}, " + + f"count={self.count!r}, data={self.data!r}, " + + f"value={self.vrepr(self.value)}" + + ")" + ) + +@dataclasses.dataclass(frozen=True) +class TagSet: + """Container for Tag objects as stored in a TIFF IFD. + + Tag objects are maintained in a list that's always sorted in ascending order + by the tag code. Only one tag for a given code may be present, which is where + the "set" name comes from. + + """ + + tags: List[Tag] = dataclasses.field(default_factory=list) + + def __post_init__(self): + if len(self.codes) != len(set(self.codes)): + raise ValueError("Duplicate tag codes are not allowed.") + + def __repr__(self): + ret = type(self).__name__ + "([" + if self.tags: + ret += "\n" + ret += "".join([f" {t},\n" for t in self.tags]) + ret += "])" + return ret + + @property + def codes(self): + return [t.code for t in self.tags] + + def __getitem__(self, code): + for t in self.tags: + if code == t.code: + return t + else: + raise KeyError(code) + + def __delitem__(self, code): + try: + i = self.codes.index(code) + except ValueError: + raise KeyError(code) from None + self.tags[:] = self.tags[:i] + self.tags[i+1:] + + def __contains__(self, code): + return code in self.codes + + def __len__(self): + return len(self.tags) + + def __iter__(self): + return iter(self.tags) + + def get(self, code, default=None): + try: + return self[code] + except KeyError: + return default + + def get_value(self, code, default=None): + tag = self.get(code) + if tag: + return tag.value + else: + return default + + def insert(self, tag): + """Add a new tag or replace an existing one.""" + for i, t in enumerate(self.tags): + if tag.code == t.code: + self.tags[i] = tag + return + elif tag.code < t.code: + break + else: + i = len(self.tags) + n = len(self.tags) + self.tags[i:n+1] = [tag] + self.tags[i:n] + + +@dataclasses.dataclass(frozen=True) +class Ifd: + tags: TagSet + offset: int + offset_next: int + + @property + def nbytes(self): + return len(self.tags) * 20 + 16 + + @property + def offset_range(self): + return range(self.offset, self.offset + self.nbytes) + + +class FormatError(Exception): + pass + + +def fix_attrib_namespace(elt): + """Prefix un-namespaced XML attributes with the tag's namespace.""" + # This fixes ElementTree's inability to round-trip XML with a default + # namespace ("cannot use non-qualified names with default_namespace option" + # error). 7-year-old BPO issue here: https://bugs.python.org/issue17088 + # Code inspired by https://gist.github.com/provegard/1381912 . + if elt.tag[0] == "{": + uri, _ = elt.tag[1:].rsplit("}", 1) + new_attrib = {} + for name, value in elt.attrib.items(): + if name[0] != "{": + # For un-namespaced attributes, copy namespace from element. + name = f"{{{uri}}}{name}" + new_attrib[name] = value + elt.attrib = new_attrib + for child in elt: + fix_attrib_namespace(child) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Convert an OME-TIFF legacy pyramid to the BioFormats 6" + " OME-TIFF pyramid format in-place.", + ) + parser.add_argument("image", help="OME-TIFF file to convert") + parser.add_argument( + "-n", + dest="channel_names", + nargs="+", + default=[], + metavar="NAME", + help="Channel names to be inserted into OME metadata. Number of names" + " must match number of channels in image. Be sure to put quotes" + " around names containing spaces or other special shell characters." + ) + args = parser.parse_args() + return args + + +def main(): + + args = parse_args() + + image_path = sys.argv[1] + try: + tiff = TiffSurgeon(image_path, encoding="utf-8", writeable=True) + except FormatError as e: + print(f"TIFF format error: {e}") + sys.exit(1) + + tiff.read_ifds() + + # ElementTree doesn't parse xml declarations so we'll just run some sanity + # checks that we do have UTF-8 and give it a decoded string instead of raw + # bytes. We need to both ensure that the raw tag bytes decode properly and + # that the declaration encoding is UTF-8 if present. + try: + omexml = tiff.ifds[0].tags.get_value(270, "") + except FormatError: + print("ImageDescription tag is not a valid UTF-8 string (not an OME-TIFF?)") + sys.exit(1) + if re.match(r'<\?xml [^>]*encoding="(?!UTF-8)[^"]*"', omexml): + print("OME-XML is encoded with something other than UTF-8.") + sys.exit(1) + + xml_ns = {"ome": "http://www.openmicroscopy.org/Schemas/OME/2016-06"} + + if xml_ns["ome"] not in omexml: + print("Not an OME-TIFF.") + sys.exit(1) + if ( + "Faas" not in tiff.ifds[0].tags.get_value(305, "") + or 330 in tiff.ifds[0].tags + ): + print("Not a legacy OME-TIFF pyramid.") + sys.exit(1) + + # All XML manipulation assumes the document is valid OME-XML! + root = xml.etree.ElementTree.fromstring(omexml) + image = root.find("ome:Image", xml_ns) + pixels = image.find("ome:Pixels", xml_ns) + size_x = int(pixels.get("SizeX")) + size_y = int(pixels.get("SizeY")) + size_c = int(pixels.get("SizeC")) + size_z = int(pixels.get("SizeZ")) + size_t = int(pixels.get("SizeT")) + num_levels = len(root.findall("ome:Image", xml_ns)) + page_dims = [(ifd.tags[256].value, ifd.tags[257].value) for ifd in tiff.ifds] + + if len(root) != num_levels: + print("Top-level OME-XML elements other than Image are not supported.") + if size_z != 1 or size_t != 1: + print("Z-stacks and multiple timepoints are not supported.") + sys.exit(1) + if size_c * num_levels != len(tiff.ifds): + print("TIFF page count does not match OME-XML Image elements.") + sys.exit(1) + if any(dims != (size_x, size_y) for dims in page_dims[:size_c]): + print(f"TIFF does not begin with SizeC={size_c} full-size pages.") + sys.exit(1) + for level in range(1, num_levels): + level_dims = page_dims[level * size_c : (level + 1) * size_c] + if len(set(level_dims)) != 1: + print( + f"Pyramid level {level + 1} out of {num_levels} has inconsistent" + f" sizes:\n{level_dims}" + ) + sys.exit(1) + if args.channel_names and len(args.channel_names) != size_c: + print( + f"Wrong number of channel names -- image has {size_c} channels but" + f" {len(args.channel_names)} names were specified:" + ) + for i, n in enumerate(args.channel_names, 1): + print(f"{i:4}: {n}") + sys.exit(1) + + print("Input image summary") + print("===================") + print(f"Dimensions: {size_x} x {size_y}") + print(f"Number of channels: {size_c}") + print(f"Pyramid sub-resolutions ({num_levels - 1} total):") + for dim_x, dim_y in page_dims[size_c::size_c]: + print(f" {dim_x} x {dim_y}") + software = tiff.ifds[0].tags.get_value(305, "<not set>") + print(f"Software: {software}") + print() + + print("Updating OME-XML metadata...") + # We already verified there is nothing but Image elements under the root. + for other_image in root[1:]: + root.remove(other_image) + for tiffdata in pixels.findall("ome:TiffData", xml_ns): + pixels.remove(tiffdata) + new_tiffdata = xml.etree.ElementTree.Element( + f"{{{xml_ns['ome']}}}TiffData", + attrib={"IFD": "0", "PlaneCount": str(size_c)}, + ) + # A valid OME-XML Pixels begins with size_c Channels; then comes TiffData. + pixels.insert(size_c, new_tiffdata) + + if args.channel_names: + print("Renaming channels...") + channels = pixels.findall("ome:Channel", xml_ns) + for channel, name in zip(channels, args.channel_names): + channel.attrib["Name"] = name + + fix_attrib_namespace(root) + # ElementTree.tostring would have been simpler but it only supports + # xml_declaration and default_namespace starting with Python 3.8. + xml_file = io.BytesIO() + tree = xml.etree.ElementTree.ElementTree(root) + tree.write( + xml_file, + encoding="utf-8", + xml_declaration=True, + default_namespace=xml_ns["ome"], + ) + new_omexml = xml_file.getvalue() + + print("Writing new TIFF headers...") + stale_ranges = [ifd.offset_range for ifd in tiff.ifds] + main_ifds = tiff.ifds[:size_c] + channel_sub_ifds = [tiff.ifds[c + size_c : : size_c] for c in range(size_c)] + for i, (main_ifd, sub_ifds) in enumerate(zip(main_ifds, channel_sub_ifds)): + for ifd in sub_ifds: + if 305 in ifd.tags: + stale_ranges.append(ifd.tags[305].offset_range) + del ifd.tags[305] + ifd.tags.insert(tiff.append_tag_data(254, 3, 1)) + if i == 0: + stale_ranges.append(main_ifd.tags[305].offset_range) + stale_ranges.append(main_ifd.tags[270].offset_range) + old_software = main_ifd.tags[305].value.replace("Faas", "F*a*a*s") + new_software = f"pyramid_upgrade.py (was {old_software})" + main_ifd.tags.insert(tiff.append_tag_data(305, 2, new_software)) + main_ifd.tags.insert(tiff.append_tag_data(270, 2, new_omexml)) + else: + if 305 in main_ifd.tags: + stale_ranges.append(main_ifd.tags[305].offset_range) + del main_ifd.tags[305] + sub_ifds[:] = tiff.append_ifd_sequence(sub_ifds) + offsets = [ifd.offset for ifd in sub_ifds] + main_ifd.tags.insert(tiff.append_tag_data(330, 16, offsets)) + main_ifds = tiff.append_ifd_sequence(main_ifds) + tiff.write_first_ifd_offset(main_ifds[0].offset) + + print("Clearing old headers and tag values...") + # We overwrite all the old IFDs and referenced data values with obvious + # "filler" as a courtesy to anyone who might need to poke around in the TIFF + # structure down the road. A real TIFF parser wouldn't see the stale data, + # but a human might just scan for the first thing that looks like a run of + # OME-XML and not realize it's been replaced with something else. The filler + # content is the repeated string "unused " with square brackets at the + # beginning and end of each filled IFD or data value. + filler = b"unused " + f_len = len(filler) + for r in stale_ranges: + tiff.file.seek(r.start) + tiff.file.write(b"[") + f_total = len(r) - 2 + for i in range(f_total // f_len): + tiff.file.write(filler) + tiff.file.write(b" " * (f_total % f_len)) + tiff.file.write(b"]") + + tiff.close() + + print() + print("Success!") + + +if __name__ == "__main__": + main()