diff --git a/proteinflow/__init__.py b/proteinflow/__init__.py index 58575c9..f84593b 100644 --- a/proteinflow/__init__.py +++ b/proteinflow/__init__.py @@ -438,6 +438,7 @@ def split_data( ignore_existing=False, min_seq_id=0.3, exclude_chains=None, + exclude_chains_file=None, exclude_threshold=0.7, exclude_clusters=False, exclude_based_on_cdr=None, @@ -485,6 +486,8 @@ def split_data( minimum sequence identity for `mmseqs` exclude_chains : list of str, optional a list of chains (`{pdb_id}-{chain_id}`) to exclude from the splitting (e.g. `["1A2B-A", "1A2B-B"]`); chain id is the author chain id + exclude_chains_file : str, optional + path to a file containing the sequences to exclude, one sequence per line exclude_threshold : float in [0, 1], default 0.7 the sequence similarity threshold for excluding chains exclude_clusters : bool, default False @@ -509,16 +512,17 @@ def split_data( temp_folder = os.path.join(tempfile.gettempdir(), "proteinflow") if not os.path.exists(temp_folder): os.makedirs(temp_folder) - if exclude_chains is None or len(exclude_chains) == 0: - excluded_biounits = [] - else: + if exclude_chains_file is not None or exclude_chains is not None: excluded_biounits = _get_excluded_files( tag, local_datasets_folder, temp_folder, exclude_chains, + exclude_chains_file, exclude_threshold, ) + else: + excluded_biounits = [] if exclude_chains_without_ligands: excluded_biounits += _exclude_files_with_no_ligand( tag, diff --git a/proteinflow/cli.py b/proteinflow/cli.py index c92468a..e583101 100644 --- a/proteinflow/cli.py +++ b/proteinflow/cli.py @@ -255,6 +255,11 @@ def generate(**kwargs): type=str, help="Exclude specific chains from the dataset ({pdb_id}-{chain_id}, e.g. -e 1a2b-A)", ) +@click.option( + "--exclude_chains_file", + type=str, + help="Exclude specific chains from the dataset (path to a file containing the sequences to exclude, one sequence per line)", +) @click.option( "--exclude_threshold", default=0.7, diff --git a/proteinflow/split/__init__.py b/proteinflow/split/__init__.py index ef48c10..f46f526 100644 --- a/proteinflow/split/__init__.py +++ b/proteinflow/split/__init__.py @@ -142,8 +142,6 @@ def _read_clusters(tmp_folder, cdr=None): for k in cluster_pdb_dict.keys(): cluster_pdb_dict[k] = np.unique(cluster_pdb_dict[k]) - print(f"{cluster_dict=}") - print(f"{cluster_pdb_dict=}") return cluster_dict, cluster_pdb_dict @@ -1304,7 +1302,12 @@ def _get_split_dictionaries( def _get_excluded_files( - tag, local_datasets_folder, tmp_folder, exclude_chains, exclude_threshold + tag, + local_datasets_folder, + tmp_folder, + exclude_chains, + exclude_chains_file, + exclude_threshold, ): """Get a list of files to exclude from the dataset. @@ -1321,6 +1324,8 @@ def _get_excluded_files( the path to the folder that stores temporary files exclude_chains : list of str, optional a list of chains (`{pdb_id}-{chain_id}`) to exclude from the splitting (e.g. `["1A2B-A", "1A2B-B"]`); chain id is the author chain id + exclude_chains_file : str, optional + path to a file containing the sequences to exclude, one sequence per line exclude_threshold : float in [0, 1], default 0.7 the sequence similarity threshold for excluding chains @@ -1339,6 +1344,9 @@ def _get_excluded_files( chains = PDBEntry.parse_fasta(outfnm) sequences.append(chains[chain_id]) os.remove(outfnm) + if exclude_chains_file is not None: + with open(exclude_chains_file) as f: + sequences += [line.strip() for line in f.readlines()] # iterate over files in the dataset to check similarity print("Checking excluded chains similarity...") @@ -1365,6 +1373,7 @@ def _get_excluded_files( break # return list of biounits to exclude + print(f"{exclude_biounits=}") return exclude_biounits