用深度学习手搓鸡兔同笼问题之pytorch版本

2023 Aug 20 See all posts


机器学习时代,记得在学校里面还是以发掘领域特征为主,比如图像识别里提取SIFT特征,走出校门后,非常遗憾没有继续搞NLP,对深度学习一直没有大的工程经验,感觉落后了一个时代。 去年曾经业余时间使用tenserflow搞了下的图片验证码识别,居然都不用图像分割,太震撼了。然而仅仅只是尝试了一下就没有下文了。现在再找个最简单的热热身吧, 借鉴了深度学习之鸡兔同笼问题的思路,他用的是tenserflow,那我这里就来个pytorch版本。

模型定义:

在构造函数中,我们按照顺序添加了各个层,每个层都是nn.Module的子类。 首先是nn.Flatten层,用于将输入数据展平为一维向量。 然后是一个全连接层nn.Linear(2, 10), 接着是ReLU激活函数nn.ReLU(), 最后是另一个全连接层nn.Linear(10, 2)。 这样,我们就定义了具有输入层、隐藏层和输出层的神经网络结构,并且使用ReLU作为激活函数。

model = nn.Sequential(
    nn.Flatten(input_shape=(2,)),
    nn.Linear(2, 10),
    nn.ReLU(),
    nn.Linear(10, 2)
)

或者写得更好复用一点:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(2, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 2)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = Model()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

接下来就很简单了:

模型训练:

可以简单生成一些数据进行训练,这里就用现成的吧

def train_data():
    x = [[2, 6], [3, 8], [3, 10], [4, 10], [4, 12], [4, 14], [5, 12], [5, 14], [5, 16], [5, 18]]
    x = np.array(x)

    y = [[1, 1], [2, 1], [1, 2], [3, 1], [2, 2], [1, 3], [4, 1], [3, 2], [2, 3], [1, 4]]
    y = np.array(y)
    return (x, y)

x_train = torch.tensor(train_data()[0], dtype=torch.float32)
y_train = torch.tensor(train_data()[1], dtype=torch.float32)

# 模型训练,总共训练5000次
for epoch in range(5000):
    optimizer.zero_grad()
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 100 == 0:
        print(f"Epoch: {epoch + 1}, Loss: {loss.item()}")

# 保存模型
torch.save(model.state_dict(), "./model.pt")

从Loss可以看到收敛得很快。

生成了模型,就可以进行预测了:

预测

# 加载模型
model = Model()
model.load_state_dict(torch.load("./model.pt"))
model.eval()

# 封装数据
data = input("请输入头和脚的数量:")
data = ast.literal_eval(data)
data = torch.tensor([data], dtype=torch.float32)

# 获取结果
with torch.no_grad():
    y_predict = model(data)
y_predict = y_predict.numpy()

print("\n鸡有" + str(round(y_predict[0][0])) + "只,兔有" + str(round(y_predict[0][1])) + "只。")

结果

请输入头和脚的数量:99,298

鸡有49只,兔有50只。

非常完美,根本不用关心鸡有多少腿,兔有多少腿这种领域知识,就可以从已知探索到未知。

Back to top