msaexplorer.explore

Explore module

This module contains the two classes MSA and Annotation which are used to read in the respective files and can be used to compute several statistics or be used as the input for the draw module functions.

Classes

   1"""
   2# Explore module
   3
   4This module contains the two classes `MSA` and `Annotation` which are used to read in the respective files
   5and can be used to compute several statistics or be used as the input for the `draw` module functions.
   6
   7## Classes
   8
   9"""
  10
  11# built-in
  12import math
  13import collections
  14import re
  15from typing import Callable, Dict
  16
  17# installed
  18import numpy as np
  19from numpy import ndarray
  20from Bio import SeqIO
  21from Bio.Align import MultipleSeqAlignment
  22from Bio.SeqIO.InsdcIO import GenBankIterator
  23
  24# msaexplorer
  25from msaexplorer import config
  26from msaexplorer._data_classes import PairwiseDistance, AlignmentStats, LengthStats, OpenReadingFrame, OrfCollection, SingleNucleotidePolymorphism, VariantCollection
  27from msaexplorer._helpers import _get_line_iterator, _create_distance_calculation_function_mapping, _read_alignment
  28
  29
  30class MSA:
  31    """
  32    An alignment class that allows computation of several stats. Supported inputs are file paths to alignments in "fasta",
  33    "clustal", "phylip", "stockholm", "nexus" formats, raw alignment strings, or Bio.Align.MultipleSeqAlignment for
  34    compatibility with Biopython.
  35    """
  36
  37    # dunder methods
  38    def __init__(self, alignment_string: str | MultipleSeqAlignment, reference_id: str = None, zoom_range: tuple | int = None):
  39        """
  40        Initialize an Alignment object.
  41        :param alignment_string: Path to alignment file or raw alignment string
  42        :param reference_id: reference id
  43        :param zoom_range: start and stop positions to zoom into the alignment
  44        """
  45        self._alignment = _read_alignment(alignment_string)
  46        self._reference_id = self._validate_ref(reference_id, self._alignment)
  47        self._zoom = self._validate_zoom(zoom_range, self._alignment)
  48        self._aln_type = self._determine_aln_type(self._alignment)
  49
  50    def __len__(self) -> int:
  51        return len(self.sequence_ids)
  52
  53    def __getitem__(self, key: str) -> str:
  54        return self.alignment[key]
  55
  56    def __contains__(self, item: str) -> bool:
  57        return item in self.sequence_ids
  58
  59    def __iter__(self):
  60        return iter(self.alignment)
  61
  62    # Static methods
  63    @staticmethod
  64    def _validate_ref(reference: str | None, alignment: dict) -> str | None | ValueError:
  65        """
  66        Validate if the ref seq is indeed part of the alignment.
  67        :param reference: reference seq id
  68        :param alignment: alignment dict
  69        :return: validated reference
  70        """
  71        if reference in alignment.keys():
  72            return reference
  73        elif reference is None:
  74            return reference
  75        else:
  76            raise ValueError('Reference not in alignment.')
  77
  78    @staticmethod
  79    def _validate_zoom(zoom: tuple | int, original_aln: dict) -> ValueError | tuple | None:
  80        """
  81        Validates if the user defined zoom range is within the start, end of the initial
  82        alignment.\n
  83        :param zoom: zoom range or zoom start
  84        :param original_aln: non-zoomed alignment dict
  85        :return: validated zoom range
  86        """
  87        if zoom is not None:
  88            aln_length = len(original_aln[list(original_aln.keys())[0]])
  89            # check if only over value is provided -> stop is alignment length
  90            if isinstance(zoom, int):
  91                if 0 <= zoom < aln_length:
  92                    return zoom, aln_length
  93                else:
  94                    raise ValueError('Zoom start must be within the alignment length range.')
  95            # check if more than 2 values are provided
  96            if len(zoom) != 2:
  97                raise ValueError('Zoom position have to be (zoom_start, zoom_end)')
  98            start, end = zoom
  99            # validate zoom start/stop for Python slicing semantics [start:end)
 100            if type(start) is not int or type(end) is not int:
 101                raise ValueError('Zoom positions have to be integers.')
 102            if not (0 <= start < aln_length):
 103                raise ValueError('Zoom position out of range')
 104            if not (0 < end <= aln_length):
 105                raise ValueError('Zoom position out of range')
 106            if start >= end:
 107                raise ValueError('Zoom position have to be (zoom_start, zoom_end) with zoom_start < zoom_end')
 108
 109        return zoom
 110
 111    @staticmethod
 112    def _determine_aln_type(alignment) -> str:
 113        """
 114        Determine the most likely type of alignment
 115        if 70% of chars in the alignment are nucleotide
 116        chars it is most likely a nt alignment
 117        :return: type of alignment
 118        """
 119        counter = int()
 120        for record in alignment:
 121            if 'U' in alignment[record]:
 122                return 'RNA'
 123            counter += sum(map(alignment[record].count, ['A', 'C', 'G', 'T', 'N', '-']))
 124        # determine which is the most likely type
 125        if counter / len(alignment) >= 0.7 * len(alignment[list(alignment.keys())[0]]):
 126            return 'DNA'
 127        else:
 128            return 'AA'
 129
 130    # Properties with setters
 131    @property
 132    def reference_id(self) -> str:
 133        return self._reference_id
 134
 135    @reference_id.setter
 136    def reference_id(self, ref_id: str):
 137        """
 138        Set and validate the reference id.
 139        """
 140        self._reference_id = self._validate_ref(ref_id, self.alignment)
 141
 142    @property
 143    def zoom(self) -> tuple:
 144        return self._zoom
 145
 146    @zoom.setter
 147    def zoom(self, zoom_pos: tuple | int):
 148        """
 149        Validate if the user defined zoom range.
 150        """
 151        self._zoom = self._validate_zoom(zoom_pos, self._alignment)
 152
 153    # Property without setters
 154    @property
 155    def aln_type(self) -> str:
 156        """
 157        define the aln type:
 158        RNA, DNA or AA
 159        """
 160        return self._aln_type
 161
 162    @property
 163    def sequence_ids(self) -> list:
 164        return list(self.alignment.keys())
 165
 166    # On the fly properties without setters
 167    @property
 168    def length(self) -> int:
 169        return len(next(iter(self.alignment.values())))
 170
 171    @property
 172    def alignment(self) -> dict:
 173        """
 174        (zoomed) version of the alignment.
 175        """
 176        if self.zoom is not None:
 177            zoomed_aln = dict()
 178            for seq in self._alignment:
 179                zoomed_aln[seq] = self._alignment[seq][self.zoom[0]:self.zoom[1]]
 180            return zoomed_aln
 181        else:
 182            return self._alignment
 183
 184    def _create_position_stat_result(self, stat_name: str, values: list | ndarray) -> AlignmentStats:
 185        """
 186        Build a shared dataclass for position-wise statistics.
 187        """
 188        values_array = np.asarray(values, dtype=float)
 189        if self.zoom is None:
 190            positions = np.arange(self.length, dtype=int)
 191        else:
 192            positions = np.arange(self.zoom[0], self.zoom[0] + self.length, dtype=int)
 193
 194        return AlignmentStats(
 195            stat_name=stat_name,
 196            positions=positions,
 197            values=values_array,
 198        )
 199
 200    def _to_array(self) -> ndarray:
 201        """convert alignment to a numpy array"""
 202        return np.array([list(self.alignment[seq_id]) for seq_id in self.sequence_ids])
 203
 204    def _get_reference_seq(self) -> str:
 205        """get the sequence of the reference sequence or majority consensus"""
 206        return self.alignment[self.reference_id] if self.reference_id is not None else self.get_consensus()
 207
 208    def get_reference_coords(self) -> tuple[int, int]:
 209        """
 210        Determine the start and end coordinates of the reference sequence
 211        defined as the first/last nucleotide in the reference sequence
 212        (excluding N and gaps).
 213
 214        :return: Start, End
 215        """
 216        start, end = 0, self.length
 217
 218        if self.reference_id is None:
 219            return start, end
 220        else:
 221            # 5' --> 3'
 222            for start in range(self.length):
 223                if self.alignment[self.reference_id][start] not in ['-', 'N']:
 224                    break
 225            # 3' --> 5'
 226            for end in range(self.length - 1, 0, -1):
 227                if self.alignment[self.reference_id][end] not in ['-', 'N']:
 228                    break
 229
 230            return start, end
 231
 232    def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
 233        """
 234        Creates a non-gapped consensus sequence.
 235
 236        :param threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes
 237            the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments)
 238            or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
 239        :param use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position
 240            has a frequency above the defined threshold.
 241        :return: consensus sequence
 242        """
 243
 244        # helper functions
 245        def determine_counts(alignment_dict: dict, position: int) -> dict:
 246            """
 247            count the number of each char at
 248            an idx of the alignment. return sorted dic.
 249            handles ambiguous nucleotides in sequences.
 250            also handles gaps.
 251            """
 252            nucleotide_list = []
 253
 254            # get all nucleotides
 255            for sequence in alignment_dict.items():
 256                nucleotide_list.append(sequence[1][position])
 257            # count occurences of nucleotides
 258            counter = dict(collections.Counter(nucleotide_list))
 259            # get permutations of an ambiguous nucleotide
 260            to_delete = []
 261            temp_dict = {}
 262            for nucleotide in counter:
 263                if nucleotide in config.AMBIG_CHARS[self.aln_type]:
 264                    to_delete.append(nucleotide)
 265                    permutations = config.AMBIG_CHARS[self.aln_type][nucleotide]
 266                    adjusted_freq = 1 / len(permutations)
 267                    for permutation in permutations:
 268                        if permutation in temp_dict:
 269                            temp_dict[permutation] += adjusted_freq
 270                        else:
 271                            temp_dict[permutation] = adjusted_freq
 272
 273            # drop ambiguous entries and add adjusted freqs to
 274            if to_delete:
 275                for i in to_delete:
 276                    counter.pop(i)
 277                for nucleotide in temp_dict:
 278                    if nucleotide in counter:
 279                        counter[nucleotide] += temp_dict[nucleotide]
 280                    else:
 281                        counter[nucleotide] = temp_dict[nucleotide]
 282
 283            return dict(sorted(counter.items(), key=lambda x: x[1], reverse=True))
 284
 285        def get_consensus_char(counts: dict, cutoff: float) -> list:
 286            """
 287            get a list of nucleotides for the consensus seq
 288            """
 289            n = 0
 290
 291            consensus_chars = []
 292            for char in counts:
 293                n += counts[char]
 294                consensus_chars.append(char)
 295                if n >= cutoff:
 296                    break
 297
 298            return consensus_chars
 299
 300        def get_ambiguous_char(nucleotides: list) -> str:
 301            """
 302            get ambiguous char from a list of nucleotides
 303            """
 304            for ambiguous, permutations in config.AMBIG_CHARS[self.aln_type].items():
 305                if set(permutations) == set(nucleotides):
 306                    break
 307
 308            return ambiguous
 309
 310        # check if params have been set correctly
 311        if threshold is not None:
 312            if threshold < 0 or threshold > 1:
 313                raise ValueError('Threshold must be between 0 and 1.')
 314        if self.aln_type == 'AA' and use_ambig_nt:
 315            raise ValueError('Ambiguous characters can not be calculated for amino acid alignments.')
 316        if threshold is None and use_ambig_nt:
 317            raise ValueError('To calculate ambiguous nucleotides, set a threshold > 0.')
 318
 319        alignment = self.alignment
 320        consensus = str()
 321
 322        if threshold is not None:
 323            consensus_cutoff = len(alignment) * threshold
 324        else:
 325            consensus_cutoff = 0
 326
 327        # built consensus sequences
 328        for idx in range(self.length):
 329            char_counts = determine_counts(alignment, idx)
 330            consensus_chars = get_consensus_char(
 331                char_counts,
 332                consensus_cutoff
 333            )
 334            if threshold != 0:
 335                if len(consensus_chars) > 1:
 336                    if use_ambig_nt:
 337                        char = get_ambiguous_char(consensus_chars)
 338                    else:
 339                        if self.aln_type == 'AA':
 340                            char = 'X'
 341                        else:
 342                            char = 'N'
 343                    consensus = consensus + char
 344                else:
 345                    consensus = consensus + consensus_chars[0]
 346            else:
 347                consensus = consensus + consensus_chars[0]
 348
 349        return consensus
 350
 351    def get_conserved_orfs(self, min_length: int = 100, identity_cutoff: float | None = None) -> OrfCollection:
 352        """
 353        **conserved ORF definition:**
 354            - conserved starts and stops
 355            - start, stop must be on the same frame
 356            - stop - start must be at least min_length
 357            - all ungapped seqs[start:stop] must have at least min_length
 358            - no ungapped seq can have a Stop in between Start Stop
 359
 360        Conservation is measured by the number of positions with identical characters divided by
 361        orf slice of the alignment.
 362
 363        **Algorithm overview:**
 364            - check for conserved start and stop codons
 365            - iterate over all three frames
 366            - check each start and next sufficiently far away stop codon
 367            - check if all ungapped seqs between start and stop codon are >= min_length
 368            - check if no ungapped seq in the alignment has a stop codon
 369            - write to dictionary
 370            - classify as internal if the stop codon has already been written with a prior start
 371            - repeat for reverse complement
 372
 373        :return: ORF positions and internal ORF positions
 374        """
 375
 376        # helper functions
 377        def determine_conserved_start_stops(alignment: dict, alignment_length: int) -> tuple:
 378            """
 379            Determine all start and stop codons within an alignment.
 380            :param alignment: alignment
 381            :param alignment_length: length of alignment
 382            :return: start and stop codon positions
 383            """
 384            starts = config.START_CODONS[self.aln_type]
 385            stops = config.STOP_CODONS[self.aln_type]
 386
 387            list_of_starts, list_of_stops = [], []
 388            # define one sequence (first) as reference (it does not matter which one)
 389            ref = alignment[self.sequence_ids[0]]
 390            for nt_position in range(alignment_length):
 391                if ref[nt_position:nt_position + 3] in starts:
 392                    conserved_start = True
 393                    for sequence in alignment:
 394                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in starts:
 395                            conserved_start = False
 396                            break
 397                    if conserved_start:
 398                        list_of_starts.append(nt_position)
 399
 400                if ref[nt_position:nt_position + 3] in stops:
 401                    conserved_stop = True
 402                    for sequence in alignment:
 403                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in stops:
 404                            conserved_stop = False
 405                            break
 406                    if conserved_stop:
 407                        list_of_stops.append(nt_position)
 408
 409            return list_of_starts, list_of_stops
 410
 411        def get_ungapped_sliced_seqs(alignment: dict, start_pos: int, stop_pos: int) -> list:
 412            """
 413            get ungapped sequences starting and stop codons and eliminate gaps
 414            :param alignment: alignment
 415            :param start_pos: start codon
 416            :param stop_pos: stop codon
 417            :return: sliced sequences
 418            """
 419            ungapped_seqs = []
 420            for seq_id in alignment:
 421                ungapped_seqs.append(alignment[seq_id][start_pos:stop_pos + 3].replace('-', ''))
 422
 423            return ungapped_seqs
 424
 425        def additional_stops(ungapped_seqs: list) -> bool:
 426            """
 427            Checks for the presence of a stop codon
 428            :param ungapped_seqs: list of ungapped sequences
 429            :return: Additional stop codons (True/False)
 430            """
 431            stops = config.STOP_CODONS[self.aln_type]
 432
 433            for sliced_seq in ungapped_seqs:
 434                for position in range(0, len(sliced_seq) - 3, 3):
 435                    if sliced_seq[position:position + 3] in stops:
 436                        return True
 437            return False
 438
 439        def calculate_identity(identity_matrix: ndarray, aln_slice:list) -> float:
 440            sliced_array = identity_matrix[:,aln_slice[0]:aln_slice[1]] + 1  # identical = 0, different = -1 --> add 1
 441            return np.sum(np.all(sliced_array == 1, axis=0))/(aln_slice[1] - aln_slice[0]) * 100
 442
 443        # checks for arguments
 444        if self.aln_type == 'AA':
 445            raise TypeError('ORF search only for RNA/DNA alignments')
 446
 447        if identity_cutoff is not None:
 448            if identity_cutoff > 100 or identity_cutoff < 0:
 449                raise ValueError('conservation cutoff must be between 0 and 100')
 450
 451        if min_length <= 6 or min_length > self.length:
 452            raise ValueError(f'min_length must be between 6 and {self.length}')
 453
 454        # ini
 455        identities = self.calc_identity_alignment()
 456        alignments = [self.alignment, self.calc_reverse_complement_alignment()]
 457        aln_len = self.length
 458
 459        orf_counter = 0
 460        # use mutable dicts during construction and convert to dataclass at the end for immutability
 461        temp_orfs: list[dict] = []
 462
 463        for aln, direction in zip(alignments, ['+', '-']):
 464            # check for starts and stops in the first seq and then check if these are present in all seqs
 465            conserved_starts, conserved_stops = determine_conserved_start_stops(aln, aln_len)
 466            # check each frame
 467            for frame in (0, 1, 2):
 468                potential_starts = [x for x in conserved_starts if x % 3 == frame]
 469                potential_stops = [x for x in conserved_stops if x % 3 == frame]
 470                last_stop = -1
 471                for start in potential_starts:
 472                    # go to the next stop that is sufficiently far away in the alignment
 473                    next_stops = [x for x in potential_stops if x + 3 >= start + min_length]
 474                    if not next_stops:
 475                        continue
 476                    next_stop = next_stops[0]
 477                    ungapped_sliced_seqs = get_ungapped_sliced_seqs(aln, start, next_stop)
 478                    # re-check the lengths of all ungapped seqs
 479                    ungapped_seq_lengths = [len(x) >= min_length for x in ungapped_sliced_seqs]
 480                    if not all(ungapped_seq_lengths):
 481                        continue
 482                    # if no stop codon between start and stop --> write to dictionary
 483                    if not additional_stops(ungapped_sliced_seqs):
 484                        if direction == '+':
 485                            positions = (start, next_stop + 3)
 486                        else:
 487                            positions = (aln_len - next_stop - 3, aln_len - start)
 488                        if last_stop != next_stop:
 489                            last_stop = next_stop
 490                            conservation = calculate_identity(identities, list(positions))
 491                            if identity_cutoff is not None and conservation < identity_cutoff:
 492                                continue
 493                            temp_orfs.append({
 494                                'orf_id': f'ORF_{orf_counter}',
 495                                'location': [positions],
 496                                'frame': frame,
 497                                'strand': direction,
 498                                'conservation': conservation,
 499                                'internal': [],
 500                            })
 501                            orf_counter += 1
 502                        else:
 503                            if temp_orfs:
 504                                temp_orfs[-1]['internal'].append(positions)
 505
 506        # convert mutable intermediate dicts to frozen dataclasses
 507        orf_list = [
 508            OpenReadingFrame(
 509                orf_id=t['orf_id'],
 510                location=tuple(t['location']),
 511                frame=t['frame'],
 512                strand=t['strand'],
 513                conservation=t['conservation'],
 514                internal=tuple(t['internal']),
 515            )
 516            for t in temp_orfs
 517        ]
 518        return OrfCollection(orfs=tuple(orf_list))
 519
 520    def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff:float = None) -> OrfCollection:
 521        """
 522        First calculates all ORFs and then searches from 5'
 523        all non-overlapping orfs in the fw strand and from the
 524        3' all non-overlapping orfs in th rw strand.
 525
 526        **No overlap algorithm:**
 527            **frame 1:** -[M------*]--- ----[M--*]---------[M-----
 528
 529            **frame 2:** -------[M------*]---------[M---*]--------
 530
 531            **frame 3:** [M---*]-----[M----------*]----------[M---
 532
 533            **results:** [M---*][M------*]--[M--*]-[M---*]-[M-----
 534
 535            frame:    3      2           1      2       1
 536
 537        :return: OrfContainer with non-overlapping orfs
 538        """
 539        all_orfs = self.get_conserved_orfs(min_length, identity_cutoff)
 540
 541        fw_orfs, rw_orfs = [], []
 542
 543        for orf in all_orfs:
 544            orf_obj = all_orfs[orf]
 545            if orf_obj.strand == '+':
 546                fw_orfs.append((orf, orf_obj.location[0]))
 547            else:
 548                rw_orfs.append((orf, orf_obj.location[0]))
 549
 550        fw_orfs.sort(key=lambda x: x[1][0])  # sort by start pos
 551        rw_orfs.sort(key=lambda x: x[1][1], reverse=True)  # sort by stop pos
 552        non_overlapping_ids = []
 553        for orf_list, strand in zip([fw_orfs, rw_orfs], ['+', '-']):
 554            previous_stop = -1 if strand == '+' else self.length + 1
 555            for orf in orf_list:
 556                if strand == '+' and orf[1][0] >= previous_stop:
 557                    non_overlapping_ids.append(orf[0])
 558                    previous_stop = orf[1][1]
 559                elif strand == '-' and orf[1][1] <= previous_stop:
 560                    non_overlapping_ids.append(orf[0])
 561                    previous_stop = orf[1][0]
 562
 563        return OrfCollection(orfs=tuple(all_orfs[orf_id] for orf_id in all_orfs if orf_id in non_overlapping_ids))
 564
 565    def calc_length_stats(self) -> LengthStats:
 566        """
 567        Determine the stats for the length of the ungapped seqs in the alignment.
 568        :return: dataclass with length stats
 569        """
 570
 571        seq_lengths = [len(self.alignment[x].replace('-', '')) for x in self.alignment]
 572
 573        return LengthStats(
 574            n_sequences=len(self.alignment),
 575            mean_length=float(np.mean(seq_lengths)),
 576            std_length=float(np.std(seq_lengths)),
 577            min_length=int(np.min(seq_lengths)),
 578            max_length=int(np.max(seq_lengths)),
 579        )
 580
 581    def calc_entropy(self) -> AlignmentStats:
 582        """
 583        Calculate the normalized shannon's entropy for every position in an alignment:
 584
 585        - 1: high entropy
 586        - 0: low entropy
 587
 588        :return: Entropies at each position.
 589        """
 590
 591        # helper functions
 592        def shannons_entropy(character_list: list, states: int, aln_type: str) -> float:
 593            """
 594            Calculate the shannon's entropy of a sequence and
 595            normalized between 0 and 1.
 596            :param character_list: characters at an alignment position
 597            :param states: number of potential characters that can be present
 598            :param aln_type: type of the alignment
 599            :returns: entropy
 600            """
 601            ent, n_chars = 0, len(character_list)
 602            # only one char is in the list
 603            if n_chars <= 1:
 604                return ent
 605            # calculate the number of unique chars and their counts
 606            chars, char_counts = np.unique(character_list, return_counts=True)
 607            char_counts = char_counts.astype(float)
 608            # ignore gaps for entropy calc
 609            char_counts, chars = char_counts[chars != "-"], chars[chars != "-"]
 610            # correctly handle ambiguous chars
 611            index_to_drop = []
 612            for index, char in enumerate(chars):
 613                if char in config.AMBIG_CHARS[aln_type]:
 614                    index_to_drop.append(index)
 615                    amb_chars, amb_counts = np.unique(config.AMBIG_CHARS[aln_type][char], return_counts=True)
 616                    amb_counts = amb_counts / len(config.AMBIG_CHARS[aln_type][char])
 617                    # add the proportionate numbers to initial array
 618                    for amb_char, amb_count in zip(amb_chars, amb_counts):
 619                        if amb_char in chars:
 620                            char_counts[chars == amb_char] += amb_count
 621                        else:
 622                            chars, char_counts = np.append(chars, amb_char), np.append(char_counts, amb_count)
 623            # drop the ambiguous characters from array
 624            char_counts, chars = np.delete(char_counts, index_to_drop), np.delete(chars, index_to_drop)
 625            # calc the entropy
 626            probs = char_counts / n_chars
 627            if np.count_nonzero(probs) <= 1:
 628                return ent
 629            for prob in probs:
 630                ent -= prob * math.log(prob, states)
 631
 632            return ent
 633
 634        aln = self.alignment
 635        entropys = []
 636
 637        if self.aln_type == 'AA':
 638            states = 20
 639        else:
 640            states = 4
 641        # iterate over alignment positions and the sequences
 642        for nuc_pos in range(self.length):
 643            pos = []
 644            for record in aln:
 645                pos.append(aln[record][nuc_pos])
 646            entropys.append(shannons_entropy(pos, states, self.aln_type))
 647
 648        return self._create_position_stat_result('entropy', entropys)
 649
 650    def calc_gc(self) -> AlignmentStats | TypeError:
 651        """
 652        Determine the GC content for every position in an nt alignment.
 653        :return: GC content for every position.
 654        :raises: TypeError for AA alignments
 655        """
 656        if self.aln_type == 'AA':
 657            raise TypeError("GC computation is not possible for aminoacid alignment")
 658
 659        gc, aln, amb_nucs = [], self.alignment, config.AMBIG_CHARS[self.aln_type]
 660
 661        for position in range(self.length):
 662            nucleotides = str()
 663            for record in aln:
 664                nucleotides = nucleotides + aln[record][position]
 665            # ini dict with chars that occur and which ones to
 666            # count in which freq
 667            to_count = {
 668                'G': 1,
 669                'C': 1,
 670            }
 671            # handle ambig. nuc
 672            for char in amb_nucs:
 673                if char in nucleotides:
 674                    to_count[char] = (amb_nucs[char].count('C') + amb_nucs[char].count('G')) / len(amb_nucs[char])
 675
 676            gc.append(
 677                sum([nucleotides.count(x) * to_count[x] for x in to_count]) / len(nucleotides)
 678            )
 679
 680        return self._create_position_stat_result('gc', gc)
 681
 682    def calc_coverage(self) -> AlignmentStats:
 683        """
 684        Determine the coverage of every position in an alignment.
 685        This is defined as:
 686            1 - cumulative length of '-' characters
 687
 688        :return: Coverage at each alignment position.
 689        """
 690        coverage, aln = [], self.alignment
 691
 692        for nuc_pos in range(self.length):
 693            pos = str()
 694            for record in self.sequence_ids:
 695                pos = pos + aln[record][nuc_pos]
 696            coverage.append(1 - pos.count('-') / len(pos))
 697
 698        return self._create_position_stat_result('coverage', coverage)
 699
 700    def calc_gap_frequency(self) -> AlignmentStats:
 701        """
 702        Determine the gap frequency for every position in an alignment. This is the inverted coverage.
 703        """
 704        coverage = self.calc_coverage()
 705
 706        return self._create_position_stat_result('gap frequency', 1 - coverage.values)
 707
 708    def calc_reverse_complement_alignment(self) -> dict | TypeError:
 709        """
 710        Reverse complement the alignment.
 711        :return: Alignment (rv)
 712        """
 713        if self.aln_type == 'AA':
 714            raise TypeError('Reverse complement only for RNA or DNA.')
 715
 716        aln = self.alignment
 717        reverse_complement_dict = {}
 718
 719        for seq_id in aln:
 720            reverse_complement_dict[seq_id] = ''.join(config.COMPLEMENT[base] for base in reversed(aln[seq_id]))
 721
 722        return reverse_complement_dict
 723
 724    def calc_numerical_alignment(self, encode_mask:bool=False, encode_ambiguities:bool=False) -> ndarray:
 725        """
 726        Transforms the alignment to numerical values. Ambiguities are encoded as -3, mask as -2 and the
 727        remaining chars with the idx + 1 of config.CHAR_COLORS[self.aln_type]['standard'].
 728
 729        :param encode_ambiguities: encode ambiguities as -2
 730        :param encode_mask: encode mask with as -3
 731        :returns matrix
 732        """
 733
 734        sequences = self._to_array()
 735        # ini matrix
 736        numerical_matrix = np.full(sequences.shape, np.nan, dtype=float)
 737        # first encode mask
 738        if encode_mask:
 739            if self.aln_type == 'AA':
 740                is_n_or_x = np.isin(sequences, ['X'])
 741            else:
 742                is_n_or_x = np.isin(sequences, ['N'])
 743            numerical_matrix[is_n_or_x] = -2
 744        # next encode ambig chars
 745        if encode_ambiguities:
 746            numerical_matrix[np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])] = -3
 747        # next convert each char into their respective values
 748        for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]['standard']):
 749            numerical_matrix[np.isin(sequences, [char])] = idx + 1
 750
 751        return numerical_matrix
 752
 753    def calc_identity_alignment(self, encode_mismatches:bool=True, encode_mask:bool=False, encode_gaps:bool=True, encode_ambiguities:bool=False, encode_each_mismatch_char:bool=False) -> ndarray:
 754        """
 755        Converts alignment to identity array (identical=0) compared to majority consensus or reference:\n
 756
 757        :param encode_mismatches: encode mismatch as -1
 758        :param encode_mask: encode mask with value=-2 --> also in the reference
 759        :param encode_gaps: encode gaps with np.nan --> also in the reference
 760        :param encode_ambiguities: encode ambiguities with value=-3
 761        :param encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.CHAR_COLORS[self.aln_type]['standard']
 762        :return: identity alignment
 763        """
 764
 765        ref = self._get_reference_seq()
 766
 767        # convert alignment to array
 768        sequences = self._to_array()
 769        reference = np.array(list(ref))
 770        # ini matrix
 771        identity_matrix = np.full(sequences.shape, 0, dtype=float)
 772
 773        is_identical = sequences == reference
 774
 775        if encode_gaps:
 776            is_gap = sequences == '-'
 777        else:
 778            is_gap = np.full(sequences.shape, False)
 779
 780        if encode_mask:
 781            if self.aln_type == 'AA':
 782                is_n_or_x = np.isin(sequences, ['X'])
 783            else:
 784                is_n_or_x = np.isin(sequences, ['N'])
 785        else:
 786            is_n_or_x = np.full(sequences.shape, False)
 787
 788        if encode_ambiguities:
 789            is_ambig = np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])
 790        else:
 791            is_ambig = np.full(sequences.shape, False)
 792
 793        if encode_mismatches:
 794            is_mismatch = ~is_gap & ~is_identical & ~is_n_or_x & ~is_ambig
 795        else:
 796            is_mismatch = np.full(sequences.shape, False)
 797
 798        # encode every different character
 799        if encode_each_mismatch_char:
 800            for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]['standard']):
 801                new_encoding = np.isin(sequences, [char]) & is_mismatch
 802                identity_matrix[new_encoding] = idx + 1
 803        # or encode different with a single value
 804        else:
 805            identity_matrix[is_mismatch] = -1  # mismatch
 806
 807        identity_matrix[is_gap] = np.nan  # gap
 808        identity_matrix[is_n_or_x] = -2  # 'N' or 'X'
 809        identity_matrix[is_ambig] = -3  # ambiguities
 810
 811        return identity_matrix
 812
 813    def calc_similarity_alignment(self, matrix_type:str|None=None, normalize:bool=True) -> ndarray:
 814        """
 815        Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight
 816        differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the
 817        reference residue at each column. Gaps are encoded as np.nan.
 818
 819        The calculation follows these steps:
 820
 821        1. **Reference Sequence**: If a reference sequence is provided (via `self.reference_id`), it is used. Otherwise,
 822           a consensus sequence is generated to serve as the reference.
 823        2. **Substitution Matrix**: The similarity between residues is determined using a substitution matrix, such as
 824           BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
 825        3. **Per-Column Normalization (optional)**:
 826
 827        For each column in the alignment:
 828            - The residue in the reference sequence is treated as the baseline for that column.
 829            - The substitution scores for the reference residue are extracted from the substitution matrix.
 830            - The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference residue.
 831            - This ensures that identical residues (or those with high similarity to the reference) have high scores,
 832            while more dissimilar residues have lower scores.
 833        4. **Output**:
 834
 835           - The normalized similarity scores are stored in a NumPy array.
 836           - Gaps (if any) or residues not present in the substitution matrix are encoded as `np.nan`.
 837
 838        :param: matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
 839        :param: normalize: whether to normalize the similarity scores to range [0, 1]
 840        :return: A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue
 841            and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity).
 842            Gaps and invalid residues are encoded as `np.nan`.
 843        :raise: ValueError
 844            If the specified substitution matrix is not available for the given alignment type.
 845        """
 846
 847        ref = self._get_reference_seq()
 848
 849        if matrix_type is None:
 850            if self.aln_type == 'AA':
 851                matrix_type = 'BLOSUM65'
 852            else:
 853                matrix_type = 'TRANS'
 854        # load substitution matrix as dictionary
 855        try:
 856            subs_matrix = config.SUBS_MATRICES[self.aln_type][matrix_type]
 857        except KeyError:
 858            raise ValueError(
 859                f'The specified matrix does not exist for alignment type.\nAvailable matrices for {self.aln_type} are:\n{list(config.SUBS_MATRICES[self.aln_type].keys())}'
 860            )
 861
 862        # set dtype and convert alignment to a NumPy array for vectorized processing
 863        dtype = np.dtype(float, metadata={'matrix': matrix_type})
 864        sequences = self._to_array()
 865        reference = np.array(list(ref))
 866        valid_chars = list(subs_matrix.keys())
 867        similarity_array = np.full(sequences.shape, np.nan, dtype=dtype)
 868
 869        for j, ref_char in enumerate(reference):
 870            if ref_char not in valid_chars + ['-']:
 871                continue
 872            # Get local min and max for the reference residue
 873            if normalize and ref_char != '-':
 874                local_scores = subs_matrix[ref_char].values()
 875                local_min, local_max = min(local_scores), max(local_scores)
 876
 877            for i, char in enumerate(sequences[:, j]):
 878                if char not in valid_chars:
 879                    continue
 880                # classify the similarity as max if the reference has a gap
 881                similarity_score = subs_matrix[char][ref_char] if ref_char != '-' else 1
 882                similarity_array[i, j] = (similarity_score - local_min) / (local_max - local_min) if normalize and ref_char != '-' else similarity_score
 883
 884        return similarity_array
 885
 886    def calc_position_matrix(self, matrix_type:str='PWM') -> None | ndarray:
 887        """
 888        Calculates a position matrix of the specified type for the given alignment. The function
 889        supports generating matrices of types Position Frequency Matrix (PFM), Position Probability
 890        Matrix (PPM), Position Weight Matrix (PWM), and cumulative Information Content (IC). It validates
 891        the provided matrix type and includes pseudo-count adjustments to ensure robust calculations.
 892
 893        :param matrix_type: Type of position matrix to calculate. Accepted values are 'PFM', 'PPM',
 894            'PWM', and 'IC'. Defaults to 'PWM'.
 895        :type matrix_type: str
 896        :raises ValueError: If the provided `matrix_type` is not one of the accepted values.
 897        :return: A numpy array representing the calculated position matrix of the specified type.
 898        :rtype: np.ndarray
 899        """
 900
 901        # ini
 902        if matrix_type not in ['PFM', 'PPM', 'IC', 'PWM']:
 903            raise ValueError('Matrix_type must be PFM, PPM, IC or PWM.')
 904        possible_chars = list(config.CHAR_COLORS[self.aln_type]['standard'].keys())
 905        sequences = self._to_array()
 906
 907        # calc position frequency matrix
 908        pfm = np.array([np.sum(sequences == char, 0) for char in possible_chars])
 909        if matrix_type == 'PFM':
 910            return pfm
 911
 912        # calc position probability matrix (probability)
 913        pseudo_count = 0.0001  # to avoid 0 values
 914        pfm = pfm + pseudo_count
 915        ppm_non_char_excluded = pfm/np.sum(pfm, axis=0)  # use this for pwm/ic calculation
 916        ppm = pfm/len(self.sequence_ids)  # calculate the frequency based on row number
 917        if matrix_type == 'PPM':
 918            return ppm
 919
 920        # calc position weight matrix (log-likelihood)
 921        pwm = np.log2(ppm_non_char_excluded * len(possible_chars))
 922        if matrix_type == 'PWM':
 923            return pwm
 924
 925        # calc information content per position (in bits) - can be used to scale a ppm for sequence logos
 926        ic = np.sum(ppm_non_char_excluded * pwm, axis=0)
 927        if matrix_type == 'IC':
 928            return ic
 929
 930    def calc_percent_recovery(self) -> dict[str, float]:
 931        """
 932        Recovery per sequence either compared to the majority consensus seq
 933        or the reference seq.\n
 934        Defined as:\n
 935
 936        `(1 - sum(N/X/- characters in ungapped ref regions))*100`
 937
 938        This is highly similar to how nextclade calculates recovery over reference.
 939
 940        :return: dict
 941        """
 942
 943        aln = self.alignment
 944        ref = self._get_reference_seq()
 945
 946        if not any(char != '-' for char in ref):
 947            raise ValueError("Reference sequence is entirely gapped, cannot calculate recovery.")
 948
 949
 950        # count 'N', 'X' and '-' chars in non-gapped regions
 951        recovery_over_ref = dict()
 952
 953        # Get positions of non-gap characters in the reference
 954        non_gap_positions = [i for i, char in enumerate(ref) if char != '-']
 955        cumulative_length = len(non_gap_positions)
 956
 957        # Calculate recovery
 958        for seq_id in self.sequence_ids:
 959            if seq_id == self.reference_id:
 960                continue
 961            seq = aln[seq_id]
 962            count_invalid = sum(
 963                seq[pos] == '-' or
 964                (seq[pos] == 'X' if self.aln_type == "AA" else seq[pos] == 'N')
 965                for pos in non_gap_positions
 966            )
 967            recovery_over_ref[seq_id] = (1 - count_invalid / cumulative_length) * 100
 968
 969        return recovery_over_ref
 970
 971    def calc_character_frequencies(self) -> dict:
 972        """
 973        Calculate the percentage characters in the alignment:
 974        The frequencies are counted by seq and in total. The
 975        percentage of non-gap characters in the alignment is
 976        relative to the total number of non-gap characters.
 977        The gap percentage is relative to the sequence length.
 978
 979        The output is a nested dictionary.
 980
 981        :return: Character frequencies
 982        """
 983
 984        aln, aln_length = self.alignment, self.length
 985
 986        freqs = {'total': {'-': {'counts': 0, '% of alignment': float()}}}
 987
 988        for seq_id in aln:
 989            freqs[seq_id], all_chars = {'-': {'counts': 0, '% of alignment': float()}}, 0
 990            unique_chars = set(aln[seq_id])
 991            for char in unique_chars:
 992                if char == '-':
 993                    continue
 994                # add characters to dictionaries
 995                if char not in freqs[seq_id]:
 996                    freqs[seq_id][char] = {'counts': 0, '% of non-gapped': 0}
 997                if char not in freqs['total']:
 998                    freqs['total'][char] = {'counts': 0, '% of non-gapped': 0}
 999                # count non-gap chars
