添加登录功能:JWT认证、注册、登录、前端登录页面

Bu işleme şunda yer alıyor:
cnbugs
2026-06-01 15:54:29 +08:00
ebeveyn 188edfa287
işleme 0d6c9d26c0
5 değiştirilmiş dosya ile 289 ekleme ve 23 silme
+90 -12
Dosyayı Görüntüle
@@ -1,20 +1,96 @@
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, Query
from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from sqlalchemy import or_
from typing import Optional
import io
from datetime import datetime
from datetime import datetime, timedelta
from urllib.parse import quote
from jose import JWTError, jwt
from passlib.context import CryptContext
from database import get_db
from models import Inventory, TransactionLog
from models import Inventory, TransactionLog, User
from schemas import InventoryCreate, InventoryUpdate, StockOperation
import openpyxl
# JWT配置
SECRET_KEY = "inventory-management-secret-key-2024"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 24小时
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
router = APIRouter()
def create_access_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
credentials_exception = HTTPException(
status_code=401,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
return user
# ===== 认证相关 =====
@router.post("/auth/register")
def register(username: str, password: str, nickname: str = "", db: Session = Depends(get_db)):
"""用户注册"""
existing = db.query(User).filter(User.username == username).first()
if existing:
raise HTTPException(status_code=400, detail="用户名已存在")
user = User(username=username, nickname=nickname)
user.set_password(password)
db.add(user)
db.commit()
return {"message": "注册成功"}
@router.post("/auth/login")
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
"""用户登录"""
user = db.query(User).filter(User.username == form_data.username).first()
if not user or not user.check_password(form_data.password):
raise HTTPException(status_code=401, detail="用户名或密码错误")
access_token = create_access_token(
data={"sub": user.username},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
return {"access_token": access_token, "token_type": "bearer", "username": user.username, "nickname": user.nickname}
@router.get("/auth/me")
def get_me(current_user: User = Depends(get_current_user)):
"""获取当前用户信息"""
return {"username": current_user.username, "nickname": current_user.nickname}
def inventory_to_dict(item: Inventory) -> dict:
"""将Inventory模型转为字典"""
return {
@@ -50,7 +126,8 @@ def list_inventory(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = None,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取库存列表,支持分页和搜索"""
query = db.query(Inventory)
@@ -71,7 +148,7 @@ def list_inventory(
@router.post("/inventory")
def create_inventory(data: InventoryCreate, db: Session = Depends(get_db)):
def create_inventory(data: InventoryCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""新增库存"""
item = Inventory(**data.model_dump())
db.add(item)
@@ -81,7 +158,7 @@ def create_inventory(data: InventoryCreate, db: Session = Depends(get_db)):
@router.put("/inventory/{item_id}")
def update_inventory(item_id: int, data: InventoryUpdate, db: Session = Depends(get_db)):
def update_inventory(item_id: int, data: InventoryUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""更新库存"""
item = db.query(Inventory).filter(Inventory.id == item_id).first()
if not item:
@@ -95,7 +172,7 @@ def update_inventory(item_id: int, data: InventoryUpdate, db: Session = Depends(
@router.delete("/inventory/{item_id}")
def delete_inventory(item_id: int, db: Session = Depends(get_db)):
def delete_inventory(item_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""删除库存"""
item = db.query(Inventory).filter(Inventory.id == item_id).first()
if not item:
@@ -108,7 +185,7 @@ def delete_inventory(item_id: int, db: Session = Depends(get_db)):
# ===== 出入库 =====
@router.post("/stock/operation")
def stock_operation(op: StockOperation, db: Session = Depends(get_db)):
def stock_operation(op: StockOperation, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""出入库操作"""
item = db.query(Inventory).filter(Inventory.id == op.inventory_id).first()
if not item:
@@ -138,7 +215,7 @@ def stock_operation(op: StockOperation, db: Session = Depends(get_db)):
@router.delete("/stock/logs")
def clear_stock_logs(db: Session = Depends(get_db)):
def clear_stock_logs(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""清空所有出入库记录"""
count = db.query(TransactionLog).delete()
db.commit()
@@ -150,7 +227,8 @@ def get_stock_logs(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = None,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取出入库记录"""
query = db.query(TransactionLog)
@@ -170,7 +248,7 @@ def get_stock_logs(
# ===== Excel 导入导出 =====
@router.get("/inventory/export")
def export_inventory(db: Session = Depends(get_db)):
def export_inventory(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""导出库存为Excel"""
items = db.query(Inventory).order_by(Inventory.id.asc()).all()
@@ -232,7 +310,7 @@ def export_inventory(db: Session = Depends(get_db)):
@router.post("/inventory/import")
def import_inventory(file: UploadFile = File(...), db: Session = Depends(get_db)):
def import_inventory(file: UploadFile = File(...), db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""从Excel导入库存"""
if not file.filename.endswith(('.xlsx', '.xls')):
raise HTTPException(status_code=400, detail="只支持 .xlsx 或 .xls 文件")
@@ -381,7 +459,7 @@ def import_inventory(file: UploadFile = File(...), db: Session = Depends(get_db)
# ===== 导出出入库记录 =====
@router.get("/stock/export")
def export_stock_logs(db: Session = Depends(get_db)):
def export_stock_logs(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""导出出入库记录为Excel"""
items = db.query(TransactionLog).order_by(TransactionLog.id.desc()).all()