Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

experiment.py 3.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  1. from mcmc_norm_learning.environment import *
  2. from mcmc_norm_learning.rules_4 import *
  3. from mcmc_norm_learning.robot_task_new import *
  4. from mcmc_norm_learning.algorithm_1_v4 import create_data,algorithm_1,to_tuple
  5. from mcmc_norm_learning.mcmc_performance import performance
  6. from mcmc_norm_learning.mcmc_convergence import prepare_sequences,calculate_R
  7. #import matplotlib.pyplot as plt
  8. from collections import Counter
  9. import pickle
  10. import time
  11. #import seaborn as sns
  12. import os
  13. import sys
  14. from tqdm import tnrange, tqdm_notebook
  15. from functools import reduce
  16. from operator import concat
  17. from mcmc_norm_learning.rules_4 import q_dict, rule_dict
  18. def gen_data_for_conv_test(data, env, task1):
  19. n=500 #Length of sequence after discarding warm-up part and splitting in half
  20. m=10 #Number of sequences after splitting in half
  21. rf=0.6
  22. chains=[]
  23. for i in tnrange(1,int(m/2+1),desc="Loop for Individual Chains"):
  24. print ("\n:::::::::::::::::::: FOR SEQUENCE {} ::::::::::::::::::::".format(i))
  25. exp_seq,lik_list = algorithm_1(data,env,task1,q_dict,rule_dict,
  26. "demo/convergence/report_for_chain_{}_x".format(i),
  27. relevance_factor=rf,max_iterations=4*n,verbose=False)
  28. chains.append(exp_seq)
  29. pickle_it(sequence_list, './experiment/sequence_list.sv')
  30. return chains
  31. def conv_test(chains):
  32. convergence_result,split_data = calculate_R(chains,50)
  33. print(convergence_result)
  34. return reduce(concat, split_data)
  35. def unpickle(path):
  36. with open(path, 'rb') as fp:
  37. result = pickle.load(fp)
  38. print("Unpickled 1st element for '{}' is {}\n".format(path, result[0]))
  39. return result
  40. def pickle_it(x, path):
  41. with open(path, 'wb') as fp:
  42. pickle.dump(x, fp)
  43. def calc_precision_and_recall(posterior_sample, env, task1, true_expression, repeat=10000):
  44. learned_expressions=Counter(map(to_tuple, posterior_sample))
  45. print("Number of unique Norms in sequence={}".format(len(learned_expressions)))
  46. print("Top 5 norms:")
  47. for freq,expression in learned_expressions.most_common(5):
  48. print("Freq. {}".format(freq))
  49. print(expression,"\n")
  50. # Calculate precision and recall of top_n norms from learned expressions
  51. pr_result=performance(task1,env,true_expression,learned_expressions,
  52. folder_name=None,file_name="top_norm",
  53. top_n=5,beta=1,repeat=repeat,verbose=False)
  54. pr_result.head()
  55. def read_scenario():
  56. true_expression = unpickle('./demo/demo_exp.sv')
  57. env = unpickle('./demo/demo_env.sv')
  58. target_area=[position(-0.8,0.7),position(0.25,0.99)]
  59. task1 = task(colour_specific=env[1],shape_specific=env[2],target_area=target_area)
  60. return env,task1,true_expression
  61. def read_observations():
  62. return(unpickle('./demo/demo_data.sv'))
  63. if __name__ == "__main__":
  64. env,task1,true_expression = read_scenario()
  65. if sys.argv[1] == "pr":
  66. posterior_sample = unpickle('./experiment/posterior_sample.sv')
  67. calc_precision_and_recall(posterior_sample, env, task1, true_expression)
  68. elif sys.argv[1] == "convtest":
  69. if os.path.exists('./experiment/sequence_list.sv'):
  70. chains = unpickle('./experiment/sequence_list.sv')
  71. else:
  72. chains = gen_data_for_conv_test(data, env, task1)
  73. posterior_sample = conv_test(prepare_sequences(chains, warmup=True))
  74. pickle_it(posterior_sample, './experiment/posterior_sample.sv')
  75. else:
  76. print("Invalid argument")
  77. print(rule_dict)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...