diff --git a/backend/data/inventory.db b/backend/data/inventory.db index 36d6443..3aa4d25 100644 Binary files a/backend/data/inventory.db and b/backend/data/inventory.db differ diff --git a/backend/models.py b/backend/models.py index 41ab888..44b7e00 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,5 +1,23 @@ from sqlalchemy import Column, Integer, String, Float, DateTime, Text, func from database import Base +import hashlib + + +class User(Base): + """用户表""" + __tablename__ = "user" + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + username = Column(String(50), unique=True, nullable=False, index=True, comment="用户名") + password_hash = Column(String(128), nullable=False, comment="密码哈希") + nickname = Column(String(50), nullable=True, comment="昵称") + created_at = Column(DateTime, server_default=func.now(), comment="创建时间") + + def set_password(self, password): + self.password_hash = hashlib.sha256(password.encode()).hexdigest() + + def check_password(self, password): + return self.password_hash == hashlib.sha256(password.encode()).hexdigest() class Inventory(Base): diff --git a/backend/requirements.txt b/backend/requirements.txt index da43518..0f66929 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -4,3 +4,5 @@ sqlalchemy==2.0.35 openpyxl==3.1.5 python-multipart==0.0.12 pydantic==2.9.2 +python-jose[cryptography]==3.3.0 +passlib[bcrypt]==1.7.4 diff --git a/backend/routers.py b/backend/routers.py index 22b9275..4f3170d 100644 --- a/backend/routers.py +++ b/backend/routers.py @@ -1,20 +1,96 @@ from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, Query from fastapi.responses import StreamingResponse +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from sqlalchemy.orm import Session from sqlalchemy import or_ from typing import Optional import io -from datetime import datetime +from datetime import datetime, timedelta from urllib.parse import quote +from jose import JWTError, jwt +from passlib.context import CryptContext from database import get_db -from models import Inventory, TransactionLog +from models import Inventory, TransactionLog, User from schemas import InventoryCreate, InventoryUpdate, StockOperation import openpyxl +# JWT配置 +SECRET_KEY = "inventory-management-secret-key-2024" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 24小时 + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") + router = APIRouter() +def create_access_token(data: dict, expires_delta: timedelta = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): + credentials_exception = HTTPException( + status_code=401, + detail="无法验证凭据", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + except JWTError: + raise credentials_exception + user = db.query(User).filter(User.username == username).first() + if user is None: + raise credentials_exception + return user + + +# ===== 认证相关 ===== +@router.post("/auth/register") +def register(username: str, password: str, nickname: str = "", db: Session = Depends(get_db)): + """用户注册""" + existing = db.query(User).filter(User.username == username).first() + if existing: + raise HTTPException(status_code=400, detail="用户名已存在") + + user = User(username=username, nickname=nickname) + user.set_password(password) + db.add(user) + db.commit() + return {"message": "注册成功"} + + +@router.post("/auth/login") +def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): + """用户登录""" + user = db.query(User).filter(User.username == form_data.username).first() + if not user or not user.check_password(form_data.password): + raise HTTPException(status_code=401, detail="用户名或密码错误") + + access_token = create_access_token( + data={"sub": user.username}, + expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + ) + return {"access_token": access_token, "token_type": "bearer", "username": user.username, "nickname": user.nickname} + + +@router.get("/auth/me") +def get_me(current_user: User = Depends(get_current_user)): + """获取当前用户信息""" + return {"username": current_user.username, "nickname": current_user.nickname} + + def inventory_to_dict(item: Inventory) -> dict: """将Inventory模型转为字典""" return { @@ -50,7 +126,8 @@ def list_inventory( page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), search: Optional[str] = None, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """获取库存列表,支持分页和搜索""" query = db.query(Inventory) @@ -71,7 +148,7 @@ def list_inventory( @router.post("/inventory") -def create_inventory(data: InventoryCreate, db: Session = Depends(get_db)): +def create_inventory(data: InventoryCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): """新增库存""" item = Inventory(**data.model_dump()) db.add(item) @@ -81,7 +158,7 @@ def create_inventory(data: InventoryCreate, db: Session = Depends(get_db)): @router.put("/inventory/{item_id}") -def update_inventory(item_id: int, data: InventoryUpdate, db: Session = Depends(get_db)): +def update_inventory(item_id: int, data: InventoryUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): """更新库存""" item = db.query(Inventory).filter(Inventory.id == item_id).first() if not item: @@ -95,7 +172,7 @@ def update_inventory(item_id: int, data: InventoryUpdate, db: Session = Depends( @router.delete("/inventory/{item_id}") -def delete_inventory(item_id: int, db: Session = Depends(get_db)): +def delete_inventory(item_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): """删除库存""" item = db.query(Inventory).filter(Inventory.id == item_id).first() if not item: @@ -108,7 +185,7 @@ def delete_inventory(item_id: int, db: Session = Depends(get_db)): # ===== 出入库 ===== @router.post("/stock/operation") -def stock_operation(op: StockOperation, db: Session = Depends(get_db)): +def stock_operation(op: StockOperation, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): """出入库操作""" item = db.query(Inventory).filter(Inventory.id == op.inventory_id).first() if not item: @@ -138,7 +215,7 @@ def stock_operation(op: StockOperation, db: Session = Depends(get_db)): @router.delete("/stock/logs") -def clear_stock_logs(db: Session = Depends(get_db)): +def clear_stock_logs(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): """清空所有出入库记录""" count = db.query(TransactionLog).delete() db.commit() @@ -150,7 +227,8 @@ def get_stock_logs( page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), search: Optional[str] = None, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """获取出入库记录""" query = db.query(TransactionLog) @@ -170,7 +248,7 @@ def get_stock_logs( # ===== Excel 导入导出 ===== @router.get("/inventory/export") -def export_inventory(db: Session = Depends(get_db)): +def export_inventory(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): """导出库存为Excel""" items = db.query(Inventory).order_by(Inventory.id.asc()).all() @@ -232,7 +310,7 @@ def export_inventory(db: Session = Depends(get_db)): @router.post("/inventory/import") -def import_inventory(file: UploadFile = File(...), db: Session = Depends(get_db)): +def import_inventory(file: UploadFile = File(...), db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): """从Excel导入库存""" if not file.filename.endswith(('.xlsx', '.xls')): raise HTTPException(status_code=400, detail="只支持 .xlsx 或 .xls 文件") @@ -381,7 +459,7 @@ def import_inventory(file: UploadFile = File(...), db: Session = Depends(get_db) # ===== 导出出入库记录 ===== @router.get("/stock/export") -def export_stock_logs(db: Session = Depends(get_db)): +def export_stock_logs(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): """导出出入库记录为Excel""" items = db.query(TransactionLog).order_by(TransactionLog.id.desc()).all() diff --git a/backend/static/index.html b/backend/static/index.html index eb9070d..c9fae8e 100644 --- a/backend/static/index.html +++ b/backend/static/index.html @@ -134,11 +134,56 @@ 库存管理系统