msaexplorer.explore

Explore module

This module contains the two classes MSA and Annotation which are used to read in the respective files and can be used to compute several statistics or be used as the input for the draw module functions.

Classes

   1"""
   2# Explore module
   3
   4This module contains the two classes `MSA` and `Annotation` which are used to read in the respective files
   5and can be used to compute several statistics or be used as the input for the `draw` module functions.
   6
   7## Classes
   8
   9"""
  10
  11# built-in
  12import math
  13import collections
  14import re
  15from copy import deepcopy
  16from typing import Callable, Dict
  17
  18# installed
  19import numpy as np
  20from numpy import ndarray
  21
  22# msaexplorer
  23from msaexplorer import config
  24
  25class MSA:
  26    """
  27    An alignment class that allows computation of several stats
  28    """
  29
  30    def __init__(self, alignment_path: str, reference_id: str = None, zoom_range: tuple | int = None):
  31        """
  32        Initialise an Alignment object.
  33        :param alignment_path: path to alignment file
  34        :param reference_id: reference id
  35        :param zoom_range: start and stop positions to zoom into the alignment
  36        """
  37        self._alignment = self._read_alignment(alignment_path)
  38        self._reference_id = self._validate_ref(reference_id, self._alignment)
  39        self._zoom = self._validate_zoom(zoom_range, self._alignment)
  40        self._aln_type = self._determine_aln_type(self._alignment)
  41
  42    # TODO: read in different alignment types
  43    # Static methods
  44    @staticmethod
  45    def _read_alignment(file_path: str) -> dict:
  46        """
  47        Parse MSA alignment file.
  48        :param file_path: path to alignment file
  49        :return: dictionary with ids as keys and sequences as values
  50        """
  51
  52        def add_seq(aln: dict, sequence_id: str, seq_list: list):
  53            """
  54            Add a complete sequence and check for non-allowed chars
  55            :param aln: alignment dictionary to build
  56            :param sequence_id: sequence id to add
  57            :param seq_list: sequences to add
  58            :return: alignment with added sequences
  59            """
  60            final_seq = ''.join(seq_list).upper()
  61            # Check for non-allowed characters
  62            invalid_chars = set(final_seq) - set(config.POSSIBLE_CHARS)
  63            if invalid_chars:
  64                raise ValueError(
  65                    f"{sequence_id} contains invalid characters: {', '.join(invalid_chars)}. Allowed chars are: {config.POSSIBLE_CHARS}"
  66                )
  67            aln[sequence_id] = final_seq
  68
  69            return aln
  70
  71        alignment, seq_lines = {}, []
  72        seq_id = None
  73
  74        with open(file_path, 'r') as file:
  75            for i, line in enumerate(file):
  76                line = line.strip()
  77                # initial check for fasta format
  78                if i == 0 and not line.startswith(">"):
  79                    raise ValueError('Alignment has to be in fasta format starting with >SeqID.')
  80                if line.startswith(">"):
  81                    if seq_id:
  82                        alignment = add_seq(alignment, seq_id, seq_lines)
  83                    # initialize a new sequence
  84                    seq_id, seq_lines = line[1:], []
  85                else:
  86                    seq_lines.append(line)
  87            # handle last sequence
  88            if seq_id:
  89                alignment = add_seq(alignment, seq_id, seq_lines)
  90        # final sanity checks
  91        if alignment:
  92            # alignment contains only one sequence:
  93            if len(alignment) < 2:
  94                raise ValueError("Alignment must contain more than one sequence.")
  95            # alignment sequences are not same length
  96            first_seq_len = len(next(iter(alignment.values())))
  97            for sequence_id, sequence in alignment.items():
  98                if len(sequence) != first_seq_len:
  99                    raise ValueError(
 100                        f"All alignment sequences must have the same length. Sequence '{sequence_id}' has length {len(sequence)}, expected {first_seq_len}."
 101                    )
 102            # all checks passed
 103            return alignment
 104        else:
 105            raise ValueError(f"Alignment file {file_path} does not contain any sequences in fasta format.")
 106
 107    @staticmethod
 108    def _validate_ref(reference: str | None, alignment: dict) -> str | None | ValueError:
 109        """
 110        Validate if the ref seq is indeed part of the alignment.
 111        :param reference: reference seq id
 112        :param alignment: alignment dict
 113        :return: validated reference
 114        """
 115        if reference in alignment.keys():
 116            return reference
 117        elif reference is None:
 118            return reference
 119        else:
 120            raise ValueError('Reference not in alignment.')
 121
 122    @staticmethod
 123    def _validate_zoom(zoom: tuple | int, original_aln: dict) -> ValueError | tuple | None:
 124        """
 125        Validates if the user defined zoom range is within the start, end of the initial
 126        alignment.\n
 127        :param zoom: zoom range or zoom start
 128        :param original_aln: non-zoomed alignment dict
 129        :return: validated zoom range
 130        """
 131        if zoom is not None:
 132            aln_length = len(original_aln[list(original_aln.keys())[0]])
 133            # check if only over value is provided -> stop is alignment length
 134            if isinstance(zoom, int):
 135                if 0 <= zoom < aln_length:
 136                    return zoom, aln_length - 1
 137                else:
 138                    raise ValueError('Zoom start must be within the alignment length range.')
 139            # check if more than 2 values are provided
 140            if len(zoom) != 2:
 141                raise ValueError('Zoom position have to be (zoom_start, zoom_end)')
 142            # validate zoom start/stop
 143            for position in zoom:
 144                if type(position) != int:
 145                    raise ValueError('Zoom positions have to be integers.')
 146                if position not in range(0, aln_length):
 147                    raise ValueError('Zoom position out of range')
 148
 149        return zoom
 150
 151    @staticmethod
 152    def _determine_aln_type(alignment) -> str:
 153        """
 154        Determine the most likely type of alignment
 155        if 70% of chars in the alignment are nucleotide
 156        chars it is most likely a nt alignment
 157        :return: type of alignment
 158        """
 159        counter = int()
 160        for record in alignment:
 161            if 'U' in alignment[record]:
 162                return 'RNA'
 163            counter += sum(map(alignment[record].count, ['A', 'C', 'G', 'T', 'N', '-']))
 164        # determine which is the most likely type
 165        if counter / len(alignment) >= 0.7 * len(alignment[list(alignment.keys())[0]]):
 166            return 'DNA'
 167        else:
 168            return 'AA'
 169
 170    # Properties with setters
 171    @property
 172    def reference_id(self):
 173        return self._reference_id
 174
 175    @reference_id.setter
 176    def reference_id(self, ref_id: str):
 177        """
 178        Set and validate the reference id.
 179        """
 180        self._reference_id = self._validate_ref(ref_id, self.alignment)
 181
 182    @property
 183    def zoom(self) -> tuple:
 184        return self._zoom
 185
 186    @zoom.setter
 187    def zoom(self, zoom_pos: tuple | int):
 188        """
 189        Validate if the user defined zoom range.
 190        """
 191        self._zoom = self._validate_zoom(zoom_pos, self._alignment)
 192
 193    # Property without setters
 194    @property
 195    def aln_type(self) -> str:
 196        """
 197        define the aln type:
 198        RNA, DNA or AA
 199        """
 200        return self._aln_type
 201
 202    # On the fly properties without setters
 203    @property
 204    def length(self) -> int:
 205        return len(next(iter(self.alignment.values())))
 206
 207    @property
 208    def alignment(self) -> dict:
 209        """
 210        (zoomed) version of the alignment.
 211        """
 212        if self.zoom is not None:
 213            zoomed_aln = dict()
 214            for seq in self._alignment:
 215                zoomed_aln[seq] = self._alignment[seq][self.zoom[0]:self.zoom[1]]
 216            return zoomed_aln
 217        else:
 218            return self._alignment
 219
 220    # functions for different alignment stats
 221    def get_reference_coords(self) -> tuple[int, int]:
 222        """
 223        Determine the start and end coordinates of the reference sequence
 224        defined as the first/last nucleotide in the reference sequence
 225        (excluding N and gaps).
 226
 227        :return: start, end
 228        """
 229        start, end = 0, self.length
 230
 231        if self.reference_id is None:
 232            return start, end
 233        else:
 234            # 5' --> 3'
 235            for start in range(self.length):
 236                if self.alignment[self.reference_id][start] not in ['-', 'N']:
 237                    break
 238            # 3' --> 5'
 239            for end in range(self.length - 1, 0, -1):
 240                if self.alignment[self.reference_id][end] not in ['-', 'N']:
 241                    break
 242
 243            return start, end
 244
 245    def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
 246        """
 247        Creates a non-gapped consensus sequence.
 248
 249        :param threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes
 250            the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments)
 251            or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
 252        :param use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position
 253            has a frequency above the defined threshold.
 254        :return: consensus sequence
 255        """
 256
 257        # helper functions
 258        def determine_counts(alignment_dict: dict, position: int) -> dict:
 259            """
 260            count the number of each char at
 261            an idx of the alignment. return sorted dic.
 262            handles ambiguous nucleotides in sequences.
 263            also handles gaps.
 264            """
 265            nucleotide_list = []
 266
 267            # get all nucleotides
 268            for sequence in alignment_dict.items():
 269                nucleotide_list.append(sequence[1][position])
 270            # count occurences of nucleotides
 271            counter = dict(collections.Counter(nucleotide_list))
 272            # get permutations of an ambiguous nucleotide
 273            to_delete = []
 274            temp_dict = {}
 275            for nucleotide in counter:
 276                if nucleotide in config.AMBIG_CHARS[self.aln_type]:
 277                    to_delete.append(nucleotide)
 278                    permutations = config.AMBIG_CHARS[self.aln_type][nucleotide]
 279                    adjusted_freq = 1 / len(permutations)
 280                    for permutation in permutations:
 281                        if permutation in temp_dict:
 282                            temp_dict[permutation] += adjusted_freq
 283                        else:
 284                            temp_dict[permutation] = adjusted_freq
 285
 286            # drop ambiguous entries and add adjusted freqs to
 287            if to_delete:
 288                for i in to_delete:
 289                    counter.pop(i)
 290                for nucleotide in temp_dict:
 291                    if nucleotide in counter:
 292                        counter[nucleotide] += temp_dict[nucleotide]
 293                    else:
 294                        counter[nucleotide] = temp_dict[nucleotide]
 295
 296            return dict(sorted(counter.items(), key=lambda x: x[1], reverse=True))
 297
 298        def get_consensus_char(counts: dict, cutoff: float) -> list:
 299            """
 300            get a list of nucleotides for the consensus seq
 301            """
 302            n = 0
 303
 304            consensus_chars = []
 305            for char in counts:
 306                n += counts[char]
 307                consensus_chars.append(char)
 308                if n >= cutoff:
 309                    break
 310
 311            return consensus_chars
 312
 313        def get_ambiguous_char(nucleotides: list) -> str:
 314            """
 315            get ambiguous char from a list of nucleotides
 316            """
 317            for ambiguous, permutations in config.AMBIG_CHARS[self.aln_type].items():
 318                if set(permutations) == set(nucleotides):
 319                    return ambiguous
 320
 321        # check if params have been set correctly
 322        if threshold is not None:
 323            if threshold < 0 or threshold > 1:
 324                raise ValueError('Threshold must be between 0 and 1.')
 325        if self.aln_type == 'AA' and use_ambig_nt:
 326            raise ValueError('Ambiguous characters can not be calculated for amino acid alignments.')
 327        if threshold is None and use_ambig_nt:
 328            raise ValueError('To calculate ambiguous nucleotides, set a threshold > 0.')
 329
 330        alignment = self.alignment
 331        consensus = str()
 332
 333        if threshold is not None:
 334            consensus_cutoff = len(alignment) * threshold
 335        else:
 336            consensus_cutoff = 0
 337
 338        # built consensus sequences
 339        for idx in range(self.length):
 340            char_counts = determine_counts(alignment, idx)
 341            consensus_chars = get_consensus_char(
 342                char_counts,
 343                consensus_cutoff
 344            )
 345            if threshold != 0:
 346                if len(consensus_chars) > 1:
 347                    if use_ambig_nt:
 348                        char = get_ambiguous_char(consensus_chars)
 349                    else:
 350                        if self.aln_type == 'AA':
 351                            char = 'X'
 352                        else:
 353                            char = 'N'
 354                    consensus = consensus + char
 355                else:
 356                    consensus = consensus + consensus_chars[0]
 357            else:
 358                consensus = consensus + consensus_chars[0]
 359
 360        return consensus
 361
 362    def get_conserved_orfs(self, min_length: int = 100, identity_cutoff: float | None = None) -> dict:
 363        """
 364        **conserved ORF definition:**
 365            - conserved starts and stops
 366            - start, stop must be on the same frame
 367            - stop - start must be at least min_length
 368            - all ungapped seqs[start:stop] must have at least min_length
 369            - no ungapped seq can have a Stop in between Start Stop
 370
 371        Conservation is measured by number of positions with identical characters divided by
 372        orf slice of the alignment.
 373
 374        **Algorithm overview:**
 375            - check for conserved start and stop codons
 376            - iterate over all three frames
 377            - check each start and next sufficiently far away stop codon
 378            - check if all ungapped seqs between start and stop codon are >= min_length
 379            - check if no ungapped seq in the alignment has a stop codon
 380            - write to dictionary
 381            - classify as internal if the stop codon has already been written with a prior start
 382            - repeat for reverse complement
 383
 384        :return: ORF positions and internal ORF positions
 385        """
 386
 387        # helper functions
 388        def determine_conserved_start_stops(alignment: dict, alignment_length: int) -> tuple:
 389            """
 390            Determine all start and stop codons within an alignment.
 391            :param alignment: alignment
 392            :param alignment_length: length of alignment
 393            :return: start and stop codon positions
 394            """
 395            starts = config.START_CODONS[self.aln_type]
 396            stops = config.STOP_CODONS[self.aln_type]
 397
 398            list_of_starts, list_of_stops = [], []
 399            ref = alignment[list(alignment.keys())[0]]
 400            for nt_position in range(alignment_length):
 401                if ref[nt_position:nt_position + 3] in starts:
 402                    conserved_start = True
 403                    for sequence in alignment:
 404                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in starts:
 405                            conserved_start = False
 406                            break
 407                    if conserved_start:
 408                        list_of_starts.append(nt_position)
 409
 410                if ref[nt_position:nt_position + 3] in stops:
 411                    conserved_stop = True
 412                    for sequence in alignment:
 413                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in stops:
 414                            conserved_stop = False
 415                            break
 416                    if conserved_stop:
 417                        list_of_stops.append(nt_position)
 418
 419            return list_of_starts, list_of_stops
 420
 421        def get_ungapped_sliced_seqs(alignment: dict, start_pos: int, stop_pos: int) -> list:
 422            """
 423            get ungapped sequences starting and stop codons and eliminate gaps
 424            :param alignment: alignment
 425            :param start_pos: start codon
 426            :param stop_pos: stop codon
 427            :return: sliced sequences
 428            """
 429            ungapped_seqs = []
 430            for seq_id in alignment:
 431                ungapped_seqs.append(alignment[seq_id][start_pos:stop_pos + 3].replace('-', ''))
 432
 433            return ungapped_seqs
 434
 435        def additional_stops(ungapped_seqs: list) -> bool:
 436            """
 437            Checks for the presence of a stop codon
 438            :param ungapped_seqs: list of ungapped sequences
 439            :return: Additional stop codons (True/False)
 440            """
 441            stops = config.STOP_CODONS[self.aln_type]
 442
 443            for sliced_seq in ungapped_seqs:
 444                for position in range(0, len(sliced_seq) - 3, 3):
 445                    if sliced_seq[position:position + 3] in stops:
 446                        return True
 447            return False
 448
 449        def calculate_identity(identity_matrix: ndarray, aln_slice:list) -> float:
 450            sliced_array = identity_matrix[:,aln_slice[0]:aln_slice[1]] + 1  # identical = 0, different = -1 --> add 1
 451            return np.sum(np.all(sliced_array == 1, axis=0))/(aln_slice[1] - aln_slice[0]) * 100
 452
 453        # checks for arguments
 454        if self.aln_type == 'AA':
 455            raise TypeError('ORF search only for RNA/DNA alignments')
 456
 457        if identity_cutoff is not None:
 458            if identity_cutoff > 100 or identity_cutoff < 0:
 459                raise ValueError('conservation cutoff must be between 0 and 100')
 460
 461        if min_length <= 0 or min_length > self.length:
 462            raise ValueError(f'min_length must be between 0 and {self.length}')
 463
 464        # ini
 465        identities = self.calc_identity_alignment()
 466        alignments = [self.alignment, self.calc_reverse_complement_alignment()]
 467        aln_len = self.length
 468
 469        orf_counter = 0
 470        orf_dict = {}
 471
 472        for aln, direction in zip(alignments, ['+', '-']):
 473            # check for starts and stops in the first seq and then check if these are present in all seqs
 474            conserved_starts, conserved_stops = determine_conserved_start_stops(aln, aln_len)
 475            # check each frame
 476            for frame in (0, 1, 2):
 477                potential_starts = [x for x in conserved_starts if x % 3 == frame]
 478                potential_stops = [x for x in conserved_stops if x % 3 == frame]
 479                last_stop = -1
 480                for start in potential_starts:
 481                    # go to the next stop that is sufficiently far away in the alignment
 482                    next_stops = [x for x in potential_stops if x + 3 >= start + min_length]
 483                    if not next_stops:
 484                        continue
 485                    next_stop = next_stops[0]
 486                    ungapped_sliced_seqs = get_ungapped_sliced_seqs(aln, start, next_stop)
 487                    # re-check the lengths of all ungapped seqs
 488                    ungapped_seq_lengths = [len(x) >= min_length for x in ungapped_sliced_seqs]
 489                    if not all(ungapped_seq_lengths):
 490                        continue
 491                    # if no stop codon between start and stop --> write to dictionary
 492                    if not additional_stops(ungapped_sliced_seqs):
 493                        if direction == '+':
 494                            positions = [start, next_stop + 3]
 495                        else:
 496                            positions = [aln_len - next_stop - 3, aln_len - start]
 497                        if last_stop != next_stop:
 498                            last_stop = next_stop
 499                            conservation = calculate_identity(identities, positions)
 500                            if identity_cutoff is not None and conservation < identity_cutoff:
 501                                continue
 502                            orf_dict[f'ORF_{orf_counter}'] = {'location': [positions],
 503                                                              'frame': frame,
 504                                                              'strand': direction,
 505                                                              'conservation': conservation,
 506                                                              'internal': []
 507                                                              }
 508                            orf_counter += 1
 509                        else:
 510                            orf_dict[f'ORF_{orf_counter - 1}']['internal'].append(positions)
 511
 512        return orf_dict
 513
 514    def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff:float = None) -> dict:
 515        """
 516        First calculates all ORFs and then searches from 5'
 517        all non-overlapping orfs in the fw strand and from the
 518        3' all non-overlapping orfs in th rw strand.
 519
 520        **No overlap algorithm:**
 521            **frame 1:** -[M------*]--- ----[M--*]---------[M-----
 522
 523            **frame 2:** -------[M------*]---------[M---*]--------
 524
 525            **frame 3:** [M---*]-----[M----------*]----------[M---
 526
 527            **results:** [M---*][M------*]--[M--*]-[M---*]-[M-----
 528
 529            frame:    3      2           1      2       1
 530
 531        :return: dictionary with non-overlapping orfs
 532        """
 533        orf_dict = self.get_conserved_orfs(min_length, identity_cutoff)
 534
 535        fw_orfs, rw_orfs = [], []
 536
 537        for orf in orf_dict:
 538            if orf_dict[orf]['strand'] == '+':
 539                fw_orfs.append((orf, orf_dict[orf]['location'][0]))
 540            else:
 541                rw_orfs.append((orf, orf_dict[orf]['location'][0]))
 542
 543        fw_orfs.sort(key=lambda x: x[1][0])  # sort by start pos
 544        rw_orfs.sort(key=lambda x: x[1][1], reverse=True)  # sort by stop pos
 545        non_overlapping_orfs = []
 546        for orf_list, strand in zip([fw_orfs, rw_orfs], ['+', '-']):
 547            previous_stop = -1 if strand == '+' else self.length + 1
 548            for orf in orf_list:
 549                if strand == '+' and orf[1][0] > previous_stop:
 550                    non_overlapping_orfs.append(orf[0])
 551                    previous_stop = orf[1][1]
 552                elif strand == '-' and orf[1][1] < previous_stop:
 553                    non_overlapping_orfs.append(orf[0])
 554                    previous_stop = orf[1][0]
 555
 556        non_overlap_dict = {}
 557        for orf in orf_dict:
 558            if orf in non_overlapping_orfs:
 559                non_overlap_dict[orf] = orf_dict[orf]
 560
 561        return non_overlap_dict
 562
 563    def calc_length_stats(self) -> dict:
 564        """
 565        Determine the stats for the length of the ungapped seqs in the alignment.
 566        :return: dictionary with length stats
 567        """
 568
 569        seq_lengths = [len(self.alignment[x].replace('-', '')) for x in self.alignment]
 570
 571        return {'number of seq': len(self.alignment),
 572                'mean length': float(np.mean(seq_lengths)),
 573                'std length': float(np.std(seq_lengths)),
 574                'min length': int(np.min(seq_lengths)),
 575                'max length': int(np.max(seq_lengths))
 576                }
 577
 578    def calc_entropy(self) -> list:
 579        """
 580        Calculate the normalized shannon's entropy for every position in an alignment:
 581
 582        - 1: high entropy
 583        - 0: low entropy
 584
 585        :return: Entropies at each position.
 586        """
 587
 588        # helper functions
 589        def shannons_entropy(character_list: list, states: int, aln_type: str) -> float:
 590            """
 591            Calculate the shannon's entropy of a sequence and
 592            normalized between 0 and 1.
 593            :param character_list: characters at an alignment position
 594            :param states: number of potential characters that can be present
 595            :param aln_type: type of the alignment
 596            :returns: entropy
 597            """
 598            ent, n_chars = 0, len(character_list)
 599            # only one char is in the list
 600            if n_chars <= 1:
 601                return ent
 602            # calculate the number of unique chars and their counts
 603            chars, char_counts = np.unique(character_list, return_counts=True)
 604            char_counts = char_counts.astype(float)
 605            # ignore gaps for entropy calc
 606            char_counts, chars = char_counts[chars != "-"], chars[chars != "-"]
 607            # correctly handle ambiguous chars
 608            index_to_drop = []
 609            for index, char in enumerate(chars):
 610                if char in config.AMBIG_CHARS[aln_type]:
 611                    index_to_drop.append(index)
 612                    amb_chars, amb_counts = np.unique(config.AMBIG_CHARS[aln_type][char], return_counts=True)
 613                    amb_counts = amb_counts / len(config.AMBIG_CHARS[aln_type][char])
 614                    # add the proportionate numbers to initial array
 615                    for amb_char, amb_count in zip(amb_chars, amb_counts):
 616                        if amb_char in chars:
 617                            char_counts[chars == amb_char] += amb_count
 618                        else:
 619                            chars, char_counts = np.append(chars, amb_char), np.append(char_counts, amb_count)
 620            # drop the ambiguous characters from array
 621            char_counts, chars = np.delete(char_counts, index_to_drop), np.delete(chars, index_to_drop)
 622            # calc the entropy
 623            probs = char_counts / n_chars
 624            if np.count_nonzero(probs) <= 1:
 625                return ent
 626            for prob in probs:
 627                ent -= prob * math.log(prob, states)
 628
 629            return ent
 630
 631        aln = self.alignment
 632        entropys = []
 633
 634        if self.aln_type == 'AA':
 635            states = 20
 636        else:
 637            states = 4
 638        # iterate over alignment positions and the sequences
 639        for nuc_pos in range(self.length):
 640            pos = []
 641            for record in aln:
 642                pos.append(aln[record][nuc_pos])
 643            entropys.append(shannons_entropy(pos, states, self.aln_type))
 644
 645        return entropys
 646
 647    def calc_gc(self) -> list | TypeError:
 648        """
 649        Determine the GC content for every position in an nt alignment.
 650        :return: GC content for every position.
 651        :raises: TypeError for AA alignments
 652        """
 653        if self.aln_type == 'AA':
 654            raise TypeError("GC computation is not possible for aminoacid alignment")
 655
 656        gc, aln, amb_nucs = [], self.alignment, config.AMBIG_CHARS[self.aln_type]
 657
 658        for position in range(self.length):
 659            nucleotides = str()
 660            for record in aln:
 661                nucleotides = nucleotides + aln[record][position]
 662            # ini dict with chars that occur and which ones to
 663            # count in which freq
 664            to_count = {
 665                'G': 1,
 666                'C': 1,
 667            }
 668            # handle ambig. nuc
 669            for char in amb_nucs:
 670                if char in nucleotides:
 671                    to_count[char] = (amb_nucs[char].count('C') + amb_nucs[char].count('G')) / len(amb_nucs[char])
 672
 673            gc.append(
 674                sum([nucleotides.count(x) * to_count[x] for x in to_count]) / len(nucleotides)
 675            )
 676
 677        return gc
 678
 679    def calc_coverage(self) -> list:
 680        """
 681        Determine the coverage of every position in an alignment.
 682        This is defined as:
 683            1 - cumulative length of '-' characters
 684
 685        :return: Coverage at each alignment position.
 686        """
 687        coverage, aln = [], self.alignment
 688
 689        for nuc_pos in range(self.length):
 690            pos = str()
 691            for record in aln.keys():
 692                pos = pos + aln[record][nuc_pos]
 693            coverage.append(1 - pos.count('-') / len(pos))
 694
 695        return coverage
 696
 697    def calc_reverse_complement_alignment(self) -> dict | TypeError:
 698        """
 699        Reverse complement the alignment.
 700        :return: Alignment (rv)
 701        """
 702        if self.aln_type == 'AA':
 703            raise TypeError('Reverse complement only for RNA or DNA.')
 704
 705        aln = self.alignment
 706        reverse_complement_dict = {}
 707
 708        for seq_id in aln:
 709            reverse_complement_dict[seq_id] = ''.join(config.COMPLEMENT[base] for base in reversed(aln[seq_id]))
 710
 711        return reverse_complement_dict
 712
 713    def calc_identity_alignment(self, encode_mismatches:bool=True, encode_mask:bool=False, encode_gaps:bool=True, encode_ambiguities:bool=False, encode_each_mismatch_char:bool=False) -> np.ndarray:
 714        """
 715        Converts alignment to identity array (identical=0) compared to majority consensus or reference:\n
 716
 717        :param encode_mismatches: encode mismatch as -1
 718        :param encode_mask: encode mask with value=-2 --> also in the reference
 719        :param encode_gaps: encode gaps with np.nan --> also in the reference
 720        :param encode_ambiguities: encode ambiguities with value=-3
 721        :param encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.DNA_colors, config.RNA_colors or config.AA_colors
 722        :return: identity alignment
 723        """
 724
 725        aln = self.alignment
 726        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
 727
 728        # convert alignment to array
 729        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 730        reference = np.array(list(ref))
 731        # ini matrix
 732        identity_matrix = np.full(sequences.shape, 0, dtype=float)
 733
 734        is_identical = sequences == reference
 735
 736        if encode_gaps:
 737            is_gap = sequences == '-'
 738        else:
 739            is_gap = np.full(sequences.shape, False)
 740
 741        if encode_mask:
 742            if self.aln_type == 'AA':
 743                is_n_or_x = np.isin(sequences, ['X'])
 744            else:
 745                is_n_or_x = np.isin(sequences, ['N'])
 746        else:
 747            is_n_or_x = np.full(sequences.shape, False)
 748
 749        if encode_ambiguities:
 750            is_ambig = np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])
 751        else:
 752            is_ambig = np.full(sequences.shape, False)
 753
 754        if encode_mismatches:
 755            is_mismatch = ~is_gap & ~is_identical & ~is_n_or_x & ~is_ambig
 756        else:
 757            is_mismatch = np.full(sequences.shape, False)
 758
 759        # encode every different character
 760        if encode_each_mismatch_char:
 761            for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]):
 762                new_encoding = np.isin(sequences, [char]) & is_mismatch
 763                identity_matrix[new_encoding] = idx + 1
 764        # or encode different with a single value
 765        else:
 766            identity_matrix[is_mismatch] = -1  # mismatch
 767
 768        identity_matrix[is_gap] = np.nan  # gap
 769        identity_matrix[is_n_or_x] = -2  # 'N' or 'X'
 770        identity_matrix[is_ambig] = -3  # ambiguities
 771
 772        return identity_matrix
 773
 774    def calc_similarity_alignment(self, matrix_type:str|None=None, normalize:bool=True) -> np.ndarray:
 775        """
 776        Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight
 777        differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the
 778        reference residue at each column. Gaps are encoded as np.nan.
 779
 780        The calculation follows these steps:
 781
 782        1. **Reference Sequence**: If a reference sequence is provided (via `self.reference_id`), it is used. Otherwise,
 783           a consensus sequence is generated to serve as the reference.
 784        2. **Substitution Matrix**: The similarity between residues is determined using a substitution matrix, such as
 785           BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
 786        3. **Per-Column Normalization (optional)**:
 787           - For each column in the alignment:
 788             - The residue in the reference sequence is treated as the baseline for that column.
 789             - The substitution scores for the reference residue are extracted from the substitution matrix.
 790             - The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference
 791               residue.
 792           - This ensures that identical residues (or those with high similarity to the reference) have high scores,
 793             while more dissimilar residues have lower scores.
 794        4. **Output**:
 795           - The normalized similarity scores are stored in a NumPy array.
 796           - Gaps (if any) or residues not present in the substitution matrix are encoded as `np.nan`.
 797
 798        :param: matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
 799        :param: normalize: whether to normalize the similarity scores to range [0, 1]
 800        :return: A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue
 801            and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity).
 802            Gaps and invalid residues are encoded as `np.nan`.
 803        :raise: ValueError
 804            If the specified substitution matrix is not available for the given alignment type.
 805        """
 806
 807        aln = self.alignment
 808        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
 809        if matrix_type is None:
 810            if self.aln_type == 'AA':
 811                matrix_type = 'BLOSUM65'
 812            else:
 813                matrix_type = 'TRANS'
 814        # load substitution matrix as dictionary
 815        try:
 816            subs_matrix = config.SUBS_MATRICES[self.aln_type][matrix_type]
 817        except KeyError:
 818            raise ValueError(
 819                f'The specified matrix does not exist for alignment type.\nAvailable matrices for {self.aln_type} are:\n{list(config.SUBS_MATRICES[self.aln_type].keys())}'
 820            )
 821
 822        # set dtype and convert alignment to a NumPy array for vectorized processing
 823        dtype = np.dtype(float, metadata={'matrix': matrix_type})
 824        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 825        reference = np.array(list(ref))
 826        valid_chars = list(subs_matrix.keys())
 827        similarity_array = np.full(sequences.shape, np.nan, dtype=dtype)
 828
 829        for j, ref_char in enumerate(reference):
 830            if ref_char not in valid_chars + ['-']:
 831                continue
 832            # Get local min and max for the reference residue
 833            if normalize and ref_char != '-':
 834                local_scores = subs_matrix[ref_char].values()
 835                local_min, local_max = min(local_scores), max(local_scores)
 836
 837            for i, char in enumerate(sequences[:, j]):
 838                if char not in valid_chars:
 839                    continue
 840                # classify the similarity as max if the reference has a gap
 841                similarity_score = subs_matrix[char][ref_char] if ref_char != '-' else 1
 842                similarity_array[i, j] = (similarity_score - local_min) / (local_max - local_min) if normalize and ref_char != '-' else similarity_score
 843
 844        return similarity_array
 845
 846    def calc_position_matrix(self, matrix_type:str='PWM') -> np.ndarray | ValueError:
 847        """
 848        Calculate various position matrices (reference https://en.wikipedia.org/wiki/Position_weight_matrix)
 849
 850        **Major steps:**
 851            1) calculate character counts (PFM)
 852            2) calculate character frequencies (PPM)
 853            3) add pseudocount (square root of row length) -> scales with aln size --> needed for positions with 0 counts
 854            4) transform to PWM with M_k,j=log_2(M_k,j/b_k) with b_k assuming statistical independence (all chars are equally frequent)
 855
 856        :param matrix_type: matrix to return (PFM, PPM or PWM)
 857        :return: pwm as numpy array
 858        :raise: ValueError for incorrect matrix type
 859        """
 860
 861        # ini
 862        aln = self.alignment
 863        if matrix_type not in ['PFM', 'PPM', 'PWM']:
 864            raise ValueError('Matrix_type must be PFM, PPM or PWM.')
 865        possible_chars = list(config.CHAR_COLORS[self.aln_type].keys())[:-1]
 866        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 867
 868        # calc position frequency matrix
 869        pfm = np.array([np.sum(sequences == char, 0) for char in possible_chars]) \
 870              + np.sqrt(sequences.shape[0])  # add pseudo-counts
 871        if matrix_type == 'PFM':
 872            return pfm
 873
 874        # calc position probability matrix
 875        ppm = pfm/np.sum(pfm, 0)
 876        if matrix_type == 'PPM':
 877            return ppm
 878
 879        # calc position weight matrix
 880        pwm = np.log2(ppm*len(possible_chars))
 881        if matrix_type == 'PWM':
 882            return pwm
 883
 884    def calc_percent_recovery(self) -> dict:
 885        """
 886        Recovery per sequence either compared to the majority consensus seq
 887        or the reference seq.\n
 888        Defined as:\n
 889
 890        `(1 - sum(N/X/- characters in ungapped ref regions))*100`
 891
 892        This is highly similar to how nextclade calculates recovery over reference.
 893
 894        :return: dict
 895        """
 896
 897        aln = self.alignment
 898
 899        if self.reference_id is not None:
 900            ref = aln[self.reference_id]
 901        else:
 902            ref = self.get_consensus()  # majority consensus
 903
 904        if not any(char != '-' for char in ref):
 905            raise ValueError("Reference sequence is entirely gapped, cannot calculate recovery.")
 906
 907
 908        # count 'N', 'X' and '-' chars in non-gapped regions
 909        recovery_over_ref = dict()
 910
 911        # Get positions of non-gap characters in the reference
 912        non_gap_positions = [i for i, char in enumerate(ref) if char != '-']
 913        cumulative_length = len(non_gap_positions)
 914
 915        # Calculate recovery
 916        for seq_id in aln:
 917            if seq_id == self.reference_id:
 918                continue
 919            seq = aln[seq_id]
 920            count_invalid = sum(
 921                seq[pos] == '-' or
 922                (seq[pos] == 'X' if self.aln_type == "AA" else seq[pos] == 'N')
 923                for pos in non_gap_positions
 924            )
 925            recovery_over_ref[seq_id] = (1 - count_invalid / cumulative_length) * 100
 926
 927        return recovery_over_ref
 928
 929    def calc_character_frequencies(self) -> dict:
 930        """
 931        Calculate the percentage characters in the alignment:
 932        The frequencies are counted by seq and in total. The
 933        percentage of non-gap characters in the alignment is
 934        relative to the total number of non-gap characters.
 935        The gap percentage is relative to the sequence length.
 936
 937        The output is a nested dictionary.
 938
 939        :return: Character frequencies
 940        """
 941
 942        aln, aln_length = self.alignment, self.length
 943
 944        freqs = {'total': {'-': {'counts': 0, '% of alignment': float()}}}
 945
 946        for seq_id in aln:
 947            freqs[seq_id], all_chars = {'-': {'counts': 0, '% of alignment': float()}}, 0
 948            unique_chars = set(aln[seq_id])
 949            for char in unique_chars:
 950                if char == '-':
 951                    continue
 952                # add characters to dictionaries
 953                if char not in freqs[seq_id]:
 954                    freqs[seq_id][char] = {'counts': 0, '% of non-gapped': 0}
 955                if char not in freqs['total']:
 956                    freqs['total'][char] = {'counts': 0, '% of non-gapped': 0}
 957                # count non-gap chars
 958                freqs[seq_id][char]['counts'] += aln[seq_id].count(char)
 959                freqs['total'][char]['counts'] += freqs[seq_id][char]['counts']
 960                all_chars += freqs[seq_id][char]['counts']
 961            # normalize counts
 962            for char in freqs[seq_id]:
 963                if char == '-':
 964                    continue
 965                freqs[seq_id][char]['% of non-gapped'] = freqs[seq_id][char]['counts'] / all_chars * 100
 966                freqs['total'][char]['% of non-gapped'] += freqs[seq_id][char]['% of non-gapped']
 967            # count gaps
 968            freqs[seq_id]['-']['counts'] = aln[seq_id].count('-')
 969            freqs['total']['-']['counts'] += freqs[seq_id]['-']['counts']
 970            # normalize gap counts
 971            freqs[seq_id]['-']['% of alignment'] = freqs[seq_id]['-']['counts'] / aln_length * 100
 972            freqs['total']['-']['% of alignment'] += freqs[seq_id]['-']['% of alignment']
 973
 974        # normalize the total counts
 975        for char in freqs['total']:
 976            for value in freqs['total'][char]:
 977                if value == '% of alignment' or value == '% of non-gapped':
 978                    freqs['total'][char][value] = freqs['total'][char][value] / len(aln)
 979
 980        return freqs
 981
 982    def calc_pairwise_identity_matrix(self, distance_type:str='ghd') -> ndarray:
 983        """
 984        Calculate pairwise identities for an alignment. As there are different definitions of sequence identity, there are different options implemented:
 985
 986        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
 987        \ndistance = matches / alignment_length * 100
 988
 989        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
 990        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
 991
 992        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
 993        \ndistance = matches / (matches + mismatches) * 100
 994
 995        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
 996        \ndistance = matches / gap_compressed_alignment_length * 100
 997
 998        :return: array with pairwise distances.
 999        """
