msaexplorer.explore

Explore module

This module contains the two classes MSA and Annotation which are used to read in the respective files and can be used to compute several statistics or be used as the input for the draw module functions.

Classes

   1"""
   2# Explore module
   3
   4This module contains the two classes `MSA` and `Annotation` which are used to read in the respective files
   5and can be used to compute several statistics or be used as the input for the `draw` module functions.
   6
   7## Classes
   8
   9"""
  10
  11# built-in
  12import os
  13import io
  14import math
  15import collections
  16import re
  17from typing import Callable, Dict
  18
  19# installed
  20import numpy as np
  21from numpy import ndarray
  22
  23# msaexplorer
  24from msaexplorer import config
  25
  26
  27def _get_line_iterator(source):
  28    """
  29    allow reading in both raw string or paths
  30    """
  31    if isinstance(source, str) and os.path.exists(source):
  32        return open(source, 'r')
  33    else:
  34        return io.StringIO(source)
  35
  36class MSA:
  37    """
  38    An alignment class that allows computation of several stats
  39    """
  40
  41    def __init__(self, alignment_string: str, reference_id: str = None, zoom_range: tuple | int = None):
  42        """
  43        Initialise an Alignment object.
  44        :param alignment_string: Path to alignment file or raw alignment string
  45        :param reference_id: reference id
  46        :param zoom_range: start and stop positions to zoom into the alignment
  47        """
  48        self._alignment = self._read_alignment(alignment_string)
  49        self._reference_id = self._validate_ref(reference_id, self._alignment)
  50        self._zoom = self._validate_zoom(zoom_range, self._alignment)
  51        self._aln_type = self._determine_aln_type(self._alignment)
  52
  53    # TODO: read in different alignment types
  54    # Static methods
  55    @staticmethod
  56    def _read_alignment(file_path: str) -> dict:
  57        """
  58        Parse MSA alignment file.
  59        :param file_path: path to alignment file
  60        :return: dictionary with ids as keys and sequences as values
  61        """
  62
  63        def add_seq(aln: dict, sequence_id: str, seq_list: list):
  64            """
  65            Add a complete sequence and check for non-allowed chars
  66            :param aln: alignment dictionary to build
  67            :param sequence_id: sequence id to add
  68            :param seq_list: sequences to add
  69            :return: alignment with added sequences
  70            """
  71            final_seq = ''.join(seq_list).upper()
  72            # Check for non-allowed characters
  73            invalid_chars = set(final_seq) - set(config.POSSIBLE_CHARS)
  74            if invalid_chars:
  75                raise ValueError(
  76                    f"{sequence_id} contains invalid characters: {', '.join(invalid_chars)}. Allowed chars are: {config.POSSIBLE_CHARS}"
  77                )
  78            aln[sequence_id] = final_seq
  79
  80            return aln
  81
  82        alignment, seq_lines = {}, []
  83        seq_id = None
  84
  85        with _get_line_iterator(file_path) as file:
  86            for i, line in enumerate(file):
  87                line = line.strip()
  88                # initial check for fasta format
  89                if i == 0 and not line.startswith(">"):
  90                    raise ValueError('Alignment has to be in fasta format starting with >SeqID.')
  91                if line.startswith(">"):
  92                    if seq_id:
  93                        alignment = add_seq(alignment, seq_id, seq_lines)
  94                    # initialize a new sequence
  95                    seq_id, seq_lines = line[1:], []
  96                else:
  97                    seq_lines.append(line)
  98            # handle last sequence
  99            if seq_id:
 100                alignment = add_seq(alignment, seq_id, seq_lines)
 101        # final sanity checks
 102        if alignment:
 103            # alignment contains only one sequence:
 104            if len(alignment) < 2:
 105                raise ValueError("Alignment must contain more than one sequence.")
 106            # alignment sequences are not same length
 107            first_seq_len = len(next(iter(alignment.values())))
 108            for sequence_id, sequence in alignment.items():
 109                if len(sequence) != first_seq_len:
 110                    raise ValueError(
 111                        f"All alignment sequences must have the same length. Sequence '{sequence_id}' has length {len(sequence)}, expected {first_seq_len}."
 112                    )
 113            # all checks passed
 114            return alignment
 115        else:
 116            raise ValueError(f"Alignment file {file_path} does not contain any sequences in fasta format.")
 117
 118    @staticmethod
 119    def _validate_ref(reference: str | None, alignment: dict) -> str | None | ValueError:
 120        """
 121        Validate if the ref seq is indeed part of the alignment.
 122        :param reference: reference seq id
 123        :param alignment: alignment dict
 124        :return: validated reference
 125        """
 126        if reference in alignment.keys():
 127            return reference
 128        elif reference is None:
 129            return reference
 130        else:
 131            raise ValueError('Reference not in alignment.')
 132
 133    @staticmethod
 134    def _validate_zoom(zoom: tuple | int, original_aln: dict) -> ValueError | tuple | None:
 135        """
 136        Validates if the user defined zoom range is within the start, end of the initial
 137        alignment.\n
 138        :param zoom: zoom range or zoom start
 139        :param original_aln: non-zoomed alignment dict
 140        :return: validated zoom range
 141        """
 142        if zoom is not None:
 143            aln_length = len(original_aln[list(original_aln.keys())[0]])
 144            # check if only over value is provided -> stop is alignment length
 145            if isinstance(zoom, int):
 146                if 0 <= zoom < aln_length:
 147                    return zoom, aln_length - 1
 148                else:
 149                    raise ValueError('Zoom start must be within the alignment length range.')
 150            # check if more than 2 values are provided
 151            if len(zoom) != 2:
 152                raise ValueError('Zoom position have to be (zoom_start, zoom_end)')
 153            # validate zoom start/stop
 154            for position in zoom:
 155                if type(position) != int:
 156                    raise ValueError('Zoom positions have to be integers.')
 157                if position not in range(0, aln_length):
 158                    raise ValueError('Zoom position out of range')
 159
 160        return zoom
 161
 162    @staticmethod
 163    def _determine_aln_type(alignment) -> str:
 164        """
 165        Determine the most likely type of alignment
 166        if 70% of chars in the alignment are nucleotide
 167        chars it is most likely a nt alignment
 168        :return: type of alignment
 169        """
 170        counter = int()
 171        for record in alignment:
 172            if 'U' in alignment[record]:
 173                return 'RNA'
 174            counter += sum(map(alignment[record].count, ['A', 'C', 'G', 'T', 'N', '-']))
 175        # determine which is the most likely type
 176        if counter / len(alignment) >= 0.7 * len(alignment[list(alignment.keys())[0]]):
 177            return 'DNA'
 178        else:
 179            return 'AA'
 180
 181    # Properties with setters
 182    @property
 183    def reference_id(self):
 184        return self._reference_id
 185
 186    @reference_id.setter
 187    def reference_id(self, ref_id: str):
 188        """
 189        Set and validate the reference id.
 190        """
 191        self._reference_id = self._validate_ref(ref_id, self.alignment)
 192
 193    @property
 194    def zoom(self) -> tuple:
 195        return self._zoom
 196
 197    @zoom.setter
 198    def zoom(self, zoom_pos: tuple | int):
 199        """
 200        Validate if the user defined zoom range.
 201        """
 202        self._zoom = self._validate_zoom(zoom_pos, self._alignment)
 203
 204    # Property without setters
 205    @property
 206    def aln_type(self) -> str:
 207        """
 208        define the aln type:
 209        RNA, DNA or AA
 210        """
 211        return self._aln_type
 212
 213    # On the fly properties without setters
 214    @property
 215    def length(self) -> int:
 216        return len(next(iter(self.alignment.values())))
 217
 218    @property
 219    def alignment(self) -> dict:
 220        """
 221        (zoomed) version of the alignment.
 222        """
 223        if self.zoom is not None:
 224            zoomed_aln = dict()
 225            for seq in self._alignment:
 226                zoomed_aln[seq] = self._alignment[seq][self.zoom[0]:self.zoom[1]]
 227            return zoomed_aln
 228        else:
 229            return self._alignment
 230
 231    # functions for different alignment stats
 232    def get_reference_coords(self) -> tuple[int, int]:
 233        """
 234        Determine the start and end coordinates of the reference sequence
 235        defined as the first/last nucleotide in the reference sequence
 236        (excluding N and gaps).
 237
 238        :return: start, end
 239        """
 240        start, end = 0, self.length
 241
 242        if self.reference_id is None:
 243            return start, end
 244        else:
 245            # 5' --> 3'
 246            for start in range(self.length):
 247                if self.alignment[self.reference_id][start] not in ['-', 'N']:
 248                    break
 249            # 3' --> 5'
 250            for end in range(self.length - 1, 0, -1):
 251                if self.alignment[self.reference_id][end] not in ['-', 'N']:
 252                    break
 253
 254            return start, end
 255
 256    def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
 257        """
 258        Creates a non-gapped consensus sequence.
 259
 260        :param threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes
 261            the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments)
 262            or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
 263        :param use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position
 264            has a frequency above the defined threshold.
 265        :return: consensus sequence
 266        """
 267
 268        # helper functions
 269        def determine_counts(alignment_dict: dict, position: int) -> dict:
 270            """
 271            count the number of each char at
 272            an idx of the alignment. return sorted dic.
 273            handles ambiguous nucleotides in sequences.
 274            also handles gaps.
 275            """
 276            nucleotide_list = []
 277
 278            # get all nucleotides
 279            for sequence in alignment_dict.items():
 280                nucleotide_list.append(sequence[1][position])
 281            # count occurences of nucleotides
 282            counter = dict(collections.Counter(nucleotide_list))
 283            # get permutations of an ambiguous nucleotide
 284            to_delete = []
 285            temp_dict = {}
 286            for nucleotide in counter:
 287                if nucleotide in config.AMBIG_CHARS[self.aln_type]:
 288                    to_delete.append(nucleotide)
 289                    permutations = config.AMBIG_CHARS[self.aln_type][nucleotide]
 290                    adjusted_freq = 1 / len(permutations)
 291                    for permutation in permutations:
 292                        if permutation in temp_dict:
 293                            temp_dict[permutation] += adjusted_freq
 294                        else:
 295                            temp_dict[permutation] = adjusted_freq
 296
 297            # drop ambiguous entries and add adjusted freqs to
 298            if to_delete:
 299                for i in to_delete:
 300                    counter.pop(i)
 301                for nucleotide in temp_dict:
 302                    if nucleotide in counter:
 303                        counter[nucleotide] += temp_dict[nucleotide]
 304                    else:
 305                        counter[nucleotide] = temp_dict[nucleotide]
 306
 307            return dict(sorted(counter.items(), key=lambda x: x[1], reverse=True))
 308
 309        def get_consensus_char(counts: dict, cutoff: float) -> list:
 310            """
 311            get a list of nucleotides for the consensus seq
 312            """
 313            n = 0
 314
 315            consensus_chars = []
 316            for char in counts:
 317                n += counts[char]
 318                consensus_chars.append(char)
 319                if n >= cutoff:
 320                    break
 321
 322            return consensus_chars
 323
 324        def get_ambiguous_char(nucleotides: list) -> str:
 325            """
 326            get ambiguous char from a list of nucleotides
 327            """
 328            for ambiguous, permutations in config.AMBIG_CHARS[self.aln_type].items():
 329                if set(permutations) == set(nucleotides):
 330                    break
 331
 332            return ambiguous
 333
 334        # check if params have been set correctly
 335        if threshold is not None:
 336            if threshold < 0 or threshold > 1:
 337                raise ValueError('Threshold must be between 0 and 1.')
 338        if self.aln_type == 'AA' and use_ambig_nt:
 339            raise ValueError('Ambiguous characters can not be calculated for amino acid alignments.')
 340        if threshold is None and use_ambig_nt:
 341            raise ValueError('To calculate ambiguous nucleotides, set a threshold > 0.')
 342
 343        alignment = self.alignment
 344        consensus = str()
 345
 346        if threshold is not None:
 347            consensus_cutoff = len(alignment) * threshold
 348        else:
 349            consensus_cutoff = 0
 350
 351        # built consensus sequences
 352        for idx in range(self.length):
 353            char_counts = determine_counts(alignment, idx)
 354            consensus_chars = get_consensus_char(
 355                char_counts,
 356                consensus_cutoff
 357            )
 358            if threshold != 0:
 359                if len(consensus_chars) > 1:
 360                    if use_ambig_nt:
 361                        char = get_ambiguous_char(consensus_chars)
 362                    else:
 363                        if self.aln_type == 'AA':
 364                            char = 'X'
 365                        else:
 366                            char = 'N'
 367                    consensus = consensus + char
 368                else:
 369                    consensus = consensus + consensus_chars[0]
 370            else:
 371                consensus = consensus + consensus_chars[0]
 372
 373        return consensus
 374
 375    def get_conserved_orfs(self, min_length: int = 100, identity_cutoff: float | None = None) -> dict:
 376        """
 377        **conserved ORF definition:**
 378            - conserved starts and stops
 379            - start, stop must be on the same frame
 380            - stop - start must be at least min_length
 381            - all ungapped seqs[start:stop] must have at least min_length
 382            - no ungapped seq can have a Stop in between Start Stop
 383
 384        Conservation is measured by number of positions with identical characters divided by
 385        orf slice of the alignment.
 386
 387        **Algorithm overview:**
 388            - check for conserved start and stop codons
 389            - iterate over all three frames
 390            - check each start and next sufficiently far away stop codon
 391            - check if all ungapped seqs between start and stop codon are >= min_length
 392            - check if no ungapped seq in the alignment has a stop codon
 393            - write to dictionary
 394            - classify as internal if the stop codon has already been written with a prior start
 395            - repeat for reverse complement
 396
 397        :return: ORF positions and internal ORF positions
 398        """
 399
 400        # helper functions
 401        def determine_conserved_start_stops(alignment: dict, alignment_length: int) -> tuple:
 402            """
 403            Determine all start and stop codons within an alignment.
 404            :param alignment: alignment
 405            :param alignment_length: length of alignment
 406            :return: start and stop codon positions
 407            """
 408            starts = config.START_CODONS[self.aln_type]
 409            stops = config.STOP_CODONS[self.aln_type]
 410
 411            list_of_starts, list_of_stops = [], []
 412            ref = alignment[list(alignment.keys())[0]]
 413            for nt_position in range(alignment_length):
 414                if ref[nt_position:nt_position + 3] in starts:
 415                    conserved_start = True
 416                    for sequence in alignment:
 417                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in starts:
 418                            conserved_start = False
 419                            break
 420                    if conserved_start:
 421                        list_of_starts.append(nt_position)
 422
 423                if ref[nt_position:nt_position + 3] in stops:
 424                    conserved_stop = True
 425                    for sequence in alignment:
 426                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in stops:
 427                            conserved_stop = False
 428                            break
 429                    if conserved_stop:
 430                        list_of_stops.append(nt_position)
 431
 432            return list_of_starts, list_of_stops
 433
 434        def get_ungapped_sliced_seqs(alignment: dict, start_pos: int, stop_pos: int) -> list:
 435            """
 436            get ungapped sequences starting and stop codons and eliminate gaps
 437            :param alignment: alignment
 438            :param start_pos: start codon
 439            :param stop_pos: stop codon
 440            :return: sliced sequences
 441            """
 442            ungapped_seqs = []
 443            for seq_id in alignment:
 444                ungapped_seqs.append(alignment[seq_id][start_pos:stop_pos + 3].replace('-', ''))
 445
 446            return ungapped_seqs
 447
 448        def additional_stops(ungapped_seqs: list) -> bool:
 449            """
 450            Checks for the presence of a stop codon
 451            :param ungapped_seqs: list of ungapped sequences
 452            :return: Additional stop codons (True/False)
 453            """
 454            stops = config.STOP_CODONS[self.aln_type]
 455
 456            for sliced_seq in ungapped_seqs:
 457                for position in range(0, len(sliced_seq) - 3, 3):
 458                    if sliced_seq[position:position + 3] in stops:
 459                        return True
 460            return False
 461
 462        def calculate_identity(identity_matrix: ndarray, aln_slice:list) -> float:
 463            sliced_array = identity_matrix[:,aln_slice[0]:aln_slice[1]] + 1  # identical = 0, different = -1 --> add 1
 464            return np.sum(np.all(sliced_array == 1, axis=0))/(aln_slice[1] - aln_slice[0]) * 100
 465
 466        # checks for arguments
 467        if self.aln_type == 'AA':
 468            raise TypeError('ORF search only for RNA/DNA alignments')
 469
 470        if identity_cutoff is not None:
 471            if identity_cutoff > 100 or identity_cutoff < 0:
 472                raise ValueError('conservation cutoff must be between 0 and 100')
 473
 474        if min_length <= 6 or min_length > self.length:
 475            raise ValueError(f'min_length must be between 6 and {self.length}')
 476
 477        # ini
 478        identities = self.calc_identity_alignment()
 479        alignments = [self.alignment, self.calc_reverse_complement_alignment()]
 480        aln_len = self.length
 481
 482        orf_counter = 0
 483        orf_dict = {}
 484
 485        for aln, direction in zip(alignments, ['+', '-']):
 486            # check for starts and stops in the first seq and then check if these are present in all seqs
 487            conserved_starts, conserved_stops = determine_conserved_start_stops(aln, aln_len)
 488            # check each frame
 489            for frame in (0, 1, 2):
 490                potential_starts = [x for x in conserved_starts if x % 3 == frame]
 491                potential_stops = [x for x in conserved_stops if x % 3 == frame]
 492                last_stop = -1
 493                for start in potential_starts:
 494                    # go to the next stop that is sufficiently far away in the alignment
 495                    next_stops = [x for x in potential_stops if x + 3 >= start + min_length]
 496                    if not next_stops:
 497                        continue
 498                    next_stop = next_stops[0]
 499                    ungapped_sliced_seqs = get_ungapped_sliced_seqs(aln, start, next_stop)
 500                    # re-check the lengths of all ungapped seqs
 501                    ungapped_seq_lengths = [len(x) >= min_length for x in ungapped_sliced_seqs]
 502                    if not all(ungapped_seq_lengths):
 503                        continue
 504                    # if no stop codon between start and stop --> write to dictionary
 505                    if not additional_stops(ungapped_sliced_seqs):
 506                        if direction == '+':
 507                            positions = [start, next_stop + 3]
 508                        else:
 509                            positions = [aln_len - next_stop - 3, aln_len - start]
 510                        if last_stop != next_stop:
 511                            last_stop = next_stop
 512                            conservation = calculate_identity(identities, positions)
 513                            if identity_cutoff is not None and conservation < identity_cutoff:
 514                                continue
 515                            orf_dict[f'ORF_{orf_counter}'] = {'location': [positions],
 516                                                              'frame': frame,
 517                                                              'strand': direction,
 518                                                              'conservation': conservation,
 519                                                              'internal': []
 520                                                              }
 521                            orf_counter += 1
 522                        else:
 523                            if orf_dict:
 524                                orf_dict[f'ORF_{orf_counter - 1}']['internal'].append(positions)
 525
 526        return orf_dict
 527
 528    def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff:float = None) -> dict:
 529        """
 530        First calculates all ORFs and then searches from 5'
 531        all non-overlapping orfs in the fw strand and from the
 532        3' all non-overlapping orfs in th rw strand.
 533
 534        **No overlap algorithm:**
 535            **frame 1:** -[M------*]--- ----[M--*]---------[M-----
 536
 537            **frame 2:** -------[M------*]---------[M---*]--------
 538
 539            **frame 3:** [M---*]-----[M----------*]----------[M---
 540
 541            **results:** [M---*][M------*]--[M--*]-[M---*]-[M-----
 542
 543            frame:    3      2           1      2       1
 544
 545        :return: dictionary with non-overlapping orfs
 546        """
 547        orf_dict = self.get_conserved_orfs(min_length, identity_cutoff)
 548
 549        fw_orfs, rw_orfs = [], []
 550
 551        for orf in orf_dict:
 552            if orf_dict[orf]['strand'] == '+':
 553                fw_orfs.append((orf, orf_dict[orf]['location'][0]))
 554            else:
 555                rw_orfs.append((orf, orf_dict[orf]['location'][0]))
 556
 557        fw_orfs.sort(key=lambda x: x[1][0])  # sort by start pos
 558        rw_orfs.sort(key=lambda x: x[1][1], reverse=True)  # sort by stop pos
 559        non_overlapping_orfs = []
 560        for orf_list, strand in zip([fw_orfs, rw_orfs], ['+', '-']):
 561            previous_stop = -1 if strand == '+' else self.length + 1
 562            for orf in orf_list:
 563                if strand == '+' and orf[1][0] > previous_stop:
 564                    non_overlapping_orfs.append(orf[0])
 565                    previous_stop = orf[1][1]
 566                elif strand == '-' and orf[1][1] < previous_stop:
 567                    non_overlapping_orfs.append(orf[0])
 568                    previous_stop = orf[1][0]
 569
 570        non_overlap_dict = {}
 571        for orf in orf_dict:
 572            if orf in non_overlapping_orfs:
 573                non_overlap_dict[orf] = orf_dict[orf]
 574
 575        return non_overlap_dict
 576
 577    def calc_length_stats(self) -> dict:
 578        """
 579        Determine the stats for the length of the ungapped seqs in the alignment.
 580        :return: dictionary with length stats
 581        """
 582
 583        seq_lengths = [len(self.alignment[x].replace('-', '')) for x in self.alignment]
 584
 585        return {'number of seq': len(self.alignment),
 586                'mean length': float(np.mean(seq_lengths)),
 587                'std length': float(np.std(seq_lengths)),
 588                'min length': int(np.min(seq_lengths)),
 589                'max length': int(np.max(seq_lengths))
 590                }
 591
 592    def calc_entropy(self) -> list:
 593        """
 594        Calculate the normalized shannon's entropy for every position in an alignment:
 595
 596        - 1: high entropy
 597        - 0: low entropy
 598
 599        :return: Entropies at each position.
 600        """
 601
 602        # helper functions
 603        def shannons_entropy(character_list: list, states: int, aln_type: str) -> float:
 604            """
 605            Calculate the shannon's entropy of a sequence and
 606            normalized between 0 and 1.
 607            :param character_list: characters at an alignment position
 608            :param states: number of potential characters that can be present
 609            :param aln_type: type of the alignment
 610            :returns: entropy
 611            """
 612            ent, n_chars = 0, len(character_list)
 613            # only one char is in the list
 614            if n_chars <= 1:
 615                return ent
 616            # calculate the number of unique chars and their counts
 617            chars, char_counts = np.unique(character_list, return_counts=True)
 618            char_counts = char_counts.astype(float)
 619            # ignore gaps for entropy calc
 620            char_counts, chars = char_counts[chars != "-"], chars[chars != "-"]
 621            # correctly handle ambiguous chars
 622            index_to_drop = []
 623            for index, char in enumerate(chars):
 624                if char in config.AMBIG_CHARS[aln_type]:
 625                    index_to_drop.append(index)
 626                    amb_chars, amb_counts = np.unique(config.AMBIG_CHARS[aln_type][char], return_counts=True)
 627                    amb_counts = amb_counts / len(config.AMBIG_CHARS[aln_type][char])
 628                    # add the proportionate numbers to initial array
 629                    for amb_char, amb_count in zip(amb_chars, amb_counts):
 630                        if amb_char in chars:
 631                            char_counts[chars == amb_char] += amb_count
 632                        else:
 633                            chars, char_counts = np.append(chars, amb_char), np.append(char_counts, amb_count)
 634            # drop the ambiguous characters from array
 635            char_counts, chars = np.delete(char_counts, index_to_drop), np.delete(chars, index_to_drop)
 636            # calc the entropy
 637            probs = char_counts / n_chars
 638            if np.count_nonzero(probs) <= 1:
 639                return ent
 640            for prob in probs:
 641                ent -= prob * math.log(prob, states)
 642
 643            return ent
 644
 645        aln = self.alignment
 646        entropys = []
 647
 648        if self.aln_type == 'AA':
 649            states = 20
 650        else:
 651            states = 4
 652        # iterate over alignment positions and the sequences
 653        for nuc_pos in range(self.length):
 654            pos = []
 655            for record in aln:
 656                pos.append(aln[record][nuc_pos])
 657            entropys.append(shannons_entropy(pos, states, self.aln_type))
 658
 659        return entropys
 660
 661    def calc_gc(self) -> list | TypeError:
 662        """
 663        Determine the GC content for every position in an nt alignment.
 664        :return: GC content for every position.
 665        :raises: TypeError for AA alignments
 666        """
 667        if self.aln_type == 'AA':
 668            raise TypeError("GC computation is not possible for aminoacid alignment")
 669
 670        gc, aln, amb_nucs = [], self.alignment, config.AMBIG_CHARS[self.aln_type]
 671
 672        for position in range(self.length):
 673            nucleotides = str()
 674            for record in aln:
 675                nucleotides = nucleotides + aln[record][position]
 676            # ini dict with chars that occur and which ones to
 677            # count in which freq
 678            to_count = {
 679                'G': 1,
 680                'C': 1,
 681            }
 682            # handle ambig. nuc
 683            for char in amb_nucs:
 684                if char in nucleotides:
 685                    to_count[char] = (amb_nucs[char].count('C') + amb_nucs[char].count('G')) / len(amb_nucs[char])
 686
 687            gc.append(
 688                sum([nucleotides.count(x) * to_count[x] for x in to_count]) / len(nucleotides)
 689            )
 690
 691        return gc
 692
 693    def calc_coverage(self) -> list:
 694        """
 695        Determine the coverage of every position in an alignment.
 696        This is defined as:
 697            1 - cumulative length of '-' characters
 698
 699        :return: Coverage at each alignment position.
 700        """
 701        coverage, aln = [], self.alignment
 702
 703        for nuc_pos in range(self.length):
 704            pos = str()
 705            for record in aln.keys():
 706                pos = pos + aln[record][nuc_pos]
 707            coverage.append(1 - pos.count('-') / len(pos))
 708
 709        return coverage
 710
 711    def calc_reverse_complement_alignment(self) -> dict | TypeError:
 712        """
 713        Reverse complement the alignment.
 714        :return: Alignment (rv)
 715        """
 716        if self.aln_type == 'AA':
 717            raise TypeError('Reverse complement only for RNA or DNA.')
 718
 719        aln = self.alignment
 720        reverse_complement_dict = {}
 721
 722        for seq_id in aln:
 723            reverse_complement_dict[seq_id] = ''.join(config.COMPLEMENT[base] for base in reversed(aln[seq_id]))
 724
 725        return reverse_complement_dict
 726
 727    def calc_identity_alignment(self, encode_mismatches:bool=True, encode_mask:bool=False, encode_gaps:bool=True, encode_ambiguities:bool=False, encode_each_mismatch_char:bool=False) -> np.ndarray:
 728        """
 729        Converts alignment to identity array (identical=0) compared to majority consensus or reference:\n
 730
 731        :param encode_mismatches: encode mismatch as -1
 732        :param encode_mask: encode mask with value=-2 --> also in the reference
 733        :param encode_gaps: encode gaps with np.nan --> also in the reference
 734        :param encode_ambiguities: encode ambiguities with value=-3
 735        :param encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.DNA_colors, config.RNA_colors or config.AA_colors
 736        :return: identity alignment
 737        """
 738
 739        aln = self.alignment
 740        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
 741
 742        # convert alignment to array
 743        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 744        reference = np.array(list(ref))
 745        # ini matrix
 746        identity_matrix = np.full(sequences.shape, 0, dtype=float)
 747
 748        is_identical = sequences == reference
 749
 750        if encode_gaps:
 751            is_gap = sequences == '-'
 752        else:
 753            is_gap = np.full(sequences.shape, False)
 754
 755        if encode_mask:
 756            if self.aln_type == 'AA':
 757                is_n_or_x = np.isin(sequences, ['X'])
 758            else:
 759                is_n_or_x = np.isin(sequences, ['N'])
 760        else:
 761            is_n_or_x = np.full(sequences.shape, False)
 762
 763        if encode_ambiguities:
 764            is_ambig = np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])
 765        else:
 766            is_ambig = np.full(sequences.shape, False)
 767
 768        if encode_mismatches:
 769            is_mismatch = ~is_gap & ~is_identical & ~is_n_or_x & ~is_ambig
 770        else:
 771            is_mismatch = np.full(sequences.shape, False)
 772
 773        # encode every different character
 774        if encode_each_mismatch_char:
 775            for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]['standard']):
 776                new_encoding = np.isin(sequences, [char]) & is_mismatch
 777                identity_matrix[new_encoding] = idx + 1
 778        # or encode different with a single value
 779        else:
 780            identity_matrix[is_mismatch] = -1  # mismatch
 781
 782        identity_matrix[is_gap] = np.nan  # gap
 783        identity_matrix[is_n_or_x] = -2  # 'N' or 'X'
 784        identity_matrix[is_ambig] = -3  # ambiguities
 785
 786        return identity_matrix
 787
 788    def calc_similarity_alignment(self, matrix_type:str|None=None, normalize:bool=True) -> np.ndarray:
 789        """
 790        Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight
 791        differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the
 792        reference residue at each column. Gaps are encoded as np.nan.
 793
 794        The calculation follows these steps:
 795
 796        1. **Reference Sequence**: If a reference sequence is provided (via `self.reference_id`), it is used. Otherwise,
 797           a consensus sequence is generated to serve as the reference.
 798        2. **Substitution Matrix**: The similarity between residues is determined using a substitution matrix, such as
 799           BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
 800        3. **Per-Column Normalization (optional)**:
 801           - For each column in the alignment:
 802             - The residue in the reference sequence is treated as the baseline for that column.
 803             - The substitution scores for the reference residue are extracted from the substitution matrix.
 804             - The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference
 805               residue.
 806           - This ensures that identical residues (or those with high similarity to the reference) have high scores,
 807             while more dissimilar residues have lower scores.
 808        4. **Output**:
 809           - The normalized similarity scores are stored in a NumPy array.
 810           - Gaps (if any) or residues not present in the substitution matrix are encoded as `np.nan`.
 811
 812        :param: matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
 813        :param: normalize: whether to normalize the similarity scores to range [0, 1]
 814        :return: A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue
 815            and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity).
 816            Gaps and invalid residues are encoded as `np.nan`.
 817        :raise: ValueError
 818            If the specified substitution matrix is not available for the given alignment type.
 819        """
 820
 821        aln = self.alignment
 822        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
 823        if matrix_type is None:
 824            if self.aln_type == 'AA':
 825                matrix_type = 'BLOSUM65'
 826            else:
 827                matrix_type = 'TRANS'
 828        # load substitution matrix as dictionary
 829        try:
 830            subs_matrix = config.SUBS_MATRICES[self.aln_type][matrix_type]
 831        except KeyError:
 832            raise ValueError(
 833                f'The specified matrix does not exist for alignment type.\nAvailable matrices for {self.aln_type} are:\n{list(config.SUBS_MATRICES[self.aln_type].keys())}'
 834            )
 835
 836        # set dtype and convert alignment to a NumPy array for vectorized processing
 837        dtype = np.dtype(float, metadata={'matrix': matrix_type})
 838        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 839        reference = np.array(list(ref))
 840        valid_chars = list(subs_matrix.keys())
 841        similarity_array = np.full(sequences.shape, np.nan, dtype=dtype)
 842
 843        for j, ref_char in enumerate(reference):
 844            if ref_char not in valid_chars + ['-']:
 845                continue
 846            # Get local min and max for the reference residue
 847            if normalize and ref_char != '-':
 848                local_scores = subs_matrix[ref_char].values()
 849                local_min, local_max = min(local_scores), max(local_scores)
 850
 851            for i, char in enumerate(sequences[:, j]):
 852                if char not in valid_chars:
 853                    continue
 854                # classify the similarity as max if the reference has a gap
 855                similarity_score = subs_matrix[char][ref_char] if ref_char != '-' else 1
 856                similarity_array[i, j] = (similarity_score - local_min) / (local_max - local_min) if normalize and ref_char != '-' else similarity_score
 857
 858        return similarity_array
 859
 860    def calc_position_matrix(self, matrix_type:str='PWM') -> np.ndarray | ValueError:
 861        """
 862        Calculates a position matrix of the specified type for the given alignment. The function
 863        supports generating matrices of types Position Frequency Matrix (PFM), Position Probability
 864        Matrix (PPM), Position Weight Matrix (PWM), and cummulative Information Content (IC). It validates
 865        the provided matrix type and includes pseudo-count adjustments to ensure robust calculations.
 866
 867        :param matrix_type: Type of position matrix to calculate. Accepted values are 'PFM', 'PPM',
 868            'PWM', and 'IC'. Defaults to 'PWM'.
 869        :type matrix_type: str
 870        :raises ValueError: If the provided `matrix_type` is not one of the accepted values.
 871        :return: A numpy array representing the calculated position matrix of the specified type.
 872        :rtype: np.ndarray
 873        """
 874
 875        # ini
 876        aln = self.alignment
 877        if matrix_type not in ['PFM', 'PPM', 'IC', 'PWM']:
 878            raise ValueError('Matrix_type must be PFM, PPM, IC or PWM.')
 879        possible_chars = list(config.CHAR_COLORS[self.aln_type]['standard'].keys())[:-1]
 880        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 881
 882        # calc position frequency matrix
 883        pfm = np.array([np.sum(sequences == char, 0) for char in possible_chars])
 884        if matrix_type == 'PFM':
 885            return pfm
 886
 887        # calc position probability matrix (probability)
 888        pseudo_count = 0.0001  # to avoid 0 values
 889        pfm = pfm + pseudo_count
 890        ppm_non_char_excluded = pfm/np.sum(pfm, axis=0)  # use this for pwm/ic calculation
 891        ppm = pfm/len(aln.keys())  # calculate the frequency based on row number
 892        if matrix_type == 'PPM':
 893            return ppm
 894
 895        # calc position weight matrix (log-likelihood)
 896        pwm = np.log2(ppm_non_char_excluded * len(possible_chars))
 897        if matrix_type == 'PWM':
 898            return pwm
 899
 900        # calc information content per position (in bits) - can be used to scale a ppm for sequence logos
 901        ic = np.sum(ppm_non_char_excluded * pwm, axis=0)
 902        if matrix_type == 'IC':
 903            return ic
 904
 905        return None
 906
 907    def calc_percent_recovery(self) -> dict:
 908        """
 909        Recovery per sequence either compared to the majority consensus seq
 910        or the reference seq.\n
 911        Defined as:\n
 912
 913        `(1 - sum(N/X/- characters in ungapped ref regions))*100`
 914
 915        This is highly similar to how nextclade calculates recovery over reference.
 916
 917        :return: dict
 918        """
 919
 920        aln = self.alignment
 921
 922        if self.reference_id is not None:
 923            ref = aln[self.reference_id]
 924        else:
 925            ref = self.get_consensus()  # majority consensus
 926
 927        if not any(char != '-' for char in ref):
 928            raise ValueError("Reference sequence is entirely gapped, cannot calculate recovery.")
 929
 930
 931        # count 'N', 'X' and '-' chars in non-gapped regions
 932        recovery_over_ref = dict()
 933
 934        # Get positions of non-gap characters in the reference
 935        non_gap_positions = [i for i, char in enumerate(ref) if char != '-']
 936        cumulative_length = len(non_gap_positions)
 937
 938        # Calculate recovery
 939        for seq_id in aln:
 940            if seq_id == self.reference_id:
 941                continue
 942            seq = aln[seq_id]
 943            count_invalid = sum(
 944                seq[pos] == '-' or
 945                (seq[pos] == 'X' if self.aln_type == "AA" else seq[pos] == 'N')
 946                for pos in non_gap_positions
 947            )
 948            recovery_over_ref[seq_id] = (1 - count_invalid / cumulative_length) * 100
 949
 950        return recovery_over_ref
 951
 952    def calc_character_frequencies(self) -> dict:
 953        """
 954        Calculate the percentage characters in the alignment:
 955        The frequencies are counted by seq and in total. The
 956        percentage of non-gap characters in the alignment is
 957        relative to the total number of non-gap characters.
 958        The gap percentage is relative to the sequence length.
 959
 960        The output is a nested dictionary.
 961
 962        :return: Character frequencies
 963        """
 964
 965        aln, aln_length = self.alignment, self.length
 966
 967        freqs = {'total': {'-': {'counts': 0, '% of alignment': float()}}}
 968
 969        for seq_id in aln:
 970            freqs[seq_id], all_chars = {'-': {'counts': 0, '% of alignment': float()}}, 0
 971            unique_chars = set(aln[seq_id])
 972            for char in unique_chars:
 973                if char == '-':
 974                    continue
 975                # add characters to dictionaries
 976                if char not in freqs[seq_id]:
 977                    freqs[seq_id][char] = {'counts': 0, '% of non-gapped': 0}
 978                if char not in freqs['total']:
 979                    freqs['total'][char] = {'counts': 0, '% of non-gapped': 0}
 980                # count non-gap chars
 981                freqs[seq_id][char]['counts'] += aln[seq_id].count(char)
 982                freqs['total'][char]['counts'] += freqs[seq_id][char]['counts']
 983                all_chars += freqs[seq_id][char]['counts']
 984            # normalize counts
 985            for char in freqs[seq_id]:
 986                if char == '-':
 987                    continue
 988                freqs[seq_id][char]['% of non-gapped'] = freqs[seq_id][char]['counts'] / all_chars * 100
 989                freqs['total'][char]['% of non-gapped'] += freqs[seq_id][char]['% of non-gapped']
 990            # count gaps
 991            freqs[seq_id]['-']['counts'] = aln[seq_id].count('-')
 992            freqs['total']['-']['counts'] += freqs[seq_id]['-']['counts']
 993            # normalize gap counts
 994            freqs[seq_id]['-']['% of alignment'] = freqs[seq_id]['-']['counts'] / aln_length * 100
 995            freqs['total']['-']['% of alignment'] += freqs[seq_id]['-']['% of alignment']
 996
 997        # normalize the total counts
 998        for char in freqs['total']:
 999            for value in freqs['total'][char]:
