#attach "/home/wstein/psage/chow_heegner/psage/ellcurve/points/chow_heegner.py"
attach "chow_heegner.py"

def table_row_fiber(E, F, ch=True, B=150,  **kwds):
    if isinstance(E, str): E = EllipticCurve(E)
    if isinstance(F, str): F = EllipticCurve(F)
    d = {}
    d['E'] = E.cremona_label()
    d['F'] = F.cremona_label()
    assert d['E'].endswith('1')
    assert d['F'].endswith('1')
    d['E'] = d['E'][:-1]
    d['F'] = d['F'][:-1]
    d['rE'] = E.rank()
    d['rF'] = F.rank()
    if ch:
        tm = cputime()
        data = ChowHeegner2(E, F).point_on_E_fiber_cardinality2(**kwds)
        tm = cputime(tm)
        d['time'] = tm
        try:
            P = identify(data['P'], B)[0]
        except ValueError, msg:
            print msg
            return data
        mEF = '?'
    else:
        P = '?'
        mEF = '?'
    d['mE'] = E.modular_degree()
    d['mF'] = F.modular_degree()
    G = E.gens()
    G0 = [G[i] for i in range(E.rank())]
    T = [x.element() for x in E.torsion_subgroup().gens()]
    EQ = G0 + list(T)
    d['EQ'] = [((x[0],x[1]), x.order() if x.order()<oo else 0) for x in EQ]

    d['P'] = '?'
    if P != '?':
        print "P = ", P
        ords = [x.order() if x.order()<oo else B for x in EQ]
        rng = [range(-B,B+1) if o == B else range(o) for o in ords]
        for t in cartesian_product_iterator(rng):
            if sum(t[i]*EQ[i] for i in range(len(t))) == P:
                d['P'] = list(t)
                break 
    d['data'] = data
    d['kwds'] = kwds

    if P == 0:
        d['Ptex'] = '0'
    else:
        d['Ptex'] = ('+'.join('%sP_{%s}'%('' if n == 1 else n, i+1) for i,n in enumerate(d['P']) if n)).replace('+-','-')

    tex = '\\CR{%s} & $%s$ & $%s$ & $%s$ & \\CR{%s} & $%s$ & $%s$ & $%s$ & \\\\\\hline'%(
          d['E'], d['rE'],
        ','.join('%s_{%s}'%v if v[1] else str(v[0]) for v in d['EQ']), d['mE'], 
          d['F'], d['rF'], d['mF'], d['Ptex'])

    d['tex'] = tex
    return d




def cases(N1,N2):
    for E in cremona_optimal_curves([N1..N2]):
        if E.rank() == 1:
            for F in cremona_optimal_curves([E.conductor()]):
                if F != E:
                    print E.cremona_label(), E.rank(), F.cremona_label(), F.rank(), F.modular_degree()    

def do_all(N1, N2):
    answers = open('answers.txt','a')
    fails = open('fails.txt','a')
    for E in cremona_optimal_curves([N1..N2]):
        if E.rank() == 1:
            for F in cremona_optimal_curves([E.conductor()]):
                if F != E:
                    lE = E.cremona_label(); lF = F.cremona_label()
                    print lE, E.rank(), lF, F.rank(), F.modular_degree()
                    sys.stdout.flush()
                    t = cputime()
                    z = table_row_fiber(E, F, equiv_prec=20)
                    print "TIME (%s, %s): %s"%(lE, lF, cputime(t))
                    if z.has_key('tex'):
                        print z['tex']
                        answers.write(z['tex'] + '\n')
                        answers.flush()
                    else:
                        print "** FAIL (%s, %s)"%(lE, lF)
                        fails.write('%s %s\n'%(lE, lF))
                        fails.flush()
                    sys.stdout.flush()                        
                        


def retry1_fails():
    answers = open('answers-retry1.txt','a')
    fails = open('fails-retry1.txt','a')
    for X in open('fails.txt').readlines():
        E, F = X.split()[:2]
        print E, F
        sys.stdout.flush()
        t = cputime()
        z = table_row_fiber(E, F, equiv_prec=20, deg1=1500, min_imag=1e-5)
        print "TIME (%s, %s): %s"%(E, F, cputime(t))
        if z.has_key('tex'):
            print z['tex']
            answers.write(z['tex'] + '% deg1=1500, min_imag=1e-5\n')
            answers.flush()
        else:
            print "** FAIL (%s, %s)"%(E, F)
            fails.write('%s %s 1500 1e-5\n'%(E, F))
            fails.flush()
        sys.stdout.flush()                        


