1.main.py
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler
if opt.half:
scaler = GradScaler()
for i, data in enumerate(train_loader, 0):
data_time.update(time.time() - start_time)
# Hnet.change(False)
Hnet.zero_grad()
Rnet.zero_grad()
with torch.autocast(device_type='cuda', dtype=torch.float16):
all_pics = data # allpics contains cover images and secret images
this_batch_size = int(all_pics.size()[0]) # get true batch size of this step
# first half of images will become cover images, the rest are treated as secret images
cover_img = all_pics[0:this_batch_size, :, :, :] # batchsize,3,256,256
secret_img,target=sample_secret_img(this_batch_size)
# secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]
cover_img = cover_img.to(device)
secret_img = secret_img.to(device)
with torch.no_grad():
cover_img = Variable(cover_img)
secret_imgv = Variable(secret_img)
cover_imgv = JPEG(cover_img)
# print(cover_imgv.shape, secret_imgv.shape)
container = Hnet(cover_imgv, secret_imgv) # put concat_image into H-net and get container image
errH = criterion(container, cover_imgv) # loss between cover and container
Hlosses.update(errH.item(), this_batch_size)
compress_img=JPEG(container)
rev_secret_img = Rnet(compress_img) # put concatenated image into R-net and get revealed secret image
# print(rev_secret_img.shape, secret_imgv.shape)
errR = criterion(rev_secret_img, secret_imgv) # loss between secret image and revealed secret image
Rlosses.update(errR.item(), this_batch_size)
betaerrR_secret = opt.beta * errR
err_sum = errH + betaerrR_secret
SumLosses.update(err_sum.item(), this_batch_size)
# err_sum.backward()
scaler.scale(err_sum).backward()
scaler.step(optimizerH)
scaler.step(optimizerR)
# optimizerH.step()
# optimizerR.step()
# Updates the scale for next iteration.
scaler.update()
batch_time.update(time.time() - start_time)
start_time = time.time()
log = '[%d/%d][%d/%d]\tLoss_H: %.4f Loss_R: %.4f Loss_sum: %.4f \tdatatime: %.4f \tbatchtime: %.4f' % (
epoch, opt.niter, i, len(train_loader),
Hlosses.val, Rlosses.val, SumLosses.val, data_time.val, batch_time.val)
if i % opt.logFrequency == 0:
print_log(log, logPath)
else:
print_log(log, logPath, console=False)
# genereate a picture every resultPicFrequency steps
if epoch % 1 == 0 and i % opt.resultPicFrequency == 0:
save_result_pic(this_batch_size,
cover_img, container.data,compress_img.data,
secret_img, rev_secret_img.data,
epoch, i, opt.trainpics)
# break
else:
for i, data in enumerate(train_loader, 0):
data_time.update(time.time() - start_time)
# Hnet.change(False)
Hnet.zero_grad()
Rnet.zero_grad()
all_pics = data # allpics contains cover images and secret images
this_batch_size = int(all_pics.size()[0]) # get true batch size of this step
# first half of images will become cover images, the rest are treated as secret images
cover_img = all_pics[0:this_batch_size, :, :, :] # batchsize,3,256,256
secret_img,target=sample_secret_img(this_batch_size)
# secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]
cover_img = cover_img.to(device)
secret_img = secret_img.to(device)
with torch.no_grad():
cover_img = Variable(cover_img)
secret_imgv = Variable(secret_img)
cover_imgv = JPEG(cover_img)
# print(cover_imgv.shape, secret_imgv.shape)
container = Hnet(cover_imgv, secret_imgv) # put concat_image into H-net and get container image
errH = criterion(container, cover_imgv) # loss between cover and container
Hlosses.update(errH.item(), this_batch_size)
compress_img=JPEG(container)
rev_secret_img = Rnet(compress_img) # put concatenated image into R-net and get revealed secret image
# print(rev_secret_img.shape, secret_imgv.shape)
errR = criterion(rev_secret_img, secret_imgv) # loss between secret image and revealed secret image
Rlosses.update(errR.item(), this_batch_size)
betaerrR_secret = opt.beta * errR
err_sum = errH + betaerrR_secret
SumLosses.update(err_sum.item(), this_batch_size)
err_sum.backward()
optimizerH.step()
optimizerR.step()
# Updates the scale for next iteration.
batch_time.update(time.time() - start_time)
start_time = time.time()
log = '[%d/%d][%d/%d]\tLoss_H: %.4f Loss_R: %.4f Loss_sum: %.4f \tdatatime: %.4f \tbatchtime: %.4f' % (
epoch, opt.niter, i, len(train_loader),
Hlosses.val, Rlosses.val, SumLosses.val, data_time.val, batch_time.val)
if i % opt.logFrequency == 0:
print_log(log, logPath)
else:
print_log(log, logPath, console=False)
# genereate a picture every resultPicFrequency steps
if epoch % 1 == 0 and i % opt.resultPicFrequency == 0:
save_result_pic(this_batch_size,
cover_img, container.data,compress_img.data,
secret_img, rev_secret_img.data,
epoch, i, opt.trainpics)
2.note
how to deal the issue.
欢迎来到这里!
我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。
注册 关于