Skip to content
30 changes: 18 additions & 12 deletions flaml/automl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,18 +912,24 @@ def search_space(cls, data_size, task, **params):
If OOM, user should change the search space themselves
"""

search_space_dict["model_path"] = {
"domain": tune.choice(
[
"google/electra-base-discriminator",
"bert-base-uncased",
"roberta-base",
"facebook/muppet-roberta-base",
"google/electra-small-discriminator",
]
),
"init_value": "facebook/muppet-roberta-base",
}
if task not in NLG_TASKS:
search_space_dict["model_path"] = {
"domain": tune.choice(
[
"google/electra-base-discriminator",
"bert-base-uncased",
"roberta-base",
"facebook/muppet-roberta-base",
"google/electra-small-discriminator",
]
),
"init_value": "facebook/muppet-roberta-base",
}
else:
search_space_dict["model_path"] = {
"domain": tune.choice(["t5-small", "facebook/bart-base"]),
"init_value": "t5-small",
}
return search_space_dict


Expand Down
18 changes: 9 additions & 9 deletions flaml/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import warnings

from flaml.automl.model import *


warnings.warn(
"Importing from `flaml.model` is deprecated. Please use `flaml.automl.model`.",
DeprecationWarning,
)
import warnings
from flaml.automl.model import *
warnings.warn(
"Importing from `flaml.model` is deprecated. Please use `flaml.automl.model`.",
DeprecationWarning,
)
46 changes: 46 additions & 0 deletions test/nlp/test_autohf_modelselection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import sys
import pytest
import requests
from utils import get_toy_data_summarization, get_automl_settings
import os
import shutil


@pytest.mark.skipif(
sys.platform == "darwin" or sys.version < "3.7",
reason="do not run on mac os or py<3.7",
)
def test_hf_ms():
from flaml import AutoML

X_train, y_train, X_val, y_val, X_test = get_toy_data_summarization()

automl = AutoML()

automl_settings = {
"gpu_per_trial": 0,
"max_iter": 3,
"time_budget": 20,
"task": "summarization",
"metric": "rouge1",
"log_file_name": "seqclass.log",
"use_ray": False,
"estimator_list": ["transformer_ms"],
}

try:
automl.fit(
X_train=X_train,
y_train=y_train,
X_val=X_val,
y_val=y_val,
**automl_settings
)
automl.score(X_val, y_val, **{"metric": "accuracy"})
automl.pickle("automl.pkl")
except requests.exceptions.HTTPError:
return


if __name__ == "__main__":
test_hf_ms()