1000                freqs[seq_id][char]['counts'] += aln[seq_id].count(char)
1001                freqs['total'][char]['counts'] += freqs[seq_id][char]['counts']
1002                all_chars += freqs[seq_id][char]['counts']
1003            # normalize counts
1004            for char in freqs[seq_id]:
1005                if char == '-':
1006                    continue
1007                freqs[seq_id][char]['% of non-gapped'] = freqs[seq_id][char]['counts'] / all_chars * 100
1008                freqs['total'][char]['% of non-gapped'] += freqs[seq_id][char]['% of non-gapped']
1009            # count gaps
1010            freqs[seq_id]['-']['counts'] = aln[seq_id].count('-')
1011            freqs['total']['-']['counts'] += freqs[seq_id]['-']['counts']
1012            # normalize gap counts
1013            freqs[seq_id]['-']['% of alignment'] = freqs[seq_id]['-']['counts'] / aln_length * 100
1014            freqs['total']['-']['% of alignment'] += freqs[seq_id]['-']['% of alignment']
1015
1016        # normalize the total counts
1017        for char in freqs['total']:
1018            for value in freqs['total'][char]:
1019                if value == '% of alignment' or value == '% of non-gapped':
1020                    freqs['total'][char][value] = freqs['total'][char][value] / len(aln)
1021
1022        return freqs
1023
1024    def calc_pairwise_identity_matrix(self, distance_type:str='ghd') -> PairwiseDistance:
1025        """
1026        Calculate pairwise identities for an alignment. Different options are implemented:
1027
1028        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
1029        \ndistance = matches / alignment_length * 100
1030
1031        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
1032        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
1033
1034        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
1035        \ndistance = (1 - mismatches / total) * 100
1036
1037        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
1038        \ndistance = matches / gap_compressed_alignment_length * 100
1039
1040        RNA/DNA only:
1041
1042        **5) jc69 (Jukes-Cantor 1969)**: Gaps excluded. Applies the JC69 substitution model to correct
1043        the p-distance for multiple hits (assumes equal base frequencies and substitution rates).
1044        \ncorrected_identity = (1 - d_JC69) * 100,  where d = -(3/4) * ln(1 - (4/3) * p)
1045
1046        **6) k2p (Kimura 2-Parameter / K80)**: Gaps excluded. Distinguishes transitions (Ts) and
1047        transversions (Tv). Returns (1 - d_K2P) * 100 as corrected percent identity.
1048        \nd = -(1/2) * ln(1 - 2P - Q) - (1/4) * ln(1 - 2Q),  P = Ts/total, Q = Tv/total
1049
1050        :param distance_type: type of distance computation: ghd, lhd, ged, gcd and nucleotide only: jc69 and k2p
1051        :return: array with pairwise distances.
1052        """
1053
1054        distance_functions = _create_distance_calculation_function_mapping()
1055
1056        if distance_type not in distance_functions:
1057            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1058
1059        aln = self.alignment
1060
1061        if self.aln_type == 'AA' and distance_type in ['jc69', 'k2p']:
1062            raise ValueError(f"JC69 and K2P are not supported for {self.aln_type} alignment.")
1063
1064        # Compute pairwise distances
1065        distance_func = distance_functions[distance_type]
1066        distance_matrix = np.zeros((len(aln), len(aln)))
1067
1068        sequences = list(aln.values())
1069        n = len(sequences)
1070        for i in range(n):
1071            seq1 = sequences[i]
1072            for j in range(i, n):
1073                seq2 = sequences[j]
1074                dist = distance_func(seq1, seq2, self.length)
1075                distance_matrix[i, j] = dist
1076                distance_matrix[j, i] = dist
1077
1078        return PairwiseDistance(
1079            reference_id=None,
1080            sequence_ids=self.sequence_ids,
1081            distances=distance_matrix
1082        )
1083
1084    def calc_pairwise_distance_to_reference(self, distance_type:str='ghd') -> PairwiseDistance:
1085        """
1086        Calculate pairwise identities between reference and all sequences in the alignment. Same computation as calc_pairwise_identity_matrix but compared to a single sequence. Different options are implemented:
1087
1088        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
1089        \ndistance = matches / alignment_length * 100
1090
1091        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
1092        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
1093
1094        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
1095        \ndistance = (1 - mismatches / total) * 100
1096
1097        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
1098        \ndistance = matches / gap_compressed_alignment_length * 100
1099
1100        RNA/DNA only:
1101
1102        **5) jc69 (Jukes-Cantor 1969)**: Gaps excluded. Applies the JC69 substitution model to correct
1103        the p-distance for multiple hits (assumes equal base frequencies and substitution rates).
1104        \ncorrected_identity = (1 - d_JC69) * 100,  where d = -(3/4) * ln(1 - (4/3) * p)
1105
1106        **6) k2p (Kimura 2-Parameter / K80)**: Gaps excluded. Distinguishes transitions (Ts) and
1107        transversions (Tv). Returns (1 - d_K2P) * 100 as corrected percent identity.
1108        \nd = -(1/2) * ln(1 - 2P - Q) - (1/4) * ln(1 - 2Q),  P = Ts/total, Q = Tv/total
1109
1110        :param distance_type: type of distance computation: ghd, lhd, ged, gcd and nucleotide only: jc69 and k2p
1111        :return: array with pairwise distances.
1112        :return: dataclass with reference label, sequence ids and pairwise distances.
1113        """
1114
1115        distance_functions = _create_distance_calculation_function_mapping()
1116
1117        if distance_type not in distance_functions:
1118            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1119
1120        if self.aln_type == 'AA' and distance_type in ['jc69', 'k2p']:
1121            raise ValueError(f"JC69 and K2P are not supported for {self.aln_type} alignment.")
1122
1123        # Compute pairwise distances
1124        distance_func = distance_functions[distance_type]
1125        aln = self.alignment
1126        ref_id = self.reference_id
1127
1128        ref_seq = aln[ref_id] if ref_id is not None else self.get_consensus()
1129        distances = []
1130        distance_names = []
1131
1132        for seq_id in aln:
1133            if seq_id == ref_id:
1134                continue
1135            distance_names.append(seq_id)
1136            distances.append(distance_func(ref_seq, aln[seq_id], self.length))
1137
1138        return PairwiseDistance(
1139            reference_id=ref_id if ref_id is not None else 'consensus',
1140            sequence_ids=distance_names,
1141            distances=np.array(distances)
1142        )
1143
1144    def get_snps(self, include_ambig:bool=False) -> VariantCollection:
1145        """
1146        Calculate snps similar to snp-sites (output is comparable):
1147        https://github.com/sanger-pathogens/snp-sites
1148        Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character.
1149        The SNPs are compared to a majority consensus sequence or to a reference if it has been set.
1150
1151        :param include_ambig: Include ambiguous snps (default: False)
1152        :return: dataclass containing SNP positions and their variants including frequency.
1153        """
1154        aln = self.alignment
1155        ref = self._get_reference_seq()
1156        aln = {x: aln[x] for x in self.sequence_ids if x != self.reference_id}
1157        seq_ids = list(aln.keys())
1158        chrom = self.reference_id if self.reference_id is not None else 'consensus'
1159        snp_positions = {}
1160        aln_size = len(aln)
1161
1162        for pos in range(self.length):
1163            reference_char = ref[pos]
1164            if not include_ambig:
1165                if reference_char in config.AMBIG_CHARS[self.aln_type] and reference_char != '-':
1166                    continue
1167            alt_chars, snps = [], []
1168            for i, seq_id in enumerate(seq_ids):
1169                alt_char = aln[seq_id][pos]
1170                alt_chars.append(alt_char)
1171                if reference_char != alt_char:
1172                    snps.append(i)
1173            if not snps:
1174                continue
1175            # Filter out ambiguous snps if not included
1176            if include_ambig:
1177                if all(alt_chars[x] in config.AMBIG_CHARS[self.aln_type] for x in snps):
1178                    continue
1179            else:
1180                snps = [x for x in snps if alt_chars[x] not in config.AMBIG_CHARS[self.aln_type]]
1181                if not snps:
1182                    continue
1183            # Build allele dict with counts
1184            alt_dict = {}
1185            for snp_idx in snps:
1186                alt_char = alt_chars[snp_idx]
1187                if alt_char not in alt_dict:
1188                    alt_dict[alt_char] = {'count': 0, 'seq_ids': []}
1189                alt_dict[alt_char]['count'] += 1
1190                alt_dict[alt_char]['seq_ids'].append(seq_ids[snp_idx])
1191            
1192            # Convert to final format: alt_char -> (frequency, seq_ids_tuple)
1193            alt = {
1194                alt_char: (data['count'] / aln_size, tuple(data['seq_ids']))
1195                for alt_char, data in alt_dict.items()
1196            }
1197            snp_positions[pos] = SingleNucleotidePolymorphism(ref=reference_char, alt=alt)
1198
1199        return VariantCollection(chrom=chrom, positions=snp_positions)
1200
1201    def calc_transition_transversion_score(self) -> AlignmentStats:
1202        """
1203        Based on the snp positions, calculates a transition/transversions score.
1204        A positive score means higher ratio of transitions and negative score means
1205        a higher ratio of transversions.
1206        :return: list
1207        """
1208
1209        if self.aln_type == 'AA':
1210            raise TypeError('TS/TV scoring only for RNA/DNA alignments')
1211
1212        # ini
1213        snps = self.get_snps()
1214        score = [0]*self.length
1215
1216        for pos, snp in snps.positions.items():
1217            for alt, (af, _) in snp.alt.items():
1218                # check the type of substitution
1219                if snp.ref + alt in ['AG', 'GA', 'CT', 'TC', 'CU', 'UC']:
1220                    score[pos] += af
1221                else:
1222                    score[pos] -= af
1223
1224        return self._create_position_stat_result('ts tv score', score)
1225
1226
1227class Annotation:
1228    """
1229    An annotation class that allows to read in gff, gb, or bed files and adjust its locations to that of the MSA.
1230    """
1231
1232    def __init__(self, aln: MSA, annotation: str | GenBankIterator):
1233        """
1234        The annotation class. Lets you parse multiple standard formats
1235        which might be used for annotating an alignment. The main purpose
1236        is to parse the annotation file and adapt the locations of diverse
1237        features to the locations within the alignment, considering the
1238        respective alignment positions. Importantly, IDs of the alignment
1239        and the MSA have to partly match.
1240
1241        :param aln: MSA class
1242        :param annotation: path to file (gb, bed, gff) or raw string or GenBankIterator from biopython
1243
1244        """
1245
1246        self.ann_type, self._seq_id, self.locus, self.features  = self._parse_annotation(annotation, aln)  # read annotation
1247        self._gapped_seq = self._MSA_validation_and_seq_extraction(aln, self._seq_id)  # extract gapped sequence
1248        self._position_map = self._build_position_map()  # build a position map
1249        self._map_to_alignment()  # adapt feature locations
1250
1251    @staticmethod
1252    def _MSA_validation_and_seq_extraction(aln: MSA, seq_id: str) -> str:
1253        """
1254        extract gapped sequence from MSA that corresponds to annotation
1255        :param aln: MSA class
1256        :param seq_id: sequence id to extract
1257        :return: gapped sequence
1258        """
1259        if not isinstance(aln, MSA):
1260            raise ValueError('alignment has to be an MSA class. use explore.MSA() to read in alignment')
1261        else:
1262            return aln._alignment[seq_id]
1263
1264    @staticmethod
1265    def _parse_annotation(annotation: str | GenBankIterator, aln: MSA) -> tuple[str, str, str, Dict]:
1266
1267        def detect_annotation_type(handle: str | GenBankIterator) -> str:
1268            """
1269            Detect the type of annotation file (GenBank, GFF, or BED) based
1270            on the first relevant line (excluding empty and #). Also recognizes
1271            Bio.SeqIO iterators as GenBank format.
1272
1273            :param handle: Path to the annotation file or Bio.SeqIO iterator for genbank records read with biopython.
1274            :return: The detected file type ('gb', 'gff', or 'bed').
1275
1276            :raises ValueError: If the file type cannot be determined.
1277            """
1278            # Check if input is a SeqIO iterator
1279            if isinstance(handle, GenBankIterator):
1280                return 'gb'
1281
1282            with _get_line_iterator(handle) as h:
1283                for line in h:
1284                    # skip empty lines and comments
1285                    if not line.strip() or line.startswith('#'):
1286                        continue
1287                   # genbank
1288                    if line.startswith('LOCUS'):
1289                        return 'gb'
1290                    # gff
1291                    if len(line.split('\t')) == 9:
1292                        # Check for expected values
1293                        columns = line.split('\t')
1294                        if columns[6] in ['+', '-', '.'] and re.match(r'^\d+$', columns[3]) and re.match(r'^\d+$',columns[4]):
1295                            return 'gff'
1296                    # BED files are tab-delimited with at least 3 fields: chrom, start, end
1297                    fields = line.split('\t')
1298                    if len(fields) >= 3 and re.match(r'^\d+$', fields[1]) and re.match(r'^\d+$', fields[2]):
1299                        return 'bed'
1300                    # only read in the first line
1301                    break
1302
1303            raise ValueError(
1304                "File type could not be determined. Ensure the file follows a recognized format (GenBank, GFF, or BED).")
1305
1306        def parse_gb(file: str | GenBankIterator) -> dict:
1307            """
1308            Parse a GenBank file into the same dictionary structure used by the annotation pipeline.
1309
1310            :param file: path to genbank file, raw string, or Bio.SeqIO iterator
1311            :return: nested dictionary
1312            """
1313            records = {}
1314
1315            # Check if input is a GenBankIterator
1316            if isinstance(file, GenBankIterator):
1317                # Direct GenBankIterator input
1318                seq_records = list(file)
1319            else:
1320                # File path or string input
1321                with _get_line_iterator(file) as handle:
1322                    seq_records = list(SeqIO.parse(handle, "genbank"))
1323
1324            for seq_record in seq_records:
1325                locus = seq_record.name if seq_record.name else seq_record.id
1326                feature_container = {}
1327                feature_counter = {}
1328
1329                for feature in seq_record.features:
1330                    feature_type = feature.type
1331                    if feature_type not in feature_container:
1332                        feature_container[feature_type] = {}
1333                        feature_counter[feature_type] = 0
1334
1335                    current_idx = feature_counter[feature_type]
1336                    feature_counter[feature_type] += 1
1337
1338                    # Support simple and compound feature locations uniformly.
1339                    parts = feature.location.parts if hasattr(feature.location, "parts") else [feature.location]
1340                    locations = []
1341                    for part in parts:
1342                        try:
1343                            locations.append([int(part.start), int(part.end)])
1344                        except (TypeError, ValueError):
1345                            continue
1346
1347                    strand = '-' if feature.location.strand == -1 else '+'
1348                    parsed_feature = {
1349                        'location': locations,
1350                        'strand': strand,
1351                    }
1352
1353                    # Keep qualifier keys unchanged and flatten values to strings.
1354                    for qualifier_type, qualifier_values in feature.qualifiers.items():
1355                        if not qualifier_values:
1356                            continue
1357                        if isinstance(qualifier_values, list):
1358                            parsed_feature[qualifier_type] = qualifier_values[0] if len(qualifier_values) == 1 else ' '.join(str(x) for x in qualifier_values)
1359                        else:
1360                            parsed_feature[qualifier_type] = str(qualifier_values)
1361
1362                    feature_container[feature_type][current_idx] = parsed_feature
1363
1364                records[locus] = {
1365                    'locus': locus,
1366                    'features': feature_container,
1367                }
1368
1369            return records
1370
1371        def parse_gff(file_path) -> dict:
1372            """
1373            Parse a GFF3 (General Feature Format) file into a dictionary structure.
1374
1375            :param file_path: path to genebank file
1376            :return: nested dictionary
1377
1378            """
1379            records = {}
1380            with _get_line_iterator(file_path) as file:
1381                previous_id, previous_feature = None, None
1382                for line in file:
1383                    if line.startswith('#') or not line.strip():
1384                        continue
1385                    parts = line.strip().split('\t')
1386                    seqid, source, feature_type, start, end, score, strand, phase, attributes = parts
1387                    # ensure that region and source features are not named differently for gff and gb
1388                    if feature_type == 'region':
1389                        feature_type = 'source'
1390                    if seqid not in records:
1391                        records[seqid] = {'locus': seqid, 'features': {}}
1392                    if feature_type not in records[seqid]['features']:
1393                        records[seqid]['features'][feature_type] = {}
1394
1395                    feature_id = len(records[seqid]['features'][feature_type])
1396                    feature = {
1397                        'strand': strand,
1398                    }
1399
1400                    # Parse attributes into key-value pairs
1401                    for attr in attributes.split(';'):
1402                        if '=' in attr:
1403                            key, value = attr.split('=', 1)
1404                            feature[key.strip()] = value.strip()
1405
1406                    # check if feature are the same --> possible splicing
1407                    if previous_id is not None and previous_feature == feature:
1408                        records[seqid]['features'][feature_type][previous_id]['location'].append([int(start)-1, int(end)])
1409                    else:
1410                        records[seqid]['features'][feature_type][feature_id] = feature
1411                        records[seqid]['features'][feature_type][feature_id]['location'] = [[int(start) - 1, int(end)]]
1412                    # set new previous id and features -> new dict as 'location' is pointed in current feature and this
1413                    # is the only key different if next feature has the same entries
1414                    previous_id, previous_feature = feature_id, {key:value for key, value in feature.items() if key != 'location'}
1415
1416            return records
1417
1418        def parse_bed(file_path) -> dict:
1419            """
1420            Parse a BED file into a dictionary structure.
1421
1422            :param file_path: path to genebank file
1423            :return: nested dictionary
1424
1425            """
1426            records = {}
1427            with _get_line_iterator(file_path) as file:
1428                for line in file:
1429                    if line.startswith('#') or not line.strip():
1430                        continue
1431                    parts = line.strip().split('\t')
1432                    chrom, start, end, *optional = parts
1433
1434                    if chrom not in records:
1435                        records[chrom] = {'locus': chrom, 'features': {}}
1436                    feature_type = 'region'
1437                    if feature_type not in records[chrom]['features']:
1438                        records[chrom]['features'][feature_type] = {}
1439
1440                    feature_id = len(records[chrom]['features'][feature_type])
1441                    feature = {
1442                        'location': [[int(start), int(end)]],  # BED uses 0-based start, convert to 1-based
1443                        'strand': '+',  # assume '+' if not present
1444                    }
1445
1446                    # Handle optional columns (name, score, strand) --> ignore 7-12
1447                    if len(optional) >= 1:
1448                        feature['name'] = optional[0]
1449                    if len(optional) >= 2:
1450                        feature['score'] = optional[1]
1451                    if len(optional) >= 3:
1452                        feature['strand'] = optional[2]
1453
1454                    records[chrom]['features'][feature_type][feature_id] = feature
1455
1456            return records
1457
1458        parse_functions: Dict[str, Callable[[str], dict]] = {
1459            'gb': parse_gb,
1460            'bed': parse_bed,
1461            'gff': parse_gff,
1462        }
1463        # determine the annotation content -> should be standard formatted
1464        try:
1465            annotation_type = detect_annotation_type(annotation)
1466        except ValueError as err:
1467            raise err
1468
1469        # read in the annotation
1470        annotations = parse_functions[annotation_type](annotation)
1471
1472        # sanity check whether one of the annotation ids and alignment ids match
1473        annotation_found = False
1474        for annotation in annotations.keys():
1475            for aln_id in aln:
1476                aln_id_sanitized = aln_id.split(' ')[0]
1477                # check in both directions
1478                if aln_id_sanitized in annotation:
1479                    annotation_found = True
1480                    break
1481                if annotation in aln_id_sanitized:
1482                    annotation_found = True
1483                    break
1484
1485        if not annotation_found:
1486            raise ValueError(f'the annotations of {annotation} do not match any ids in the MSA')
1487
1488        # return only the annotation that has been found, the respective type and the seq_id to map to
1489        return annotation_type, aln_id, annotations[annotation]['locus'], annotations[annotation]['features']
1490
1491
1492    def _build_position_map(self) -> Dict[int, int]:
1493        """
1494        build a position map from a sequence.
1495
1496        :return genomic position: gapped position
1497        """
1498
1499        position_map = {}
1500        genomic_pos = 0
1501        for aln_index, char in enumerate(self._gapped_seq):
1502            if char != '-':
1503                position_map[genomic_pos] = aln_index
1504                genomic_pos += 1
1505        # ensure the last genomic position is included
1506        position_map[genomic_pos] = len(self._gapped_seq)
1507
1508        return position_map
1509
1510
1511    def _map_to_alignment(self):
1512        """
1513        Adjust all feature locations to alignment positions
1514        """
1515
1516        def map_location(position_map: Dict[int, int], locations: list) -> list:
1517            """
1518            Map genomic locations to alignment positions using a precomputed position map.
1519
1520            :param position_map: Positions mapped from gapped to ungapped
1521            :param locations: List of genomic start and end positions.
1522            :return: List of adjusted alignment positions.
1523            """
1524
1525            aligned_locs = []
1526            for start, end in locations:
1527                try:
1528                    aligned_start = position_map[start]
1529                    aligned_end = position_map[end]
1530                    aligned_locs.append([aligned_start, aligned_end])
1531                except KeyError:
1532                    raise ValueError(f"Positions {start}-{end} lie outside of the position map.")
1533
1534            return aligned_locs
1535
1536        for feature_type, features in self.features.items():
1537            for feature_id, feature_data in features.items():
1538                original_locations = feature_data['location']
1539                aligned_locations = map_location(self._position_map, original_locations)
1540                feature_data['location'] = aligned_locations
class MSA:
  31class MSA:
  32    """
  33    An alignment class that allows computation of several stats. Supported inputs are file paths to alignments in "fasta",
  34    "clustal", "phylip", "stockholm", "nexus" formats, raw alignment strings, or Bio.Align.MultipleSeqAlignment for
  35    compatibility with Biopython.
  36    """
  37
  38    # dunder methods
  39    def __init__(self, alignment_string: str | MultipleSeqAlignment, reference_id: str = None, zoom_range: tuple | int = None):
  40        """
  41        Initialize an Alignment object.
  42        :param alignment_string: Path to alignment file or raw alignment string
  43        :param reference_id: reference id
  44        :param zoom_range: start and stop positions to zoom into the alignment
  45        """
  46        self._alignment = _read_alignment(alignment_string)
  47        self._reference_id = self._validate_ref(reference_id, self._alignment)
  48        self._zoom = self._validate_zoom(zoom_range, self._alignment)
  49        self._aln_type = self._determine_aln_type(self._alignment)
  50
  51    def __len__(self) -> int:
  52        return len(self.sequence_ids)
  53
  54    def __getitem__(self, key: str) -> str:
  55        return self.alignment[key]
  56
  57    def __contains__(self, item: str) -> bool:
  58        return item in self.sequence_ids
  59
  60    def __iter__(self):
  61        return iter(self.alignment)
  62
  63    # Static methods
  64    @staticmethod
  65    def _validate_ref(reference: str | None, alignment: dict) -> str | None | ValueError:
  66        """
  67        Validate if the ref seq is indeed part of the alignment.
  68        :param reference: reference seq id
  69        :param alignment: alignment dict
  70        :return: validated reference
  71        """
  72        if reference in alignment.keys():
  73            return reference
  74        elif reference is None:
  75            return reference
  76        else:
  77            raise ValueError('Reference not in alignment.')
  78
  79    @staticmethod
  80    def _validate_zoom(zoom: tuple | int, original_aln: dict) -> ValueError | tuple | None:
  81        """
  82        Validates if the user defined zoom range is within the start, end of the initial
  83        alignment.\n
  84        :param zoom: zoom range or zoom start
  85        :param original_aln: non-zoomed alignment dict
  86        :return: validated zoom range
  87        """
  88        if zoom is not None:
  89            aln_length = len(original_aln[list(original_aln.keys())[0]])
  90            # check if only over value is provided -> stop is alignment length
  91            if isinstance(zoom, int):
  92                if 0 <= zoom < aln_length:
  93                    return zoom, aln_length
  94                else:
  95                    raise ValueError('Zoom start must be within the alignment length range.')
  96            # check if more than 2 values are provided
  97            if len(zoom) != 2:
  98                raise ValueError('Zoom position have to be (zoom_start, zoom_end)')
  99            start, end = zoom
 100            # validate zoom start/stop for Python slicing semantics [start:end)
 101            if type(start) is not int or type(end) is not int:
 102                raise ValueError('Zoom positions have to be integers.')
 103            if not (0 <= start < aln_length):
 104                raise ValueError('Zoom position out of range')
 105            if not (0 < end <= aln_length):
 106                raise ValueError('Zoom position out of range')
 107            if start >= end:
 108                raise ValueError('Zoom position have to be (zoom_start, zoom_end) with zoom_start < zoom_end')
 109
 110        return zoom
 111
 112    @staticmethod
 113    def _determine_aln_type(alignment) -> str:
 114        """
 115        Determine the most likely type of alignment
 116        if 70% of chars in the alignment are nucleotide
 117        chars it is most likely a nt alignment
 118        :return: type of alignment
 119        """
 120        counter = int()
 121        for record in alignment:
 122            if 'U' in alignment[record]:
 123                return 'RNA'
 124            counter += sum(map(alignment[record].count, ['A', 'C', 'G', 'T', 'N', '-']))
 125        # determine which is the most likely type
 126        if counter / len(alignment) >= 0.7 * len(alignment[list(alignment.keys())[0]]):
 127            return 'DNA'
 128        else:
 129            return 'AA'
 130
 131    # Properties with setters
 132    @property
 133    def reference_id(self) -> str:
 134        return self._reference_id
 135
 136    @reference_id.setter
 137    def reference_id(self, ref_id: str):
 138        """
 139        Set and validate the reference id.
 140        """
 141        self._reference_id = self._validate_ref(ref_id, self.alignment)
 142
 143    @property
 144    def zoom(self) -> tuple:
 145        return self._zoom
 146
 147    @zoom.setter
 148    def zoom(self, zoom_pos: tuple | int):
 149        """
 150        Validate if the user defined zoom range.
 151        """
 152        self._zoom = self._validate_zoom(zoom_pos, self._alignment)
 153
 154    # Property without setters
 155    @property
 156    def aln_type(self) -> str:
 157        """
 158        define the aln type:
 159        RNA, DNA or AA
 160        """
 161        return self._aln_type
 162
 163    @property
 164    def sequence_ids(self) -> list:
 165        return list(self.alignment.keys())
 166
 167    # On the fly properties without setters
 168    @property
 169    def length(self) -> int:
 170        return len(next(iter(self.alignment.values())))
 171
 172    @property
 173    def alignment(self) -> dict:
 174        """
 175        (zoomed) version of the alignment.
 176        """
 177        if self.zoom is not None:
 178            zoomed_aln = dict()
 179            for seq in self._alignment:
 180                zoomed_aln[seq] = self._alignment[seq][self.zoom[0]:self.zoom[1]]
 181            return zoomed_aln
 182        else:
 183            return self._alignment
 184
 185    def _create_position_stat_result(self, stat_name: str, values: list | ndarray) -> AlignmentStats:
 186        """
 187        Build a shared dataclass for position-wise statistics.
 188        """
 189        values_array = np.asarray(values, dtype=float)
 190        if self.zoom is None:
 191            positions = np.arange(self.length, dtype=int)
 192        else:
 193            positions = np.arange(self.zoom[0], self.zoom[0] + self.length, dtype=int)
 194
 195        return AlignmentStats(
 196            stat_name=stat_name,
 197            positions=positions,
 198            values=values_array,
 199        )
 200
 201    def _to_array(self) -> ndarray:
 202        """convert alignment to a numpy array"""
 203        return np.array([list(self.alignment[seq_id]) for seq_id in self.sequence_ids])
 204
 205    def _get_reference_seq(self) -> str:
 206        """get the sequence of the reference sequence or majority consensus"""
 207        return self.alignment[self.reference_id] if self.reference_id is not None else self.get_consensus()
 208
 209    def get_reference_coords(self) -> tuple[int, int]:
 210        """
 211        Determine the start and end coordinates of the reference sequence
 212        defined as the first/last nucleotide in the reference sequence
 213        (excluding N and gaps).
 214
 215        :return: Start, End
 216        """
 217        start, end = 0, self.length
 218
 219        if self.reference_id is None:
 220            return start, end
 221        else:
 222            # 5' --> 3'
 223            for start in range(self.length):
 224                if self.alignment[self.reference_id][start] not in ['-', 'N']:
 225                    break
 226            # 3' --> 5'
 227            for end in range(self.length - 1, 0, -1):
 228                if self.alignment[self.reference_id][end] not in ['-', 'N']:
 229                    break
 230
 231            return start, end
 232
 233    def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
 234        """
 235        Creates a non-gapped consensus sequence.
 236
 237        :param threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes
 238            the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments)
 239            or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
 240        :param use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position
 241            has a frequency above the defined threshold.
 242        :return: consensus sequence
 243        """
 244
 245        # helper functions
 246        def determine_counts(alignment_dict: dict, position: int) -> dict:
 247            """
 248            count the number of each char at
 249            an idx of the alignment. return sorted dic.
 250            handles ambiguous nucleotides in sequences.
 251            also handles gaps.
 252            """
 253            nucleotide_list = []
 254
 255            # get all nucleotides
 256            for sequence in alignment_dict.items():
 257                nucleotide_list.append(sequence[1][position])
 258            # count occurences of nucleotides
 259            counter = dict(collections.Counter(nucleotide_list))
 260            # get permutations of an ambiguous nucleotide
 261            to_delete = []
 262            temp_dict = {}
 263            for nucleotide in counter:
 264                if nucleotide in config.AMBIG_CHARS[self.aln_type]:
 265                    to_delete.append(nucleotide)
 266                    permutations = config.AMBIG_CHARS[self.aln_type][nucleotide]
 267                    adjusted_freq = 1 / len(permutations)
 268                    for permutation in permutations:
 269                        if permutation in temp_dict:
 270                            temp_dict[permutation] += adjusted_freq
 271                        else:
 272                            temp_dict[permutation] = adjusted_freq
 273
 274            # drop ambiguous entries and add adjusted freqs to
 275            if to_delete:
 276                for i in to_delete:
 277                    counter.pop(i)
 278                for nucleotide in temp_dict:
 279                    if nucleotide in counter:
 280                        counter[nucleotide] += temp_dict[nucleotide]
 281                    else:
 282                        counter[nucleotide] = temp_dict[nucleotide]
 283
 284            return dict(sorted(counter.items(), key=lambda x: x[1], reverse=True))
 285
 286        def get_consensus_char(counts: dict, cutoff: float) -> list:
 287            """
 288            get a list of nucleotides for the consensus seq
 289            """
 290            n = 0
 291
 292            consensus_chars = []
 293            for char in counts:
 294                n += counts[char]
 295                consensus_chars.append(char)
 296                if n >= cutoff:
 297                    break
 298
 299            return consensus_chars
 300
 301        def get_ambiguous_char(nucleotides: list) -> str:
 302            """
 303            get ambiguous char from a list of nucleotides
 304            """
 305            for ambiguous, permutations in config.AMBIG_CHARS[self.aln_type].items():
 306                if set(permutations) == set(nucleotides):
 307                    break
 308
 309            return ambiguous
 310
 311        # check if params have been set correctly
 312        if threshold is not None:
 313            if threshold < 0 or threshold > 1:
 314                raise ValueError('Threshold must be between 0 and 1.')
 315        if self.aln_type == 'AA' and use_ambig_nt:
 316            raise ValueError('Ambiguous characters can not be calculated for amino acid alignments.')
 317        if threshold is None and use_ambig_nt:
 318            raise ValueError('To calculate ambiguous nucleotides, set a threshold > 0.')
 319
 320        alignment = self.alignment
 321        consensus = str()
 322
 323        if threshold is not None:
 324            consensus_cutoff = len(alignment) * threshold
 325        else:
 326            consensus_cutoff = 0
 327
 328        # built consensus sequences
 329        for idx in range(self.length):
 330            char_counts = determine_counts(alignment, idx)
 331            consensus_chars = get_consensus_char(
 332                char_counts,
 333                consensus_cutoff
 334            )
 335            if threshold != 0:
 336                if len(consensus_chars) > 1:
 337                    if use_ambig_nt:
 338                        char = get_ambiguous_char(consensus_chars)
 339                    else:
 340                        if self.aln_type == 'AA':
 341                            char = 'X'
 342                        else:
 343                            char = 'N'
 344                    consensus = consensus + char
 345                else:
 346                    consensus = consensus + consensus_chars[0]
 347            else:
 348                consensus = consensus + consensus_chars[0]
 349
 350        return consensus
 351
 352    def get_conserved_orfs(self, min_length: int = 100, identity_cutoff: float | None = None) -> OrfCollection:
 353        """
 354        **conserved ORF definition:**
 355            - conserved starts and stops
 356            - start, stop must be on the same frame
 357            - stop - start must be at least min_length
 358            - all ungapped seqs[start:stop] must have at least min_length
 359            - no ungapped seq can have a Stop in between Start Stop
 360
 361        Conservation is measured by the number of positions with identical characters divided by
 362        orf slice of the alignment.
 363
 364        **Algorithm overview:**
 365            - check for conserved start and stop codons
 366            - iterate over all three frames
 367            - check each start and next sufficiently far away stop codon
 368            - check if all ungapped seqs between start and stop codon are >= min_length
 369            - check if no ungapped seq in the alignment has a stop codon
 370            - write to dictionary
 371            - classify as internal if the stop codon has already been written with a prior start
 372            - repeat for reverse complement
 373
 374        :return: ORF positions and internal ORF positions
 375        """
 376
 377        # helper functions
 378        def determine_conserved_start_stops(alignment: dict, alignment_length: int) -> tuple:
 379            """
 380            Determine all start and stop codons within an alignment.
 381            :param alignment: alignment
 382            :param alignment_length: length of alignment
 383            :return: start and stop codon positions
 384            """
 385            starts = config.START_CODONS[self.aln_type]
 386            stops = config.STOP_CODONS[self.aln_type]
 387
 388            list_of_starts, list_of_stops = [], []
 389            # define one sequence (first) as reference (it does not matter which one)
 390            ref = alignment[self.sequence_ids[0]]
 391            for nt_position in range(alignment_length):
 392                if ref[nt_position:nt_position + 3] in starts:
 393                    conserved_start = True
 394                    for sequence in alignment:
 395                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in starts:
 396                            conserved_start = False
 397                            break
 398                    if conserved_start:
 399                        list_of_starts.append(nt_position)
 400
 401                if ref[nt_position:nt_position + 3] in stops:
 402                    conserved_stop = True
 403                    for sequence in alignment:
 404                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in stops:
 405                            conserved_stop = False
 406                            break
 407                    if conserved_stop:
 408                        list_of_stops.append(nt_position)
 409
 410            return list_of_starts, list_of_stops
 411
 412        def get_ungapped_sliced_seqs(alignment: dict, start_pos: int, stop_pos: int) -> list:
 413            """
 414            get ungapped sequences starting and stop codons and eliminate gaps
 415            :param alignment: alignment
 416            :param start_pos: start codon
 417            :param stop_pos: stop codon
 418            :return: sliced sequences
 419            """
 420            ungapped_seqs = []
 421            for seq_id in alignment:
 422                ungapped_seqs.append(alignment[seq_id][start_pos:stop_pos + 3].replace('-', ''))
 423
 424            return ungapped_seqs
 425
 426        def additional_stops(ungapped_seqs: list) -> bool:
 427            """
 428            Checks for the presence of a stop codon
 429            :param ungapped_seqs: list of ungapped sequences
 430            :return: Additional stop codons (True/False)
 431            """
 432            stops = config.STOP_CODONS[self.aln_type]
 433
 434            for sliced_seq in ungapped_seqs:
 435                for position in range(0, len(sliced_seq) - 3, 3):
 436                    if sliced_seq[position:position + 3] in stops:
 437                        return True
 438            return False
 439
 440        def calculate_identity(identity_matrix: ndarray, aln_slice:list) -> float:
 441            sliced_array = identity_matrix[:,aln_slice[0]:aln_slice[1]] + 1  # identical = 0, different = -1 --> add 1
 442            return np.sum(np.all(sliced_array == 1, axis=0))/(aln_slice[1] - aln_slice[0]) * 100
 443
 444        # checks for arguments
 445        if self.aln_type == 'AA':
 446            raise TypeError('ORF search only for RNA/DNA alignments')
 447
 448        if identity_cutoff is not None:
 449            if identity_cutoff > 100 or identity_cutoff < 0:
 450                raise ValueError('conservation cutoff must be between 0 and 100')
 451
 452        if min_length <= 6 or min_length > self.length:
 453            raise ValueError(f'min_length must be between 6 and {self.length}')
 454
 455        # ini
 456        identities = self.calc_identity_alignment()
 457        alignments = [self.alignment, self.calc_reverse_complement_alignment()]
 458        aln_len = self.length
 459
 460        orf_counter = 0
 461        # use mutable dicts during construction and convert to dataclass at the end for immutability
 462        temp_orfs: list[dict] = []
 463
 464        for aln, direction in zip(alignments, ['+', '-']):
 465            # check for starts and stops in the first seq and then check if these are present in all seqs
 466            conserved_starts, conserved_stops = determine_conserved_start_stops(aln, aln_len)
 467            # check each frame
 468            for frame in (0, 1, 2):
 469                potential_starts = [x for x in conserved_starts if x % 3 == frame]
 470                potential_stops = [x for x in conserved_stops if x % 3 == frame]
 471                last_stop = -1
 472                for start in potential_starts:
 473                    # go to the next stop that is sufficiently far away in the alignment
 474                    next_stops = [x for x in potential_stops if x + 3 >= start + min_length]
 475                    if not next_stops:
 476                        continue
 477                    next_stop = next_stops[0]
 478                    ungapped_sliced_seqs = get_ungapped_sliced_seqs(aln, start, next_stop)
 479                    # re-check the lengths of all ungapped seqs
 480                    ungapped_seq_lengths = [len(x) >= min_length for x in ungapped_sliced_seqs]
 481                    if not all(ungapped_seq_lengths):
 482                        continue
 483                    # if no stop codon between start and stop --> write to dictionary
 484                    if not additional_stops(ungapped_sliced_seqs):
 485                        if direction == '+':
 486                            positions = (start, next_stop + 3)
 487                        else:
 488                            positions = (aln_len - next_stop - 3, aln_len - start)
 489                        if last_stop != next_stop:
 490                            last_stop = next_stop
 491                            conservation = calculate_identity(identities, list(positions))
 492                            if identity_cutoff is not None and conservation < identity_cutoff:
 493                                continue
 494                            temp_orfs.append({
 495                                'orf_id': f'ORF_{orf_counter}',
 496                                'location': [positions],
 497                                'frame': frame,
 498                                'strand': direction,
 499                                'conservation': conservation,
 500                                'internal': [],
 501                            })
 502                            orf_counter += 1
 503                        else:
 504                            if temp_orfs:
 505                                temp_orfs[-1]['internal'].append(positions)
 506
 507        # convert mutable intermediate dicts to frozen dataclasses
 508        orf_list = [
 509            OpenReadingFrame(
 510                orf_id=t['orf_id'],
 511                location=tuple(t['location']),
 512                frame=t['frame'],
 513                strand=t['strand'],
 514                conservation=t['conservation'],
 515                internal=tuple(t['internal']),
 516            )
 517            for t in temp_orfs
 518        ]
 519        return OrfCollection(orfs=tuple(orf_list))
 520
 521    def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff:float = None) -> OrfCollection:
 522        """
 523        First calculates all ORFs and then searches from 5'
 524        all non-overlapping orfs in the fw strand and from the
 525        3' all non-overlapping orfs in th rw strand.
 526
 527        **No overlap algorithm:**
 528            **frame 1:** -[M------*]--- ----[M--*]---------[M-----
 529
 530            **frame 2:** -------[M------*]---------[M---*]--------
 531
 532            **frame 3:** [M---*]-----[M----------*]----------[M---
 533
 534            **results:** [M---*][M------*]--[M--*]-[M---*]-[M-----
 535
 536            frame:    3      2           1      2       1
 537
 538        :return: OrfContainer with non-overlapping orfs
 539        """
 540        all_orfs = self.get_conserved_orfs(min_length, identity_cutoff)
 541
 542        fw_orfs, rw_orfs = [], []
 543
 544        for orf in all_orfs:
 545            orf_obj = all_orfs[orf]
 546            if orf_obj.strand == '+':
 547                fw_orfs.append((orf, orf_obj.location[0]))
 548            else:
 549                rw_orfs.append((orf, orf_obj.location[0]))
 550
 551        fw_orfs.sort(key=lambda x: x[1][0])  # sort by start pos
 552        rw_orfs.sort(key=lambda x: x[1][1], reverse=True)  # sort by stop pos
 553        non_overlapping_ids = []
 554        for orf_list, strand in zip([fw_orfs, rw_orfs], ['+', '-']):
 555            previous_stop = -1 if strand == '+' else self.length + 1
 556            for orf in orf_list:
 557                if strand == '+' and orf[1][0] >= previous_stop:
 558                    non_overlapping_ids.append(orf[0])
 559                    previous_stop = orf[1][1]
 560                elif strand == '-' and orf[1][1] <= previous_stop:
 561                    non_overlapping_ids.append(orf[0])
 562                    previous_stop = orf[1][0]
 563
 564        return OrfCollection(orfs=tuple(all_orfs[orf_id] for orf_id in all_orfs if orf_id in non_overlapping_ids))
 565
 566    def calc_length_stats(self) -> LengthStats:
 567        """
 568        Determine the stats for the length of the ungapped seqs in the alignment.
 569        :return: dataclass with length stats
 570        """
 571
 572        seq_lengths = [len(self.alignment[x].replace('-', '')) for x in self.alignment]
 573
 574        return LengthStats(
 575            n_sequences=len(self.alignment),
 576            mean_length=float(np.mean(seq_lengths)),
 577            std_length=float(np.std(seq_lengths)),
 578            min_length=int(np.min(seq_lengths)),
 579            max_length=int(np.max(seq_lengths)),
 580        )
 581
 582    def calc_entropy(self) -> AlignmentStats:
 583        """
 584        Calculate the normalized shannon's entropy for every position in an alignment:
 585
 586        - 1: high entropy
 587        - 0: low entropy
 588
 589        :return: Entropies at each position.
 590        """
 591
 592        # helper functions
 593        def shannons_entropy(character_list: list, states: int, aln_type: str) -> float:
 594            """
 595            Calculate the shannon's entropy of a sequence and
 596            normalized between 0 and 1.
 597            :param character_list: characters at an alignment position
 598            :param states: number of potential characters that can be present
 599            :param aln_type: type of the alignment
 600            :returns: entropy
 601            """
 602            ent, n_chars = 0, len(character_list)
 603            # only one char is in the list
 604            if n_chars <= 1:
 605                return ent
 606            # calculate the number of unique chars and their counts
 607            chars, char_counts = np.unique(character_list, return_counts=True)
 608            char_counts = char_counts.astype(float)
 609            # ignore gaps for entropy calc
 610            char_counts, chars = char_counts[chars != "-"], chars[chars != "-"]
 611            # correctly handle ambiguous chars
 612            index_to_drop = []
 613            for index, char in enumerate(chars):
 614                if char in config.AMBIG_CHARS[aln_type]:
 615                    index_to_drop.append(index)
 616                    amb_chars, amb_counts = np.unique(config.AMBIG_CHARS[aln_type][char], return_counts=True)
 617                    amb_counts = amb_counts / len(config.AMBIG_CHARS[aln_type][char])
 618                    # add the proportionate numbers to initial array
 619                    for amb_char, amb_count in zip(amb_chars, amb_counts):
 620                        if amb_char in chars:
 621                            char_counts[chars == amb_char] += amb_count
 622                        else:
 623                            chars, char_counts = np.append(chars, amb_char), np.append(char_counts, amb_count)
 624            # drop the ambiguous characters from array
 625            char_counts, chars = np.delete(char_counts, index_to_drop), np.delete(chars, index_to_drop)
 626            # calc the entropy
 627            probs = char_counts / n_chars
 628            if np.count_nonzero(probs) <= 1:
 629                return ent
 630            for prob in probs:
 631                ent -= prob * math.log(prob, states)
 632
 633            return ent
 634
 635        aln = self.alignment
 636        entropys = []
 637
 638        if self.aln_type == 'AA':
 639            states = 20
 640        else:
 641            states = 4
 642        # iterate over alignment positions and the sequences
 643        for nuc_pos in range(self.length):
 644            pos = []
 645            for record in aln:
 646                pos.append(aln[record][nuc_pos])
 647            entropys.append(shannons_entropy(pos, states, self.aln_type))
 648
 649        return self._create_position_stat_result('entropy', entropys)
 650
 651    def calc_gc(self) -> AlignmentStats | TypeError:
 652        """
 653        Determine the GC content for every position in an nt alignment.
 654        :return: GC content for every position.
 655        :raises: TypeError for AA alignments
 656        """
 657        if self.aln_type == 'AA':
 658            raise TypeError("GC computation is not possible for aminoacid alignment")
 659
 660        gc, aln, amb_nucs = [], self.alignment, config.AMBIG_CHARS[self.aln_type]
 661
 662        for position in range(self.length):
 663            nucleotides = str()
 664            for record in aln:
 665                nucleotides = nucleotides + aln[record][position]
 666            # ini dict with chars that occur and which ones to
 667            # count in which freq
 668            to_count = {
 669                'G': 1,
 670                'C': 1,
 671            }
 672            # handle ambig. nuc
 673            for char in amb_nucs:
 674                if char in nucleotides:
 675                    to_count[char] = (amb_nucs[char].count('C') + amb_nucs[char].count('G')) / len(amb_nucs[char])
 676
 677            gc.append(
 678                sum([nucleotides.count(x) * to_count[x] for x in to_count]) / len(nucleotides)
 679            )
 680
 681        return self._create_position_stat_result('gc', gc)
 682
 683    def calc_coverage(self) -> AlignmentStats:
 684        """
 685        Determine the coverage of every position in an alignment.
 686        This is defined as:
 687            1 - cumulative length of '-' characters
 688
 689        :return: Coverage at each alignment position.
 690        """
 691        coverage, aln = [], self.alignment
 692
 693        for nuc_pos in range(self.length):
 694            pos = str()
 695            for record in self.sequence_ids:
 696                pos = pos + aln[record][nuc_pos]
 697            coverage.append(1 - pos.count('-') / len(pos))
 698
 699        return self._create_position_stat_result('coverage', coverage)
 700
 701    def calc_gap_frequency(self) -> AlignmentStats:
 702        """
 703        Determine the gap frequency for every position in an alignment. This is the inverted coverage.
 704        """
 705        coverage = self.calc_coverage()
 706
 707        return self._create_position_stat_result('gap frequency', 1 - coverage.values)
 708
 709    def calc_reverse_complement_alignment(self) -> dict | TypeError:
 710        """
 711        Reverse complement the alignment.
 712        :return: Alignment (rv)
 713        """
 714        if self.aln_type == 'AA':
 715            raise TypeError('Reverse complement only for RNA or DNA.')
 716
 717        aln = self.alignment
 718        reverse_complement_dict = {}
 719
 720        for seq_id in aln:
 721            reverse_complement_dict[seq_id] = ''.join(config.COMPLEMENT[base] for base in reversed(aln[seq_id]))
 722
 723        return reverse_complement_dict
 724
 725    def calc_numerical_alignment(self, encode_mask:bool=False, encode_ambiguities:bool=False) -> ndarray:
 726        """
 727        Transforms the alignment to numerical values. Ambiguities are encoded as -3, mask as -2 and the
 728        remaining chars with the idx + 1 of config.CHAR_COLORS[self.aln_type]['standard'].
 729
 730        :param encode_ambiguities: encode ambiguities as -2
 731        :param encode_mask: encode mask with as -3
 732        :returns matrix
 733        """
 734
 735        sequences = self._to_array()
 736        # ini matrix
 737        numerical_matrix = np.full(sequences.shape, np.nan, dtype=float)
 738        # first encode mask
 739        if encode_mask:
 740            if self.aln_type == 'AA':
 741                is_n_or_x = np.isin(sequences, ['X'])
 742            else:
 743                is_n_or_x = np.isin(sequences, ['N'])
 744            numerical_matrix[is_n_or_x] = -2
 745        # next encode ambig chars
 746        if encode_ambiguities:
 747            numerical_matrix[np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])] = -3
 748        # next convert each char into their respective values
 749        for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]['standard']):
 750            numerical_matrix[np.isin(sequences, [char])] = idx + 1
 751
 752        return numerical_matrix
 753
 754    def calc_identity_alignment(self, encode_mismatches:bool=True, encode_mask:bool=False, encode_gaps:bool=True, encode_ambiguities:bool=False, encode_each_mismatch_char:bool=False) -> ndarray:
 755        """
 756        Converts alignment to identity array (identical=0) compared to majority consensus or reference:\n
 757
 758        :param encode_mismatches: encode mismatch as -1
 759        :param encode_mask: encode mask with value=-2 --> also in the reference
 760        :param encode_gaps: encode gaps with np.nan --> also in the reference
 761        :param encode_ambiguities: encode ambiguities with value=-3
 762        :param encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.CHAR_COLORS[self.aln_type]['standard']
 763        :return: identity alignment
 764        """
 765
 766        ref = self._get_reference_seq()
 767
 768        # convert alignment to array
 769        sequences = self._to_array()
 770        reference = np.array(list(ref))
 771        # ini matrix
 772        identity_matrix = np.full(sequences.shape, 0, dtype=float)
 773
 774        is_identical = sequences == reference
 775
 776        if encode_gaps:
 777            is_gap = sequences == '-'
 778        else:
 779            is_gap = np.full(sequences.shape, False)
 780
 781        if encode_mask:
 782            if self.aln_type == 'AA':
 783                is_n_or_x = np.isin(sequences, ['X'])
 784            else:
 785                is_n_or_x = np.isin(sequences, ['N'])
 786        else:
 787            is_n_or_x = np.full(sequences.shape, False)
 788
 789        if encode_ambiguities:
 790            is_ambig = np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])
 791        else:
 792            is_ambig = np.full(sequences.shape, False)
 793
 794        if encode_mismatches:
 795            is_mismatch = ~is_gap & ~is_identical & ~is_n_or_x & ~is_ambig
 796        else:
 797            is_mismatch = np.full(sequences.shape, False)
 798
 799        # encode every different character
 800        if encode_each_mismatch_char:
 801            for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]['standard']):
 802                new_encoding = np.isin(sequences, [char]) & is_mismatch
 803                identity_matrix[new_encoding] = idx + 1
 804        # or encode different with a single value
 805        else:
 806            identity_matrix[is_mismatch] = -1  # mismatch
 807
 808        identity_matrix[is_gap] = np.nan  # gap
 809        identity_matrix[is_n_or_x] = -2  # 'N' or 'X'
 810        identity_matrix[is_ambig] = -3  # ambiguities
 811
 812        return identity_matrix
 813
 814    def calc_similarity_alignment(self, matrix_type:str|None=None, normalize:bool=True) -> ndarray:
 815        """
 816        Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight
 817        differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the
 818        reference residue at each column. Gaps are encoded as np.nan.
 819
 820        The calculation follows these steps:
 821
 822        1. **Reference Sequence**: If a reference sequence is provided (via `self.reference_id`), it is used. Otherwise,
 823           a consensus sequence is generated to serve as the reference.
 824        2. **Substitution Matrix**: The similarity between residues is determined using a substitution matrix, such as
 825           BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
 826        3. **Per-Column Normalization (optional)**:
 827
 828        For each column in the alignment:
 829            - The residue in the reference sequence is treated as the baseline for that column.
 830            - The substitution scores for the reference residue are extracted from the substitution matrix.
 831            - The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference residue.
 832            - This ensures that identical residues (or those with high similarity to the reference) have high scores,
 833            while more dissimilar residues have lower scores.
 834        4. **Output**:
 835
 836           - The normalized similarity scores are stored in a NumPy array.
 837           - Gaps (if any) or residues not present in the substitution matrix are encoded as `np.nan`.
 838
 839        :param: matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
 840        :param: normalize: whether to normalize the similarity scores to range [0, 1]
 841        :return: A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue
 842            and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity).
 843            Gaps and invalid residues are encoded as `np.nan`.
 844        :raise: ValueError
 845            If the specified substitution matrix is not available for the given alignment type.
 846        """
 847
 848        ref = self._get_reference_seq()
 849
 850        if matrix_type is None:
 851            if self.aln_type == 'AA':
 852                matrix_type = 'BLOSUM65'
 853            else:
 854                matrix_type = 'TRANS'
 855        # load substitution matrix as dictionary
 856        try:
 857            subs_matrix = config.SUBS_MATRICES[self.aln_type][matrix_type]
 858        except KeyError:
 859            raise ValueError(
 860                f'The specified matrix does not exist for alignment type.\nAvailable matrices for {self.aln_type} are:\n{list(config.SUBS_MATRICES[self.aln_type].keys())}'
 861            )
 862
 863        # set dtype and convert alignment to a NumPy array for vectorized processing
 864        dtype = np.dtype(float, metadata={'matrix': matrix_type})
 865        sequences = self._to_array()
 866        reference = np.array(list(ref))
 867        valid_chars = list(subs_matrix.keys())
 868        similarity_array = np.full(sequences.shape, np.nan, dtype=dtype)
 869
 870        for j, ref_char in enumerate(reference):
 871            if ref_char not in valid_chars + ['-']:
 872                continue
 873            # Get local min and max for the reference residue
 874            if normalize and ref_char != '-':
 875                local_scores = subs_matrix[ref_char].values()
 876                local_min, local_max = min(local_scores), max(local_scores)
 877
 878            for i, char in enumerate(sequences[:, j]):
 879                if char not in valid_chars:
 880                    continue
 881                # classify the similarity as max if the reference has a gap
 882                similarity_score = subs_matrix[char][ref_char] if ref_char != '-' else 1
 883                similarity_array[i, j] = (similarity_score - local_min) / (local_max - local_min) if normalize and ref_char != '-' else similarity_score
 884
 885        return similarity_array
 886
 887    def calc_position_matrix(self, matrix_type:str='PWM') -> None | ndarray:
 888        """
 889        Calculates a position matrix of the specified type for the given alignment. The function
 890        supports generating matrices of types Position Frequency Matrix (PFM), Position Probability
 891        Matrix (PPM), Position Weight Matrix (PWM), and cumulative Information Content (IC). It validates
 892        the provided matrix type and includes pseudo-count adjustments to ensure robust calculations.
 893
 894        :param matrix_type: Type of position matrix to calculate. Accepted values are 'PFM', 'PPM',
 895            'PWM', and 'IC'. Defaults to 'PWM'.
 896        :type matrix_type: str
 897        :raises ValueError: If the provided `matrix_type` is not one of the accepted values.
 898        :return: A numpy array representing the calculated position matrix of the specified type.
 899        :rtype: np.ndarray
 900        """
 901
 902        # ini
 903        if matrix_type not in ['PFM', 'PPM', 'IC', 'PWM']:
 904            raise ValueError('Matrix_type must be PFM, PPM, IC or PWM.')
 905        possible_chars = list(config.CHAR_COLORS[self.aln_type]['standard'].keys())
 906        sequences = self._to_array()
 907
 908        # calc position frequency matrix
 909        pfm = np.array([np.sum(sequences == char, 0) for char in possible_chars])
 910        if matrix_type == 'PFM':
 911            return pfm
 912
 913        # calc position probability matrix (probability)
 914        pseudo_count = 0.0001  # to avoid 0 values
 915        pfm = pfm + pseudo_count
 916        ppm_non_char_excluded = pfm/np.sum(pfm, axis=0)  # use this for pwm/ic calculation
 917        ppm = pfm/len(self.sequence_ids)  # calculate the frequency based on row number
 918        if matrix_type == 'PPM':
 919            return ppm
 920
 921        # calc position weight matrix (log-likelihood)
 922        pwm = np.log2(ppm_non_char_excluded * len(possible_chars))
 923        if matrix_type == 'PWM':
 924            return pwm
 925
 926        # calc information content per position (in bits) - can be used to scale a ppm for sequence logos
 927        ic = np.sum(ppm_non_char_excluded * pwm, axis=0)
 928        if matrix_type == 'IC':
 929            return ic
 930
 931    def calc_percent_recovery(self) -> dict[str, float]:
 932        """
 933        Recovery per sequence either compared to the majority consensus seq
 934        or the reference seq.\n
 935        Defined as:\n
 936
 937        `(1 - sum(N/X/- characters in ungapped ref regions))*100`
 938
 939        This is highly similar to how nextclade calculates recovery over reference.
 940
 941        :return: dict
 942        """
 943
 944        aln = self.alignment
 945        ref = self._get_reference_seq()
 946
 947        if not any(char != '-' for char in ref):
 948            raise ValueError("Reference sequence is entirely gapped, cannot calculate recovery.")
 949
 950
 951        # count 'N', 'X' and '-' chars in non-gapped regions
 952        recovery_over_ref = dict()
 953
 954        # Get positions of non-gap characters in the reference
 955        non_gap_positions = [i for i, char in enumerate(ref) if char != '-']
 956        cumulative_length = len(non_gap_positions)
 957
 958        # Calculate recovery
 959        for seq_id in self.sequence_ids:
 960            if seq_id == self.reference_id:
 961                continue
 962            seq = aln[seq_id]
 963            count_invalid = sum(
 964                seq[pos] == '-' or
 965                (seq[pos] == 'X' if self.aln_type == "AA" else seq[pos] == 'N')
 966                for pos in non_gap_positions
 967            )
 968            recovery_over_ref[seq_id] = (1 - count_invalid / cumulative_length) * 100
 969
 970        return recovery_over_ref
 971
 972    def calc_character_frequencies(self) -> dict:
 973        """
 974        Calculate the percentage characters in the alignment:
 975        The frequencies are counted by seq and in total. The
 976        percentage of non-gap characters in the alignment is
 977        relative to the total number of non-gap characters.
 978        The gap percentage is relative to the sequence length.
 979
 980        The output is a nested dictionary.
 981
 982        :return: Character frequencies
 983        """
 984
 985        aln, aln_length = self.alignment, self.length
 986
 987        freqs = {'total': {'-': {'counts': 0, '% of alignment': float()}}}
 988
 989        for seq_id in aln:
 990            freqs[seq_id], all_chars = {'-': {'counts': 0, '% of alignment': float()}}, 0
 991            unique_chars = set(aln[seq_id])
 992            for char in unique_chars:
 993                if char == '-':
 994                    continue
 995                # add characters to dictionaries
 996                if char not in freqs[seq_id]:
 997                    freqs[seq_id][char] = {'counts': 0, '% of non-gapped': 0}
 998                if char not in freqs['total']:
 999                    freqs['total'][char] = {'counts': 0, '% of non-gapped': 0}
