committed by
GitHub
13 changed files with 104 additions and 307 deletions
@ -0,0 +1,55 @@ |
|||
from typing import Any |
|||
|
|||
from sqlmodel import Session, select |
|||
|
|||
from app.core.security import get_password_hash, verify_password |
|||
from app.models import Item, ItemCreate, User, UserCreate, UserUpdate |
|||
|
|||
|
|||
def create_user(*, session: Session, user_create: UserCreate) -> User: |
|||
db_obj = User.model_validate( |
|||
user_create, update={"hashed_password": get_password_hash(user_create.password)} |
|||
) |
|||
session.add(db_obj) |
|||
session.commit() |
|||
session.refresh(db_obj) |
|||
return db_obj |
|||
|
|||
|
|||
def update_user(*, session: Session, user_id: int, user_in: UserUpdate) -> Any: |
|||
db_user = session.get(User, user_id) |
|||
if not db_user: |
|||
return None |
|||
user_data = user_in.model_dump(exclude_unset=True) |
|||
extra_data = {} |
|||
if "password" in user_data: |
|||
password = user_data["password"] |
|||
hashed_password = get_password_hash(password) |
|||
extra_data["hashed_password"] = hashed_password |
|||
db_user.sqlmodel_update(user_data, update=extra_data) |
|||
session.add(db_user) |
|||
session.commit() |
|||
session.refresh(db_user) |
|||
return db_user |
|||
|
|||
|
|||
def get_user_by_email(*, session: Session, email: str) -> User | None: |
|||
statement = select(User).where(User.email == email) |
|||
session_user = session.exec(statement).first() |
|||
return session_user |
|||
|
|||
|
|||
def authenticate(*, session: Session, email: str, password: str) -> User | None: |
|||
db_user = get_user_by_email(session=session, email=email) |
|||
if not db_user: |
|||
return None |
|||
if not verify_password(password, db_user.hashed_password): |
|||
return None |
|||
return db_user |
|||
|
|||
def create_item(*, session: Session, item_in: ItemCreate, owner_id: int) -> Item: |
|||
db_item = Item.model_validate(item_in, update={"owner_id": owner_id}) |
|||
session.add(db_item) |
|||
session.commit() |
|||
session.refresh(db_item) |
|||
return db_item |
@ -1,37 +0,0 @@ |
|||
# For a new basic set of CRUD operations you could just do |
|||
# from .base import CRUDBase |
|||
# from app.models.item import Item |
|||
# from app.schemas.item import ItemCreate, ItemUpdate |
|||
# item = CRUDBase[Item, ItemCreate, ItemUpdate](Item) |
|||
from sqlmodel import Session, select |
|||
|
|||
from app.core.security import get_password_hash, verify_password |
|||
from app.models import User, UserCreate |
|||
|
|||
from .crud_item import item as item |
|||
from .crud_user import user as user |
|||
|
|||
|
|||
def create_user(*, session: Session, user_create: UserCreate) -> User: |
|||
db_obj = User.from_orm( |
|||
user_create, update={"hashed_password": get_password_hash(user_create.password)} |
|||
) |
|||
session.add(db_obj) |
|||
session.commit() |
|||
session.refresh(db_obj) |
|||
return db_obj |
|||
|
|||
|
|||
def get_user_by_email(*, session: Session, email: str) -> User | None: |
|||
statement = select(User).where(User.email == email) |
|||
session_user = session.exec(statement).first() |
|||
return session_user |
|||
|
|||
|
|||
def authenticate(*, session: Session, email: str, password: str) -> User | None: |
|||
db_user = get_user_by_email(session=session, email=email) |
|||
if not db_user: |
|||
return None |
|||
if not verify_password(password, db_user.hashed_password): |
|||
return None |
|||
return db_user |
@ -1,59 +0,0 @@ |
|||
from typing import Any, Generic, TypeVar |
|||
|
|||
from fastapi.encoders import jsonable_encoder |
|||
from pydantic import BaseModel |
|||
from sqlalchemy.orm import Session |
|||
|
|||
ModelType = TypeVar("ModelType", bound=Any) |
|||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) |
|||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) |
|||
|
|||
|
|||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): |
|||
def __init__(self, model: type[ModelType]): |
|||
""" |
|||
CRUD object with default methods to Create, Read, Update, Delete (CRUD). |
|||
|
|||
**Parameters** |
|||
|
|||
* `model`: A SQLAlchemy model class |
|||
* `schema`: A Pydantic model (schema) class |
|||
""" |
|||
self.model = model |
|||
|
|||
def get(self, db: Session, id: Any) -> ModelType | None: |
|||
return db.query(self.model).filter(self.model.id == id).first() |
|||
|
|||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: |
|||
obj_in_data = jsonable_encoder(obj_in) |
|||
db_obj = self.model(**obj_in_data) # type: ignore |
|||
db.add(db_obj) |
|||
db.commit() |
|||
db.refresh(db_obj) |
|||
return db_obj |
|||
|
|||
def update( |
|||
self, |
|||
db: Session, |
|||
*, |
|||
db_obj: ModelType, |
|||
obj_in: UpdateSchemaType | dict[str, Any], |
|||
) -> ModelType: |
|||
obj_data = jsonable_encoder(db_obj) |
|||
if isinstance(obj_in, dict): |
|||
update_data = obj_in |
|||
else: |
|||
update_data = obj_in.dict(exclude_unset=True) |
|||
for field in obj_data: |
|||
if field in update_data: |
|||
setattr(db_obj, field, update_data[field]) |
|||
db.add(db_obj) |
|||
db.commit() |
|||
db.refresh(db_obj) |
|||
return db_obj |
|||
|
|||
def remove(self, db: Session, *, id: int) -> ModelType: |
|||
obj = db.query(self.model).get(id) |
|||
db.delete(obj) |
|||
db.commit() |
|||
return obj |
@ -1,32 +0,0 @@ |
|||
from fastapi.encoders import jsonable_encoder |
|||
from sqlalchemy.orm import Session |
|||
|
|||
from app.crud.base import CRUDBase |
|||
from app.models import Item |
|||
from app.schemas.item import ItemCreate, ItemUpdate |
|||
|
|||
|
|||
class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]): |
|||
def create_with_owner( |
|||
self, db: Session, *, obj_in: ItemCreate, owner_id: int |
|||
) -> Item: |
|||
obj_in_data = jsonable_encoder(obj_in) |
|||
db_obj = self.model(**obj_in_data, owner_id=owner_id) |
|||
db.add(db_obj) |
|||
db.commit() |
|||
db.refresh(db_obj) |
|||
return db_obj |
|||
|
|||
def get_multi_by_owner( |
|||
self, db: Session, *, owner_id: int, skip: int = 0, limit: int = 100 |
|||
) -> list[Item]: |
|||
return ( |
|||
db.query(self.model) |
|||
.filter(Item.owner_id == owner_id) |
|||
.offset(skip) |
|||
.limit(limit) |
|||
.all() |
|||
) |
|||
|
|||
|
|||
item = CRUDItem(Item) |
@ -1,55 +0,0 @@ |
|||
from typing import Any |
|||
|
|||
from sqlalchemy.orm import Session |
|||
|
|||
from app.core.security import get_password_hash, verify_password |
|||
from app.crud.base import CRUDBase |
|||
from app.models import User |
|||
from app.schemas.user import UserCreate, UserUpdate |
|||
|
|||
|
|||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): |
|||
def get_by_email(self, db: Session, *, email: str) -> User | None: |
|||
return db.query(User).filter(User.email == email).first() |
|||
|
|||
def create(self, db: Session, *, obj_in: UserCreate) -> User: |
|||
db_obj = User( |
|||
email=obj_in.email, |
|||
hashed_password=get_password_hash(obj_in.password), |
|||
full_name=obj_in.full_name, |
|||
is_superuser=obj_in.is_superuser, |
|||
) |
|||
db.add(db_obj) |
|||
db.commit() |
|||
db.refresh(db_obj) |
|||
return db_obj |
|||
|
|||
def update( |
|||
self, db: Session, *, db_obj: User, obj_in: UserUpdate | dict[str, Any] |
|||
) -> User: |
|||
if isinstance(obj_in, dict): |
|||
update_data = obj_in |
|||
else: |
|||
update_data = obj_in.dict(exclude_unset=True) |
|||
if update_data["password"]: |
|||
hashed_password = get_password_hash(update_data["password"]) |
|||
del update_data["password"] |
|||
update_data["hashed_password"] = hashed_password |
|||
return super().update(db, db_obj=db_obj, obj_in=update_data) |
|||
|
|||
def authenticate(self, db: Session, *, email: str, password: str) -> User | None: |
|||
user = self.get_by_email(db, email=email) |
|||
if not user: |
|||
return None |
|||
if not verify_password(password, user.hashed_password): |
|||
return None |
|||
return user |
|||
|
|||
def is_active(self, user: User) -> bool: |
|||
return user.is_active |
|||
|
|||
def is_superuser(self, user: User) -> bool: |
|||
return user.is_superuser |
|||
|
|||
|
|||
user = CRUDUser(User) |
@ -1,61 +0,0 @@ |
|||
from sqlalchemy.orm import Session |
|||
|
|||
from app import crud |
|||
from app.schemas.item import ItemCreate, ItemUpdate |
|||
from app.tests.utils.user import create_random_user |
|||
from app.tests.utils.utils import random_lower_string |
|||
|
|||
|
|||
def test_create_item(db: Session) -> None: |
|||
title = random_lower_string() |
|||
description = random_lower_string() |
|||
item_in = ItemCreate(title=title, description=description) |
|||
user = create_random_user(db) |
|||
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id) |
|||
assert item.title == title |
|||
assert item.description == description |
|||
assert item.owner_id == user.id |
|||
|
|||
|
|||
def test_get_item(db: Session) -> None: |
|||
title = random_lower_string() |
|||
description = random_lower_string() |
|||
item_in = ItemCreate(title=title, description=description) |
|||
user = create_random_user(db) |
|||
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id) |
|||
stored_item = crud.item.get(db=db, id=item.id) |
|||
assert stored_item |
|||
assert item.id == stored_item.id |
|||
assert item.title == stored_item.title |
|||
assert item.description == stored_item.description |
|||
assert item.owner_id == stored_item.owner_id |
|||
|
|||
|
|||
def test_update_item(db: Session) -> None: |
|||
title = random_lower_string() |
|||
description = random_lower_string() |
|||
item_in = ItemCreate(title=title, description=description) |
|||
user = create_random_user(db) |
|||
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id) |
|||
description2 = random_lower_string() |
|||
item_update = ItemUpdate(description=description2) |
|||
item2 = crud.item.update(db=db, db_obj=item, obj_in=item_update) |
|||
assert item.id == item2.id |
|||
assert item.title == item2.title |
|||
assert item2.description == description2 |
|||
assert item.owner_id == item2.owner_id |
|||
|
|||
|
|||
def test_delete_item(db: Session) -> None: |
|||
title = random_lower_string() |
|||
description = random_lower_string() |
|||
item_in = ItemCreate(title=title, description=description) |
|||
user = create_random_user(db) |
|||
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id) |
|||
item2 = crud.item.remove(db=db, id=item.id) |
|||
item3 = crud.item.get(db=db, id=item.id) |
|||
assert item3 is None |
|||
assert item2.id == item.id |
|||
assert item2.title == title |
|||
assert item2.description == description |
|||
assert item2.owner_id == user.id |
@ -1,16 +1,16 @@ |
|||
from sqlalchemy.orm import Session |
|||
from sqlmodel import Session |
|||
|
|||
from app import crud, models |
|||
from app.schemas.item import ItemCreate |
|||
from app import crud |
|||
from app.models import Item, ItemCreate |
|||
from app.tests.utils.user import create_random_user |
|||
from app.tests.utils.utils import random_lower_string |
|||
|
|||
|
|||
def create_random_item(db: Session, *, owner_id: int | None = None) -> models.Item: |
|||
if owner_id is None: |
|||
user = create_random_user(db) |
|||
owner_id = user.id |
|||
def create_random_item(db: Session) -> Item: |
|||
user = create_random_user(db) |
|||
owner_id = user.id |
|||
assert owner_id is not None |
|||
title = random_lower_string() |
|||
description = random_lower_string() |
|||
item_in = ItemCreate(title=title, description=description, id=id) |
|||
return crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=owner_id) |
|||
item_in = ItemCreate(title=title, description=description) |
|||
return crud.create_item(session=db, item_in=item_in, owner_id=owner_id) |
|||
|
Loading…
Reference in new issue