print('######################################################################################################')
print('############## You are using Modification and Annotation in Proteins (MAP) Script ####################')
print('##################### MAP Program, developed by Prof G. P. S. Raghava group. #########################')
print('############ Please cite: MAP; available at https://webs.iiitd.edu.in/raghava/maprepo/  ##############')
print('######################################################################################################')
 
import ast
import os
import sys
import pandas as pd
import requests
from io import StringIO
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from Bio.PDB import MMCIFIO
from Bio.PDB.DSSP import DSSP
import os
from Bio.PDB import MMCIFParser, NeighborSearch, Selection, is_aa, Select,PDBParser
from Bio.PDB.PDBList import PDBList
import Bio.PDB
from Bio.PDB.Structure import Structure
from Bio.PDB.Model import Model
from Bio.PDB.DSSP import make_dssp_dict
import multiprocessing as mp
from Bio.PDB.Polypeptide import is_aa
import argparse

import warnings
from Bio.PDB.PDBExceptions import PDBConstructionWarning
# Suppress specific warning
warnings.filterwarnings("ignore", category=UserWarning, module='Bio.PDB.DSSP')
warnings.filterwarnings("ignore", category=PDBConstructionWarning)

# Argument parsing
parser = argparse.ArgumentParser(description="Epitope annotation pipeline using RSA")
parser.add_argument("-i", "--input", required=True, help="Input CSV file with PDB and antigen chain info")
# parser.add_argument("-l", "--chain_col", required=True, help="Column name for antigen chain ID")
parser.add_argument("-m", "--marker", default="{Ab:Int}", help="Epitope marker to be inserted (default: '{Ab:Int}')")
parser.add_argument("-f", "--format", choices=["c", "f"], default="c", help="Output format: 'c' for CSV, 'f' for FASTA (default: 'c')")
parser.add_argument("-org", "--organism", default="Unknown organism", help="Organism for FASTA header (used if -f f)")
parser.add_argument("-d", "--description", default="Unknown function", help="Functional description for FASTA header (used if -f f)")
parser.add_argument("-n", "--name", default="Sample", help="Sample name for FASTA header (used if -f f)")
parser.add_argument("-o", "--output", type=str, default="Output", help="Base name for output files")


args = parser.parse_args()


def fetch_pdb_file(pdb_id, save_path):
    url = f'https://files.rcsb.org/download/{pdb_id}.cif'
    response = requests.get(url)
    if response.status_code == 200:
        pdb_text = response.text
        pdb_file = StringIO(pdb_text)

        # Save the PDB file to the specified path
        with open(save_path, 'w') as file:
            file.write(pdb_text)

        parser = MMCIFParser()
        structure = parser.get_structure(pdb_id, pdb_file)
        model = structure[0]
        return model
    else:
        raise Exception(f"Failed to fetch PDB file for {pdb_id}")
        
def calculate_rsa(model,pdb_path):
    # Initialize DSSP object
    dssp = DSSP(model, pdb_path,dssp='mkdssp')
    chain_residue_rsa = pd.DataFrame(columns=['Chain', 'Residue', 'RSA', 'SS'])
    for (chain_id, residue_id) in dssp.keys():
      residue = dssp[(chain_id,residue_id)][1]
      if residue == 'X':
        continue
      rsa = dssp[(chain_id,residue_id)][3]
      ss = dssp[(chain_id,residue_id)][2]
      chain_residue_rsa.loc[len(chain_residue_rsa)] = [chain_id,residue,rsa,ss]
    return chain_residue_rsa


import gemmi

def extract_chain(pdb_file, chain_id, output_file):

    st = gemmi.read_structure(pdb_file)
    # deletes chains other than D
    for model in st:
        for chain in model:
            if chain.name != chain_id:
                del chain[:]
    st.write_minimal_pdb(output_file)
    
# Function to determine epitope residues based on changes in solvent accessibility
def identify_epitope_residues(rsa_unbound, rsa_bound, threshold=0.05):
  epitopic_residue = pd.DataFrame(columns=['Residue', 'Epitopic', 'Secondary Structure'])
  for i in range(len(rsa_unbound)):
    if abs(float(rsa_unbound.iloc[i,2]) - float(rsa_bound.iloc[i,2])) > threshold:
      epitopic_residue.loc[len(epitopic_residue)] = [rsa_unbound.iloc[i,1], 1, rsa_unbound.iloc[i,3]]
    else:
      epitopic_residue.loc[len(epitopic_residue)] = [rsa_unbound.iloc[i,1], 0, rsa_unbound.iloc[i,3]]
  return epitopic_residue

