import math

def superimpose(coord0, coord1, natm):
    # coord0 is target
    # coord1 is prediction. output is superimposed
    # err is the rmsd of natm atoms
    # get the rms deviation between two arbitrarilly oriented sets of 'natm' atoms.
    #MAXATOM = 200
    mtx = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
    vec = [0.0, 0.0, 0.0]
    #if (natm > MAXATOM):
    #    print "superpose: ERROR - Increase MAXATOM!!!"
    #    exit()
    tol = 0.0001
    for i in range(0, natm):
        if (max([coord0[0][i], coord1[0][i]]) > 998):
            err = -1
            return (err, mtx, vec)
    xc0 = [0.0, 0.0, 0.0]
    xc1 = [0.0, 0.0, 0.0]
    for i in range(0, natm):
        for j in range(0, 3):
            xc0[j] = xc0[j] + coord0[j][i]
            xc1[j] = xc1[j] + coord1[j][i]
    for j in range(0, 3):
        xc0[j] = xc0[j] / float(natm)
        xc1[j] = xc1[j] / float(natm)
    # Center on origin
    x0 = []
    x1 = []
    for j in range(0, 3):
        x0.append([])
        x1.append([])
        for i in range(0, natm):
            x0[j].append(0.0)
            x1[j].append(0.0)
    for i in range(0, natm):
        for j in range(0, 3):
            x0[j][i] = coord0[j][i] - xc0[j]
            x1[j][i] = coord1[j][i] - xc1[j]
    # End center on origin
    aa = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
    for k in range(0, natm):
        for i in range(0, 3):
            for j in range(0, 3):
                aa[i][j] = aa[i][j] + x1[i][k] * x0[j][k]
    rot = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
    # Iterative rotation scheme
    ict = 0
    do51 = True # This is a way to deal with those nasty gotos in the FORTRAN code
    while (True):
        if (do51):
            iflag = 0
            ix = 0
        ict = ict + 1
        if (ict > 1000):
            break
        iy = ix + 1
        if (iy == 3):
            iy = 0
        iz = 3 - ix - iy
        sig = aa[iz][iy] - aa[iy][iz]
        gam = aa[iy][iy] + aa[iz][iz]
        sg = math.sqrt(sig * sig + gam * gam)
        if (sg == 0):
            # Goto 50 in FORTRAN code
            ix = ix + 1
            if (iflag == 0):
                break
            if (ix < 3):
                do51 = False
            else:
                do51 = True
            continue
        sg = 1.0 / float(sg)
        if (math.fabs(sig) < tol * math.fabs(gam)):
            # Goto 50 in FORTRAN code
            ix = ix + 1
            if (iflag == 0):
                break
            if (ix < 3):
                do51 = False
            else:
                do51 = True
            continue
        for k in range(0, 3):
            bb = gam * aa[iy][k] + sig * aa[iz][k]
            cc = gam * aa[iz][k] - sig * aa[iy][k]
            aa[iy][k] = bb * sg
            aa[iz][k] = cc * sg
            bb = gam * rot[iy][k] + sig * rot[iz][k]
            cc = gam * rot[iz][k] - sig * rot[iy][k]
            rot[iy][k] = bb * sg
            rot[iz][k] = cc * sg
        iflag = 1
        # Goto 50 in FORTRAN code
        ix = ix + 1
        if (iflag == 0):
            break
        if (ix < 3):
            do51 = False
        else:
            do51 = True
        continue
    t = [0.0, 0.0, 0.0]
    for i in range(0, natm):
        for j in range(0, 3):
            t[j] = 0.0
            for k in range(0, 3):
                t[j] = t[j] + rot[j][k] * x1[k][i]
        for j in range(0, 3):
            x1[j][i] = t[j]
    err = 0.0
    for i in range(0, natm):
        for j in range(0, 3):
            err = err + (x0[j][i] - x1[j][i]) ** 2
    err = math.sqrt(err / natm)
    # Translate center back
    for i in range(0, natm):
        for j in range(0, 3):
            coord1[j][i] = x1[j][i] + xc0[j]
    for i in range(0, 3):
        t[i] = xc0[i] # Center of target
        for j in range(0, 3): # Minus rotated center of prediction
            mtx[i][j] = rot[i][j]
            t[i] = t[i] - rot[i][j] * xc1[j]
        vec[i] = t[i]
    return (err, mtx, vec)

def writepdb(ifilename, frag_pos, mtx, vec, cf, x, chainID, chain, ofilename):
    ifile = open(ifilename, "r")
    ofile = open(ofilename, "w")
    coord = [0.0, 0.0, 0.0]
    for aline in ifile:
        if (aline[0:6] != "ATOM  "):
            continue
        #if (aline[21] != chainID):
        #    continue
        ires = aline[22:26]
        if (int(ires) in frag_pos):
            coord[0] = float(aline[30:38])
            coord[1] = float(aline[38:46])
            coord[2] = float(aline[46:54])
            coord_moved = move(coord, mtx, vec)
            aline = aline[0:30] + "%8.3f%8.3f%8.3f%6.2f%6.2f" % (coord_moved[0], coord_moved[1], coord_moved[2], x, cf)
            aline = aline[0:21] + chain + aline[22:]
            ofile.write(aline + "\n")
    ifile.close()
    ofile.close()
    return

def move(x, mtx, vec):
    y = [0.0, 0.0, 0.0]
    ret = x[:]
    for j in range(0, 3):
        for i in range(0, 3):
            y[j] = y[j] + x[i] * mtx[j][i]
    for i in range(0, 3):
        ret[i] = y[i] + vec[i]
    return ret