1000                if value == '% of alignment' or value == '% of non-gapped':
1001                    freqs['total'][char][value] = freqs['total'][char][value] / len(aln)
1002
1003        return freqs
1004
1005    def calc_pairwise_identity_matrix(self, distance_type:str='ghd') -> ndarray:
1006        """
1007        Calculate pairwise identities for an alignment. As there are different definitions of sequence identity, there are different options implemented:
1008
1009        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
1010        \ndistance = matches / alignment_length * 100
1011
1012        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
1013        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
1014
1015        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
1016        \ndistance = matches / (matches + mismatches) * 100
1017
1018        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
1019        \ndistance = matches / gap_compressed_alignment_length * 100
1020
1021        :return: array with pairwise distances.
1022        """
1023
1024        def hamming_distance(seq1: str, seq2: str) -> int:
1025            return sum(c1 == c2 for c1, c2 in zip(seq1, seq2))
1026
1027        def ghd(seq1: str, seq2: str) -> float:
1028            return hamming_distance(seq1, seq2) / self.length * 100
1029
1030        def lhd(seq1, seq2):
1031            # Trim gaps from both sides
1032            i, j = 0, self.length - 1
1033            while i < self.length and (seq1[i] == '-' or seq2[i] == '-'):
1034                i += 1
1035            while j >= 0 and (seq1[j] == '-' or seq2[j] == '-'):
1036                j -= 1
1037            if i > j:
1038                return 0.0
1039
1040            seq1_, seq2_ = seq1[i:j + 1], seq2[i:j + 1]
1041            matches = sum(c1 == c2 for c1, c2 in zip(seq1_, seq2_))
1042            length = j - i + 1
1043            return (matches / length) * 100 if length > 0 else 0.0
1044
1045        def ged(seq1: str, seq2: str) -> float:
1046
1047            matches, mismatches = 0, 0
1048
1049            for c1, c2 in zip(seq1, seq2):
1050                if c1 != '-' and c2 != '-':
1051                    if c1 == c2:
1052                        matches += 1
1053                    else:
1054                        mismatches += 1
1055            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1056
1057        def gcd(seq1: str, seq2: str) -> float:
1058            matches = 0
1059            mismatches = 0
1060            in_gap = False
1061
1062            for char1, char2 in zip(seq1, seq2):
1063                if char1 == '-' and char2 == '-':  # Shared gap: do nothing
1064                    continue
1065                elif char1 == '-' or char2 == '-':  # Gap in only one sequence
1066                    if not in_gap:  # Start of a new gap stretch
1067                        mismatches += 1
1068                        in_gap = True
1069                else:  # No gaps
1070                    in_gap = False
1071                    if char1 == char2:  # Matching characters
1072                        matches += 1
1073                    else:  # Mismatched characters
1074                        mismatches += 1
1075
1076            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1077
1078
1079        # Map distance type to corresponding function
1080        distance_functions: Dict[str, Callable[[str, str], float]] = {
1081            'ghd': ghd,
1082            'lhd': lhd,
1083            'ged': ged,
1084            'gcd': gcd
1085        }
1086
1087        if distance_type not in distance_functions:
1088            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1089
1090        # Compute pairwise distances
1091        aln = self.alignment
1092        distance_func = distance_functions[distance_type]
1093        distance_matrix = np.zeros((len(aln), len(aln)))
1094
1095        sequences = list(aln.values())
1096        n = len(sequences)
1097        for i in range(n):
1098            seq1 = sequences[i]
1099            for j in range(i, n):
1100                seq2 = sequences[j]
1101                dist = distance_func(seq1, seq2)
1102                distance_matrix[i, j] = dist
1103                distance_matrix[j, i] = dist
1104
1105        return distance_matrix
1106
1107    def get_snps(self, include_ambig:bool=False) -> dict:
1108        """
1109        Calculate snps similar to snp-sites (output is comparable):
1110        https://github.com/sanger-pathogens/snp-sites
1111        Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character.
1112        The SNPs are compared to a majority consensus sequence or to a reference if it has been set.
1113
1114        :param include_ambig: Include ambiguous snps (default: False)
1115        :return: dictionary containing snp positions and their variants including their frequency.
1116        """
1117        aln = self.alignment
1118        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
1119        aln = {x: aln[x] for x in aln.keys() if x != self.reference_id}
1120        seq_ids = list(aln.keys())
1121        snp_dict = {'#CHROM': self.reference_id if self.reference_id is not None else 'consensus', 'POS': {}}
1122
1123        for pos in range(self.length):
1124            reference_char = ref[pos]
1125            if not include_ambig:
1126                if reference_char in config.AMBIG_CHARS[self.aln_type] and reference_char != '-':
1127                    continue
1128            alt_chars, snps = [], []
1129            for i, seq_id in enumerate(aln.keys()):
1130                alt_chars.append(aln[seq_id][pos])
1131                if reference_char != aln[seq_id][pos]:
1132                    snps.append(i)
1133            if not snps:
1134                continue
1135            if include_ambig:
1136                if all(alt_chars[x] in config.AMBIG_CHARS[self.aln_type] for x in snps):
1137                    continue
1138            else:
1139                snps = [x for x in snps if alt_chars[x] not in config.AMBIG_CHARS[self.aln_type]]
1140                if not snps:
1141                    continue
1142            if pos not in snp_dict:
1143                snp_dict['POS'][pos] = {'ref': reference_char, 'ALT': {}}
1144            for snp in snps:
1145                if alt_chars[snp] not in snp_dict['POS'][pos]['ALT']:
1146                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]] = {
1147                        'AF': 1,
1148                        'SEQ_ID': [seq_ids[snp]]
1149                    }
1150                else:
1151                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['AF'] += 1
1152                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['SEQ_ID'].append(seq_ids[snp])
1153            # calculate AF
1154            if pos in snp_dict['POS']:
1155                for alt in snp_dict['POS'][pos]['ALT']:
1156                    snp_dict['POS'][pos]['ALT'][alt]['AF'] /= len(aln)
1157
1158        return snp_dict
1159
1160    def calc_transition_transversion_score(self) -> list:
1161        """
1162        Based on the snp positions, calculates a transition/transversions score.
1163        A positive score means higher ratio of transitions and negative score means
1164        a higher ratio of transversions.
1165        :return: list
1166        """
1167
1168        if self.aln_type == 'AA':
1169            raise TypeError('TS/TV scoring only for RNA/DNA alignments')
1170
1171        # ini
1172        snps = self.get_snps()
1173        score = [0]*self.length
1174
1175        for pos in snps['POS']:
1176            t_score_temp = 0
1177            for alt in snps['POS'][pos]['ALT']:
1178                # check the type of substitution
1179                if snps['POS'][pos]['ref'] + alt in ['AG', 'GA', 'CT', 'TC', 'CU', 'UC']:
1180                    score[pos] += snps['POS'][pos]['ALT'][alt]['AF']
1181                else:
1182                    score[pos] -= snps['POS'][pos]['ALT'][alt]['AF']
1183
1184        return score
1185
1186
1187class Annotation:
1188    """
1189    An annotation class that allows to read in gff, gb or bed files and adjust its locations to that of the MSA.
1190    """
1191
1192    def __init__(self, aln: MSA, annotation_path: str):
1193        """
1194        The annotation class. Lets you parse multiple standard formats
1195        which might be used for annotating an alignment. The main purpose
1196        is to parse the annotation file and adapt the locations of diverse
1197        features to the locations within the alignment, considering the
1198        respective alignment positions. Importantly, IDs of the alignment
1199        and the MSA have to partly match.
1200
1201        :param aln: MSA class
1202        :param annotation_path: path to annotation file (gb, bed, gff) or raw string
1203
1204        """
1205
1206        self.ann_type, self._seq_id, self.locus, self.features  = self._parse_annotation(annotation_path, aln)  # read annotation
1207        self._gapped_seq = self._MSA_validation_and_seq_extraction(aln, self._seq_id)  # extract gapped sequence
1208        self._position_map = self._build_position_map()  # build a position map
1209        self._map_to_alignment()  # adapt feature locations
1210
1211    @staticmethod
1212    def _MSA_validation_and_seq_extraction(aln: MSA, seq_id: str) -> str:
1213        """
1214        extract gapped sequence from MSA that corresponds to annotation
1215        :param aln: MSA class
1216        :param seq_id: sequence id to extract
1217        :return: gapped sequence
1218        """
1219        if not isinstance(aln, MSA):
1220            raise ValueError('alignment has to be an MSA class. use explore.MSA() to read in alignment')
1221        else:
1222            return aln._alignment[seq_id]
1223
1224    @staticmethod
1225    def _parse_annotation(annotation_path: str, aln: MSA) -> tuple[str, str, str, Dict]:
1226
1227        def detect_annotation_type(file_path: str) -> str:
1228            """
1229            Detect the type of annotation file (GenBank, GFF, or BED) based
1230            on the first relevant line (excluding empty and #)
1231
1232            :param file_path: Path to the annotation file.
1233            :return: The detected file type ('gb', 'gff', or 'bed').
1234
1235            :raises ValueError: If the file type cannot be determined.
1236            """
1237
1238            with _get_line_iterator(file_path) as file:
1239                for line in file:
1240                    # skip empty lines and comments
1241                    if not line.strip() or line.startswith('#'):
1242                        continue
1243                   # genbank
1244                    if line.startswith('LOCUS'):
1245                        return 'gb'
1246                    # gff
1247                    if len(line.split('\t')) == 9:
1248                        # Check for expected values
1249                        columns = line.split('\t')
1250                        if columns[6] in ['+', '-', '.'] and re.match(r'^\d+$', columns[3]) and re.match(r'^\d+$',columns[4]):
1251                            return 'gff'
1252                    # BED files are tab-delimited with at least 3 fields: chrom, start, end
1253                    fields = line.split('\t')
1254                    if len(fields) >= 3 and re.match(r'^\d+$', fields[1]) and re.match(r'^\d+$', fields[2]):
1255                        return 'bed'
1256                    # only read in the first line
1257                    break
1258
1259            raise ValueError(
1260                "File type could not be determined. Ensure the file follows a recognized format (GenBank, GFF, or BED).")
1261
1262        def parse_gb(file_path) -> dict:
1263            """
1264            parse a genebank file to dictionary - primarily retained are the informations
1265            for qualifiers as these will be used for plotting.
1266
1267            :param file_path: path to genebank file
1268            :return: nested dictionary
1269
1270            """
1271
1272            def sanitize_gb_location(string: str) -> tuple[list, str]:
1273                """
1274                see: https://www.insdc.org/submitting-standards/feature-table/
1275                """
1276                strand = '+'
1277                locations = []
1278                # check the direction of the annotation
1279                if 'complement' in string:
1280                    strand = '-'
1281                # sanitize operators
1282                for operator in ['complement(', 'join(', 'order(']:
1283                    string = string.replace(operator, '')
1284                # sanitize possible chars for splitting start stop -
1285                # however in the future might not simply do this
1286                # as some useful information is retained
1287                for char in ['>', '<', ')']:
1288                    string = string.replace(char, '')
1289                # check if we have multiple location e.g. due to splicing
1290                if ',' in string:
1291                    raw_locations = string.split(',')
1292                else:
1293                    raw_locations = [string]
1294                # try to split start and stop
1295                for location in raw_locations:
1296                    for sep in ['..', '.', '^']:
1297                        if sep in location:
1298                            sanitized_locations = [int(x) for x in location.split(sep)]
1299                            sanitized_locations[0] = sanitized_locations[0] - 1  # enforce 0-based starts
1300                            locations.append(sanitized_locations)
1301                            break
1302
1303                return locations, strand
1304
1305
1306            records = {}
1307            with _get_line_iterator(file_path) as file:
1308                record = None
1309                in_features = False
1310                counter_dict = {}
1311                for line in file:
1312                    line = line.rstrip()
1313                    parts = line.split()
1314                    # extract the locus id
1315                    if line.startswith('LOCUS'):
1316                        if record:
1317                            records[record['locus']] = record
1318                        record = {
1319                            'locus': parts[1],
1320                            'features': {}
1321                        }
1322
1323                    elif line.startswith('FEATURES'):
1324                        in_features = True
1325
1326                    # ignore the sequence info
1327                    elif line.startswith('ORIGIN'):
1328                        in_features = False
1329
1330                    # now write useful feature information to dictionary
1331                    elif in_features:
1332                        if not line.strip():
1333                            continue
1334                        if line[5] != ' ':
1335                            location_line = True  # remember that we are in a location for multi-line locations
1336                            feature_type, qualifier = parts[0], parts[1]
1337                            if feature_type not in record['features']:
1338                                record['features'][feature_type] = {}
1339                                counter_dict[feature_type] = 0
1340                            locations, strand = sanitize_gb_location(qualifier)
1341                            record['features'][feature_type][counter_dict[feature_type]] = {
1342                                'location': locations,
1343                                'strand': strand
1344                            }
1345                            counter_dict[feature_type] += 1
1346                        else:
1347                            # edge case for multi-line locations
1348                            if location_line and not line.strip().startswith('/'):
1349                                locations, strand = sanitize_gb_location(parts[0])
1350                                for loc in locations:
1351                                    record['features'][feature_type][counter_dict[feature_type]]['location'].append(loc)
1352                            else:
1353                                location_line = False
1354                                try:
1355                                    qualifier_type, qualifier = parts[0].split('=')
1356                                except ValueError:  # we are in the coding sequence
1357                                    qualifier = qualifier + parts[0]
1358
1359                                qualifier_type, qualifier = qualifier_type.lstrip('/'), qualifier.strip('"')
1360                                last_index = counter_dict[feature_type] - 1
1361                                record['features'][feature_type][last_index][qualifier_type] = qualifier
1362
1363            records[record['locus']] = record
1364
1365            return records
1366
1367        def parse_gff(file_path) -> dict:
1368            """
1369            Parse a GFF3 (General Feature Format) file into a dictionary structure.
1370
1371            :param file_path: path to genebank file
1372            :return: nested dictionary
1373
1374            """
1375            records = {}
1376            with _get_line_iterator(file_path) as file:
1377                previous_id, previous_feature = None, None
1378                for line in file:
1379                    if line.startswith('#') or not line.strip():
1380                        continue
1381                    parts = line.strip().split('\t')
1382                    seqid, source, feature_type, start, end, score, strand, phase, attributes = parts
1383                    # ensure that region and source features are not named differently for gff and gb
1384                    if feature_type == 'region':
1385                        feature_type = 'source'
1386                    if seqid not in records:
1387                        records[seqid] = {'locus': seqid, 'features': {}}
1388                    if feature_type not in records[seqid]['features']:
1389                        records[seqid]['features'][feature_type] = {}
1390
1391                    feature_id = len(records[seqid]['features'][feature_type])
1392                    feature = {
1393                        'strand': strand,
1394                    }
1395
1396                    # Parse attributes into key-value pairs
1397                    for attr in attributes.split(';'):
1398                        if '=' in attr:
1399                            key, value = attr.split('=', 1)
1400                            feature[key.strip()] = value.strip()
1401
1402                    # check if feature are the same --> possible splicing
1403                    if previous_id is not None and previous_feature == feature:
1404                        records[seqid]['features'][feature_type][previous_id]['location'].append([int(start)-1, int(end)])
1405                    else:
1406                        records[seqid]['features'][feature_type][feature_id] = feature
1407                        records[seqid]['features'][feature_type][feature_id]['location'] = [[int(start) - 1, int(end)]]
1408                    # set new previous id and features -> new dict as 'location' is pointed in current feature and this
1409                    # is the only key different if next feature has the same entries
1410                    previous_id, previous_feature = feature_id, {key:value for key, value in feature.items() if key != 'location'}
1411
1412            return records
1413
1414        def parse_bed(file_path) -> dict:
1415            """
1416            Parse a BED file into a dictionary structure.
1417
1418            :param file_path: path to genebank file
1419            :return: nested dictionary
1420
1421            """
1422            records = {}
1423            with _get_line_iterator(file_path) as file:
1424                for line in file:
1425                    if line.startswith('#') or not line.strip():
1426                        continue
1427                    parts = line.strip().split('\t')
1428                    chrom, start, end, *optional = parts
1429
1430                    if chrom not in records:
1431                        records[chrom] = {'locus': chrom, 'features': {}}
1432                    feature_type = 'region'
1433                    if feature_type not in records[chrom]['features']:
1434                        records[chrom]['features'][feature_type] = {}
1435
1436                    feature_id = len(records[chrom]['features'][feature_type])
1437                    feature = {
1438                        'location': [[int(start), int(end)]],  # BED uses 0-based start, convert to 1-based
1439                        'strand': '+',  # assume '+' if not present
1440                    }
1441
1442                    # Handle optional columns (name, score, strand) --> ignore 7-12
1443                    if len(optional) >= 1:
1444                        feature['name'] = optional[0]
1445                    if len(optional) >= 2:
1446                        feature['score'] = optional[1]
1447                    if len(optional) >= 3:
1448                        feature['strand'] = optional[2]
1449
1450                    records[chrom]['features'][feature_type][feature_id] = feature
1451
1452            return records
1453
1454        parse_functions: Dict[str, Callable[[str], dict]] = {
1455            'gb': parse_gb,
1456            'bed': parse_bed,
1457            'gff': parse_gff,
1458        }
1459        # determine the annotation content -> should be standard formatted
1460        try:
1461            annotation_type = detect_annotation_type(annotation_path)
1462        except ValueError as err:
1463            raise err
1464
1465        # read in the annotation
1466        annotations = parse_functions[annotation_type](annotation_path)
1467
1468        # sanity check whether one of the annotation ids and alignment ids match
1469        annotation_found = False
1470        for annotation in annotations.keys():
1471            for aln_id in aln.alignment.keys():
1472                aln_id_sanitized = aln_id.split(' ')[0]
1473                # check in both directions
1474                if aln_id_sanitized in annotation:
1475                    annotation_found = True
1476                    break
1477                if annotation in aln_id_sanitized:
1478                    annotation_found = True
1479                    break
1480
1481        if not annotation_found:
1482            raise ValueError(f'the annotations of {annotation_path} do not match any ids in the MSA')
1483
1484        # return only the annotation that has been found, the respective type and the seq_id to map to
1485        return annotation_type, aln_id, annotations[annotation]['locus'], annotations[annotation]['features']
1486
1487
1488    def _build_position_map(self) -> Dict[int, int]:
1489        """
1490        build a position map from a sequence.
1491
1492        :return genomic position: gapped position
1493        """
1494
1495        position_map = {}
1496        genomic_pos = 0
1497        for aln_index, char in enumerate(self._gapped_seq):
1498            if char != '-':
1499                position_map[genomic_pos] = aln_index
1500                genomic_pos += 1
1501        # ensure the last genomic position is included
1502        position_map[genomic_pos] = len(self._gapped_seq)
1503
1504        return position_map
1505
1506
1507    def _map_to_alignment(self):
1508        """
1509        Adjust all feature locations to alignment positions
1510        """
1511
1512        def map_location(position_map: Dict[int, int], locations: list) -> list:
1513            """
1514            Map genomic locations to alignment positions using a precomputed position map.
1515
1516            :param position_map: Positions mapped from gapped to ungapped
1517            :param locations: List of genomic start and end positions.
1518            :return: List of adjusted alignment positions.
1519            """
1520
1521            aligned_locs = []
1522            for start, end in locations:
1523                try:
1524                    aligned_start = position_map[start]
1525                    aligned_end = position_map[end]
1526                    aligned_locs.append([aligned_start, aligned_end])
1527                except KeyError:
1528                    raise ValueError(f"Positions {start}-{end} lie outside of the position map.")
1529
1530            return aligned_locs
1531
1532        for feature_type, features in self.features.items():
1533            for feature_id, feature_data in features.items():
1534                original_locations = feature_data['location']
1535                aligned_locations = map_location(self._position_map, original_locations)
1536                feature_data['location'] = aligned_locations
class MSA:
  37class MSA:
  38    """
  39    An alignment class that allows computation of several stats
  40    """
  41
  42    def __init__(self, alignment_string: str, reference_id: str = None, zoom_range: tuple | int = None):
  43        """
  44        Initialise an Alignment object.
  45        :param alignment_string: Path to alignment file or raw alignment string
  46        :param reference_id: reference id
  47        :param zoom_range: start and stop positions to zoom into the alignment
  48        """
  49        self._alignment = self._read_alignment(alignment_string)
  50        self._reference_id = self._validate_ref(reference_id, self._alignment)
  51        self._zoom = self._validate_zoom(zoom_range, self._alignment)
  52        self._aln_type = self._determine_aln_type(self._alignment)
  53
  54    # TODO: read in different alignment types
  55    # Static methods
  56    @staticmethod
  57    def _read_alignment(file_path: str) -> dict:
  58        """
  59        Parse MSA alignment file.
  60        :param file_path: path to alignment file
  61        :return: dictionary with ids as keys and sequences as values
  62        """
  63
  64        def add_seq(aln: dict, sequence_id: str, seq_list: list):
  65            """
  66            Add a complete sequence and check for non-allowed chars
  67            :param aln: alignment dictionary to build
  68            :param sequence_id: sequence id to add
  69            :param seq_list: sequences to add
  70            :return: alignment with added sequences
  71            """
  72            final_seq = ''.join(seq_list).upper()
  73            # Check for non-allowed characters
  74            invalid_chars = set(final_seq) - set(config.POSSIBLE_CHARS)
  75            if invalid_chars:
  76                raise ValueError(
  77                    f"{sequence_id} contains invalid characters: {', '.join(invalid_chars)}. Allowed chars are: {config.POSSIBLE_CHARS}"
  78                )
  79            aln[sequence_id] = final_seq
  80
  81            return aln
  82
  83        alignment, seq_lines = {}, []
  84        seq_id = None
  85
  86        with _get_line_iterator(file_path) as file:
  87            for i, line in enumerate(file):
  88                line = line.strip()
  89                # initial check for fasta format
  90                if i == 0 and not line.startswith(">"):
  91                    raise ValueError('Alignment has to be in fasta format starting with >SeqID.')
  92                if line.startswith(">"):
  93                    if seq_id:
  94                        alignment = add_seq(alignment, seq_id, seq_lines)
  95                    # initialize a new sequence
  96                    seq_id, seq_lines = line[1:], []
  97                else:
  98                    seq_lines.append(line)
  99            # handle last sequence
 100            if seq_id:
 101                alignment = add_seq(alignment, seq_id, seq_lines)
 102        # final sanity checks
 103        if alignment:
 104            # alignment contains only one sequence:
 105            if len(alignment) < 2:
 106                raise ValueError("Alignment must contain more than one sequence.")
 107            # alignment sequences are not same length
 108            first_seq_len = len(next(iter(alignment.values())))
 109            for sequence_id, sequence in alignment.items():
 110                if len(sequence) != first_seq_len:
 111                    raise ValueError(
 112                        f"All alignment sequences must have the same length. Sequence '{sequence_id}' has length {len(sequence)}, expected {first_seq_len}."
 113                    )
 114            # all checks passed
 115            return alignment
 116        else:
 117            raise ValueError(f"Alignment file {file_path} does not contain any sequences in fasta format.")
 118
 119    @staticmethod
 120    def _validate_ref(reference: str | None, alignment: dict) -> str | None | ValueError:
 121        """
 122        Validate if the ref seq is indeed part of the alignment.
 123        :param reference: reference seq id
 124        :param alignment: alignment dict
 125        :return: validated reference
 126        """
 127        if reference in alignment.keys():
 128            return reference
 129        elif reference is None:
 130            return reference
 131        else:
 132            raise ValueError('Reference not in alignment.')
 133
 134    @staticmethod
 135    def _validate_zoom(zoom: tuple | int, original_aln: dict) -> ValueError | tuple | None:
 136        """
 137        Validates if the user defined zoom range is within the start, end of the initial
 138        alignment.\n
 139        :param zoom: zoom range or zoom start
 140        :param original_aln: non-zoomed alignment dict
 141        :return: validated zoom range
 142        """
 143        if zoom is not None:
 144            aln_length = len(original_aln[list(original_aln.keys())[0]])
 145            # check if only over value is provided -> stop is alignment length
 146            if isinstance(zoom, int):
 147                if 0 <= zoom < aln_length:
 148                    return zoom, aln_length - 1
 149                else:
 150                    raise ValueError('Zoom start must be within the alignment length range.')
 151            # check if more than 2 values are provided
 152            if len(zoom) != 2:
 153                raise ValueError('Zoom position have to be (zoom_start, zoom_end)')
 154            # validate zoom start/stop
 155            for position in zoom:
 156                if type(position) != int:
 157                    raise ValueError('Zoom positions have to be integers.')
 158                if position not in range(0, aln_length):
 159                    raise ValueError('Zoom position out of range')
 160
 161        return zoom
 162
 163    @staticmethod
 164    def _determine_aln_type(alignment) -> str:
 165        """
 166        Determine the most likely type of alignment
 167        if 70% of chars in the alignment are nucleotide
 168        chars it is most likely a nt alignment
 169        :return: type of alignment
 170        """
 171        counter = int()
 172        for record in alignment:
 173            if 'U' in alignment[record]:
 174                return 'RNA'
 175            counter += sum(map(alignment[record].count, ['A', 'C', 'G', 'T', 'N', '-']))
 176        # determine which is the most likely type
 177        if counter / len(alignment) >= 0.7 * len(alignment[list(alignment.keys())[0]]):
 178            return 'DNA'
 179        else:
 180            return 'AA'
 181
 182    # Properties with setters
 183    @property
 184    def reference_id(self):
 185        return self._reference_id
 186
 187    @reference_id.setter
 188    def reference_id(self, ref_id: str):
 189        """
 190        Set and validate the reference id.
 191        """
 192        self._reference_id = self._validate_ref(ref_id, self.alignment)
 193
 194    @property
 195    def zoom(self) -> tuple:
 196        return self._zoom
 197
 198    @zoom.setter
 199    def zoom(self, zoom_pos: tuple | int):
 200        """
 201        Validate if the user defined zoom range.
 202        """
 203        self._zoom = self._validate_zoom(zoom_pos, self._alignment)
 204
 205    # Property without setters
 206    @property
 207    def aln_type(self) -> str:
 208        """
 209        define the aln type:
 210        RNA, DNA or AA
 211        """
 212        return self._aln_type
 213
 214    # On the fly properties without setters
 215    @property
 216    def length(self) -> int:
 217        return len(next(iter(self.alignment.values())))
 218
 219    @property
 220    def alignment(self) -> dict:
 221        """
 222        (zoomed) version of the alignment.
 223        """
 224        if self.zoom is not None:
 225            zoomed_aln = dict()
 226            for seq in self._alignment:
 227                zoomed_aln[seq] = self._alignment[seq][self.zoom[0]:self.zoom[1]]
 228            return zoomed_aln
 229        else:
 230            return self._alignment
 231
 232    # functions for different alignment stats
 233    def get_reference_coords(self) -> tuple[int, int]:
 234        """
 235        Determine the start and end coordinates of the reference sequence
 236        defined as the first/last nucleotide in the reference sequence
 237        (excluding N and gaps).
 238
 239        :return: start, end
 240        """
 241        start, end = 0, self.length
 242
 243        if self.reference_id is None:
 244            return start, end
 245        else:
 246            # 5' --> 3'
 247            for start in range(self.length):
 248                if self.alignment[self.reference_id][start] not in ['-', 'N']:
 249                    break
 250            # 3' --> 5'
 251            for end in range(self.length - 1, 0, -1):
 252                if self.alignment[self.reference_id][end] not in ['-', 'N']:
 253                    break
 254
 255            return start, end
 256
 257    def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
 258        """
 259        Creates a non-gapped consensus sequence.
 260
 261        :param threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes
 262            the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments)
 263            or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
 264        :param use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position
 265            has a frequency above the defined threshold.
 266        :return: consensus sequence
 267        """
 268
 269        # helper functions
 270        def determine_counts(alignment_dict: dict, position: int) -> dict:
 271            """
 272            count the number of each char at
 273            an idx of the alignment. return sorted dic.
 274            handles ambiguous nucleotides in sequences.
 275            also handles gaps.
 276            """
 277            nucleotide_list = []
 278
 279            # get all nucleotides
 280            for sequence in alignment_dict.items():
 281                nucleotide_list.append(sequence[1][position])
 282            # count occurences of nucleotides
 283            counter = dict(collections.Counter(nucleotide_list))
 284            # get permutations of an ambiguous nucleotide
 285            to_delete = []
 286            temp_dict = {}
 287            for nucleotide in counter:
 288                if nucleotide in config.AMBIG_CHARS[self.aln_type]:
 289                    to_delete.append(nucleotide)
 290                    permutations = config.AMBIG_CHARS[self.aln_type][nucleotide]
 291                    adjusted_freq = 1 / len(permutations)
 292                    for permutation in permutations:
 293                        if permutation in temp_dict:
 294                            temp_dict[permutation] += adjusted_freq
 295                        else:
 296                            temp_dict[permutation] = adjusted_freq
 297
 298            # drop ambiguous entries and add adjusted freqs to
 299            if to_delete:
 300                for i in to_delete:
 301                    counter.pop(i)
 302                for nucleotide in temp_dict:
 303                    if nucleotide in counter:
 304                        counter[nucleotide] += temp_dict[nucleotide]
 305                    else:
 306                        counter[nucleotide] = temp_dict[nucleotide]
 307
 308            return dict(sorted(counter.items(), key=lambda x: x[1], reverse=True))
 309
 310        def get_consensus_char(counts: dict, cutoff: float) -> list:
 311            """
 312            get a list of nucleotides for the consensus seq
 313            """
 314            n = 0
 315
 316            consensus_chars = []
 317            for char in counts:
 318                n += counts[char]
 319                consensus_chars.append(char)
 320                if n >= cutoff:
 321                    break
 322
 323            return consensus_chars
 324
 325        def get_ambiguous_char(nucleotides: list) -> str:
 326            """
 327            get ambiguous char from a list of nucleotides
 328            """
 329            for ambiguous, permutations in config.AMBIG_CHARS[self.aln_type].items():
 330                if set(permutations) == set(nucleotides):
 331                    break
 332
 333            return ambiguous
 334
 335        # check if params have been set correctly
 336        if threshold is not None:
 337            if threshold < 0 or threshold > 1:
 338                raise ValueError('Threshold must be between 0 and 1.')
 339        if self.aln_type == 'AA' and use_ambig_nt:
 340            raise ValueError('Ambiguous characters can not be calculated for amino acid alignments.')
 341        if threshold is None and use_ambig_nt:
 342            raise ValueError('To calculate ambiguous nucleotides, set a threshold > 0.')
 343
 344        alignment = self.alignment
 345        consensus = str()
 346
 347        if threshold is not None:
 348            consensus_cutoff = len(alignment) * threshold
 349        else:
 350            consensus_cutoff = 0
 351
 352        # built consensus sequences
 353        for idx in range(self.length):
 354            char_counts = determine_counts(alignment, idx)
 355            consensus_chars = get_consensus_char(
 356                char_counts,
 357                consensus_cutoff
 358            )
 359            if threshold != 0:
 360                if len(consensus_chars) > 1:
 361                    if use_ambig_nt:
 362                        char = get_ambiguous_char(consensus_chars)
 363                    else:
 364                        if self.aln_type == 'AA':
 365                            char = 'X'
 366                        else:
 367                            char = 'N'
 368                    consensus = consensus + char
 369                else:
 370                    consensus = consensus + consensus_chars[0]
 371            else:
 372                consensus = consensus + consensus_chars[0]
 373
 374        return consensus
 375
 376    def get_conserved_orfs(self, min_length: int = 100, identity_cutoff: float | None = None) -> dict:
 377        """
 378        **conserved ORF definition:**
 379            - conserved starts and stops
 380            - start, stop must be on the same frame
 381            - stop - start must be at least min_length
 382            - all ungapped seqs[start:stop] must have at least min_length
 383            - no ungapped seq can have a Stop in between Start Stop
 384
 385        Conservation is measured by number of positions with identical characters divided by
 386        orf slice of the alignment.
 387
 388        **Algorithm overview:**
 389            - check for conserved start and stop codons
 390            - iterate over all three frames
 391            - check each start and next sufficiently far away stop codon
 392            - check if all ungapped seqs between start and stop codon are >= min_length
 393            - check if no ungapped seq in the alignment has a stop codon
 394            - write to dictionary
 395            - classify as internal if the stop codon has already been written with a prior start
 396            - repeat for reverse complement
 397
 398        :return: ORF positions and internal ORF positions
 399        """
 400
 401        # helper functions
 402        def determine_conserved_start_stops(alignment: dict, alignment_length: int) -> tuple:
 403            """
 404            Determine all start and stop codons within an alignment.
 405            :param alignment: alignment
 406            :param alignment_length: length of alignment
 407            :return: start and stop codon positions
 408            """
 409            starts = config.START_CODONS[self.aln_type]
 410            stops = config.STOP_CODONS[self.aln_type]
 411
 412            list_of_starts, list_of_stops = [], []
 413            ref = alignment[list(alignment.keys())[0]]
 414            for nt_position in range(alignment_length):
 415                if ref[nt_position:nt_position + 3] in starts:
 416                    conserved_start = True
 417                    for sequence in alignment:
 418                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in starts:
 419                            conserved_start = False
 420                            break
 421                    if conserved_start:
 422                        list_of_starts.append(nt_position)
 423
 424                if ref[nt_position:nt_position + 3] in stops:
 425                    conserved_stop = True
 426                    for sequence in alignment:
 427                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in stops:
 428                            conserved_stop = False
 429                            break
 430                    if conserved_stop:
 431                        list_of_stops.append(nt_position)
 432
 433            return list_of_starts, list_of_stops
 434
 435        def get_ungapped_sliced_seqs(alignment: dict, start_pos: int, stop_pos: int) -> list:
 436            """
 437            get ungapped sequences starting and stop codons and eliminate gaps
 438            :param alignment: alignment
 439            :param start_pos: start codon
 440            :param stop_pos: stop codon
 441            :return: sliced sequences
 442            """
 443            ungapped_seqs = []
 444            for seq_id in alignment:
 445                ungapped_seqs.append(alignment[seq_id][start_pos:stop_pos + 3].replace('-', ''))
 446
 447            return ungapped_seqs
 448
 449        def additional_stops(ungapped_seqs: list) -> bool:
 450            """
 451            Checks for the presence of a stop codon
 452            :param ungapped_seqs: list of ungapped sequences
 453            :return: Additional stop codons (True/False)
 454            """
 455            stops = config.STOP_CODONS[self.aln_type]
 456
 457            for sliced_seq in ungapped_seqs:
 458                for position in range(0, len(sliced_seq) - 3, 3):
 459                    if sliced_seq[position:position + 3] in stops:
 460                        return True
 461            return False
 462
 463        def calculate_identity(identity_matrix: ndarray, aln_slice:list) -> float:
 464            sliced_array = identity_matrix[:,aln_slice[0]:aln_slice[1]] + 1  # identical = 0, different = -1 --> add 1
 465            return np.sum(np.all(sliced_array == 1, axis=0))/(aln_slice[1] - aln_slice[0]) * 100
 466
 467        # checks for arguments
 468        if self.aln_type == 'AA':
 469            raise TypeError('ORF search only for RNA/DNA alignments')
 470
 471        if identity_cutoff is not None:
 472            if identity_cutoff > 100 or identity_cutoff < 0:
 473                raise ValueError('conservation cutoff must be between 0 and 100')
 474
 475        if min_length <= 6 or min_length > self.length:
 476            raise ValueError(f'min_length must be between 6 and {self.length}')
 477
 478        # ini
 479        identities = self.calc_identity_alignment()
 480        alignments = [self.alignment, self.calc_reverse_complement_alignment()]
 481        aln_len = self.length
 482
 483        orf_counter = 0
 484        orf_dict = {}
 485
 486        for aln, direction in zip(alignments, ['+', '-']):
 487            # check for starts and stops in the first seq and then check if these are present in all seqs
 488            conserved_starts, conserved_stops = determine_conserved_start_stops(aln, aln_len)
 489            # check each frame
 490            for frame in (0, 1, 2):
 491                potential_starts = [x for x in conserved_starts if x % 3 == frame]
 492                potential_stops = [x for x in conserved_stops if x % 3 == frame]
 493                last_stop = -1
 494                for start in potential_starts:
 495                    # go to the next stop that is sufficiently far away in the alignment
 496                    next_stops = [x for x in potential_stops if x + 3 >= start + min_length]
 497                    if not next_stops:
 498                        continue
 499                    next_stop = next_stops[0]
 500                    ungapped_sliced_seqs = get_ungapped_sliced_seqs(aln, start, next_stop)
 501                    # re-check the lengths of all ungapped seqs
 502                    ungapped_seq_lengths = [len(x) >= min_length for x in ungapped_sliced_seqs]
 503                    if not all(ungapped_seq_lengths):
 504                        continue
 505                    # if no stop codon between start and stop --> write to dictionary
 506                    if not additional_stops(ungapped_sliced_seqs):
 507                        if direction == '+':
 508                            positions = [start, next_stop + 3]
 509                        else:
 510                            positions = [aln_len - next_stop - 3, aln_len - start]
 511                        if last_stop != next_stop:
 512                            last_stop = next_stop
 513                            conservation = calculate_identity(identities, positions)
 514                            if identity_cutoff is not None and conservation < identity_cutoff:
 515                                continue
 516                            orf_dict[f'ORF_{orf_counter}'] = {'location': [positions],
 517                                                              'frame': frame,
 518                                                              'strand': direction,
 519                                                              'conservation': conservation,
 520                                                              'internal': []
 521                                                              }
 522                            orf_counter += 1
 523                        else:
 524                            if orf_dict:
 525                                orf_dict[f'ORF_{orf_counter - 1}']['internal'].append(positions)
 526
 527        return orf_dict
 528
 529    def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff:float = None) -> dict:
 530        """
 531        First calculates all ORFs and then searches from 5'
 532        all non-overlapping orfs in the fw strand and from the
 533        3' all non-overlapping orfs in th rw strand.
 534
 535        **No overlap algorithm:**
 536            **frame 1:** -[M------*]--- ----[M--*]---------[M-----
 537
 538            **frame 2:** -------[M------*]---------[M---*]--------
 539
 540            **frame 3:** [M---*]-----[M----------*]----------[M---
 541
 542            **results:** [M---*][M------*]--[M--*]-[M---*]-[M-----
 543
 544            frame:    3      2           1      2       1
 545
 546        :return: dictionary with non-overlapping orfs
 547        """
 548        orf_dict = self.get_conserved_orfs(min_length, identity_cutoff)
 549
 550        fw_orfs, rw_orfs = [], []
 551
 552        for orf in orf_dict:
 553            if orf_dict[orf]['strand'] == '+':
 554                fw_orfs.append((orf, orf_dict[orf]['location'][0]))
 555            else:
 556                rw_orfs.append((orf, orf_dict[orf]['location'][0]))
 557
 558        fw_orfs.sort(key=lambda x: x[1][0])  # sort by start pos
 559        rw_orfs.sort(key=lambda x: x[1][1], reverse=True)  # sort by stop pos
 560        non_overlapping_orfs = []
 561        for orf_list, strand in zip([fw_orfs, rw_orfs], ['+', '-']):
 562            previous_stop = -1 if strand == '+' else self.length + 1
 563            for orf in orf_list:
 564                if strand == '+' and orf[1][0] > previous_stop:
 565                    non_overlapping_orfs.append(orf[0])
 566                    previous_stop = orf[1][1]
 567                elif strand == '-' and orf[1][1] < previous_stop:
 568                    non_overlapping_orfs.append(orf[0])
 569                    previous_stop = orf[1][0]
 570
 571        non_overlap_dict = {}
 572        for orf in orf_dict:
 573            if orf in non_overlapping_orfs:
 574                non_overlap_dict[orf] = orf_dict[orf]
 575
 576        return non_overlap_dict
 577
 578    def calc_length_stats(self) -> dict:
 579        """
 580        Determine the stats for the length of the ungapped seqs in the alignment.
 581        :return: dictionary with length stats
 582        """
 583
 584        seq_lengths = [len(self.alignment[x].replace('-', '')) for x in self.alignment]
 585
 586        return {'number of seq': len(self.alignment),
 587                'mean length': float(np.mean(seq_lengths)),
 588                'std length': float(np.std(seq_lengths)),
 589                'min length': int(np.min(seq_lengths)),
 590                'max length': int(np.max(seq_lengths))
 591                }
 592
 593    def calc_entropy(self) -> list:
 594        """
 595        Calculate the normalized shannon's entropy for every position in an alignment:
 596
 597        - 1: high entropy
 598        - 0: low entropy
 599
 600        :return: Entropies at each position.
 601        """
 602
 603        # helper functions
 604        def shannons_entropy(character_list: list, states: int, aln_type: str) -> float:
 605            """
 606            Calculate the shannon's entropy of a sequence and
 607            normalized between 0 and 1.
 608            :param character_list: characters at an alignment position
 609            :param states: number of potential characters that can be present
 610            :param aln_type: type of the alignment
 611            :returns: entropy
 612            """
 613            ent, n_chars = 0, len(character_list)
 614            # only one char is in the list
 615            if n_chars <= 1:
 616                return ent
 617            # calculate the number of unique chars and their counts
 618            chars, char_counts = np.unique(character_list, return_counts=True)
 619            char_counts = char_counts.astype(float)
 620            # ignore gaps for entropy calc
 621            char_counts, chars = char_counts[chars != "-"], chars[chars != "-"]
 622            # correctly handle ambiguous chars
 623            index_to_drop = []
 624            for index, char in enumerate(chars):
 625                if char in config.AMBIG_CHARS[aln_type]:
 626                    index_to_drop.append(index)
 627                    amb_chars, amb_counts = np.unique(config.AMBIG_CHARS[aln_type][char], return_counts=True)
 628                    amb_counts = amb_counts / len(config.AMBIG_CHARS[aln_type][char])
 629                    # add the proportionate numbers to initial array
 630                    for amb_char, amb_count in zip(amb_chars, amb_counts):
 631                        if amb_char in chars:
 632                            char_counts[chars == amb_char] += amb_count
 633                        else:
 634                            chars, char_counts = np.append(chars, amb_char), np.append(char_counts, amb_count)
 635            # drop the ambiguous characters from array
 636            char_counts, chars = np.delete(char_counts, index_to_drop), np.delete(chars, index_to_drop)
 637            # calc the entropy
 638            probs = char_counts / n_chars
 639            if np.count_nonzero(probs) <= 1:
 640                return ent
 641            for prob in probs:
 642                ent -= prob * math.log(prob, states)
 643
 644            return ent
 645
 646        aln = self.alignment
 647        entropys = []
 648
 649        if self.aln_type == 'AA':
 650            states = 20
 651        else:
 652            states = 4
 653        # iterate over alignment positions and the sequences
 654        for nuc_pos in range(self.length):
 655            pos = []
 656            for record in aln:
 657                pos.append(aln[record][nuc_pos])
 658            entropys.append(shannons_entropy(pos, states, self.aln_type))
 659
 660        return entropys
 661
 662    def calc_gc(self) -> list | TypeError:
 663        """
 664        Determine the GC content for every position in an nt alignment.
 665        :return: GC content for every position.
 666        :raises: TypeError for AA alignments
 667        """
 668        if self.aln_type == 'AA':
 669            raise TypeError("GC computation is not possible for aminoacid alignment")
 670
 671        gc, aln, amb_nucs = [], self.alignment, config.AMBIG_CHARS[self.aln_type]
 672
 673        for position in range(self.length):
 674            nucleotides = str()
 675            for record in aln:
 676                nucleotides = nucleotides + aln[record][position]
 677            # ini dict with chars that occur and which ones to
 678            # count in which freq
 679            to_count = {
 680                'G': 1,
 681                'C': 1,
 682            }
 683            # handle ambig. nuc
 684            for char in amb_nucs:
 685                if char in nucleotides:
 686                    to_count[char] = (amb_nucs[char].count('C') + amb_nucs[char].count('G')) / len(amb_nucs[char])
 687
 688            gc.append(
 689                sum([nucleotides.count(x) * to_count[x] for x in to_count]) / len(nucleotides)
 690            )
 691
 692        return gc
 693
 694    def calc_coverage(self) -> list:
 695        """
 696        Determine the coverage of every position in an alignment.
 697        This is defined as:
 698            1 - cumulative length of '-' characters
 699
 700        :return: Coverage at each alignment position.
 701        """
 702        coverage, aln = [], self.alignment
 703
 704        for nuc_pos in range(self.length):
 705            pos = str()
 706            for record in aln.keys():
 707                pos = pos + aln[record][nuc_pos]
 708            coverage.append(1 - pos.count('-') / len(pos))
 709
 710        return coverage
 711
 712    def calc_reverse_complement_alignment(self) -> dict | TypeError:
 713        """
 714        Reverse complement the alignment.
 715        :return: Alignment (rv)
 716        """
 717        if self.aln_type == 'AA':
 718            raise TypeError('Reverse complement only for RNA or DNA.')
 719
 720        aln = self.alignment
 721        reverse_complement_dict = {}
 722
 723        for seq_id in aln:
 724            reverse_complement_dict[seq_id] = ''.join(config.COMPLEMENT[base] for base in reversed(aln[seq_id]))
 725
 726        return reverse_complement_dict
 727
 728    def calc_identity_alignment(self, encode_mismatches:bool=True, encode_mask:bool=False, encode_gaps:bool=True, encode_ambiguities:bool=False, encode_each_mismatch_char:bool=False) -> np.ndarray:
 729        """
 730        Converts alignment to identity array (identical=0) compared to majority consensus or reference:\n
 731
 732        :param encode_mismatches: encode mismatch as -1
 733        :param encode_mask: encode mask with value=-2 --> also in the reference
 734        :param encode_gaps: encode gaps with np.nan --> also in the reference
 735        :param encode_ambiguities: encode ambiguities with value=-3
 736        :param encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.DNA_colors, config.RNA_colors or config.AA_colors
 737        :return: identity alignment
 738        """
 739
 740        aln = self.alignment
 741        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
 742
 743        # convert alignment to array
 744        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 745        reference = np.array(list(ref))
 746        # ini matrix
 747        identity_matrix = np.full(sequences.shape, 0, dtype=float)
 748
 749        is_identical = sequences == reference
 750
 751        if encode_gaps:
 752            is_gap = sequences == '-'
 753        else:
 754            is_gap = np.full(sequences.shape, False)
 755
 756        if encode_mask:
 757            if self.aln_type == 'AA':
 758                is_n_or_x = np.isin(sequences, ['X'])
 759            else:
 760                is_n_or_x = np.isin(sequences, ['N'])
 761        else:
 762            is_n_or_x = np.full(sequences.shape, False)
 763
 764        if encode_ambiguities:
 765            is_ambig = np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])
 766        else:
 767            is_ambig = np.full(sequences.shape, False)
 768
 769        if encode_mismatches:
 770            is_mismatch = ~is_gap & ~is_identical & ~is_n_or_x & ~is_ambig
 771        else:
 772            is_mismatch = np.full(sequences.shape, False)
 773
 774        # encode every different character
 775        if encode_each_mismatch_char:
 776            for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]['standard']):
 777                new_encoding = np.isin(sequences, [char]) & is_mismatch
 778                identity_matrix[new_encoding] = idx + 1
 779        # or encode different with a single value
 780        else:
 781            identity_matrix[is_mismatch] = -1  # mismatch
 782
 783        identity_matrix[is_gap] = np.nan  # gap
 784        identity_matrix[is_n_or_x] = -2  # 'N' or 'X'
 785        identity_matrix[is_ambig] = -3  # ambiguities
 786
 787        return identity_matrix
 788
 789    def calc_similarity_alignment(self, matrix_type:str|None=None, normalize:bool=True) -> np.ndarray:
 790        """
 791        Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight
 792        differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the
 793        reference residue at each column. Gaps are encoded as np.nan.
 794
 795        The calculation follows these steps:
 796
 797        1. **Reference Sequence**: If a reference sequence is provided (via `self.reference_id`), it is used. Otherwise,
 798           a consensus sequence is generated to serve as the reference.
 799        2. **Substitution Matrix**: The similarity between residues is determined using a substitution matrix, such as
 800           BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
 801        3. **Per-Column Normalization (optional)**:
 802           - For each column in the alignment:
 803             - The residue in the reference sequence is treated as the baseline for that column.
 804             - The substitution scores for the reference residue are extracted from the substitution matrix.
 805             - The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference
 806               residue.
 807           - This ensures that identical residues (or those with high similarity to the reference) have high scores,
 808             while more dissimilar residues have lower scores.
 809        4. **Output**:
 810           - The normalized similarity scores are stored in a NumPy array.
 811           - Gaps (if any) or residues not present in the substitution matrix are encoded as `np.nan`.
 812
 813        :param: matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
 814        :param: normalize: whether to normalize the similarity scores to range [0, 1]
 815        :return: A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue
 816            and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity).
 817            Gaps and invalid residues are encoded as `np.nan`.
 818        :raise: ValueError
 819            If the specified substitution matrix is not available for the given alignment type.
 820        """
 821
 822        aln = self.alignment
 823        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
 824        if matrix_type is None:
 825            if self.aln_type == 'AA':
 826                matrix_type = 'BLOSUM65'
 827            else:
 828                matrix_type = 'TRANS'
 829        # load substitution matrix as dictionary
 830        try:
 831            subs_matrix = config.SUBS_MATRICES[self.aln_type][matrix_type]
 832        except KeyError:
 833            raise ValueError(
 834                f'The specified matrix does not exist for alignment type.\nAvailable matrices for {self.aln_type} are:\n{list(config.SUBS_MATRICES[self.aln_type].keys())}'
 835            )
 836
 837        # set dtype and convert alignment to a NumPy array for vectorized processing
 838        dtype = np.dtype(float, metadata={'matrix': matrix_type})
 839        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 840        reference = np.array(list(ref))
 841        valid_chars = list(subs_matrix.keys())
 842        similarity_array = np.full(sequences.shape, np.nan, dtype=dtype)
 843
 844        for j, ref_char in enumerate(reference):
 845            if ref_char not in valid_chars + ['-']:
 846                continue
 847            # Get local min and max for the reference residue
 848            if normalize and ref_char != '-':
 849                local_scores = subs_matrix[ref_char].values()
 850                local_min, local_max = min(local_scores), max(local_scores)
 851
 852            for i, char in enumerate(sequences[:, j]):
 853                if char not in valid_chars:
 854                    continue
 855                # classify the similarity as max if the reference has a gap
 856                similarity_score = subs_matrix[char][ref_char] if ref_char != '-' else 1
 857                similarity_array[i, j] = (similarity_score - local_min) / (local_max - local_min) if normalize and ref_char != '-' else similarity_score
 858
 859        return similarity_array
 860
 861    def calc_position_matrix(self, matrix_type:str='PWM') -> np.ndarray | ValueError:
 862        """
 863        Calculates a position matrix of the specified type for the given alignment. The function
 864        supports generating matrices of types Position Frequency Matrix (PFM), Position Probability
 865        Matrix (PPM), Position Weight Matrix (PWM), and cummulative Information Content (IC). It validates
 866        the provided matrix type and includes pseudo-count adjustments to ensure robust calculations.
 867
 868        :param matrix_type: Type of position matrix to calculate. Accepted values are 'PFM', 'PPM',
 869            'PWM', and 'IC'. Defaults to 'PWM'.
 870        :type matrix_type: str
 871        :raises ValueError: If the provided `matrix_type` is not one of the accepted values.
 872        :return: A numpy array representing the calculated position matrix of the specified type.
 873        :rtype: np.ndarray
 874        """
 875
 876        # ini
 877        aln = self.alignment
 878        if matrix_type not in ['PFM', 'PPM', 'IC', 'PWM']:
 879            raise ValueError('Matrix_type must be PFM, PPM, IC or PWM.')
 880        possible_chars = list(config.CHAR_COLORS[self.aln_type]['standard'].keys())[:-1]
 881        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 882
 883        # calc position frequency matrix
 884        pfm = np.array([np.sum(sequences == char, 0) for char in possible_chars])
 885        if matrix_type == 'PFM':
 886            return pfm
 887
 888        # calc position probability matrix (probability)
 889        pseudo_count = 0.0001  # to avoid 0 values
 890        pfm = pfm + pseudo_count
 891        ppm_non_char_excluded = pfm/np.sum(pfm, axis=0)  # use this for pwm/ic calculation
 892        ppm = pfm/len(aln.keys())  # calculate the frequency based on row number
 893        if matrix_type == 'PPM':
 894            return ppm
 895
 896        # calc position weight matrix (log-likelihood)
 897        pwm = np.log2(ppm_non_char_excluded * len(possible_chars))
 898        if matrix_type == 'PWM':
 899            return pwm
 900
 901        # calc information content per position (in bits) - can be used to scale a ppm for sequence logos
 902        ic = np.sum(ppm_non_char_excluded * pwm, axis=0)
 903        if matrix_type == 'IC':
 904            return ic
 905
 906        return None
 907
 908    def calc_percent_recovery(self) -> dict:
 909        """
 910        Recovery per sequence either compared to the majority consensus seq
 911        or the reference seq.\n
 912        Defined as:\n
 913
 914        `(1 - sum(N/X/- characters in ungapped ref regions))*100`
 915
 916        This is highly similar to how nextclade calculates recovery over reference.
 917
 918        :return: dict
 919        """
 920
 921        aln = self.alignment
 922
 923        if self.reference_id is not None:
 924            ref = aln[self.reference_id]
 925        else:
 926            ref = self.get_consensus()  # majority consensus
 927
 928        if not any(char != '-' for char in ref):
 929            raise ValueError("Reference sequence is entirely gapped, cannot calculate recovery.")
 930
 931
 932        # count 'N', 'X' and '-' chars in non-gapped regions
 933        recovery_over_ref = dict()
 934
 935        # Get positions of non-gap characters in the reference
 936        non_gap_positions = [i for i, char in enumerate(ref) if char != '-']
 937        cumulative_length = len(non_gap_positions)
 938
 939        # Calculate recovery
 940        for seq_id in aln:
 941            if seq_id == self.reference_id:
 942                continue
 943            seq = aln[seq_id]
 944            count_invalid = sum(
 945                seq[pos] == '-' or
 946                (seq[pos] == 'X' if self.aln_type == "AA" else seq[pos] == 'N')
 947                for pos in non_gap_positions
 948            )
 949            recovery_over_ref[seq_id] = (1 - count_invalid / cumulative_length) * 100
 950
 951        return recovery_over_ref
 952
 953    def calc_character_frequencies(self) -> dict:
 954        """
 955        Calculate the percentage characters in the alignment:
 956        The frequencies are counted by seq and in total. The
 957        percentage of non-gap characters in the alignment is
 958        relative to the total number of non-gap characters.
 959        The gap percentage is relative to the sequence length.
 960
 961        The output is a nested dictionary.
 962
 963        :return: Character frequencies
 964        """
 965
 966        aln, aln_length = self.alignment, self.length
 967
 968        freqs = {'total': {'-': {'counts': 0, '% of alignment': float()}}}
 969
 970        for seq_id in aln:
 971            freqs[seq_id], all_chars = {'-': {'counts': 0, '% of alignment': float()}}, 0
 972            unique_chars = set(aln[seq_id])
 973            for char in unique_chars:
 974                if char == '-':
 975                    continue
 976                # add characters to dictionaries
 977                if char not in freqs[seq_id]:
 978                    freqs[seq_id][char] = {'counts': 0, '% of non-gapped': 0}
 979                if char not in freqs['total']:
 980                    freqs['total'][char] = {'counts': 0, '% of non-gapped': 0}
 981                # count non-gap chars
 982                freqs[seq_id][char]['counts'] += aln[seq_id].count(char)
 983                freqs['total'][char]['counts'] += freqs[seq_id][char]['counts']
 984                all_chars += freqs[seq_id][char]['counts']
 985            # normalize counts
 986            for char in freqs[seq_id]:
 987                if char == '-':
 988                    continue
 989                freqs[seq_id][char]['% of non-gapped'] = freqs[seq_id][char]['counts'] / all_chars * 100
 990                freqs['total'][char]['% of non-gapped'] += freqs[seq_id][char]['% of non-gapped']
 991            # count gaps
 992            freqs[seq_id]['-']['counts'] = aln[seq_id].count('-')
 993            freqs['total']['-']['counts'] += freqs[seq_id]['-']['counts']
 994            # normalize gap counts
 995            freqs[seq_id]['-']['% of alignment'] = freqs[seq_id]['-']['counts'] / aln_length * 100
 996            freqs['total']['-']['% of alignment'] += freqs[seq_id]['-']['% of alignment']
 997
 998        # normalize the total counts
 999        for char in freqs['total']:
