前言
这是一篇讲述想要跟风摸鱼,啥也不懂却也可以实现功能的解决方案。
自述
Java 新手,也不会 python。但是 人間讃歌は「勇気」の讃歌ッ!! 人間のすばらしさは勇気のすばらしさ!!
。
资源准备
- fast_neural_style
- fast-style-transfer
- 基于 macOS 搭建一个 tensorflow 环境
- Windows 10 搭建 TensorFlow 试玩 fast-style-transfer
- 百度
- 手和大脑
测试
路由器没带所以改个源吧。
sudo vi ~/.config/pip/pip.conf
[global]
#index-url = https://pypi.tuna.tsinghua.edu.cn/simple
index-url = https://mirrors.aliyun.com/pypi/simple
git clone https://github.com/lengstrom/fast-style-transfer
cd fast-style-transfer/
pyenv activate v370env
python evaluate.py --checkpoint model/udnie.ckpt --in-path xxx.jpg --out-path xxx/
魔改 Api(python)
mkdir neural_style
cd ..
git clone https://github.com/pytorch/examples/
cd examples/
cp fast_neural_style/neural_style/*.py ../fast-style-transfer/neural_style
python restful api ?
pip install flask
python create class?
修改后的 evaluate.py
from __future__ import print_function
import sys
sys.path.insert(0, 'src')
import numpy as np, os
import tensorflow as tf
from src.utils import save_img, get_img, exists, list_files
from src import transform
from collections import defaultdict
from moviepy.video.io.VideoFileClip import VideoFileClip
import moviepy.video.io.ffmpeg_writer as ffmpeg_writer
BATCH_SIZE = 4
DEVICE = '/gpu:0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def ffwd_video(path_in, path_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
video_clip = VideoFileClip(path_in, audio=False)
video_writer = ffmpeg_writer.FFMPEG_VideoWriter(path_out, video_clip.size, video_clip.fps, codec="libx264",
preset="medium", bitrate="2000k",
audiofile=path_in, threads=None,
ffmpeg_params=None)
g = tf.Graph()
soft_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
soft_config.gpu_options.allow_growth = True
with g.as_default(), g.device(device_t), \
tf.compat.v1.Session(config=soft_config) as sess:
batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3)
img_placeholder = tf.compat.v1.placeholder(tf.float32, shape=batch_shape,
name='img_placeholder')
preds = transform.net(img_placeholder)
saver = tf.train.Saver()
if os.path.isdir(checkpoint_dir):
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise Exception("No checkpoint found...")
else:
saver.restore(sess, checkpoint_dir)
X = np.zeros(batch_shape, dtype=np.float32)
def style_and_write(count):
for i in range(count, batch_size):
X[i] = X[count - 1] # Use last frame to fill X
_preds = sess.run(preds, feed_dict={img_placeholder: X})
for i in range(0, count):
video_writer.write_frame(np.clip(_preds[i], 0, 255).astype(np.uint8))
frame_count = 0 # The frame count that written to X
for frame in video_clip.iter_frames():
X[frame_count] = frame
frame_count += 1
if frame_count == batch_size:
style_and_write(frame_count)
frame_count = 0
if frame_count != 0:
style_and_write(frame_count)
video_writer.close()
# get img_shape
def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
assert len(paths_out) > 0
is_paths = type(data_in[0]) == str
if is_paths:
assert len(data_in) == len(paths_out)
img_shape = get_img(data_in[0]).shape
else:
assert data_in.size[0] == len(paths_out)
img_shape = X[0].shape
g = tf.Graph()
batch_size = min(len(paths_out), batch_size)
soft_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
soft_config.gpu_options.allow_growth = True
with g.as_default(), g.device(device_t), \
tf.compat.v1.Session(config=soft_config) as sess:
batch_shape = (batch_size,) + img_shape
img_placeholder = tf.compat.v1.placeholder(tf.float32, shape=batch_shape,
name='img_placeholder')
preds = transform.net(img_placeholder)
saver = tf.compat.v1.train.Saver()
if os.path.isdir(checkpoint_dir):
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise Exception("No checkpoint found...")
else:
saver.restore(sess, checkpoint_dir)
num_iters = int(len(paths_out)/batch_size)
for i in range(num_iters):
pos = i * batch_size
curr_batch_out = paths_out[pos:pos+batch_size]
if is_paths:
curr_batch_in = data_in[pos:pos+batch_size]
X = np.zeros(batch_shape, dtype=np.float32)
for j, path_in in enumerate(curr_batch_in):
img = get_img(path_in)
assert img.shape == img_shape, \
'Images have different dimensions. ' + \
'Resize images or use --allow-different-dimensions.'
X[j] = img
else:
X = data_in[pos:pos+batch_size]
_preds = sess.run(preds, feed_dict={img_placeholder:X})
for j, path_out in enumerate(curr_batch_out):
save_img(path_out, _preds[j])
remaining_in = data_in[num_iters*batch_size:]
remaining_out = paths_out[num_iters*batch_size:]
if len(remaining_in) > 0:
ffwd(remaining_in, remaining_out, checkpoint_dir,
device_t=device_t, batch_size=1)
def ffwd_to_img(in_path, out_path, checkpoint_dir, device='/cpu:0'):
paths_in, paths_out = [in_path], [out_path]
ffwd(paths_in, paths_out, checkpoint_dir, batch_size=1, device_t=device)
def ffwd_different_dimensions(in_path, out_path, checkpoint_dir,
device_t=DEVICE, batch_size=4):
in_path_of_shape = defaultdict(list)
out_path_of_shape = defaultdict(list)
for i in range(len(in_path)):
in_image = in_path[i]
out_image = out_path[i]
shape = "%dx%dx%d" % get_img(in_image).shape
in_path_of_shape[shape].append(in_image)
out_path_of_shape[shape].append(out_image)
for shape in in_path_of_shape:
print('Processing images of shape %s' % shape)
ffwd(in_path_of_shape[shape], out_path_of_shape[shape],
checkpoint_dir, device_t, batch_size)
def check_opts(opts):
exists(opts.checkpoint_dir, 'Checkpoint not found!')
exists(opts.in_path, 'In path not found!')
if os.path.isdir(opts.out_path):
exists(opts.out_path, 'out dir not found!')
assert opts.batch_size > 0
# 对象传参吧?
class Parser(object):
def __init__(self):
self._in_path = None
self._out_path = None
self._checkpoint_dir = None
self._device=DEVICE
self._batch_size=BATCH_SIZE
@property
def batch_size(self):
return self._batch_size
@property
def device(self):
return self._device
@property
def in_path(self):
return self._in_path
@property
def out_path(self):
return self._out_path
@property
def checkpoint_dir(self):
return self._checkpoint_dir
@in_path.setter
def in_path(self, in_path):
self._in_path=in_path
@out_path.setter
def out_path(self, out_path):
self._out_path=out_path
@checkpoint_dir.setter
def checkpoint_dir(self, checkpoint_dir):
self._checkpoint_dir=checkpoint_dir
@in_path.deleter
def in_path(self):
del self._in_path
@out_path.deleter
def out_path(self):
del self._out_path
@checkpoint_dir.deleter
def checkpoint_dir(self):
del self._checkpoint_dir
def main(opts):
check_opts(opts)
if not os.path.isdir(opts.in_path):
if os.path.exists(opts.out_path) and os.path.isdir(opts.out_path):
out_path = \
os.path.join(opts.out_path,os.path.basename(opts.in_path))
else:
out_path = opts.out_path
ffwd_to_img(opts.in_path, out_path, opts.checkpoint_dir,
device=opts.device)
else:
files = list_files(opts.in_path)
full_in = [os.path.join(opts.in_path,x) for x in files]
full_out = [os.path.join(opts.out_path,x) for x in files]
if opts.allow_different_dimensions:
ffwd_different_dimensions(full_in, full_out, opts.checkpoint_dir,
device_t=opts.device, batch_size=opts.batch_size)
else :
ffwd(full_in, full_out, opts.checkpoint_dir, device_t=opts.device,
batch_size=opts.batch_size)
什么是 argparse.ArgumentParser ?
修改后的 neural_style.py
import argparse
import os
import sys
import re
import torch
from torchvision import transforms
import torch.onnx
import neural_style.utils as utils
from neural_style.transformer_net import TransformerNet
def check_paths(args):
try:
if not os.path.exists(args.save_model_dir):
os.makedirs(args.save_model_dir)
if args.checkpoint_model_dir is not None and not (os.path.exists(args.checkpoint_model_dir)):
os.makedirs(args.checkpoint_model_dir)
except OSError as e:
print(e)
sys.exit(1)
def stylize(args):
device = torch.device("cuda" if args.cuda else "cpu")
content_image = utils.load_image(args.content_image, scale=args.content_scale)
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0).to(device)
if args.model.endswith(".onnx"):
output = stylize_onnx_caffe2(content_image, args)
else:
with torch.no_grad():
style_model = TransformerNet()
state_dict = torch.load(args.model)
# remove saved deprecated running_* keys in InstanceNorm from the checkpoint
for k in list(state_dict.keys()):
if re.search(r'in\d+\.running_(mean|var)$', k):
del state_dict[k]
style_model.load_state_dict(state_dict)
style_model.to(device)
if args.export_onnx:
assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu()
else:
output = style_model(content_image).cpu()
utils.save_image(args.output_image, output[0])
def stylize_onnx_caffe2(content_image, args):
assert not args.export_onnx
import onnx
import onnx_caffe2.backend
model = onnx.load(args.model)
prepared_backend = onnx_caffe2.backend.prepare(model, device='CUDA' if args.cuda else 'CPU')
inp = {model.graph.input[0].name: content_image.numpy()}
c2_out = prepared_backend.run(inp)[0]
return torch.from_numpy(c2_out)
def main(content_image,model,output_image):
args = argparse.ArgumentParser(description="parser for fast-neural-style").parse_args()
args.content_image=content_image
args.model=model
args.output_image=output_image
args.content_scale=None
args.export_onnx=None
args.cuda=0
stylize(args)
简单的 RESTful 实现
pip install jinja2
新建 app.py
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
import time
import os
import evaluateapi
import neural_style.neural_style as nn
app = Flask(__name__)
UPLOAD_FOLDER = 'upload'
UPLOAD_PATH = '/Users/apple/PycharmProjects/fast-style-transfer/upload/'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['JSON_AS_ASCII'] = False
basedir = os.path.abspath(os.path.dirname(__file__))
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG', 'jpeg', 'JPEG', 'gif'])
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
@app.route('/api/upload', methods=['POST'], strict_slashes=False)
def api_upload():
file_dir = os.path.join(basedir, app.config['UPLOAD_FOLDER'])
type = request.form['type']
if not os.path.exists(file_dir):
os.makedirs(file_dir)
f = request.files['myfile']
if f and allowed_file(f.filename):
fname = secure_filename(f.filename)
print(fname)
ext = fname.rsplit('.', 1)[1]
unix_time = int(time.time())
new_filename = str(unix_time) + '.' + ext
f.save(os.path.join(file_dir, new_filename))
file_path = UPLOAD_PATH + new_filename
if type.endswith('.ckpt'):
opts = evaluateapi.Parser()
opts.in_path = file_path
opts.out_path = file_path
opts.checkpoint_dir = type
evaluateapi.main(opts)
else:
content_image = file_path
model = type
output_image = file_path
nn.main(content_image, model, output_image)
print(file_path)
return jsonify({"code": 0,"filePath": file_path})
else:
return jsonify({"code": 1001, "errmsg": "上传失败"})
if __name__ == "__main__":
app.run(debug=True)
客户端(Java)
swagger 怎么写 ?
/**
* @author:czx.me 2020/3/23
*/
@Slf4j
@Api(tags = "fast-style-transfer(算力有限,图片大小限制在1M)", value = "fast-style-transfer(算力有限,图片大小限制在1M)")
@Controller
public class TransferController {
@PutMapping(value = "go", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
@ApiOperation(value = "transfer", notes = "fast-style-transfer")
@ApiImplicitParam(name = "model",
defaultValue = "la_muse",
value = "风格模型(la_muse rain_princess scream udnie wave wreck candy mosaic udnie_pth rain_princess_pth 十种)任选其一",
dataType = "String", paramType = "query")
public void transferStart(String model,
@ApiParam(name = "file", value = "图片文件", required = true) @ModelAttribute("file") MultipartFile file,
HttpServletResponse response,
HttpServletRequest request) throws IOException {
File copyFile = null;
File inFile = null;
BufferedOutputStream out = null;
InputStream input = null;
String url = "http://127.0.0.1:5000/api/upload";
try {
String fileName = file.getOriginalFilename();
//模型枚举 关键字+路径的组合
model = Model.fromTypeName(model);
if (null == model || "".equals(model)) {
log.error("没有 [{}] 这个类型?", model);
throw new Exception();
}
inFile = new File("./temp/" + fileName);
FileUtils.writeByteArrayToFile(inFile, file.getBytes());
// 一个封装的HttpURLConnection
String json = Request.post(url)
.connectTimeout(50000)
.readTimeout(50000)
.contentType("multipart/form-data")
.part("myfile", fileName, "multipart/form-data", inFile)
.part("type", model)
.body();
JSONObject jsonObject = JSON.parseObject(json);
String filePath = jsonObject.getString("filePath");
String code = jsonObject.getString("code");
if ("0".equals(code)) {
copyFile = new File(filePath);
input = new FileInputStream(copyFile);
String contentType = request.getServletContext().getMimeType(fileName);
response.setContentType(contentType);
out = new BufferedOutputStream(response.getOutputStream());
IOUtils.copy(input, out);
out.flush();
input.close();
} else {
log.error("好像算不过来。");
throw new Exception();
}
} catch (Exception e) {
response.sendRedirect("404");
} finally {
FileUtils.deleteQuietly(copyFile);
FileUtils.deleteQuietly(inFile);
}
}
}
目前拥有的模型
欢迎来到这里!
我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。
注册 关于