1000
1001        def hamming_distance(seq1: str, seq2: str) -> int:
1002            return sum(c1 == c2 for c1, c2 in zip(seq1, seq2))
1003
1004        def ghd(seq1: str, seq2: str) -> float:
1005            return hamming_distance(seq1, seq2) / self.length * 100
1006
1007        def lhd(seq1: str, seq2: str) -> float:
1008            # remove 5' trailing gaps
1009            i, j = 0, self.length - 1
1010            while i < self.length and (seq1[i] == '-' or seq2[i] == '-'):
1011                i += 1
1012            while j >= 0 and (seq1[j] == '-' or seq2[j] == '-'):
1013                j -= 1
1014            if i > j:
1015                return 0.0
1016            # slice seq
1017            seq1_, seq2_ = seq1[i:j + 1], seq2[i:j + 1]
1018
1019            return hamming_distance(seq1_, seq2_) / min([len(seq1_), len(seq2_)]) * 100
1020
1021        def ged(seq1: str, seq2: str) -> float:
1022
1023            matches, mismatches = 0, 0
1024
1025            for c1, c2 in zip(seq1, seq2):
1026                if c1 != '-' and c2 != '-':
1027                    if c1 == c2:
1028                        matches += 1
1029                    else:
1030                        mismatches += 1
1031            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1032
1033        def gcd(seq1: str, seq2: str) -> float:
1034            matches = 0
1035            mismatches = 0
1036            in_gap = False
1037
1038            for char1, char2 in zip(seq1, seq2):
1039                if char1 == '-' and char2 == '-':  # Shared gap: do nothing
1040                    continue
1041                elif char1 == '-' or char2 == '-':  # Gap in only one sequence
1042                    if not in_gap:  # Start of a new gap stretch
1043                        mismatches += 1
1044                        in_gap = True
1045                else:  # No gaps
1046                    in_gap = False
1047                    if char1 == char2:  # Matching characters
1048                        matches += 1
1049                    else:  # Mismatched characters
1050                        mismatches += 1
1051
1052            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1053
1054
1055        # Map distance type to corresponding function
1056        distance_functions: Dict[str, Callable[[str, str], float]] = {
1057            'ghd': ghd,
1058            'lhd': lhd,
1059            'ged': ged,
1060            'gcd': gcd
1061        }
1062
1063        if distance_type not in distance_functions:
1064            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1065
1066        # Compute pairwise distances
1067        aln = self.alignment
1068        distance_func = distance_functions[distance_type]
1069        distance_matrix = np.zeros((len(aln), len(aln)))
1070
1071        sequences = list(aln.values())
1072        for i, seq1 in enumerate(sequences):
1073            for j, seq2 in enumerate(sequences):
1074                if i <= j:  # Compute only once for symmetric matrix
1075                    distance_matrix[i, j] = distance_func(seq1, seq2)
1076                    distance_matrix[j, i] = distance_matrix[i, j]
1077
1078        return distance_matrix
1079
1080    def get_snps(self, include_ambig:bool=False) -> dict:
1081        """
1082        Calculate snps similar to snp-sites (output is comparable):
1083        https://github.com/sanger-pathogens/snp-sites
1084        Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character.
1085        The SNPs are compared to a majority consensus sequence or to a reference if it has been set.
1086
1087        :param include_ambig: Include ambiguous snps (default: False)
1088        :return: dictionary containing snp positions and their variants including their frequency.
1089        """
1090        aln = self.alignment
1091        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
1092        aln = {x: aln[x] for x in aln.keys() if x != self.reference_id}
1093        seq_ids = list(aln.keys())
1094        snp_dict = {'#CHROM': self.reference_id if self.reference_id is not None else 'consensus', 'POS': {}}
1095
1096        for pos in range(self.length):
1097            reference_char = ref[pos]
1098            if not include_ambig:
1099                if reference_char in config.AMBIG_CHARS[self.aln_type] and reference_char != '-':
1100                    continue
1101            alt_chars, snps = [], []
1102            for i, seq_id in enumerate(aln.keys()):
1103                alt_chars.append(aln[seq_id][pos])
1104                if reference_char != aln[seq_id][pos]:
1105                    snps.append(i)
1106            if not snps:
1107                continue
1108            if include_ambig:
1109                if all(alt_chars[x] in config.AMBIG_CHARS[self.aln_type] for x in snps):
1110                    continue
1111            else:
1112                snps = [x for x in snps if alt_chars[x] not in config.AMBIG_CHARS[self.aln_type]]
1113                if not snps:
1114                    continue
1115            if pos not in snp_dict:
1116                snp_dict['POS'][pos] = {'ref': reference_char, 'ALT': {}}
1117            for snp in snps:
1118                if alt_chars[snp] not in snp_dict['POS'][pos]['ALT']:
1119                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]] = {
1120                        'AF': 1,
1121                        'SEQ_ID': [seq_ids[snp]]
1122                    }
1123                else:
1124                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['AF'] += 1
1125                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['SEQ_ID'].append(seq_ids[snp])
1126            # calculate AF
1127            if pos in snp_dict['POS']:
1128                for alt in snp_dict['POS'][pos]['ALT']:
1129                    snp_dict['POS'][pos]['ALT'][alt]['AF'] /= len(aln)
1130
1131        return snp_dict
1132
1133    def calc_transition_transversion_score(self) -> list:
1134        """
1135        Based on the snp positions, calculates a transition/transversions score.
1136        A positive score means higher ratio of transitions and negative score means
1137        a higher ratio of transversions.
1138        :return: list
1139        """
1140
1141        if self.aln_type == 'AA':
1142            raise TypeError('TS/TV scoring only for RNA/DNA alignments')
1143
1144        # ini
1145        snps = self.get_snps()
1146        score = [0]*self.length
1147
1148        for pos in snps['POS']:
1149            t_score_temp = 0
1150            for alt in snps['POS'][pos]['ALT']:
1151                # check the type of substitution
1152                if snps['POS'][pos]['ref'] + alt in ['AG', 'GA', 'CT', 'TC', 'CU', 'UC']:
1153                    score[pos] += snps['POS'][pos]['ALT'][alt]['AF']
1154                else:
1155                    score[pos] -= snps['POS'][pos]['ALT'][alt]['AF']
1156
1157        return score
1158
1159
1160class Annotation:
1161    """
1162    An annotation class that allows to read in gff, gb or bed files and adjust its locations to that of the MSA.
1163    """
1164
1165    def __init__(self, aln: MSA, annotation_path: str):
1166        """
1167        The annotation class. Lets you parse multiple standard formats
1168        which might be used for annotating an alignment. The main purpose
1169        is to parse the annotation file and adapt the locations of diverse
1170        features to the locations within the alignment, considering the
1171        respective alignment positions. Importantly, IDs of the alignment
1172        and the MSA have to partly match.
1173
1174        :param aln: MSA class
1175        :param annotation_path: path to annotation file (gb, bed, gff).
1176
1177        """
1178
1179        self.ann_type, self._seq_id, self.locus, self.features  = self._parse_annotation(annotation_path, aln)  # read annotation
1180        self._gapped_seq = self._MSA_validation_and_seq_extraction(aln, self._seq_id)  # extract gapped sequence
1181        self._position_map = self._build_position_map()  # build a position map
1182        self._map_to_alignment()  # adapt feature locations
1183
1184    @staticmethod
1185    def _MSA_validation_and_seq_extraction(aln: MSA, seq_id: str) -> str:
1186        """
1187        extract gapped sequence from MSA that corresponds to annotation
1188        :param aln: MSA class
1189        :param seq_id: sequence id to extract
1190        :return: gapped sequence
1191        """
1192        if not isinstance(aln, MSA):
1193            raise ValueError('alignment has to be an MSA class. use explore.MSA() to read in alignment')
1194        else:
1195            return aln._alignment[seq_id]
1196
1197    @staticmethod
1198    def _parse_annotation(annotation_path: str, aln: MSA) -> tuple[str, str, str, Dict]:
1199
1200        def detect_annotation_type(file_path: str) -> str:
1201            """
1202            Detect the type of annotation file (GenBank, GFF, or BED) based
1203            on the first relevant line (excluding empty and #)
1204
1205            :param file_path: Path to the annotation file.
1206            :return: The detected file type ('gb', 'gff', or 'bed').
1207
1208            :raises ValueError: If the file type cannot be determined.
1209            """
1210
1211            with open(file_path, 'r') as file:
1212                for line in file:
1213                    # skip empty lines and comments
1214                    if not line.strip() or line.startswith('#'):
1215                        continue
1216                   # genbank
1217                    if line.startswith('LOCUS'):
1218                        return 'gb'
1219                    # gff
1220                    if len(line.split('\t')) == 9:
1221                        # Check for expected values
1222                        columns = line.split('\t')
1223                        if columns[6] in ['+', '-', '.'] and re.match(r'^\d+$', columns[3]) and re.match(r'^\d+$',columns[4]):
1224                            return 'gff'
1225                    # BED files are tab-delimited with at least 3 fields: chrom, start, end
1226                    fields = line.split('\t')
1227                    if len(fields) >= 3 and re.match(r'^\d+$', fields[1]) and re.match(r'^\d+$', fields[2]):
1228                        return 'bed'
1229                    # only read in the first line
1230                    break
1231
1232            raise ValueError(
1233                "File type could not be determined. Ensure the file follows a recognized format (GenBank, GFF, or BED).")
1234
1235        def parse_gb(file_path) -> dict:
1236            """
1237            parse a genebank file to dictionary - primarily retained are the informations
1238            for qualifiers as these will be used for plotting.
1239
1240            :param file_path: path to genebank file
1241            :return: nested dictionary
1242
1243            """
1244
1245            def sanitize_gb_location(string: str) -> tuple[list, str]:
1246                """
1247                see: https://www.insdc.org/submitting-standards/feature-table/
1248                """
1249                strand = '+'
1250                locations = []
1251                # check the direction of the annotation
1252                if 'complement' in string:
1253                    strand = '-'
1254                # sanitize operators
1255                for operator in ['complement(', 'join(', 'order(']:
1256                    string = string.strip(operator)
1257                # sanitize possible chars for splitting start stop -
1258                # however in the future might not simply do this
1259                # as some useful information is retained
1260                for char in ['>', '<', ')']:
1261                    string = string.replace(char, '')
1262                # check if we have multiple location e.g. due to splicing
1263                if ',' in string:
1264                    raw_locations = string.split(',')
1265                else:
1266                    raw_locations = [string]
1267                # try to split start and stop
1268                for location in raw_locations:
1269                    for sep in ['..', '.', '^']:
1270                        if sep in location:
1271                            sanitized_locations = [int(x) for x in location.split(sep)]
1272                            sanitized_locations[0] = sanitized_locations[0] - 1  # enforce 0-based starts
1273                            locations.append(sanitized_locations)
1274                            break
1275
1276                return locations, strand
1277
1278
1279            records = {}
1280            with open(file_path, "r") as file:
1281                record = None
1282                in_features = False
1283                counter_dict = {}
1284                for line in file:
1285                    line = line.rstrip()
1286                    parts = line.split()
1287                    # extract the locus id
1288                    if line.startswith('LOCUS'):
1289                        if record:
1290                            records[record['locus']] = record
1291                        record = {
1292                            'locus': parts[1],
1293                            'features': {}
1294                        }
1295
1296                    elif line.startswith('FEATURES'):
1297                        in_features = True
1298
1299                    # ignore the sequence info
1300                    elif line.startswith('ORIGIN'):
1301                        in_features = False
1302
1303                    # now write useful feature information to dictionary
1304                    elif in_features:
1305                        if not line.strip():
1306                            continue
1307                        if line[5] != ' ':
1308                            feature_type, qualifier = parts[0], parts[1]
1309                            if feature_type not in record['features']:
1310                                record['features'][feature_type] = {}
1311                                counter_dict[feature_type] = 0
1312                            locations, strand = sanitize_gb_location(qualifier)
1313                            record['features'][feature_type][counter_dict[feature_type]] = {
1314                                'location': locations,
1315                                'strand': strand
1316                            }
1317                            counter_dict[feature_type] += 1
1318                        else:
1319                            try:
1320                                qualifier_type, qualifier = parts[0].split('=')
1321                            except ValueError:  # we are in the coding sequence
1322                                qualifier = qualifier + parts[0]
1323
1324                            qualifier_type, qualifier = qualifier_type.lstrip('/'), qualifier.strip('"')
1325                            last_index = counter_dict[feature_type] - 1
1326                            record['features'][feature_type][last_index][qualifier_type] = qualifier
1327
1328            records[record['locus']] = record
1329
1330            return records
1331
1332        def parse_gff(file_path) -> dict:
1333            """
1334            Parse a GFF3 (General Feature Format) file into a dictionary structure.
1335
1336            :param file_path: path to genebank file
1337            :return: nested dictionary
1338
1339            """
1340            records = {}
1341            with open(file_path, 'r') as file:
1342                previous_id, previous_feature = None, None
1343                for line in file:
1344                    if line.startswith('#') or not line.strip():
1345                        continue
1346                    parts = line.strip().split('\t')
1347                    seqid, source, feature_type, start, end, score, strand, phase, attributes = parts
1348                    # ensure that region and source features are not named differently for gff and gb
1349                    if feature_type == 'region':
1350                        feature_type = 'source'
1351                    if seqid not in records:
1352                        records[seqid] = {'locus': seqid, 'features': {}}
1353                    if feature_type not in records[seqid]['features']:
1354                        records[seqid]['features'][feature_type] = {}
1355
1356                    feature_id = len(records[seqid]['features'][feature_type])
1357                    feature = {
1358                        'strand': strand,
1359                    }
1360
1361                    # Parse attributes into key-value pairs
1362                    for attr in attributes.split(';'):
1363                        if '=' in attr:
1364                            key, value = attr.split('=', 1)
1365                            feature[key.strip()] = value.strip()
1366
1367                    # check if feature are the same --> possible splicing
1368                    if previous_id is not None and previous_feature == feature:
1369                        records[seqid]['features'][feature_type][previous_id]['location'].append([int(start)-1, int(end)])
1370                    else:
1371                        records[seqid]['features'][feature_type][feature_id] = feature
1372                        records[seqid]['features'][feature_type][feature_id]['location'] = [[int(start) - 1, int(end)]]
1373                    # set new previous id and features -> new dict as 'location' is pointed in current feature and this
1374                    # is the only key different if next feature has the same entries
1375                    previous_id, previous_feature = feature_id, {key:value for key, value in feature.items() if key != 'location'}
1376
1377            return records
1378
1379        def parse_bed(file_path) -> dict:
1380            """
1381            Parse a BED file into a dictionary structure.
1382
1383            :param file_path: path to genebank file
1384            :return: nested dictionary
1385
1386            """
1387            records = {}
1388            with open(file_path, 'r') as file:
1389                for line in file:
1390                    if line.startswith('#') or not line.strip():
1391                        continue
1392                    parts = line.strip().split('\t')
1393                    chrom, start, end, *optional = parts
1394
1395                    if chrom not in records:
1396                        records[chrom] = {'locus': chrom, 'features': {}}
1397                    feature_type = 'region'
1398                    if feature_type not in records[chrom]['features']:
1399                        records[chrom]['features'][feature_type] = {}
1400
1401                    feature_id = len(records[chrom]['features'][feature_type])
1402                    feature = {
1403                        'location': [[int(start), int(end)]],  # BED uses 0-based start, convert to 1-based
1404                        'strand': '+',  # assume '+' if not present
1405                    }
1406
1407                    # Handle optional columns (name, score, strand) --> ignore 7-12
1408                    if len(optional) >= 1:
1409                        feature['name'] = optional[0]
1410                    if len(optional) >= 2:
1411                        feature['score'] = optional[1]
1412                    if len(optional) >= 3:
1413                        feature['strand'] = optional[2]
1414
1415                    records[chrom]['features'][feature_type][feature_id] = feature
1416
1417            return records
1418
1419        parse_functions: Dict[str, Callable[[str], dict]] = {
1420            'gb': parse_gb,
1421            'bed': parse_bed,
1422            'gff': parse_gff,
1423        }
1424        # determine the annotation content -> should be standard formatted
1425        try:
1426            annotation_type = detect_annotation_type(annotation_path)
1427        except ValueError as err:
1428            raise err
1429
1430        # read in the annotation
1431        annotations = parse_functions[annotation_type](annotation_path)
1432
1433        # sanity check whether one of the annotation ids and alignment ids match
1434        annotation_found = False
1435        for annotation in annotations.keys():
1436            for aln_id in aln.alignment.keys():
1437                aln_id_sanitized = aln_id.split(' ')[0]
1438                # check in both directions
1439                if aln_id_sanitized in annotation:
1440                    annotation_found = True
1441                    break
1442                if annotation in aln_id_sanitized:
1443                    annotation_found = True
1444                    break
1445
1446        if not annotation_found:
1447            raise ValueError(f'the annotations of {annotation_path} do not match any ids in the MSA')
1448
1449        # return only the annotation that has been found, the respective type and the seq_id to map to
1450        return annotation_type, aln_id, annotations[annotation]['locus'], annotations[annotation]['features']
1451
1452
1453    def _build_position_map(self) -> Dict[int, int]:
1454        """
1455        build a position map from a sequence.
1456
1457        :return genomic position: gapped position
1458        """
1459
1460        position_map = {}
1461        genomic_pos = 0
1462        for aln_index, char in enumerate(self._gapped_seq):
1463            if char != '-':
1464                position_map[genomic_pos] = aln_index
1465                genomic_pos += 1
1466        # ensure the last genomic position is included
1467        position_map[genomic_pos] = len(self._gapped_seq)
1468
1469        return position_map
1470
1471
1472    def _map_to_alignment(self):
1473        """
1474        Adjust all feature locations to alignment positions
1475        """
1476
1477        def map_location(position_map: Dict[int, int], locations: list) -> list:
1478            """
1479            Map genomic locations to alignment positions using a precomputed position map.
1480
1481            :param position_map: Positions mapped from gapped to ungapped
1482            :param locations: List of genomic start and end positions.
1483            :return: List of adjusted alignment positions.
1484            """
1485
1486            aligned_locs = []
1487            for start, end in locations:
1488                try:
1489                    aligned_start = position_map[start]
1490                    aligned_end = position_map[end]
1491                    aligned_locs.append([aligned_start, aligned_end])
1492                except KeyError:
1493                    raise ValueError(f"Positions {start}-{end} lie outside of the position map.")
1494
1495            return aligned_locs
1496
1497        for feature_type, features in self.features.items():
1498            for feature_id, feature_data in features.items():
1499                original_locations = feature_data['location']
1500                aligned_locations = map_location(self._position_map, original_locations)
1501                feature_data['location'] = aligned_locations
class MSA:
  26class MSA:
  27    """
  28    An alignment class that allows computation of several stats
  29    """
  30
  31    def __init__(self, alignment_path: str, reference_id: str = None, zoom_range: tuple | int = None):
  32        """
  33        Initialise an Alignment object.
  34        :param alignment_path: path to alignment file
  35        :param reference_id: reference id
  36        :param zoom_range: start and stop positions to zoom into the alignment
  37        """
  38        self._alignment = self._read_alignment(alignment_path)
  39        self._reference_id = self._validate_ref(reference_id, self._alignment)
  40        self._zoom = self._validate_zoom(zoom_range, self._alignment)
  41        self._aln_type = self._determine_aln_type(self._alignment)
  42
  43    # TODO: read in different alignment types
  44    # Static methods
  45    @staticmethod
  46    def _read_alignment(file_path: str) -> dict:
  47        """
  48        Parse MSA alignment file.
  49        :param file_path: path to alignment file
  50        :return: dictionary with ids as keys and sequences as values
  51        """
  52
  53        def add_seq(aln: dict, sequence_id: str, seq_list: list):
  54            """
  55            Add a complete sequence and check for non-allowed chars
  56            :param aln: alignment dictionary to build
  57            :param sequence_id: sequence id to add
  58            :param seq_list: sequences to add
  59            :return: alignment with added sequences
  60            """
  61            final_seq = ''.join(seq_list).upper()
  62            # Check for non-allowed characters
  63            invalid_chars = set(final_seq) - set(config.POSSIBLE_CHARS)
  64            if invalid_chars:
  65                raise ValueError(
  66                    f"{sequence_id} contains invalid characters: {', '.join(invalid_chars)}. Allowed chars are: {config.POSSIBLE_CHARS}"
  67                )
  68            aln[sequence_id] = final_seq
  69
  70            return aln
  71
  72        alignment, seq_lines = {}, []
  73        seq_id = None
  74
  75        with open(file_path, 'r') as file:
  76            for i, line in enumerate(file):
  77                line = line.strip()
  78                # initial check for fasta format
  79                if i == 0 and not line.startswith(">"):
  80                    raise ValueError('Alignment has to be in fasta format starting with >SeqID.')
  81                if line.startswith(">"):
  82                    if seq_id:
  83                        alignment = add_seq(alignment, seq_id, seq_lines)
  84                    # initialize a new sequence
  85                    seq_id, seq_lines = line[1:], []
  86                else:
  87                    seq_lines.append(line)
  88            # handle last sequence
  89            if seq_id:
  90                alignment = add_seq(alignment, seq_id, seq_lines)
  91        # final sanity checks
  92        if alignment:
  93            # alignment contains only one sequence:
  94            if len(alignment) < 2:
  95                raise ValueError("Alignment must contain more than one sequence.")
  96            # alignment sequences are not same length
  97            first_seq_len = len(next(iter(alignment.values())))
  98            for sequence_id, sequence in alignment.items():
  99                if len(sequence) != first_seq_len:
 100                    raise ValueError(
 101                        f"All alignment sequences must have the same length. Sequence '{sequence_id}' has length {len(sequence)}, expected {first_seq_len}."
 102                    )
 103            # all checks passed
 104            return alignment
 105        else:
 106            raise ValueError(f"Alignment file {file_path} does not contain any sequences in fasta format.")
 107
 108    @staticmethod
 109    def _validate_ref(reference: str | None, alignment: dict) -> str | None | ValueError:
 110        """
 111        Validate if the ref seq is indeed part of the alignment.
 112        :param reference: reference seq id
 113        :param alignment: alignment dict
 114        :return: validated reference
 115        """
 116        if reference in alignment.keys():
 117            return reference
 118        elif reference is None:
 119            return reference
 120        else:
 121            raise ValueError('Reference not in alignment.')
 122
 123    @staticmethod
 124    def _validate_zoom(zoom: tuple | int, original_aln: dict) -> ValueError | tuple | None:
 125        """
 126        Validates if the user defined zoom range is within the start, end of the initial
 127        alignment.\n
 128        :param zoom: zoom range or zoom start
 129        :param original_aln: non-zoomed alignment dict
 130        :return: validated zoom range
 131        """
 132        if zoom is not None:
 133            aln_length = len(original_aln[list(original_aln.keys())[0]])
 134            # check if only over value is provided -> stop is alignment length
 135            if isinstance(zoom, int):
 136                if 0 <= zoom < aln_length:
 137                    return zoom, aln_length - 1
 138                else:
 139                    raise ValueError('Zoom start must be within the alignment length range.')
 140            # check if more than 2 values are provided
 141            if len(zoom) != 2:
 142                raise ValueError('Zoom position have to be (zoom_start, zoom_end)')
 143            # validate zoom start/stop
 144            for position in zoom:
 145                if type(position) != int:
 146                    raise ValueError('Zoom positions have to be integers.')
 147                if position not in range(0, aln_length):
 148                    raise ValueError('Zoom position out of range')
 149
 150        return zoom
 151
 152    @staticmethod
 153    def _determine_aln_type(alignment) -> str:
 154        """
 155        Determine the most likely type of alignment
 156        if 70% of chars in the alignment are nucleotide
 157        chars it is most likely a nt alignment
 158        :return: type of alignment
 159        """
 160        counter = int()
 161        for record in alignment:
 162            if 'U' in alignment[record]:
 163                return 'RNA'
 164            counter += sum(map(alignment[record].count, ['A', 'C', 'G', 'T', 'N', '-']))
 165        # determine which is the most likely type
 166        if counter / len(alignment) >= 0.7 * len(alignment[list(alignment.keys())[0]]):
 167            return 'DNA'
 168        else:
 169            return 'AA'
 170
 171    # Properties with setters
 172    @property
 173    def reference_id(self):
 174        return self._reference_id
 175
 176    @reference_id.setter
 177    def reference_id(self, ref_id: str):
 178        """
 179        Set and validate the reference id.
 180        """
 181        self._reference_id = self._validate_ref(ref_id, self.alignment)
 182
 183    @property
 184    def zoom(self) -> tuple:
 185        return self._zoom
 186
 187    @zoom.setter
 188    def zoom(self, zoom_pos: tuple | int):
 189        """
 190        Validate if the user defined zoom range.
 191        """
 192        self._zoom = self._validate_zoom(zoom_pos, self._alignment)
 193
 194    # Property without setters
 195    @property
 196    def aln_type(self) -> str:
 197        """
 198        define the aln type:
 199        RNA, DNA or AA
 200        """
 201        return self._aln_type
 202
 203    # On the fly properties without setters
 204    @property
 205    def length(self) -> int:
 206        return len(next(iter(self.alignment.values())))
 207
 208    @property
 209    def alignment(self) -> dict:
 210        """
 211        (zoomed) version of the alignment.
 212        """
 213        if self.zoom is not None:
 214            zoomed_aln = dict()
 215            for seq in self._alignment:
 216                zoomed_aln[seq] = self._alignment[seq][self.zoom[0]:self.zoom[1]]
 217            return zoomed_aln
 218        else:
 219            return self._alignment
 220
 221    # functions for different alignment stats
 222    def get_reference_coords(self) -> tuple[int, int]:
 223        """
 224        Determine the start and end coordinates of the reference sequence
 225        defined as the first/last nucleotide in the reference sequence
 226        (excluding N and gaps).
 227
 228        :return: start, end
 229        """
 230        start, end = 0, self.length
 231
 232        if self.reference_id is None:
 233            return start, end
 234        else:
 235            # 5' --> 3'
 236            for start in range(self.length):
 237                if self.alignment[self.reference_id][start] not in ['-', 'N']:
 238                    break
 239            # 3' --> 5'
 240            for end in range(self.length - 1, 0, -1):
 241                if self.alignment[self.reference_id][end] not in ['-', 'N']:
 242                    break
 243
 244            return start, end
 245
 246    def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
 247        """
 248        Creates a non-gapped consensus sequence.
 249
 250        :param threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes
 251            the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments)
 252            or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
 253        :param use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position
 254            has a frequency above the defined threshold.
 255        :return: consensus sequence
 256        """
 257
 258        # helper functions
 259        def determine_counts(alignment_dict: dict, position: int) -> dict:
 260            """
 261            count the number of each char at
 262            an idx of the alignment. return sorted dic.
 263            handles ambiguous nucleotides in sequences.
 264            also handles gaps.
 265            """
 266            nucleotide_list = []
 267
 268            # get all nucleotides
 269            for sequence in alignment_dict.items():
 270                nucleotide_list.append(sequence[1][position])
 271            # count occurences of nucleotides
 272            counter = dict(collections.Counter(nucleotide_list))
 273            # get permutations of an ambiguous nucleotide
 274            to_delete = []
 275            temp_dict = {}
 276            for nucleotide in counter:
 277                if nucleotide in config.AMBIG_CHARS[self.aln_type]:
 278                    to_delete.append(nucleotide)
 279                    permutations = config.AMBIG_CHARS[self.aln_type][nucleotide]
 280                    adjusted_freq = 1 / len(permutations)
 281                    for permutation in permutations:
 282                        if permutation in temp_dict:
 283                            temp_dict[permutation] += adjusted_freq
 284                        else:
 285                            temp_dict[permutation] = adjusted_freq
 286
 287            # drop ambiguous entries and add adjusted freqs to
 288            if to_delete:
 289                for i in to_delete:
 290                    counter.pop(i)
 291                for nucleotide in temp_dict:
 292                    if nucleotide in counter:
 293                        counter[nucleotide] += temp_dict[nucleotide]
 294                    else:
 295                        counter[nucleotide] = temp_dict[nucleotide]
 296
 297            return dict(sorted(counter.items(), key=lambda x: x[1], reverse=True))
 298
 299        def get_consensus_char(counts: dict, cutoff: float) -> list:
 300            """
 301            get a list of nucleotides for the consensus seq
 302            """
 303            n = 0
 304
 305            consensus_chars = []
 306            for char in counts:
 307                n += counts[char]
 308                consensus_chars.append(char)
 309                if n >= cutoff:
 310                    break
 311
 312            return consensus_chars
 313
 314        def get_ambiguous_char(nucleotides: list) -> str:
 315            """
 316            get ambiguous char from a list of nucleotides
 317            """
 318            for ambiguous, permutations in config.AMBIG_CHARS[self.aln_type].items():
 319                if set(permutations) == set(nucleotides):
 320                    return ambiguous
 321
 322        # check if params have been set correctly
 323        if threshold is not None:
 324            if threshold < 0 or threshold > 1:
 325                raise ValueError('Threshold must be between 0 and 1.')
 326        if self.aln_type == 'AA' and use_ambig_nt:
 327            raise ValueError('Ambiguous characters can not be calculated for amino acid alignments.')
 328        if threshold is None and use_ambig_nt:
 329            raise ValueError('To calculate ambiguous nucleotides, set a threshold > 0.')
 330
 331        alignment = self.alignment
 332        consensus = str()
 333
 334        if threshold is not None:
 335            consensus_cutoff = len(alignment) * threshold
 336        else:
 337            consensus_cutoff = 0
 338
 339        # built consensus sequences
 340        for idx in range(self.length):
 341            char_counts = determine_counts(alignment, idx)
 342            consensus_chars = get_consensus_char(
 343                char_counts,
 344                consensus_cutoff
 345            )
 346            if threshold != 0:
 347                if len(consensus_chars) > 1:
 348                    if use_ambig_nt:
 349                        char = get_ambiguous_char(consensus_chars)
 350                    else:
 351                        if self.aln_type == 'AA':
 352                            char = 'X'
 353                        else:
 354                            char = 'N'
 355                    consensus = consensus + char
 356                else:
 357                    consensus = consensus + consensus_chars[0]
 358            else:
 359                consensus = consensus + consensus_chars[0]
 360
 361        return consensus
 362
 363    def get_conserved_orfs(self, min_length: int = 100, identity_cutoff: float | None = None) -> dict:
 364        """
 365        **conserved ORF definition:**
 366            - conserved starts and stops
 367            - start, stop must be on the same frame
 368            - stop - start must be at least min_length
 369            - all ungapped seqs[start:stop] must have at least min_length
 370            - no ungapped seq can have a Stop in between Start Stop
 371
 372        Conservation is measured by number of positions with identical characters divided by
 373        orf slice of the alignment.
 374
 375        **Algorithm overview:**
 376            - check for conserved start and stop codons
 377            - iterate over all three frames
 378            - check each start and next sufficiently far away stop codon
 379            - check if all ungapped seqs between start and stop codon are >= min_length
 380            - check if no ungapped seq in the alignment has a stop codon
 381            - write to dictionary
 382            - classify as internal if the stop codon has already been written with a prior start
 383            - repeat for reverse complement
 384
 385        :return: ORF positions and internal ORF positions
 386        """
 387
 388        # helper functions
 389        def determine_conserved_start_stops(alignment: dict, alignment_length: int) -> tuple:
 390            """
 391            Determine all start and stop codons within an alignment.
 392            :param alignment: alignment
 393            :param alignment_length: length of alignment
 394            :return: start and stop codon positions
 395            """
 396            starts = config.START_CODONS[self.aln_type]
 397            stops = config.STOP_CODONS[self.aln_type]
 398
 399            list_of_starts, list_of_stops = [], []
 400            ref = alignment[list(alignment.keys())[0]]
 401            for nt_position in range(alignment_length):
 402                if ref[nt_position:nt_position + 3] in starts:
 403                    conserved_start = True
 404                    for sequence in alignment:
 405                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in starts:
 406                            conserved_start = False
 407                            break
 408                    if conserved_start:
 409                        list_of_starts.append(nt_position)
 410
 411                if ref[nt_position:nt_position + 3] in stops:
 412                    conserved_stop = True
 413                    for sequence in alignment:
 414                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in stops:
 415                            conserved_stop = False
 416                            break
 417                    if conserved_stop:
 418                        list_of_stops.append(nt_position)
 419
 420            return list_of_starts, list_of_stops
 421
 422        def get_ungapped_sliced_seqs(alignment: dict, start_pos: int, stop_pos: int) -> list:
 423            """
 424            get ungapped sequences starting and stop codons and eliminate gaps
 425            :param alignment: alignment
 426            :param start_pos: start codon
 427            :param stop_pos: stop codon
 428            :return: sliced sequences
 429            """
 430            ungapped_seqs = []
 431            for seq_id in alignment:
 432                ungapped_seqs.append(alignment[seq_id][start_pos:stop_pos + 3].replace('-', ''))
 433
 434            return ungapped_seqs
 435
 436        def additional_stops(ungapped_seqs: list) -> bool:
 437            """
 438            Checks for the presence of a stop codon
 439            :param ungapped_seqs: list of ungapped sequences
 440            :return: Additional stop codons (True/False)
 441            """
 442            stops = config.STOP_CODONS[self.aln_type]
 443
 444            for sliced_seq in ungapped_seqs:
 445                for position in range(0, len(sliced_seq) - 3, 3):
 446                    if sliced_seq[position:position + 3] in stops:
 447                        return True
 448            return False
 449
 450        def calculate_identity(identity_matrix: ndarray, aln_slice:list) -> float:
 451            sliced_array = identity_matrix[:,aln_slice[0]:aln_slice[1]] + 1  # identical = 0, different = -1 --> add 1
 452            return np.sum(np.all(sliced_array == 1, axis=0))/(aln_slice[1] - aln_slice[0]) * 100
 453
 454        # checks for arguments
 455        if self.aln_type == 'AA':
 456            raise TypeError('ORF search only for RNA/DNA alignments')
 457
 458        if identity_cutoff is not None:
 459            if identity_cutoff > 100 or identity_cutoff < 0:
 460                raise ValueError('conservation cutoff must be between 0 and 100')
 461
 462        if min_length <= 0 or min_length > self.length:
 463            raise ValueError(f'min_length must be between 0 and {self.length}')
 464
 465        # ini
 466        identities = self.calc_identity_alignment()
 467        alignments = [self.alignment, self.calc_reverse_complement_alignment()]
 468        aln_len = self.length
 469
 470        orf_counter = 0
 471        orf_dict = {}
 472
 473        for aln, direction in zip(alignments, ['+', '-']):
 474            # check for starts and stops in the first seq and then check if these are present in all seqs
 475            conserved_starts, conserved_stops = determine_conserved_start_stops(aln, aln_len)
 476            # check each frame
 477            for frame in (0, 1, 2):
 478                potential_starts = [x for x in conserved_starts if x % 3 == frame]
 479                potential_stops = [x for x in conserved_stops if x % 3 == frame]
 480                last_stop = -1
 481                for start in potential_starts:
 482                    # go to the next stop that is sufficiently far away in the alignment
 483                    next_stops = [x for x in potential_stops if x + 3 >= start + min_length]
 484                    if not next_stops:
 485                        continue
 486                    next_stop = next_stops[0]
 487                    ungapped_sliced_seqs = get_ungapped_sliced_seqs(aln, start, next_stop)
 488                    # re-check the lengths of all ungapped seqs
 489                    ungapped_seq_lengths = [len(x) >= min_length for x in ungapped_sliced_seqs]
 490                    if not all(ungapped_seq_lengths):
 491                        continue
 492                    # if no stop codon between start and stop --> write to dictionary
 493                    if not additional_stops(ungapped_sliced_seqs):
 494                        if direction == '+':
 495                            positions = [start, next_stop + 3]
 496                        else:
 497                            positions = [aln_len - next_stop - 3, aln_len - start]
 498                        if last_stop != next_stop:
 499                            last_stop = next_stop
 500                            conservation = calculate_identity(identities, positions)
 501                            if identity_cutoff is not None and conservation < identity_cutoff:
 502                                continue
 503                            orf_dict[f'ORF_{orf_counter}'] = {'location': [positions],
 504                                                              'frame': frame,
 505                                                              'strand': direction,
 506                                                              'conservation': conservation,
 507                                                              'internal': []
 508                                                              }
 509                            orf_counter += 1
 510                        else:
 511                            orf_dict[f'ORF_{orf_counter - 1}']['internal'].append(positions)
 512
 513        return orf_dict
 514
 515    def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff:float = None) -> dict:
 516        """
 517        First calculates all ORFs and then searches from 5'
 518        all non-overlapping orfs in the fw strand and from the
 519        3' all non-overlapping orfs in th rw strand.
 520
 521        **No overlap algorithm:**
 522            **frame 1:** -[M------*]--- ----[M--*]---------[M-----
 523
 524            **frame 2:** -------[M------*]---------[M---*]--------
 525
 526            **frame 3:** [M---*]-----[M----------*]----------[M---
 527
 528            **results:** [M---*][M------*]--[M--*]-[M---*]-[M-----
 529
 530            frame:    3      2           1      2       1
 531
 532        :return: dictionary with non-overlapping orfs
 533        """
 534        orf_dict = self.get_conserved_orfs(min_length, identity_cutoff)
 535
 536        fw_orfs, rw_orfs = [], []
 537
 538        for orf in orf_dict:
 539            if orf_dict[orf]['strand'] == '+':
 540                fw_orfs.append((orf, orf_dict[orf]['location'][0]))
 541            else:
 542                rw_orfs.append((orf, orf_dict[orf]['location'][0]))
 543
 544        fw_orfs.sort(key=lambda x: x[1][0])  # sort by start pos
 545        rw_orfs.sort(key=lambda x: x[1][1], reverse=True)  # sort by stop pos
 546        non_overlapping_orfs = []
 547        for orf_list, strand in zip([fw_orfs, rw_orfs], ['+', '-']):
 548            previous_stop = -1 if strand == '+' else self.length + 1
 549            for orf in orf_list:
 550                if strand == '+' and orf[1][0] > previous_stop:
 551                    non_overlapping_orfs.append(orf[0])
 552                    previous_stop = orf[1][1]
 553                elif strand == '-' and orf[1][1] < previous_stop:
 554                    non_overlapping_orfs.append(orf[0])
 555                    previous_stop = orf[1][0]
 556
 557        non_overlap_dict = {}
 558        for orf in orf_dict:
 559            if orf in non_overlapping_orfs:
 560                non_overlap_dict[orf] = orf_dict[orf]
 561
 562        return non_overlap_dict
 563
 564    def calc_length_stats(self) -> dict:
 565        """
 566        Determine the stats for the length of the ungapped seqs in the alignment.
 567        :return: dictionary with length stats
 568        """
 569
 570        seq_lengths = [len(self.alignment[x].replace('-', '')) for x in self.alignment]
 571
 572        return {'number of seq': len(self.alignment),
 573                'mean length': float(np.mean(seq_lengths)),
 574                'std length': float(np.std(seq_lengths)),
 575                'min length': int(np.min(seq_lengths)),
 576                'max length': int(np.max(seq_lengths))
 577                }
 578
 579    def calc_entropy(self) -> list:
 580        """
 581        Calculate the normalized shannon's entropy for every position in an alignment:
 582
 583        - 1: high entropy
 584        - 0: low entropy
 585
 586        :return: Entropies at each position.
 587        """
 588
 589        # helper functions
 590        def shannons_entropy(character_list: list, states: int, aln_type: str) -> float:
 591            """
 592            Calculate the shannon's entropy of a sequence and
 593            normalized between 0 and 1.
 594            :param character_list: characters at an alignment position
 595            :param states: number of potential characters that can be present
 596            :param aln_type: type of the alignment
 597            :returns: entropy
 598            """
 599            ent, n_chars = 0, len(character_list)
 600            # only one char is in the list
 601            if n_chars <= 1:
 602                return ent
 603            # calculate the number of unique chars and their counts
 604            chars, char_counts = np.unique(character_list, return_counts=True)
 605            char_counts = char_counts.astype(float)
 606            # ignore gaps for entropy calc
 607            char_counts, chars = char_counts[chars != "-"], chars[chars != "-"]
 608            # correctly handle ambiguous chars
 609            index_to_drop = []
 610            for index, char in enumerate(chars):
 611                if char in config.AMBIG_CHARS[aln_type]:
 612                    index_to_drop.append(index)
 613                    amb_chars, amb_counts = np.unique(config.AMBIG_CHARS[aln_type][char], return_counts=True)
 614                    amb_counts = amb_counts / len(config.AMBIG_CHARS[aln_type][char])
 615                    # add the proportionate numbers to initial array
 616                    for amb_char, amb_count in zip(amb_chars, amb_counts):
 617                        if amb_char in chars:
 618                            char_counts[chars == amb_char] += amb_count
 619                        else:
 620                            chars, char_counts = np.append(chars, amb_char), np.append(char_counts, amb_count)
 621            # drop the ambiguous characters from array
 622            char_counts, chars = np.delete(char_counts, index_to_drop), np.delete(chars, index_to_drop)
 623            # calc the entropy
 624            probs = char_counts / n_chars
 625            if np.count_nonzero(probs) <= 1:
 626                return ent
 627            for prob in probs:
 628                ent -= prob * math.log(prob, states)
 629
 630            return ent
 631
 632        aln = self.alignment
 633        entropys = []
 634
 635        if self.aln_type == 'AA':
 636            states = 20
 637        else:
 638            states = 4
 639        # iterate over alignment positions and the sequences
 640        for nuc_pos in range(self.length):
 641            pos = []
 642            for record in aln:
 643                pos.append(aln[record][nuc_pos])
 644            entropys.append(shannons_entropy(pos, states, self.aln_type))
 645
 646        return entropys
 647
 648    def calc_gc(self) -> list | TypeError:
 649        """
 650        Determine the GC content for every position in an nt alignment.
 651        :return: GC content for every position.
 652        :raises: TypeError for AA alignments
 653        """
 654        if self.aln_type == 'AA':
 655            raise TypeError("GC computation is not possible for aminoacid alignment")
 656
 657        gc, aln, amb_nucs = [], self.alignment, config.AMBIG_CHARS[self.aln_type]
 658
 659        for position in range(self.length):
 660            nucleotides = str()
 661            for record in aln:
 662                nucleotides = nucleotides + aln[record][position]
 663            # ini dict with chars that occur and which ones to
 664            # count in which freq
 665            to_count = {
 666                'G': 1,
 667                'C': 1,
 668            }
 669            # handle ambig. nuc
 670            for char in amb_nucs:
 671                if char in nucleotides:
 672                    to_count[char] = (amb_nucs[char].count('C') + amb_nucs[char].count('G')) / len(amb_nucs[char])
 673
 674            gc.append(
 675                sum([nucleotides.count(x) * to_count[x] for x in to_count]) / len(nucleotides)
 676            )
 677
 678        return gc
 679
 680    def calc_coverage(self) -> list:
 681        """
 682        Determine the coverage of every position in an alignment.
 683        This is defined as:
 684            1 - cumulative length of '-' characters
 685
 686        :return: Coverage at each alignment position.
 687        """
 688        coverage, aln = [], self.alignment
 689
 690        for nuc_pos in range(self.length):
 691            pos = str()
 692            for record in aln.keys():
 693                pos = pos + aln[record][nuc_pos]
 694            coverage.append(1 - pos.count('-') / len(pos))
 695
 696        return coverage
 697
 698    def calc_reverse_complement_alignment(self) -> dict | TypeError:
 699        """
 700        Reverse complement the alignment.
 701        :return: Alignment (rv)
 702        """
 703        if self.aln_type == 'AA':
 704            raise TypeError('Reverse complement only for RNA or DNA.')
 705
 706        aln = self.alignment
 707        reverse_complement_dict = {}
 708
 709        for seq_id in aln:
 710            reverse_complement_dict[seq_id] = ''.join(config.COMPLEMENT[base] for base in reversed(aln[seq_id]))
 711
 712        return reverse_complement_dict
 713
 714    def calc_identity_alignment(self, encode_mismatches:bool=True, encode_mask:bool=False, encode_gaps:bool=True, encode_ambiguities:bool=False, encode_each_mismatch_char:bool=False) -> np.ndarray:
 715        """
 716        Converts alignment to identity array (identical=0) compared to majority consensus or reference:\n
 717
 718        :param encode_mismatches: encode mismatch as -1
 719        :param encode_mask: encode mask with value=-2 --> also in the reference
 720        :param encode_gaps: encode gaps with np.nan --> also in the reference
 721        :param encode_ambiguities: encode ambiguities with value=-3
 722        :param encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.DNA_colors, config.RNA_colors or config.AA_colors
 723        :return: identity alignment
 724        """
 725
 726        aln = self.alignment
 727        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
 728
 729        # convert alignment to array
 730        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 731        reference = np.array(list(ref))
 732        # ini matrix
 733        identity_matrix = np.full(sequences.shape, 0, dtype=float)
 734
 735        is_identical = sequences == reference
 736
 737        if encode_gaps:
 738            is_gap = sequences == '-'
 739        else:
 740            is_gap = np.full(sequences.shape, False)
 741
 742        if encode_mask:
 743            if self.aln_type == 'AA':
 744                is_n_or_x = np.isin(sequences, ['X'])
 745            else:
 746                is_n_or_x = np.isin(sequences, ['N'])
 747        else:
 748            is_n_or_x = np.full(sequences.shape, False)
 749
 750        if encode_ambiguities:
 751            is_ambig = np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])
 752        else:
 753            is_ambig = np.full(sequences.shape, False)
 754
 755        if encode_mismatches:
 756            is_mismatch = ~is_gap & ~is_identical & ~is_n_or_x & ~is_ambig
 757        else:
 758            is_mismatch = np.full(sequences.shape, False)
 759
 760        # encode every different character
 761        if encode_each_mismatch_char:
 762            for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]):
 763                new_encoding = np.isin(sequences, [char]) & is_mismatch
 764                identity_matrix[new_encoding] = idx + 1
 765        # or encode different with a single value
 766        else:
 767            identity_matrix[is_mismatch] = -1  # mismatch
 768
 769        identity_matrix[is_gap] = np.nan  # gap
 770        identity_matrix[is_n_or_x] = -2  # 'N' or 'X'
 771        identity_matrix[is_ambig] = -3  # ambiguities
 772
 773        return identity_matrix
 774
 775    def calc_similarity_alignment(self, matrix_type:str|None=None, normalize:bool=True) -> np.ndarray:
 776        """
 777        Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight
 778        differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the
 779        reference residue at each column. Gaps are encoded as np.nan.
 780
 781        The calculation follows these steps:
 782
 783        1. **Reference Sequence**: If a reference sequence is provided (via `self.reference_id`), it is used. Otherwise,
 784           a consensus sequence is generated to serve as the reference.
 785        2. **Substitution Matrix**: The similarity between residues is determined using a substitution matrix, such as
 786           BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
 787        3. **Per-Column Normalization (optional)**:
 788           - For each column in the alignment:
 789             - The residue in the reference sequence is treated as the baseline for that column.
 790             - The substitution scores for the reference residue are extracted from the substitution matrix.
 791             - The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference
 792               residue.
 793           - This ensures that identical residues (or those with high similarity to the reference) have high scores,
 794             while more dissimilar residues have lower scores.
 795        4. **Output**:
 796           - The normalized similarity scores are stored in a NumPy array.
 797           - Gaps (if any) or residues not present in the substitution matrix are encoded as `np.nan`.
 798
 799        :param: matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
 800        :param: normalize: whether to normalize the similarity scores to range [0, 1]
 801        :return: A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue
 802            and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity).
 803            Gaps and invalid residues are encoded as `np.nan`.
 804        :raise: ValueError
 805            If the specified substitution matrix is not available for the given alignment type.
 806        """
 807
 808        aln = self.alignment
 809        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
 810        if matrix_type is None:
 811            if self.aln_type == 'AA':
 812                matrix_type = 'BLOSUM65'
 813            else:
 814                matrix_type = 'TRANS'
 815        # load substitution matrix as dictionary
 816        try:
 817            subs_matrix = config.SUBS_MATRICES[self.aln_type][matrix_type]
 818        except KeyError:
 819            raise ValueError(
 820                f'The specified matrix does not exist for alignment type.\nAvailable matrices for {self.aln_type} are:\n{list(config.SUBS_MATRICES[self.aln_type].keys())}'
 821            )
 822
 823        # set dtype and convert alignment to a NumPy array for vectorized processing
 824        dtype = np.dtype(float, metadata={'matrix': matrix_type})
 825        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 826        reference = np.array(list(ref))
 827        valid_chars = list(subs_matrix.keys())
 828        similarity_array = np.full(sequences.shape, np.nan, dtype=dtype)
 829
 830        for j, ref_char in enumerate(reference):
 831            if ref_char not in valid_chars + ['-']:
 832                continue
 833            # Get local min and max for the reference residue
 834            if normalize and ref_char != '-':
 835                local_scores = subs_matrix[ref_char].values()
 836                local_min, local_max = min(local_scores), max(local_scores)
 837
 838            for i, char in enumerate(sequences[:, j]):
 839                if char not in valid_chars:
 840                    continue
 841                # classify the similarity as max if the reference has a gap
 842                similarity_score = subs_matrix[char][ref_char] if ref_char != '-' else 1
 843                similarity_array[i, j] = (similarity_score - local_min) / (local_max - local_min) if normalize and ref_char != '-' else similarity_score
 844
 845        return similarity_array
 846
 847    def calc_position_matrix(self, matrix_type:str='PWM') -> np.ndarray | ValueError:
 848        """
 849        Calculate various position matrices (reference https://en.wikipedia.org/wiki/Position_weight_matrix)
 850
 851        **Major steps:**
 852            1) calculate character counts (PFM)
 853            2) calculate character frequencies (PPM)
 854            3) add pseudocount (square root of row length) -> scales with aln size --> needed for positions with 0 counts
 855            4) transform to PWM with M_k,j=log_2(M_k,j/b_k) with b_k assuming statistical independence (all chars are equally frequent)
 856
 857        :param matrix_type: matrix to return (PFM, PPM or PWM)
 858        :return: pwm as numpy array
 859        :raise: ValueError for incorrect matrix type
 860        """
 861
 862        # ini
 863        aln = self.alignment
 864        if matrix_type not in ['PFM', 'PPM', 'PWM']:
 865            raise ValueError('Matrix_type must be PFM, PPM or PWM.')
 866        possible_chars = list(config.CHAR_COLORS[self.aln_type].keys())[:-1]
 867        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
 868
 869        # calc position frequency matrix
 870        pfm = np.array([np.sum(sequences == char, 0) for char in possible_chars]) \
 871              + np.sqrt(sequences.shape[0])  # add pseudo-counts
 872        if matrix_type == 'PFM':
 873            return pfm
 874
 875        # calc position probability matrix
 876        ppm = pfm/np.sum(pfm, 0)
 877        if matrix_type == 'PPM':
 878            return ppm
 879
 880        # calc position weight matrix
 881        pwm = np.log2(ppm*len(possible_chars))
 882        if matrix_type == 'PWM':
 883            return pwm
 884
 885    def calc_percent_recovery(self) -> dict:
 886        """
 887        Recovery per sequence either compared to the majority consensus seq
 888        or the reference seq.\n
 889        Defined as:\n
 890
 891        `(1 - sum(N/X/- characters in ungapped ref regions))*100`
 892
 893        This is highly similar to how nextclade calculates recovery over reference.
 894
 895        :return: dict
 896        """
 897
 898        aln = self.alignment
 899
 900        if self.reference_id is not None:
 901            ref = aln[self.reference_id]
 902        else:
 903            ref = self.get_consensus()  # majority consensus
 904
 905        if not any(char != '-' for char in ref):
 906            raise ValueError("Reference sequence is entirely gapped, cannot calculate recovery.")
 907
 908
 909        # count 'N', 'X' and '-' chars in non-gapped regions
 910        recovery_over_ref = dict()
 911
 912        # Get positions of non-gap characters in the reference
 913        non_gap_positions = [i for i, char in enumerate(ref) if char != '-']
 914        cumulative_length = len(non_gap_positions)
 915
 916        # Calculate recovery
 917        for seq_id in aln:
 918            if seq_id == self.reference_id:
 919                continue
 920            seq = aln[seq_id]
 921            count_invalid = sum(
 922                seq[pos] == '-' or
 923                (seq[pos] == 'X' if self.aln_type == "AA" else seq[pos] == 'N')
 924                for pos in non_gap_positions
 925            )
 926            recovery_over_ref[seq_id] = (1 - count_invalid / cumulative_length) * 100
 927
 928        return recovery_over_ref
 929
 930    def calc_character_frequencies(self) -> dict:
 931        """
 932        Calculate the percentage characters in the alignment:
 933        The frequencies are counted by seq and in total. The
 934        percentage of non-gap characters in the alignment is
 935        relative to the total number of non-gap characters.
 936        The gap percentage is relative to the sequence length.
 937
 938        The output is a nested dictionary.
 939
 940        :return: Character frequencies
 941        """
 942
 943        aln, aln_length = self.alignment, self.length
 944
 945        freqs = {'total': {'-': {'counts': 0, '% of alignment': float()}}}
 946
 947        for seq_id in aln:
 948            freqs[seq_id], all_chars = {'-': {'counts': 0, '% of alignment': float()}}, 0
 949            unique_chars = set(aln[seq_id])
 950            for char in unique_chars:
 951                if char == '-':
 952                    continue
 953                # add characters to dictionaries
 954                if char not in freqs[seq_id]:
 955                    freqs[seq_id][char] = {'counts': 0, '% of non-gapped': 0}
 956                if char not in freqs['total']:
 957                    freqs['total'][char] = {'counts': 0, '% of non-gapped': 0}
 958                # count non-gap chars
 959                freqs[seq_id][char]['counts'] += aln[seq_id].count(char)
 960                freqs['total'][char]['counts'] += freqs[seq_id][char]['counts']
 961                all_chars += freqs[seq_id][char]['counts']
 962            # normalize counts
 963            for char in freqs[seq_id]:
 964                if char == '-':
 965                    continue
 966                freqs[seq_id][char]['% of non-gapped'] = freqs[seq_id][char]['counts'] / all_chars * 100
 967                freqs['total'][char]['% of non-gapped'] += freqs[seq_id][char]['% of non-gapped']
 968            # count gaps
 969            freqs[seq_id]['-']['counts'] = aln[seq_id].count('-')
 970            freqs['total']['-']['counts'] += freqs[seq_id]['-']['counts']
 971            # normalize gap counts
 972            freqs[seq_id]['-']['% of alignment'] = freqs[seq_id]['-']['counts'] / aln_length * 100
 973            freqs['total']['-']['% of alignment'] += freqs[seq_id]['-']['% of alignment']
 974
 975        # normalize the total counts
 976        for char in freqs['total']:
 977            for value in freqs['total'][char]:
 978                if value == '% of alignment' or value == '% of non-gapped':
 979                    freqs['total'][char][value] = freqs['total'][char][value] / len(aln)
 980
 981        return freqs
 982
 983    def calc_pairwise_identity_matrix(self, distance_type:str='ghd') -> ndarray:
 984        """
 985        Calculate pairwise identities for an alignment. As there are different definitions of sequence identity, there are different options implemented:
 986
 987        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
 988        \ndistance = matches / alignment_length * 100
 989
 990        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
 991        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
 992
 993        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
 994        \ndistance = matches / (matches + mismatches) * 100
 995
 996        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
 997        \ndistance = matches / gap_compressed_alignment_length * 100
 998
 999        :return: array with pairwise distances.
1000        """
1001
1002        def hamming_distance(seq1: str, seq2: str) -> int:
1003            return sum(c1 == c2 for c1, c2 in zip(seq1, seq2))
1004
1005        def ghd(seq1: str, seq2: str) -> float:
1006            return hamming_distance(seq1, seq2) / self.length * 100
1007
1008        def lhd(seq1: str, seq2: str) -> float:
1009            # remove 5' trailing gaps
1010            i, j = 0, self.length - 1
1011            while i < self.length and (seq1[i] == '-' or seq2[i] == '-'):
1012                i += 1
1013            while j >= 0 and (seq1[j] == '-' or seq2[j] == '-'):
1014                j -= 1
1015            if i > j:
1016                return 0.0
1017            # slice seq
1018            seq1_, seq2_ = seq1[i:j + 1], seq2[i:j + 1]
1019
1020            return hamming_distance(seq1_, seq2_) / min([len(seq1_), len(seq2_)]) * 100
1021
1022        def ged(seq1: str, seq2: str) -> float:
1023
1024            matches, mismatches = 0, 0
1025
1026            for c1, c2 in zip(seq1, seq2):
1027                if c1 != '-' and c2 != '-':
1028                    if c1 == c2:
1029                        matches += 1
1030                    else:
1031                        mismatches += 1
1032            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1033
1034        def gcd(seq1: str, seq2: str) -> float:
1035            matches = 0
1036            mismatches = 0
1037            in_gap = False
1038
1039            for char1, char2 in zip(seq1, seq2):
1040                if char1 == '-' and char2 == '-':  # Shared gap: do nothing
1041                    continue
1042                elif char1 == '-' or char2 == '-':  # Gap in only one sequence
1043                    if not in_gap:  # Start of a new gap stretch
1044                        mismatches += 1
1045                        in_gap = True
1046                else:  # No gaps
1047                    in_gap = False
1048                    if char1 == char2:  # Matching characters
1049                        matches += 1
1050                    else:  # Mismatched characters
1051                        mismatches += 1
1052
1053            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1054
1055
1056        # Map distance type to corresponding function
1057        distance_functions: Dict[str, Callable[[str, str], float]] = {
1058            'ghd': ghd,
1059            'lhd': lhd,
1060            'ged': ged,
1061            'gcd': gcd
1062        }
1063
1064        if distance_type not in distance_functions:
1065            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1066
1067        # Compute pairwise distances
1068        aln = self.alignment
1069        distance_func = distance_functions[distance_type]
1070        distance_matrix = np.zeros((len(aln), len(aln)))
1071
1072        sequences = list(aln.values())
1073        for i, seq1 in enumerate(sequences):
1074            for j, seq2 in enumerate(sequences):
1075                if i <= j:  # Compute only once for symmetric matrix
1076                    distance_matrix[i, j] = distance_func(seq1, seq2)
1077                    distance_matrix[j, i] = distance_matrix[i, j]
1078
1079        return distance_matrix
1080
1081    def get_snps(self, include_ambig:bool=False) -> dict:
1082        """
1083        Calculate snps similar to snp-sites (output is comparable):
1084        https://github.com/sanger-pathogens/snp-sites
1085        Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character.
1086        The SNPs are compared to a majority consensus sequence or to a reference if it has been set.
1087
1088        :param include_ambig: Include ambiguous snps (default: False)
1089        :return: dictionary containing snp positions and their variants including their frequency.
1090        """
1091        aln = self.alignment
1092        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
1093        aln = {x: aln[x] for x in aln.keys() if x != self.reference_id}
1094        seq_ids = list(aln.keys())
1095        snp_dict = {'#CHROM': self.reference_id if self.reference_id is not None else 'consensus', 'POS': {}}
1096
1097        for pos in range(self.length):
1098            reference_char = ref[pos]
1099            if not include_ambig:
1100                if reference_char in config.AMBIG_CHARS[self.aln_type] and reference_char != '-':
1101                    continue
1102            alt_chars, snps = [], []
1103            for i, seq_id in enumerate(aln.keys()):
1104                alt_chars.append(aln[seq_id][pos])
1105                if reference_char != aln[seq_id][pos]:
1106                    snps.append(i)
1107            if not snps:
1108                continue
1109            if include_ambig:
1110                if all(alt_chars[x] in config.AMBIG_CHARS[self.aln_type] for x in snps):
1111                    continue
1112            else:
1113                snps = [x for x in snps if alt_chars[x] not in config.AMBIG_CHARS[self.aln_type]]
1114                if not snps:
1115                    continue
1116            if pos not in snp_dict:
1117                snp_dict['POS'][pos] = {'ref': reference_char, 'ALT': {}}
1118            for snp in snps:
1119                if alt_chars[snp] not in snp_dict['POS'][pos]['ALT']:
1120                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]] = {
1121                        'AF': 1,
1122                        'SEQ_ID': [seq_ids[snp]]
1123                    }
1124                else:
1125                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['AF'] += 1
1126                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['SEQ_ID'].append(seq_ids[snp])
1127            # calculate AF
1128            if pos in snp_dict['POS']:
1129                for alt in snp_dict['POS'][pos]['ALT']:
1130                    snp_dict['POS'][pos]['ALT'][alt]['AF'] /= len(aln)
1131
1132        return snp_dict
1133
1134    def calc_transition_transversion_score(self) -> list:
1135        """
1136        Based on the snp positions, calculates a transition/transversions score.
1137        A positive score means higher ratio of transitions and negative score means
1138        a higher ratio of transversions.
1139        :return: list
1140        """
1141
1142        if self.aln_type == 'AA':
1143            raise TypeError('TS/TV scoring only for RNA/DNA alignments')
1144
1145        # ini
1146        snps = self.get_snps()
1147        score = [0]*self.length
1148
1149        for pos in snps['POS']:
1150            t_score_temp = 0
1151            for alt in snps['POS'][pos]['ALT']:
1152                # check the type of substitution
1153                if snps['POS'][pos]['ref'] + alt in ['AG', 'GA', 'CT', 'TC', 'CU', 'UC']:
1154                    score[pos] += snps['POS'][pos]['ALT'][alt]['AF']
1155                else:
1156                    score[pos] -= snps['POS'][pos]['ALT'][alt]['AF']
1157
1158        return score

