diff xarray_tool.py @ 0:fea8a53f8099 draft

"planemo upload for repository https://github.com/galaxyecology/tools-ecology/tree/master/tools/data_manipulation/xarray/ commit 57b6d23e3734d883e71081c78e77964d61be82ba"
author ecology
date Sun, 06 Jun 2021 08:50:43 +0000
parents
children 3e73f657a998
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/xarray_tool.py	Sun Jun 06 08:50:43 2021 +0000
@@ -0,0 +1,365 @@
+# xarray tool for:
+# - getting metadata information
+# - select data and save results in csv file for further post-processing
+
+import argparse
+import csv
+import os
+import warnings
+
+import geopandas as gdp
+
+import pandas as pd
+
+from shapely.geometry import Point
+from shapely.ops import nearest_points
+
+import xarray as xr
+
+
+class XarrayTool ():
+    def __init__(self, infile, outfile_info="", outfile_summary="",
+                 select="", outfile="", outputdir="", latname="",
+                 latvalN="", latvalS="", lonname="", lonvalE="",
+                 lonvalW="", filter_list="", coords="", time="",
+                 verbose=False, no_missing=False, coords_info=None,
+                 tolerance=None):
+        self.infile = infile
+        self.outfile_info = outfile_info
+        self.outfile_summary = outfile_summary
+        self.select = select
+        self.outfile = outfile
+        self.outputdir = outputdir
+        self.latname = latname
+        if tolerance != "" and tolerance is not None:
+            self.tolerance = float(tolerance)
+        else:
+            self.tolerance = -1
+        if latvalN != "" and latvalN is not None:
+            self.latvalN = float(latvalN)
+        else:
+            self.latvalN = ""
+        if latvalS != "" and latvalS is not None:
+            self.latvalS = float(latvalS)
+        else:
+            self.latvalS = ""
+        self.lonname = lonname
+        if lonvalE != "" and lonvalE is not None:
+            self.lonvalE = float(lonvalE)
+        else:
+            self.lonvalE = ""
+        if lonvalW != "" and lonvalW is not None:
+            self.lonvalW = float(lonvalW)
+        else:
+            self.lonvalW = ""
+        self.filter = filter_list
+        self.time = time
+        self.coords = coords
+        self.verbose = verbose
+        self.no_missing = no_missing
+        # initialization
+        self.dset = None
+        self.gset = None
+        self.coords_info = coords_info
+        if self.verbose:
+            print("infile: ", self.infile)
+            print("outfile_info: ", self.outfile_info)
+            print("outfile_summary: ", self.outfile_summary)
+            print("outfile: ", self.outfile)
+            print("select: ", self.select)
+            print("outfile: ", self.outfile)
+            print("outputdir: ", self.outputdir)
+            print("latname: ", self.latname)
+            print("latvalN: ", self.latvalN)
+            print("latvalS: ", self.latvalS)
+            print("lonname: ", self.lonname)
+            print("lonvalE: ", self.lonvalE)
+            print("lonvalW: ", self.lonvalW)
+            print("filter: ", self.filter)
+            print("time: ", self.time)
+            print("coords: ", self.coords)
+            print("coords_info: ", self.coords_info)
+
+    def info(self):
+        f = open(self.outfile_info, 'w')
+        ds = xr.open_dataset(self.infile)
+        ds.info(f)
+        f.close()
+
+    def summary(self):
+        f = open(self.outfile_summary, 'w')
+        ds = xr.open_dataset(self.infile)
+        writer = csv.writer(f, delimiter='\t')
+        header = ['VariableName', 'NumberOfDimensions']
+        for idx, val in enumerate(ds.dims.items()):
+            header.append('Dim' + str(idx) + 'Name')
+            header.append('Dim' + str(idx) + 'Size')
+        writer.writerow(header)
+        for name, da in ds.data_vars.items():
+            line = [name]
+            line.append(len(ds[name].shape))
+            for d, s in zip(da.shape, da.sizes):
+                line.append(s)
+                line.append(d)
+            writer.writerow(line)
+        for name, da in ds.coords.items():
+            line = [name]
+            line.append(len(ds[name].shape))
+            for d, s in zip(da.shape, da.sizes):
+                line.append(s)
+                line.append(d)
+            writer.writerow(line)
+        f.close()
+
+    def rowfilter(self, single_filter):
+        split_filter = single_filter.split('#')
+        filter_varname = split_filter[0]
+        op = split_filter[1]
+        ll = float(split_filter[2])
+        if (op == 'bi'):
+            rl = float(split_filter[3])
+        if filter_varname == self.select:
+            # filter on values of the selected variable
+            if op == 'bi':
+                self.dset = self.dset.where(
+                     (self.dset <= rl) & (self.dset >= ll)
+                     )
+            elif op == 'le':
+                self.dset = self.dset.where(self.dset <= ll)
+            elif op == 'ge':
+                self.dset = self.dset.where(self.dset >= ll)
+            elif op == 'e':
+                self.dset = self.dset.where(self.dset == ll)
+        else:  # filter on other dimensions of the selected variable
+            if op == 'bi':
+                self.dset = self.dset.sel({filter_varname: slice(ll, rl)})
+            elif op == 'le':
+                self.dset = self.dset.sel({filter_varname: slice(None, ll)})
+            elif op == 'ge':
+                self.dset = self.dset.sel({filter_varname: slice(ll, None)})
+            elif op == 'e':
+                self.dset = self.dset.sel({filter_varname: ll},
+                                          method='nearest')
+
+    def selection(self):
+        if self.dset is None:
+            self.ds = xr.open_dataset(self.infile)
+            self.dset = self.ds[self.select]  # select variable
+            if self.time:
+                self.datetime_selection()
+            if self.filter:
+                self.filter_selection()
+
+        self.area_selection()
+        if self.gset.count() > 1:
+            # convert to dataframe if several rows and cols
+            self.gset = self.gset.to_dataframe().dropna(how='all'). \
+                        reset_index()
+            self.gset.to_csv(self.outfile, header=True, sep='\t')
+        else:
+            data = {
+                self.latname: [self.gset[self.latname].values],
+                self.lonname: [self.gset[self.lonname].values],
+                self.select: [self.gset.values]
+            }
+
+            df = pd.DataFrame(data, columns=[self.latname, self.lonname,
+                                             self.select])
+            df.to_csv(self.outfile, header=True, sep='\t')
+
+    def datetime_selection(self):
+        split_filter = self.time.split('#')
+        time_varname = split_filter[0]
+        op = split_filter[1]
+        ll = split_filter[2]
+        if (op == 'sl'):
+            rl = split_filter[3]
+            self.dset = self.dset.sel({time_varname: slice(ll, rl)})
+        elif (op == 'to'):
+            self.dset = self.dset.sel({time_varname: slice(None, ll)})
+        elif (op == 'from'):
+            self.dset = self.dset.sel({time_varname: slice(ll, None)})
+        elif (op == 'is'):
+            self.dset = self.dset.sel({time_varname: ll}, method='nearest')
+
+    def filter_selection(self):
+        for single_filter in self.filter:
+            self.rowfilter(single_filter)
+
+    def area_selection(self):
+
+        if self.latvalS != "" and self.lonvalW != "":
+            # Select geographical area
+            self.gset = self.dset.sel({self.latname:
+                                       slice(self.latvalS, self.latvalN),
+                                       self.lonname:
+                                       slice(self.lonvalW, self.lonvalE)})
+        elif self.latvalN != "" and self.lonvalE != "":
+            # select nearest location
+            if self.no_missing:
+                self.nearest_latvalN = self.latvalN
+                self.nearest_lonvalE = self.lonvalE
+            else:
+                # find nearest location without NaN values
+                self.nearest_location()
+            if self.tolerance > 0:
+                self.gset = self.dset.sel({self.latname: self.nearest_latvalN,
+                                           self.lonname: self.nearest_lonvalE},
+                                          method='nearest',
+                                          tolerance=self.tolerance)
+            else:
+                self.gset = self.dset.sel({self.latname: self.nearest_latvalN,
+                                           self.lonname: self.nearest_lonvalE},
+                                          method='nearest')
+        else:
+            self.gset = self.dset
+
+    def nearest_location(self):
+        # Build a geopandas dataframe with all first elements in each dimension
+        # so we assume null values correspond to a mask that is the same for
+        # all dimensions in the dataset.
+        dsel_frame = self.dset
+        for dim in self.dset.dims:
+            if dim != self.latname and dim != self.lonname:
+                dsel_frame = dsel_frame.isel({dim: 0})
+        # transform to pandas dataframe
+        dff = dsel_frame.to_dataframe().dropna().reset_index()
+        # transform to geopandas to collocate
+        gdf = gdp.GeoDataFrame(dff,
+                               geometry=gdp.points_from_xy(dff[self.lonname],
+                                                           dff[self.latname]))
+        # Find nearest location where values are not null
+        point = Point(self.lonvalE, self.latvalN)
+        multipoint = gdf.geometry.unary_union
+        queried_geom, nearest_geom = nearest_points(point, multipoint)
+        self.nearest_latvalN = nearest_geom.y
+        self.nearest_lonvalE = nearest_geom.x
+
+    def selection_from_coords(self):
+        fcoords = pd.read_csv(self.coords, sep='\t')
+        for row in fcoords.itertuples():
+            self.latvalN = row[0]
+            self.lonvalE = row[1]
+            self.outfile = (os.path.join(self.outputdir,
+                            self.select + '_' +
+                            str(row.Index) + '.tabular'))
+            self.selection()
+
+    def get_coords_info(self):
+        ds = xr.open_dataset(self.infile)
+        for c in ds.coords:
+            filename = os.path.join(self.coords_info,
+                                    c.strip() +
+                                    '.tabular')
+            pd = ds.coords[c].to_pandas()
+            pd.index = range(len(pd))
+            pd.to_csv(filename, header=False, sep='\t')
+
+
+if __name__ == '__main__':
+    warnings.filterwarnings("ignore")
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        'infile',
+        help='netCDF input filename'
+    )
+    parser.add_argument(
+        '--info',
+        help='Output filename where metadata information is stored'
+    )
+    parser.add_argument(
+        '--summary',
+        help='Output filename where data summary information is stored'
+    )
+    parser.add_argument(
+        '--select',
+        help='Variable name to select'
+    )
+    parser.add_argument(
+        '--latname',
+        help='Latitude name'
+    )
+    parser.add_argument(
+        '--latvalN',
+        help='North latitude value'
+    )
+    parser.add_argument(
+        '--latvalS',
+        help='South latitude value'
+    )
+    parser.add_argument(
+        '--lonname',
+        help='Longitude name'
+    )
+    parser.add_argument(
+        '--lonvalE',
+        help='East longitude value'
+    )
+    parser.add_argument(
+        '--lonvalW',
+        help='West longitude value'
+    )
+    parser.add_argument(
+        '--tolerance',
+        help='Maximum distance between original and selected value for '
+             ' inexact matches e.g. abs(index[indexer] - target) <= tolerance'
+    )
+    parser.add_argument(
+        '--coords',
+        help='Input file containing Latitude and Longitude'
+             'for geographical selection'
+    )
+    parser.add_argument(
+        '--coords_info',
+        help='output-folder where for each coordinate, coordinate values '
+             ' are being printed in the corresponding outputfile'
+    )
+    parser.add_argument(
+        '--filter',
+        nargs="*",
+        help='Filter list variable#operator#value_s#value_e'
+    )
+    parser.add_argument(
+        '--time',
+        help='select timeseries variable#operator#value_s[#value_e]'
+    )
+    parser.add_argument(
+        '--outfile',
+        help='csv outfile for storing results of the selection'
+             '(valid only when --select)'
+    )
+    parser.add_argument(
+        '--outputdir',
+        help='folder name for storing results with multiple selections'
+             '(valid only when --select)'
+    )
+    parser.add_argument(
+        "-v", "--verbose",
+        help="switch on verbose mode",
+        action="store_true"
+    )
+    parser.add_argument(
+        "--no_missing",
+        help="""Do not take into account possible null/missing values
+                (only valid for single location)""",
+        action="store_true"
+    )
+    args = parser.parse_args()
+
+    p = XarrayTool(args.infile, args.info, args.summary, args.select,
+                   args.outfile, args.outputdir, args.latname,
+                   args.latvalN, args.latvalS, args.lonname,
+                   args.lonvalE, args.lonvalW, args.filter,
+                   args.coords, args.time, args.verbose,
+                   args.no_missing, args.coords_info, args.tolerance)
+    if args.info:
+        p.info()
+    if args.summary:
+        p.summary()
+    if args.coords:
+        p.selection_from_coords()
+    elif args.select:
+        p.selection()
+    elif args.coords_info:
+        p.get_coords_info()