generated from GDGVIT/template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
client.js
109 lines (85 loc) · 3.85 KB
/
client.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const path = require('path');
class StyleTransfer {
constructor() {
this.styleNet = null;
this.transformNet = null;
}
async loadModels() {
this.styleNet = await tf.loadGraphModel('file://./saved_model_style_inception_js/model.json');
this.transformNet = await tf.loadGraphModel('file://./saved_model_transformer_js/model.json');
}
async stylizeImage(contentPath, stylePath, outputPath, styleRatio = 1.0) {
await this.loadModels();
const content = await this.loadImage(contentPath);
const style = await this.loadImage(stylePath);
const stylized = await tf.tidy(() => {
const contentTensor = content.toFloat().div(tf.scalar(255)).expandDims();
const styleTensor = style.toFloat().div(tf.scalar(255)).expandDims();
let bottleneck = this.styleNet.predict(styleTensor);
if (styleRatio !== 1.0) {
const identityBottleneck = this.styleNet.predict(contentTensor);
const styleBottleneck = bottleneck;
bottleneck = tf.tidy(() => {
const styleBottleneckScaled = styleBottleneck.mul(tf.scalar(styleRatio));
const identityBottleneckScaled = identityBottleneck.mul(tf.scalar(1.0 - styleRatio));
return styleBottleneckScaled.add(identityBottleneckScaled);
});
identityBottleneck.dispose();
}
return this.transformNet.predict([contentTensor, bottleneck]).squeeze();
});
await this.saveImage(stylized, outputPath);
stylized.dispose();
}
async combineStyles(contentPath, style1Path, style2Path, outputPath, styleRatio = 0.5) {
await this.loadModels();
const content = await this.loadImage(contentPath);
const style1 = await this.loadImage(style1Path);
const style2 = await this.loadImage(style2Path);
const stylized = await tf.tidy(() => {
const contentTensor = content.toFloat().div(tf.scalar(255)).expandDims();
const style1Tensor = style1.toFloat().div(tf.scalar(255)).expandDims();
const style2Tensor = style2.toFloat().div(tf.scalar(255)).expandDims();
const bottleneck1 = this.styleNet.predict(style1Tensor);
const bottleneck2 = this.styleNet.predict(style2Tensor);
const combinedBottleneck = tf.tidy(() => {
const scaledBottleneck1 = bottleneck1.mul(tf.scalar(1 - styleRatio));
const scaledBottleneck2 = bottleneck2.mul(tf.scalar(styleRatio));
return scaledBottleneck1.add(scaledBottleneck2);
});
return this.transformNet.predict([contentTensor, combinedBottleneck]).squeeze();
});
await this.saveImage(stylized, outputPath);
stylized.dispose();
}
async loadImage(imagePath) {
const imageBuffer = fs.readFileSync(imagePath);
const tfImage = tf.node.decodeImage(imageBuffer);
return tfImage;
}
async saveImage(tensor, outputPath) {
const [height, width] = tensor.shape;
const uint8Array = await tf.node.encodeJpeg(tensor.mul(255).cast('int32'));
fs.writeFileSync(outputPath, uint8Array);
}
}
async function main() {
const styleTransfer = new StyleTransfer();
// Single style transfer
const contentPath = './skull.jpg';
const stylePath = './paint.jpg';
const outputPath = './stylized_image.jpg';
const styleRatio = 0.95;
await styleTransfer.stylizeImage(contentPath, stylePath, outputPath, styleRatio);
console.log('Single style transfer: Stylized image saved to:', outputPath);
// Combined style transfer
const style1Path = './paint.jpg';
const style2Path = './flowers.jpg';
const combinedOutputPath = './combined_stylized_image.jpg';
const combinedStyleRatio = 0.5; // Equal mix of both styles
await styleTransfer.combineStyles(contentPath, style1Path, style2Path, combinedOutputPath, combinedStyleRatio);
console.log('Combined style transfer: Stylized image saved to:', combinedOutputPath);
}
main().catch(console.error);