An alignment class that allows computation of several stats

MSA( alignment_path: str, reference_id: str = None, zoom_range: tuple | int = None)
31    def __init__(self, alignment_path: str, reference_id: str = None, zoom_range: tuple | int = None):
32        """
33        Initialise an Alignment object.
34        :param alignment_path: path to alignment file
35        :param reference_id: reference id
36        :param zoom_range: start and stop positions to zoom into the alignment
37        """
38        self._alignment = self._read_alignment(alignment_path)
39        self._reference_id = self._validate_ref(reference_id, self._alignment)
40        self._zoom = self._validate_zoom(zoom_range, self._alignment)
41        self._aln_type = self._determine_aln_type(self._alignment)

Initialise an Alignment object.

Parameters
  • alignment_path: path to alignment file
  • reference_id: reference id
  • zoom_range: start and stop positions to zoom into the alignment
reference_id
172    @property
173    def reference_id(self):
174        return self._reference_id

Set and validate the reference id.

zoom: tuple
183    @property
184    def zoom(self) -> tuple:
185        return self._zoom

Validate if the user defined zoom range.

aln_type: str
195    @property
196    def aln_type(self) -> str:
197        """
198        define the aln type:
199        RNA, DNA or AA
200        """
201        return self._aln_type

define the aln type: RNA, DNA or AA

