Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

analyse_chains.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
  1. import sys
  2. sys.path.append('src')
  3. from pickle_wrapper import unpickle, pickle_it
  4. import csv
  5. from mcmc_norm_learning.mcmc_convergence import prepare_sequences
  6. from mcmc_norm_learning.algorithm_1_v4 import to_tuple
  7. from collections import defaultdict
  8. import itertools
  9. import operator
  10. import yaml
  11. import math
  12. with open("params.yaml", 'r') as fd:
  13. params = yaml.safe_load(fd)
  14. m = params['m']
  15. num_chains = math.ceil(m/2)
  16. chains_and_log_posteriors = unpickle('data/chains_and_log_posteriors.pickle')[:num_chains]
  17. with open('metrics/chain_posteriors.csv', 'w', newline='') as csvfile, \
  18. open('metrics/chain_info.txt', 'w') as chain_info:
  19. chain_info.write(f'Number of chains: {len(chains_and_log_posteriors)}\n')
  20. chain_info.write(f'Length of each chain: {len(chains_and_log_posteriors[0]["chain"])}\n')
  21. csv_writer = csv.writer(csvfile)
  22. csv_writer.writerow(('chain_number', 'chain_pos', 'expression', 'log_posterior'))
  23. exps_in_chains = [None]*len(chains_and_log_posteriors)
  24. for i,chain_data in enumerate(chains_and_log_posteriors): # Consider skipping first few entries
  25. chain = chain_data['chain']
  26. log_posteriors = chain_data['log_posteriors']
  27. exp_lp_pairs = list(zip(chain,log_posteriors))
  28. exps_in_chains[i] = set(map(to_tuple, chain))
  29. print(sorted(log_posteriors, reverse=True))
  30. lps_to_exps = defaultdict(set)
  31. for exp,lp in exp_lp_pairs:
  32. lps_to_exps[lp].add(to_tuple(exp))
  33. num_exps_in_chain = len(exps_in_chains[i])
  34. print(lps_to_exps.keys())
  35. print('\n')
  36. chain_info.write(f'Num. expressions in chain {i}: {num_exps_in_chain}\n')
  37. decreasing_lps = sorted(lps_to_exps.keys(), reverse=True)
  38. chain_info.write("Expressions by decreasing log posterior\n")
  39. for lp in decreasing_lps:
  40. chain_info.write(f'lp = {lp} [{len(lps_to_exps[lp])} exps]:\n')
  41. for exp in lps_to_exps[lp]:
  42. chain_info.write(f' {exp}\n')
  43. chain_info.write('\n')
  44. chain_info.write('\n')
  45. changed_exp_indices = [i for i in range(1,len(chain)) if chain[i] != chain[i-1]]
  46. print(f'Writing {len(exp_lp_pairs)} rows to CSV file\n')
  47. csv_writer.writerows(((i,j,chain_lp_pair[0],chain_lp_pair[1]) for j,chain_lp_pair in enumerate(exp_lp_pairs)))
  48. all_exps = set(itertools.chain(*exps_in_chains))
  49. chain_info.write(f'Total num. distinct exps across all chains (including warm-up): {len(all_exps)}\n')
  50. with open("params.yaml", 'r') as fd:
  51. params = yaml.safe_load(fd)
  52. true_norm_exp = params['true_norm']['exp']
  53. true_norm_tuple = to_tuple(true_norm_exp)
  54. chain_info.write(f'True norm in some chain(s): {true_norm_tuple in all_exps}\n')
  55. num_chains_in_to_exps = defaultdict(set)
  56. for exp in all_exps:
  57. num_chains_in = operator.countOf(map(operator.contains,
  58. exps_in_chains,
  59. (exp for _ in range(len(exps_in_chains)))
  60. ),
  61. True)
  62. num_chains_in_to_exps[num_chains_in].add(exp)
  63. for num in sorted(num_chains_in_to_exps.keys(), reverse=True):
  64. chain_info.write(f'Out of {len(exps_in_chains)} chains ...\n')
  65. chain_info.write(f'{len(num_chains_in_to_exps[num])} exps are in {num} chains.\n')
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...