1000            for value in freqs['total'][char]:
1001                if value == '% of alignment' or value == '% of non-gapped':
1002                    freqs['total'][char][value] = freqs['total'][char][value] / len(aln)
1003
1004        return freqs
1005
1006    def calc_pairwise_identity_matrix(self, distance_type:str='ghd') -> ndarray:
1007        """
1008        Calculate pairwise identities for an alignment. As there are different definitions of sequence identity, there are different options implemented:
1009
1010        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
1011        \ndistance = matches / alignment_length * 100
1012
1013        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
1014        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
1015
1016        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
1017        \ndistance = matches / (matches + mismatches) * 100
1018
1019        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
1020        \ndistance = matches / gap_compressed_alignment_length * 100
1021
1022        :return: array with pairwise distances.
1023        """
1024
1025        def hamming_distance(seq1: str, seq2: str) -> int:
1026            return sum(c1 == c2 for c1, c2 in zip(seq1, seq2))
1027
1028        def ghd(seq1: str, seq2: str) -> float:
1029            return hamming_distance(seq1, seq2) / self.length * 100
1030
1031        def lhd(seq1, seq2):
1032            # Trim gaps from both sides
1033            i, j = 0, self.length - 1
1034            while i < self.length and (seq1[i] == '-' or seq2[i] == '-'):
1035                i += 1
1036            while j >= 0 and (seq1[j] == '-' or seq2[j] == '-'):
1037                j -= 1
1038            if i > j:
1039                return 0.0
1040
1041            seq1_, seq2_ = seq1[i:j + 1], seq2[i:j + 1]
1042            matches = sum(c1 == c2 for c1, c2 in zip(seq1_, seq2_))
1043            length = j - i + 1
1044            return (matches / length) * 100 if length > 0 else 0.0
1045
1046        def ged(seq1: str, seq2: str) -> float:
1047
1048            matches, mismatches = 0, 0
1049
1050            for c1, c2 in zip(seq1, seq2):
1051                if c1 != '-' and c2 != '-':
1052                    if c1 == c2:
1053                        matches += 1
1054                    else:
1055                        mismatches += 1
1056            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1057
1058        def gcd(seq1: str, seq2: str) -> float:
1059            matches = 0
1060            mismatches = 0
1061            in_gap = False
1062
1063            for char1, char2 in zip(seq1, seq2):
1064                if char1 == '-' and char2 == '-':  # Shared gap: do nothing
1065                    continue
1066                elif char1 == '-' or char2 == '-':  # Gap in only one sequence
1067                    if not in_gap:  # Start of a new gap stretch
1068                        mismatches += 1
1069                        in_gap = True
1070                else:  # No gaps
1071                    in_gap = False
1072                    if char1 == char2:  # Matching characters
1073                        matches += 1
1074                    else:  # Mismatched characters
1075                        mismatches += 1
1076
1077            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1078
1079
1080        # Map distance type to corresponding function
1081        distance_functions: Dict[str, Callable[[str, str], float]] = {
1082            'ghd': ghd,
1083            'lhd': lhd,
1084            'ged': ged,
1085            'gcd': gcd
1086        }
1087
1088        if distance_type not in distance_functions:
1089            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1090
1091        # Compute pairwise distances
1092        aln = self.alignment
1093        distance_func = distance_functions[distance_type]
1094        distance_matrix = np.zeros((len(aln), len(aln)))
1095
1096        sequences = list(aln.values())
1097        n = len(sequences)
1098        for i in range(n):
1099            seq1 = sequences[i]
1100            for j in range(i, n):
1101                seq2 = sequences[j]
1102                dist = distance_func(seq1, seq2)
1103                distance_matrix[i, j] = dist
1104                distance_matrix[j, i] = dist
1105
1106        return distance_matrix
1107
1108    def get_snps(self, include_ambig:bool=False) -> dict:
1109        """
1110        Calculate snps similar to snp-sites (output is comparable):
1111        https://github.com/sanger-pathogens/snp-sites
1112        Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character.
1113        The SNPs are compared to a majority consensus sequence or to a reference if it has been set.
1114
1115        :param include_ambig: Include ambiguous snps (default: False)
1116        :return: dictionary containing snp positions and their variants including their frequency.
1117        """
1118        aln = self.alignment
1119        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
1120        aln = {x: aln[x] for x in aln.keys() if x != self.reference_id}
1121        seq_ids = list(aln.keys())
1122        snp_dict = {'#CHROM': self.reference_id if self.reference_id is not None else 'consensus', 'POS': {}}
1123
1124        for pos in range(self.length):
1125            reference_char = ref[pos]
1126            if not include_ambig:
1127                if reference_char in config.AMBIG_CHARS[self.aln_type] and reference_char != '-':
1128                    continue
1129            alt_chars, snps = [], []
1130            for i, seq_id in enumerate(aln.keys()):
1131                alt_chars.append(aln[seq_id][pos])
1132                if reference_char != aln[seq_id][pos]:
1133                    snps.append(i)
1134            if not snps:
1135                continue
1136            if include_ambig:
1137                if all(alt_chars[x] in config.AMBIG_CHARS[self.aln_type] for x in snps):
1138                    continue
1139            else:
1140                snps = [x for x in snps if alt_chars[x] not in config.AMBIG_CHARS[self.aln_type]]
1141                if not snps:
1142                    continue
1143            if pos not in snp_dict:
1144                snp_dict['POS'][pos] = {'ref': reference_char, 'ALT': {}}
1145            for snp in snps:
1146                if alt_chars[snp] not in snp_dict['POS'][pos]['ALT']:
1147                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]] = {
1148                        'AF': 1,
1149                        'SEQ_ID': [seq_ids[snp]]
1150                    }
1151                else:
1152                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['AF'] += 1
1153                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['SEQ_ID'].append(seq_ids[snp])
1154            # calculate AF
1155            if pos in snp_dict['POS']:
1156                for alt in snp_dict['POS'][pos]['ALT']:
1157                    snp_dict['POS'][pos]['ALT'][alt]['AF'] /= len(aln)
1158
1159        return snp_dict
1160
1161    def calc_transition_transversion_score(self) -> list:
1162        """
1163        Based on the snp positions, calculates a transition/transversions score.
1164        A positive score means higher ratio of transitions and negative score means
1165        a higher ratio of transversions.
1166        :return: list
1167        """
1168
1169        if self.aln_type == 'AA':
1170            raise TypeError('TS/TV scoring only for RNA/DNA alignments')
1171
1172        # ini
1173        snps = self.get_snps()
1174        score = [0]*self.length
1175
1176        for pos in snps['POS']:
1177            t_score_temp = 0
1178            for alt in snps['POS'][pos]['ALT']:
1179                # check the type of substitution
1180                if snps['POS'][pos]['ref'] + alt in ['AG', 'GA', 'CT', 'TC', 'CU', 'UC']:
1181                    score[pos] += snps['POS'][pos]['ALT'][alt]['AF']
1182                else:
1183                    score[pos] -= snps['POS'][pos]['ALT'][alt]['AF']
1184
1185        return score