length: int
204    @property
205    def length(self) -> int:
206        return len(next(iter(self.alignment.values())))
alignment: dict
208    @property
209    def alignment(self) -> dict:
210        """
211        (zoomed) version of the alignment.
212        """
213        if self.zoom is not None:
214            zoomed_aln = dict()
215            for seq in self._alignment:
216                zoomed_aln[seq] = self._alignment[seq][self.zoom[0]:self.zoom[1]]
217            return zoomed_aln
218        else:
219            return self._alignment

(zoomed) version of the alignment.

def get_reference_coords(self) -> tuple[int, int]:
222    def get_reference_coords(self) -> tuple[int, int]:
223        """
224        Determine the start and end coordinates of the reference sequence
225        defined as the first/last nucleotide in the reference sequence
226        (excluding N and gaps).
227
228        :return: start, end
229        """
230        start, end = 0, self.length
231
232        if self.reference_id is None:
233            return start, end
234        else:
235            # 5' --> 3'
236            for start in range(self.length):
237                if self.alignment[self.reference_id][start] not in ['-', 'N']:
238                    break
239            # 3' --> 5'
240            for end in range(self.length - 1, 0, -1):
241                if self.alignment[self.reference_id][end] not in ['-', 'N']:
242                    break
243
244            return start, end

