from rdkit import Chem
import re
import pandas as pd
import warnings

warnings.filterwarnings('ignore')

##HELM to MAP
# Load the MAP monomer library
df1 = pd.read_csv('data/MAP_momomers_lib.csv')
map_to_helm_dict = df1.set_index('MAP_denotion')['Symbol'].sort_index(ascending=False).to_dict()



##MAP to HELM sequence
def process_HELM_seq(helm_seq, ID):
    # print(helm_seq)
    start = helm_seq.index('{')
    end = helm_seq.index('}')
    cyc_seq = helm_seq[start:end]
    seq_len = len(helm_seq[end+1:].split('.'))
    cyc_list = cyc_seq.split('-')
    start_pos = cyc_list[0][-1]
    end_pos = cyc_list[1]

    if start_pos == 'N' and end_pos == 'C':
        return f'PEPTIDE{ID}{{{helm_seq[end+1:]}}}$PEPTIDE{ID},PEPTIDE{ID},1:R1-{seq_len}:R2$$$'
    elif start_pos != '1' and end_pos == str(seq_len):
        return f'PEPTIDE{ID}{{{helm_seq[end+1:]}}}$PEPTIDE{ID},PEPTIDE{ID},{start_pos}:R3-{seq_len}:R2$$$'
    elif start_pos == '1' and end_pos != str(seq_len):
        return f'PEPTIDE{ID}{{{helm_seq[end+1:]}}}$PEPTIDE{ID},PEPTIDE{ID},1:R1-{end_pos}:R3$$$'
    else:
        return f'PEPTIDE{ID}{{{helm_seq[end+1:]}}}$PEPTIDE{ID},PEPTIDE{ID},{start_pos}:R3-{end_pos}:R3$$$'

def convert_map_to_helm_sequence(map_str, ID):
    nterm_pattern = r'\{nt:[^}]+\}'
    cyc_pattern = r'\{cyc:\s*([N]|\d+)-([C]|\d+)\}'
    string = ''
    nterm_modifications = re.findall(nterm_pattern, map_str)
    map_str = re.sub(nterm_pattern, '', map_str)
    cyc_string = re.search(cyc_pattern, map_str)
    # print("cyc_string",cyc_string[0])
    map_str = re.sub(cyc_pattern, '', map_str)
    string += ''.join(cyc_string[0]) + ''.join(nterm_modifications) + map_str

    tokens = []
    i = 0
    while i < len(string):
        matched = False
        for key in map_to_helm_dict.keys():
            if string[i:].startswith(key):
                if string[i-4:i] == 'cyc:':
                    val = map_to_helm_dict[key]
                    token = f'[{val}]' if len(val) > 1 else f'{val}'
                    tokens.append(token)
                    i += len(key)
                    matched = True
                    break
                else:
                    val = map_to_helm_dict[key]
                    token = f'[{val}].' if len(val) > 1 else f'{val}.'
                    tokens.append(token)
                    i += len(key)
                    matched = True
                    break

        if not matched:
            tokens.append(string[i])
            i += 1
    helm_seq = ''.join(tokens).rstrip('.')
    return helm_seq