An alignment class that allows computation of several stats

MSA( alignment_string: str, reference_id: str = None, zoom_range: tuple | int = None)
42    def __init__(self, alignment_string: str, reference_id: str = None, zoom_range: tuple | int = None):
43        """
44        Initialise an Alignment object.
45        :param alignment_string: Path to alignment file or raw alignment string
46        :param reference_id: reference id
47        :param zoom_range: start and stop positions to zoom into the alignment
48        """
49        self._alignment = self._read_alignment(alignment_string)
50        self._reference_id = self._validate_ref(reference_id, self._alignment)
51        self._zoom = self._validate_zoom(zoom_range, self._alignment)
52        self._aln_type = self._determine_aln_type(self._alignment)

Initialise an Alignment object.

Parameters
  • alignment_string: Path to alignment file or raw alignment string
  • reference_id: reference id
  • zoom_range: start and stop positions to zoom into the alignment
reference_id
183    @property
184    def reference_id(self):
185        return self._reference_id

Set and validate the reference id.

zoom: tuple
194    @property
195    def zoom(self) -> tuple:
196        return self._zoom

Validate if the user defined zoom range.

aln_type: str
206    @property
207    def aln_type(self) -> str:
208        """
209        define the aln type:
210        RNA, DNA or AA
211        """
212        return self._aln_type