Determine the start and end coordinates of the reference sequence defined as the first/last nucleotide in the reference sequence (excluding N and gaps).

Returns

start, end

def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
246    def get_consensus(self, threshold: float = None, use_ambig_nt: bool = False) -> str:
247        """
248        Creates a non-gapped consensus sequence.
249
250        :param threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes
251            the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments)
252            or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
253        :param use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position
254            has a frequency above the defined threshold.
255        :return: consensus sequence
256        """
257
258        # helper functions
259        def determine_counts(alignment_dict: dict, position: int) -> dict:
260            """
261            count the number of each char at
262            an idx of the alignment. return sorted dic.
263            handles ambiguous nucleotides in sequences.
264            also handles gaps.
265            """
266            nucleotide_list = []
267
268            # get all nucleotides
269            for sequence in alignment_dict.items():
270                nucleotide_list.append(sequence[1][position])
271            # count occurences of nucleotides
272            counter = dict(collections.Counter(nucleotide_list))
273            # get permutations of an ambiguous nucleotide
274            to_delete = []
275            temp_dict = {}
276            for nucleotide in counter:
277                if nucleotide in config.AMBIG_CHARS[self.aln_type]:
278                    to_delete.append(nucleotide)
279                    permutations = config.AMBIG_CHARS[self.aln_type][nucleotide]
280                    adjusted_freq = 1 / len(permutations)
281                    for permutation in permutations:
282                        if permutation in temp_dict:
283                            temp_dict[permutation] += adjusted_freq
284                        else:
285                            temp_dict[permutation] = adjusted_freq
286
287            # drop ambiguous entries and add adjusted freqs to
288            if to_delete:
289                for i in to_delete:
290                    counter.pop(i)
291                for nucleotide in temp_dict:
292                    if nucleotide in counter:
293                        counter[nucleotide] += temp_dict[nucleotide]
294                    else:
295                        counter[nucleotide] = temp_dict[nucleotide]
296
297            return dict(sorted(counter.items(), key=lambda x: x[1], reverse=True))
298
299        def get_consensus_char(counts: dict, cutoff: float) -> list:
300            """
301            get a list of nucleotides for the consensus seq
302            """
303            n = 0
304
305            consensus_chars = []
306            for char in counts:
307                n += counts[char]
308                consensus_chars.append(char)
309                if n >= cutoff:
310                    break
311
312            return consensus_chars
313
314        def get_ambiguous_char(nucleotides: list) -> str:
315            """
316            get ambiguous char from a list of nucleotides
317            """
318            for ambiguous, permutations in config.AMBIG_CHARS[self.aln_type].items():
319                if set(permutations) == set(nucleotides):
320                    return ambiguous
321
322        # check if params have been set correctly
323        if threshold is not None:
324            if threshold < 0 or threshold > 1:
325                raise ValueError('Threshold must be between 0 and 1.')
326        if self.aln_type == 'AA' and use_ambig_nt:
327            raise ValueError('Ambiguous characters can not be calculated for amino acid alignments.')
328        if threshold is None and use_ambig_nt:
329            raise ValueError('To calculate ambiguous nucleotides, set a threshold > 0.')
330
331        alignment = self.alignment
332        consensus = str()
333
334        if threshold is not None:
335            consensus_cutoff = len(alignment) * threshold
336        else:
337            consensus_cutoff = 0
338
339        # built consensus sequences
340        for idx in range(self.length):
341            char_counts = determine_counts(alignment, idx)
342            consensus_chars = get_consensus_char(
343                char_counts,
344                consensus_cutoff
345            )
346            if threshold != 0:
347                if len(consensus_chars) > 1:
348                    if use_ambig_nt:
349                        char = get_ambiguous_char(consensus_chars)
350                    else:
351                        if self.aln_type == 'AA':
352                            char = 'X'
353                        else:
354                            char = 'N'
355                    consensus = consensus + char
356                else:
357                    consensus = consensus + consensus_chars[0]
358            else:
359                consensus = consensus + consensus_chars[0]
360
361        return consensus

