Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tune.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  1. import sys
  2. from pathlib import Path
  3. import optuna
  4. import click
  5. from sklearn.pipeline import Pipeline
  6. def get_local_path():
  7. debug_local = True #to use local version
  8. local = (Path("..") / "yspecies").resolve()
  9. if debug_local and local.exists():
  10. sys.path.insert(0, Path("..").as_posix())
  11. #sys.path.insert(0, local.as_posix())
  12. print("extending pathes with local yspecies")
  13. print(sys.path)
  14. return local
  15. @click.command()
  16. @click.option('--name', default="general_tuner", help='study name')
  17. @click.option('--trials', default=10, help='Number of trials in hyper optimization')
  18. @click.option('--folds', default=5, help='Number of folds in cross-validation')
  19. @click.option('--hold_outs', default=1, help='Number of hold outs in cross-validation')
  20. @click.option('--threads', default=1, help="number of threads (1 by default). If you put -1 it will try to utilize all cores, however it can be dangerous memorywise")
  21. @click.option('--species_in_validation', default=3, help="species_in_validation")
  22. @click.option('--not_validated_species', default="Homo_sapiens", help="not_validated_species")
  23. def tune(name, trials, folds, hold_outs, threads, species_in_validation, not_validated_species):
  24. print(f"starting hyperparameters optimization script with {trials} trials, {folds} folds and {hold_outs} hold outs!")
  25. local = get_local_path()
  26. not_validated_species = [not_validated_species] if type(not_validated_species) is str else not_validated_species
  27. from yspecies.dataset import ExpressionDataset
  28. from yspecies.partition import DataPartitioner, FeatureSelection, DataExtractor
  29. from yspecies.tuning import GeneralTuner, TuningResults
  30. from yspecies.workflow import Locations
  31. from yspecies.models import Metrics
  32. locations: Locations = Locations("./") if Path("./data").exists() else Locations("../")
  33. data = ExpressionDataset.from_folder(locations.interim.selected)
  34. selection = FeatureSelection(
  35. samples = ["tissue","species"], #samples metadata to include
  36. species = [], #species metadata other then Y label to include
  37. exclude_from_training = ["species"], #exclude some fields from LightGBM training
  38. to_predict = "lifespan", #column to predict
  39. categorical = ["tissue"])
  40. ext = Pipeline([
  41. ('extractor', DataExtractor(selection)), # to extract the data required for ML from the dataset
  42. ("partitioner", DataPartitioner(nfolds = folds, nhold_out = 1, species_in_validation=species_in_validation, not_validated_species = not_validated_species))
  43. ])
  44. parts = ext.fit_transform(data)
  45. assert (len(parts.cv_merged_index) + len(parts.hold_out_merged_index)) == data.samples.shape[0], "cv and hold out should be same as samples number"
  46. assert parts.nhold_out ==1 and parts.hold_out_partition_indexes == [parts.indexes[4]], "checking that hold_out is computed in a right way"
  47. url = f'sqlite:///' +str((locations.output.optimization / "study.sqlite").absolute())
  48. print('loading (if exists) study from '+url)
  49. storage = optuna.storages.RDBStorage(
  50. url=url
  51. #engine_kwargs={'check_same_thread': False}
  52. )
  53. study = optuna.create_study(storage, study_name="general_tuner", direction='minimize', load_if_exists=True)
  54. tuner = GeneralTuner(n_trials = trials, n_jobs = threads, study = study)
  55. best_parameters = tuner.fit(parts)
  56. print("======BEST=PARAMETERS===============")
  57. print(best_parameters)
  58. print("=====BEST=RESULTS===================")
  59. results = tuner.transform(parts)
  60. #parameters: ('COMPLETE', 0.20180128076981702, '2020-08-09 09:13:47.778135', 2)]
  61. print(results)
  62. import json
  63. with open(locations.output.optimization / 'parameters.json', 'w') as fp:
  64. json.dump(best_parameters, fp)
  65. if results.train_metrics is not None and results.validation_metrics is not None:
  66. Metrics.combine([results.train_metrics, results.validation_metrics]).to_csv(locations.output.optimization / 'metrics.tsv', sep="\t")
  67. if __name__ == "__main__":
  68. tune()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...