Transfer Learning 之 fast-style-transfer 初体验

本贴最后更新于 1825 天前,其中的信息可能已经事过境迁

前言

这是一篇讲述想要跟风摸鱼,啥也不懂却也可以实现功能的解决方案。

自述

Java 新手,也不会 python。但是 人間讃歌は「勇気」の讃歌ッ!! 人間のすばらしさは勇気のすばらしさ!!

资源准备

  1. fast_neural_style
  2. fast-style-transfer
  3. 基于 macOS 搭建一个 tensorflow 环境
  4. Windows 10 搭建 TensorFlow 试玩 fast-style-transfer
  5. 百度
  6. 手和大脑

测试

路由器没带所以改个源吧。
sudo vi ~/.config/pip/pip.conf

[global] #index-url = https://pypi.tuna.tsinghua.edu.cn/simple index-url = https://mirrors.aliyun.com/pypi/simple
git clone https://github.com/lengstrom/fast-style-transfer cd fast-style-transfer/ pyenv activate v370env python evaluate.py --checkpoint model/udnie.ckpt --in-path xxx.jpg --out-path xxx/

魔改 Api(python)

mkdir neural_style cd .. git clone https://github.com/pytorch/examples/ cd examples/ cp fast_neural_style/neural_style/*.py ../fast-style-transfer/neural_style

python restful api ?

pip install flask

python create class?

