Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
719c8bd
added new viz api
chadhardin Sep 14, 2022
f51a924
Merge branch 'main' of https://github.com/chadhardin/FLAMLTest
chadhardin Sep 14, 2022
8f130b8
if statement added for different plots
chadhardin Sep 28, 2022
41b084a
Merge branch 'microsoft:main' into main
chadhardin Oct 3, 2022
58ea65f
Merge branch 'microsoft:main' into main
chadhardin Oct 10, 2022
4ca305e
Small Updates
chadhardin Oct 10, 2022
bdf75d0
Merge branch 'main' of https://github.com/chadhardin/FLAMLTest
chadhardin Oct 10, 2022
f0e7322
Merge branch 'microsoft:main' into main
chadhardin Oct 17, 2022
1d9cd45
Visualization created with important features
chadhardin Oct 17, 2022
4661b3f
Merge branch 'main' of https://github.com/chadhardin/FLAMLTest
chadhardin Oct 17, 2022
a9d5f57
Renamed type to feature_importance to be specific
chadhardin Oct 17, 2022
cb9ed36
Adding minor comments
chadhardin Oct 17, 2022
5fff7b9
Included an example in visualization
chadhardin Oct 17, 2022
92e778f
Merge branch 'microsoft:main' into main
chadhardin Oct 24, 2022
7d9ade5
Merge branch 'microsoft:main' into main
chadhardin Oct 31, 2022
5e868cf
Adding the logger warning and save fig
chadhardin Oct 31, 2022
28a2a3b
Merge branch 'main' of https://github.com/chadhardin/FLAMLTest
chadhardin Oct 31, 2022
78f08f7
Forgot a comma
chadhardin Oct 31, 2022
1d4d879
Merge branch 'microsoft:main' into main
chadhardin Nov 2, 2022
758290c
Merge branch 'main' into main
qingyun-wu Nov 7, 2022
0cfb2f7
Feature and validation updates
chadhardin Nov 7, 2022
f53aeee
test file added
chadhardin Nov 7, 2022
a940568
Updated testing file name
chadhardin Nov 7, 2022
540d40b
Merge branch 'microsoft:main' into main
chadhardin Nov 14, 2022
73cdcad
Merge branch 'microsoft:main' into main
chadhardin Nov 16, 2022
f11f811
Merge branch 'microsoft:main' into main
chadhardin Nov 17, 2022
ec2e26e
Merge remote-tracking branch 'upstream/main' into main
qingyun-wu Dec 7, 2022
aca206b
simplify api
qingyun-wu Dec 7, 2022
3eee239
Merge branch 'main' into main
qingyun-wu Dec 7, 2022
149a97e
Update setup.py
qingyun-wu Dec 7, 2022
df73b4f
merge
qingyun-wu Dec 24, 2022
7e95240
Update .pre-commit-config.yaml
qingyun-wu Dec 24, 2022
229ffc7
revert
qingyun-wu Dec 24, 2022
aa0ceeb
Merge branch 'main' into main
qingyun-wu Dec 24, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ repos:
- id: check-merge-conflict
- id: detect-private-key
- id: trailing-whitespace
- id: no-commit-to-branch
# - id: no-commit-to-branch
49 changes: 49 additions & 0 deletions flaml/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3051,6 +3051,55 @@ def is_to_reverse_metric(metric, task):
del self._state.groups, self._state.groups_all, self._state.groups_val
logger.setLevel(old_level)

def visualize(
self,
type="learning_curve",
automl_instance=None,
plot_filename=None,
log_file_name=None,
**kwargs,
):
"""
type: The type of the plot. The default visualization type is the learning curve.
automl_instance: An flaml AutoML instance.
plot_filename: str | File name
log_file_name: str | Log file name
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError(
"The visualization functionalitye requires installation of matplotlib. "
"Please run pip install flaml[visualization]"
)

if type == "feature_importance":
plt.barh(self.feature_names_in_, self.feature_importances_)
plt.savefig("{}.png".format(plot_filename))
plt.close()
elif type == "learning_curve":
from flaml.data import get_output_from_log

log_file_name = kwargs.get("log_file_name")
if not log_file_name:
log_file_name = self._settings.get("log_file_name")
print("log", log_file_name)
if not log_file_name:
logger.warning("Please provide a search history log file.")
(
time_history,
best_valid_loss_history,
valid_loss_history,
config_history,
metric_history,
) = get_output_from_log(filename=log_file_name, time_budget=240)
plt.title("Learning Curve")
plt.xlabel("Wall Clock Time (s)")
plt.ylabel("Validation Accuracy")
plt.scatter(time_history, 1 - np.array(valid_loss_history))
plt.step(time_history, 1 - np.array(best_valid_loss_history), where="post")
plt.savefig("{}".format(plot_filename))

def _search_parallel(self):
if self._use_ray is not False:
try:
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"rouge_score",
"hcrystalball==0.1.10",
"seqeval",
"matplotlib",
"pytorch-forecasting>=0.9.0,<=0.10.1",
"mlflow",
"pyspark>=3.0.0",
Expand Down Expand Up @@ -110,6 +111,7 @@
"hcrystalball==0.1.10",
"pytorch-forecasting>=0.9.0",
],
"visualization": ["matplotlib"],
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
},
classifiers=[
Expand Down
24 changes: 24 additions & 0 deletions test/automl/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from flaml import AutoML
from flaml.data import load_openml_dataset


def test_fi_lc():
X_train, X_test, y_train, y_test = load_openml_dataset(
dataset_id=1169, data_dir="./"
)
settings = {
"time_budget": 10, # total running time in seconds
"metric": "accuracy", # can be: 'r2', 'rmse', 'mae', 'mse', 'accuracy', 'roc_auc', 'roc_auc_ovr',
# 'roc_auc_ovo', 'log_loss', 'mape', 'f1', 'ap', 'ndcg', 'micro_f1', 'macro_f1'
"task": "classification", # task type
"log_file_name": "airlines_experiment.log", # flaml log file
"seed": 7654321, # random seed
}
automl = AutoML(**settings)
automl.fit(X_train=X_train, y_train=y_train)
automl.visualize(type="feature_importance", plot_filename="feature_importance")
automl.visualize(type="learning_curve", plot_filename="learning_curve")


if __name__ == "__main__":
test_fi_lc()