Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataset_gen.py 3.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  1. """
  2. This script downloads the Fashion MNIST dataset, processes a specified number of samples,
  3. and saves them to a CSV file. Each row in the CSV file contains the original dataset index,
  4. the class label name, and the image encoded as a base64 string.
  5. The Fashion MNIST dataset is a collection of 70,000 grayscale images of 28x28 pixels,
  6. each depicting one of 10 types of clothing. For more information on the dataset, see:
  7. https://www.tensorflow.org/api_docs/python/tf/keras/datasets/fashion_mnist/load_data
  8. Usage:
  9. python save_fashion_mnist_to_csv.py --num_samples <num_samples> --filename <filename>
  10. Arguments:
  11. --num_samples: Number of samples to save (default: 100)
  12. --filename: Output CSV file name (default: fashion_mnist_sample_base64.csv)
  13. """
  14. import base64
  15. import io
  16. from typing import List
  17. import numpy as np
  18. import pandas as pd
  19. import tensorflow as tf
  20. from PIL import Image
  21. # set np seed for reproducibility
  22. np.random.seed(0)
  23. def get_class_names() -> dict[int, str]:
  24. """Retrieves the class names for the Fashion MNIST dataset.
  25. Returns:
  26. A dictionary mapping class indices to class names.
  27. """
  28. return {
  29. 0: "T-shirt/top",
  30. 1: "Trouser",
  31. 2: "Pullover",
  32. 3: "Dress",
  33. 4: "Coat",
  34. 5: "Sandal",
  35. 6: "Shirt",
  36. 7: "Sneaker",
  37. 8: "Bag",
  38. 9: "Ankle boot",
  39. }
  40. def image_to_base64(image: np.ndarray) -> str:
  41. """Converts an image to a base64 encoded string.
  42. Args:
  43. image: A numpy array representing the image.
  44. Returns:
  45. A base64 encoded string of the image.
  46. """
  47. buffered = io.BytesIO()
  48. pil_image = Image.fromarray(image)
  49. # NOTE: For a dataset with large images, you can resize it here to save
  50. # costs on the inference side.
  51. # pil_image = pil_image.resize((32, 32))
  52. pil_image.save(buffered, format="jpeg")
  53. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  54. def save_fashion_mnist_sample_to_csv(num_samples: int, filename: str) -> None:
  55. """Saves a sample of the Fashion MNIST dataset to a CSV file.
  56. Args:
  57. num_samples: The number of samples to save.
  58. filename: The name of the output CSV file.
  59. """
  60. # Load the Fashion MNIST dataset
  61. fashion_mnist = tf.keras.datasets.fashion_mnist
  62. (train_images, train_labels), _ = fashion_mnist.load_data()
  63. class_names = get_class_names()
  64. # Randomly sample indices without replacement
  65. sample_indices = np.random.choice(len(train_images), num_samples, replace=False)
  66. # Convert images to base64 and combine with labels and indices
  67. data: List[List] = []
  68. for sample_index in sample_indices:
  69. base64_image = image_to_base64(train_images[sample_index])
  70. label_index = train_labels[sample_index]
  71. label_name = class_names[label_index]
  72. data.append([sample_index, label_name, base64_image])
  73. pd.DataFrame(data, columns=["index", "label", "image_base64"]).sort_values(
  74. by=["label", "index"]
  75. ).to_csv(filename, index=False)
  76. print(f"CSV file '{filename}' created successfully.")
  77. if __name__ == "__main__":
  78. import argparse
  79. parser = argparse.ArgumentParser(
  80. description="Save Fashion MNIST samples to a CSV file."
  81. )
  82. parser.add_argument(
  83. "--num_samples", type=int, default=100, help="Number of samples to save"
  84. )
  85. parser.add_argument(
  86. "--filename",
  87. type=str,
  88. default="fashion_mnist_sample_base64.csv",
  89. help="Output CSV file name",
  90. )
  91. args = parser.parse_args()
  92. save_fashion_mnist_sample_to_csv(args.num_samples, args.filename)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...