Browse Source
* ♻️ Refactor backend, update DB session handling * ✨ Add mypy config and plugins * ➕ Use Python-jose instead of PyJWT as it has some extra functionalities and features * ✨ Add/update scripts for test, lint, format * 🔧 Update lint and format configs * 🎨 Update import format, comments, and types * 🎨 Add types to config * ✨ Add types for all the code, and small fixes * 🎨 Use global imports to simplify exploring with Jupyter * ♻️ Import schemas and models, instead of each class * 🚚 Rename db_session to db for simplicity * 📌 Update dependencies installation for testingpull/13907/head
committed by
GitHub
59 changed files with 545 additions and 443 deletions
@ -0,0 +1,3 @@ |
|||
[flake8] |
|||
max-line-length = 88 |
|||
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache |
@ -0,0 +1,3 @@ |
|||
.mypy_cache |
|||
.coverage |
|||
htmlcov |
@ -0,0 +1,61 @@ |
|||
from typing import Generator |
|||
|
|||
from fastapi import Depends, HTTPException, status |
|||
from fastapi.security import OAuth2PasswordBearer |
|||
from jose import jwt |
|||
from pydantic import ValidationError |
|||
from sqlalchemy.orm import Session |
|||
|
|||
from app import crud, models, schemas |
|||
from app.core import security |
|||
from app.core.config import settings |
|||
from app.db.session import SessionLocal |
|||
|
|||
reusable_oauth2 = OAuth2PasswordBearer( |
|||
tokenUrl=f"{settings.API_V1_STR}/login/access-token" |
|||
) |
|||
|
|||
|
|||
def get_db() -> Generator: |
|||
try: |
|||
db = SessionLocal() |
|||
yield db |
|||
finally: |
|||
db.close() |
|||
|
|||
|
|||
def get_current_user( |
|||
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2) |
|||
) -> models.User: |
|||
try: |
|||
payload = jwt.decode( |
|||
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] |
|||
) |
|||
token_data = schemas.TokenPayload(**payload) |
|||
except (jwt.JWTError, ValidationError): |
|||
raise HTTPException( |
|||
status_code=status.HTTP_403_FORBIDDEN, |
|||
detail="Could not validate credentials", |
|||
) |
|||
user = crud.user.get(db, id=token_data.sub) |
|||
if not user: |
|||
raise HTTPException(status_code=404, detail="User not found") |
|||
return user |
|||
|
|||
|
|||
def get_current_active_user( |
|||
current_user: models.User = Depends(get_current_user), |
|||
) -> models.User: |
|||
if not crud.user.is_active(current_user): |
|||
raise HTTPException(status_code=400, detail="Inactive user") |
|||
return current_user |
|||
|
|||
|
|||
def get_current_active_superuser( |
|||
current_user: models.User = Depends(get_current_user), |
|||
) -> models.User: |
|||
if not crud.user.is_superuser(current_user): |
|||
raise HTTPException( |
|||
status_code=400, detail="The user doesn't have enough privileges" |
|||
) |
|||
return current_user |
@ -1,5 +0,0 @@ |
|||
from starlette.requests import Request |
|||
|
|||
|
|||
def get_db(request: Request): |
|||
return request.state.db |
@ -1,45 +0,0 @@ |
|||
import jwt |
|||
from fastapi import Depends, HTTPException, Security |
|||
from fastapi.security import OAuth2PasswordBearer |
|||
from jwt import PyJWTError |
|||
from sqlalchemy.orm import Session |
|||
from starlette.status import HTTP_403_FORBIDDEN |
|||
|
|||
from app import crud |
|||
from app.api.utils.db import get_db |
|||
from app.core.config import settings |
|||
from app.core.jwt import ALGORITHM |
|||
from app.models.user import User |
|||
from app.schemas.token import TokenPayload |
|||
|
|||
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token") |
|||
|
|||
|
|||
def get_current_user( |
|||
db: Session = Depends(get_db), token: str = Security(reusable_oauth2) |
|||
): |
|||
try: |
|||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) |
|||
token_data = TokenPayload(**payload) |
|||
except PyJWTError: |
|||
raise HTTPException( |
|||
status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials" |
|||
) |
|||
user = crud.user.get(db, id=token_data.user_id) |
|||
if not user: |
|||
raise HTTPException(status_code=404, detail="User not found") |
|||
return user |
|||
|
|||
|
|||
def get_current_active_user(current_user: User = Security(get_current_user)): |
|||
if not crud.user.is_active(current_user): |
|||
raise HTTPException(status_code=400, detail="Inactive user") |
|||
return current_user |
|||
|
|||
|
|||
def get_current_active_superuser(current_user: User = Security(get_current_user)): |
|||
if not crud.user.is_superuser(current_user): |
|||
raise HTTPException( |
|||
status_code=400, detail="The user doesn't have enough privileges" |
|||
) |
|||
return current_user |
@ -1,19 +0,0 @@ |
|||
from datetime import datetime, timedelta |
|||
|
|||
import jwt |
|||
|
|||
from app.core.config import settings |
|||
|
|||
ALGORITHM = "HS256" |
|||
access_token_jwt_subject = "access" |
|||
|
|||
|
|||
def create_access_token(*, data: dict, expires_delta: timedelta = None): |
|||
to_encode = data.copy() |
|||
if expires_delta: |
|||
expire = datetime.utcnow() + expires_delta |
|||
else: |
|||
expire = datetime.utcnow() + timedelta(minutes=15) |
|||
to_encode.update({"exp": expire, "sub": access_token_jwt_subject}) |
|||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) |
|||
return encoded_jwt |
@ -1,11 +1,34 @@ |
|||
from datetime import datetime, timedelta |
|||
from typing import Any, Union |
|||
|
|||
from jose import jwt |
|||
from passlib.context import CryptContext |
|||
|
|||
from app.core.config import settings |
|||
|
|||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|||
|
|||
|
|||
def verify_password(plain_password: str, hashed_password: str): |
|||
ALGORITHM = "HS256" |
|||
|
|||
|
|||
def create_access_token( |
|||
subject: Union[str, Any], expires_delta: timedelta = None |
|||
) -> str: |
|||
if expires_delta: |
|||
expire = datetime.utcnow() + expires_delta |
|||
else: |
|||
expire = datetime.utcnow() + timedelta( |
|||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES |
|||
) |
|||
to_encode = {"exp": expire, "sub": str(subject)} |
|||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) |
|||
return encoded_jwt |
|||
|
|||
|
|||
def verify_password(plain_password: str, hashed_password: str) -> bool: |
|||
return pwd_context.verify(plain_password, hashed_password) |
|||
|
|||
|
|||
def get_password_hash(password: str): |
|||
def get_password_hash(password: str) -> str: |
|||
return pwd_context.hash(password) |
|||
|
@ -1,5 +1,5 @@ |
|||
# Import all the models, so that Base has them before being |
|||
# imported by Alembic |
|||
from app.db.base_class import Base # noqa |
|||
from app.models.user import User # noqa |
|||
from app.models.item import Item # noqa |
|||
from app.models.user import User # noqa |
|||
|
@ -1,9 +1,13 @@ |
|||
from typing import Any |
|||
|
|||
from sqlalchemy.ext.declarative import as_declarative, declared_attr |
|||
|
|||
|
|||
@as_declarative() |
|||
class Base: |
|||
id: Any |
|||
__name__: str |
|||
# Generate __tablename__ automatically |
|||
@declared_attr |
|||
def __tablename__(cls): |
|||
def __tablename__(cls) -> str: |
|||
return cls.__name__.lower() |
|||
|
@ -1,24 +1,25 @@ |
|||
from app import crud |
|||
from sqlalchemy.orm import Session |
|||
|
|||
from app import crud, schemas |
|||
from app.core.config import settings |
|||
from app.schemas.user import UserCreate |
|||
from app.db import base # noqa: F401 |
|||
|
|||
# make sure all SQL Alchemy models are imported before initializing DB |
|||
# make sure all SQL Alchemy models are imported (app.db.base) before initializing DB |
|||
# otherwise, SQL Alchemy might fail to initialize relationships properly |
|||
# for more details: https://github.com/tiangolo/full-stack-fastapi-postgresql/issues/28 |
|||
from app.db import base # noqa: F401 |
|||
|
|||
|
|||
def init_db(db_session): |
|||
def init_db(db: Session) -> None: |
|||
# Tables should be created with Alembic migrations |
|||
# But if you don't want to use migrations, create |
|||
# the tables un-commenting the next line |
|||
# Base.metadata.create_all(bind=engine) |
|||
|
|||
user = crud.user.get_by_email(db_session, email=settings.FIRST_SUPERUSER) |
|||
user = crud.user.get_by_email(db, email=settings.FIRST_SUPERUSER) |
|||
if not user: |
|||
user_in = UserCreate( |
|||
user_in = schemas.UserCreate( |
|||
email=settings.FIRST_SUPERUSER, |
|||
password=settings.FIRST_SUPERUSER_PASSWORD, |
|||
is_superuser=True, |
|||
) |
|||
user = crud.user.create(db_session, obj_in=user_in) # noqa: F841 |
|||
user = crud.user.create(db, obj_in=user_in) # noqa: F841 |
|||
|
@ -1,10 +1,7 @@ |
|||
from sqlalchemy import create_engine |
|||
from sqlalchemy.orm import scoped_session, sessionmaker |
|||
from sqlalchemy.orm import sessionmaker |
|||
|
|||
from app.core.config import settings |
|||
|
|||
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True) |
|||
db_session = scoped_session( |
|||
sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|||
) |
|||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|||
|
@ -0,0 +1,2 @@ |
|||
from .item import Item |
|||
from .user import User |
@ -1,14 +1,19 @@ |
|||
from typing import TYPE_CHECKING |
|||
|
|||
from sqlalchemy import Boolean, Column, Integer, String |
|||
from sqlalchemy.orm import relationship |
|||
|
|||
from app.db.base_class import Base |
|||
|
|||
if TYPE_CHECKING: |
|||
from .item import Item # noqa: F401 |
|||
|
|||
|
|||
class User(Base): |
|||
id = Column(Integer, primary_key=True, index=True) |
|||
full_name = Column(String, index=True) |
|||
email = Column(String, unique=True, index=True) |
|||
hashed_password = Column(String) |
|||
email = Column(String, unique=True, index=True, nullable=False) |
|||
hashed_password = Column(String, nullable=False) |
|||
is_active = Column(Boolean(), default=True) |
|||
is_superuser = Column(Boolean(), default=False) |
|||
items = relationship("Item", back_populates="owner") |
|||
|
@ -0,0 +1,4 @@ |
|||
from .item import Item, ItemCreate, ItemInDB, ItemUpdate |
|||
from .msg import Msg |
|||
from .token import Token, TokenPayload |
|||
from .user import User, UserCreate, UserInDB, UserUpdate |
@ -1,20 +1,29 @@ |
|||
from typing import Dict, Iterator |
|||
|
|||
import pytest |
|||
from sqlalchemy.orm import Session |
|||
|
|||
from app.core.config import settings |
|||
from app.tests.utils.utils import get_server_api, get_superuser_token_headers |
|||
from app.db.session import SessionLocal |
|||
from app.tests.utils.user import authentication_token_from_email |
|||
from app.tests.utils.utils import get_server_api, get_superuser_token_headers |
|||
|
|||
|
|||
@pytest.fixture(scope="session") |
|||
def db() -> Iterator[Session]: |
|||
yield SessionLocal() |
|||
|
|||
|
|||
@pytest.fixture(scope="module") |
|||
def server_api(): |
|||
def server_api() -> str: |
|||
return get_server_api() |
|||
|
|||
|
|||
@pytest.fixture(scope="module") |
|||
def superuser_token_headers(): |
|||
def superuser_token_headers() -> Dict[str, str]: |
|||
return get_superuser_token_headers() |
|||
|
|||
|
|||
@pytest.fixture(scope="module") |
|||
def normal_user_token_headers(): |
|||
return authentication_token_from_email(settings.EMAIL_TEST_USER) |
|||
def normal_user_token_headers(db: Session) -> Dict[str, str]: |
|||
return authentication_token_from_email(email=settings.EMAIL_TEST_USER, db=db) |
|||
|
@ -1,94 +1,94 @@ |
|||
from fastapi.encoders import jsonable_encoder |
|||
from sqlalchemy.orm import Session |
|||
|
|||
from app import crud |
|||
from app.core.security import get_password_hash, verify_password |
|||
from app.db.session import db_session |
|||
from app.core.security import verify_password |
|||
from app.schemas.user import UserCreate, UserUpdate |
|||
from app.tests.utils.utils import random_lower_string, random_email |
|||
from app.tests.utils.utils import random_email, random_lower_string |
|||
|
|||
|
|||
def test_create_user(): |
|||
def test_create_user(db: Session) -> None: |
|||
email = random_email() |
|||
password = random_lower_string() |
|||
user_in = UserCreate(email=email, password=password) |
|||
user = crud.user.create(db_session, obj_in=user_in) |
|||
user = crud.user.create(db, obj_in=user_in) |
|||
assert user.email == email |
|||
assert hasattr(user, "hashed_password") |
|||
|
|||
|
|||
def test_authenticate_user(): |
|||
def test_authenticate_user(db: Session) -> None: |
|||
email = random_email() |
|||
password = random_lower_string() |
|||
user_in = UserCreate(email=email, password=password) |
|||
user = crud.user.create(db_session, obj_in=user_in) |
|||
authenticated_user = crud.user.authenticate( |
|||
db_session, email=email, password=password |
|||
) |
|||
user = crud.user.create(db, obj_in=user_in) |
|||
authenticated_user = crud.user.authenticate(db, email=email, password=password) |
|||
assert authenticated_user |
|||
assert user.email == authenticated_user.email |
|||
|
|||
|
|||
def test_not_authenticate_user(): |
|||
def test_not_authenticate_user(db: Session) -> None: |
|||
email = random_email() |
|||
password = random_lower_string() |
|||
user = crud.user.authenticate(db_session, email=email, password=password) |
|||
user = crud.user.authenticate(db, email=email, password=password) |
|||
assert user is None |
|||
|
|||
|
|||
def test_check_if_user_is_active(): |
|||
def test_check_if_user_is_active(db: Session) -> None: |
|||
email = random_email() |
|||
password = random_lower_string() |
|||
user_in = UserCreate(email=email, password=password) |
|||
user = crud.user.create(db_session, obj_in=user_in) |
|||
user = crud.user.create(db, obj_in=user_in) |
|||
is_active = crud.user.is_active(user) |
|||
assert is_active is True |
|||
|
|||
|
|||
def test_check_if_user_is_active_inactive(): |
|||
def test_check_if_user_is_active_inactive(db: Session) -> None: |
|||
email = random_email() |
|||
password = random_lower_string() |
|||
user_in = UserCreate(email=email, password=password, disabled=True) |
|||
user = crud.user.create(db_session, obj_in=user_in) |
|||
user = crud.user.create(db, obj_in=user_in) |
|||
is_active = crud.user.is_active(user) |
|||
assert is_active |
|||
|
|||
|
|||
def test_check_if_user_is_superuser(): |
|||
def test_check_if_user_is_superuser(db: Session) -> None: |
|||
email = random_email() |
|||
password = random_lower_string() |
|||
user_in = UserCreate(email=email, password=password, is_superuser=True) |
|||
user = crud.user.create(db_session, obj_in=user_in) |
|||
user = crud.user.create(db, obj_in=user_in) |
|||
is_superuser = crud.user.is_superuser(user) |
|||
assert is_superuser is True |
|||
|
|||
|
|||
def test_check_if_user_is_superuser_normal_user(): |
|||
def test_check_if_user_is_superuser_normal_user(db: Session) -> None: |
|||
username = random_email() |
|||
password = random_lower_string() |
|||
user_in = UserCreate(email=username, password=password) |
|||
user = crud.user.create(db_session, obj_in=user_in) |
|||
user = crud.user.create(db, obj_in=user_in) |
|||
is_superuser = crud.user.is_superuser(user) |
|||
assert is_superuser is False |
|||
|
|||
|
|||
def test_get_user(): |
|||
def test_get_user(db: Session) -> None: |
|||
password = random_lower_string() |
|||
username = random_email() |
|||
user_in = UserCreate(email=username, password=password, is_superuser=True) |
|||
user = crud.user.create(db_session, obj_in=user_in) |
|||
user_2 = crud.user.get(db_session, id=user.id) |
|||
user = crud.user.create(db, obj_in=user_in) |
|||
user_2 = crud.user.get(db, id=user.id) |
|||
assert user_2 |
|||
assert user.email == user_2.email |
|||
assert jsonable_encoder(user) == jsonable_encoder(user_2) |
|||
|
|||
|
|||
def test_update_user(): |
|||
def test_update_user(db: Session) -> None: |
|||
password = random_lower_string() |
|||
email = random_email() |
|||
user_in = UserCreate(email=email, password=password, is_superuser=True) |
|||
user = crud.user.create(db_session, obj_in=user_in) |
|||
user = crud.user.create(db, obj_in=user_in) |
|||
new_password = random_lower_string() |
|||
user_in = UserUpdate(password=new_password, is_superuser=True) |
|||
crud.user.update(db_session, db_obj=user, obj_in=user_in) |
|||
user_2 = crud.user.get(db_session, id=user.id) |
|||
user_in_update = UserUpdate(password=new_password, is_superuser=True) |
|||
crud.user.update(db, db_obj=user, obj_in=user_in_update) |
|||
user_2 = crud.user.get(db, id=user.id) |
|||
assert user_2 |
|||
assert user.email == user_2.email |
|||
assert verify_password(new_password, user_2.hashed_password) |
|||
|
@ -1,17 +1,18 @@ |
|||
from app import crud |
|||
from app.db.session import db_session |
|||
from typing import Optional |
|||
|
|||
from sqlalchemy.orm import Session |
|||
|
|||
from app import crud, models |
|||
from app.schemas.item import ItemCreate |
|||
from app.tests.utils.user import create_random_user |
|||
from app.tests.utils.utils import random_lower_string |
|||
|
|||
|
|||
def create_random_item(owner_id: int = None): |
|||
def create_random_item(db: Session, *, owner_id: Optional[int] = None) -> models.Item: |
|||
if owner_id is None: |
|||
user = create_random_user() |
|||
user = create_random_user(db) |
|||
owner_id = user.id |
|||
title = random_lower_string() |
|||
description = random_lower_string() |
|||
item_in = ItemCreate(title=title, description=description, id=id) |
|||
return crud.item.create_with_owner( |
|||
db_session=db_session, obj_in=item_in, owner_id=owner_id |
|||
) |
|||
return crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=owner_id) |
|||
|
@ -1,43 +1,50 @@ |
|||
from typing import Dict |
|||
|
|||
import requests |
|||
from sqlalchemy.orm import Session |
|||
|
|||
from app import crud |
|||
from app.core.config import settings |
|||
from app.db.session import db_session |
|||
from app.models.user import User |
|||
from app.schemas.user import UserCreate, UserUpdate |
|||
from app.tests.utils.utils import get_server_api, random_lower_string, random_email |
|||
from app.tests.utils.utils import get_server_api, random_email, random_lower_string |
|||
|
|||
|
|||
def user_authentication_headers(server_api, email, password): |
|||
def user_authentication_headers( |
|||
server_api: str, email: str, password: str |
|||
) -> Dict[str, str]: |
|||
data = {"username": email, "password": password} |
|||
|
|||
r = requests.post(f"{server_api}{settings.API_V1_STR}/login/access-token", data=data) |
|||
r = requests.post( |
|||
f"{server_api}{settings.API_V1_STR}/login/access-token", data=data |
|||
) |
|||
response = r.json() |
|||
auth_token = response["access_token"] |
|||
headers = {"Authorization": f"Bearer {auth_token}"} |
|||
return headers |
|||
|
|||
|
|||
def create_random_user(): |
|||
def create_random_user(db: Session) -> User: |
|||
email = random_email() |
|||
password = random_lower_string() |
|||
user_in = UserCreate(username=email, email=email, password=password) |
|||
user = crud.user.create(db_session=db_session, obj_in=user_in) |
|||
user = crud.user.create(db=db, obj_in=user_in) |
|||
return user |
|||
|
|||
|
|||
def authentication_token_from_email(email): |
|||
def authentication_token_from_email(*, email: str, db: Session) -> Dict[str, str]: |
|||
""" |
|||
Return a valid token for the user with given email. |
|||
|
|||
If the user doesn't exist it is created first. |
|||
""" |
|||
password = random_lower_string() |
|||
user = crud.user.get_by_email(db_session, email=email) |
|||
user = crud.user.get_by_email(db, email=email) |
|||
if not user: |
|||
user_in = UserCreate(username=email, email=email, password=password) |
|||
user = crud.user.create(db_session=db_session, obj_in=user_in) |
|||
user_in_create = UserCreate(username=email, email=email, password=password) |
|||
user = crud.user.create(db, obj_in=user_in_create) |
|||
else: |
|||
user_in = UserUpdate(password=password) |
|||
user = crud.user.update(db_session, db_obj=user, obj_in=user_in) |
|||
user_in_update = UserUpdate(password=password) |
|||
user = crud.user.update(db, db_obj=user, obj_in=user_in_update) |
|||
|
|||
return user_authentication_headers(get_server_api(), email, password) |
|||
|
@ -1,11 +1,11 @@ |
|||
from raven import Client |
|||
|
|||
from app.core.config import settings |
|||
from app.core.celery_app import celery_app |
|||
from app.core.config import settings |
|||
|
|||
client_sentry = Client(settings.SENTRY_DSN) |
|||
|
|||
|
|||
@celery_app.task(acks_late=True) |
|||
def test_celery(word: str): |
|||
def test_celery(word: str) -> str: |
|||
return f"test task return {word}" |
|||
|
@ -0,0 +1,4 @@ |
|||
[mypy] |
|||
plugins = pydantic.mypy, sqlmypy |
|||
ignore_missing_imports = True |
|||
disallow_untyped_defs = True |
@ -0,0 +1,6 @@ |
|||
#!/bin/sh -e |
|||
set -x |
|||
|
|||
# Sort imports one per line, so autoflake can remove unused imports |
|||
isort --recursive --force-single-line-imports --apply app |
|||
sh ./scripts/format.sh |
@ -0,0 +1,6 @@ |
|||
#!/bin/sh -e |
|||
set -x |
|||
|
|||
autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place app --exclude=__init__.py |
|||
black app |
|||
isort --recursive --apply app |
@ -0,0 +1,6 @@ |
|||
#!/usr/bin/env bash |
|||
|
|||
set -e |
|||
set -x |
|||
|
|||
bash scripts/test.sh --cov-report=html "${@}" |
@ -0,0 +1,6 @@ |
|||
#!/usr/bin/env bash |
|||
|
|||
set -e |
|||
set -x |
|||
|
|||
pytest --cov=app --cov-report=term-missing app/tests "${@}" |
Loading…
Reference in new issue