Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tune.py 4.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  1. import sys
  2. from pathlib import Path
  3. import optuna
  4. import click
  5. from sklearn.pipeline import Pipeline
  6. def get_local_path():
  7. debug_local = True #to use local version
  8. local = (Path("..") / "yspecies").resolve()
  9. if debug_local and local.exists():
  10. sys.path.insert(0, Path("..").as_posix())
  11. #sys.path.insert(0, local.as_posix())
  12. print("extending pathes with local yspecies")
  13. print(sys.path)
  14. return local
  15. @click.command()
  16. @click.option('--name', default="general_tuner", help='study name')
  17. @click.option('--trials', default=10, help='Number of trials in hyper optimization')
  18. @click.option('--folds', default=5, help='Number of folds in cross-validation')
  19. @click.option('--hold_outs', default=1, help='Number of hold outs in cross-validation')
  20. @click.option('--threads', default=1, help="number of threads (1 by default). If you put -1 it will try to utilize all cores, however it can be dangerous memorywise")
  21. @click.option('--species_in_validation', default=3, help="species_in_validation")
  22. @click.option('--not_validated_species', default="", help="not_validated_species")
  23. @click.option('--repeats', default=10, help="number of times to repeat validation")
  24. @click.option("--loss", default="huber", help="loss type (huber, l1, l2), huber by default")
  25. def tune(name: str, trials: int, folds: int, hold_outs: int, threads: int, species_in_validation: int, not_validated_species: str, repeats: int):
  26. print(f"starting hyperparameters optimization script with {trials} trials, {folds} folds and {hold_outs} hold outs!")
  27. local = get_local_path()
  28. if not_validated_species is None or not_validated_species == "":
  29. not_validated_species = []
  30. elif type(not_validated_species) is str:
  31. not_validated_species = [not_validated_species]
  32. else:
  33. not_validated_species = not_validated_species
  34. from yspecies.dataset import ExpressionDataset
  35. from yspecies.partition import DataPartitioner, FeatureSelection, DataExtractor
  36. from yspecies.tuning import GeneralTuner, TuningResults
  37. from yspecies.workflow import Locations
  38. from yspecies.models import Metrics
  39. locations: Locations = Locations("./") if Path("./data").exists() else Locations("../")
  40. data = ExpressionDataset.from_folder(locations.interim.selected)
  41. selection = FeatureSelection(
  42. samples = ["tissue","species"], #samples metadata to include
  43. species = [], #species metadata other then Y label to include
  44. exclude_from_training = ["species"], #exclude some fields from LightGBM training
  45. to_predict = "lifespan", #column to predict
  46. categorical = ["tissue"])
  47. ext = Pipeline([
  48. ('extractor', DataExtractor(selection)), # to extract the data required for ML from the dataset
  49. ("partitioner", DataPartitioner(nfolds = folds, nhold_out = 1, species_in_validation=species_in_validation, not_validated_species = not_validated_species))
  50. ])
  51. parts = ext.fit_transform(data)
  52. assert (len(parts.cv_merged_index) + len(parts.hold_out_merged_index)) == data.samples.shape[0], "cv and hold out should be same as samples number"
  53. assert parts.nhold_out ==1 and parts.hold_out_partition_indexes == [parts.indexes[4]], "checking that hold_out is computed in a right way"
  54. url = f'sqlite:///' +str((locations.output.optimization / "study.sqlite").absolute())
  55. print('loading (if exists) study from '+url)
  56. storage = optuna.storages.RDBStorage(
  57. url=url
  58. #engine_kwargs={'check_same_thread': False}
  59. )
  60. study = optuna.create_study(storage, study_name="general_tuner", direction='minimize', load_if_exists=True)
  61. tuner = GeneralTuner(n_trials = trials, n_jobs = threads, study = study)
  62. best_parameters = tuner.fit(parts)
  63. print("======BEST=PARAMETERS===============")
  64. print(best_parameters)
  65. print("=====BEST=RESULTS===================")
  66. results = tuner.transform(parts)
  67. #parameters: ('COMPLETE', 0.20180128076981702, '2020-08-09 09:13:47.778135', 2)]
  68. print(results)
  69. import json
  70. with open(locations.output.optimization / 'parameters.json', 'w') as fp:
  71. json.dump(best_parameters, fp)
  72. if results.train_metrics is not None and results.validation_metrics is not None:
  73. Metrics.combine([results.train_metrics, results.validation_metrics]).to_csv(locations.output.optimization / 'metrics.tsv', sep="\t")
  74. if __name__ == "__main__":
  75. tune()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...