define the aln type: RNA, DNA or AA

length: int
215    @property
216    def length(self) -> int:
217        return len(next(iter(self.alignment.values())))
alignment: dict
219    @property
220    def alignment(self) -> dict:
221        """
222        (zoomed) version of the alignment.
223        """
224        if self.zoom is not None:
225            zoomed_aln = dict()
226            for seq in self._alignment:
227                zoomed_aln[seq] = self._alignment[seq][self.zoom[0]:self.zoom[1]]
228            return zoomed_aln
229        else:
230            return self._alignment

(zoomed) version of the alignment.

def get_reference_coords(self) -> tuple[int, int]:
233    def get_reference_coords(self) -> tuple[int, int]:
234        """
235        Determine the start and end coordinates of the reference sequence
236        defined as the first/last nucleotide in the reference sequence
237        (excluding N and gaps).
238
239        :return: start, end
240        """
241        start, end = 0, self.length
242
243        if self.reference_id is None:
244            return start, end
245        else:
246            # 5' --> 3'
247            for start in range(self.length):
248                if self.alignment[self.reference_id][start] not in ['-', 'N']:
249                    break
250            # 3' --> 5'
251            for end in range(self.length - 1, 0, -1):
252                if self.alignment[self.reference_id][end] not in ['-', 'N']:
253                    break
254
255            return start, end