1000                # count non-gap chars
1001                freqs[seq_id][char]['counts'] += aln[seq_id].count(char)
1002                freqs['total'][char]['counts'] += freqs[seq_id][char]['counts']
1003                all_chars += freqs[seq_id][char]['counts']
1004            # normalize counts
1005            for char in freqs[seq_id]:
1006                if char == '-':
1007                    continue
1008                freqs[seq_id][char]['% of non-gapped'] = freqs[seq_id][char]['counts'] / all_chars * 100
1009                freqs['total'][char]['% of non-gapped'] += freqs[seq_id][char]['% of non-gapped']
1010            # count gaps
1011            freqs[seq_id]['-']['counts'] = aln[seq_id].count('-')
1012            freqs['total']['-']['counts'] += freqs[seq_id]['-']['counts']
1013            # normalize gap counts
1014            freqs[seq_id]['-']['% of alignment'] = freqs[seq_id]['-']['counts'] / aln_length * 100
1015            freqs['total']['-']['% of alignment'] += freqs[seq_id]['-']['% of alignment']
1016
1017        # normalize the total counts
1018        for char in freqs['total']:
1019            for value in freqs['total'][char]:
1020                if value == '% of alignment' or value == '% of non-gapped':
1021                    freqs['total'][char][value] = freqs['total'][char][value] / len(aln)
1022
1023        return freqs
1024
1025    def calc_pairwise_identity_matrix(self, distance_type:str='ghd') -> PairwiseDistance:
1026        """
1027        Calculate pairwise identities for an alignment. Different options are implemented:
1028
1029        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
1030        \ndistance = matches / alignment_length * 100
1031
1032        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
1033        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
1034
1035        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
1036        \ndistance = (1 - mismatches / total) * 100
1037
1038        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
1039        \ndistance = matches / gap_compressed_alignment_length * 100
1040
1041        RNA/DNA only:
1042
1043        **5) jc69 (Jukes-Cantor 1969)**: Gaps excluded. Applies the JC69 substitution model to correct
1044        the p-distance for multiple hits (assumes equal base frequencies and substitution rates).
1045        \ncorrected_identity = (1 - d_JC69) * 100,  where d = -(3/4) * ln(1 - (4/3) * p)
1046
1047        **6) k2p (Kimura 2-Parameter / K80)**: Gaps excluded. Distinguishes transitions (Ts) and
1048        transversions (Tv). Returns (1 - d_K2P) * 100 as corrected percent identity.
1049        \nd = -(1/2) * ln(1 - 2P - Q) - (1/4) * ln(1 - 2Q),  P = Ts/total, Q = Tv/total
1050
1051        :param distance_type: type of distance computation: ghd, lhd, ged, gcd and nucleotide only: jc69 and k2p
1052        :return: array with pairwise distances.
1053        """
1054
1055        distance_functions = _create_distance_calculation_function_mapping()
1056
1057        if distance_type not in distance_functions:
1058            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1059
1060        aln = self.alignment
1061
1062        if self.aln_type == 'AA' and distance_type in ['jc69', 'k2p']:
1063            raise ValueError(f"JC69 and K2P are not supported for {self.aln_type} alignment.")
1064
1065        # Compute pairwise distances
1066        distance_func = distance_functions[distance_type]
1067        distance_matrix = np.zeros((len(aln), len(aln)))
1068
1069        sequences = list(aln.values())
1070        n = len(sequences)
1071        for i in range(n):
1072            seq1 = sequences[i]
1073            for j in range(i, n):
1074                seq2 = sequences[j]
1075                dist = distance_func(seq1, seq2, self.length)
1076                distance_matrix[i, j] = dist
1077                distance_matrix[j, i] = dist
1078
1079        return PairwiseDistance(
1080            reference_id=None,
1081            sequence_ids=self.sequence_ids,
1082            distances=distance_matrix
1083        )
1084
1085    def calc_pairwise_distance_to_reference(self, distance_type:str='ghd') -> PairwiseDistance:
1086        """
1087        Calculate pairwise identities between reference and all sequences in the alignment. Same computation as calc_pairwise_identity_matrix but compared to a single sequence. Different options are implemented:
1088
1089        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
1090        \ndistance = matches / alignment_length * 100
1091
1092        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
1093        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
1094
1095        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
1096        \ndistance = (1 - mismatches / total) * 100
1097
1098        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
1099        \ndistance = matches / gap_compressed_alignment_length * 100
1100
1101        RNA/DNA only:
1102
1103        **5) jc69 (Jukes-Cantor 1969)**: Gaps excluded. Applies the JC69 substitution model to correct
1104        the p-distance for multiple hits (assumes equal base frequencies and substitution rates).
1105        \ncorrected_identity = (1 - d_JC69) * 100,  where d = -(3/4) * ln(1 - (4/3) * p)
1106
1107        **6) k2p (Kimura 2-Parameter / K80)**: Gaps excluded. Distinguishes transitions (Ts) and
1108        transversions (Tv). Returns (1 - d_K2P) * 100 as corrected percent identity.
1109        \nd = -(1/2) * ln(1 - 2P - Q) - (1/4) * ln(1 - 2Q),  P = Ts/total, Q = Tv/total
1110
1111        :param distance_type: type of distance computation: ghd, lhd, ged, gcd and nucleotide only: jc69 and k2p
1112        :return: array with pairwise distances.
1113        :return: dataclass with reference label, sequence ids and pairwise distances.
1114        """
1115
1116        distance_functions = _create_distance_calculation_function_mapping()
1117
1118        if distance_type not in distance_functions:
1119            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1120
1121        if self.aln_type == 'AA' and distance_type in ['jc69', 'k2p']:
1122            raise ValueError(f"JC69 and K2P are not supported for {self.aln_type} alignment.")
1123
1124        # Compute pairwise distances
1125        distance_func = distance_functions[distance_type]
1126        aln = self.alignment
1127        ref_id = self.reference_id
1128
1129        ref_seq = aln[ref_id] if ref_id is not None else self.get_consensus()
1130        distances = []
1131        distance_names = []
1132
1133        for seq_id in aln:
1134            if seq_id == ref_id:
1135                continue
1136            distance_names.append(seq_id)
1137            distances.append(distance_func(ref_seq, aln[seq_id], self.length))
1138
1139        return PairwiseDistance(
1140            reference_id=ref_id if ref_id is not None else 'consensus',
1141            sequence_ids=distance_names,
1142            distances=np.array(distances)
1143        )
1144
1145    def get_snps(self, include_ambig:bool=False) -> VariantCollection:
1146        """
1147        Calculate snps similar to snp-sites (output is comparable):
1148        https://github.com/sanger-pathogens/snp-sites
1149        Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character.
1150        The SNPs are compared to a majority consensus sequence or to a reference if it has been set.
1151
1152        :param include_ambig: Include ambiguous snps (default: False)
1153        :return: dataclass containing SNP positions and their variants including frequency.
1154        """
1155        aln = self.alignment
1156        ref = self._get_reference_seq()
1157        aln = {x: aln[x] for x in self.sequence_ids if x != self.reference_id}
1158        seq_ids = list(aln.keys())
1159        chrom = self.reference_id if self.reference_id is not None else 'consensus'
1160        snp_positions = {}
1161        aln_size = len(aln)
1162
1163        for pos in range(self.length):
1164            reference_char = ref[pos]
1165            if not include_ambig:
1166                if reference_char in config.AMBIG_CHARS[self.aln_type] and reference_char != '-':
1167                    continue
1168            alt_chars, snps = [], []
1169            for i, seq_id in enumerate(seq_ids):
1170                alt_char = aln[seq_id][pos]
1171                alt_chars.append(alt_char)
1172                if reference_char != alt_char:
1173                    snps.append(i)
1174            if not snps:
1175                continue
1176            # Filter out ambiguous snps if not included
1177            if include_ambig:
1178                if all(alt_chars[x] in config.AMBIG_CHARS[self.aln_type] for x in snps):
1179                    continue
1180            else:
1181                snps = [x for x in snps if alt_chars[x] not in config.AMBIG_CHARS[self.aln_type]]
1182                if not snps:
1183                    continue
1184            # Build allele dict with counts
1185            alt_dict = {}
1186            for snp_idx in snps:
1187                alt_char = alt_chars[snp_idx]
1188                if alt_char not in alt_dict:
1189                    alt_dict[alt_char] = {'count': 0, 'seq_ids': []}
1190                alt_dict[alt_char]['count'] += 1
1191                alt_dict[alt_char]['seq_ids'].append(seq_ids[snp_idx])
1192            
1193            # Convert to final format: alt_char -> (frequency, seq_ids_tuple)
1194            alt = {
1195                alt_char: (data['count'] / aln_size, tuple(data['seq_ids']))
1196                for alt_char, data in alt_dict.items()
1197            }
1198            snp_positions[pos] = SingleNucleotidePolymorphism(ref=reference_char, alt=alt)
1199
1200        return VariantCollection(chrom=chrom, positions=snp_positions)
1201
1202    def calc_transition_transversion_score(self) -> AlignmentStats:
1203        """
1204        Based on the snp positions, calculates a transition/transversions score.
1205        A positive score means higher ratio of transitions and negative score means
1206        a higher ratio of transversions.
1207        :return: list
1208        """
1209
1210        if self.aln_type == 'AA':
1211            raise TypeError('TS/TV scoring only for RNA/DNA alignments')
1212
1213        # ini
1214        snps = self.get_snps()
1215        score = [0]*self.length
1216
1217        for pos, snp in snps.positions.items():
1218            for alt, (af, _) in snp.alt.items():
1219                # check the type of substitution
1220                if snp.ref + alt in ['AG', 'GA', 'CT', 'TC', 'CU', 'UC']:
1221                    score[pos] += af
1222                else:
1223                    score[pos] -= af
1224
1225        return self._create_position_stat_result('ts tv score', score)