Creates a non-gapped consensus sequence.

Parameters
  • threshold: Threshold for consensus sequence. If use_ambig_nt = True the ambig. char that encodes the nucleotides that reach a cumulative frequency >= threshold is used. Otherwise 'N' (for nt alignments) or 'X' (for as alignments) is used if none of the characters reach a cumulative frequency >= threshold.
  • use_ambig_nt: Use ambiguous character nt if none of the possible nt at a alignment position has a frequency above the defined threshold.
Returns

consensus sequence

def get_conserved_orfs( self, min_length: int = 100, identity_cutoff: float | None = None) -> dict:
363    def get_conserved_orfs(self, min_length: int = 100, identity_cutoff: float | None = None) -> dict:
364        """
365        **conserved ORF definition:**
366            - conserved starts and stops
367            - start, stop must be on the same frame
368            - stop - start must be at least min_length
369            - all ungapped seqs[start:stop] must have at least min_length
370            - no ungapped seq can have a Stop in between Start Stop
371
372        Conservation is measured by number of positions with identical characters divided by
373        orf slice of the alignment.
374
375        **Algorithm overview:**
376            - check for conserved start and stop codons
377            - iterate over all three frames
378            - check each start and next sufficiently far away stop codon
379            - check if all ungapped seqs between start and stop codon are >= min_length
380            - check if no ungapped seq in the alignment has a stop codon
381            - write to dictionary
382            - classify as internal if the stop codon has already been written with a prior start
383            - repeat for reverse complement
384
385        :return: ORF positions and internal ORF positions
386        """
387
388        # helper functions
389        def determine_conserved_start_stops(alignment: dict, alignment_length: int) -> tuple:
390            """
391            Determine all start and stop codons within an alignment.
392            :param alignment: alignment
393            :param alignment_length: length of alignment
394            :return: start and stop codon positions
395            """
396            starts = config.START_CODONS[self.aln_type]
397            stops = config.STOP_CODONS[self.aln_type]
398
399            list_of_starts, list_of_stops = [], []
400            ref = alignment[list(alignment.keys())[0]]
401            for nt_position in range(alignment_length):
402                if ref[nt_position:nt_position + 3] in starts:
403                    conserved_start = True
404                    for sequence in alignment:
405                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in starts:
406                            conserved_start = False
407                            break
408                    if conserved_start:
409                        list_of_starts.append(nt_position)
410
411                if ref[nt_position:nt_position + 3] in stops:
412                    conserved_stop = True
413                    for sequence in alignment:
414                        if not alignment[sequence][nt_position:].replace('-', '')[0:3] in stops:
415                            conserved_stop = False
416                            break
417                    if conserved_stop:
418                        list_of_stops.append(nt_position)
419
420            return list_of_starts, list_of_stops
421
422        def get_ungapped_sliced_seqs(alignment: dict, start_pos: int, stop_pos: int) -> list:
423            """
424            get ungapped sequences starting and stop codons and eliminate gaps
425            :param alignment: alignment
426            :param start_pos: start codon
427            :param stop_pos: stop codon
428            :return: sliced sequences
429            """
430            ungapped_seqs = []
431            for seq_id in alignment:
432                ungapped_seqs.append(alignment[seq_id][start_pos:stop_pos + 3].replace('-', ''))
433
434            return ungapped_seqs
435
436        def additional_stops(ungapped_seqs: list) -> bool:
437            """
438            Checks for the presence of a stop codon
439            :param ungapped_seqs: list of ungapped sequences
440            :return: Additional stop codons (True/False)
441            """
442            stops = config.STOP_CODONS[self.aln_type]
443
444            for sliced_seq in ungapped_seqs:
445                for position in range(0, len(sliced_seq) - 3, 3):
446                    if sliced_seq[position:position + 3] in stops:
447                        return True
448            return False
449
450        def calculate_identity(identity_matrix: ndarray, aln_slice:list) -> float:
451            sliced_array = identity_matrix[:,aln_slice[0]:aln_slice[1]] + 1  # identical = 0, different = -1 --> add 1
452            return np.sum(np.all(sliced_array == 1, axis=0))/(aln_slice[1] - aln_slice[0]) * 100
453
454        # checks for arguments
455        if self.aln_type == 'AA':
456            raise TypeError('ORF search only for RNA/DNA alignments')
457
458        if identity_cutoff is not None:
459            if identity_cutoff > 100 or identity_cutoff < 0:
460                raise ValueError('conservation cutoff must be between 0 and 100')
461
462        if min_length <= 0 or min_length > self.length:
463            raise ValueError(f'min_length must be between 0 and {self.length}')
464
465        # ini
466        identities = self.calc_identity_alignment()
467        alignments = [self.alignment, self.calc_reverse_complement_alignment()]
468        aln_len = self.length
469
470        orf_counter = 0
471        orf_dict = {}
472
473        for aln, direction in zip(alignments, ['+', '-']):
474            # check for starts and stops in the first seq and then check if these are present in all seqs
475            conserved_starts, conserved_stops = determine_conserved_start_stops(aln, aln_len)
476            # check each frame
477            for frame in (0, 1, 2):
478                potential_starts = [x for x in conserved_starts if x % 3 == frame]
479                potential_stops = [x for x in conserved_stops if x % 3 == frame]
480                last_stop = -1
481                for start in potential_starts:
482                    # go to the next stop that is sufficiently far away in the alignment
483                    next_stops = [x for x in potential_stops if x + 3 >= start + min_length]
484                    if not next_stops:
485                        continue
486                    next_stop = next_stops[0]
487                    ungapped_sliced_seqs = get_ungapped_sliced_seqs(aln, start, next_stop)
488                    # re-check the lengths of all ungapped seqs
489                    ungapped_seq_lengths = [len(x) >= min_length for x in ungapped_sliced_seqs]
490                    if not all(ungapped_seq_lengths):
491                        continue
492                    # if no stop codon between start and stop --> write to dictionary
493                    if not additional_stops(ungapped_sliced_seqs):
494                        if direction == '+':
495                            positions = [start, next_stop + 3]
496                        else:
497                            positions = [aln_len - next_stop - 3, aln_len - start]
498                        if last_stop != next_stop:
499                            last_stop = next_stop
500                            conservation = calculate_identity(identities, positions)
501                            if identity_cutoff is not None and conservation < identity_cutoff:
502                                continue
503                            orf_dict[f'ORF_{orf_counter}'] = {'location': [positions],
504                                                              'frame': frame,
505                                                              'strand': direction,
506                                                              'conservation': conservation,
507                                                              'internal': []
508                                                              }
509                            orf_counter += 1
510                        else:
511                            orf_dict[f'ORF_{orf_counter - 1}']['internal'].append(positions)
512
513        return orf_dict

conserved ORF definition: - conserved starts and stops - start, stop must be on the same frame - stop - start must be at least min_length - all ungapped seqs[start:stop] must have at least min_length - no ungapped seq can have a Stop in between Start Stop

Conservation is measured by number of positions with identical characters divided by orf slice of the alignment.

Algorithm overview: - check for conserved start and stop codons - iterate over all three frames - check each start and next sufficiently far away stop codon - check if all ungapped seqs between start and stop codon are >= min_length - check if no ungapped seq in the alignment has a stop codon - write to dictionary - classify as internal if the stop codon has already been written with a prior start - repeat for reverse complement

Returns

ORF positions and internal ORF positions

def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff: float = None) -> dict:
515    def get_non_overlapping_conserved_orfs(self, min_length: int = 100, identity_cutoff:float = None) -> dict:
516        """
517        First calculates all ORFs and then searches from 5'
518        all non-overlapping orfs in the fw strand and from the
519        3' all non-overlapping orfs in th rw strand.
520
521        **No overlap algorithm:**
522            **frame 1:** -[M------*]--- ----[M--*]---------[M-----
523
524            **frame 2:** -------[M------*]---------[M---*]--------
525
526            **frame 3:** [M---*]-----[M----------*]----------[M---
527
528            **results:** [M---*][M------*]--[M--*]-[M---*]-[M-----
529
530            frame:    3      2           1      2       1
531
532        :return: dictionary with non-overlapping orfs
533        """
534        orf_dict = self.get_conserved_orfs(min_length, identity_cutoff)
535
536        fw_orfs, rw_orfs = [], []
537
538        for orf in orf_dict:
539            if orf_dict[orf]['strand'] == '+':
540                fw_orfs.append((orf, orf_dict[orf]['location'][0]))
541            else:
542                rw_orfs.append((orf, orf_dict[orf]['location'][0]))
543
544        fw_orfs.sort(key=lambda x: x[1][0])  # sort by start pos
545        rw_orfs.sort(key=lambda x: x[1][1], reverse=True)  # sort by stop pos
546        non_overlapping_orfs = []
547        for orf_list, strand in zip([fw_orfs, rw_orfs], ['+', '-']):
548            previous_stop = -1 if strand == '+' else self.length + 1
549            for orf in orf_list:
550                if strand == '+' and orf[1][0] > previous_stop:
551                    non_overlapping_orfs.append(orf[0])
552                    previous_stop = orf[1][1]
553                elif strand == '-' and orf[1][1] < previous_stop:
554                    non_overlapping_orfs.append(orf[0])
555                    previous_stop = orf[1][0]
556
557        non_overlap_dict = {}
558        for orf in orf_dict:
559            if orf in non_overlapping_orfs:
560                non_overlap_dict[orf] = orf_dict[orf]
561
562        return non_overlap_dict

First calculates all ORFs and then searches from 5' all non-overlapping orfs in the fw strand and from the 3' all non-overlapping orfs in th rw strand.

