Add TOTP-based Multi-Factor Authentication (MFA) for local users
Global MFA toggle in Security settings, QR code setup on first login, 6-digit TOTP verification on subsequent logins. Azure AD users exempt. Admins can reset user MFA. TOTP secrets encrypted at rest with Fernet. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -80,6 +80,9 @@ def _run_migrations() -> None:
|
||||
("system_config", "default_language", "TEXT DEFAULT 'en'"),
|
||||
("users", "default_language", "TEXT"),
|
||||
("deployments", "npm_stream_id", "INTEGER"),
|
||||
("system_config", "mfa_enabled", "BOOLEAN DEFAULT 0"),
|
||||
("users", "totp_secret_encrypted", "TEXT"),
|
||||
("users", "totp_enabled", "BOOLEAN DEFAULT 0"),
|
||||
]
|
||||
for table, column, col_type in migrations:
|
||||
if not _has_column(table, column):
|
||||
|
||||
@@ -30,6 +30,39 @@ def create_access_token(username: str, expires_delta: Optional[timedelta] = None
|
||||
return jwt.encode(payload, SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_mfa_token(username: str) -> str:
|
||||
"""Create a short-lived JWT for the MFA verification step (5 min)."""
|
||||
expire = datetime.utcnow() + timedelta(minutes=5)
|
||||
payload = {"sub": username, "exp": expire, "purpose": "mfa"}
|
||||
return jwt.encode(payload, SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
|
||||
|
||||
def verify_mfa_token(token: str) -> str:
|
||||
"""Verify an MFA-purpose JWT and return the username.
|
||||
|
||||
Raises HTTPException if the token is invalid, expired, or not an MFA token.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
if payload.get("purpose") != "mfa":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid MFA token.",
|
||||
)
|
||||
username: Optional[str] = payload.get("sub")
|
||||
if not username:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid MFA token.",
|
||||
)
|
||||
return username
|
||||
except JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="MFA token expired or invalid.",
|
||||
)
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
@@ -161,6 +161,7 @@ class SystemConfig(Base):
|
||||
)
|
||||
branding_logo_path: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
||||
default_language: Mapped[Optional[str]] = mapped_column(String(10), default="en")
|
||||
mfa_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
azure_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
azure_tenant_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
azure_client_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
@@ -193,6 +194,7 @@ class SystemConfig(Base):
|
||||
"branding_subtitle": self.branding_subtitle or "Multi-Tenant Management Platform",
|
||||
"branding_logo_path": self.branding_logo_path,
|
||||
"default_language": self.default_language or "en",
|
||||
"mfa_enabled": bool(self.mfa_enabled),
|
||||
"azure_enabled": bool(self.azure_enabled),
|
||||
"azure_tenant_id": self.azure_tenant_id or "",
|
||||
"azure_client_id": self.azure_client_id or "",
|
||||
@@ -252,10 +254,12 @@ class User(Base):
|
||||
role: Mapped[str] = mapped_column(String(20), default="admin")
|
||||
auth_provider: Mapped[str] = mapped_column(String(20), default="local")
|
||||
default_language: Mapped[Optional[str]] = mapped_column(String(10), nullable=True, default=None)
|
||||
totp_secret_encrypted: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
totp_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize user to dictionary (no password)."""
|
||||
"""Serialize user to dictionary (no password, no TOTP secret)."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"username": self.username,
|
||||
@@ -264,5 +268,6 @@ class User(Base):
|
||||
"role": self.role or "admin",
|
||||
"auth_provider": self.auth_provider or "local",
|
||||
"default_language": self.default_language,
|
||||
"totp_enabled": bool(self.totp_enabled),
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Authentication API endpoints — login, logout, current user, password change, Azure AD."""
|
||||
"""Authentication API endpoints — login, logout, current user, password change, MFA, Azure AD."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
@@ -9,10 +11,18 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies import create_access_token, get_current_user
|
||||
from app.dependencies import create_access_token, create_mfa_token, get_current_user, verify_mfa_token
|
||||
from app.models import SystemConfig, User
|
||||
from app.utils.security import decrypt_value, hash_password, verify_password
|
||||
from app.utils.validators import ChangePasswordRequest, LoginRequest
|
||||
from app.utils.security import (
|
||||
decrypt_value,
|
||||
encrypt_value,
|
||||
generate_totp_secret,
|
||||
generate_totp_uri,
|
||||
hash_password,
|
||||
verify_password,
|
||||
verify_totp,
|
||||
)
|
||||
from app.utils.validators import ChangePasswordRequest, LoginRequest, MfaTokenRequest, MfaVerifyRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
@@ -20,15 +30,7 @@ router = APIRouter()
|
||||
|
||||
@router.post("/login")
|
||||
async def login(payload: LoginRequest, db: Session = Depends(get_db)):
|
||||
"""Authenticate and return a JWT token.
|
||||
|
||||
Args:
|
||||
payload: Username and password.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
JSON with ``access_token`` and ``token_type``.
|
||||
"""
|
||||
"""Authenticate with username/password. May require MFA as a second step."""
|
||||
user = db.query(User).filter(User.username == payload.username).first()
|
||||
if not user or not verify_password(payload.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
@@ -41,6 +43,17 @@ async def login(payload: LoginRequest, db: Session = Depends(get_db)):
|
||||
detail="Account is disabled.",
|
||||
)
|
||||
|
||||
# Check if MFA is required (only for local users)
|
||||
if user.auth_provider == "local":
|
||||
config = db.query(SystemConfig).filter(SystemConfig.id == 1).first()
|
||||
if config and getattr(config, "mfa_enabled", False):
|
||||
mfa_token = create_mfa_token(user.username)
|
||||
return {
|
||||
"mfa_required": True,
|
||||
"mfa_token": mfa_token,
|
||||
"totp_setup_needed": not bool(user.totp_enabled),
|
||||
}
|
||||
|
||||
token = create_access_token(user.username)
|
||||
logger.info("User %s logged in.", user.username)
|
||||
return {
|
||||
@@ -50,24 +63,140 @@ async def login(payload: LoginRequest, db: Session = Depends(get_db)):
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MFA endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/mfa/setup")
|
||||
async def mfa_setup(payload: MfaTokenRequest, db: Session = Depends(get_db)):
|
||||
"""Generate a new TOTP secret and QR code for first-time MFA setup."""
|
||||
username = verify_mfa_token(payload.mfa_token)
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found.")
|
||||
|
||||
# Generate new secret and store encrypted (not yet enabled)
|
||||
secret = generate_totp_secret()
|
||||
user.totp_secret_encrypted = encrypt_value(secret)
|
||||
db.commit()
|
||||
|
||||
# Generate QR code as base64 data URI
|
||||
uri = generate_totp_uri(secret, username)
|
||||
import qrcode
|
||||
|
||||
img = qrcode.make(uri)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
qr_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
return {
|
||||
"secret": secret,
|
||||
"qr_code": f"data:image/png;base64,{qr_b64}",
|
||||
"otpauth_uri": uri,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/mfa/setup/complete")
|
||||
async def mfa_setup_complete(payload: MfaVerifyRequest, db: Session = Depends(get_db)):
|
||||
"""Verify the first TOTP code to complete MFA setup, then issue access token."""
|
||||
username = verify_mfa_token(payload.mfa_token)
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found.")
|
||||
|
||||
if not user.totp_secret_encrypted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="TOTP setup not initiated. Call /auth/mfa/setup first.",
|
||||
)
|
||||
|
||||
secret = decrypt_value(user.totp_secret_encrypted)
|
||||
if not verify_totp(secret, payload.totp_code):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid TOTP code.",
|
||||
)
|
||||
|
||||
user.totp_enabled = True
|
||||
db.commit()
|
||||
|
||||
token = create_access_token(user.username)
|
||||
logger.info("User %s completed MFA setup and logged in.", user.username)
|
||||
return {
|
||||
"access_token": token,
|
||||
"token_type": "bearer",
|
||||
"user": user.to_dict(),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/mfa/verify")
|
||||
async def mfa_verify(payload: MfaVerifyRequest, db: Session = Depends(get_db)):
|
||||
"""Verify a TOTP code for users who already have MFA set up."""
|
||||
username = verify_mfa_token(payload.mfa_token)
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found.")
|
||||
|
||||
if not user.totp_secret_encrypted or not user.totp_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="TOTP is not set up for this user.",
|
||||
)
|
||||
|
||||
secret = decrypt_value(user.totp_secret_encrypted)
|
||||
if not verify_totp(secret, payload.totp_code):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid TOTP code.",
|
||||
)
|
||||
|
||||
token = create_access_token(user.username)
|
||||
logger.info("User %s passed MFA verification.", user.username)
|
||||
return {
|
||||
"access_token": token,
|
||||
"token_type": "bearer",
|
||||
"user": user.to_dict(),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/mfa/status")
|
||||
async def mfa_status(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Return MFA status for the current user and global setting."""
|
||||
config = db.query(SystemConfig).filter(SystemConfig.id == 1).first()
|
||||
return {
|
||||
"mfa_enabled_global": bool(config and getattr(config, "mfa_enabled", False)),
|
||||
"totp_enabled_user": bool(current_user.totp_enabled),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/mfa/disable")
|
||||
async def mfa_disable(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Disable TOTP for the current user."""
|
||||
current_user.totp_enabled = False
|
||||
current_user.totp_secret_encrypted = None
|
||||
db.commit()
|
||||
logger.info("User %s disabled their TOTP.", current_user.username)
|
||||
return {"message": "TOTP disabled successfully."}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Password change
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/logout")
|
||||
async def logout(current_user: User = Depends(get_current_user)):
|
||||
"""Logout (client-side token discard).
|
||||
|
||||
Returns:
|
||||
Confirmation message.
|
||||
"""
|
||||
"""Logout (client-side token discard)."""
|
||||
logger.info("User %s logged out.", current_user.username)
|
||||
return {"message": "Logged out successfully."}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_me(current_user: User = Depends(get_current_user)):
|
||||
"""Return the current authenticated user's profile.
|
||||
|
||||
Returns:
|
||||
User dict (no password hash).
|
||||
"""
|
||||
"""Return the current authenticated user's profile."""
|
||||
return current_user.to_dict()
|
||||
|
||||
|
||||
@@ -77,16 +206,7 @@ async def change_password(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Change the current user's password.
|
||||
|
||||
Args:
|
||||
payload: Current and new password.
|
||||
current_user: Authenticated user.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
Confirmation message.
|
||||
"""
|
||||
"""Change the current user's password."""
|
||||
if not verify_password(payload.current_password, current_user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -99,6 +219,9 @@ async def change_password(
|
||||
return {"message": "Password changed successfully."}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Azure AD
|
||||
# ---------------------------------------------------------------------------
|
||||
class AzureCallbackRequest(BaseModel):
|
||||
"""Azure AD auth code callback payload."""
|
||||
code: str
|
||||
|
||||
@@ -129,3 +129,28 @@ async def reset_password(
|
||||
|
||||
logger.info("Password reset for user '%s' by '%s'.", user.username, current_user.username)
|
||||
return {"message": "Password reset successfully.", "new_password": new_password}
|
||||
|
||||
|
||||
@router.post("/{user_id}/reset-mfa")
|
||||
async def reset_mfa(
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Reset MFA (TOTP) for a user. They will need to set up again on next login."""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found.")
|
||||
|
||||
if user.auth_provider != "local":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot reset MFA for Azure AD users.",
|
||||
)
|
||||
|
||||
user.totp_enabled = False
|
||||
user.totp_secret_encrypted = None
|
||||
db.commit()
|
||||
|
||||
logger.info("MFA reset for user '%s' by '%s'.", user.username, current_user.username)
|
||||
return {"message": f"MFA reset for '{user.username}'."}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Security utilities — password hashing (bcrypt) and token encryption (Fernet)."""
|
||||
"""Security utilities — password hashing (bcrypt), token encryption (Fernet), TOTP."""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
|
||||
import pyotp
|
||||
from cryptography.fernet import Fernet
|
||||
from passlib.context import CryptContext
|
||||
|
||||
@@ -91,6 +92,32 @@ def generate_relay_secret() -> str:
|
||||
return secrets.token_hex(16)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TOTP (Time-based One-Time Password)
|
||||
# ---------------------------------------------------------------------------
|
||||
def generate_totp_secret() -> str:
|
||||
"""Generate a new random TOTP secret (base32-encoded)."""
|
||||
return pyotp.random_base32()
|
||||
|
||||
|
||||
def verify_totp(secret: str, code: str) -> bool:
|
||||
"""Verify a 6-digit TOTP code against a secret.
|
||||
|
||||
Allows a window of +/- 1 interval (30s) to account for clock drift.
|
||||
"""
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.verify(code, valid_window=1)
|
||||
|
||||
|
||||
def generate_totp_uri(secret: str, username: str, issuer: str = "NetBird MSP") -> str:
|
||||
"""Generate an otpauth:// URI for QR code generation."""
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.provisioning_uri(name=username, issuer_name=issuer)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Misc key generation
|
||||
# ---------------------------------------------------------------------------
|
||||
def generate_datastore_encryption_key() -> str:
|
||||
"""Generate a base64-encoded 32-byte key for NetBird DataStoreEncryptionKey.
|
||||
|
||||
|
||||
@@ -23,6 +23,19 @@ class ChangePasswordRequest(BaseModel):
|
||||
new_password: str = Field(..., min_length=12, max_length=128)
|
||||
|
||||
|
||||
class MfaTokenRequest(BaseModel):
|
||||
"""Request containing only an MFA token (for setup initiation)."""
|
||||
|
||||
mfa_token: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
class MfaVerifyRequest(BaseModel):
|
||||
"""MFA TOTP verification payload."""
|
||||
|
||||
mfa_token: str = Field(..., min_length=1)
|
||||
totp_code: str = Field(..., min_length=6, max_length=6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Customer
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -113,6 +126,7 @@ class SystemConfigUpdate(BaseModel):
|
||||
branding_name: Optional[str] = Field(None, max_length=255)
|
||||
branding_subtitle: Optional[str] = Field(None, max_length=255)
|
||||
default_language: Optional[str] = Field(None, max_length=10)
|
||||
mfa_enabled: Optional[bool] = None
|
||||
azure_enabled: Optional[bool] = None
|
||||
azure_tenant_id: Optional[str] = Field(None, max_length=255)
|
||||
azure_client_id: Optional[str] = Field(None, max_length=255)
|
||||
|
||||
Reference in New Issue
Block a user