Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

evaluate.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
  1. """Evaluation Script."""
  2. import argparse
  3. import torch
  4. import mlflow
  5. import os
  6. from configs import config
  7. from models import StyleTransferNetwork
  8. from utils.image_utils import *
  9. from utils.data_utils import *
  10. def evaluate(args):
  11. """Evaluate the network."""
  12. device = torch.device('cpu')
  13. # Set up MLflow
  14. mlflow.set_tracking_uri("https://dagshub.com/shatter-star/musical-octo-dollop.mlflow")
  15. os.environ["MLFLOW_TRACKING_USERNAME"] = "shatter-star"
  16. os.environ["MLFLOW_TRACKING_PASSWORD"] = "411996890a0df0c0ccf65dbd848d454f40ad3cbb"
  17. model = mlflow.pytorch.load_model(args.model_uri, map_location=device)
  18. model.eval()
  19. content_image = imload(args.content_path, imsize=args.imsize)
  20. # for all styles
  21. if args.style_index == -1:
  22. style_code = torch.eye(config.NUM_STYLE).unsqueeze(-1)
  23. content_image = content_image.repeat(config.NUM_STYLE, 1, 1, 1)
  24. # for specific style
  25. elif args.style_index in range(config.NUM_STYLE):
  26. style_code = torch.zeros(1, config.NUM_STYLE, 1)
  27. style_code[:, args.style_index, :] = 1
  28. else:
  29. raise RuntimeError("Not expected style index")
  30. stylized_image = model(content_image, style_code)
  31. imsave(stylized_image, args.output_path)
  32. return None
  33. if __name__ == '__main__':
  34. parser = argparse.ArgumentParser()
  35. # Data configurations
  36. parser.add_argument('--content_path', type=str, required=True,
  37. help='Path to content image')
  38. parser.add_argument('--imsize', type=int, default=config.IMSIZE,
  39. help='Input image size')
  40. # Other configurations
  41. parser.add_argument('--output_path', type=str, default='stylized_image.jpg',
  42. help='Path to save the stylized image')
  43. parser.add_argument('--style_index', type=int, default=0,
  44. help='Index of the style to use (-1 for all styles)')
  45. parser.add_argument('--model_uri', type=str, required=False,
  46. help='URI of the MLflow model to load')
  47. args = parser.parse_args()
  48. if not args.model_uri:
  49. args.model_uri = input("Enter the MLflow model URI: ")
  50. evaluate(args)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...