def retry2_fails():
    answers = open('answers-retry2.txt','a')
    fails = open('fails-retry2.txt','a')
    for X in open('fails-retry1.txt').readlines():
        E, F = X.split()[:2]
        print E, F
        sys.stdout.flush()
        t = cputime()
        z = table_row_fiber(E, F, equiv_prec=20, deg1=3000, min_imag=1e-5/2)
        print "TIME (%s, %s): %s"%(E, F, cputime(t))
        if z.has_key('tex'):
            print z['tex']
            answers.write(z['tex'] + '% deg1=3000, min_imag=1e-5/2\n')
            answers.flush()
        else:
            print "** FAIL (%s, %s)"%(E, F)
            fails.write('%s %s 3000 1e-5/2\n'%(E, F))
            fails.flush()
        sys.stdout.flush()                        


def retry3_fails():
    answers = open('answers-retry3.txt','a')
    fails = open('fails-retry3.txt','a')
    for X in open('fails-retry2.txt').readlines():
        E, F = X.split()[:2]
        print E, F
        sys.stdout.flush()
        t = cputime()
        z = table_row_fiber(E, F, equiv_prec=20, deg1=3500, min_imag=1e-5/4)
        print "TIME (%s, %s): %s"%(E, F, cputime(t))
        if z.has_key('tex'):
            print z['tex']
            answers.write(z['tex'] + '% deg1=3500, min_imag=1e-5/4\n')
            answers.flush()
        else:
            print "** FAIL (%s, %s)"%(E, F)
            fails.write('%s %s 3500 1e-5/5\n'%(E, F))
            fails.flush()
        sys.stdout.flush()                        

def sgn(n):
    return '+' if n > 0 else '-'

def add_local_roots(s):
    for X in open(s).readlines():
        if not ('CR' in X and 'hline' in X):
            print X.rstrip('\n')
            continue
        i = X.find('{')
        j = X.find('}')
        label = X[i+1:j]
        E = EllipticCurve(label)
        eps = ''.join([sgn(E.root_number(p)) for p in E.conductor().prime_divisors()])
        print X[:j+1] + '& ' + eps + X[j+1:].rstrip('\n')
        
        
        

##########################################################
# Double checking/replicating the table in the paper as
# quickly as possible from scratch on multi-core machine.
##########################################################
# 1. Extract key information from table.
# 2. Recompute it in parallel using the same parameters,
#    except for the starting real numbers z.
# 3. Confirm that we get the same answers.  Record
#    any that are different (there should be none).
##########################################################
#
# sage: t = extract_table_from_paper()
# sage: time recompute_table_in_parallel(t, z=0.1, ncpu=6)
# sage: time recompute_table_in_parallel(t, z=0.2, ncpu=6)
#

def extract_table_from_paper():
    v = []
    for X in open('chow_heegner_talk.tex').readlines():
        if ('CR' in X and 'hline' in X and not X.startswith('%')):
            E,_,_,_,_,F,_,_,P,code = X.split('&')
            E = E.strip()[4:-1]
            F = F.strip()[4:-1]
            P = P.strip().strip('$')
            code = code[:code.find('\\')].strip()
            v.append({'E':E,'F':F,'P':P,'code':code})
    return v

def recompute_table_in_parallel(table, z=0.1, ncpu=12):
    from nosqlite import Client
    @parallel(ncpu)
    def f(E,F,P,code):
        data = Client('db').db.data
        print 'E=',E
        v = list(data.find(E=E,F=F,z=str(z)))
        if len(v) > 0:
            print v[0]
            return v[0]
        if code == '':
            min_imag=1e-4
            deg1=500
            equiv_prec=20
        elif code == '(1)':
            min_imag=1e-5
            deg1=1500
            equiv_prec=20
        elif code == '(2)':
            min_imag=1e-5/2
            deg1=3000
            equiv_prec=20
        print "Computing... ", (E,F,P,code)
        sys.stdout.flush()
        d = table_row_fiber(E, F, z=z,
                min_imag=min_imag, deg1=deg1, equiv_prec=equiv_prec)
        del d['data']
        d['z'] = str(z)
        d['same_answer'] = (d['Ptex'] == P)
        data.insert(d)
        return d

    for X in f([(t['E'],t['F'],t['P'],t['code']) for t in table]):
        print X
        
        

        
def ai(E):
    if isinstance(E, str):
        E = EllipticCurve(E)
    print E.galois_representation().non_surjective()
    print E.tamagawa_numbers()
    print E.modular_degree()
