import sys
import matplotlib.animation as animation
import constellation as constellations
import pulse_shaper
from timing_error_compensation import compensate_timing_error
import numpy as np
import matplotlib.pyplot as plt
from sampling import upsample
import timing_error_detection


def s_curve(snr, bits_per_symbol, roll_off, avglen, gardner, sq_gardner, zc, mm, mmm, filename):

  NUM_BITS_PER_SYMBOL = int(bits_per_symbol)
  AVGLEN = 500
  RUNS = 200
  ROLL_OFF = 0.1
  SNR = int(snr)
  avglen = int(avglen)

  symbol_indices = np.random.randint(0, int(2**NUM_BITS_PER_SYMBOL), int(2**15))


  snr_lin = 10**(SNR/10)


  if NUM_BITS_PER_SYMBOL == 1:
    symbols = constellations.bpsk(symbol_indices)
    constellation = constellations.bpsk(np.arange(int(2**NUM_BITS_PER_SYMBOL)))
  else:
    symbols = constellations.qam(symbol_indices, NUM_BITS_PER_SYMBOL)
    constellation = constellations.qam(np.arange(int(2**NUM_BITS_PER_SYMBOL)),
                                      NUM_BITS_PER_SYMBOL)


  us_factor = 32
  span_in_symbols = 32
  us_symbols = pulse_shaper.sinc(upsample(symbols, us_factor), 1.0, us_factor, span_in_symbols)
  # us_symbols = pulse_shaper.raised_cosine_shaping(upsample(symbols, us_factor), 1.0, us_factor, span_in_symbols, ROLL_OFF)
  noise = np.random.normal(0, 1/np.sqrt(2*snr_lin), us_symbols.shape) + 1j*np.random.normal(0, 1/np.sqrt(2*snr_lin), us_symbols.shape)
  us_symbols = us_symbols + noise
  shaper_length = us_factor * span_in_symbols
  shaper_length = shaper_length + (shaper_length + 1) % 2

  timing_errors = np.linspace(-0.5, 0.5, 32)
  timing_errors = np.arange(us_factor) - us_factor//2

  if gardner:
    teds_gardner = np.zeros((len(timing_errors),))
    for i,terr in enumerate(timing_errors):
      time_shifted_symbols = us_symbols[shaper_length//2+terr:-(shaper_length//2)+terr][::us_factor]
      ted = timing_error_detection.gardner_ted(time_shifted_symbols, avglen, "rc", False, ROLL_OFF)
      teds_gardner[i] = np.mean(ted)

  if sq_gardner:
    teds_sq_gardner = np.zeros((len(timing_errors),))
    for i,terr in enumerate(timing_errors):
      time_shifted_symbols = us_symbols[shaper_length//2+terr:-(shaper_length//2)+terr][::us_factor]
      ted = timing_error_detection.gardner_ted(time_shifted_symbols, avglen, "rc", True, ROLL_OFF)
      teds_sq_gardner[i] = np.mean(ted)

  if zc:
    teds_zc = np.zeros((len(timing_errors),))
    for i,terr in enumerate(timing_errors):
      time_shifted_symbols = us_symbols[shaper_length//2+terr:-(shaper_length//2)+terr][::us_factor]
      ted = timing_error_detection.zero_crossing_new_ted(time_shifted_symbols, constellation, avglen, "rc", False, ROLL_OFF)
      teds_zc[i] = np.mean(ted)

  if mm:
    teds_mm = np.zeros((len(timing_errors),))
    for i,terr in enumerate(timing_errors):
      time_shifted_symbols = us_symbols[shaper_length//2+terr:-(shaper_length//2)+terr][::us_factor]
      ted = timing_error_detection.mueller_muller_ted(time_shifted_symbols, constellation, avglen, False)
      teds_mm[i] = np.mean(ted)

  if mmm:
    teds_mmm = np.zeros((len(timing_errors),))
    for i,terr in enumerate(timing_errors):
      time_shifted_symbols = us_symbols[shaper_length//2+terr:-(shaper_length//2)+terr][::us_factor]
      ted = timing_error_detection.mod_mueller_muller_ted(time_shifted_symbols, constellation, avglen, False)
      teds_mmm[i] = np.mean(ted)


  fig, ax = plt.subplots(figsize=(12.8,4.8))
  if gardner:
    ax.plot(timing_errors/us_factor, -teds_gardner, label='Gardner TED')
  if sq_gardner:
    ax.plot(timing_errors/us_factor, -teds_sq_gardner, label='Squared Gardner TED')
  if zc:
    ax.plot(timing_errors/us_factor, -teds_zc, label='Zero-Crossing TED')
  if mm:
    ax.plot(timing_errors/us_factor, -teds_mm, label='Mueller-Muller TED')
  if mmm:
    ax.plot(timing_errors/us_factor, -teds_mmm, label='Modified Mueller-Muller TED')
  ax.legend()
  ax.set_xlim([timing_errors[0]/us_factor, timing_errors[-1]/us_factor])
  ax.hlines(0, timing_errors[0]/us_factor, timing_errors[-1]/us_factor, linestyles='dashed', color="#323232", alpha=0.5)
  ax.set_xlabel(r"Time Shift ($1/T_\mathrm{S}$)")
  ax.set_ylabel("TED")
  fig.savefig(filename)

