Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tune.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  1. import sys
  2. from pathlib import Path
  3. import optuna
  4. import click
  5. from sklearn.pipeline import Pipeline
  6. def get_local_path():
  7. debug_local = True #to use local version
  8. local = (Path("..") / "yspecies").resolve()
  9. if debug_local and local.exists():
  10. sys.path.insert(0, Path("..").as_posix())
  11. #sys.path.insert(0, local.as_posix())
  12. print("extending pathes with local yspecies")
  13. print(sys.path)
  14. return local
  15. @click.command()
  16. def tune():
  17. print("starting hyperparameters optimization script")
  18. number_of_folds = 5
  19. #time_budget_seconds = 200
  20. n_trials = 5
  21. threads = 1
  22. local = get_local_path()
  23. from yspecies.dataset import ExpressionDataset
  24. from yspecies.partition import DataPartitioner, FeatureSelection, DataExtractor
  25. from yspecies.tuning import GeneralTuner, TuningResults
  26. from yspecies.workflow import Locations
  27. locations: Locations = Locations("./") if Path("./data").exists() else Locations("../")
  28. data = ExpressionDataset.from_folder(locations.interim.selected)
  29. selection = FeatureSelection(
  30. samples = ["tissue","species"], #samples metadata to include
  31. species = [], #species metadata other then Y label to include
  32. exclude_from_training = ["species"], #exclude some fields from LightGBM training
  33. to_predict = "lifespan", #column to predict
  34. categorical = ["tissue"])
  35. ext = Pipeline([
  36. ('extractor', DataExtractor(selection)), # to extract the data required for ML from the dataset
  37. ("partitioner", DataPartitioner(nfolds = number_of_folds, nhold_out = 1, species_in_validation=2, not_validated_species = ["Homo_sapiens"]))
  38. ])
  39. parts = ext.fit_transform(data)
  40. assert (len(parts.cv_merged_index) + len(parts.hold_out_merged_index)) == data.samples.shape[0], "cv and hold out should be same as samples number"
  41. assert parts.nhold_out ==1 and parts.hold_out_partition_indexes == [parts.indexes[4]], "checking that hold_out is computed in a right way"
  42. url = f'sqlite:///' +str((locations.output.optimization / "study.sqlite").absolute())
  43. print('loading (if exists) study from '+url)
  44. storage = optuna.storages.RDBStorage(
  45. url=url
  46. #engine_kwargs={'check_same_thread': False}
  47. )
  48. study = optuna.create_study(storage, study_name="naive_tuner", direction='minimize', load_if_exists=True)
  49. tuner = GeneralTuner(n_trials = n_trials, n_jobs = threads, study = study)
  50. best_parameters = tuner.fit(parts)
  51. print("======BEST=PARAMETERS===============")
  52. print(best_parameters)
  53. print("=====BEST=RESULTS===================")
  54. results = tuner.transform(parts)
  55. #parameters: ('COMPLETE', 0.20180128076981702, '2020-08-09 09:13:47.778135', 2)]
  56. print(results)
  57. import json
  58. with open(locations.output.optimization / 'parameters.json', 'w') as fp:
  59. json.dump(best_parameters, fp)
  60. with open(locations.output.optimization / 'results.json', 'w') as fp:
  61. json.dump(results, fp)
  62. if __name__ == "__main__":
  63. tune()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...