-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
413 lines (366 loc) · 13 KB
/
utils.py
File metadata and controls
413 lines (366 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
import argparse
import os
import logging
import random
import numpy as np
from typing import List, Optional, Type
import torch
from monai.data import DataLoader
import sys
from functools import partial
from strategies import * # Assuming these are defined in strategies/*.py
def none_or_str(value):
if value.lower() == "none":
return None
return value
def arg_parser():
parser = argparse.ArgumentParser()
########### PATHS ############
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Path to save/load checkpoint"
)
parser.add_argument(
"--load_checkpoint_path",
type=none_or_str,
default=None,
help="Path to load checkpoint from",
)
parser.add_argument(
"--data_path",
type=str,
default="/hdd/Continual_learning_data/FINAL",
help="Root directory of data",
)
################### DATASET PARAMETERS ###################
parser.add_argument(
"--drop_modality", type=int, default=1, help="Randomly drop modalities"
)
parser.add_argument(
"--num_experts", type=int, default=4, help="Number of experts in the model"
)
parser.add_argument(
"--heavy_aug",
type=int,
default=0,
help="Use heavy augmentation (default: 0, no heavy augmentation)",
)
parser.add_argument(
"--fixed_length",
type=int,
default=0,
help="Use fixed length for sequences (default: 0, no fixed length)",
)
#################### TRAINING PARAMETERS ####################
parser.add_argument(
"--n_epochs", type=int, default=1, help="Number of epochs to train for"
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument(
"--num_workers",
type=int,
default=os.cpu_count() - 1,
help="Number of workers for dataloader",
)
parser.add_argument("--seed", type=int, default=12345, help="Random seed")
parser.add_argument(
"--name",
type=none_or_str,
default=None,
help="Name of the experiment (used for saving checkpoints)",
)
parser.add_argument(
"--lr", type=float, default=1e-3, help="Learning rate for optimizer"
)
parser.add_argument(
"--optimizer", type=str, default="adam", help="Optimizer to use"
)
parser.add_argument("--amp", type=int, default=0, help="Use AMP")
parser.add_argument(
"--compile", type=int, default=1, help="Compile model"
) # it has problems working with MONAI sometimes
parser.add_argument(
"--network",
type=str,
default="unet",
help="Network architecture to use (unet or moe)",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for training (cuda or cpu)",
)
parser.add_argument(
"--save_interval",
type=int,
default=5,
help="Interval for saving checkpoints (in epochs)",
)
############ MISCELLANEOUS PARAMETERS ############
parser.add_argument(
"--show_progress", type=int, default=1, help="Show progress bar"
)
parser.add_argument(
"--sequence", type=int, default=0, help="which sequence to use for training"
)
parser.add_argument(
"--strategy",
type=str,
default="naive",
help="Strategy to use for training (naive, joint, etc.)",
)
############### BUFFER PARAMETERS ############
parser.add_argument(
"--buffer_size",
type=int,
default=100,
help="Size of the buffer for experience replay",
)
################## DECISION MAKING MODE ##################
parser.add_argument(
"--boundary",
type=int,
default=0,
help="Use boundary uncertainty (default: 0, no boundary uncertainty)",
)
parser.add_argument(
"--thickness",
type=int,
default=7,
help="Thickness for boundary uncertainty (default: 7)",
)
################ For the ER_4_rank strategy ################
parser.add_argument(
"--uncertainty_weight",
type=float,
default=0.0,
help="Weight for uncertainty in buffer management",
)
parser.add_argument(
"--complexity_weight",
type=float,
default=0.0,
help="Weight for complexity in buffer management",
)
parser.add_argument(
"--tumor_size_weight",
type=float,
default=0.0,
help="Weight for tumor size in buffer management",
)
parser.add_argument(
"--confidence_weight",
type=float,
default=0.0,
help="Weight for confidence in buffer management",
)
parser.add_argument(
"--hard_soft_weighting",
type=float,
nargs=2,
default=[1.0, 1.0],
help="Weights for hard and soft examples in buffer management (provide two numbers, e.g. --hard_soft_weighting 1.0 1.0)",
)
############## For dynamic modality setting ##############
parser.add_argument(
"--dynamic_modalities",
type=int,
default=0,
help="Use dynamic modality setting (default: 0, do not use dynamic modalities)",
)
################# text embedding parameters #################
parser.add_argument(
"--use_text_embedding",
type=int,
default=0,
help="Use text embeddings (default: 0, do not use text embeddings)",
)
parser.add_argument(
"--num_heads",
type=int,
default=8,
help="Number of heads in the text embedding transformer",
)
return parser.parse_args()
def print_args(args):
# Log the output to the file
logging.info("*" * 10)
logging.info("Arguments:")
for arg in vars(args):
logging.info(f"{arg}: {getattr(args, arg)}")
logging.info("*" * 10)
def setup_logging(log_file_path="app_log.txt"):
"""
Configures the logging module to log both to the console and to a specified file path.
:param log_file_path: The path where the log file will be saved. Defaults to 'app_log.txt'.
"""
# Check if logging has already been configured
if hasattr(setup_logging, "done"):
return
# Set the log formatter
formatter = logging.Formatter(
"[%(levelname)s] %(asctime)s - %(message)s", datefmt="%d-%b-%y %H:%M:%S"
)
# Create a stream handler for console output
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
# Ensure the directory exists for the log file
log_dir = os.path.dirname(log_file_path)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir) # Create the directory if it does not exist
# Create a file handler to save logs to the provided file path
file_handler = logging.FileHandler(log_file_path) # Log file path from parameter
file_handler.setFormatter(formatter)
# Get the logger and remove any existing handlers
logger = logging.getLogger("root")
if logger.handlers:
for h in logger.handlers:
logger.removeHandler(h)
# Set the log level and add both handlers
logger.setLevel(os.getenv("LOG_LEVEL", "INFO"))
logger.addHandler(stream_handler)
logger.addHandler(file_handler)
# Mark the logging setup as done
setattr(setup_logging, "done", True)
def set_random_seed(seed: int) -> None:
"""
Sets the seeds at a certain value.
Args:
seed: the value to be set
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
try:
torch.cuda.manual_seed_all(seed)
except BaseException:
logging.error("Could not set cuda seed.")
def worker_init_fn(worker_id, num_workers, seed, rank=1):
"""
Sets the seeds for a worker of a dataloader.
The seed of each worker is set to: `num_worker * rank + worker_id + seed`
"""
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
def create_seeded_dataloader(
args, dataset, non_verbose=False, **dataloader_args
) -> DataLoader:
"""
Creates a dataloader object from a dataset, setting the seeds for the workers (if `--seed` is set).
Args:
args: the arguments of the program
dataset: the dataset to be loaded
verbose: whether to print the number of workers
dataloader_args: external arguments of the dataloader
Returns:
the dataloader object
"""
n_cpus = 4 if not hasattr(os, "sched_getaffinity") else len(os.sched_getaffinity(0))
num_workers = (
min(8, n_cpus) if args.num_workers is None else args.num_workers
) # limit to 8 cpus if not specified
dataloader_args["num_workers"] = (
num_workers
if "num_workers" not in dataloader_args
else dataloader_args["num_workers"]
)
if not non_verbose:
logging.info(
f'Using {dataloader_args["num_workers"]} workers for the dataloader.'
)
if args.seed is not None and args.seed > 0:
worker_generator = torch.Generator()
worker_generator.manual_seed(args.seed)
else:
worker_generator = None
dataloader_args["generator"] = (
worker_generator
if "generator" not in dataloader_args
else dataloader_args["generator"]
)
init_fn = (
partial(worker_init_fn, num_workers=num_workers, seed=args.seed)
if args.seed is not None
else None
)
dataloader_args["worker_init_fn"] = (
init_fn
if "worker_init_fn" not in dataloader_args
else dataloader_args["worker_init_fn"]
)
return DataLoader(dataset, **dataloader_args)
def _needs_context(network: str) -> bool:
"""Determine if the dataset should return context based on the network type."""
return not network in ("unet", "unet_dyn")
def build_dataset_kwargs(
name: str, *, args, is_test: bool, prev_modalities_set: Optional[set]
):
"""Prepare kwargs shared by both dataset classes; keeps source of truth in one place."""
return dict(
root_path=args.data_path,
dataset_type=name,
test_mode=is_test,
randomly_drop_modalities=False if is_test else args.drop_modality,
heavy_aug=args.heavy_aug,
fixed_length=args.fixed_length,
return_text_embedding=bool(args.use_text_embedding),
# Only dyn class will consume this; BasicDataset will ignore if its signature doesn't have it.
# If BasicDataset doesn't accept it, we'll remove it before calling.
perv_modality_set=prev_modalities_set,
)
def instantiate_dataset(cls: Type, kwargs: dict):
"""
Instantiate dataset while gracefully handling the 'perv_modality_set' kwarg
if the class doesn't support it (in case BasicDataset lacks it).
"""
try:
return cls(**kwargs)
except TypeError as e:
# Drop perv_modality_set for non-dynamic class signatures.
if "perv_modality_set" in kwargs:
k2 = dict(kwargs)
k2.pop("perv_modality_set", None)
return cls(**k2)
raise e
def get_strategy(args, train_datasets, test_datasets, writer, experiment_name):
strategies = {
"naive": Naive,
"joint": Joint,
"cumulative": Cumulative,
"er": ER,
"scratch": Scratch,
"clmu_net": CLMUNet,
}
if args.strategy not in strategies:
raise ValueError(
f"Unknown strategy '{args.strategy}'. Available strategies: {list(strategies.keys())}"
)
return strategies[args.strategy](
args, train_datasets, test_datasets, writer, experiment_name
)
def get_experiment_name(args):
default_name = f"_optim_{args.optimizer}_lr_{args.lr}_bs_{args.batch_size}_epochs_{args.n_epochs}_drop_{args.drop_modality}_amp_{args.amp}_seed_{args.seed}"
name_prefix = ""
buffer_based_name = ""
text_based_name = ""
if args.dynamic_modalities or args.use_text_embedding:
if args.dynamic_modalities and args.use_text_embedding:
name_prefix = "_dyn_text"
text_based_name = f"_heads_{args.num_heads}"
elif args.dynamic_modalities:
name_prefix = "_dyn"
else:
name_prefix = "_text"
text_based_name = f"_heads_{args.num_heads}"
if args.strategy in ["er", "rclp"]:
buffer_based_name = f"_buffer_{args.buffer_size}"
elif args.strategy in ["clmu_net"]:
buffer_based_name = f"_buffer_{args.buffer_size}_boundary_{args.boundary}_thickness_{args.thickness}_flags_U{args.uncertainty_weight}_C{args.complexity_weight}_TS{args.tumor_size_weight}_CF{args.confidence_weight}_H{args.hard_soft_weighting[0]}_S{args.hard_soft_weighting[1]}"
additional_name = ""
if args.name is not None:
additional_name = f"_{args.name}"
final_name = f"{args.strategy}{name_prefix}_network_{args.network}{buffer_based_name}{text_based_name}{default_name}{additional_name}"
return final_name, os.path.join(
args.checkpoint_path, f"{args.strategy}{name_prefix}"
)