修改后的 evaluate.py
from __future__ import print_function import sys sys.path.insert(0, 'src') import numpy as np, os import tensorflow as tf from src.utils import save_img, get_img, exists, list_files from src import transform from collections import defaultdict from moviepy.video.io.VideoFileClip import VideoFileClip import moviepy.video.io.ffmpeg_writer as ffmpeg_writer BATCH_SIZE = 4 DEVICE = '/gpu:0' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' def ffwd_video(path_in, path_out, checkpoint_dir, device_t='/gpu:0', batch_size=4): video_clip = VideoFileClip(path_in, audio=False) video_writer = ffmpeg_writer.FFMPEG_VideoWriter(path_out, video_clip.size, video_clip.fps, codec="libx264", preset="medium", bitrate="2000k", audiofile=path_in, threads=None, ffmpeg_params=None) g = tf.Graph() soft_config = tf.compat.v1.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), g.device(device_t), \ tf.compat.v1.Session(config=soft_config) as sess: batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3) img_placeholder = tf.compat.v1.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.train.Saver() if os.path.isdir(checkpoint_dir): ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("No checkpoint found...") else: saver.restore(sess, checkpoint_dir) X = np.zeros(batch_shape, dtype=np.float32) def style_and_write(count): for i in range(count, batch_size): X[i] = X[count - 1] # Use last frame to fill X _preds = sess.run(preds, feed_dict={img_placeholder: X}) for i in range(0, count): video_writer.write_frame(np.clip(_preds[i], 0, 255).astype(np.uint8)) frame_count = 0 # The frame count that written to X for frame in video_clip.iter_frames(): X[frame_count] = frame frame_count += 1 if frame_count == batch_size: style_and_write(frame_count) frame_count = 0 if frame_count != 0: style_and_write(frame_count) video_writer.close() # get img_shape def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4): assert len(paths_out) > 0 is_paths = type(data_in[0]) == str if is_paths: assert len(data_in) == len(paths_out) img_shape = get_img(data_in[0]).shape else: assert data_in.size[0] == len(paths_out) img_shape = X[0].shape g = tf.Graph() batch_size = min(len(paths_out), batch_size) soft_config = tf.compat.v1.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), g.device(device_t), \ tf.compat.v1.Session(config=soft_config) as sess: batch_shape = (batch_size,) + img_shape img_placeholder = tf.compat.v1.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.compat.v1.train.Saver() if os.path.isdir(checkpoint_dir): ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("No checkpoint found...") else: saver.restore(sess, checkpoint_dir) num_iters = int(len(paths_out)/batch_size) for i in range(num_iters): pos = i * batch_size curr_batch_out = paths_out[pos:pos+batch_size] if is_paths: curr_batch_in = data_in[pos:pos+batch_size] X = np.zeros(batch_shape, dtype=np.float32) for j, path_in in enumerate(curr_batch_in): img = get_img(path_in) assert img.shape == img_shape, \ 'Images have different dimensions. ' + \ 'Resize images or use --allow-different-dimensions.' X[j] = img else: X = data_in[pos:pos+batch_size] _preds = sess.run(preds, feed_dict={img_placeholder:X}) for j, path_out in enumerate(curr_batch_out): save_img(path_out, _preds[j]) remaining_in = data_in[num_iters*batch_size:] remaining_out = paths_out[num_iters*batch_size:] if len(remaining_in) > 0: ffwd(remaining_in, remaining_out, checkpoint_dir, device_t=device_t, batch_size=1) def ffwd_to_img(in_path, out_path, checkpoint_dir, device='/cpu:0'): paths_in, paths_out = [in_path], [out_path] ffwd(paths_in, paths_out, checkpoint_dir, batch_size=1, device_t=device) def ffwd_different_dimensions(in_path, out_path, checkpoint_dir, device_t=DEVICE, batch_size=4): in_path_of_shape = defaultdict(list) out_path_of_shape = defaultdict(list) for i in range(len(in_path)): in_image = in_path[i] out_image = out_path[i] shape = "%dx%dx%d" % get_img(in_image).shape in_path_of_shape[shape].append(in_image) out_path_of_shape[shape].append(out_image) for shape in in_path_of_shape: print('Processing images of shape %s' % shape) ffwd(in_path_of_shape[shape], out_path_of_shape[shape], checkpoint_dir, device_t, batch_size) def check_opts(opts): exists(opts.checkpoint_dir, 'Checkpoint not found!') exists(opts.in_path, 'In path not found!') if os.path.isdir(opts.out_path): exists(opts.out_path, 'out dir not found!') assert opts.batch_size > 0 # 对象传参吧? class Parser(object): def __init__(self): self._in_path = None self._out_path = None self._checkpoint_dir = None self._device=DEVICE self._batch_size=BATCH_SIZE @property def batch_size(self): return self._batch_size @property def device(self): return self._device @property def in_path(self): return self._in_path @property def out_path(self): return self._out_path @property def checkpoint_dir(self): return self._checkpoint_dir @in_path.setter def in_path(self, in_path): self._in_path=in_path @out_path.setter def out_path(self, out_path): self._out_path=out_path @checkpoint_dir.setter def checkpoint_dir(self, checkpoint_dir): self._checkpoint_dir=checkpoint_dir @in_path.deleter def in_path(self): del self._in_path @out_path.deleter def out_path(self): del self._out_path @checkpoint_dir.deleter def checkpoint_dir(self): del self._checkpoint_dir def main(opts): check_opts(opts) if not os.path.isdir(opts.in_path): if os.path.exists(opts.out_path) and os.path.isdir(opts.out_path): out_path = \ os.path.join(opts.out_path,os.path.basename(opts.in_path)) else: out_path = opts.out_path ffwd_to_img(opts.in_path, out_path, opts.checkpoint_dir, device=opts.device) else: files = list_files(opts.in_path) full_in = [os.path.join(opts.in_path,x) for x in files] full_out = [os.path.join(opts.out_path,x) for x in files] if opts.allow_different_dimensions: ffwd_different_dimensions(full_in, full_out, opts.checkpoint_dir, device_t=opts.device, batch_size=opts.batch_size) else : ffwd(full_in, full_out, opts.checkpoint_dir, device_t=opts.device, batch_size=opts.batch_size)

什么是 argparse.ArgumentParser ?

修改后的 neural_style.py
import argparse import os import sys import re import torch from torchvision import transforms import torch.onnx import neural_style.utils as utils from neural_style.transformer_net import TransformerNet def check_paths(args): try: if not os.path.exists(args.save_model_dir): os.makedirs(args.save_model_dir) if args.checkpoint_model_dir is not None and not (os.path.exists(args.checkpoint_model_dir)): os.makedirs(args.checkpoint_model_dir) except OSError as e: print(e) sys.exit(1) def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) if args.model.endswith(".onnx"): output = stylize_onnx_caffe2(content_image, args) else: with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) if args.export_onnx: assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx" output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu() else: output = style_model(content_image).cpu() utils.save_image(args.output_image, output[0]) def stylize_onnx_caffe2(content_image, args): assert not args.export_onnx import onnx import onnx_caffe2.backend model = onnx.load(args.model) prepared_backend = onnx_caffe2.backend.prepare(model, device='CUDA' if args.cuda else 'CPU') inp = {model.graph.input[0].name: content_image.numpy()} c2_out = prepared_backend.run(inp)[0] return torch.from_numpy(c2_out) def main(content_image,model,output_image): args = argparse.ArgumentParser(description="parser for fast-neural-style").parse_args() args.content_image=content_image args.model=model args.output_image=output_image args.content_scale=None args.export_onnx=None args.cuda=0 stylize(args)

