Files

514 rivejä
19 KiB
Python

from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, Query
from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from sqlalchemy import or_
from typing import Optional
import io
from datetime import datetime, timedelta
from urllib.parse import quote
from jose import JWTError, jwt
from passlib.context import CryptContext
from database import get_db
from models import Inventory, TransactionLog, User
from schemas import InventoryCreate, InventoryUpdate, StockOperation
import openpyxl
# JWT配置
SECRET_KEY = "inventory-management-secret-key-2024"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 24小时
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
router = APIRouter()
def create_access_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
credentials_exception = HTTPException(
status_code=401,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
return user
# ===== 认证相关 =====
@router.post("/auth/register")
def register(username: str, password: str, nickname: str = "", db: Session = Depends(get_db)):
"""用户注册"""
existing = db.query(User).filter(User.username == username).first()
if existing:
raise HTTPException(status_code=400, detail="用户名已存在")
user = User(username=username, nickname=nickname)
user.set_password(password)
db.add(user)
db.commit()
return {"message": "注册成功"}
@router.post("/auth/login")
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
"""用户登录"""
user = db.query(User).filter(User.username == form_data.username).first()
if not user or not user.check_password(form_data.password):
raise HTTPException(status_code=401, detail="用户名或密码错误")
access_token = create_access_token(
data={"sub": user.username},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
return {"access_token": access_token, "token_type": "bearer", "username": user.username, "nickname": user.nickname}
@router.get("/auth/me")
def get_me(current_user: User = Depends(get_current_user)):
"""获取当前用户信息"""
return {"username": current_user.username, "nickname": current_user.nickname}
def inventory_to_dict(item: Inventory) -> dict:
"""将Inventory模型转为字典"""
return {
"id": item.id,
"cInvCode": item.cInvCode,
"supplier": item.supplier,
"casing_label_remark": item.casing_label_remark,
"batch": item.batch,
"current_remaining": item.current_remaining,
"storage_location": item.storage_location,
"created_at": item.created_at.isoformat() if item.created_at else None,
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
}
def log_to_dict(log: TransactionLog) -> dict:
"""将TransactionLog模型转为字典"""
return {
"id": log.id,
"inventory_id": log.inventory_id,
"cInvCode": log.cInvCode,
"type": log.type,
"quantity": log.quantity,
"remark": log.remark,
"created_at": log.created_at.isoformat() if log.created_at else None,
}
# ===== 库存 CRUD =====
@router.get("/inventory")
def list_inventory(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取库存列表,支持分页和搜索"""
query = db.query(Inventory)
if search:
keyword = f"%{search}%"
query = query.filter(
or_(
Inventory.cInvCode.like(keyword),
Inventory.supplier.like(keyword),
Inventory.batch.like(keyword),
Inventory.storage_location.like(keyword),
Inventory.casing_label_remark.like(keyword),
)
)
total = query.count()
items = query.order_by(Inventory.id.asc()).offset((page - 1) * page_size).limit(page_size).all()
return {"total": total, "items": [inventory_to_dict(i) for i in items]}
@router.post("/inventory")
def create_inventory(data: InventoryCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""新增库存"""
item = Inventory(**data.model_dump())
db.add(item)
db.commit()
db.refresh(item)
return inventory_to_dict(item)
@router.put("/inventory/{item_id}")
def update_inventory(item_id: int, data: InventoryUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""更新库存"""
item = db.query(Inventory).filter(Inventory.id == item_id).first()
if not item:
raise HTTPException(status_code=404, detail="库存记录不存在")
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(item, key, value)
db.commit()
db.refresh(item)
return inventory_to_dict(item)
@router.delete("/inventory/{item_id}")
def delete_inventory(item_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""删除库存"""
item = db.query(Inventory).filter(Inventory.id == item_id).first()
if not item:
raise HTTPException(status_code=404, detail="库存记录不存在")
db.delete(item)
db.commit()
return {"message": "删除成功"}
# ===== 出入库 =====
@router.post("/stock/operation")
def stock_operation(op: StockOperation, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""出入库操作"""
item = db.query(Inventory).filter(Inventory.id == op.inventory_id).first()
if not item:
raise HTTPException(status_code=404, detail="库存记录不存在")
if op.type == "out":
if item.current_remaining < op.quantity:
raise HTTPException(status_code=400, detail=f"库存不足,当前剩余: {item.current_remaining}")
item.current_remaining -= op.quantity
elif op.type == "in":
item.current_remaining += op.quantity
else:
raise HTTPException(status_code=400, detail="类型必须是 in 或 out")
# 记录操作日志
log = TransactionLog(
inventory_id=item.id,
cInvCode=item.cInvCode,
type=op.type,
quantity=op.quantity,
remark=op.remark,
)
db.add(log)
db.commit()
db.refresh(item)
return {"message": "操作成功", "current_remaining": item.current_remaining}
@router.delete("/stock/logs")
def clear_stock_logs(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""清空所有出入库记录"""
count = db.query(TransactionLog).delete()
db.commit()
return {"message": f"已清空 {count} 条出入库记录"}
@router.get("/stock/logs")
def get_stock_logs(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取出入库记录"""
query = db.query(TransactionLog)
if search:
keyword = f"%{search}%"
query = query.filter(
or_(
TransactionLog.cInvCode.like(keyword),
TransactionLog.remark.like(keyword),
)
)
total = query.count()
items = query.order_by(TransactionLog.id.asc()).offset((page - 1) * page_size).limit(page_size).all()
return {"total": total, "items": [log_to_dict(l) for l in items]}
# ===== Excel 导入导出 =====
@router.get("/inventory/export")
def export_inventory(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""导出库存为Excel"""
items = db.query(Inventory).order_by(Inventory.id.asc()).all()
wb = openpyxl.Workbook()
ws = wb.active
ws.title = "库存数据"
# 表头
headers = ["序号", "产品编码", "供应商", "现外壳&标签&备注", "批次", "当前时间剩余", "存货地点"]
ws.append(headers)
# 表头样式
from openpyxl.styles import Font, Alignment, PatternFill, Border, Side
header_font = Font(bold=True, size=12, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
thin_border = Border(
left=Side(style='thin'), right=Side(style='thin'),
top=Side(style='thin'), bottom=Side(style='thin')
)
for col_idx, header in enumerate(headers, 1):
cell = ws.cell(row=1, column=col_idx, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
cell.border = thin_border
# 数据行
for item in items:
row_data = [
item.id,
item.cInvCode,
item.supplier or "",
item.casing_label_remark or "",
item.batch or "",
item.current_remaining,
item.storage_location or "",
]
ws.append(row_data)
# 调整列宽
col_widths = [8, 20, 20, 30, 15, 15, 20]
for idx, width in enumerate(col_widths, 1):
ws.column_dimensions[openpyxl.utils.get_column_letter(idx)].width = width
# 保存到内存
output = io.BytesIO()
wb.save(output)
output.seek(0)
filename = f"库存数据_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
filename_encoded = quote(filename)
return StreamingResponse(
output,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": f"attachment; filename={filename_encoded}; filename*=UTF-8''{filename_encoded}"}
)
@router.post("/inventory/import")
def import_inventory(file: UploadFile = File(...), db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""从Excel导入库存"""
if not file.filename.endswith(('.xlsx', '.xls')):
raise HTTPException(status_code=400, detail="只支持 .xlsx 或 .xls 文件")
try:
contents = file.file.read()
wb = openpyxl.load_workbook(io.BytesIO(contents))
ws = wb.active
imported = 0
updated = 0
errors = []
# 查找表头位置
header_map = {}
expected_headers = {
"序号": "id",
"产品编码": "cInvCode",
"供应商": "supplier",
"现外壳&标签&备注": "casing_label_remark",
"外壳&标签&备注": "casing_label_remark",
"批次": "batch",
"当前时间剩余": "current_remaining",
"存货地点": "storage_location",
}
for row_idx, row in enumerate(ws.iter_rows(min_row=1, max_row=5), 1):
for col_idx, cell in enumerate(row):
if cell.value and str(cell.value).strip() in expected_headers:
header_map[str(cell.value).strip()] = col_idx
if not header_map.get("产品编码"):
raise HTTPException(status_code=400, detail="未找到有效的表头行,请确保包含'产品编码'")
# 解析数据行(从第2行开始,假设第1行是表头)
for row_idx, row in enumerate(ws.iter_rows(min_row=2), 2):
try:
# 先读取序号
id_col = header_map.get("序号")
record_id = None
if id_col is not None and id_col < len(row):
id_val = row[id_col].value
if id_val is not None:
try:
record_id = int(float(id_val))
except (ValueError, TypeError):
record_id = None
cInvCode_col = header_map.get("产品编码")
if cInvCode_col is None:
continue
cInvCode = row[cInvCode_col].value if cInvCode_col < len(row) else None
if not cInvCode:
continue
cInvCode = str(cInvCode).strip()
def get_val(key, default=""):
col = header_map.get(key)
if col is not None and col < len(row):
val = row[col].value
return str(val).strip() if val is not None else default
return default
supplier = get_val("供应商")
casing_label_remark = get_val("现外壳&标签&备注") or get_val("外壳&标签&备注")
batch = get_val("批次")
storage_location = get_val("存货地点")
current_remaining_val = 0
cr_col = header_map.get("当前时间剩余")
if cr_col is not None and cr_col < len(row):
val = row[cr_col].value
try:
current_remaining_val = float(val) if val is not None else 0
except (ValueError, TypeError):
current_remaining_val = 0
# 按序号优先,否则按产品编码+批次
if record_id:
# 按序号查找
existing = db.query(Inventory).filter(Inventory.id == record_id).first()
if existing:
existing.cInvCode = cInvCode
existing.supplier = supplier
existing.casing_label_remark = casing_label_remark
existing.batch = batch
existing.current_remaining = current_remaining_val
existing.storage_location = storage_location
updated += 1
else:
# 序号不存在,新增并指定ID
new_item = Inventory(
id=record_id,
cInvCode=cInvCode,
supplier=supplier,
casing_label_remark=casing_label_remark,
batch=batch,
current_remaining=current_remaining_val,
storage_location=storage_location,
)
db.add(new_item)
imported += 1
else:
# 按产品编码+批次去重
existing = db.query(Inventory).filter(
Inventory.cInvCode == cInvCode,
Inventory.batch == batch
).first()
if existing:
existing.supplier = supplier or existing.supplier
existing.casing_label_remark = casing_label_remark or existing.casing_label_remark
existing.current_remaining = current_remaining_val if current_remaining_val else existing.current_remaining
existing.storage_location = storage_location or existing.storage_location
updated += 1
else:
new_item = Inventory(
cInvCode=cInvCode,
supplier=supplier,
casing_label_remark=casing_label_remark,
batch=batch,
current_remaining=current_remaining_val,
storage_location=storage_location,
)
db.add(new_item)
imported += 1
except Exception as e:
errors.append(f"{row_idx}行: {str(e)}")
continue
db.commit()
return {
"message": f"导入完成:新增 {imported} 条,更新 {updated}",
"imported": imported,
"updated": updated,
"errors": errors[:10],
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"导入失败: {str(e)}")
# ===== 导出出入库记录 =====
@router.get("/stock/export")
def export_stock_logs(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""导出出入库记录为Excel"""
items = db.query(TransactionLog).order_by(TransactionLog.id.desc()).all()
wb = openpyxl.Workbook()
ws = wb.active
ws.title = "出入库记录"
headers = ["序号", "产品编码", "类型", "数量", "备注", "操作时间"]
ws.append(headers)
from openpyxl.styles import Font, Alignment, PatternFill, Border, Side
header_font = Font(bold=True, size=12, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
thin_border = Border(
left=Side(style='thin'), right=Side(style='thin'),
top=Side(style='thin'), bottom=Side(style='thin')
)
for col_idx, header in enumerate(headers, 1):
cell = ws.cell(row=1, column=col_idx, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
cell.border = thin_border
for item in items:
ws.append([
item.id,
item.cInvCode,
"入库" if item.type == "in" else "出库",
item.quantity,
item.remark or "",
item.created_at.strftime("%Y-%m-%d %H:%M:%S") if item.created_at else "",
])
col_widths = [8, 20, 10, 10, 30, 20]
for idx, width in enumerate(col_widths, 1):
ws.column_dimensions[openpyxl.utils.get_column_letter(idx)].width = width
output = io.BytesIO()
wb.save(output)
output.seek(0)
filename = f"出入库记录_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
filename_encoded = quote(filename)
return StreamingResponse(
output,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": f"attachment; filename={filename_encoded}; filename*=UTF-8''{filename_encoded}"}
)