An alignment class that allows computation of several stats. Supported inputs are file paths to alignments in "fasta", "clustal", "phylip", "stockholm", "nexus" formats, raw alignment strings, or Bio.Align.MultipleSeqAlignment for compatibility with Biopython.

MSA( alignment_string: str | Bio.Align.MultipleSeqAlignment, reference_id: str = None, zoom_range: tuple | int = None)
39    def __init__(self, alignment_string: str | MultipleSeqAlignment, reference_id: str = None, zoom_range: tuple | int = None):
40        """
41        Initialize an Alignment object.
42        :param alignment_string: Path to alignment file or raw alignment string
43        :param reference_id: reference id
44        :param zoom_range: start and stop positions to zoom into the alignment
45        """
46        self._alignment = _read_alignment(alignment_string)
47        self._reference_id = self._validate_ref(reference_id, self._alignment)
48        self._zoom = self._validate_zoom(zoom_range, self._alignment)
49        self._aln_type = self._determine_aln_type(self._alignment)

Initialize an Alignment object.

Parameters
  • alignment_string: Path to alignment file or raw alignment string
  • reference_id: reference id
  • zoom_range: start and stop positions to zoom into the alignment
reference_id: str
132    @property
133    def reference_id(self) -> str:
134        return self._reference_id

Set and validate the reference id.

