from rdkit import Chem
import re
import pandas as pd
import warnings

warnings.filterwarnings('ignore')

##HELM to MAP
# Load the MAP monomer library
df1 = pd.read_csv('data/MAP_momomers_lib.csv')
map_to_helm_dict = df1.set_index('MAP_denotion')['Symbol'].sort_index(ascending=False).to_dict()

def helm_to_map(helm):
    try:
        start = helm.index('{') + 1
        end = helm.index('}')
        helm_sequence = helm[start:end]
        elements = [elem.strip('[]') for elem in helm_sequence.split('.')]
        num_elements = len(elements)
        map_format = ''
        for element in elements:
            map_notation = df1.loc[df1['Symbol'] == element, 'MAP_denotion']
            if not map_notation.empty:
                map_format += map_notation.values[0]
        dollar_split = helm.split('$')
        if len(dollar_split) > 2:
            last_part = dollar_split[1]
            if len(last_part) > 0:
                last_element = last_part.split(',')[-1]
                first_part = last_element.split(':')[0]
                second_part = last_element.split(':')[1].split('-')[1]

                if int(first_part) == 1 and int(second_part)==num_elements:
                    cyc_string = f'N-C'
                else:
                    cyc_string = f'{first_part}-{second_part}'
                final_output = f'{map_format}'
                nterm_pattern = r'\{nt:[^}]+\}'
                cterm_pattern = r'\{ct:[^}]+\}'
                nterm_modifications = re.findall(nterm_pattern, final_output)
                final_output = re.sub(nterm_pattern, '', final_output)
                cterm_modifications = re.findall(cterm_pattern, final_output)
                final_output = re.sub(cterm_pattern, '', final_output)
                final_output += ''.join(nterm_modifications) + ''.join(cterm_modifications)
                return f'{final_output}{{cyc:{cyc_string}}}'
            else:
                final_output = f'{map_format}'
                nterm_pattern = r'\{nt:[^}]+\}'
                cterm_pattern = r'\{ct:[^}]+\}'
                nterm_modifications = re.findall(nterm_pattern, final_output)
                final_output = re.sub(nterm_pattern, '', final_output)
                cterm_modifications = re.findall(cterm_pattern, final_output)
                final_output = re.sub(cterm_pattern, '', final_output)
                final_output += ''.join(nterm_modifications) + ''.join(cterm_modifications)
                return final_output
    except Exception as e:
        return f"ERROR: {e}"
    return ''


