Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

gen_observations.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  1. import sys
  2. sys.path.append('src')
  3. import yaml
  4. import shutil
  5. import math
  6. import pickle
  7. import numpy as np
  8. from pickle_wrapper import unpickle, pickle_it
  9. from mcmc_norm_learning.algorithm_1_v4 import create_data
  10. from mcmc_norm_learning.rules_4 import get_prob, get_log_prob
  11. from mcmc_norm_learning.environment import position
  12. from mcmc_norm_learning.robot_task_new import task, robot
  13. with open("params.yaml", 'r') as fd:
  14. params = yaml.safe_load(fd)
  15. true_norm_exp = params['true_norm']['exp']
  16. num_observations = params['num_observations']
  17. obs_data_set = params['obs_data_set']
  18. colour_specific = params['colour_specific']
  19. shape_specific = params['shape_specific']
  20. target_area_parts = params['target_area'].replace(' ','').split(';')
  21. target_area_part0 = position(*map(float, target_area_parts[0].split(',')))
  22. target_area_part1 = position(*map(float, target_area_parts[1].split(',')))
  23. target_area = (target_area_part0, target_area_part1)
  24. print(target_area_part0.coordinates())
  25. print(target_area_part1.coordinates())
  26. the_task = task(colour_specific, shape_specific,target_area)
  27. env = unpickle('data/env.pickle')
  28. rob = robot(the_task,env)
  29. actionable = rob.all_actionable()
  30. print(actionable)
  31. true_norm_prior = get_prob("NORMS",true_norm_exp)
  32. true_norm_log_prior = get_log_prob("NORMS",true_norm_exp)
  33. if math.isclose(true_norm_prior, 0):
  34. print(f'Stopping! True norm expression has near-zero prior ({true_norm_prior})\n')
  35. # elif (num_observations == 100 and obs_data_set == 1):
  36. # shutil.copyfile('data/default_observations.pickle', 'data/observations.pickle')
  37. else:
  38. observations = create_data(true_norm_exp,env,name=None,task=the_task,random_task=False,
  39. num_actionable=np.nan,num_repeat=num_observations,verbose=False)
  40. pickle_it(observations, 'data/observations.pickle')
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...