Skip to content

Commit

Permalink
Merge branch 'yves' of https://github.com/GT4SD/gt4sd-core into yves
Browse files Browse the repository at this point in the history
  • Loading branch information
yvesnana committed Mar 6, 2024
2 parents e96d724 + 9baef5e commit 3913c8f
Show file tree
Hide file tree
Showing 7 changed files with 0 additions and 72 deletions.
1 change: 0 additions & 1 deletion src/gt4sd/frameworks/enzeptional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,3 @@
"""

from .core import EnzymeOptimizer # noqa: F401
from .core import EnzymeOptimizer # noqa: F401
Binary file not shown.
Binary file modified src/gt4sd/frameworks/enzeptional/__pycache__/core.cpython-38.pyc
Binary file not shown.
Binary file not shown.
71 changes: 0 additions & 71 deletions src/gt4sd/frameworks/enzeptional/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
#
from abc import ABC
import torch
import torch
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from tape.datasets import pad_sequences
from tape.registry import registry
from tape.tokenizers import TAPETokenizer
Expand Down Expand Up @@ -61,23 +59,12 @@ def __init__(self):
"""
self.cache = {}

def get(self, key):
"""
Retrieves a model from the cache using the given key.
"""
def __init__(self):
"""
Initializes the cache as an empty dictionary.
"""
self.cache = {}

def get(self, key):
"""
Retrieves a model from the cache using the given key.
Args:
key: The key used to store the model.
key: The key used to store the model.
Returns:
The model associated with the key, or None if not found.
Expand Down Expand Up @@ -116,18 +103,10 @@ def embed(self, samples: List[str]) -> np.ndarray:
Raises:
NotImplementedError: If the method is not implemented in the subclass.
Returns:
np.ndarray: The resulting embeddings as a NumPy array.
"""
raise NotImplementedError


class HFandTAPEModelUtility(StringEmbedding):
"""
Utility class for handling both Hugging Face and TAPE models for embedding
and unmasking tasks.
"""
class HFandTAPEModelUtility(StringEmbedding):
"""
Utility class for handling both Hugging Face and TAPE models for embedding
Expand All @@ -140,13 +119,8 @@ def __init__(
tokenizer_path: str,
unmasking_model_path: Optional[str] = None,
is_tape_model: bool = False,
embedding_model_path: str,
tokenizer_path: str,
unmasking_model_path: Optional[str] = None,
is_tape_model: bool = False,
device: Optional[Union[torch.device, str]] = None,
cache_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
) -> None:
"""Initializes the utility with specified model and tokenizer paths.
Expand Down Expand Up @@ -259,7 +233,6 @@ def _embed_tape(self, samples: List[str]) -> np.ndarray:
Args:
samples (List[str]): List of strings to be embedded.
samples (List[str]): List of strings to be embedded.
Returns:
np.ndarray: The resulting embeddings.
Expand Down Expand Up @@ -296,7 +269,6 @@ def _embed_huggingface(self, samples: List[str]) -> np.ndarray:
Args:
samples (List[str]): List of strings to be embedded.
samples (List[str]): List of strings to be embedded.
Returns:
np.ndarray: The resulting embeddings.
Expand All @@ -315,25 +287,9 @@ def _embed_huggingface(self, samples: List[str]) -> np.ndarray:

sequence_lengths = inputs["attention_mask"].sum(1)

inputs = self.tokenizer(
samples,
add_special_tokens=True,
padding=True,
return_tensors="pt",
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}

with torch.no_grad():
outputs = self.embedding_model(**inputs)
sequence_embeddings = outputs[0].cpu().detach().numpy()

sequence_lengths = inputs["attention_mask"].sum(1)

return np.array(
[
sequence_embedding[:sequence_length].mean(0)
for sequence_embedding, sequence_length in zip(
sequence_embeddings, sequence_lengths
for sequence_embedding, sequence_length in zip(
sequence_embeddings, sequence_lengths
)
Expand Down Expand Up @@ -363,15 +319,6 @@ def unmask(self, sequence: str, top_k: int = 2) -> List[str]:
except (KeyError, NotImplementedError) as e:
logger.warning(f"{e} Standard unmasking failed ")
raise KeyError("Check the unmasking model you want to use")
if self.is_tape_model:
logger.error("Unmasking is not supported for TAPE models.")
raise NotImplementedError("Unmasking is not supported for TAPE models.")

try:
return self._unmask_with_model(sequence, top_k)
except (KeyError, NotImplementedError) as e:
logger.warning(f"{e} Standard unmasking failed ")
raise KeyError("Check the unmasking model you want to use")

def _unmask_with_model(self, sequence: str, top_k: int) -> List[str]:
"""Unmasks a sequence using the model, providing top-k predictions.
Expand Down Expand Up @@ -440,12 +387,9 @@ def mutate_sequence_with_variant(sequence: str, variant: str) -> str:
Args:
sequence (str): The original amino acid sequence.
variant (str): The variant to apply, formatted as a string.
sequence (str): The original amino acid sequence.
variant (str): The variant to apply, formatted as a string.
Returns:
str: The mutated amino acid sequence.
str: The mutated amino acid sequence.
"""
mutated_sequence = list(sequence)
for variant_string in variant.split("/"):
Expand All @@ -460,20 +404,12 @@ def sanitize_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[int, int]
Args:
intervals (List[Tuple[int, int]]): A list of
start and end points of intervals.
intervals (List[Tuple[int, int]]): A list of
start and end points of intervals.
Returns:
List[Tuple[int, int]]: A list of merged intervals.
List[Tuple[int, int]]: A list of merged intervals.
"""
intervals.sort()
merged: List[Tuple[int, int]] = []
for start, end in intervals:
if not merged or merged[-1][1] < start:
merged.append((start, end))
intervals.sort()
merged: List[Tuple[int, int]] = []
for start, end in intervals:
if not merged or merged[-1][1] < start:
merged.append((start, end))
Expand Down Expand Up @@ -551,19 +487,12 @@ def reconstruct_sequence_with_mutation_range(
mutated_sequence_range (str): The range of the sequence to be mutated.
intervals (List[Tuple[int, int]]): The intervals where
mutations are applied.
sequence (str): The original sequence.
mutated_sequence_range (str): The range of the sequence to be mutated.
intervals (List[Tuple[int, int]]): The intervals where
mutations are applied.
Returns:
str: The reconstructed sequence with mutations.
str: The reconstructed sequence with mutations.
"""
mutated_sequence = list(sequence)
range_index = 0
mutated_sequence = list(sequence)
range_index = 0
for start, end in intervals:
size_fragment = end - start
mutated_sequence[start:end] = list(
Expand Down
Binary file not shown.
Binary file not shown.

0 comments on commit 3913c8f

Please sign in to comment.