-
Notifications
You must be signed in to change notification settings - Fork 6
/
corrections_craft.patch
50 lines (44 loc) · 1.69 KB
/
corrections_craft.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
diff --git a/craft_text_detector/image_utils.py b/craft_text_detector/image_utils.py
index ad097ab..d15071f 100644
--- a/craft_text_detector/image_utils.py
+++ b/craft_text_detector/image_utils.py
@@ -5,9 +5,12 @@ MIT License
import cv2
import numpy as np
+import PIL
def read_image(image):
+ img = None
+
if type(image) == str:
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
@@ -24,7 +27,8 @@ def read_image(image):
img = image
elif len(image.shape) == 3 and image.shape[2] == 4: # RGBAscale
img = image[:, :, :3]
-
+ elif isinstance(image, PIL.Image.Image):
+ img = np.array(image)
return img
diff --git a/craft_text_detector/models/basenet/vgg16_bn.py b/craft_text_detector/models/basenet/vgg16_bn.py
index e58d0b6..e844093 100644
--- a/craft_text_detector/models/basenet/vgg16_bn.py
+++ b/craft_text_detector/models/basenet/vgg16_bn.py
@@ -4,8 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.init as init
from torchvision import models
-from torchvision.models.vgg import model_urls
-
+# from torchvision.models.vgg import model_urls
def init_weights(modules):
for m in modules:
@@ -24,7 +23,7 @@ def init_weights(modules):
class vgg16_bn(torch.nn.Module):
def __init__(self, pretrained=True, freeze=True):
super(vgg16_bn, self).__init__()
- model_urls["vgg16_bn"] = model_urls["vgg16_bn"].replace("https://", "http://")
+ # model_urls["vgg16_bn"] = VGG16_BN.replace("https://", "http://")
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()