技术文摘
100 行 Python 代码,轻松实现神经网络
100 行 Python 代码,轻松实现神经网络
在当今的科技领域,神经网络已经成为了一项极其重要的技术,广泛应用于图像识别、自然语言处理、预测分析等众多领域。然而,许多人可能认为实现神经网络是一项复杂且艰巨的任务。但实际上,通过 Python 语言,我们仅用 100 行左右的代码就能轻松构建一个简单的神经网络。
让我们来导入所需的库,比如 numpy 用于数值计算。
import numpy as np
接下来,我们定义神经网络的类。这个类将包含神经网络的基本结构和功能。
class NeuralNetwork:
def __init__(self, input_size, hidden_size, output_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
# 随机初始化权重
self.W1 = np.random.randn(self.input_size, self.hidden_size)
self.b1 = np.zeros((1, self.hidden_size))
self.W2 = np.random.randn(self.hidden_size, self.output_size)
self.b2 = np.zeros((1, self.output_size))
def forward(self, X):
# 前向传播
self.z1 = np.dot(X, self.W1) + self.b1
self.a1 = np.tanh(self.z1)
self.z2 = np.dot(self.a1, self.W2) + self.b2
self.a2 = self.sigmoid(self.z2)
return self.a2
def sigmoid(self, z):
return 1 / (1 + np.exp(-z))
def loss(self, y_true, y_pred):
# 计算损失
return np.mean((y_true - y_pred) ** 2)
def backward(self, X, y_true):
# 反向传播
m = X.shape[0]
dZ2 = self.a2 - y_true
dW2 = np.dot(self.a1.T, dZ2) / m
db2 = np.sum(dZ2, axis=0, keepdims=True) / m
dZ1 = np.dot(dZ2, self.W2.T) * (1 - np.power(self.a1, 2))
dW1 = np.dot(X.T, dZ1) / m
db1 = np.sum(dZ1, axis=0, keepdims=True) / m
self.W1 -= learning_rate * dW1
self.b1 -= learning_rate * db1
self.W2 -= learning_rate * dW2
self.b2 -= learning_rate * db2
然后,我们可以使用这个神经网络类来进行训练和预测。
input_size = 2
hidden_size = 4
output_size = 1
learning_rate = 0.1
nn = NeuralNetwork(input_size, hidden_size, output_size)
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])
for _ in range(1000):
y_pred = nn.forward(X)
loss = nn.loss(y, y_pred)
nn.backward(X, y)
print("Loss:", loss)
通过这 100 行左右的 Python 代码,我们成功实现了一个简单的神经网络。虽然它还很基础,但已经为我们理解神经网络的工作原理和实现方式提供了一个很好的起点。
随着对神经网络的深入研究和不断优化,我们可以在这个基础上添加更多的功能和改进,以适应更复杂的任务和需求。希望这个简单的示例能够激发您对神经网络的兴趣,并为您在相关领域的探索和应用提供一些帮助。
TAGS: 轻松实现 Python 代码 神经网络 Python 神经网络
- 京东白条的数据架构演进揭秘
- 五张图解析 RocketMQ 消费者启动流程
- 一文弄懂 Vue3.0 采用 Proxy 的原因
- 20 行 Python 代码,便捷提取 PPT 文字至 Word
- VR 怎样使街道更安全?
- Python 中字符串格式化输出之浅议
- 我的 JavaScript 速度超你的 Rust
- ThreadLocal 会导致内存泄漏吗?
- 偷看同事代码,揭开优雅代码的神秘面纱
- 基于 Node.js 与 SQLite 打造离线优先应用
- 新一代 Pnpm 包管理工具
- 掌握 TS infer ,书写泛型超棒!
- Python 字典操作指南,一篇就够
- 消息队列堆积过多,下游处理不及该如何应对
- 浅析逻辑选择器 Is、Where、Not、Has