Determine the start and end coordinates of the reference sequence defined as the first/last nucleotide in the reference sequence (excluding N and gaps).

Returns

start, end

def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
257    def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
258        """
259        Creates a non-gapped consensus sequence.
260
261        :param threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes
262            the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments)
263            or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
264        :param use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position
265            has a frequency above the defined threshold.
266        :return: consensus sequence
267        """
268
269        # helper functions
270        def determine_counts(alignment_dict: dict, position: int) -> dict:
271            """
272            count the number of each char at
273            an idx of the alignment. return sorted dic.
274            handles ambiguous nucleotides in sequences.
275            also handles gaps.
276            """
277            nucleotide_list = []
278
279            # get all nucleotides
280            for sequence in alignment_dict.items():
281                nucleotide_list.append(sequence[1][position])
282            # count occurences of nucleotides
283            counter = dict(collections.Counter(nucleotide_list))
284            # get permutations of an ambiguous nucleotide
285            to_delete = []
286            temp_dict = {}
287            for nucleotide in counter:
288                if nucleotide in config.AMBIG_CHARS[self.aln_type]:
289                    to_delete.append(nucleotide)
290                    permutations = config.AMBIG_CHARS[self.aln_type][nucleotide]
291                    adjusted_freq = 1 / len(permutations)
292                    for permutation in permutations:
293                        if permutation in temp_dict:
294                            temp_dict[permutation] += adjusted_freq
295                        else:
296                            temp_dict[permutation] = adjusted_freq
297
298            # drop ambiguous entries and add adjusted freqs to
299            if to_delete:
300                for i in to_delete:
301                    counter.pop(i)
302                for nucleotide in temp_dict:
303                    if nucleotide in counter:
304                        counter[nucleotide] += temp_dict[nucleotide]
305                    else:
306                        counter[nucleotide] = temp_dict[nucleotide]
307
308            return dict(sorted(counter.items(), key=lambda x: x[1], reverse=True))
309
310        def get_consensus_char(counts: dict, cutoff: float) -> list:
311            """
312            get a list of nucleotides for the consensus seq
313            """
314            n = 0
315
316            consensus_chars = []
317            for char in counts:
318                n += counts[char]
319                consensus_chars.append(char)
320                if n >= cutoff:
321                    break
322
323            return consensus_chars
324
325        def get_ambiguous_char(nucleotides: list) -> str:
326            """
327            get ambiguous char from a list of nucleotides
328            """
329            for ambiguous, permutations in config.AMBIG_CHARS[self.aln_type].items():
330                if set(permutations) == set(nucleotides):
331                    break
332
333            return ambiguous
334
335        # check if params have been set correctly
336        if threshold is not None:
337            if threshold < 0 or threshold > 1:
338                raise ValueError('Threshold must be between 0 and 1.')
339        if self.aln_type == 'AA' and use_ambig_nt:
340            raise ValueError('Ambiguous characters can not be calculated for amino acid alignments.')
341        if threshold is None and use_ambig_nt:
342            raise ValueError('To calculate ambiguous nucleotides, set a threshold > 0.')
343
344        alignment = self.alignment
345        consensus = str()
346
347        if threshold is not None:
348            consensus_cutoff = len(alignment) * threshold
349        else:
350            consensus_cutoff = 0
351
352        # built consensus sequences
353        for idx in range(self.length):
354            char_counts = determine_counts(alignment, idx)
355            consensus_chars = get_consensus_char(
356                char_counts,
357                consensus_cutoff
358            )
359            if threshold != 0:
360                if len(consensus_chars) > 1:
361                    if use_ambig_nt:
362                        char = get_ambiguous_char(consensus_chars)
363                    else:
364                        if self.aln_type == 'AA':
365                            char = 'X'
366                        else:
367                            char = 'N'
368                    consensus = consensus + char
369                else:
370                    consensus = consensus + consensus_chars[0]
371            else:
372                consensus = consensus + consensus_chars[0]
373
374        return consensus

Creates a non-gapped consensus sequence.

Parameters
  • threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments) or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
  • use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position has a frequency above the defined threshold.
Returns

consensus sequence

def get_conserved_orfs( self, min_length: int = 100, identity_cutoff: float | None = None) -> dict:
376    def get_conserved_orfs(self, min_length: int = 100, identity_cutoff: float | None = None) -> dict:
377        """
378        **conserved ORF definition:**
379            - conserved starts and stops
380            - start, stop must be on the same frame
381            - stop - start must be at least min_length
382            - all ungapped seqs[start:stop] must have at least min_length
383            - no ungapped seq can have a Stop in between Start Stop
384
385        Conservation is measured by number of positions with identical characters divided by
386        orf slice of the alignment.
387
388        **Algorithm overview:**
389            - check for conserved start and stop codons
390            - iterate over all three frames
391            - check each start and next sufficiently far away stop codon
392            - check if all ungapped seqs between start and stop codon are >= min_length
393            - check if no ungapped seq in the alignment has a stop codon
394            - write to dictionary
395            - classify as internal if the stop codon has already been written with a prior start
396            - repeat for reverse complement
397
398        :return: ORF positions and internal ORF positions
399        """
400
401        # helper functions
402        def determine_conserved_start_stops(alignment: dict, alignment_length: int) -> tuple:
403            """
404            Determine all start and stop codons within an alignment.
405            :param alignment: alignment
406            :param alignment_length: length of alignment
407            :return: start and stop codon positions
408            """
409            starts = config.START_CODONS[self.aln_type]
410            stops = config.STOP_CODONS[self.aln_type]
411
412            list_of_starts, list_of_stops = [], []
413            ref = alignment[list(alignment.keys())[0]]
414            for nt_position in range(alignment_length):
415                if ref[nt_position:nt_position + 3] in starts:
416                    conserved_start = True
417                    for sequence in alignment:
418                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in starts:
419                            conserved_start = False
420                            break
421                    if conserved_start:
422                        list_of_starts.append(nt_position)
423
424                if ref[nt_position:nt_position + 3] in stops:
425                    conserved_stop = True
426                    for sequence in alignment:
427                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in stops:
428                            conserved_stop = False
429                            break
430                    if conserved_stop:
431                        list_of_stops.append(nt_position)
432
433            return list_of_starts, list_of_stops
434
435        def get_ungapped_sliced_seqs(alignment: dict, start_pos: int, stop_pos: int) -> list:
436            """
437            get ungapped sequences starting and stop codons and eliminate gaps
438            :param alignment: alignment
439            :param start_pos: start codon
440            :param stop_pos: stop codon
441            :return: sliced sequences
442            """
443            ungapped_seqs = []
444            for seq_id in alignment:
445                ungapped_seqs.append(alignment[seq_id][start_pos:stop_pos + 3].replace('-', ''))
446
447            return ungapped_seqs
448
449        def additional_stops(ungapped_seqs: list) -> bool:
450            """
451            Checks for the presence of a stop codon
452            :param ungapped_seqs: list of ungapped sequences
453            :return: Additional stop codons (True/False)
454            """
455            stops = config.STOP_CODONS[self.aln_type]
456
457            for sliced_seq in ungapped_seqs:
458                for position in range(0, len(sliced_seq) - 3, 3):
459                    if sliced_seq[position:position + 3] in stops:
460                        return True
461            return False
462
463        def calculate_identity(identity_matrix: ndarray, aln_slice:list) -> float:
464            sliced_array = identity_matrix[:,aln_slice[0]:aln_slice[1]] + 1  # identical = 0, different = -1 --> add 1
465            return np.sum(np.all(sliced_array == 1, axis=0))/(aln_slice[1] - aln_slice[0]) * 100
466
467        # checks for arguments
468        if self.aln_type == 'AA':
469            raise TypeError('ORF search only for RNA/DNA alignments')
470
471        if identity_cutoff is not None:
472            if identity_cutoff > 100 or identity_cutoff < 0:
473                raise ValueError('conservation cutoff must be between 0 and 100')
474
475        if min_length <= 6 or min_length > self.length:
476            raise ValueError(f'min_length must be between 6 and {self.length}')
477
478        # ini
479        identities = self.calc_identity_alignment()
480        alignments = [self.alignment, self.calc_reverse_complement_alignment()]
481        aln_len = self.length
482
483        orf_counter = 0
484        orf_dict = {}
485
486        for aln, direction in zip(alignments, ['+', '-']):
487            # check for starts and stops in the first seq and then check if these are present in all seqs
488            conserved_starts, conserved_stops = determine_conserved_start_stops(aln, aln_len)
489            # check each frame
490            for frame in (0, 1, 2):
491                potential_starts = [x for x in conserved_starts if x % 3 == frame]
492                potential_stops = [x for x in conserved_stops if x % 3 == frame]
493                last_stop = -1
494                for start in potential_starts:
495                    # go to the next stop that is sufficiently far away in the alignment
496                    next_stops = [x for x in potential_stops if x + 3 >= start + min_length]
497                    if not next_stops:
498                        continue
499                    next_stop = next_stops[0]
500                    ungapped_sliced_seqs = get_ungapped_sliced_seqs(aln, start, next_stop)
501                    # re-check the lengths of all ungapped seqs
502                    ungapped_seq_lengths = [len(x) >= min_length for x in ungapped_sliced_seqs]
503                    if not all(ungapped_seq_lengths):
504                        continue
505                    # if no stop codon between start and stop --> write to dictionary
506                    if not additional_stops(ungapped_sliced_seqs):
507                        if direction == '+':
508                            positions = [start, next_stop + 3]
509                        else:
510                            positions = [aln_len - next_stop - 3, aln_len - start]
511                        if last_stop != next_stop:
512                            last_stop = next_stop
513                            conservation = calculate_identity(identities, positions)
514                            if identity_cutoff is not None and conservation < identity_cutoff:
515                                continue
516                            orf_dict[f'ORF_{orf_counter}'] = {'location': [positions],
517                                                              'frame': frame,
518                                                              'strand': direction,
519                                                              'conservation': conservation,
520                                                              'internal': []
521                                                              }
522                            orf_counter += 1
523                        else:
524                            if orf_dict:
525                                orf_dict[f'ORF_{orf_counter - 1}']['internal'].append(positions)
526
527        return orf_dict

conserved ORF definition: - conserved starts and stops - start, stop must be on the same frame - stop - start must be at least min_length - all ungapped seqs[start:stop] must have at least min_length - no ungapped seq can have a Stop in between Start Stop

Conservation is measured by number of positions with identical characters divided by orf slice of the alignment.

Algorithm overview: - check for conserved start and stop codons - iterate over all three frames - check each start and next sufficiently far away stop codon - check if all ungapped seqs between start and stop codon are >= min_length - check if no ungapped seq in the alignment has a stop codon - write to dictionary - classify as internal if the stop codon has already been written with a prior start - repeat for reverse complement

Returns

ORF positions and internal ORF positions

def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff: float = None) -> dict:
529    def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff:float = None) -> dict:
530        """
531        First calculates all ORFs and then searches from 5'
532        all non-overlapping orfs in the fw strand and from the
533        3' all non-overlapping orfs in th rw strand.
534
535        **No overlap algorithm:**
536            **frame 1:** -[M------*]--- ----[M--*]---------[M-----
537
538            **frame 2:** -------[M------*]---------[M---*]--------
539
540            **frame 3:** [M---*]-----[M----------*]----------[M---
541
542            **results:** [M---*][M------*]--[M--*]-[M---*]-[M-----
543
544            frame:    3      2           1      2       1
545
546        :return: dictionary with non-overlapping orfs
547        """
548        orf_dict = self.get_conserved_orfs(min_length, identity_cutoff)
549
550        fw_orfs, rw_orfs = [], []
551
552        for orf in orf_dict:
553            if orf_dict[orf]['strand'] == '+':
554                fw_orfs.append((orf, orf_dict[orf]['location'][0]))
555            else:
556                rw_orfs.append((orf, orf_dict[orf]['location'][0]))
557
558        fw_orfs.sort(key=lambda x: x[1][0])  # sort by start pos
559        rw_orfs.sort(key=lambda x: x[1][1], reverse=True)  # sort by stop pos
560        non_overlapping_orfs = []
561        for orf_list, strand in zip([fw_orfs, rw_orfs], ['+', '-']):
562            previous_stop = -1 if strand == '+' else self.length + 1
563            for orf in orf_list:
564                if strand == '+' and orf[1][0] > previous_stop:
565                    non_overlapping_orfs.append(orf[0])
566                    previous_stop = orf[1][1]
567                elif strand == '-' and orf[1][1] < previous_stop:
568                    non_overlapping_orfs.append(orf[0])
569                    previous_stop = orf[1][0]
570
571        non_overlap_dict = {}
572        for orf in orf_dict:
573            if orf in non_overlapping_orfs:
574                non_overlap_dict[orf] = orf_dict[orf]
575
576        return non_overlap_dict

First calculates all ORFs and then searches from 5' all non-overlapping orfs in the fw strand and from the 3' all non-overlapping orfs in th rw strand.

