import matplotlib.pyplot as plt
import numpy as np

def demo_mutual_approx(bpsk,qam4,qam16, qam64, qam256, qam1024, EsNo_start, EsNo_end, mutInf_end, plotOverEbNo, filename):

    SNR_start = EsNo_start;
    SNR_end = EsNo_end;
    
    eb_no = plotOverEbNo;
    if eb_no:
        SNR_start = -20
    
    SNR_dB = np.linspace(SNR_start, SNR_end)
    SNR_linear = 10**(SNR_dB/10)
    shannon = np.log2(1+SNR_linear)
    
    snr = SNR_dB - 10*np.log10(shannon) if eb_no else SNR_dB 
    plt.plot(0, -1, 's', color = 'lightblue', label = 'Unachievable region')
    plt.plot(snr,shannon,label='Shannon Limit')
    plt.fill_between(snr, 100, shannon, where=shannon<=100, color='lightblue')
    plt.xlim(min(snr), max(snr))
    plt.ylim((0,mutInf_end))
    
    for m in [1,2,4,6,8,10][::-1]:
        name = 0
        if m == 1 and bpsk:
            a = [0.052867, 0.585823, 0.361334]
            b = [2.294948*2, 0.522143*2, 0.805668*2]
            name = "BPSK"
        elif m == 2 and qam4:
            a = [0.052867, 0.585823, 0.361334]
            b = [2.294948, 0.522143, 0.805668]
            name = "4-QAM"
        elif m == 4 and qam16:
            a = [0.219856, 0.05748, 0.220352, 0.502069]
            b = [0.179376, 1.906477, 0.700161, 0.106817]
            name = "16-QAM"
        elif m == 6 and qam64:
            a = [0.140371, 0.434353, 0.051814, 0.145847, 0.227242]
            b = [0.050582, 0.026002, 1.688834, 0.603848, 0.185743]
            name = "64-QAM"
        elif m == 8 and qam256:
            a = [0.122949, 0.16287, 0.352661, 0.044999, 0.204491, 0.111848]
            b = [0.529204, 0.160656, 0.006443, 1.630036, 0.047423, 0.012854]
            name = "256-QAM"
        elif m == 10 and qam1024:
            a = [0.139684, 0.174438, 0.278827, 0.042557, 0.109014, 0.100422, 0.154803]
            b = [0.131906, 0.011354, 0.00159, 1.520725, 0.457703, 0.002928, 0.038759]
            name = "1024-QAM"
        
        if name:
            MI = m * (1 - np.sum(np.array(a) * np.exp(-np.array(b) * np.expand_dims(SNR_linear,1)),axis=1))
            snr = SNR_dB - 10*np.log10(MI) if eb_no else SNR_dB 
            plt.plot(snr,MI,label=name)
    
    plt.title("Achievable rates for AWGN channel with bitwise soft demapping")
    plt.xlabel('Eb/N0 (dB)' if eb_no else 'Es/N0 (dB)')
    plt.ylabel('mutual information (bits/channel use)')
    plt.legend()
    plt.grid()
    plt.savefig(filename)
    plt.close()

