import sys
import os
import subprocess
import random
import math


"""
Plot simple curves in R
"""

class RPlotter(object):
  """
  Plot some curves
  @ivar nbColors:   number of different colors
  @type nbColors:   int
  @ivar fileName:   name of the file
  @type fileName:   string
  @ivar lines:      lines to be plotted
  @type lines:      array of dict
  @ivar names:      name of the lines
  @type names:      array of strings
  @ivar colors:     color of the lines
  @type colors:     array of strings
  @ivar types:      type of the lines (plain or dashed)
  @type types:      array of strings
  @ivar format:     format of the picture
  @type format:     string
  @ivar xMin:       minimum value taken on the x-axis
  @type xMin:       int
  @ivar xMax:       maximum value taken on the x-axis
  @type xMax:       int
  @ivar yMin:       minimum value taken on the y-axis
  @type yMin:       int
  @ivar yMax:       maximum value taken on the y-axis
  @type yMax:       int
  @ivar minimumX:   minimum value allowed on the x-axis
  @type minimumX:   int
  @ivar maximumX:   maximum value allowed on the x-axis
  @type maximumX:   int
  @ivar minimumY:   minimum value allowed on the y-axis
  @type minimumY:   int
  @ivar maximumY:   maximum value allowed on the y-axis
  @type maximumY:   int
  @ivar logX        use log scale on the x-axis
  @type logX        boolean
  @ivar logY        use log scale on the y-axis
  @type logY        boolean
  @ivar logZ        use log scale on the z-axis (the color)
  @type logZ        boolean
  @ival fill:       if a value is not given, fill it with given value
  @type fill:       int
  @ival bucket:     cluster the data into buckets of given size
  @type bucket:     int
  @ival seed:       a random number
  @type seed:       int
  @ival legend:     set the legend
  @type legend:     boolean
  @ival xLabel:     label for the x-axis
  @type xLabel:     string
  @ival yLabel:     label for the y-axis
  @type yLabel:     string
  @ival title:      title of the plot
  @type title:      string
  @ival barplot:    use a barplot representation instead
  @type barplot:    boolean
  @ival points:     use a point cloud instead
  @type points:     boolean
  @ival heatPoints: use a colored point cloud instead
  @type heatPoints: boolean
  @ival verbosity:  verbosity of the class
  @type verbosity:  int
  @ival keep:       keep temporary files
  @type keep:       boolean
  """

  def __init__(self, fileName, verbosity = 0, keep = False):
    """
    Constructor
    @param fileName:  name of the file to produce
    @type  fileName:  string
    @param verbosity: verbosity
    @type  verbosity: int
    @param keep:      keep temporary files
    @type  keep:      boolean
    """
    self.nbColors   = 9
    self.fileName   = fileName
    self.verbosity  = verbosity
    self.keep       = keep
    self.format     = "png"
    self.fill       = None
    self.bucket     = None
    self.lines      = []
    self.names      = []
    self.colors     = []
    self.types      = []
    self.xMin       = None
    self.xMax       = None
    self.yMin       = None
    self.yMax       = None
    self.seed       = random.randint(0, 10000)
    self.minimumX   = None
    self.maximumX   = None
    self.minimumY   = None
    self.maximumY   = None
    self.logX       = False
    self.logY       = False
    self.logZ       = False
    self.width      = 1000
    self.height     = 500
    self.legend     = False
    self.xLabel     = ""
    self.yLabel     = ""
    self.title      = None
    self.points     = False
    self.heatPoints = False
    self.barplot    = False

  
  def __del__(self):
    """
    Destructor
    Remove tmp files
    """
    if not self.keep:
      scriptFileName = "tmpScript-%d.R" % (self.seed)
      os.remove(scriptFileName)
      outputFileName = "%sout" % (scriptFileName)
      if os.path.exists(outputFileName):
        os.remove(outputFileName)
      for i in range(0, len(self.lines)):
        os.remove("tmpData-%d-%d.dat" % (self.seed, i))


    
  def setMinimumX(self, xMin):
    """
    Set the minimum value on the x-axis
    @param xMin:  minimum value on the x-axis
    @type  xMin:  int
    """
    self.minimumX = xMin

    
  def setMaximumX(self, xMax):
    """
    Set the maximum value on the x-axis
    @param xMax:  maximum value on the x-axis
    @type  xMax:  int
    """
    self.maximumX = xMax
    
  
  def setMinimumY(self, yMin):
    """
    Set the minimum value on the y-axis
    @param yMin:  minimum value on the y-axis
    @type  yMin:  int
    """
    self.minimumY = yMin

    
  def setMaximumY(self, yMax):
    """
    Set the maximum value on the y-axis
    @param yMax:  maximum value on the y-axis
    @type  xmax:  int
    """
    self.maximumY = yMax
    
  
  def setFill(self, fill):
    """
    Fill empty data with given value
    @param fill: the value to fill with
    @type  fill: int
    """
    self.fill = fill


  def setBuckets(self, bucket):
    """
    Cluster the data into buckets of given size
    @param bucket: the size of the buckets
    @type  bucket: int
    """
    self.bucket = bucket



  def setFormat(self, format):
    """
    Set the format of the picture
    @param format: the format
    @type  format: string
    """
    if format not in ("png", "pdf", "jpeg", "bmp", "tiff"):
      sys.exit("Format '%s' is not supported by RPlotter" % (format))
    self.format = format

    
  def setImageSize(self, width, height):
    """
    Set the dimensions of the image produced
    @param width:  width of the image
    @type  width:  int
    @param height: heigth of the image
    @type  height: int
    """
    self.width  = width
    self.height = height
    
    
  def setLegend(self, legend):
    """
    Print a legend or not
    @param legend: print a legend
    @type  legend: boolean
    """
    self.legend = legend


  def setXLabel(self, label):
    """
    Print a label for the x-axis
    @param label: the label
    @type  label: string
    """
    self.xLabel = label
    if self.xLabel != None:
      self.xLabel = self.xLabel.replace("_", " ")


  def setYLabel(self, label):
    """
    Print a label for the y-axis
    @param label: the label
    @type  label: string
    """
    self.yLabel = label
    if self.yLabel != None:
      self.yLabel = self.yLabel.replace("_", " ")


  def setTitle(self, title):
    """
    Print a title for graph
    @param title: a title
    @type  title: string
    """
    self.title = title
    if self.title != None:
      self.title = self.title.replace("_", " ")


  def setLog(self, log):
    """
    Use log-scale for axes
    @param log: use log scale
    @type  log: boolean
    """
    self.logX = (log.find("x") >= 0)
    self.logY = (log.find("y") >= 0)
    self.logZ = (log.find("z") >= 0)
    

  def setBarplot(self, barplot):
    """
    Use barplot representation instead
    @param barplot: barplot representation
    @type  barplot: boolean
    """
    self.barplot = barplot
    

  def setPoints(self, points):
    """
    Use points cloud representation instead
    @param points: points cloud representation
    @type  points: boolean
    """
    self.points = points
    

  def setHeatPoints(self, heatPoints):
    """
    Use points cloud representation with color representing another variable instead
    @param points: colored points cloud representation
    @type  points: boolean
    """
    self.heatPoints = heatPoints
    

  def addLine(self, line, name = "", color = None):
    """
    Add a line 
    @param line: a line to plot
    @type  line: dict
    """
    # prepare data
    plot = {}
    if self.points or self.heatPoints:
      values = line.values()
    elif self.fill == None:
      values = sorted(line.keys())
    else:
      values = range(min(line.keys()), max(line.keys()) + 1)
      
    for element in values:
      if self.points or self.heatPoints:
        x = element[0]
        y = element[1]
      else:
        x = element
        if x not in line:
          y = self.fill
        else:
          y = line[x]
        
      if self.minimumX != None and x < self.minimumX:
        continue
      if self.maximumX != None and x > self.maximumX:
        continue
      
      if x == None:
        sys.exit("Problem! x is None. Aborting...")
      if y == None:
        sys.exit("Problem! y is None. Aborting...")
      if self.xMin == None:
        if not self.logX or x != 0:
          self.xMin = x
      else:
        if not self.logX or x != 0:
          self.xMin = min(self.xMin, x)
      if self.xMax == None:
        self.xMax = x
      else:
        self.xMax = max(self.xMax, x)
      if self.yMin == None:
        if not self.logY or y != 0:
          self.yMin = y
      else:
        if not self.logY or y != 0:
          self.yMin = min(self.yMin, y)
      if self.yMax == None:
        self.yMax = y
      else:
        self.yMax = max(self.yMax, y)

      plot[x] = y

    # cluster the data into buckets
    if self.bucket != None:
      buckets = dict([((int(value) / int(self.bucket)) * self.bucket, 0) for value in xrange(min(line.keys()), max(line.keys())+1)])
      for distance, nb in line.iteritems():
        buckets[(int(distance) / int(self.bucket)) * self.bucket] += nb
      plot = buckets

    # write file
    dataFileName = "tmpData-%d-%d.dat" % (self.seed, len(self.lines))
    dataHandle   = open(dataFileName, "w")
    for x in sorted(plot.keys()):
      dataHandle.write("%f\t%f\n" % (x, plot[x]))
    dataHandle.close()

    self.lines.append(line)
    self.names.append(name)

    if color == None:
      colorNumber = len(self.colors) % (self.nbColors - 1) + 1
      type        = "solid"
      if len(self.colors) >= self.nbColors:
        type = "dashed"
      color = "colorPanel[%d]" % (colorNumber)
    else:
      color = "\"%s\"" % (color)
      type  = "solid"
    self.colors.append(color)
    self.types.append(type)


  def addHeatLine(self, line, name = "", color = None):
    """
    Add the heat line 
    @param line: the line which gives the color of the points
    @type  line: dict
    """
    if not self.heatPoints:
      sys.exit("Error! Trying to add a heat point whereas not mentioned to earlier! Aborting.")
      
    dataFileName = "tmpData-%d-%d.dat" % (self.seed, len(self.lines))
    dataHandle   = open(dataFileName, "w")
  
    minimumHeat = min(line.values())
    maximumHeat = max(line.values())
    minLogValue = 0.00001
    log         = (self.log.find("z") >= 0)
    
    if log:
      if minimumHeat == 0:
        for element in line:
          line[element] += minLogValue
        minimumHeat += minLogValue
        maximumHeat += minLogValue
      minimumHeat = math.log10(minimumHeat)
      maximumHeat = math.log10(maximumHeat)
    
    coeff = 255.0 / (maximumHeat - minimumHeat)

    for element in line:
      value = line[element]
      if log:
        value = math.log10(max(minLogValue, value))
      dataHandle.write("\"#%02X%02X00\"\n" % (int((value - minimumHeat) * coeff), 255 - int((value - minimumHeat) * coeff)))

    dataHandle.close()
    self.names.append(name)
    if color == None:
      colorNumber = len(self.colors) % (self.nbColors - 1) + 1
      type        = "solid"
      if len(self.colors) >= self.nbColors:
        type = "dashed"
      color = "colorPanel[%d]" % (colorNumber)
    else:
      color = "\"%s\"" % (color)
      type  = "solid"
    self.colors.append(color)
    self.types.append(type)


  def plot(self):
    """
    Plot the lines
    """
    lineWidth = 1

    colors = []
    types  = []
    
    xMin = self.xMin
    if self.minimumX != None:
      xMin = max(xMin, self.minimumX)
    xMax = self.xMax
    if self.maximumX != None:
      xMax = min(xMax, self.maximumX)
    yMin = self.yMin
    if self.minimumY != None:
      yMin = max(yMin, self.minimumY)
    yMax = self.yMax
    if self.maximumY != None:
      yMax = min(yMax, self.maximumY)
    
    log = ""
    if self.logX:
      log += "x"
    if self.logY:
      log += "y"
    if log != "":
      log = ", log=\"%s\"" % (log)

    title = ""
    if self.title != None:
      title = ", main = \"%s\"" % (self.title)

    scriptFileName = "tmpScript-%d.R" % (self.seed)
    scriptHandle   = open(scriptFileName, "w")
    scriptHandle.write("library(RColorBrewer)\n")
    scriptHandle.write("colorPanel = brewer.pal(n=%d, name=\"Set1\")\n" % (self.nbColors))
    scriptHandle.write("%s(%s = \"%s\", width = %d, height = %d, bg = \"white\")\n" % (self.format, "filename" if self.format != "pdf" else "file", self.fileName, self.width, self.height))
    
    if self.barplot:
      scriptHandle.write("data = scan(\"tmpData-%d-0.dat\", list(x = -666, y = -666))\n" % (self.seed))
      scriptHandle.write("barplot(data$y, name = data$x, xlab=\"%s\", ylab=\"%s\", cex.axis = 2, , cex.names = 2, cex.lab = 2%s%s)\n" % (self.xLabel, self.yLabel, title, log))
    elif self.points:
      scriptHandle.write("data = scan(\"tmpData-%d-0.dat\", list(x = -666, y = -666))\n" % (self.seed))
      scriptHandle.write("plot(data$x, data$y, xlab=\"%s\", ylab=\"%s\", cex.axis = 2, cex.lab = 2%s%s)\n" % (self.xLabel, self.yLabel, title, log))
    elif self.heatPoints:
      if len(self.lines) != 1:
        sys.exit("Error! Bad number of input data! Aborting...")
      scriptHandle.write("data = scan(\"tmpData-%d-0.dat\", list(x = -666, y = -666))\n" % (self.seed))
      scriptHandle.write("heatData = scan(\"tmpData-%d-1.dat\", list(x = \"\"))\n" % (self.seed))
      scriptHandle.write("plot(data$x, data$y, col=heatData$x, xlab=\"%s\", ylab=\"%s\", cex.axis = 2, cex.lab = 2%s%s)\n" % (self.xLabel, self.yLabel, title, log))
    else:
      scriptHandle.write("plot(x = NA, y = NA, panel.first = grid(lwd = 1.0), xlab=\"%s\", ylab=\"%s\", xlim = c(%f, %d), ylim = c(%f, %d), cex.axis = 2, cex.lab = 2%s%s)\n" % (self.xLabel, self.yLabel, xMin, xMax, yMin, yMax, title, log))
      for i in range(0, len(self.lines)):
        scriptHandle.write("data = scan(\"tmpData-%d-%d.dat\", list(x = -666.666, y = -666.666))\n" % (self.seed, i))
        scriptHandle.write("lines(x = data$x, y = data$y, col = %s, lty = \"%s\", lwd = %d)\n" % (self.colors[i], self.types[i], lineWidth))
        
      if self.legend:
        scriptHandle.write("legends    = c(%s)\n" % ", ".join(["\"%s\"" % name  for name  in self.names]))
        scriptHandle.write("colors     = c(%s)\n" % ", ".join(["%s" %     color for color in self.colors]))
        scriptHandle.write("lineTypes  = c(%s)\n" % ", ".join(["\"%s\"" % type  for type  in self.types]))
        scriptHandle.write("legend(%d, %d, legend = legends, xjust = 0, yjust = 1, col = colors, lty = lineTypes, lwd = %d, cex = 1.5, ncol = 1, bg = \"white\")\n" % (self.xMax, self.yMax, lineWidth))

    scriptHandle.write("dev.off()\n")
    scriptHandle.close()
    rCommand = "R"
    if "SMARTRPATH" in os.environ:
      rCommand = os.environ["SMARTRPATH"]
    command = "\"%s\" CMD BATCH %s" % (rCommand, scriptFileName)
    status  = subprocess.call(command, shell=True)

    if status != 0:
      self.keep = True
      sys.exit("Problem with the execution of script file %s, status is: %s" % (scriptFileName, status))
      

  def getSpearmanRho(self):
    """
    Get the Spearman rho correlation using R
    """
    if not self.points and not self.barplot:
      sys.exit("Cannot compute Spearman rho correlation whereas not in 'points' or 'bar' mode.")
    
    scriptFileName = "tmpScript-%d.R" % (self.seed)
    rScript = open(scriptFileName, "w")
    rScript.write("library(Hmisc)\n")
    rScript.write("data = scan(\"tmpData-%d-0.dat\", list(x = -0.000000, y = -0.000000))\n" % (self.seed))
    rScript.write("spearman(data$x, data$y)\n")
    rScript.close()

    rCommand = "R"
    if "SMARTRPATH" in os.environ:
      rCommand = os.environ["SMARTRPATH"]
    command = "\"%s\" CMD BATCH %s" % (rCommand, scriptFileName)
    status  = subprocess.call(command, shell=True)

    if status != 0:
      self.keep = True
      sys.exit("Problem with the execution of script file %s, status is: %s" % (scriptFileName, status))

    outputRFile = open("%sout" % (scriptFileName))
    nextLine    = False
    for line in outputRFile:
      line = line.strip()
      if nextLine:
        if line == "NA":
          return None
        return float(line)
        nextLine = False
      if line == "rho":
        nextLine = True

    return None