简单的 RESTful 实现

pip install jinja2
新建 app.py
from flask import Flask, request, jsonify from werkzeug.utils import secure_filename import time import os import evaluateapi import neural_style.neural_style as nn app = Flask(__name__) UPLOAD_FOLDER = 'upload' UPLOAD_PATH = '/Users/apple/PycharmProjects/fast-style-transfer/upload/' app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER app.config['JSON_AS_ASCII'] = False basedir = os.path.abspath(os.path.dirname(__file__)) ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG', 'jpeg', 'JPEG', 'gif']) def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS @app.route('/api/upload', methods=['POST'], strict_slashes=False) def api_upload(): file_dir = os.path.join(basedir, app.config['UPLOAD_FOLDER']) type = request.form['type'] if not os.path.exists(file_dir): os.makedirs(file_dir) f = request.files['myfile'] if f and allowed_file(f.filename): fname = secure_filename(f.filename) print(fname) ext = fname.rsplit('.', 1)[1] unix_time = int(time.time()) new_filename = str(unix_time) + '.' + ext f.save(os.path.join(file_dir, new_filename)) file_path = UPLOAD_PATH + new_filename if type.endswith('.ckpt'): opts = evaluateapi.Parser() opts.in_path = file_path opts.out_path = file_path opts.checkpoint_dir = type evaluateapi.main(opts) else: content_image = file_path model = type output_image = file_path nn.main(content_image, model, output_image) print(file_path) return jsonify({"code": 0,"filePath": file_path}) else: return jsonify({"code": 1001, "errmsg": "上传失败"}) if __name__ == "__main__": app.run(debug=True)

客户端(Java)

swagger 怎么写 ?

/** * @author:czx.me 2020/3/23 */ @Slf4j @Api(tags = "fast-style-transfer(算力有限,图片大小限制在1M)", value = "fast-style-transfer(算力有限,图片大小限制在1M)") @Controller public class TransferController { @PutMapping(value = "go", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) @ApiOperation(value = "transfer", notes = "fast-style-transfer") @ApiImplicitParam(name = "model", defaultValue = "la_muse", value = "风格模型(la_muse rain_princess scream udnie wave wreck candy mosaic udnie_pth rain_princess_pth 十种)任选其一", dataType = "String", paramType = "query") public void transferStart(String model, @ApiParam(name = "file", value = "图片文件", required = true) @ModelAttribute("file") MultipartFile file, HttpServletResponse response, HttpServletRequest request) throws IOException { File copyFile = null; File inFile = null; BufferedOutputStream out = null; InputStream input = null; String url = "http://127.0.0.1:5000/api/upload"; try { String fileName = file.getOriginalFilename(); //模型枚举 关键字+路径的组合 model = Model.fromTypeName(model); if (null == model || "".equals(model)) { log.error("没有 [{}] 这个类型?", model); throw new Exception(); } inFile = new File("./temp/" + fileName); FileUtils.writeByteArrayToFile(inFile, file.getBytes()); // 一个封装的HttpURLConnection String json = Request.post(url) .connectTimeout(50000) .readTimeout(50000) .contentType("multipart/form-data") .part("myfile", fileName, "multipart/form-data", inFile) .part("type", model) .body(); JSONObject jsonObject = JSON.parseObject(json); String filePath = jsonObject.getString("filePath"); String code = jsonObject.getString("code"); if ("0".equals(code)) { copyFile = new File(filePath); input = new FileInputStream(copyFile); String contentType = request.getServletContext().getMimeType(fileName); response.setContentType(contentType); out = new BufferedOutputStream(response.getOutputStream()); IOUtils.copy(input, out); out.flush(); input.close(); } else { log.error("好像算不过来。"); throw new Exception(); } } catch (Exception e) { response.sendRedirect("404"); } finally { FileUtils.deleteQuietly(copyFile); FileUtils.deleteQuietly(inFile); } } }

目前拥有的模型

大功告成?

限时体验

现已部署在阿里云轻量级服务器上

戳这里
  • PyTorch
    8 引用 • 8 回帖
  • 深度学习

    深度学习(Deep Learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。

    53 引用 • 40 回帖 • 1 关注
  • 图片处理
    12 引用 • 34 回帖
1 操作
233333 在 2020-03-26 13:50:05 更新了该帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...
请输入回帖内容 ...