You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
3.7 KiB

from datetime import datetime, timedelta
from typing import Optional
import jwt
from fastapi import Depends, FastAPI, Security
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jwt import PyJWTError
from passlib.context import CryptContext
from pydantic import BaseModel
from starlette.exceptions import HTTPException
from starlette.status import HTTP_403_FORBIDDEN
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
TOKEN_SUBJECT = "access"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
fake_users_db = {
"johndoe": {
"username": "johndoe",
"full_name": "John Doe",
"email": "[email protected]",
"password_hash": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",
}
}
class Token(BaseModel):
access_token: str
token_type: str
class TokenPayload(BaseModel):
username: str = None
class User(BaseModel):
username: str
email: Optional[str] = None
full_name: Optional[str] = None
disabled: Optional[bool] = None
class UserInDB(User):
hashed_password: str
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token")
app = FastAPI()
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def get_user(db, username: str):
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
if not verify_password(password, user.hashed_password):
return False
return user
def create_access_token(*, data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire, "sub": TOKEN_SUBJECT})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: str = Security(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
token_data = TokenPayload(**payload)
except PyJWTError:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
)
user = get_user(fake_users_db, username=token_data.username)
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)):
if not current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
@app.post("/token", response_model=Token)
async def route_login_access_token(form_data: OAuth2PasswordRequestForm):
data = form_data.parse()
user = authenticate_user(fake_users_db, data.username, data.password)
if not user:
raise HTTPException(status_code=400, detail="Incorrect email or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
return {
"access_token": create_access_token(
data={"username": data.username}, expires_delta=access_token_expires
),
"token_type": "bearer",
}
@app.get("/users/me", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_active_user)):
return current_user
@app.get("/users/me/items/")
async def read_own_items(current_user: User = Depends(get_current_active_user)):
return [{"item_id": "Foo", "owner": current_user.username}]