zoom: tuple
143    @property
144    def zoom(self) -> tuple:
145        return self._zoom

Validate if the user defined zoom range.

aln_type: str
155    @property
156    def aln_type(self) -> str:
157        """
158        define the aln type:
159        RNA, DNA or AA
160        """
161        return self._aln_type

define the aln type: RNA, DNA or AA

sequence_ids: list
163    @property
164    def sequence_ids(self) -> list:
165        return list(self.alignment.keys())
length: int
168    @property
169    def length(self) -> int:
170        return len(next(iter(self.alignment.values())))
alignment: dict
172    @property
173    def alignment(self) -> dict:
174        """
175        (zoomed) version of the alignment.
176        """
177        if self.zoom is not None:
178            zoomed_aln = dict()
179            for seq in self._alignment:
180                zoomed_aln[seq] = self._alignment[seq][self.zoom[0]:self.zoom[1]]
181            return zoomed_aln
182        else:
183            return self._alignment

(zoomed) version of the alignment.

def get_reference_coords(self) -> tuple[int, int]:
209    def get_reference_coords(self) -> tuple[int, int]:
210        """
211        Determine the start and end coordinates of the reference sequence
212        defined as the first/last nucleotide in the reference sequence
213        (excluding N and gaps).
214
215        :return: Start, End
216        """
217        start, end = 0, self.length
218
219        if self.reference_id is None:
220            return start, end
221        else:
222            # 5' --> 3'
223            for start in range(self.length):
224                if self.alignment[self.reference_id][start] not in ['-', 'N']:
225                    break
226            # 3' --> 5'
227            for end in range(self.length - 1, 0, -1):
228                if self.alignment[self.reference_id][end] not in ['-', 'N']:
229                    break
230
231            return start, end

