Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

gen_mcmc_chains.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  1. import sys
  2. sys.path.append('src')
  3. import yaml
  4. import tqdm
  5. #from tqdm.notebook import tqdm, trange
  6. import pickle
  7. import csv
  8. import math
  9. from mcmc_norm_learning.rules_4 import q_dict, rule_dict, get_log_prob
  10. from mcmc_norm_learning.robot_task_new import task
  11. from mcmc_norm_learning.algorithm_1_v4 import algorithm_1, over_dispersed_starting_points
  12. from mcmc_norm_learning.environment import position
  13. from pickle_wrapper import unpickle, pickle_it
  14. import dask
  15. from dask.distributed import Client
  16. with open("params.yaml", 'r') as fd:
  17. params = yaml.safe_load(fd)
  18. n = params['n']
  19. m = params['m']
  20. rf = params['rf']
  21. colour_specific = params['colour_specific']
  22. shape_specific = params['shape_specific']
  23. target_area_parts = params['target_area'].replace(' ','').split(';')
  24. target_area_part0 = position(*map(float, target_area_parts[0].split(',')))
  25. target_area_part1 = position(*map(float, target_area_parts[1].split(',')))
  26. target_area = (target_area_part0, target_area_part1)
  27. true_expression = params['true_norm']['exp']
  28. env = unpickle('data/env.pickle')
  29. the_task = task(colour_specific, shape_specific,target_area)
  30. obs = unpickle('data/observations.pickle')
  31. num_chains = math.ceil(m/2)
  32. starts, info = over_dispersed_starting_points(num_chains,obs,env,the_task,time_threshold=math.inf)
  33. with open('metrics/starts_info.txt', 'w') as chain_info:
  34. chain_info.write(info)
  35. @dask.delayed
  36. def delayed_alg1(obs,env,the_task,q_dict,rule_dict,start,rf,max_iters):
  37. exp_seq,log_likelihoods = algorithm_1(obs,env,the_task,q_dict,rule_dict,
  38. "dummy value",
  39. start = start,
  40. relevance_factor=rf,max_iterations=max_iters,verbose=False)
  41. log_posteriors = [None]*len(exp_seq)
  42. for i in range(len(exp_seq)):
  43. exp = exp_seq[i]
  44. ll = log_likelihoods[i]
  45. log_prior = get_log_prob("NORMS",exp) # Note: this imports the rules dict from rules_4.py
  46. log_posteriors[i] = log_prior + ll
  47. return {'chain': exp_seq, 'log_posteriors': log_posteriors}
  48. chains_and_log_posteriors=[]
  49. #for i in trange(1,num_chains,desc="Loop for Individual Chains"):
  50. for i in tqdm.tqdm(range(num_chains),desc="Loop for Individual Chains"):
  51. #chains_and_log_posteriors.append(
  52. # delayed_alg1(obs,env,the_task,q_dict,rule_dict,starts[i],rf,4*n))
  53. # Replaced with all the lines in the loop body below:
  54. exp_seq,log_likelihoods = algorithm_1(obs,env,the_task,q_dict,rule_dict,
  55. "dummy value",
  56. start = starts[i],
  57. relevance_factor=rf,max_iterations=4*n,verbose=False)
  58. log_posteriors = [None]*len(exp_seq)
  59. for i in range(len(exp_seq)):
  60. exp = exp_seq[i]
  61. ll = log_likelihoods[i]
  62. log_prior = get_log_prob("NORMS",exp) # Note: this imports the rules dict from rules_4.py
  63. log_posteriors[i] = log_prior + ll
  64. chains_and_log_posteriors.append({'chain': exp_seq, 'log_posteriors': log_posteriors})
  65. #pickle_it(dask.compute(*chains_and_log_posteriors), 'data/chains_and_log_posteriors.pickle')
  66. pickle_it(chains_and_log_posteriors, 'data/chains_and_log_posteriors.pickle')
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...