No overlap algorithm: frame 1: -[M------]--- ----[M--]---------[M-----

**frame 2:** -------[M------*]---------[M---*]--------

**frame 3:** [M---*]-----[M----------*]----------[M---

**results:** [M---*][M------*]--[M--*]-[M---*]-[M-----

frame:    3      2           1      2       1
Returns

dictionary with non-overlapping orfs

def calc_length_stats(self) -> dict:
564    def calc_length_stats(self) -> dict:
565        """
566        Determine the stats for the length of the ungapped seqs in the alignment.
567        :return: dictionary with length stats
568        """
569
570        seq_lengths = [len(self.alignment[x].replace('-', '')) for x in self.alignment]
571
572        return {'number of seq': len(self.alignment),
573                'mean length': float(np.mean(seq_lengths)),
574                'std length': float(np.std(seq_lengths)),
575                'min length': int(np.min(seq_lengths)),
576                'max length': int(np.max(seq_lengths))
577                }

Determine the stats for the length of the ungapped seqs in the alignment.

Returns

dictionary with length stats

def calc_entropy(self) -> list:
579    def calc_entropy(self) -> list:
580        """
581        Calculate the normalized shannon's entropy for every position in an alignment:
582
583        - 1: high entropy
584        - 0: low entropy
585
586        :return: Entropies at each position.
587        """
588
589        # helper functions
590        def shannons_entropy(character_list: list, states: int, aln_type: str) -> float:
591            """
592            Calculate the shannon's entropy of a sequence and
593            normalized between 0 and 1.
594            :param character_list: characters at an alignment position
595            :param states: number of potential characters that can be present
596            :param aln_type: type of the alignment
597            :returns: entropy
598            """
599            ent, n_chars = 0, len(character_list)
600            # only one char is in the list
601            if n_chars <= 1:
602                return ent
603            # calculate the number of unique chars and their counts
604            chars, char_counts = np.unique(character_list, return_counts=True)
605            char_counts = char_counts.astype(float)
606            # ignore gaps for entropy calc
607            char_counts, chars = char_counts[chars != "-"], chars[chars != "-"]
608            # correctly handle ambiguous chars
609            index_to_drop = []
610            for index, char in enumerate(chars):
611                if char in config.AMBIG_CHARS[aln_type]:
612                    index_to_drop.append(index)
613                    amb_chars, amb_counts = np.unique(config.AMBIG_CHARS[aln_type][char], return_counts=True)
614                    amb_counts = amb_counts / len(config.AMBIG_CHARS[aln_type][char])
615                    # add the proportionate numbers to initial array
616                    for amb_char, amb_count in zip(amb_chars, amb_counts):
617                        if amb_char in chars:
618                            char_counts[chars == amb_char] += amb_count
619                        else:
620                            chars, char_counts = np.append(chars, amb_char), np.append(char_counts, amb_count)
621            # drop the ambiguous characters from array
622            char_counts, chars = np.delete(char_counts, index_to_drop), np.delete(chars, index_to_drop)
623            # calc the entropy
624            probs = char_counts / n_chars
625            if np.count_nonzero(probs) <= 1:
626                return ent
627            for prob in probs:
628                ent -= prob * math.log(prob, states)
629
630            return ent
631
632        aln = self.alignment
633        entropys = []
634
635        if self.aln_type == 'AA':
636            states = 20
637        else:
638            states = 4
639        # iterate over alignment positions and the sequences
640        for nuc_pos in range(self.length):
641            pos = []
642            for record in aln:
643                pos.append(aln[record][nuc_pos])
644            entropys.append(shannons_entropy(pos, states, self.aln_type))
645
646        return entropys

Calculate the normalized shannon's entropy for every position in an alignment:

  • 1: high entropy
  • 0: low entropy
Returns

Entropies at each position.

def calc_gc(self) -> list | TypeError:
648    def calc_gc(self) -> list | TypeError:
649        """
650        Determine the GC content for every position in an nt alignment.
651        :return: GC content for every position.
652        :raises: TypeError for AA alignments
653        """
654        if self.aln_type == 'AA':
655            raise TypeError("GC computation is not possible for aminoacid alignment")
656
657        gc, aln, amb_nucs = [], self.alignment, config.AMBIG_CHARS[self.aln_type]
658
659        for position in range(self.length):
660            nucleotides = str()
661            for record in aln:
662                nucleotides = nucleotides + aln[record][position]
663            # ini dict with chars that occur and which ones to
664            # count in which freq
665            to_count = {
666                'G': 1,
667                'C': 1,
668            }
669            # handle ambig. nuc
670            for char in amb_nucs:
671                if char in nucleotides:
672                    to_count[char] = (amb_nucs[char].count('C') + amb_nucs[char].count('G')) / len(amb_nucs[char])
673
674            gc.append(
675                sum([nucleotides.count(x) * to_count[x] for x in to_count]) / len(nucleotides)
676            )
677
678        return gc

Determine the GC content for every position in an nt alignment.

Returns

GC content for every position.

Raises
  • TypeError for AA alignments
def calc_coverage(self) -> list:
680    def calc_coverage(self) -> list:
681        """
682        Determine the coverage of every position in an alignment.
683        This is defined as:
684            1 - cumulative length of '-' characters
685
686        :return: Coverage at each alignment position.
687        """
688        coverage, aln = [], self.alignment
689
690        for nuc_pos in range(self.length):
691            pos = str()
692            for record in aln.keys():
693                pos = pos + aln[record][nuc_pos]
694            coverage.append(1 - pos.count('-') / len(pos))
695
696        return coverage

Determine the coverage of every position in an alignment. This is defined as: 1 - cumulative length of '-' characters

Returns

Coverage at each alignment position.

def calc_reverse_complement_alignment(self) -> dict | TypeError:
698    def calc_reverse_complement_alignment(self) -> dict | TypeError:
699        """
700        Reverse complement the alignment.
701        :return: Alignment (rv)
702        """
703        if self.aln_type == 'AA':
704            raise TypeError('Reverse complement only for RNA or DNA.')
705
706        aln = self.alignment
707        reverse_complement_dict = {}
708
709        for seq_id in aln:
710            reverse_complement_dict[seq_id] = ''.join(config.COMPLEMENT[base] for base in reversed(aln[seq_id]))
711
712        return reverse_complement_dict

Reverse complement the alignment.

Returns

Alignment (rv)

def calc_identity_alignment( self, encode_mismatches: bool = True, encode_mask: bool = False, encode_gaps: bool = True, encode_ambiguities: bool = False, encode_each_mismatch_char: bool = False) -> numpy.ndarray:
714    def calc_identity_alignment(self, encode_mismatches:bool=True, encode_mask:bool=False, encode_gaps:bool=True, encode_ambiguities:bool=False, encode_each_mismatch_char:bool=False) -> np.ndarray:
715        """
716        Converts alignment to identity array (identical=0) compared to majority consensus or reference:\n
717
718        :param encode_mismatches: encode mismatch as -1
719        :param encode_mask: encode mask with value=-2 --> also in the reference
720        :param encode_gaps: encode gaps with np.nan --> also in the reference
721        :param encode_ambiguities: encode ambiguities with value=-3
722        :param encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.DNA_colors, config.RNA_colors or config.AA_colors
723        :return: identity alignment
724        """
725
726        aln = self.alignment
727        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
728
729        # convert alignment to array
730        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
731        reference = np.array(list(ref))
732        # ini matrix
733        identity_matrix = np.full(sequences.shape, 0, dtype=float)
734
735        is_identical = sequences == reference
736
737        if encode_gaps:
738            is_gap = sequences == '-'
739        else:
740            is_gap = np.full(sequences.shape, False)
741
742        if encode_mask:
743            if self.aln_type == 'AA':
744                is_n_or_x = np.isin(sequences, ['X'])
745            else:
746                is_n_or_x = np.isin(sequences, ['N'])
747        else:
748            is_n_or_x = np.full(sequences.shape, False)
749
750        if encode_ambiguities:
751            is_ambig = np.isin(sequences, [key for key in config.AMBIG_CHARS[self.aln_type] if key not in ['N', 'X', '-']])
752        else:
753            is_ambig = np.full(sequences.shape, False)
754
755        if encode_mismatches:
756            is_mismatch = ~is_gap & ~is_identical & ~is_n_or_x & ~is_ambig
757        else:
758            is_mismatch = np.full(sequences.shape, False)
759
760        # encode every different character
761        if encode_each_mismatch_char:
762            for idx, char in enumerate(config.CHAR_COLORS[self.aln_type]):
763                new_encoding = np.isin(sequences, [char]) & is_mismatch
764                identity_matrix[new_encoding] = idx + 1
765        # or encode different with a single value
766        else:
767            identity_matrix[is_mismatch] = -1  # mismatch
768
769        identity_matrix[is_gap] = np.nan  # gap
770        identity_matrix[is_n_or_x] = -2  # 'N' or 'X'
771        identity_matrix[is_ambig] = -3  # ambiguities
772
773        return identity_matrix

Converts alignment to identity array (identical=0) compared to majority consensus or reference:

Parameters
  • encode_mismatches: encode mismatch as -1
  • encode_mask: encode mask with value=-2 --> also in the reference
  • encode_gaps: encode gaps with np.nan --> also in the reference
  • encode_ambiguities: encode ambiguities with value=-3
  • encode_each_mismatch_char: for each mismatch encode characters separately - these values represent the idx+1 values of config.DNA_colors, config.RNA_colors or config.AA_colors
Returns

identity alignment

def calc_similarity_alignment( self, matrix_type: str | None = None, normalize: bool = True) -> numpy.ndarray:
775    def calc_similarity_alignment(self, matrix_type:str|None=None, normalize:bool=True) -> np.ndarray:
776        """
777        Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight
778        differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the
779        reference residue at each column. Gaps are encoded as np.nan.
780
781        The calculation follows these steps:
782
783        1. **Reference Sequence**: If a reference sequence is provided (via `self.reference_id`), it is used. Otherwise,
784           a consensus sequence is generated to serve as the reference.
785        2. **Substitution Matrix**: The similarity between residues is determined using a substitution matrix, such as
786           BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
787        3. **Per-Column Normalization (optional)**:
788           - For each column in the alignment:
789             - The residue in the reference sequence is treated as the baseline for that column.
790             - The substitution scores for the reference residue are extracted from the substitution matrix.
791             - The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference
792               residue.
793           - This ensures that identical residues (or those with high similarity to the reference) have high scores,
794             while more dissimilar residues have lower scores.
795        4. **Output**:
796           - The normalized similarity scores are stored in a NumPy array.
797           - Gaps (if any) or residues not present in the substitution matrix are encoded as `np.nan`.
798
799        :param: matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
800        :param: normalize: whether to normalize the similarity scores to range [0, 1]
801        :return: A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue
802            and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity).
803            Gaps and invalid residues are encoded as `np.nan`.
804        :raise: ValueError
805            If the specified substitution matrix is not available for the given alignment type.
806        """
807
808        aln = self.alignment
809        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
810        if matrix_type is None:
811            if self.aln_type == 'AA':
812                matrix_type = 'BLOSUM65'
813            else:
814                matrix_type = 'TRANS'
815        # load substitution matrix as dictionary
816        try:
817            subs_matrix = config.SUBS_MATRICES[self.aln_type][matrix_type]
818        except KeyError:
819            raise ValueError(
820                f'The specified matrix does not exist for alignment type.\nAvailable matrices for {self.aln_type} are:\n{list(config.SUBS_MATRICES[self.aln_type].keys())}'
821            )
822
823        # set dtype and convert alignment to a NumPy array for vectorized processing
824        dtype = np.dtype(float, metadata={'matrix': matrix_type})
825        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
826        reference = np.array(list(ref))
827        valid_chars = list(subs_matrix.keys())
828        similarity_array = np.full(sequences.shape, np.nan, dtype=dtype)
829
830        for j, ref_char in enumerate(reference):
831            if ref_char not in valid_chars + ['-']:
832                continue
833            # Get local min and max for the reference residue
834            if normalize and ref_char != '-':
835                local_scores = subs_matrix[ref_char].values()
836                local_min, local_max = min(local_scores), max(local_scores)
837
838            for i, char in enumerate(sequences[:, j]):
839                if char not in valid_chars:
840                    continue
841                # classify the similarity as max if the reference has a gap
842                similarity_score = subs_matrix[char][ref_char] if ref_char != '-' else 1
843                similarity_array[i, j] = (similarity_score - local_min) / (local_max - local_min) if normalize and ref_char != '-' else similarity_score
844
845        return similarity_array

Calculate the similarity score between the alignment and the reference sequence, with normalization to highlight differences. The similarity scores are scaled to the range [0, 1] based on the substitution matrix values for the reference residue at each column. Gaps are encoded as np.nan.

The calculation follows these steps:

  1. Reference Sequence: If a reference sequence is provided (via self.reference_id), it is used. Otherwise, a consensus sequence is generated to serve as the reference.
  2. Substitution Matrix: The similarity between residues is determined using a substitution matrix, such as BLOSUM65 for amino acids or BLASTN for nucleotides. The matrix is loaded based on the alignment type.
  3. Per-Column Normalization (optional):
    • For each column in the alignment:
      • The residue in the reference sequence is treated as the baseline for that column.
      • The substitution scores for the reference residue are extracted from the substitution matrix.
      • The scores are normalized to the range [0, 1] using the minimum and maximum possible scores for the reference residue.
    • This ensures that identical residues (or those with high similarity to the reference) have high scores, while more dissimilar residues have lower scores.
  4. Output:
    • The normalized similarity scores are stored in a NumPy array.
    • Gaps (if any) or residues not present in the substitution matrix are encoded as np.nan.
Parameters
  • matrix_type: type of similarity score (if not set - AA: BLOSSUM65, RNA/DNA: BLASTN)
  • normalize: whether to normalize the similarity scores to range [0, 1]
Returns

A 2D NumPy array where each entry corresponds to the normalized similarity score between the aligned residue and the reference residue for that column. Values range from 0 (low similarity) to 1 (high similarity). Gaps and invalid residues are encoded as np.nan. :raise: ValueError If the specified substitution matrix is not available for the given alignment type.

def calc_position_matrix(self, matrix_type: str = 'PWM') -> numpy.ndarray | ValueError:
847    def calc_position_matrix(self, matrix_type:str='PWM') -> np.ndarray | ValueError:
848        """
849        Calculate various position matrices (reference https://en.wikipedia.org/wiki/Position_weight_matrix)
850
851        **Major steps:**
852            1) calculate character counts (PFM)
853            2) calculate character frequencies (PPM)
854            3) add pseudocount (square root of row length) -> scales with aln size --> needed for positions with 0 counts
855            4) transform to PWM with M_k,j=log_2(M_k,j/b_k) with b_k assuming statistical independence (all chars are equally frequent)
856
857        :param matrix_type: matrix to return (PFM, PPM or PWM)
858        :return: pwm as numpy array
859        :raise: ValueError for incorrect matrix type
860        """
861
862        # ini
863        aln = self.alignment
864        if matrix_type not in ['PFM', 'PPM', 'PWM']:
865            raise ValueError('Matrix_type must be PFM, PPM or PWM.')
866        possible_chars = list(config.CHAR_COLORS[self.aln_type].keys())[:-1]
867        sequences = np.array([list(aln[seq_id]) for seq_id in list(aln.keys())])
868
869        # calc position frequency matrix
870        pfm = np.array([np.sum(sequences == char, 0) for char in possible_chars]) \
871              + np.sqrt(sequences.shape[0])  # add pseudo-counts
872        if matrix_type == 'PFM':
873            return pfm
874
875        # calc position probability matrix
876        ppm = pfm/np.sum(pfm, 0)
877        if matrix_type == 'PPM':
878            return ppm
879
880        # calc position weight matrix
881        pwm = np.log2(ppm*len(possible_chars))
882        if matrix_type == 'PWM':
883            return pwm

Calculate various position matrices (reference https://en.wikipedia.org/wiki/Position_weight_matrix)

Major steps: 1) calculate character counts (PFM) 2) calculate character frequencies (PPM) 3) add pseudocount (square root of row length) -> scales with aln size --> needed for positions with 0 counts 4) transform to PWM with M_k,j=log_2(M_k,j/b_k) with b_k assuming statistical independence (all chars are equally frequent)

Parameters
  • matrix_type: matrix to return (PFM, PPM or PWM)
Returns

pwm as numpy array :raise: ValueError for incorrect matrix type

def calc_percent_recovery(self) -> dict:
885    def calc_percent_recovery(self) -> dict:
886        """
887        Recovery per sequence either compared to the majority consensus seq
888        or the reference seq.\n
889        Defined as:\n
890
891        `(1 - sum(N/X/- characters in ungapped ref regions))*100`
892
893        This is highly similar to how nextclade calculates recovery over reference.
894
895        :return: dict
896        """
897
898        aln = self.alignment
899
900        if self.reference_id is not None:
901            ref = aln[self.reference_id]
902        else:
903            ref = self.get_consensus()  # majority consensus
904
905        if not any(char != '-' for char in ref):
906            raise ValueError("Reference sequence is entirely gapped, cannot calculate recovery.")
907
908
909        # count 'N', 'X' and '-' chars in non-gapped regions
910        recovery_over_ref = dict()
911
912        # Get positions of non-gap characters in the reference
913        non_gap_positions = [i for i, char in enumerate(ref) if char != '-']
914        cumulative_length = len(non_gap_positions)
915
916        # Calculate recovery
917        for seq_id in aln:
918            if seq_id == self.reference_id:
919                continue
920            seq = aln[seq_id]
921            count_invalid = sum(
922                seq[pos] == '-' or
923                (seq[pos] == 'X' if self.aln_type == "AA" else seq[pos] == 'N')
924                for pos in non_gap_positions
925            )
926            recovery_over_ref[seq_id] = (1 - count_invalid / cumulative_length) * 100
927
928        return recovery_over_ref

Recovery per sequence either compared to the majority consensus seq or the reference seq.

Defined as:

(1 - sum(N/X/- characters in ungapped ref regions))*100

This is highly similar to how nextclade calculates recovery over reference.

Returns

dict

def calc_character_frequencies(self) -> dict:
930    def calc_character_frequencies(self) -> dict:
931        """
932        Calculate the percentage characters in the alignment:
933        The frequencies are counted by seq and in total. The
934        percentage of non-gap characters in the alignment is
935        relative to the total number of non-gap characters.
936        The gap percentage is relative to the sequence length.
937
938        The output is a nested dictionary.
939
940        :return: Character frequencies
941        """
942
943        aln, aln_length = self.alignment, self.length
944
945        freqs = {'total': {'-': {'counts': 0, '% of alignment': float()}}}
946
947        for seq_id in aln:
948            freqs[seq_id], all_chars = {'-': {'counts': 0, '% of alignment': float()}}, 0
949            unique_chars = set(aln[seq_id])
950            for char in unique_chars:
951                if char == '-':
952                    continue
953                # add characters to dictionaries
954                if char not in freqs[seq_id]:
955                    freqs[seq_id][char] = {'counts': 0, '% of non-gapped': 0}
956                if char not in freqs['total']:
957                    freqs['total'][char] = {'counts': 0, '% of non-gapped': 0}
958                # count non-gap chars
959                freqs[seq_id][char]['counts'] += aln[seq_id].count(char)
960                freqs['total'][char]['counts'] += freqs[seq_id][char]['counts']
961                all_chars += freqs[seq_id][char]['counts']
962            # normalize counts
963            for char in freqs[seq_id]:
964                if char == '-':
965                    continue
966                freqs[seq_id][char]['% of non-gapped'] = freqs[seq_id][char]['counts'] / all_chars * 100
967                freqs['total'][char]['% of non-gapped'] += freqs[seq_id][char]['% of non-gapped']
968            # count gaps
969            freqs[seq_id]['-']['counts'] = aln[seq_id].count('-')
970            freqs['total']['-']['counts'] += freqs[seq_id]['-']['counts']
971            # normalize gap counts
972            freqs[seq_id]['-']['% of alignment'] = freqs[seq_id]['-']['counts'] / aln_length * 100
973            freqs['total']['-']['% of alignment'] += freqs[seq_id]['-']['% of alignment']
974
975        # normalize the total counts
976        for char in freqs['total']:
977            for value in freqs['total'][char]:
978                if value == '% of alignment' or value == '% of non-gapped':
979                    freqs['total'][char][value] = freqs['total'][char][value] / len(aln)
980
981        return freqs

Calculate the percentage characters in the alignment: The frequencies are counted by seq and in total. The percentage of non-gap characters in the alignment is relative to the total number of non-gap characters. The gap percentage is relative to the sequence length.

The output is a nested dictionary.

Returns

Character frequencies

def calc_pairwise_identity_matrix(self, distance_type: str = 'ghd') -> numpy.ndarray:
 983    def calc_pairwise_identity_matrix(self, distance_type:str='ghd') -> ndarray:
 984        """
 985        Calculate pairwise identities for an alignment. As there are different definitions of sequence identity, there are different options implemented:
 986
 987        **1) ghd (global hamming distance)**: At each alignment position, check if characters match:
 988        \ndistance = matches / alignment_length * 100
 989
 990        **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:
 991        \ndistance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100
 992
 993        **3) ged (gap excluded distance)**: All gaps are excluded from the alignment
 994        \ndistance = matches / (matches + mismatches) * 100
 995
 996        **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.
 997        \ndistance = matches / gap_compressed_alignment_length * 100
 998
 999        :return: array with pairwise distances.
1000        """
1001
1002        def hamming_distance(seq1: str, seq2: str) -> int:
1003            return sum(c1 == c2 for c1, c2 in zip(seq1, seq2))
1004
1005        def ghd(seq1: str, seq2: str) -> float:
1006            return hamming_distance(seq1, seq2) / self.length * 100
1007
1008        def lhd(seq1: str, seq2: str) -> float:
1009            # remove 5' trailing gaps
1010            i, j = 0, self.length - 1
1011            while i < self.length and (seq1[i] == '-' or seq2[i] == '-'):
1012                i += 1
1013            while j >= 0 and (seq1[j] == '-' or seq2[j] == '-'):
1014                j -= 1
1015            if i > j:
1016                return 0.0
1017            # slice seq
1018            seq1_, seq2_ = seq1[i:j + 1], seq2[i:j + 1]
1019
1020            return hamming_distance(seq1_, seq2_) / min([len(seq1_), len(seq2_)]) * 100
1021
1022        def ged(seq1: str, seq2: str) -> float:
1023
1024            matches, mismatches = 0, 0
1025
1026            for c1, c2 in zip(seq1, seq2):
1027                if c1 != '-' and c2 != '-':
1028                    if c1 == c2:
1029                        matches += 1
1030                    else:
1031                        mismatches += 1
1032            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1033
1034        def gcd(seq1: str, seq2: str) -> float:
1035            matches = 0
1036            mismatches = 0
1037            in_gap = False
1038
1039            for char1, char2 in zip(seq1, seq2):
1040                if char1 == '-' and char2 == '-':  # Shared gap: do nothing
1041                    continue
1042                elif char1 == '-' or char2 == '-':  # Gap in only one sequence
1043                    if not in_gap:  # Start of a new gap stretch
1044                        mismatches += 1
1045                        in_gap = True
1046                else:  # No gaps
1047                    in_gap = False
1048                    if char1 == char2:  # Matching characters
1049                        matches += 1
1050                    else:  # Mismatched characters
1051                        mismatches += 1
1052
1053            return matches / (matches + mismatches) * 100 if (matches + mismatches) > 0 else 0
1054
1055
1056        # Map distance type to corresponding function
1057        distance_functions: Dict[str, Callable[[str, str], float]] = {
1058            'ghd': ghd,
1059            'lhd': lhd,
1060            'ged': ged,
1061            'gcd': gcd
1062        }
1063
1064        if distance_type not in distance_functions:
1065            raise ValueError(f"Invalid distance type '{distance_type}'. Choose from {list(distance_functions.keys())}.")
1066
1067        # Compute pairwise distances
1068        aln = self.alignment
1069        distance_func = distance_functions[distance_type]
1070        distance_matrix = np.zeros((len(aln), len(aln)))
1071
1072        sequences = list(aln.values())
1073        for i, seq1 in enumerate(sequences):
1074            for j, seq2 in enumerate(sequences):
1075                if i <= j:  # Compute only once for symmetric matrix
1076                    distance_matrix[i, j] = distance_func(seq1, seq2)
1077                    distance_matrix[j, i] = distance_matrix[i, j]
1078
1079        return distance_matrix

