Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

predict.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  1. import numpy as np
  2. import tempfile
  3. import shutil
  4. import os
  5. from PIL import Image
  6. import subprocess
  7. from cog import BasePredictor, Input, Path
  8. class Predictor(BasePredictor):
  9. def predict(
  10. self,
  11. image: Path = Input(
  12. description="Input Image.",
  13. ),
  14. ) -> Path:
  15. input_dir = "input_dir"
  16. output_path = Path(tempfile.mkdtemp()) / "output.png"
  17. try:
  18. for d in [input_dir, "results"]:
  19. if os.path.exists(input_dir):
  20. shutil.rmtree(input_dir)
  21. os.makedirs(input_dir, exist_ok=False)
  22. input_path = os.path.join(input_dir, os.path.basename(image))
  23. shutil.copy(str(image), input_path)
  24. subprocess.call(
  25. [
  26. "python",
  27. "hat/test.py",
  28. "-opt",
  29. "options/test/HAT_SRx4_ImageNet-LR.yml",
  30. ]
  31. )
  32. res_dir = os.path.join(
  33. "results", "HAT_SRx4_ImageNet-LR", "visualization", "custom"
  34. )
  35. assert (
  36. len(os.listdir(res_dir)) == 1
  37. ), "Should contain only one result for Single prediction."
  38. res = Image.open(os.path.join(res_dir, os.listdir(res_dir)[0]))
  39. res.save(str(output_path))
  40. finally:
  41. pass
  42. shutil.rmtree(input_dir)
  43. shutil.rmtree("results")
  44. return output_path
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...