Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

instruct_generate.py 4.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
  1. import sys
  2. import time
  3. import json
  4. import argparse
  5. import asyncio
  6. import itertools
  7. from pprint import pprint
  8. import instruct_few_shot_examples
  9. sys.path.append("llava")
  10. from openai_api import call_async
  11. conv_to_str = lambda conv: "\n\n".join([("User: " if x["from"] == "human" else "Assistant: ") + x["value"] for x in conv])
  12. class PromptGenerator:
  13. @staticmethod
  14. def few_shot_messages_gen(query_context, use_inline_mentions=True):
  15. messages = [
  16. {"role": "system", "content": """You are an AI assistant specialized in biomedical topics.
  17. You are provided with a text description (Figure Caption) of a figure image from a biomedical research paper. In some cases, you may have additional text (Figure Context) that mentions the image. Unfortunately, you don't have access to the actual image.
  18. Your task is to generate a conversation between a person (User) inquiring about the image and you (Assistant) responding to their questions. The conversation should proceed as though both the User and Assistant are viewing the image, while not referring to the text information (Figure Caption and Figure Context).
  19. Below are requirements for generating the questions and answers in the conversation:
  20. - Avoid quoting or referring to specific facts, terms, abbreviations, dates, numbers, or names, as these may reveal the conversation is based on the text information, rather than the image itself. Focus on the visual aspects of the image that can be inferred without the text information.
  21. - Do not use phrases like "mentioned", "caption", "context" in the conversation. Instead, refer to the information as being "in the image."
  22. - Ensure that questions are diverse and cover a range of visual aspects of the image.
  23. - The conversation should include at least 2-3 turns of questions and answers about the visual aspects of the image.
  24. - Answer responsibly, avoiding overconfidence, and do not provide medical advice or diagnostic information. Encourage the user to consult a healthcare professional for advice.
  25. """},
  26. ]
  27. for ex in instruct_few_shot_examples.fs:
  28. messages += [
  29. {"role": "user", "content": PromptGenerator.context_gen(ex, use_inline_mentions)},
  30. {"role": "assistant", "content": conv_to_str(ex["conversations"])},
  31. ]
  32. messages.append({"role": "user", "content": query_context})
  33. return messages
  34. @staticmethod
  35. def context_gen(sample, use_inline_mentions=True):
  36. ctx = []
  37. if use_inline_mentions and sample["in_text_mention"]:
  38. for sent in sample["in_text_mention"]:
  39. if isinstance(sent, dict):
  40. sent = sent["tokens"]
  41. ctx.append(sent)
  42. ret = f"Figure Caption:\n{sample['fig_label']}: {sample['fig_caption']}"
  43. if len(ctx):
  44. ret += "\n\nFigure Context:\n\t- {ctx}".format(ctx="\n\t- ".join(ctx))
  45. return ret
  46. @staticmethod
  47. def wrap_gen_message(sample, use_inline_mentions=False):
  48. text = PromptGenerator.context_gen(sample, use_inline_mentions=use_inline_mentions)
  49. context = PromptGenerator.few_shot_messages_gen(text, use_inline_mentions=use_inline_mentions)
  50. return context
  51. def main(args):
  52. with open(args.input_path) as f:
  53. domain_dict = json.load(f)
  54. results = []
  55. for i in range(3):
  56. print(f'round {i}')
  57. result_pair_ids = set(result['pair_id'] for result in results)
  58. batch = []
  59. counter = 0
  60. for cycle_idx, samples in enumerate(itertools.zip_longest(*domain_dict.values())):
  61. if counter>=args.max_size:
  62. break
  63. for domain_idx, sample in enumerate(samples):
  64. if not sample:
  65. continue
  66. counter+=1
  67. if counter>=args.max_size:
  68. break
  69. if sample['pair_id'] in result_pair_ids:
  70. continue
  71. batch.append(sample)
  72. if len(batch)>=args.batch_size:
  73. async_results = call_async(batch, lambda x: PromptGenerator.wrap_gen_message(x, use_inline_mentions=args.use_inline_mentions))
  74. results.extend(async_results)
  75. print(f"Result Size: {len(results)}")
  76. batch = []
  77. async_results = call_async(batch, lambda x: PromptGenerator.wrap_gen_message(x, use_inline_mentions=args.use_inline_mentions))
  78. results.extend(async_results)
  79. print(f"Result Size: {len(results)}")
  80. with open(args.output_path, 'w') as f:
  81. for line in results:
  82. f.write(json.dumps(line)+'\n')
  83. if __name__ == '__main__':
  84. parser = argparse.ArgumentParser()
  85. parser.add_argument('--input_path', type=str, default='data/instruct/llava_med_instruct_fig_captions.json')
  86. parser.add_argument('--output_path', type=str, default='data/instruct/llava_med_instruct_fig_captions_gen.json')
  87. parser.add_argument('--use_inline_mentions', type=bool, default=False)
  88. parser.add_argument('--batch_size', type=int, default=3)
  89. parser.add_argument('--max_size', type=int, default=60000)
  90. args = parser.parse_args()
  91. main(args)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...