Calculate pairwise identities for an alignment. As there are different definitions of sequence identity, there are different options implemented:

    **1) ghd (global hamming distance)**: At each alignment position, check if characters match:

distance = matches / alignment_length * 100

    **2) lhd (local hamming distance)**: Restrict the alignment to the region in both sequences that do not start and end with gaps:

distance = matches / min(5'3' ungapped seq1, 5'3' ungapped seq2) * 100

    **3) ged (gap excluded distance)**: All gaps are excluded from the alignment

distance = matches / (matches + mismatches) * 100

    **4) gcd (gap compressed distance)**: All consecutive gaps are compressed to one mismatch.

distance = matches / gap_compressed_alignment_length * 100

    :return: array with pairwise distances.
def get_snps(self, include_ambig: bool = False) -> dict:
1081    def get_snps(self, include_ambig:bool=False) -> dict:
1082        """
1083        Calculate snps similar to snp-sites (output is comparable):
1084        https://github.com/sanger-pathogens/snp-sites
1085        Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character.
1086        The SNPs are compared to a majority consensus sequence or to a reference if it has been set.
1087
1088        :param include_ambig: Include ambiguous snps (default: False)
1089        :return: dictionary containing snp positions and their variants including their frequency.
1090        """
1091        aln = self.alignment
1092        ref = aln[self.reference_id] if self.reference_id is not None else self.get_consensus()
1093        aln = {x: aln[x] for x in aln.keys() if x != self.reference_id}
1094        seq_ids = list(aln.keys())
1095        snp_dict = {'#CHROM': self.reference_id if self.reference_id is not None else 'consensus', 'POS': {}}
1096
1097        for pos in range(self.length):
1098            reference_char = ref[pos]
1099            if not include_ambig:
1100                if reference_char in config.AMBIG_CHARS[self.aln_type] and reference_char != '-':
1101                    continue
1102            alt_chars, snps = [], []
1103            for i, seq_id in enumerate(aln.keys()):
1104                alt_chars.append(aln[seq_id][pos])
1105                if reference_char != aln[seq_id][pos]:
1106                    snps.append(i)
1107            if not snps:
1108                continue
1109            if include_ambig:
1110                if all(alt_chars[x] in config.AMBIG_CHARS[self.aln_type] for x in snps):
1111                    continue
1112            else:
1113                snps = [x for x in snps if alt_chars[x] not in config.AMBIG_CHARS[self.aln_type]]
1114                if not snps:
1115                    continue
1116            if pos not in snp_dict:
1117                snp_dict['POS'][pos] = {'ref': reference_char, 'ALT': {}}
1118            for snp in snps:
1119                if alt_chars[snp] not in snp_dict['POS'][pos]['ALT']:
1120                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]] = {
1121                        'AF': 1,
1122                        'SEQ_ID': [seq_ids[snp]]
1123                    }
1124                else:
1125                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['AF'] += 1
1126                    snp_dict['POS'][pos]['ALT'][alt_chars[snp]]['SEQ_ID'].append(seq_ids[snp])
1127            # calculate AF
1128            if pos in snp_dict['POS']:
1129                for alt in snp_dict['POS'][pos]['ALT']:
1130                    snp_dict['POS'][pos]['ALT'][alt]['AF'] /= len(aln)
1131
1132        return snp_dict

Calculate snps similar to snp-sites (output is comparable): https://github.com/sanger-pathogens/snp-sites Importantly, SNPs are only considered if at least one of the snps is not an ambiguous character. The SNPs are compared to a majority consensus sequence or to a reference if it has been set.

Parameters
  • include_ambig: Include ambiguous snps (default: False)
Returns

dictionary containing snp positions and their variants including their frequency.

def calc_transition_transversion_score(self) -> list:
1134    def calc_transition_transversion_score(self) -> list:
1135        """
1136        Based on the snp positions, calculates a transition/transversions score.
1137        A positive score means higher ratio of transitions and negative score means
1138        a higher ratio of transversions.
1139        :return: list
1140        """
1141
1142        if self.aln_type == 'AA':
1143            raise TypeError('TS/TV scoring only for RNA/DNA alignments')
1144
1145        # ini
1146        snps = self.get_snps()
1147        score = [0]*self.length
1148
1149        for pos in snps['POS']:
1150            t_score_temp = 0
1151            for alt in snps['POS'][pos]['ALT']:
1152                # check the type of substitution
1153                if snps['POS'][pos]['ref'] + alt in ['AG', 'GA', 'CT', 'TC', 'CU', 'UC']:
1154                    score[pos] += snps['POS'][pos]['ALT'][alt]['AF']
1155                else:
1156                    score[pos] -= snps['POS'][pos]['ALT'][alt]['AF']
1157
1158        return score

Based on the snp positions, calculates a transition/transversions score. A positive score means higher ratio of transitions and negative score means a higher ratio of transversions.

Returns

list

class Annotation:
1161class Annotation:
1162    """
1163    An annotation class that allows to read in gff, gb or bed files and adjust its locations to that of the MSA.
1164    """
1165
1166    def __init__(self, aln: MSA, annotation_path: str):
1167        """
1168        The annotation class. Lets you parse multiple standard formats
1169        which might be used for annotating an alignment. The main purpose
1170        is to parse the annotation file and adapt the locations of diverse
1171        features to the locations within the alignment, considering the
1172        respective alignment positions. Importantly, IDs of the alignment
1173        and the MSA have to partly match.
1174
1175        :param aln: MSA class
1176        :param annotation_path: path to annotation file (gb, bed, gff).
1177
1178        """
1179
1180        self.ann_type, self._seq_id, self.locus, self.features  = self._parse_annotation(annotation_path, aln)  # read annotation
1181        self._gapped_seq = self._MSA_validation_and_seq_extraction(aln, self._seq_id)  # extract gapped sequence
1182        self._position_map = self._build_position_map()  # build a position map
1183        self._map_to_alignment()  # adapt feature locations
1184
1185    @staticmethod
1186    def _MSA_validation_and_seq_extraction(aln: MSA, seq_id: str) -> str:
1187        """
1188        extract gapped sequence from MSA that corresponds to annotation
1189        :param aln: MSA class
1190        :param seq_id: sequence id to extract
1191        :return: gapped sequence
1192        """
1193        if not isinstance(aln, MSA):
1194            raise ValueError('alignment has to be an MSA class. use explore.MSA() to read in alignment')
1195        else:
1196            return aln._alignment[seq_id]
1197
1198    @staticmethod
1199    def _parse_annotation(annotation_path: str, aln: MSA) -> tuple[str, str, str, Dict]:
1200
1201        def detect_annotation_type(file_path: str) -> str:
1202            """
1203            Detect the type of annotation file (GenBank, GFF, or BED) based
1204            on the first relevant line (excluding empty and #)
1205
1206            :param file_path: Path to the annotation file.
1207            :return: The detected file type ('gb', 'gff', or 'bed').
1208
1209            :raises ValueError: If the file type cannot be determined.
1210            """
1211
1212            with open(file_path, 'r') as file:
1213                for line in file:
1214                    # skip empty lines and comments
1215                    if not line.strip() or line.startswith('#'):
1216                        continue
1217                   # genbank
1218                    if line.startswith('LOCUS'):
1219                        return 'gb'
1220                    # gff
1221                    if len(line.split('\t')) == 9:
1222                        # Check for expected values
1223                        columns = line.split('\t')
1224                        if columns[6] in ['+', '-', '.'] and re.match(r'^\d+$', columns[3]) and re.match(r'^\d+$',columns[4]):
1225                            return 'gff'
1226                    # BED files are tab-delimited with at least 3 fields: chrom, start, end
1227                    fields = line.split('\t')
1228                    if len(fields) >= 3 and re.match(r'^\d+$', fields[1]) and re.match(r'^\d+$', fields[2]):
1229                        return 'bed'
1230                    # only read in the first line
1231                    break
1232
1233            raise ValueError(
1234                "File type could not be determined. Ensure the file follows a recognized format (GenBank, GFF, or BED).")
1235
1236        def parse_gb(file_path) -> dict:
1237            """
1238            parse a genebank file to dictionary - primarily retained are the informations
1239            for qualifiers as these will be used for plotting.
1240
1241            :param file_path: path to genebank file
1242            :return: nested dictionary
1243
1244            """
1245
1246            def sanitize_gb_location(string: str) -> tuple[list, str]:
1247                """
1248                see: https://www.insdc.org/submitting-standards/feature-table/
1249                """
1250                strand = '+'
1251                locations = []
1252                # check the direction of the annotation
1253                if 'complement' in string:
1254                    strand = '-'
1255                # sanitize operators
1256                for operator in ['complement(', 'join(', 'order(']:
1257                    string = string.strip(operator)
1258                # sanitize possible chars for splitting start stop -
1259                # however in the future might not simply do this
1260                # as some useful information is retained
1261                for char in ['>', '<', ')']:
1262                    string = string.replace(char, '')
1263                # check if we have multiple location e.g. due to splicing
1264                if ',' in string:
1265                    raw_locations = string.split(',')
1266                else:
1267                    raw_locations = [string]
1268                # try to split start and stop
1269                for location in raw_locations:
1270                    for sep in ['..', '.', '^']:
1271                        if sep in location:
1272                            sanitized_locations = [int(x) for x in location.split(sep)]
1273                            sanitized_locations[0] = sanitized_locations[0] - 1  # enforce 0-based starts
1274                            locations.append(sanitized_locations)
1275                            break
1276
1277                return locations, strand
1278
1279
1280            records = {}
1281            with open(file_path, "r") as file:
1282                record = None
1283                in_features = False
1284                counter_dict = {}
1285                for line in file:
1286                    line = line.rstrip()
1287                    parts = line.split()
1288                    # extract the locus id
1289                    if line.startswith('LOCUS'):
1290                        if record:
1291                            records[record['locus']] = record
1292                        record = {
1293                            'locus': parts[1],
1294                            'features': {}
1295                        }
1296
1297                    elif line.startswith('FEATURES'):
1298                        in_features = True
1299
1300                    # ignore the sequence info
1301                    elif line.startswith('ORIGIN'):
1302                        in_features = False
1303
1304                    # now write useful feature information to dictionary
1305                    elif in_features:
1306                        if not line.strip():
1307                            continue
1308                        if line[5] != ' ':
1309                            feature_type, qualifier = parts[0], parts[1]
1310                            if feature_type not in record['features']:
1311                                record['features'][feature_type] = {}
1312                                counter_dict[feature_type] = 0
1313                            locations, strand = sanitize_gb_location(qualifier)
1314                            record['features'][feature_type][counter_dict[feature_type]] = {
1315                                'location': locations,
1316                                'strand': strand
1317                            }
1318                            counter_dict[feature_type] += 1
1319                        else:
1320                            try:
1321                                qualifier_type, qualifier = parts[0].split('=')
1322                            except ValueError:  # we are in the coding sequence
1323                                qualifier = qualifier + parts[0]
1324
1325                            qualifier_type, qualifier = qualifier_type.lstrip('/'), qualifier.strip('"')
1326                            last_index = counter_dict[feature_type] - 1
1327                            record['features'][feature_type][last_index][qualifier_type] = qualifier
1328
1329            records[record['locus']] = record
1330
1331            return records
1332
1333        def parse_gff(file_path) -> dict:
1334            """
1335            Parse a GFF3 (General Feature Format) file into a dictionary structure.
1336
1337            :param file_path: path to genebank file
1338            :return: nested dictionary
1339
1340            """
1341            records = {}
1342            with open(file_path, 'r') as file:
1343                previous_id, previous_feature = None, None
1344                for line in file:
1345                    if line.startswith('#') or not line.strip():
1346                        continue
1347                    parts = line.strip().split('\t')
1348                    seqid, source, feature_type, start, end, score, strand, phase, attributes = parts
1349                    # ensure that region and source features are not named differently for gff and gb
1350                    if feature_type == 'region':
1351                        feature_type = 'source'
1352                    if seqid not in records:
1353                        records[seqid] = {'locus': seqid, 'features': {}}
1354                    if feature_type not in records[seqid]['features']:
1355                        records[seqid]['features'][feature_type] = {}
1356
1357                    feature_id = len(records[seqid]['features'][feature_type])
1358                    feature = {
1359                        'strand': strand,
1360                    }
1361
1362                    # Parse attributes into key-value pairs
1363                    for attr in attributes.split(';'):
1364                        if '=' in attr:
1365                            key, value = attr.split('=', 1)
1366                            feature[key.strip()] = value.strip()
1367
1368                    # check if feature are the same --> possible splicing
1369                    if previous_id is not None and previous_feature == feature:
1370                        records[seqid]['features'][feature_type][previous_id]['location'].append([int(start)-1, int(end)])
1371                    else:
1372                        records[seqid]['features'][feature_type][feature_id] = feature
1373                        records[seqid]['features'][feature_type][feature_id]['location'] = [[int(start) - 1, int(end)]]
1374                    # set new previous id and features -> new dict as 'location' is pointed in current feature and this
1375                    # is the only key different if next feature has the same entries
1376                    previous_id, previous_feature = feature_id, {key:value for key, value in feature.items() if key != 'location'}
1377
1378            return records
1379
1380        def parse_bed(file_path) -> dict:
1381            """
1382            Parse a BED file into a dictionary structure.
1383
1384            :param file_path: path to genebank file
1385            :return: nested dictionary
1386
1387            """
1388            records = {}
1389            with open(file_path, 'r') as file:
1390                for line in file:
1391                    if line.startswith('#') or not line.strip():
1392                        continue
1393                    parts = line.strip().split('\t')
1394                    chrom, start, end, *optional = parts
1395
1396                    if chrom not in records:
1397                        records[chrom] = {'locus': chrom, 'features': {}}
1398                    feature_type = 'region'
1399                    if feature_type not in records[chrom]['features']:
1400                        records[chrom]['features'][feature_type] = {}
1401
1402                    feature_id = len(records[chrom]['features'][feature_type])
1403                    feature = {
1404                        'location': [[int(start), int(end)]],  # BED uses 0-based start, convert to 1-based
1405                        'strand': '+',  # assume '+' if not present
1406                    }
1407
1408                    # Handle optional columns (name, score, strand) --> ignore 7-12
1409                    if len(optional) >= 1:
1410                        feature['name'] = optional[0]
1411                    if len(optional) >= 2:
1412                        feature['score'] = optional[1]
1413                    if len(optional) >= 3:
1414                        feature['strand'] = optional[2]
1415
1416                    records[chrom]['features'][feature_type][feature_id] = feature
1417
1418            return records
1419
1420        parse_functions: Dict[str, Callable[[str], dict]] = {
1421            'gb': parse_gb,
1422            'bed': parse_bed,
1423            'gff': parse_gff,
1424        }
1425        # determine the annotation content -> should be standard formatted
1426        try:
1427            annotation_type = detect_annotation_type(annotation_path)
1428        except ValueError as err:
1429            raise err
1430
1431        # read in the annotation
1432        annotations = parse_functions[annotation_type](annotation_path)
1433
1434        # sanity check whether one of the annotation ids and alignment ids match
1435        annotation_found = False
1436        for annotation in annotations.keys():
1437            for aln_id in aln.alignment.keys():
1438                aln_id_sanitized = aln_id.split(' ')[0]
1439                # check in both directions
1440                if aln_id_sanitized in annotation:
1441                    annotation_found = True
1442                    break
1443                if annotation in aln_id_sanitized:
1444                    annotation_found = True
1445                    break
1446
1447        if not annotation_found:
1448            raise ValueError(f'the annotations of {annotation_path} do not match any ids in the MSA')
1449
1450        # return only the annotation that has been found, the respective type and the seq_id to map to
1451        return annotation_type, aln_id, annotations[annotation]['locus'], annotations[annotation]['features']
1452
1453
1454    def _build_position_map(self) -> Dict[int, int]:
1455        """
1456        build a position map from a sequence.
1457
1458        :return genomic position: gapped position
1459        """
1460
1461        position_map = {}
1462        genomic_pos = 0
1463        for aln_index, char in enumerate(self._gapped_seq):
1464            if char != '-':
1465                position_map[genomic_pos] = aln_index
1466                genomic_pos += 1
1467        # ensure the last genomic position is included
1468        position_map[genomic_pos] = len(self._gapped_seq)
1469
1470        return position_map
1471
1472
1473    def _map_to_alignment(self):
1474        """
1475        Adjust all feature locations to alignment positions
1476        """
1477
1478        def map_location(position_map: Dict[int, int], locations: list) -> list:
1479            """
1480            Map genomic locations to alignment positions using a precomputed position map.
1481
1482            :param position_map: Positions mapped from gapped to ungapped
1483            :param locations: List of genomic start and end positions.
1484            :return: List of adjusted alignment positions.
1485            """
1486
1487            aligned_locs = []
1488            for start, end in locations:
1489                try:
1490                    aligned_start = position_map[start]
1491                    aligned_end = position_map[end]
1492                    aligned_locs.append([aligned_start, aligned_end])
1493                except KeyError:
1494                    raise ValueError(f"Positions {start}-{end} lie outside of the position map.")
1495
1496            return aligned_locs
1497
1498        for feature_type, features in self.features.items():
1499            for feature_id, feature_data in features.items():
1500                original_locations = feature_data['location']
1501                aligned_locations = map_location(self._position_map, original_locations)
1502                feature_data['location'] = aligned_locations

An annotation class that allows to read in gff, gb or bed files and adjust its locations to that of the MSA.

Annotation(aln: MSA, annotation_path: str)
1166    def __init__(self, aln: MSA, annotation_path: str):
1167        """
1168        The annotation class. Lets you parse multiple standard formats
1169        which might be used for annotating an alignment. The main purpose
1170        is to parse the annotation file and adapt the locations of diverse
1171        features to the locations within the alignment, considering the
1172        respective alignment positions. Importantly, IDs of the alignment
1173        and the MSA have to partly match.
1174
1175        :param aln: MSA class
1176        :param annotation_path: path to annotation file (gb, bed, gff).
1177
1178        """
1179
1180        self.ann_type, self._seq_id, self.locus, self.features  = self._parse_annotation(annotation_path, aln)  # read annotation
1181        self._gapped_seq = self._MSA_validation_and_seq_extraction(aln, self._seq_id)  # extract gapped sequence
1182        self._position_map = self._build_position_map()  # build a position map
1183        self._map_to_alignment()  # adapt feature locations

The annotation class. Lets you parse multiple standard formats which might be used for annotating an alignment. The main purpose is to parse the annotation file and adapt the locations of diverse features to the locations within the alignment, considering the respective alignment positions. Importantly, IDs of the alignment and the MSA have to partly match.

Parameters
  • aln: MSA class
  • annotation_path: path to annotation file (gb, bed, gff).