-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathencoder.py
More file actions
143 lines (111 loc) · 4.82 KB
/
encoder.py
File metadata and controls
143 lines (111 loc) · 4.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn as nn
# from torchsummary import summary
from layers import MultiHeadAttention
from data import generate_data
import math
class Normalization(nn.Module):
def __init__(self, embed_dim, normalization='batch'):
super().__init__()
normalizer_class = {
'batch': nn.BatchNorm1d,
'instance': nn.InstanceNorm1d}.get(normalization, None)
self.normalizer = normalizer_class(embed_dim, affine=True)
# Normalization by default initializes affine parameters with bias 0 and weight unif(0,1) which is too large!
# self.init_parameters()
# def init_parameters(self):
# for name, param in self.named_parameters():
# stdv = 1. / math.sqrt(param.size(-1))
# param.data.uniform_(-stdv, stdv)
def forward(self, x):
if isinstance(self.normalizer, nn.BatchNorm1d):
# (batch, num_features)
# https://discuss.pytorch.org/t/batch-normalization-of-linear-layers/20989
return self.normalizer(x.view(-1, x.size(-1))).view(*x.size())
elif isinstance(self.normalizer, nn.InstanceNorm1d):
return self.normalizer(x.permute(0, 2, 1)).permute(0, 2, 1)
else:
assert self.normalizer is None, "Unknown normalizer type"
return x
class ResidualBlock_BN(nn.Module):
def __init__(self, MHA, BN, **kwargs):
super().__init__(**kwargs)
self.MHA = MHA
self.BN = BN
def forward(self, x, mask=None):
if mask is None:
return self.BN(x + self.MHA(x))
return self.BN(x + self.MHA(x, mask))
class SelfAttention(nn.Module):
def __init__(self, MHA, **kwargs):
super().__init__(**kwargs)
self.MHA = MHA
def forward(self, x, mask=None):
return self.MHA([x, x, x], mask=mask)
class EncoderLayer(nn.Module):
# nn.Sequential):
def __init__(self, n_heads=8, FF_hidden=512, embed_dim=128, **kwargs):
super().__init__(**kwargs)
self.n_heads = n_heads
self.FF_hidden = FF_hidden
self.BN1 = Normalization(embed_dim, normalization='batch')
self.BN2 = Normalization(embed_dim, normalization='batch')
self.MHA_sublayer = ResidualBlock_BN(
SelfAttention(
MultiHeadAttention(n_heads=self.n_heads, embed_dim=embed_dim, need_W=True)
),
self.BN1
)
self.FF_sublayer = ResidualBlock_BN(
nn.Sequential(
nn.Linear(embed_dim, FF_hidden, bias=True),
nn.ReLU(),
nn.Linear(FF_hidden, embed_dim, bias=True)
),
self.BN2
)
def forward(self, x, mask=None):
""" arg x: (batch, max_stacks, embed_dim)
return: (batch, max_stacks, embed_dim)
"""
return self.FF_sublayer(self.MHA_sublayer(x, mask=mask))
class GraphAttentionEncoder(nn.Module):
def __init__(self, embed_dim=128, n_heads=8, n_layers=3, FF_hidden=512,n_containers=8,max_stacks=4,max_tiers=4):
super().__init__()
#self.init_W = torch.nn.Linear(max_tiers, embed_dim, bias=True)
self.init_W = torch.nn.LSTM(input_size=max_tiers,hidden_size=embed_dim)
self.encoder_layers = nn.ModuleList([EncoderLayer(n_heads, FF_hidden, embed_dim) for _ in range(n_layers)])
def forward(self, nowx, mask=None):
""" x (batch,max_stacks,max_tiers)
return: (node embeddings(= embedding for all nodes), graph embedding(= mean of node embeddings for graph))
=((batch, max_stacks, embed_dim), (batch, embed_dim))
"""
#x = torch.cat([self.init_W_depot(x[0])[:, None, :],
# self.init_W(torch.cat([x[1], x[2][:, :, None]], dim=-1))], dim=1)
#传进来的是env.x,如果这里不是clone的话会出现 反向传播出问题 一个变量被改变了
x=nowx.clone()
#x=self.init_W(x)
x=self.init_W(x)[0]
for layer in self.encoder_layers:
x = layer(x, mask)
return (x, torch.mean(x, dim=1))
if __name__ == '__main__':
batch = 5
n_nodes = 21
encoder = GraphAttentionEncoder(n_layers=1)
#device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
encoder = encoder.to(device)
data=generate_data(device,n_samples=batch,n_containers = 8,max_stacks=4,max_tiers=4)
#data = generate_data(device, n_samples=batch, n_customer=n_nodes - 1)
# mask = torch.zeros((batch, n_nodes, 1), dtype = bool)
output = encoder(data, mask=None)
print('output[0].shape:', output[0].size())
print('output[1].shape', output[1].size())
# summary(encoder, [(2), (20,2), (20)])
cnt = 0
for i, k in encoder.state_dict().items():
print(i, k.size(), torch.numel(k))
cnt += torch.numel(k)
print(cnt)
# output[0].mean().backward()
# print(encoder.init_W_depot.weight.grad)