No overlap algorithm: frame 1: -[M------]--- ----[M--]---------[M-----

**frame 2:** -------[M------*]---------[M---*]--------

**frame 3:** [M---*]-----[M----------*]----------[M---

**results:** [M---*][M------*]--[M--*]-[M---*]-[M-----

frame:    3      2           1      2       1
Returns

dictionary with non-overlapping orfs

def calc_length_stats(self) -> dict:
578    def calc_length_stats(self) -> dict:
579        """
580        Determine the stats for the length of the ungapped seqs in the alignment.
581        :return: dictionary with length stats
582        """
583
584        seq_lengths = [len(self.alignment[x].replace('-', '')) for x in self.alignment]
585
586        return {'number of seq': len(self.alignment),
587                'mean length': float(np.mean(seq_lengths)),
588                'std length': float(np.std(seq_lengths)),
589                'min length': int(np.min(seq_lengths)),
590                'max length': int(np.max(seq_lengths))
591                }

Determine the stats for the length of the ungapped seqs in the alignment.

Returns

dictionary with length stats

def calc_entropy(self) -> list:
593    def calc_entropy(self) -> list:
594        """
595        Calculate the normalized shannon's entropy for every position in an alignment:
596
597        - 1: high entropy
598        - 0: low entropy
599
600        :return: Entropies at each position.
601        """
602
603        # helper functions
604        def shannons_entropy(character_list: list, states: int, aln_type: str) -> float:
605            """
606            Calculate the shannon's entropy of a sequence and
607            normalized between 0 and 1.
608            :param character_list: characters at an alignment position
609            :param states: number of potential characters that can be present
610            :param aln_type: type of the alignment
611            :returns: entropy
612            """
613            ent, n_chars = 0, len(character_list)
614            # only one char is in the list
615            if n_chars <= 1:
616                return ent
617            # calculate the number of unique chars and their counts
618            chars, char_counts = np.unique(character_list, return_counts=True)
619            char_counts = char_counts.astype(float)
620            # ignore gaps for entropy calc
621            char_counts, chars = char_counts[chars != "-"], chars[chars != "-"]
622            # correctly handle ambiguous chars
623            index_to_drop = []
624            for index, char in enumerate(chars):
625                if char in config.AMBIG_CHARS[aln_type]:
626                    index_to_drop.append(index)
627                    amb_chars, amb_counts = np.unique(config.AMBIG_CHARS[aln_type][char], return_counts=True)
628                    amb_counts = amb_counts / len(config.AMBIG_CHARS[aln_type][char])
629                    # add the proportionate numbers to initial array
630                    for amb_char, amb_count in zip(amb_chars, amb_counts):
631                        if amb_char in chars:
632                            char_counts[chars == amb_char] += amb_count
633                        else:
634                            chars, char_counts = np.append(chars, amb_char), np.append(char_counts, amb_count)
635            # drop the ambiguous characters from array
636            char_counts, chars = np.delete(char_counts, index_to_drop), np.delete(chars, index_to_drop)
637            # calc the entropy
638            probs = char_counts / n_chars
639            if np.count_nonzero(probs) <= 1:
640                return ent
641            for prob in probs:
642                ent -= prob * math.log(prob, states)
643
644            return ent
645
646        aln = self.alignment
647        entropys = []
648
649        if self.aln_type == 'AA':
650            states = 20
651        else:
652            states = 4
653        # iterate over alignment positions and the sequences
654        for nuc_pos in range(self.length):
655            pos = []
656            for record in aln:
657                pos.append(aln[record][nuc_pos])
658            entropys.append(shannons_entropy(pos, states, self.aln_type))
659
660        return entropys

Calculate the normalized shannon's entropy for every position in an alignment:

  • 1: high entropy
  • 0: low entropy
Returns

Entropies at each position.

def calc_gc(self) -> list | TypeError:
662    def calc_gc(self) -> list | TypeError:
663        """
664        Determine the GC content for every position in an nt alignment.
665        :return: GC content for every position.
666        :raises: TypeError for AA alignments
667        """
668        if self.aln_type == 'AA':
669            raise TypeError("GC computation is not possible for aminoacid alignment")
670
671        gc, aln, amb_nucs = [], self.alignment, config.AMBIG_CHARS[self.aln_type]
672
673        for position in range(self.length):
674            nucleotides = str()
675            for record in aln:
676                nucleotides = nucleotides + aln[record][position]
677            # ini dict with chars that occur and which ones to
678            # count in which freq
679            to_count = {
680                'G': 1,
681                'C': 1,
682            }
683            # handle ambig. nuc
684            for char in amb_nucs:
685                if char in nucleotides:
686                    to_count[char] = (amb_nucs[char].count('C') + amb_nucs[char].count('G')) / len(amb_nucs[char])
687
688            gc.append(
689                sum([nucleotides.count(x) * to_count[x] for x in to_count]) / len(nucleotides)
690            )
691
692        return gc

Determine the GC content for every position in an nt alignment.

Returns

GC content for every position.

Raises
  • TypeError for AA alignments
def calc_coverage(self) -> list:
694    def calc_coverage(self) -> list:
695        """
696        Determine the coverage of every position in an alignment.
697        This is defined as:
698            1 - cumulative length of '-' characters
699
700        :return: Coverage at each alignment position.
701        """
702        coverage, aln = [], self.alignment
703
704        for nuc_pos in range(self.length):
705            pos = str()
706            for record in aln.keys():
707                pos = pos + aln[record][nuc_pos]
708            coverage.append(1 - pos.count('-') / len(pos))
709
710        return coverage

Determine the coverage of every position in an alignment. This is defined as: 1 - cumulative length of '-' characters

Returns

Coverage at each alignment position.

def calc_reverse_complement_alignment(self) -> dict | TypeError:
712    def calc_reverse_complement_alignment(self) -> dict | TypeError:
713        """
714        Reverse complement the alignment.
715        :return: Alignment (rv)
716        """
717        if self.aln_type == 'AA':
718            raise TypeError('Reverse complement only for RNA or DNA.')
719
720        aln = self.alignment
721        reverse_complement_dict = {}
722
723        for seq_id in aln:
724            reverse_complement_dict[seq_id] = ''.join(config.COMPLEMENT[base] for base in reversed(aln[seq_id]))
725
726        return reverse_complement_dict

Reverse complement the alignment.

Returns

Alignment (rv)

def calc_identity_alignment( self, encode_mismatches: bool = True, encode_mask: bool = False, encode_gaps: bool = True, encode_ambiguities: bool = False, encode_each_mismatch_char: bool = False) -> numpy.ndarray:
728    def calc_identity_alignment(self, encode_mismatches:bool=True, encode_mask:bool=False, encode_gaps:bool=True, encode_ambiguities:bool=False, encode_each_mismatch_char:bool=False) -> np.ndarray:
729        """
730        Converts alignment to identity array (identical=0) compared to majority consensus or reference:\n
731
732        :param encode_mismatches: encode mismatch as -1
733        :param encode_mask: encode mask with value=-2 --> also in the reference
734        :param encode_gaps: encode gaps with np.nan --> also in the reference
735        :param encode_ambiguities: encode ambiguities with value=-3
736        :param encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.DNA_colors, config.RNA_colors or config.AA_colors
737        :return: identity alignment
738        """
739
740        aln = self.alignment
741        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
742
743        # convert alignment to array
744        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
745        reference = np.array(list(ref))
746        # ini matrix
747        identity_matrix = np.full(sequences.shape, 0, dtype=float)
748
749        is_identical = sequences == reference
750
751        if encode_gaps:
752            is_gap = sequences == '-'
753        else:
754            is_gap = np.full(sequences.shape, False)
755
756        if encode_mask:
757            if self.aln_type == 'AA':
758                is_n_or_x = np.isin(sequences, ['X'])
759            else:
760                is_n_or_x = np.isin(sequences, ['N'])
761        else:
762            is_n_or_x = np.full(sequences.shape, False)
763
764        if encode_ambiguities:
765            is_ambig = np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])
766        else:
767            is_ambig = np.full(sequences.shape, False)
768
769        if encode_mismatches:
770            is_mismatch = ~is_gap & ~is_identical & ~is_n_or_x & ~is_ambig
771        else:
772            is_mismatch = np.full(sequences.shape, False)
773
774        # encode every different character
775        if encode_each_mismatch_char:
776            for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]['standard']):
777                new_encoding = np.isin(sequences, [char]) & is_mismatch
778                identity_matrix[new_encoding] = idx + 1
779        # or encode different with a single value
780        else:
781            identity_matrix[is_mismatch] = -1  # mismatch
782
783        identity_matrix[is_gap] = np.nan  # gap
784        identity_matrix[is_n_or_x] = -2  # 'N' or 'X'
785        identity_matrix[is_ambig] = -3  # ambiguities
786
787        return identity_matrix

Converts alignment to identity array (identical=0) compared to majority consensus or reference:

Parameters
  • encode_mismatches: encode mismatch as -1
  • encode_mask: encode mask with value=-2 --> also in the reference
  • encode_gaps: encode gaps with np.nan --> also in the reference
  • encode_ambiguities: encode ambiguities with value=-3
  • encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.DNA_colors, config.RNA_colors or config.AA_colors
Returns

identity alignment

def calc_similarity_alignment( self, matrix_type: str | None = None, normalize: bool = True) -> numpy.ndarray:
789    def calc_similarity_alignment(self, matrix_type:str|None=None, normalize:bool=True) -> np.ndarray:
790        """
791        Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight
792        differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the
793        reference residue at each column. Gaps are encoded as np.nan.
794
795        The calculation follows these steps:
796
797        1. **Reference Sequence**: If a reference sequence is provided (via `self.reference_id`), it is used. Otherwise,
798           a consensus sequence is generated to serve as the reference.
799        2. **Substitution Matrix**: The similarity between residues is determined using a substitution matrix, such as
800           BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
801        3. **Per-Column Normalization (optional)**:
802           - For each column in the alignment:
803             - The residue in the reference sequence is treated as the baseline for that column.
804             - The substitution scores for the reference residue are extracted from the substitution matrix.
805             - The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference
806               residue.
807           - This ensures that identical residues (or those with high similarity to the reference) have high scores,
808             while more dissimilar residues have lower scores.
809        4. **Output**:
810           - The normalized similarity scores are stored in a NumPy array.
811           - Gaps (if any) or residues not present in the substitution matrix are encoded as `np.nan`.
812
813        :param: matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
814        :param: normalize: whether to normalize the similarity scores to range [0, 1]
815        :return: A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue
816            and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity).
817            Gaps and invalid residues are encoded as `np.nan`.
818        :raise: ValueError
819            If the specified substitution matrix is not available for the given alignment type.
820        """
821
822        aln = self.alignment
823        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
824        if matrix_type is None:
825            if self.aln_type == 'AA':
826                matrix_type = 'BLOSUM65'
827            else:
828                matrix_type = 'TRANS'
829        # load substitution matrix as dictionary
830        try:
831            subs_matrix = config.SUBS_MATRICES[self.aln_type][matrix_type]
832        except KeyError:
833            raise ValueError(
834                f'The specified matrix does not exist for alignment type.\nAvailable matrices for {self.aln_type} are:\n{list(config.SUBS_MATRICES[self.aln_type].keys())}'
835            )
836
837        # set dtype and convert alignment to a NumPy array for vectorized processing
838        dtype = np.dtype(float, metadata={'matrix': matrix_type})
839        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
840        reference = np.array(list(ref))
841        valid_chars = list(subs_matrix.keys())
842        similarity_array = np.full(sequences.shape, np.nan, dtype=dtype)
843
844        for j, ref_char in enumerate(reference):
845            if ref_char not in valid_chars + ['-']:
846                continue
847            # Get local min and max for the reference residue
848            if normalize and ref_char != '-':
849                local_scores = subs_matrix[ref_char].values()
850                local_min, local_max = min(local_scores), max(local_scores)
851
852            for i, char in enumerate(sequences[:, j]):
853                if char not in valid_chars:
854                    continue
855                # classify the similarity as max if the reference has a gap
856                similarity_score = subs_matrix[char][ref_char] if ref_char != '-' else 1
857                similarity_array[i, j] = (similarity_score - local_min) / (local_max - local_min) if normalize and ref_char != '-' else similarity_score
858
859        return similarity_array

Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the reference residue at each column. Gaps are encoded as np.nan.

The calculation follows these steps:

  1. Reference Sequence: If a reference sequence is provided (via self.reference_id), it is used. Otherwise, a consensus sequence is generated to serve as the reference.
  2. Substitution Matrix: The similarity between residues is determined using a substitution matrix, such as BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
  3. Per-Column Normalization (optional):
    • For each column in the alignment:
      • The residue in the reference sequence is treated as the baseline for that column.
      • The substitution scores for the reference residue are extracted from the substitution matrix.
      • The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference residue.
    • This ensures that identical residues (or those with high similarity to the reference) have high scores, while more dissimilar residues have lower scores.
  4. Output:
    • The normalized similarity scores are stored in a NumPy array.
    • Gaps (if any) or residues not present in the substitution matrix are encoded as np.nan.
Parameters
  • matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
  • normalize: whether to normalize the similarity scores to range [0, 1]
Returns

A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity). Gaps and invalid residues are encoded as np.nan. :raise: ValueError If the specified substitution matrix is not available for the given alignment type.

def calc_position_matrix(self, matrix_type: str = 'PWM') -> numpy.ndarray | ValueError:
861    def calc_position_matrix(self, matrix_type:str='PWM') -> np.ndarray | ValueError:
862        """
863        Calculates a position matrix of the specified type for the given alignment. The function
864        supports generating matrices of types Position Frequency Matrix (PFM), Position Probability
865        Matrix (PPM), Position Weight Matrix (PWM), and cummulative Information Content (IC). It validates
866        the provided matrix type and includes pseudo-count adjustments to ensure robust calculations.
867
868        :param matrix_type: Type of position matrix to calculate. Accepted values are 'PFM', 'PPM',
869            'PWM', and 'IC'. Defaults to 'PWM'.
870        :type matrix_type: str
871        :raises ValueError: If the provided `matrix_type` is not one of the accepted values.
872        :return: A numpy array representing the calculated position matrix of the specified type.
873        :rtype: np.ndarray
874        """
875
876        # ini
877        aln = self.alignment
878        if matrix_type not in ['PFM', 'PPM', 'IC', 'PWM']:
879            raise ValueError('Matrix_type must be PFM, PPM, IC or PWM.')
880        possible_chars = list(config.CHAR_COLORS[self.aln_type]['standard'].keys())[:-1]
881        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
882
883        # calc position frequency matrix
884        pfm = np.array([np.sum(sequences == char, 0) for char in possible_chars])
885        if matrix_type == 'PFM':
886            return pfm
887
888        # calc position probability matrix (probability)
889        pseudo_count = 0.0001  # to avoid 0 values
890        pfm = pfm + pseudo_count
891        ppm_non_char_excluded = pfm/np.sum(pfm, axis=0)  # use this for pwm/ic calculation
892        ppm = pfm/len(aln.keys())  # calculate the frequency based on row number
893        if matrix_type == 'PPM':
894            return ppm
895
896        # calc position weight matrix (log-likelihood)
897        pwm = np.log2(ppm_non_char_excluded * len(possible_chars))
898        if matrix_type == 'PWM':
899            return pwm
900
901        # calc information content per position (in bits) - can be used to scale a ppm for sequence logos
902        ic = np.sum(ppm_non_char_excluded * pwm, axis=0)
903        if matrix_type == 'IC':
904            return ic
905
906        return None

Calculates a position matrix of the specified type for the given alignment. The function supports generating matrices of types Position Frequency Matrix (PFM), Position Probability Matrix (PPM), Position Weight Matrix (PWM), and cummulative Information Content (IC). It validates the provided matrix type and includes pseudo-count adjustments to ensure robust calculations.

Parameters
  • matrix_type: Type of position matrix to calculate. Accepted values are 'PFM', 'PPM', 'PWM', and 'IC'. Defaults to 'PWM'.
Raises
  • ValueError: If the provided matrix_type is not one of the accepted values.
Returns

A numpy array representing the calculated position matrix of the specified type.

def calc_percent_recovery(self) -> dict:
908    def calc_percent_recovery(self) -> dict:
909        """
910        Recovery per sequence either compared to the majority consensus seq
911        or the reference seq.\n
912        Defined as:\n
913
914        `(1 - sum(N/X/- characters in ungapped ref regions))*100`
915
916        This is highly similar to how nextclade calculates recovery over reference.
917
918        :return: dict
919        """
920
921        aln = self.alignment
922
923        if self.reference_id is not None:
924            ref = aln[self.reference_id]
925        else:
926            ref = self.get_consensus()  # majority consensus
927
928        if not any(char != '-' for char in ref):
929            raise ValueError("Reference sequence is entirely gapped, cannot calculate recovery.")
930
931
932        # count 'N', 'X' and '-' chars in non-gapped regions
933        recovery_over_ref = dict()
934
935        # Get positions of non-gap characters in the reference
936        non_gap_positions = [i for i, char in enumerate(ref) if char != '-']
937        cumulative_length = len(non_gap_positions)
938
939        # Calculate recovery
940        for seq_id in aln:
941            if seq_id == self.reference_id:
942                continue
943            seq = aln[seq_id]
944            count_invalid = sum(
945                seq[pos] == '-' or
946                (seq[pos] == 'X' if self.aln_type == "AA" else seq[pos] == 'N')
947                for pos in non_gap_positions
948            )
949            recovery_over_ref[seq_id] = (1 - count_invalid / cumulative_length) * 100
950
951        return recovery_over_ref

Recovery per sequence either compared to the majority consensus seq or the reference seq.

Defined as:

(1 - sum(N/X/- characters in ungapped ref regions))*100

This is highly similar to how nextclade calculates recovery over reference.

Returns

dict

def calc_character_frequencies(self) -> dict:
 953    def calc_character_frequencies(self) -> dict:
 954        """
 955        Calculate the percentage characters in the alignment:
 956        The frequencies are counted by seq and in total. The
 957        percentage of non-gap characters in the alignment is
 958        relative to the total number of non-gap characters.
 959        The gap percentage is relative to the sequence length.
 960
 961        The output is a nested dictionary.
 962
 963        :return: Character frequencies
 964        """
 965
 966        aln, aln_length = self.alignment, self.length
 967
 968        freqs = {'total': {'-': {'counts': 0, '% of alignment': float()}}}
 969
 970        for seq_id in aln:
 971            freqs[seq_id], all_chars = {'-': {'counts': 0, '% of alignment': float()}}, 0
 972            unique_chars = set(aln[seq_id])
 973            for char in unique_chars:
 974                if char == '-':
 975                    continue
 976                # add characters to dictionaries
 977                if char not in freqs[seq_id]:
 978                    freqs[seq_id][char] = {'counts': 0, '% of non-gapped': 0}
 979                if char not in freqs['total']:
 980                    freqs['total'][char] = {'counts': 0, '% of non-gapped': 0}
 981                # count non-gap chars
 982                freqs[seq_id][char]['counts'] += aln[seq_id].count(char)
 983                freqs['total'][char]['counts'] += freqs[seq_id][char]['counts']
 984                all_chars += freqs[seq_id][char]['counts']
 985            # normalize counts
 986            for char in freqs[seq_id]:
 987                if char == '-':
 988                    continue
 989                freqs[seq_id][char]['% of non-gapped'] = freqs[seq_id][char]['counts'] / all_chars * 100
 990                freqs['total'][char]['% of non-gapped'] += freqs[seq_id][char]['% of non-gapped']
 991            # count gaps
 992            freqs[seq_id]['-']['counts'] = aln[seq_id].count('-')
 993            freqs['total']['-']['counts'] += freqs[seq_id]['-']['counts']
 994            # normalize gap counts
 995            freqs[seq_id]['-']['% of alignment'] = freqs[seq_id]['-']['counts'] / aln_length * 100
 996            freqs['total']['-']['% of alignment'] += freqs[seq_id]['-']['% of alignment']
 997
 998        # normalize the total counts
 999        for char in freqs['total']:
1000            for value in freqs['total'][char]:
1001                if value == '% of alignment' or value == '% of non-gapped':
1002                    freqs['total'][char][value] = freqs['total'][char][value] / len(aln)
1003
1004        return freqs

Calculate the percentage characters in the alignment: The frequencies are counted by seq and in total. The percentage of non-gap characters in the alignment is relative to the total number of non-gap characters. The gap percentage is relative to the sequence length.

The output is a nested dictionary.

Returns

Character frequencies

def calc_pairwise_identity_matrix(self, distance_type: str = 'ghd') -> numpy.ndarray:
1006    def calc_pairwise_identity_matrix(self, distance_type:str='ghd') -> ndarray:
1007        """
1008        Calculate pairwise identities for an alignment. As there are different definitions of sequence identity, there are different options implemented:
1009
1010        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
1011        \ndistance = matches / alignment_length * 100
1012
1013        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
1014        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
1015
1016        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
1017        \ndistance = matches / (matches + mismatches) * 100
1018
1019        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
1020        \ndistance = matches / gap_compressed_alignment_length * 100
1021
1022        :return: array with pairwise distances.
1023        """
1024
1025        def hamming_distance(seq1: str, seq2: str) -> int:
1026            return sum(c1 == c2 for c1, c2 in zip(seq1, seq2))
1027
1028        def ghd(seq1: str, seq2: str) -> float:
1029            return hamming_distance(seq1, seq2) / self.length * 100
1030
1031        def lhd(seq1, seq2):
1032            # Trim gaps from both sides
1033            i, j = 0, self.length - 1
1034            while i < self.length and (seq1[i] == '-' or seq2[i] == '-'):
1035                i += 1
1036            while j >= 0 and (seq1[j] == '-' or seq2[j] == '-'):
1037                j -= 1
1038            if i > j:
1039                return 0.0
1040
1041            seq1_, seq2_ = seq1[i:j + 1], seq2[i:j + 1]
1042            matches = sum(c1 == c2 for c1, c2 in zip(seq1_, seq2_))
1043            length = j - i + 1
1044            return (matches / length) * 100 if length > 0 else 0.0
1045
1046        def ged(seq1: str, seq2: str) -> float:
1047
1048            matches, mismatches = 0, 0
1049
1050            for c1, c2 in zip(seq1, seq2):
1051                if c1 != '-' and c2 != '-':
1052                    if c1 == c2:
1053                        matches += 1
1054                    else:
1055                        mismatches += 1
1056            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1057
1058        def gcd(seq1: str, seq2: str) -> float:
1059            matches = 0
1060            mismatches = 0
1061            in_gap = False
1062
1063            for char1, char2 in zip(seq1, seq2):
1064                if char1 == '-' and char2 == '-':  # Shared gap: do nothing
1065                    continue
1066                elif char1 == '-' or char2 == '-':  # Gap in only one sequence
1067                    if not in_gap:  # Start of a new gap stretch
1068                        mismatches += 1
1069                        in_gap = True
1070                else:  # No gaps
1071                    in_gap = False
1072                    if char1 == char2:  # Matching characters
1073                        matches += 1
1074                    else:  # Mismatched characters
1075                        mismatches += 1
1076
1077            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1078
1079
1080        # Map distance type to corresponding function
1081        distance_functions: Dict[str, Callable[[str, str], float]] = {
1082            'ghd': ghd,
1083            'lhd': lhd,
1084            'ged': ged,
1085            'gcd': gcd
1086        }
1087
1088        if distance_type not in distance_functions:
1089            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1090
1091        # Compute pairwise distances
1092        aln = self.alignment
1093        distance_func = distance_functions[distance_type]
1094        distance_matrix = np.zeros((len(aln), len(aln)))
1095
1096        sequences = list(aln.values())
1097        n = len(sequences)
1098        for i in range(n):
1099            seq1 = sequences[i]
1100            for j in range(i, n):
1101                seq2 = sequences[j]
1102                dist = distance_func(seq1, seq2)
1103                distance_matrix[i, j] = dist
1104                distance_matrix[j, i] = dist
1105
1106        return distance_matrix

