-
-
Notifications
You must be signed in to change notification settings - Fork 155
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
make retinaface compatible with tf2.16 and later
- Loading branch information
Showing
2 changed files
with
36 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# 3rd party dependencies | ||
import tensorflow as tf | ||
|
||
# project dependencies | ||
from retinaface.commons.logger import Logger | ||
|
||
logger = Logger(module="retinaface/commons/package_utils.py") | ||
|
||
|
||
def validate_for_keras3(): | ||
tf_major = int(tf.__version__.split(".", maxsplit=1)[0]) | ||
tf_minor = int(tf.__version__.split(".", maxsplit=-1)[1]) | ||
|
||
# tf_keras is a must dependency after tf 2.16 | ||
if tf_major == 1 or (tf_major == 2 and tf_minor < 16): | ||
return | ||
|
||
try: | ||
import tf_keras | ||
|
||
logger.debug(f"tf_keras is already available - {tf_keras.__version__}") | ||
except ImportError as err: | ||
# you may consider to install that package here | ||
raise ValueError( | ||
f"You have tensorflow {tf.__version__} and this requires " | ||
"tf-keras package. Please run `pip install tf-keras` " | ||
"or downgrade your tensorflow." | ||
) from err |