这里主要探讨torch.nn.CrossEntropyLoss函数的用法。
使用方法如下:
# 首先定义该类
loss = torch.nn.CrossEntropyLoss()
#然后传参进去
loss(target, label)
第一个参数的维度为m1 * m2
,第二个参数维度为m1
。我们在做多分类问题的时候,target应该为我们网络生成的值,而label则是非one-hot类型
的值。
用手写体数字识别简单举个例子:
for (trainX,trainY) in trainLoader:
# forward
out = net(trainX) # batch_size * num_classes
loss = loss_fn(out, trainY)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
另外,trainY必须为Long类型
,如果为非类型则会报错。RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'trainY'
CrossEntropyLoss还会自动对out作用softmax。
因篇幅问题不能全部显示,请点此查看更多更全内容