Calculate pairwise identities for an alignment. As there are different definitions of sequence identity, there are different options implemented:

    **1) ghd (global hamming distance)**: At each alignment position, check if characters match:

distance = matches / alignment_length * 100

    **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:

distance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100

    **3) ged (gap excluded distance)**: All gaps are excluded from the alignment

distance = matches / (matches + mismatches) * 100

    **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.

distance = matches / gap_compressed_alignment_length * 100

    :return: array with pairwise distances.
def get_snps(self, include_ambig: bool = False) -> dict:
1108    def get_snps(self, include_ambig:bool=False) -> dict:
1109        """
1110        Calculate snps similar to snp-sites (output is comparable):
1111        https://github.com/sanger-pathogens/snp-sites
1112        Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character.
1113        The SNPs are compared to a majority consensus sequence or to a reference if it has been set.
1114
1115        :param include_ambig: Include ambiguous snps (default: False)
1116        :return: dictionary containing snp positions and their variants including their frequency.
1117        """
1118        aln = self.alignment
1119        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
1120        aln = {x: aln[x] for x in aln.keys() if x != self.reference_id}
1121        seq_ids = list(aln.keys())
1122        snp_dict = {'#CHROM': self.reference_id if self.reference_id is not None else 'consensus', 'POS': {}}
1123
1124        for pos in range(self.length):
1125            reference_char = ref[pos]
1126            if not include_ambig:
1127                if reference_char in config.AMBIG_CHARS[self.aln_type] and reference_char != '-':
1128                    continue
1129            alt_chars, snps = [], []
1130            for i, seq_id in enumerate(aln.keys()):
1131                alt_chars.append(aln[seq_id][pos])
1132                if reference_char != aln[seq_id][pos]:
1133                    snps.append(i)
1134            if not snps:
1135                continue
1136            if include_ambig:
1137                if all(alt_chars[x] in config.AMBIG_CHARS[self.aln_type] for x in snps):
1138                    continue
1139            else:
1140                snps = [x for x in snps if alt_chars[x] not in config.AMBIG_CHARS[self.aln_type]]
1141                if not snps:
1142                    continue
1143            if pos not in snp_dict:
1144                snp_dict['POS'][pos] = {'ref': reference_char, 'ALT': {}}
1145            for snp in snps:
1146                if alt_chars[snp] not in snp_dict['POS'][pos]['ALT']:
1147                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]] = {
1148                        'AF': 1,
1149                        'SEQ_ID': [seq_ids[snp]]
1150                    }
1151                else:
1152                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['AF'] += 1
1153                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['SEQ_ID'].append(seq_ids[snp])
1154            # calculate AF
1155            if pos in snp_dict['POS']:
1156                for alt in snp_dict['POS'][pos]['ALT']:
1157                    snp_dict['POS'][pos]['ALT'][alt]['AF'] /= len(aln)
1158
1159        return snp_dict

Calculate snps similar to snp-sites (output is comparable): https://github.com/sanger-pathogens/snp-sites Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character. The SNPs are compared to a majority consensus sequence or to a reference if it has been set.

Parameters
  • include_ambig: Include ambiguous snps (default: False)
Returns

dictionary containing snp positions and their variants including their frequency.

def calc_transition_transversion_score(self) -> list:
1161    def calc_transition_transversion_score(self) -> list:
1162        """
1163        Based on the snp positions, calculates a transition/transversions score.
1164        A positive score means higher ratio of transitions and negative score means
1165        a higher ratio of transversions.
1166        :return: list
1167        """
1168
1169        if self.aln_type == 'AA':
1170            raise TypeError('TS/TV scoring only for RNA/DNA alignments')
1171
1172        # ini
1173        snps = self.get_snps()
1174        score = [0]*self.length
1175
1176        for pos in snps['POS']:
1177            t_score_temp = 0
1178            for alt in snps['POS'][pos]['ALT']:
1179                # check the type of substitution
1180                if snps['POS'][pos]['ref'] + alt in ['AG', 'GA', 'CT', 'TC', 'CU', 'UC']:
1181                    score[pos] += snps['POS'][pos]['ALT'][alt]['AF']
1182                else:
1183                    score[pos] -= snps['POS'][pos]['ALT'][alt]['AF']
1184
1185        return score

Based on the snp positions, calculates a transition/transversions score. A positive score means higher ratio of transitions and negative score means a higher ratio of transversions.

Returns

list

class Annotation:
1188class Annotation:
1189    """
1190    An annotation class that allows to read in gff, gb or bed files and adjust its locations to that of the MSA.
1191    """
1192
1193    def __init__(self, aln: MSA, annotation_path: str):
1194        """
1195        The annotation class. Lets you parse multiple standard formats
1196        which might be used for annotating an alignment. The main purpose
1197        is to parse the annotation file and adapt the locations of diverse
1198        features to the locations within the alignment, considering the
1199        respective alignment positions. Importantly, IDs of the alignment
1200        and the MSA have to partly match.
1201
1202        :param aln: MSA class
1203        :param annotation_path: path to annotation file (gb, bed, gff) or raw string
1204
1205        """
1206
1207        self.ann_type, self._seq_id, self.locus, self.features  = self._parse_annotation(annotation_path, aln)  # read annotation
1208        self._gapped_seq = self._MSA_validation_and_seq_extraction(aln, self._seq_id)  # extract gapped sequence
1209        self._position_map = self._build_position_map()  # build a position map
1210        self._map_to_alignment()  # adapt feature locations
1211
1212    @staticmethod
1213    def _MSA_validation_and_seq_extraction(aln: MSA, seq_id: str) -> str:
1214        """
1215        extract gapped sequence from MSA that corresponds to annotation
1216        :param aln: MSA class
1217        :param seq_id: sequence id to extract
1218        :return: gapped sequence
1219        """
1220        if not isinstance(aln, MSA):
1221            raise ValueError('alignment has to be an MSA class. use explore.MSA() to read in alignment')
1222        else:
1223            return aln._alignment[seq_id]
1224
1225    @staticmethod
1226    def _parse_annotation(annotation_path: str, aln: MSA) -> tuple[str, str, str, Dict]:
1227
1228        def detect_annotation_type(file_path: str) -> str:
1229            """
1230            Detect the type of annotation file (GenBank, GFF, or BED) based
1231            on the first relevant line (excluding empty and #)
1232
1233            :param file_path: Path to the annotation file.
1234            :return: The detected file type ('gb', 'gff', or 'bed').
1235
1236            :raises ValueError: If the file type cannot be determined.
1237            """
1238
1239            with _get_line_iterator(file_path) as file:
1240                for line in file:
1241                    # skip empty lines and comments
1242                    if not line.strip() or line.startswith('#'):
1243                        continue
1244                   # genbank
1245                    if line.startswith('LOCUS'):
1246                        return 'gb'
1247                    # gff
1248                    if len(line.split('\t')) == 9:
1249                        # Check for expected values
1250                        columns = line.split('\t')
1251                        if columns[6] in ['+', '-', '.'] and re.match(r'^\d+$', columns[3]) and re.match(r'^\d+$',columns[4]):
1252                            return 'gff'
1253                    # BED files are tab-delimited with at least 3 fields: chrom, start, end
1254                    fields = line.split('\t')
1255                    if len(fields) >= 3 and re.match(r'^\d+$', fields[1]) and re.match(r'^\d+$', fields[2]):
1256                        return 'bed'
1257                    # only read in the first line
1258                    break
1259
1260            raise ValueError(
1261                "File type could not be determined. Ensure the file follows a recognized format (GenBank, GFF, or BED).")
1262
1263        def parse_gb(file_path) -> dict:
1264            """
1265            parse a genebank file to dictionary - primarily retained are the informations
1266            for qualifiers as these will be used for plotting.
1267
1268            :param file_path: path to genebank file
1269            :return: nested dictionary
1270
1271            """
1272
1273            def sanitize_gb_location(string: str) -> tuple[list, str]:
1274                """
1275                see: https://www.insdc.org/submitting-standards/feature-table/
1276                """
1277                strand = '+'
1278                locations = []
1279                # check the direction of the annotation
1280                if 'complement' in string:
1281                    strand = '-'
1282                # sanitize operators
1283                for operator in ['complement(', 'join(', 'order(']:
1284                    string = string.replace(operator, '')
1285                # sanitize possible chars for splitting start stop -
1286                # however in the future might not simply do this
1287                # as some useful information is retained
1288                for char in ['>', '<', ')']:
1289                    string = string.replace(char, '')
1290                # check if we have multiple location e.g. due to splicing
1291                if ',' in string:
1292                    raw_locations = string.split(',')
1293                else:
1294                    raw_locations = [string]
1295                # try to split start and stop
1296                for location in raw_locations:
1297                    for sep in ['..', '.', '^']:
1298                        if sep in location:
1299                            sanitized_locations = [int(x) for x in location.split(sep)]
1300                            sanitized_locations[0] = sanitized_locations[0] - 1  # enforce 0-based starts
1301                            locations.append(sanitized_locations)
1302                            break
1303
1304                return locations, strand
1305
1306
1307            records = {}
1308            with _get_line_iterator(file_path) as file:
1309                record = None
1310                in_features = False
1311                counter_dict = {}
1312                for line in file:
1313                    line = line.rstrip()
1314                    parts = line.split()
1315                    # extract the locus id
1316                    if line.startswith('LOCUS'):
1317                        if record:
1318                            records[record['locus']] = record
1319                        record = {
1320                            'locus': parts[1],
1321                            'features': {}
1322                        }
1323
1324                    elif line.startswith('FEATURES'):
1325                        in_features = True
1326
1327                    # ignore the sequence info
1328                    elif line.startswith('ORIGIN'):
1329                        in_features = False
1330
1331                    # now write useful feature information to dictionary
1332                    elif in_features:
1333                        if not line.strip():
1334                            continue
1335                        if line[5] != ' ':
1336                            location_line = True  # remember that we are in a location for multi-line locations
1337                            feature_type, qualifier = parts[0], parts[1]
1338                            if feature_type not in record['features']:
1339                                record['features'][feature_type] = {}
1340                                counter_dict[feature_type] = 0
1341                            locations, strand = sanitize_gb_location(qualifier)
1342                            record['features'][feature_type][counter_dict[feature_type]] = {
1343                                'location': locations,
1344                                'strand': strand
1345                            }
1346                            counter_dict[feature_type] += 1
1347                        else:
1348                            # edge case for multi-line locations
1349                            if location_line and not line.strip().startswith('/'):
1350                                locations, strand = sanitize_gb_location(parts[0])
1351                                for loc in locations:
1352                                    record['features'][feature_type][counter_dict[feature_type]]['location'].append(loc)
1353                            else:
1354                                location_line = False
1355                                try:
1356                                    qualifier_type, qualifier = parts[0].split('=')
1357                                except ValueError:  # we are in the coding sequence
1358                                    qualifier = qualifier + parts[0]
1359
1360                                qualifier_type, qualifier = qualifier_type.lstrip('/'), qualifier.strip('"')
1361                                last_index = counter_dict[feature_type] - 1
1362                                record['features'][feature_type][last_index][qualifier_type] = qualifier
1363
1364            records[record['locus']] = record
1365
1366            return records
1367
1368        def parse_gff(file_path) -> dict:
1369            """
1370            Parse a GFF3 (General Feature Format) file into a dictionary structure.
1371
1372            :param file_path: path to genebank file
1373            :return: nested dictionary
1374
1375            """
1376            records = {}
1377            with _get_line_iterator(file_path) as file:
1378                previous_id, previous_feature = None, None
1379                for line in file:
1380                    if line.startswith('#') or not line.strip():
1381                        continue
1382                    parts = line.strip().split('\t')
1383                    seqid, source, feature_type, start, end, score, strand, phase, attributes = parts
1384                    # ensure that region and source features are not named differently for gff and gb
1385                    if feature_type == 'region':
1386                        feature_type = 'source'
1387                    if seqid not in records:
1388                        records[seqid] = {'locus': seqid, 'features': {}}
1389                    if feature_type not in records[seqid]['features']:
1390                        records[seqid]['features'][feature_type] = {}
1391
1392                    feature_id = len(records[seqid]['features'][feature_type])
1393                    feature = {
1394                        'strand': strand,
1395                    }
1396
1397                    # Parse attributes into key-value pairs
1398                    for attr in attributes.split(';'):
1399                        if '=' in attr:
1400                            key, value = attr.split('=', 1)
1401                            feature[key.strip()] = value.strip()
1402
1403                    # check if feature are the same --> possible splicing
1404                    if previous_id is not None and previous_feature == feature:
1405                        records[seqid]['features'][feature_type][previous_id]['location'].append([int(start)-1, int(end)])
1406                    else:
1407                        records[seqid]['features'][feature_type][feature_id] = feature
1408                        records[seqid]['features'][feature_type][feature_id]['location'] = [[int(start) - 1, int(end)]]
1409                    # set new previous id and features -> new dict as 'location' is pointed in current feature and this
1410                    # is the only key different if next feature has the same entries
1411                    previous_id, previous_feature = feature_id, {key:value for key, value in feature.items() if key != 'location'}
1412
1413            return records
1414
1415        def parse_bed(file_path) -> dict:
1416            """
1417            Parse a BED file into a dictionary structure.
1418
1419            :param file_path: path to genebank file
1420            :return: nested dictionary
1421
1422            """
1423            records = {}
1424            with _get_line_iterator(file_path) as file:
1425                for line in file:
1426                    if line.startswith('#') or not line.strip():
1427                        continue
1428                    parts = line.strip().split('\t')
1429                    chrom, start, end, *optional = parts
1430
1431                    if chrom not in records:
1432                        records[chrom] = {'locus': chrom, 'features': {}}
1433                    feature_type = 'region'
1434                    if feature_type not in records[chrom]['features']:
1435                        records[chrom]['features'][feature_type] = {}
1436
1437                    feature_id = len(records[chrom]['features'][feature_type])
1438                    feature = {
1439                        'location': [[int(start), int(end)]],  # BED uses 0-based start, convert to 1-based
1440                        'strand': '+',  # assume '+' if not present
1441                    }
1442
1443                    # Handle optional columns (name, score, strand) --> ignore 7-12
1444                    if len(optional) >= 1:
1445                        feature['name'] = optional[0]
1446                    if len(optional) >= 2:
1447                        feature['score'] = optional[1]
1448                    if len(optional) >= 3:
1449                        feature['strand'] = optional[2]
1450
1451                    records[chrom]['features'][feature_type][feature_id] = feature
1452
1453            return records
1454
1455        parse_functions: Dict[str, Callable[[str], dict]] = {
1456            'gb': parse_gb,
1457            'bed': parse_bed,
1458            'gff': parse_gff,
1459        }
1460        # determine the annotation content -> should be standard formatted
1461        try:
1462            annotation_type = detect_annotation_type(annotation_path)
1463        except ValueError as err:
1464            raise err
1465
1466        # read in the annotation
1467        annotations = parse_functions[annotation_type](annotation_path)
1468
1469        # sanity check whether one of the annotation ids and alignment ids match
1470        annotation_found = False
1471        for annotation in annotations.keys():
1472            for aln_id in aln.alignment.keys():
1473                aln_id_sanitized = aln_id.split(' ')[0]
1474                # check in both directions
1475                if aln_id_sanitized in annotation:
1476                    annotation_found = True
1477                    break
1478                if annotation in aln_id_sanitized:
1479                    annotation_found = True
1480                    break
1481
1482        if not annotation_found:
1483            raise ValueError(f'the annotations of {annotation_path} do not match any ids in the MSA')
1484
1485        # return only the annotation that has been found, the respective type and the seq_id to map to
1486        return annotation_type, aln_id, annotations[annotation]['locus'], annotations[annotation]['features']
1487
1488
1489    def _build_position_map(self) -> Dict[int, int]:
1490        """
1491        build a position map from a sequence.
1492
1493        :return genomic position: gapped position
1494        """
1495
1496        position_map = {}
1497        genomic_pos = 0
1498        for aln_index, char in enumerate(self._gapped_seq):
1499            if char != '-':
1500                position_map[genomic_pos] = aln_index
1501                genomic_pos += 1
1502        # ensure the last genomic position is included
1503        position_map[genomic_pos] = len(self._gapped_seq)
1504
1505        return position_map
1506
1507
1508    def _map_to_alignment(self):
1509        """
1510        Adjust all feature locations to alignment positions
1511        """
1512
1513        def map_location(position_map: Dict[int, int], locations: list) -> list:
1514            """
1515            Map genomic locations to alignment positions using a precomputed position map.
1516
1517            :param position_map: Positions mapped from gapped to ungapped
1518            :param locations: List of genomic start and end positions.
1519            :return: List of adjusted alignment positions.
1520            """
1521
1522            aligned_locs = []
1523            for start, end in locations:
1524                try:
1525                    aligned_start = position_map[start]
1526                    aligned_end = position_map[end]
1527                    aligned_locs.append([aligned_start, aligned_end])
1528                except KeyError:
1529                    raise ValueError(f"Positions {start}-{end} lie outside of the position map.")
1530
1531            return aligned_locs
1532
1533        for feature_type, features in self.features.items():
1534            for feature_id, feature_data in features.items():
1535                original_locations = feature_data['location']
1536                aligned_locations = map_location(self._position_map, original_locations)
1537                feature_data['location'] = aligned_locations

An annotation class that allows to read in gff, gb or bed files and adjust its locations to that of the MSA.

Annotation(aln: MSA, annotation_path: str)
1193    def __init__(self, aln: MSA, annotation_path: str):
1194        """
1195        The annotation class. Lets you parse multiple standard formats
1196        which might be used for annotating an alignment. The main purpose
1197        is to parse the annotation file and adapt the locations of diverse
1198        features to the locations within the alignment, considering the
1199        respective alignment positions. Importantly, IDs of the alignment
1200        and the MSA have to partly match.
1201
1202        :param aln: MSA class
1203        :param annotation_path: path to annotation file (gb, bed, gff) or raw string
1204
1205        """
1206
1207        self.ann_type, self._seq_id, self.locus, self.features  = self._parse_annotation(annotation_path, aln)  # read annotation
1208        self._gapped_seq = self._MSA_validation_and_seq_extraction(aln, self._seq_id)  # extract gapped sequence
1209        self._position_map = self._build_position_map()  # build a position map
1210        self._map_to_alignment()  # adapt feature locations

The annotation class. Lets you parse multiple standard formats which might be used for annotating an alignment. The main purpose is to parse the annotation file and adapt the locations of diverse features to the locations within the alignment, considering the respective alignment positions. Importantly, IDs of the alignment and the MSA have to partly match.

Parameters
  • aln: MSA class
  • annotation_path: path to annotation file (gb, bed, gff) or raw string