def process_row(row):
    try:
        # Extract PDB ID and antigen chain ID from the input row
        pdb_id = row['pdb']
        antigen_chain_id = row['antigen_chain']

        # Define the path to the full structure file (.cif format)
        pdb_path = f"dssp/{pdb_id}.cif"
        
        if isinstance(antigen_chain_id, str):
            antigen_chain_id = antigen_chain_id[0]
        else:
            # Return zeros if chain is not valid (e.g., NaN or empty)
            return 0, 0, 0

        # Download or load the full PDB model structure
        model = fetch_pdb_file(pdb_id, pdb_path)

        # Compute RSA (Relative Surface Accessibility) for all chains
        rsa_data = calculate_rsa(model, pdb_path)

        # Filter RSA data to only include the specified antigen chain
        rsa_data = rsa_data[rsa_data['Chain'] == antigen_chain_id]

        # Define path to save extracted chain (antigen only)
        pdb_path_unbound = f"dssp/{pdb_id}_antigen.pdb"

        # Extract the antigen chain from full structure and save to file
        extract_chain(pdb_path, antigen_chain_id, pdb_path_unbound)

        # Parse the unbound structure using Biopython's PDBParser
        p = PDBParser(QUIET=True)
        structure_unbound = p.get_structure("Antigen", pdb_path_unbound)

        # Take the first model (usually model 0)
        model_unbound = structure_unbound[0]

        # Calculate RSA for the extracted antigen chain (unbound)
        rsa_data_unbound = calculate_rsa(model_unbound, pdb_path_unbound)
        print(rsa_data_unbound)  # Optional debug print

        # (Optional) Difference between bound and unbound RSA values (not used later)
        t = [x for x in rsa_data if x not in rsa_data_unbound]

        # Identify epitope residues by comparing bound vs unbound RSA
        dat = identify_epitope_residues(rsa_data, rsa_data_unbound)

        # Extract relevant columns as lists
        antigen_sequence_list = dat["Residue"].to_list()
        epitope_list = dat["Epitopic"].to_list()
        ss = dat["Secondary Structure"].to_list()

        # Clean up temporary files
        os.remove(pdb_path)
        os.remove(pdb_path_unbound)

        # Return the sequence, epitope labels, and secondary structure
        return antigen_sequence_list, epitope_list, ss

    except Exception as e:
        # Print error message and return placeholder values
        print(f"Error processing row: {e}")
        return 0, 0, 0
    

    # ------------------------ MAIN ------------------------ #

# ------------------------ LOAD INPUT FILE ------------------------ #
try:
    df = pd.read_csv(args.input)
    df.columns = ['pdb', 'antigen_chain']
    print("🔄 Ignored original headers. Set columns to: ['pdb', 'antigen_chain']")
except Exception as e:
    print(f"❌ Failed to read input file: {e}")
    sys.exit(1)

df['Residues'], df['Epitopic'], df['Secondary Structure'] = None, None, None

results = []
for _, row in tqdm(df.iterrows(), total=len(df)):
    results.append(process_row(row))

for i, (residues, epitope, ss) in enumerate(results):
    if residues != 0:
        df.at[i, 'Residues'] = residues
        df.at[i, 'Epitopic'] = epitope
        df.at[i, 'Secondary Structure'] = ss

df.to_csv('out1.csv', index=False)




# ------------------------ ANNOTATE SEQUENCES ------------------------ #

# Load the CSV file (no header)
file_path = "out1.csv"
df = pd.read_csv(file_path, header=None)

# Prepare list to store formatted sequences
formatted_sequences = []

# Loop through each data row (skip header at row 0)
for i in range(1, len(df)):
    try:
        # Convert string list to actual lists
        residues = ast.literal_eval(df.iloc[i, 2])
        epitopic = ast.literal_eval(df.iloc[i, 3])

        sequence = ""
        in_epitope = False

        for res, epi in zip(residues, epitopic):
            if epi == 1:
                sequence += res + args.marker
                in_epitope = False  # reset since no wrapping
            else:
                sequence += res
        formatted_sequences.append(sequence)

    except Exception as e:
        formatted_sequences.append("ERROR")

# ------------------------ OUTPUT FORMAT HANDLING ------------------------ #

if args.format == "f":

    output_file = args.output
    if not output_file.endswith(".fasta"):
        output_file += ".fasta"
    with open(output_file, "w") as f:
        for i, seq in enumerate(formatted_sequences):
            header = f">{args.name}_{i+1} {{org:{args.organism}}} {{func:{args.description}}}"
            f.write(header + "\n" + seq + "\n")
    print(f"✅ FASTA file saved as: {args.output}.fasta")

else:
    df.at[0, 5] = "Annotated Sequence"
    for i, seq in enumerate(formatted_sequences):
        df.at[i + 1, 5] = seq
    df.to_csv(args.output, index=False, header=False)
    print(f"✅ CSV file saved as: {args.output}.csv")

os.remove('out1.csv')
