Skip to content

Commit cec1010

Browse files
committed
[bug fixed] load pre-trained checkpoint.
1 parent e4b1276 commit cec1010

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

eval_demo.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
cfg = yaml2config(args.config)
3232

3333
model = get_model(cfg.model)(cfg, args.config)
34-
model.load(args.ckpt)
34+
model.load(args.ckpt, cfg.device)
3535
if args.mode == 'style':
3636
model.eval_style()
3737
elif args.mode == 'rand':
@@ -41,4 +41,4 @@
4141
elif args.mode == 'text':
4242
model.eval_text()
4343
else:
44-
print('Unsupported mode: {} | [rand] [style] [interp]'.format(cfg.mode))
44+
print('Unsupported mode: {} | [rand] [style] [interp]'.format(cfg.mode))

networks/model.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,18 @@ def save(self, tag='best', epoch_done=0, **kwargs):
8181
ckpt_save_path = os.path.join(self.log_root, self.opt.training.ckpt_dir, tag + '.pth')
8282
torch.save(ckpt, ckpt_save_path)
8383

84-
def load(self, ckpt, modules=None):
84+
def load(self, ckpt, map_location=None, modules=None):
8585
if modules is None:
8686
modules = []
8787
elif not isinstance(modules, list):
8888
modules = [modules]
8989

9090
print('load checkpoint from ', ckpt)
91-
ckpt = torch.load(ckpt)
91+
if map_location is None:
92+
ckpt = torch.load(ckpt)
93+
else:
94+
ckpt = torch.load(ckpt, map_location=map_location)
95+
9296
if len(modules) == 0:
9397
for model in self.models.values():
9498
model.load_state_dict(ckpt[type(model).__name__])

test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@
4242
cfg.valid.dset_split = args.split
4343

4444
model = get_model(cfg.model)(cfg, args.config)
45-
model.load(args.ckpt)
46-
print(model.validate(args.guided))
45+
model.load(args.ckpt, cfg.device)
46+
print(model.validate(args.guided))

0 commit comments

Comments
 (0)