You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

43 lines
1.7 KiB

2 months ago
  1. from sklearn.linear_model import LogisticRegression
  2. from dtcontrol.benchmark_suite import BenchmarkSuite
  3. from dtcontrol.decision_tree.decision_tree import DecisionTree
  4. from dtcontrol.decision_tree.determinization.max_freq_determinizer import MaxFreqDeterminizer
  5. from dtcontrol.decision_tree.impurity.entropy import Entropy
  6. from dtcontrol.decision_tree.impurity.multi_label_entropy import MultiLabelEntropy
  7. from dtcontrol.decision_tree.splitting.axis_aligned import AxisAlignedSplittingStrategy
  8. from dtcontrol.decision_tree.splitting.linear_classifier import LinearClassifierSplittingStrategy
  9. import pydot
  10. def create_decision_tree(filename, name, output_folder,
  11. timeout=60*60*2,
  12. benchmark_file='benchmark',
  13. save_folder='saved_classifiers',
  14. export_pdf=False,
  15. classifiers=None):
  16. suite = BenchmarkSuite(timeout=timeout,
  17. save_folder=save_folder,
  18. output_folder=output_folder,
  19. benchmark_file=benchmark_file,
  20. rerun=True)
  21. suite.add_datasets([filename])
  22. if classifiers is None:
  23. aa = AxisAlignedSplittingStrategy()
  24. aa.priority = 1
  25. classifiers = [DecisionTree([aa], Entropy(), name)]
  26. suite.benchmark(classifiers)
  27. if export_pdf:
  28. for dataset in suite.datasets:
  29. for classifier in classifiers:
  30. filename = suite.get_filename(output_folder, dataset=dataset , classifier=classifier, extension='.dot')
  31. (graph,) = pydot.graph_from_dot_file(filename)
  32. graph.write_pdf(F'{name}.pdf')
  33. return suite