Determine the start and end coordinates of the reference sequence defined as the first/last nucleotide in the reference sequence (excluding N and gaps).

Returns

Start, End

def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
233    def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
234        """
235        Creates a non-gapped consensus sequence.
236
237        :param threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes
238            the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments)
239            or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
240        :param use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position
241            has a frequency above the defined threshold.
242        :return: consensus sequence
243        """
244
245        # helper functions
246        def determine_counts(alignment_dict: dict, position: int) -> dict:
247            """
248            count the number of each char at
249            an idx of the alignment. return sorted dic.
250            handles ambiguous nucleotides in sequences.
251            also handles gaps.
252            """
253            nucleotide_list = []
254
255            # get all nucleotides
256            for sequence in alignment_dict.items():
257                nucleotide_list.append(sequence[1][position])
258            # count occurences of nucleotides
259            counter = dict(collections.Counter(nucleotide_list))
260            # get permutations of an ambiguous nucleotide
261            to_delete = []
262            temp_dict = {}
263            for nucleotide in counter:
264                if nucleotide in config.AMBIG_CHARS[self.aln_type]:
265                    to_delete.append(nucleotide)
266                    permutations = config.AMBIG_CHARS[self.aln_type][nucleotide]
267                    adjusted_freq = 1 / len(permutations)
268                    for permutation in permutations:
269                        if permutation in temp_dict:
270                            temp_dict[permutation] += adjusted_freq
271                        else:
272                            temp_dict[permutation] = adjusted_freq
273
274            # drop ambiguous entries and add adjusted freqs to
275            if to_delete:
276                for i in to_delete:
277                    counter.pop(i)
278                for nucleotide in temp_dict:
279                    if nucleotide in counter:
280                        counter[nucleotide] += temp_dict[nucleotide]
281                    else:
282                        counter[nucleotide] = temp_dict[nucleotide]
283
284            return dict(sorted(counter.items(), key=lambda x: x[1], reverse=True))
285
286        def get_consensus_char(counts: dict, cutoff: float) -> list:
287            """
288            get a list of nucleotides for the consensus seq
289            """
290            n = 0
291
292            consensus_chars = []
293            for char in counts:
294                n += counts[char]
295                consensus_chars.append(char)
296                if n >= cutoff:
297                    break
298
299            return consensus_chars
300
301        def get_ambiguous_char(nucleotides: list) -> str:
302            """
303            get ambiguous char from a list of nucleotides
304            """
305            for ambiguous, permutations in config.AMBIG_CHARS[self.aln_type].items():
306                if set(permutations) == set(nucleotides):
307                    break
308
309            return ambiguous
310
311        # check if params have been set correctly
312        if threshold is not None:
313            if threshold < 0 or threshold > 1:
314                raise ValueError('Threshold must be between 0 and 1.')
315        if self.aln_type == 'AA' and use_ambig_nt:
316            raise ValueError('Ambiguous characters can not be calculated for amino acid alignments.')
317        if threshold is None and use_ambig_nt:
318            raise ValueError('To calculate ambiguous nucleotides, set a threshold > 0.')
319
320        alignment = self.alignment
321        consensus = str()
322
323        if threshold is not None:
324            consensus_cutoff = len(alignment) * threshold
325        else:
326            consensus_cutoff = 0
327
328        # built consensus sequences
329        for idx in range(self.length):
330            char_counts = determine_counts(alignment, idx)
331            consensus_chars = get_consensus_char(
332                char_counts,
333                consensus_cutoff
334            )
335            if threshold != 0:
336                if len(consensus_chars) > 1:
337                    if use_ambig_nt:
338                        char = get_ambiguous_char(consensus_chars)
339                    else:
340                        if self.aln_type == 'AA':
341                            char = 'X'
342                        else:
343                            char = 'N'
344                    consensus = consensus + char
345                else:
346                    consensus = consensus + consensus_chars[0]
347            else:
348                consensus = consensus + consensus_chars[0]
349
350        return consensus

Creates a non-gapped consensus sequence.

Parameters
  • threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments) or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
  • use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position has a frequency above the defined threshold.
Returns

consensus sequence

def get_conserved_orfs( self, min_length: int = 100, identity_cutoff: float | None = None) -> msaexplorer._data_classes.OrfCollection:
352    def get_conserved_orfs(self, min_length: int = 100, identity_cutoff: float | None = None) -> OrfCollection:
353        """
354        **conserved ORF definition:**
355            - conserved starts and stops
356            - start, stop must be on the same frame
357            - stop - start must be at least min_length
358            - all ungapped seqs[start:stop] must have at least min_length
359            - no ungapped seq can have a Stop in between Start Stop
360
361        Conservation is measured by the number of positions with identical characters divided by
362        orf slice of the alignment.
363
364        **Algorithm overview:**
365            - check for conserved start and stop codons
366            - iterate over all three frames
367            - check each start and next sufficiently far away stop codon
368            - check if all ungapped seqs between start and stop codon are >= min_length
369            - check if no ungapped seq in the alignment has a stop codon
370            - write to dictionary
371            - classify as internal if the stop codon has already been written with a prior start
372            - repeat for reverse complement
373
374        :return: ORF positions and internal ORF positions
375        """
376
377        # helper functions
378        def determine_conserved_start_stops(alignment: dict, alignment_length: int) -> tuple:
379            """
380            Determine all start and stop codons within an alignment.
381            :param alignment: alignment
382            :param alignment_length: length of alignment
383            :return: start and stop codon positions
384            """
385            starts = config.START_CODONS[self.aln_type]
386            stops = config.STOP_CODONS[self.aln_type]
387
388            list_of_starts, list_of_stops = [], []
389            # define one sequence (first) as reference (it does not matter which one)
390            ref = alignment[self.sequence_ids[0]]
391            for nt_position in range(alignment_length):
392                if ref[nt_position:nt_position + 3] in starts:
393                    conserved_start = True
394                    for sequence in alignment:
395                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in starts:
396                            conserved_start = False
397                            break
398                    if conserved_start:
399                        list_of_starts.append(nt_position)
400
401                if ref[nt_position:nt_position + 3] in stops:
402                    conserved_stop = True
403                    for sequence in alignment:
404                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in stops:
405                            conserved_stop = False
406                            break
407                    if conserved_stop:
408                        list_of_stops.append(nt_position)
409
410            return list_of_starts, list_of_stops
411
412        def get_ungapped_sliced_seqs(alignment: dict, start_pos: int, stop_pos: int) -> list:
413            """
414            get ungapped sequences starting and stop codons and eliminate gaps
415            :param alignment: alignment
416            :param start_pos: start codon
417            :param stop_pos: stop codon
418            :return: sliced sequences
419            """
420            ungapped_seqs = []
421            for seq_id in alignment:
422                ungapped_seqs.append(alignment[seq_id][start_pos:stop_pos + 3].replace('-', ''))
423
424            return ungapped_seqs
425
426        def additional_stops(ungapped_seqs: list) -> bool:
427            """
428            Checks for the presence of a stop codon
429            :param ungapped_seqs: list of ungapped sequences
430            :return: Additional stop codons (True/False)
431            """
432            stops = config.STOP_CODONS[self.aln_type]
433
434            for sliced_seq in ungapped_seqs:
435                for position in range(0, len(sliced_seq) - 3, 3):
436                    if sliced_seq[position:position + 3] in stops:
437                        return True
438            return False
439
440        def calculate_identity(identity_matrix: ndarray, aln_slice:list) -> float:
441            sliced_array = identity_matrix[:,aln_slice[0]:aln_slice[1]] + 1  # identical = 0, different = -1 --> add 1
442            return np.sum(np.all(sliced_array == 1, axis=0))/(aln_slice[1] - aln_slice[0]) * 100
443
444        # checks for arguments
445        if self.aln_type == 'AA':
446            raise TypeError('ORF search only for RNA/DNA alignments')
447
448        if identity_cutoff is not None:
449            if identity_cutoff > 100 or identity_cutoff < 0:
450                raise ValueError('conservation cutoff must be between 0 and 100')
451
452        if min_length <= 6 or min_length > self.length:
453            raise ValueError(f'min_length must be between 6 and {self.length}')
454
455        # ini
456        identities = self.calc_identity_alignment()
457        alignments = [self.alignment, self.calc_reverse_complement_alignment()]
458        aln_len = self.length
459
460        orf_counter = 0
461        # use mutable dicts during construction and convert to dataclass at the end for immutability
462        temp_orfs: list[dict] = []
463
464        for aln, direction in zip(alignments, ['+', '-']):
465            # check for starts and stops in the first seq and then check if these are present in all seqs
466            conserved_starts, conserved_stops = determine_conserved_start_stops(aln, aln_len)
467            # check each frame
468            for frame in (0, 1, 2):
469                potential_starts = [x for x in conserved_starts if x % 3 == frame]
470                potential_stops = [x for x in conserved_stops if x % 3 == frame]
471                last_stop = -1
472                for start in potential_starts:
473                    # go to the next stop that is sufficiently far away in the alignment
474                    next_stops = [x for x in potential_stops if x + 3 >= start + min_length]
475                    if not next_stops:
476                        continue
477                    next_stop = next_stops[0]
478                    ungapped_sliced_seqs = get_ungapped_sliced_seqs(aln, start, next_stop)
479                    # re-check the lengths of all ungapped seqs
480                    ungapped_seq_lengths = [len(x) >= min_length for x in ungapped_sliced_seqs]
481                    if not all(ungapped_seq_lengths):
482                        continue
483                    # if no stop codon between start and stop --> write to dictionary
484                    if not additional_stops(ungapped_sliced_seqs):
485                        if direction == '+':
486                            positions = (start, next_stop + 3)
487                        else:
488                            positions = (aln_len - next_stop - 3, aln_len - start)
489                        if last_stop != next_stop:
490                            last_stop = next_stop
491                            conservation = calculate_identity(identities, list(positions))
492                            if identity_cutoff is not None and conservation < identity_cutoff:
493                                continue
494                            temp_orfs.append({
495                                'orf_id': f'ORF_{orf_counter}',
496                                'location': [positions],
497                                'frame': frame,
498                                'strand': direction,
499                                'conservation': conservation,
500                                'internal': [],
501                            })
502                            orf_counter += 1
503                        else:
504                            if temp_orfs:
505                                temp_orfs[-1]['internal'].append(positions)
506
507        # convert mutable intermediate dicts to frozen dataclasses
508        orf_list = [
509            OpenReadingFrame(
510                orf_id=t['orf_id'],
511                location=tuple(t['location']),
512                frame=t['frame'],
513                strand=t['strand'],
514                conservation=t['conservation'],
515                internal=tuple(t['internal']),
516            )
517            for t in temp_orfs
518        ]
519        return OrfCollection(orfs=tuple(orf_list))

conserved ORF definition: - conserved starts and stops - start, stop must be on the same frame - stop - start must be at least min_length - all ungapped seqs[start:stop] must have at least min_length - no ungapped seq can have a Stop in between Start Stop

Conservation is measured by the number of positions with identical characters divided by orf slice of the alignment.

Algorithm overview: - check for conserved start and stop codons - iterate over all three frames - check each start and next sufficiently far away stop codon - check if all ungapped seqs between start and stop codon are >= min_length - check if no ungapped seq in the alignment has a stop codon - write to dictionary - classify as internal if the stop codon has already been written with a prior start - repeat for reverse complement

Returns

ORF positions and internal ORF positions

def get_non_overlapping_conserved_orfs( self, min_length: int = 100, identity_cutoff: float = None) -> msaexplorer._data_classes.OrfCollection:
521    def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff:float = None) -> OrfCollection:
522        """
523        First calculates all ORFs and then searches from 5'
524        all non-overlapping orfs in the fw strand and from the
525        3' all non-overlapping orfs in th rw strand.
526
527        **No overlap algorithm:**
528            **frame 1:** -[M------*]--- ----[M--*]---------[M-----
529
530            **frame 2:** -------[M------*]---------[M---*]--------
531
532            **frame 3:** [M---*]-----[M----------*]----------[M---
533
534            **results:** [M---*][M------*]--[M--*]-[M---*]-[M-----
535
536            frame:    3      2           1      2       1
537
538        :return: OrfContainer with non-overlapping orfs
539        """
540        all_orfs = self.get_conserved_orfs(min_length, identity_cutoff)
541
542        fw_orfs, rw_orfs = [], []
543
544        for orf in all_orfs:
545            orf_obj = all_orfs[orf]
546            if orf_obj.strand == '+':
547                fw_orfs.append((orf, orf_obj.location[0]))
548            else:
549                rw_orfs.append((orf, orf_obj.location[0]))
550
551        fw_orfs.sort(key=lambda x: x[1][0])  # sort by start pos
552        rw_orfs.sort(key=lambda x: x[1][1], reverse=True)  # sort by stop pos
553        non_overlapping_ids = []
554        for orf_list, strand in zip([fw_orfs, rw_orfs], ['+', '-']):
555            previous_stop = -1 if strand == '+' else self.length + 1
556            for orf in orf_list:
557                if strand == '+' and orf[1][0] >= previous_stop:
558                    non_overlapping_ids.append(orf[0])
559                    previous_stop = orf[1][1]
560                elif strand == '-' and orf[1][1] <= previous_stop:
561                    non_overlapping_ids.append(orf[0])
562                    previous_stop = orf[1][0]
563
564        return OrfCollection(orfs=tuple(all_orfs[orf_id] for orf_id in all_orfs if orf_id in non_overlapping_ids))

First calculates all ORFs and then searches from 5' all non-overlapping orfs in the fw strand and from the 3' all non-overlapping orfs in th rw strand.

No overlap algorithm: frame 1: -[M------]--- ----[M--]---------[M-----

**frame 2:** -------[M------*]---------[M---*]--------

**frame 3:** [M---*]-----[M----------*]----------[M---

**results:** [M---*][M------*]--[M--*]-[M---*]-[M-----

frame:    3      2           1      2       1
Returns

OrfContainer with non-overlapping orfs

def calc_length_stats(self) -> msaexplorer._data_classes.LengthStats:
566    def calc_length_stats(self) -> LengthStats:
567        """
568        Determine the stats for the length of the ungapped seqs in the alignment.
569        :return: dataclass with length stats
570        """
571
572        seq_lengths = [len(self.alignment[x].replace('-', '')) for x in self.alignment]
573
574        return LengthStats(
575            n_sequences=len(self.alignment),
576            mean_length=float(np.mean(seq_lengths)),
577            std_length=float(np.std(seq_lengths)),
578            min_length=int(np.min(seq_lengths)),
579            max_length=int(np.max(seq_lengths)),
580        )

Determine the stats for the length of the ungapped seqs in the alignment.

Returns

dataclass with length stats

def calc_entropy(self) -> msaexplorer._data_classes.AlignmentStats:
582    def calc_entropy(self) -> AlignmentStats:
583        """
584        Calculate the normalized shannon's entropy for every position in an alignment:
585
586        - 1: high entropy
587        - 0: low entropy
588
589        :return: Entropies at each position.
590        """
591
592        # helper functions
593        def shannons_entropy(character_list: list, states: int, aln_type: str) -> float:
594            """
595            Calculate the shannon's entropy of a sequence and
596            normalized between 0 and 1.
597            :param character_list: characters at an alignment position
598            :param states: number of potential characters that can be present
599            :param aln_type: type of the alignment
600            :returns: entropy
601            """
602            ent, n_chars = 0, len(character_list)
603            # only one char is in the list
604            if n_chars <= 1:
605                return ent
606            # calculate the number of unique chars and their counts
607            chars, char_counts = np.unique(character_list, return_counts=True)
608            char_counts = char_counts.astype(float)
609            # ignore gaps for entropy calc
610            char_counts, chars = char_counts[chars != "-"], chars[chars != "-"]
611            # correctly handle ambiguous chars
612            index_to_drop = []
613            for index, char in enumerate(chars):
614                if char in config.AMBIG_CHARS[aln_type]:
615                    index_to_drop.append(index)
616                    amb_chars, amb_counts = np.unique(config.AMBIG_CHARS[aln_type][char], return_counts=True)
617                    amb_counts = amb_counts / len(config.AMBIG_CHARS[aln_type][char])
618                    # add the proportionate numbers to initial array
619                    for amb_char, amb_count in zip(amb_chars, amb_counts):
620                        if amb_char in chars:
621                            char_counts[chars == amb_char] += amb_count
622                        else:
623                            chars, char_counts = np.append(chars, amb_char), np.append(char_counts, amb_count)
624            # drop the ambiguous characters from array
625            char_counts, chars = np.delete(char_counts, index_to_drop), np.delete(chars, index_to_drop)
626            # calc the entropy
627            probs = char_counts / n_chars
628            if np.count_nonzero(probs) <= 1:
629                return ent
630            for prob in probs:
631                ent -= prob * math.log(prob, states)
632
633            return ent
634
635        aln = self.alignment
636        entropys = []
637
638        if self.aln_type == 'AA':
639            states = 20
640        else:
641            states = 4
642        # iterate over alignment positions and the sequences
643        for nuc_pos in range(self.length):
644            pos = []
645            for record in aln:
646                pos.append(aln[record][nuc_pos])
647            entropys.append(shannons_entropy(pos, states, self.aln_type))
648
649        return self._create_position_stat_result('entropy', entropys)

Calculate the normalized shannon's entropy for every position in an alignment:

  • 1: high entropy
  • 0: low entropy
Returns

Entropies at each position.

def calc_gc(self) -> msaexplorer._data_classes.AlignmentStats | TypeError:
651    def calc_gc(self) -> AlignmentStats | TypeError:
652        """
653        Determine the GC content for every position in an nt alignment.
654        :return: GC content for every position.
655        :raises: TypeError for AA alignments
656        """
657        if self.aln_type == 'AA':
658            raise TypeError("GC computation is not possible for aminoacid alignment")
659
660        gc, aln, amb_nucs = [], self.alignment, config.AMBIG_CHARS[self.aln_type]
661
662        for position in range(self.length):
663            nucleotides = str()
664            for record in aln:
665                nucleotides = nucleotides + aln[record][position]
666            # ini dict with chars that occur and which ones to
667            # count in which freq
668            to_count = {
669                'G': 1,
670                'C': 1,
671            }
672            # handle ambig. nuc
673            for char in amb_nucs:
674                if char in nucleotides:
675                    to_count[char] = (amb_nucs[char].count('C') + amb_nucs[char].count('G')) / len(amb_nucs[char])
676
677            gc.append(
678                sum([nucleotides.count(x) * to_count[x] for x in to_count]) / len(nucleotides)
679            )
680
681        return self._create_position_stat_result('gc', gc)

Determine the GC content for every position in an nt alignment.

Returns

GC content for every position.

Raises
  • TypeError for AA alignments
def calc_coverage(self) -> msaexplorer._data_classes.AlignmentStats:
683    def calc_coverage(self) -> AlignmentStats:
684        """
685        Determine the coverage of every position in an alignment.
686        This is defined as:
687            1 - cumulative length of '-' characters
688
689        :return: Coverage at each alignment position.
690        """
691        coverage, aln = [], self.alignment
692
693        for nuc_pos in range(self.length):
694            pos = str()
695            for record in self.sequence_ids:
696                pos = pos + aln[record][nuc_pos]
697            coverage.append(1 - pos.count('-') / len(pos))
698
699        return self._create_position_stat_result('coverage', coverage)

Determine the coverage of every position in an alignment. This is defined as: 1 - cumulative length of '-' characters

Returns

Coverage at each alignment position.

def calc_gap_frequency(self) -> msaexplorer._data_classes.AlignmentStats:
701    def calc_gap_frequency(self) -> AlignmentStats:
702        """
703        Determine the gap frequency for every position in an alignment. This is the inverted coverage.
704        """
705        coverage = self.calc_coverage()
706
707        return self._create_position_stat_result('gap frequency', 1 - coverage.values)

Determine the gap frequency for every position in an alignment. This is the inverted coverage.

def calc_reverse_complement_alignment(self) -> dict | TypeError:
709    def calc_reverse_complement_alignment(self) -> dict | TypeError:
710        """
711        Reverse complement the alignment.
712        :return: Alignment (rv)
713        """
714        if self.aln_type == 'AA':
715            raise TypeError('Reverse complement only for RNA or DNA.')
716
717        aln = self.alignment
718        reverse_complement_dict = {}
719
720        for seq_id in aln:
721            reverse_complement_dict[seq_id] = ''.join(config.COMPLEMENT[base] for base in reversed(aln[seq_id]))
722
723        return reverse_complement_dict

Reverse complement the alignment.

Returns

Alignment (rv)

def calc_numerical_alignment( self, encode_mask: bool = False, encode_ambiguities: bool = False) -> numpy.ndarray:
725    def calc_numerical_alignment(self, encode_mask:bool=False, encode_ambiguities:bool=False) -> ndarray:
726        """
727        Transforms the alignment to numerical values. Ambiguities are encoded as -3, mask as -2 and the
728        remaining chars with the idx + 1 of config.CHAR_COLORS[self.aln_type]['standard'].
729
730        :param encode_ambiguities: encode ambiguities as -2
731        :param encode_mask: encode mask with as -3
732        :returns matrix
733        """
734
735        sequences = self._to_array()
736        # ini matrix
737        numerical_matrix = np.full(sequences.shape, np.nan, dtype=float)
738        # first encode mask
739        if encode_mask:
740            if self.aln_type == 'AA':
741                is_n_or_x = np.isin(sequences, ['X'])
742            else:
743                is_n_or_x = np.isin(sequences, ['N'])
744            numerical_matrix[is_n_or_x] = -2
745        # next encode ambig chars
746        if encode_ambiguities:
747            numerical_matrix[np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])] = -3
748        # next convert each char into their respective values
749        for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]['standard']):
750            numerical_matrix[np.isin(sequences, [char])] = idx + 1
751
752        return numerical_matrix

