Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

st_inference.py 2.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
  1. from PIL import Image
  2. # Warning: this is private internal dvc api, it may change for future dvc versions
  3. from dvc.repo.get import get
  4. import numpy as np
  5. import pandas as pd
  6. import streamlit as st
  7. import tensorflow as tf
  8. from scripts.params import TRAIN_DIR, IMG_SIZE
  9. from st_scripts.st_utils import REPO, ROOT_DIR, get_model_backbone, st_model_multiselect
  10. @st.cache
  11. def load_model(rev: str):
  12. model_cache_dir = ROOT_DIR / ".model_cache"
  13. model_cache_dir.mkdir(exist_ok=True)
  14. print(f"Loading model for revision {rev}")
  15. # 1. Download model to MODEL_CACHE dir using `dvc get`
  16. # See https://dvc.org/doc/command-reference/get
  17. out_model_dir = str(model_cache_dir / rev)
  18. # Try to load the model directly (if it is in cache dir)
  19. try:
  20. return tf.keras.models.load_model(out_model_dir)
  21. except OSError:
  22. print(f"Could not find model {rev} in cache")
  23. except Exception as e:
  24. print(f"Could not load model {rev} from cache")
  25. get(url=".", path=str(TRAIN_DIR / "model"), out=out_model_dir, rev=rev)
  26. print(f"Model downloaded to {out_model_dir}")
  27. # 2. Load the model with tf.keras.models.load_model
  28. return tf.keras.models.load_model(out_model_dir)
  29. def st_inference():
  30. st.markdown("### ModelsInference")
  31. selected_models = st_model_multiselect()
  32. models = {
  33. model_rev: load_model(rev=model_rev)
  34. for model_rev in selected_models
  35. }
  36. uploaded_file = st.file_uploader("Upload an image")
  37. if uploaded_file:
  38. image = Image.open(uploaded_file)
  39. image_name = uploaded_file.name
  40. resized_image = image.resize(IMG_SIZE)
  41. beta_column_0, beta_column_1, beta_column_2 = st.columns(3)
  42. with beta_column_0:
  43. st.write(f"Image name: {image_name}")
  44. st.write(f"Image width: {image.size[0]}")
  45. st.write(f"Image height: {image.size[1]}")
  46. with beta_column_1:
  47. st.image(image, caption="Original image")
  48. with beta_column_2:
  49. st.image(resized_image, caption=f"Input resized image : {IMG_SIZE}")
  50. input = np.expand_dims(tf.keras.preprocessing.image.img_to_array(resized_image), axis=0)
  51. if st.button(f"Run {len(models)} model(s)"):
  52. predictions = []
  53. for model_rev, model in models.items():
  54. prediction = tf.nn.sigmoid(model.predict(input).flatten()).numpy()[0]
  55. predictions.append({
  56. "backbone": get_model_backbone(model_rev),
  57. "commit_hash": model_rev,
  58. "commit_message": REPO.commit(model_rev).message,
  59. "cat": 1 - prediction,
  60. "dog": prediction,
  61. })
  62. st.dataframe(pd.DataFrame(predictions))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...