Transforms the alignment to numerical values. Ambiguities are encoded as -3, mask as -2 and the remaining chars with the idx + 1 of config.CHAR_COLORS[self.aln_type]['standard'].

Parameters
  • encode_ambiguities: encode ambiguities as -2
  • encode_mask: encode mask with as -3 :returns matrix
def calc_identity_alignment( self, encode_mismatches: bool = True, encode_mask: bool = False, encode_gaps: bool = True, encode_ambiguities: bool = False, encode_each_mismatch_char: bool = False) -> numpy.ndarray:
754    def calc_identity_alignment(self, encode_mismatches:bool=True, encode_mask:bool=False, encode_gaps:bool=True, encode_ambiguities:bool=False, encode_each_mismatch_char:bool=False) -> ndarray:
755        """
756        Converts alignment to identity array (identical=0) compared to majority consensus or reference:\n
757
758        :param encode_mismatches: encode mismatch as -1
759        :param encode_mask: encode mask with value=-2 --> also in the reference
760        :param encode_gaps: encode gaps with np.nan --> also in the reference
761        :param encode_ambiguities: encode ambiguities with value=-3
762        :param encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.CHAR_COLORS[self.aln_type]['standard']
763        :return: identity alignment
764        """
765
766        ref = self._get_reference_seq()
767
768        # convert alignment to array
769        sequences = self._to_array()
770        reference = np.array(list(ref))
771        # ini matrix
772        identity_matrix = np.full(sequences.shape, 0, dtype=float)
773
774        is_identical = sequences == reference
775
776        if encode_gaps:
777            is_gap = sequences == '-'
778        else:
779            is_gap = np.full(sequences.shape, False)
780
781        if encode_mask:
782            if self.aln_type == 'AA':
783                is_n_or_x = np.isin(sequences, ['X'])
784            else:
785                is_n_or_x = np.isin(sequences, ['N'])
786        else:
787            is_n_or_x = np.full(sequences.shape, False)
788
789        if encode_ambiguities:
790            is_ambig = np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])
791        else:
792            is_ambig = np.full(sequences.shape, False)
793
794        if encode_mismatches:
795            is_mismatch = ~is_gap & ~is_identical & ~is_n_or_x & ~is_ambig
796        else:
797            is_mismatch = np.full(sequences.shape, False)
798
799        # encode every different character
800        if encode_each_mismatch_char:
801            for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]['standard']):
802                new_encoding = np.isin(sequences, [char]) & is_mismatch
803                identity_matrix[new_encoding] = idx + 1
804        # or encode different with a single value
805        else:
806            identity_matrix[is_mismatch] = -1  # mismatch
807
808        identity_matrix[is_gap] = np.nan  # gap
809        identity_matrix[is_n_or_x] = -2  # 'N' or 'X'
810        identity_matrix[is_ambig] = -3  # ambiguities
811
812        return identity_matrix

Converts alignment to identity array (identical=0) compared to majority consensus or reference:

Parameters
  • encode_mismatches: encode mismatch as -1
  • encode_mask: encode mask with value=-2 --> also in the reference
  • encode_gaps: encode gaps with np.nan --> also in the reference
  • encode_ambiguities: encode ambiguities with value=-3
  • encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.CHAR_COLORS[self.aln_type]['standard']
Returns

identity alignment

def calc_similarity_alignment( self, matrix_type: str | None = None, normalize: bool = True) -> numpy.ndarray:
814    def calc_similarity_alignment(self, matrix_type:str|None=None, normalize:bool=True) -> ndarray:
815        """
816        Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight
817        differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the
818        reference residue at each column. Gaps are encoded as np.nan.
819
820        The calculation follows these steps:
821
822        1. **Reference Sequence**: If a reference sequence is provided (via `self.reference_id`), it is used. Otherwise,
823           a consensus sequence is generated to serve as the reference.
824        2. **Substitution Matrix**: The similarity between residues is determined using a substitution matrix, such as
825           BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
826        3. **Per-Column Normalization (optional)**:
827
828        For each column in the alignment:
829            - The residue in the reference sequence is treated as the baseline for that column.
830            - The substitution scores for the reference residue are extracted from the substitution matrix.
831            - The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference residue.
832            - This ensures that identical residues (or those with high similarity to the reference) have high scores,
833            while more dissimilar residues have lower scores.
834        4. **Output**:
835
836           - The normalized similarity scores are stored in a NumPy array.
837           - Gaps (if any) or residues not present in the substitution matrix are encoded as `np.nan`.
838
839        :param: matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
840        :param: normalize: whether to normalize the similarity scores to range [0, 1]
841        :return: A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue
842            and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity).
843            Gaps and invalid residues are encoded as `np.nan`.
844        :raise: ValueError
845            If the specified substitution matrix is not available for the given alignment type.
846        """
847
848        ref = self._get_reference_seq()
849
850        if matrix_type is None:
851            if self.aln_type == 'AA':
852                matrix_type = 'BLOSUM65'
853            else:
854                matrix_type = 'TRANS'
855        # load substitution matrix as dictionary
856        try:
857            subs_matrix = config.SUBS_MATRICES[self.aln_type][matrix_type]
858        except KeyError:
859            raise ValueError(
860                f'The specified matrix does not exist for alignment type.\nAvailable matrices for {self.aln_type} are:\n{list(config.SUBS_MATRICES[self.aln_type].keys())}'
861            )
862
863        # set dtype and convert alignment to a NumPy array for vectorized processing
864        dtype = np.dtype(float, metadata={'matrix': matrix_type})
865        sequences = self._to_array()
866        reference = np.array(list(ref))
867        valid_chars = list(subs_matrix.keys())
868        similarity_array = np.full(sequences.shape, np.nan, dtype=dtype)
869
870        for j, ref_char in enumerate(reference):
871            if ref_char not in valid_chars + ['-']:
872                continue
873            # Get local min and max for the reference residue
874            if normalize and ref_char != '-':
875                local_scores = subs_matrix[ref_char].values()
876                local_min, local_max = min(local_scores), max(local_scores)
877
878            for i, char in enumerate(sequences[:, j]):
879                if char not in valid_chars:
880                    continue
881                # classify the similarity as max if the reference has a gap
882                similarity_score = subs_matrix[char][ref_char] if ref_char != '-' else 1
883                similarity_array[i, j] = (similarity_score - local_min) / (local_max - local_min) if normalize and ref_char != '-' else similarity_score
884
885        return similarity_array

Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the reference residue at each column. Gaps are encoded as np.nan.

The calculation follows these steps:

  1. Reference Sequence: If a reference sequence is provided (via self.reference_id), it is used. Otherwise, a consensus sequence is generated to serve as the reference.
  2. Substitution Matrix: The similarity between residues is determined using a substitution matrix, such as BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
  3. Per-Column Normalization (optional):

For each column in the alignment: - The residue in the reference sequence is treated as the baseline for that column. - The substitution scores for the reference residue are extracted from the substitution matrix. - The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference residue. - This ensures that identical residues (or those with high similarity to the reference) have high scores, while more dissimilar residues have lower scores. 4. Output:

  • The normalized similarity scores are stored in a NumPy array.
  • Gaps (if any) or residues not present in the substitution matrix are encoded as np.nan.
Parameters
  • matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
  • normalize: whether to normalize the similarity scores to range [0, 1]
Returns

A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity). Gaps and invalid residues are encoded as np.nan. :raise: ValueError If the specified substitution matrix is not available for the given alignment type.

def calc_position_matrix(self, matrix_type: str = 'PWM') -> None | numpy.ndarray:
887    def calc_position_matrix(self, matrix_type:str='PWM') -> None | ndarray:
888        """
889        Calculates a position matrix of the specified type for the given alignment. The function
890        supports generating matrices of types Position Frequency Matrix (PFM), Position Probability
891        Matrix (PPM), Position Weight Matrix (PWM), and cumulative Information Content (IC). It validates
892        the provided matrix type and includes pseudo-count adjustments to ensure robust calculations.
893
894        :param matrix_type: Type of position matrix to calculate. Accepted values are 'PFM', 'PPM',
895            'PWM', and 'IC'. Defaults to 'PWM'.
896        :type matrix_type: str
897        :raises ValueError: If the provided `matrix_type` is not one of the accepted values.
898        :return: A numpy array representing the calculated position matrix of the specified type.
899        :rtype: np.ndarray
900        """
901
902        # ini
903        if matrix_type not in ['PFM', 'PPM', 'IC', 'PWM']:
904            raise ValueError('Matrix_type must be PFM, PPM, IC or PWM.')
905        possible_chars = list(config.CHAR_COLORS[self.aln_type]['standard'].keys())
906        sequences = self._to_array()
907
908        # calc position frequency matrix
909        pfm = np.array([np.sum(sequences == char, 0) for char in possible_chars])
910        if matrix_type == 'PFM':
911            return pfm
912
913        # calc position probability matrix (probability)
914        pseudo_count = 0.0001  # to avoid 0 values
915        pfm = pfm + pseudo_count
916        ppm_non_char_excluded = pfm/np.sum(pfm, axis=0)  # use this for pwm/ic calculation
917        ppm = pfm/len(self.sequence_ids)  # calculate the frequency based on row number
918        if matrix_type == 'PPM':
919            return ppm
920
921        # calc position weight matrix (log-likelihood)
922        pwm = np.log2(ppm_non_char_excluded * len(possible_chars))
923        if matrix_type == 'PWM':
924            return pwm
925
926        # calc information content per position (in bits) - can be used to scale a ppm for sequence logos
927        ic = np.sum(ppm_non_char_excluded * pwm, axis=0)
928        if matrix_type == 'IC':
929            return ic

Calculates a position matrix of the specified type for the given alignment. The function supports generating matrices of types Position Frequency Matrix (PFM), Position Probability Matrix (PPM), Position Weight Matrix (PWM), and cumulative Information Content (IC). It validates the provided matrix type and includes pseudo-count adjustments to ensure robust calculations.

Parameters
  • matrix_type: Type of position matrix to calculate. Accepted values are 'PFM', 'PPM', 'PWM', and 'IC'. Defaults to 'PWM'.
Raises
  • ValueError: If the provided matrix_type is not one of the accepted values.
Returns

A numpy array representing the calculated position matrix of the specified type.

def calc_percent_recovery(self) -> dict[str, float]:
931    def calc_percent_recovery(self) -> dict[str, float]:
932        """
933        Recovery per sequence either compared to the majority consensus seq
934        or the reference seq.\n
935        Defined as:\n
936
937        `(1 - sum(N/X/- characters in ungapped ref regions))*100`
938
939        This is highly similar to how nextclade calculates recovery over reference.
940
941        :return: dict
942        """
943
944        aln = self.alignment
945        ref = self._get_reference_seq()
946
947        if not any(char != '-' for char in ref):
948            raise ValueError("Reference sequence is entirely gapped, cannot calculate recovery.")
949
950
951        # count 'N', 'X' and '-' chars in non-gapped regions
952        recovery_over_ref = dict()
953
954        # Get positions of non-gap characters in the reference
955        non_gap_positions = [i for i, char in enumerate(ref) if char != '-']
956        cumulative_length = len(non_gap_positions)
957
958        # Calculate recovery
959        for seq_id in self.sequence_ids:
960            if seq_id == self.reference_id:
961                continue
962            seq = aln[seq_id]
963            count_invalid = sum(
964                seq[pos] == '-' or
965                (seq[pos] == 'X' if self.aln_type == "AA" else seq[pos] == 'N')
966                for pos in non_gap_positions
967            )
968            recovery_over_ref[seq_id] = (1 - count_invalid / cumulative_length) * 100
969
970        return recovery_over_ref

Recovery per sequence either compared to the majority consensus seq or the reference seq.

Defined as:

(1 - sum(N/X/- characters in ungapped ref regions))*100

This is highly similar to how nextclade calculates recovery over reference.

Returns

dict

def calc_character_frequencies(self) -> dict:
 972    def calc_character_frequencies(self) -> dict:
 973        """
 974        Calculate the percentage characters in the alignment:
 975        The frequencies are counted by seq and in total. The
 976        percentage of non-gap characters in the alignment is
 977        relative to the total number of non-gap characters.
 978        The gap percentage is relative to the sequence length.
 979
 980        The output is a nested dictionary.
 981
 982        :return: Character frequencies
 983        """
 984
 985        aln, aln_length = self.alignment, self.length
 986
 987        freqs = {'total': {'-': {'counts': 0, '% of alignment': float()}}}
 988
 989        for seq_id in aln:
 990            freqs[seq_id], all_chars = {'-': {'counts': 0, '% of alignment': float()}}, 0
 991            unique_chars = set(aln[seq_id])
 992            for char in unique_chars:
 993                if char == '-':
 994                    continue
 995                # add characters to dictionaries
 996                if char not in freqs[seq_id]:
 997                    freqs[seq_id][char] = {'counts': 0, '% of non-gapped': 0}
 998                if char not in freqs['total']:
 999                    freqs['total'][char] = {'counts': 0, '% of non-gapped': 0}
1000                # count non-gap chars
1001                freqs[seq_id][char]['counts'] += aln[seq_id].count(char)
1002                freqs['total'][char]['counts'] += freqs[seq_id][char]['counts']
1003                all_chars += freqs[seq_id][char]['counts']
1004            # normalize counts
1005            for char in freqs[seq_id]:
1006                if char == '-':
1007                    continue
1008                freqs[seq_id][char]['% of non-gapped'] = freqs[seq_id][char]['counts'] / all_chars * 100
1009                freqs['total'][char]['% of non-gapped'] += freqs[seq_id][char]['% of non-gapped']
1010            # count gaps
1011            freqs[seq_id]['-']['counts'] = aln[seq_id].count('-')
1012            freqs['total']['-']['counts'] += freqs[seq_id]['-']['counts']
1013            # normalize gap counts
1014            freqs[seq_id]['-']['% of alignment'] = freqs[seq_id]['-']['counts'] / aln_length * 100
1015            freqs['total']['-']['% of alignment'] += freqs[seq_id]['-']['% of alignment']
1016
1017        # normalize the total counts
1018        for char in freqs['total']:
1019            for value in freqs['total'][char]:
1020                if value == '% of alignment' or value == '% of non-gapped':
1021                    freqs['total'][char][value] = freqs['total'][char][value] / len(aln)
1022
1023        return freqs

Calculate the percentage characters in the alignment: The frequencies are counted by seq and in total. The percentage of non-gap characters in the alignment is relative to the total number of non-gap characters. The gap percentage is relative to the sequence length.

The output is a nested dictionary.

Returns

Character frequencies

def calc_pairwise_identity_matrix( self, distance_type: str = 'ghd') -> msaexplorer._data_classes.PairwiseDistance:
1025    def calc_pairwise_identity_matrix(self, distance_type:str='ghd') -> PairwiseDistance:
1026        """
1027        Calculate pairwise identities for an alignment. Different options are implemented:
1028
1029        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
1030        \ndistance = matches / alignment_length * 100
1031
1032        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
1033        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
1034
1035        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
1036        \ndistance = (1 - mismatches / total) * 100
1037
1038        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
1039        \ndistance = matches / gap_compressed_alignment_length * 100
1040
1041        RNA/DNA only:
1042
1043        **5) jc69 (Jukes-Cantor 1969)**: Gaps excluded. Applies the JC69 substitution model to correct
1044        the p-distance for multiple hits (assumes equal base frequencies and substitution rates).
1045        \ncorrected_identity = (1 - d_JC69) * 100,  where d = -(3/4) * ln(1 - (4/3) * p)
1046
1047        **6) k2p (Kimura 2-Parameter / K80)**: Gaps excluded. Distinguishes transitions (Ts) and
1048        transversions (Tv). Returns (1 - d_K2P) * 100 as corrected percent identity.
1049        \nd = -(1/2) * ln(1 - 2P - Q) - (1/4) * ln(1 - 2Q),  P = Ts/total, Q = Tv/total
1050
1051        :param distance_type: type of distance computation: ghd, lhd, ged, gcd and nucleotide only: jc69 and k2p
1052        :return: array with pairwise distances.
1053        """
1054
1055        distance_functions = _create_distance_calculation_function_mapping()
1056
1057        if distance_type not in distance_functions:
1058            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1059
1060        aln = self.alignment
1061
1062        if self.aln_type == 'AA' and distance_type in ['jc69', 'k2p']:
1063            raise ValueError(f"JC69 and K2P are not supported for {self.aln_type} alignment.")
1064
1065        # Compute pairwise distances
1066        distance_func = distance_functions[distance_type]
1067        distance_matrix = np.zeros((len(aln), len(aln)))
1068
1069        sequences = list(aln.values())
1070        n = len(sequences)
1071        for i in range(n):
1072            seq1 = sequences[i]
1073            for j in range(i, n):
1074                seq2 = sequences[j]
1075                dist = distance_func(seq1, seq2, self.length)
1076                distance_matrix[i, j] = dist
1077                distance_matrix[j, i] = dist
1078
1079        return PairwiseDistance(
1080            reference_id=None,
1081            sequence_ids=self.sequence_ids,
1082            distances=distance_matrix
1083        )

Calculate pairwise identities for an alignment. Different options are implemented:

    **1) ghd (global hamming distance)**: At each alignment position, check if characters match:

distance = matches / alignment_length * 100

    **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:

distance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100

    **3) ged (gap excluded distance)**: All gaps are excluded from the alignment

distance = (1 - mismatches / total) * 100

    **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.

distance = matches / gap_compressed_alignment_length * 100

    RNA/DNA only:

    **5) jc69 (Jukes-Cantor 1969)**: Gaps excluded. Applies the JC69 substitution model to correct
    the p-distance for multiple hits (assumes equal base frequencies and substitution rates).

corrected_identity = (1 - d_JC69) * 100, where d = -(3/4) * ln(1 - (4/3) * p)

    **6) k2p (Kimura 2-Parameter / K80)**: Gaps excluded. Distinguishes transitions (Ts) and
    transversions (Tv). Returns (1 - d_K2P) * 100 as corrected percent identity.

d = -(1/2) * ln(1 - 2P - Q) - (1/4) * ln(1 - 2Q), P = Ts/total, Q = Tv/total

    :param distance_type: type of distance computation: ghd, lhd, ged, gcd and nucleotide only: jc69 and k2p
    :return: array with pairwise distances.
def calc_pairwise_distance_to_reference( self, distance_type: str = 'ghd') -> msaexplorer._data_classes.PairwiseDistance:
1085    def calc_pairwise_distance_to_reference(self, distance_type:str='ghd') -> PairwiseDistance:
1086        """
1087        Calculate pairwise identities between reference and all sequences in the alignment. Same computation as calc_pairwise_identity_matrix but compared to a single sequence. Different options are implemented:
1088
1089        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
1090        \ndistance = matches / alignment_length * 100
1091
1092        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
1093        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
1094
1095        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
1096        \ndistance = (1 - mismatches / total) * 100
1097
1098        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
1099        \ndistance = matches / gap_compressed_alignment_length * 100
1100
1101        RNA/DNA only:
1102
1103        **5) jc69 (Jukes-Cantor 1969)**: Gaps excluded. Applies the JC69 substitution model to correct
1104        the p-distance for multiple hits (assumes equal base frequencies and substitution rates).
1105        \ncorrected_identity = (1 - d_JC69) * 100,  where d = -(3/4) * ln(1 - (4/3) * p)
1106
1107        **6) k2p (Kimura 2-Parameter / K80)**: Gaps excluded. Distinguishes transitions (Ts) and
1108        transversions (Tv). Returns (1 - d_K2P) * 100 as corrected percent identity.
1109        \nd = -(1/2) * ln(1 - 2P - Q) - (1/4) * ln(1 - 2Q),  P = Ts/total, Q = Tv/total
1110
1111        :param distance_type: type of distance computation: ghd, lhd, ged, gcd and nucleotide only: jc69 and k2p
1112        :return: array with pairwise distances.
1113        :return: dataclass with reference label, sequence ids and pairwise distances.
1114        """
1115
1116        distance_functions = _create_distance_calculation_function_mapping()
1117
1118        if distance_type not in distance_functions:
1119            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1120
1121        if self.aln_type == 'AA' and distance_type in ['jc69', 'k2p']:
1122            raise ValueError(f"JC69 and K2P are not supported for {self.aln_type} alignment.")
1123
1124        # Compute pairwise distances
1125        distance_func = distance_functions[distance_type]
1126        aln = self.alignment
1127        ref_id = self.reference_id
1128
1129        ref_seq = aln[ref_id] if ref_id is not None else self.get_consensus()
1130        distances = []
1131        distance_names = []
1132
1133        for seq_id in aln:
1134            if seq_id == ref_id:
1135                continue
1136            distance_names.append(seq_id)
1137            distances.append(distance_func(ref_seq, aln[seq_id], self.length))
1138
1139        return PairwiseDistance(
1140            reference_id=ref_id if ref_id is not None else 'consensus',
1141            sequence_ids=distance_names,
1142            distances=np.array(distances)
1143        )

Calculate pairwise identities between reference and all sequences in the alignment. Same computation as calc_pairwise_identity_matrix but compared to a single sequence. Different options are implemented:

    **1) ghd (global hamming distance)**: At each alignment position, check if characters match:

distance = matches / alignment_length * 100

    **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:

distance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100

    **3) ged (gap excluded distance)**: All gaps are excluded from the alignment

distance = (1 - mismatches / total) * 100

    **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.

distance = matches / gap_compressed_alignment_length * 100

    RNA/DNA only:

    **5) jc69 (Jukes-Cantor 1969)**: Gaps excluded. Applies the JC69 substitution model to correct
    the p-distance for multiple hits (assumes equal base frequencies and substitution rates).

corrected_identity = (1 - d_JC69) * 100, where d = -(3/4) * ln(1 - (4/3) * p)

    **6) k2p (Kimura 2-Parameter / K80)**: Gaps excluded. Distinguishes transitions (Ts) and
    transversions (Tv). Returns (1 - d_K2P) * 100 as corrected percent identity.

d = -(1/2) * ln(1 - 2P - Q) - (1/4) * ln(1 - 2Q), P = Ts/total, Q = Tv/total

    :param distance_type: type of distance computation: ghd, lhd, ged, gcd and nucleotide only: jc69 and k2p
    :return: array with pairwise distances.
    :return: dataclass with reference label, sequence ids and pairwise distances.
def get_snps( self, include_ambig: bool = False) -> msaexplorer._data_classes.VariantCollection:
1145    def get_snps(self, include_ambig:bool=False) -> VariantCollection:
1146        """
1147        Calculate snps similar to snp-sites (output is comparable):
1148        https://github.com/sanger-pathogens/snp-sites
1149        Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character.
1150        The SNPs are compared to a majority consensus sequence or to a reference if it has been set.
1151
1152        :param include_ambig: Include ambiguous snps (default: False)
1153        :return: dataclass containing SNP positions and their variants including frequency.
1154        """
1155        aln = self.alignment
1156        ref = self._get_reference_seq()
1157        aln = {x: aln[x] for x in self.sequence_ids if x != self.reference_id}
1158        seq_ids = list(aln.keys())
1159        chrom = self.reference_id if self.reference_id is not None else 'consensus'
1160        snp_positions = {}
1161        aln_size = len(aln)
1162
1163        for pos in range(self.length):
1164            reference_char = ref[pos]
1165            if not include_ambig:
1166                if reference_char in config.AMBIG_CHARS[self.aln_type] and reference_char != '-':
1167                    continue
1168            alt_chars, snps = [], []
1169            for i, seq_id in enumerate(seq_ids):
1170                alt_char = aln[seq_id][pos]
1171                alt_chars.append(alt_char)
1172                if reference_char != alt_char:
1173                    snps.append(i)
1174            if not snps:
1175                continue
1176            # Filter out ambiguous snps if not included
1177            if include_ambig:
1178                if all(alt_chars[x] in config.AMBIG_CHARS[self.aln_type] for x in snps):
1179                    continue
1180            else:
1181                snps = [x for x in snps if alt_chars[x] not in config.AMBIG_CHARS[self.aln_type]]
1182                if not snps:
1183                    continue
1184            # Build allele dict with counts
1185            alt_dict = {}
1186            for snp_idx in snps:
1187                alt_char = alt_chars[snp_idx]
1188                if alt_char not in alt_dict:
1189                    alt_dict[alt_char] = {'count': 0, 'seq_ids': []}
1190                alt_dict[alt_char]['count'] += 1
1191                alt_dict[alt_char]['seq_ids'].append(seq_ids[snp_idx])
1192            
1193            # Convert to final format: alt_char -> (frequency, seq_ids_tuple)
1194            alt = {
1195                alt_char: (data['count'] / aln_size, tuple(data['seq_ids']))
1196                for alt_char, data in alt_dict.items()
1197            }
1198            snp_positions[pos] = SingleNucleotidePolymorphism(ref=reference_char, alt=alt)
1199
1200        return VariantCollection(chrom=chrom, positions=snp_positions)

Calculate snps similar to snp-sites (output is comparable): https://github.com/sanger-pathogens/snp-sites Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character. The SNPs are compared to a majority consensus sequence or to a reference if it has been set.

Parameters
  • include_ambig: Include ambiguous snps (default: False)
Returns

dataclass containing SNP positions and their variants including frequency.

def calc_transition_transversion_score(self) -> msaexplorer._data_classes.AlignmentStats:
1202    def calc_transition_transversion_score(self) -> AlignmentStats:
1203        """
1204        Based on the snp positions, calculates a transition/transversions score.
1205        A positive score means higher ratio of transitions and negative score means
1206        a higher ratio of transversions.
1207        :return: list
1208        """
1209
1210        if self.aln_type == 'AA':
1211            raise TypeError('TS/TV scoring only for RNA/DNA alignments')
1212
1213        # ini
1214        snps = self.get_snps()
1215        score = [0]*self.length
1216
1217        for pos, snp in snps.positions.items():
1218            for alt, (af, _) in snp.alt.items():
1219                # check the type of substitution
1220                if snp.ref + alt in ['AG', 'GA', 'CT', 'TC', 'CU', 'UC']:
1221                    score[pos] += af
1222                else:
1223                    score[pos] -= af
1224
1225        return self._create_position_stat_result('ts tv score', score)

Based on the snp positions, calculates a transition/transversions score. A positive score means higher ratio of transitions and negative score means a higher ratio of transversions.

Returns

list

class Annotation:
1228class Annotation:
1229    """
1230    An annotation class that allows to read in gff, gb, or bed files and adjust its locations to that of the MSA.
1231    """
1232
1233    def __init__(self, aln: MSA, annotation: str | GenBankIterator):
1234        """
1235        The annotation class. Lets you parse multiple standard formats
1236        which might be used for annotating an alignment. The main purpose
1237        is to parse the annotation file and adapt the locations of diverse
1238        features to the locations within the alignment, considering the
1239        respective alignment positions. Importantly, IDs of the alignment
1240        and the MSA have to partly match.
1241
1242        :param aln: MSA class
1243        :param annotation: path to file (gb, bed, gff) or raw string or GenBankIterator from biopython
1244
1245        """
1246
1247        self.ann_type, self._seq_id, self.locus, self.features  = self._parse_annotation(annotation, aln)  # read annotation
1248        self._gapped_seq = self._MSA_validation_and_seq_extraction(aln, self._seq_id)  # extract gapped sequence
1249        self._position_map = self._build_position_map()  # build a position map
1250        self._map_to_alignment()  # adapt feature locations
1251
1252    @staticmethod
1253    def _MSA_validation_and_seq_extraction(aln: MSA, seq_id: str) -> str:
1254        """
1255        extract gapped sequence from MSA that corresponds to annotation
1256        :param aln: MSA class
1257        :param seq_id: sequence id to extract
1258        :return: gapped sequence
1259        """
1260        if not isinstance(aln, MSA):
1261            raise ValueError('alignment has to be an MSA class. use explore.MSA() to read in alignment')
1262        else:
1263            return aln._alignment[seq_id]
1264
1265    @staticmethod
1266    def _parse_annotation(annotation: str | GenBankIterator, aln: MSA) -> tuple[str, str, str, Dict]:
1267
1268        def detect_annotation_type(handle: str | GenBankIterator) -> str:
1269            """
1270            Detect the type of annotation file (GenBank, GFF, or BED) based
1271            on the first relevant line (excluding empty and #). Also recognizes
1272            Bio.SeqIO iterators as GenBank format.
1273
1274            :param handle: Path to the annotation file or Bio.SeqIO iterator for genbank records read with biopython.
1275            :return: The detected file type ('gb', 'gff', or 'bed').
1276
1277            :raises ValueError: If the file type cannot be determined.
1278            """
1279            # Check if input is a SeqIO iterator
1280            if isinstance(handle, GenBankIterator):
1281                return 'gb'
1282
1283            with _get_line_iterator(handle) as h:
1284                for line in h:
1285                    # skip empty lines and comments
1286                    if not line.strip() or line.startswith('#'):
1287                        continue
1288                   # genbank
1289                    if line.startswith('LOCUS'):
1290                        return 'gb'
1291                    # gff
1292                    if len(line.split('\t')) == 9:
1293                        # Check for expected values
1294                        columns = line.split('\t')
1295                        if columns[6] in ['+', '-', '.'] and re.match(r'^\d+$', columns[3]) and re.match(r'^\d+$',columns[4]):
1296                            return 'gff'
1297                    # BED files are tab-delimited with at least 3 fields: chrom, start, end
1298                    fields = line.split('\t')
1299                    if len(fields) >= 3 and re.match(r'^\d+$', fields[1]) and re.match(r'^\d+$', fields[2]):
1300                        return 'bed'
1301                    # only read in the first line
1302                    break
1303
1304            raise ValueError(
1305                "File type could not be determined. Ensure the file follows a recognized format (GenBank, GFF, or BED).")
1306
1307        def parse_gb(file: str | GenBankIterator) -> dict:
1308            """
1309            Parse a GenBank file into the same dictionary structure used by the annotation pipeline.
1310
1311            :param file: path to genbank file, raw string, or Bio.SeqIO iterator
1312            :return: nested dictionary
1313            """
1314            records = {}
1315
1316            # Check if input is a GenBankIterator
1317            if isinstance(file, GenBankIterator):
1318                # Direct GenBankIterator input
1319                seq_records = list(file)
1320            else:
1321                # File path or string input
1322                with _get_line_iterator(file) as handle:
1323                    seq_records = list(SeqIO.parse(handle, "genbank"))
1324
1325            for seq_record in seq_records:
1326                locus = seq_record.name if seq_record.name else seq_record.id
1327                feature_container = {}
1328                feature_counter = {}
1329
1330                for feature in seq_record.features:
1331                    feature_type = feature.type
1332                    if feature_type not in feature_container:
1333                        feature_container[feature_type] = {}
1334                        feature_counter[feature_type] = 0
1335
1336                    current_idx = feature_counter[feature_type]
1337                    feature_counter[feature_type] += 1
1338
1339                    # Support simple and compound feature locations uniformly.
1340                    parts = feature.location.parts if hasattr(feature.location, "parts") else [feature.location]
1341                    locations = []
1342                    for part in parts:
1343                        try:
1344                            locations.append([int(part.start), int(part.end)])
1345                        except (TypeError, ValueError):
1346                            continue
1347
1348                    strand = '-' if feature.location.strand == -1 else '+'
1349                    parsed_feature = {
1350                        'location': locations,
1351                        'strand': strand,
1352                    }
1353
1354                    # Keep qualifier keys unchanged and flatten values to strings.
1355                    for qualifier_type, qualifier_values in feature.qualifiers.items():
1356                        if not qualifier_values:
1357                            continue
1358                        if isinstance(qualifier_values, list):
1359                            parsed_feature[qualifier_type] = qualifier_values[0] if len(qualifier_values) == 1 else ' '.join(str(x) for x in qualifier_values)
1360                        else:
1361                            parsed_feature[qualifier_type] = str(qualifier_values)
1362
1363                    feature_container[feature_type][current_idx] = parsed_feature
1364
1365                records[locus] = {
1366                    'locus': locus,
1367                    'features': feature_container,
1368                }
1369
1370            return records
1371
1372        def parse_gff(file_path) -> dict:
1373            """
1374            Parse a GFF3 (General Feature Format) file into a dictionary structure.
1375
1376            :param file_path: path to genebank file
1377            :return: nested dictionary
1378
1379            """
1380            records = {}
1381            with _get_line_iterator(file_path) as file:
1382                previous_id, previous_feature = None, None
1383                for line in file:
1384                    if line.startswith('#') or not line.strip():
1385                        continue
1386                    parts = line.strip().split('\t')
1387                    seqid, source, feature_type, start, end, score, strand, phase, attributes = parts
1388                    # ensure that region and source features are not named differently for gff and gb
1389                    if feature_type == 'region':
1390                        feature_type = 'source'
1391                    if seqid not in records:
1392                        records[seqid] = {'locus': seqid, 'features': {}}
1393                    if feature_type not in records[seqid]['features']:
1394                        records[seqid]['features'][feature_type] = {}
1395
1396                    feature_id = len(records[seqid]['features'][feature_type])
1397                    feature = {
1398                        'strand': strand,
1399                    }
1400
1401                    # Parse attributes into key-value pairs
1402                    for attr in attributes.split(';'):
1403                        if '=' in attr:
1404                            key, value = attr.split('=', 1)
1405                            feature[key.strip()] = value.strip()
1406
1407                    # check if feature are the same --> possible splicing
1408                    if previous_id is not None and previous_feature == feature:
1409                        records[seqid]['features'][feature_type][previous_id]['location'].append([int(start)-1, int(end)])
1410                    else:
1411                        records[seqid]['features'][feature_type][feature_id] = feature
1412                        records[seqid]['features'][feature_type][feature_id]['location'] = [[int(start) - 1, int(end)]]
1413                    # set new previous id and features -> new dict as 'location' is pointed in current feature and this
1414                    # is the only key different if next feature has the same entries
1415                    previous_id, previous_feature = feature_id, {key:value for key, value in feature.items() if key != 'location'}
1416
1417            return records
1418
1419        def parse_bed(file_path) -> dict:
1420            """
1421            Parse a BED file into a dictionary structure.
1422
1423            :param file_path: path to genebank file
1424            :return: nested dictionary
1425
1426            """
1427            records = {}
1428            with _get_line_iterator(file_path) as file:
1429                for line in file:
1430                    if line.startswith('#') or not line.strip():
1431                        continue
1432                    parts = line.strip().split('\t')
1433                    chrom, start, end, *optional = parts
1434
1435                    if chrom not in records:
1436                        records[chrom] = {'locus': chrom, 'features': {}}
1437                    feature_type = 'region'
1438                    if feature_type not in records[chrom]['features']:
1439                        records[chrom]['features'][feature_type] = {}
1440
1441                    feature_id = len(records[chrom]['features'][feature_type])
1442                    feature = {
1443                        'location': [[int(start), int(end)]],  # BED uses 0-based start, convert to 1-based
1444                        'strand': '+',  # assume '+' if not present
1445                    }
1446
1447                    # Handle optional columns (name, score, strand) --> ignore 7-12
1448                    if len(optional) >= 1:
1449                        feature['name'] = optional[0]
1450                    if len(optional) >= 2:
1451                        feature['score'] = optional[1]
1452                    if len(optional) >= 3:
1453                        feature['strand'] = optional[2]
1454
1455                    records[chrom]['features'][feature_type][feature_id] = feature
1456
1457            return records
1458
1459        parse_functions: Dict[str, Callable[[str], dict]] = {
1460            'gb': parse_gb,
1461            'bed': parse_bed,
1462            'gff': parse_gff,
1463        }
1464        # determine the annotation content -> should be standard formatted
1465        try:
1466            annotation_type = detect_annotation_type(annotation)
1467        except ValueError as err:
1468            raise err
1469
1470        # read in the annotation
1471        annotations = parse_functions[annotation_type](annotation)
1472
1473        # sanity check whether one of the annotation ids and alignment ids match
1474        annotation_found = False
1475        for annotation in annotations.keys():
1476            for aln_id in aln:
1477                aln_id_sanitized = aln_id.split(' ')[0]
1478                # check in both directions
1479                if aln_id_sanitized in annotation:
1480                    annotation_found = True
1481                    break
1482                if annotation in aln_id_sanitized:
1483                    annotation_found = True
1484                    break
1485
1486        if not annotation_found:
1487            raise ValueError(f'the annotations of {annotation} do not match any ids in the MSA')
1488
1489        # return only the annotation that has been found, the respective type and the seq_id to map to
1490        return annotation_type, aln_id, annotations[annotation]['locus'], annotations[annotation]['features']
1491
1492
1493    def _build_position_map(self) -> Dict[int, int]:
1494        """
1495        build a position map from a sequence.
1496
1497        :return genomic position: gapped position
1498        """
1499
1500        position_map = {}
1501        genomic_pos = 0
1502        for aln_index, char in enumerate(self._gapped_seq):
1503            if char != '-':
1504                position_map[genomic_pos] = aln_index
1505                genomic_pos += 1
1506        # ensure the last genomic position is included
1507        position_map[genomic_pos] = len(self._gapped_seq)
1508
1509        return position_map
1510
1511
1512    def _map_to_alignment(self):
1513        """
1514        Adjust all feature locations to alignment positions
1515        """
1516
1517        def map_location(position_map: Dict[int, int], locations: list) -> list:
1518            """
1519            Map genomic locations to alignment positions using a precomputed position map.
1520
1521            :param position_map: Positions mapped from gapped to ungapped
1522            :param locations: List of genomic start and end positions.
1523            :return: List of adjusted alignment positions.
1524            """
1525
1526            aligned_locs = []
1527            for start, end in locations:
1528                try:
1529                    aligned_start = position_map[start]
1530                    aligned_end = position_map[end]
1531                    aligned_locs.append([aligned_start, aligned_end])
1532                except KeyError:
1533                    raise ValueError(f"Positions {start}-{end} lie outside of the position map.")
1534
1535            return aligned_locs
1536
1537        for feature_type, features in self.features.items():
1538            for feature_id, feature_data in features.items():
1539                original_locations = feature_data['location']
1540                aligned_locations = map_location(self._position_map, original_locations)
1541                feature_data['location'] = aligned_locations

An annotation class that allows to read in gff, gb, or bed files and adjust its locations to that of the MSA.

Annotation( aln: MSA, annotation: str | Bio.SeqIO.InsdcIO.GenBankIterator)
1233    def __init__(self, aln: MSA, annotation: str | GenBankIterator):
1234        """
1235        The annotation class. Lets you parse multiple standard formats
1236        which might be used for annotating an alignment. The main purpose
1237        is to parse the annotation file and adapt the locations of diverse
1238        features to the locations within the alignment, considering the
1239        respective alignment positions. Importantly, IDs of the alignment
1240        and the MSA have to partly match.
1241
1242        :param aln: MSA class
1243        :param annotation: path to file (gb, bed, gff) or raw string or GenBankIterator from biopython
1244
1245        """
1246
1247        self.ann_type, self._seq_id, self.locus, self.features  = self._parse_annotation(annotation, aln)  # read annotation
1248        self._gapped_seq = self._MSA_validation_and_seq_extraction(aln, self._seq_id)  # extract gapped sequence
1249        self._position_map = self._build_position_map()  # build a position map
1250        self._map_to_alignment()  # adapt feature locations

The annotation class. Lets you parse multiple standard formats which might be used for annotating an alignment. The main purpose is to parse the annotation file and adapt the locations of diverse features to the locations within the alignment, considering the respective alignment positions. Importantly, IDs of the alignment and the MSA have to partly match.

Parameters
  • aln: MSA class
  • annotation: path to file (gb, bed, gff) or raw string or GenBankIterator from biopython